Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions linopy/testing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from __future__ import annotations

import numpy as np
from xarray.testing import assert_equal

from linopy.constants import TERM_DIM
from linopy.constraints import ConstraintBase, _con_unwrap
from linopy.expressions import LinearExpression, QuadraticExpression, _expr_unwrap
from linopy.model import Model
from linopy.variables import Variable, _var_unwrap


def _sort_by_vars_along_term(expr: LinearExpression) -> LinearExpression:
"""Sort a linear expression's terms by variable labels along _term."""
ds = expr.data
if TERM_DIM not in ds.dims:
return expr
order = np.argsort(ds["vars"].values, axis=-1, kind="stable")
sorted_vars = np.take_along_axis(ds["vars"].values, order, axis=-1)
sorted_coeffs = np.take_along_axis(ds["coeffs"].values, order, axis=-1)
new_ds = ds.copy()
new_ds["vars"] = (ds["vars"].dims, sorted_vars)
new_ds["coeffs"] = (ds["coeffs"].dims, sorted_coeffs)
return LinearExpression(new_ds, expr.model)


def assert_varequal(a: Variable, b: Variable) -> None:
"""Assert that two variables are equal."""
return assert_equal(_var_unwrap(a), _var_unwrap(b))
Expand All @@ -16,10 +32,18 @@ def assert_varequal(a: Variable, b: Variable) -> None:
def assert_linequal(
a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression
) -> None:
"""Assert that two linear expressions are equal."""
"""
Assert that two linear expressions are semantically equal.

Terms are sorted by variable labels along _term before comparing,
so expressions with different term orderings but identical mathematical
meaning are considered equal.
"""
assert isinstance(a, LinearExpression)
assert isinstance(b, LinearExpression)
return assert_equal(_expr_unwrap(a), _expr_unwrap(b))
a_sorted = _sort_by_vars_along_term(a)
b_sorted = _sort_by_vars_along_term(b)
return assert_equal(_expr_unwrap(a_sorted), _expr_unwrap(b_sorted))


def assert_quadequal(
Expand Down