From 8e3a52c8e538f283f1b2194073648ce1a674f5e6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 1 Apr 2026 11:47:45 +0200 Subject: [PATCH] fix: make assert_linequal compare semantic equality of expressions Sort both sides by variable labels along _term before comparing, so expressions with different term orderings (e.g. from CSR round-trip with freeze_constraints=True) are correctly recognized as equal. Co-Authored-By: Claude Opus 4.6 (1M context) --- linopy/testing.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/linopy/testing.py b/linopy/testing.py index e6a58a0d..31082284 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -1,13 +1,29 @@ from __future__ import annotations +import numpy as np from xarray.testing import assert_equal +from linopy.constants import TERM_DIM from linopy.constraints import ConstraintBase, _con_unwrap from linopy.expressions import LinearExpression, QuadraticExpression, _expr_unwrap from linopy.model import Model from linopy.variables import Variable, _var_unwrap +def _sort_by_vars_along_term(expr: LinearExpression) -> LinearExpression: + """Sort a linear expression's terms by variable labels along _term.""" + ds = expr.data + if TERM_DIM not in ds.dims: + return expr + order = np.argsort(ds["vars"].values, axis=-1, kind="stable") + sorted_vars = np.take_along_axis(ds["vars"].values, order, axis=-1) + sorted_coeffs = np.take_along_axis(ds["coeffs"].values, order, axis=-1) + new_ds = ds.copy() + new_ds["vars"] = (ds["vars"].dims, sorted_vars) + new_ds["coeffs"] = (ds["coeffs"].dims, sorted_coeffs) + return LinearExpression(new_ds, expr.model) + + def assert_varequal(a: Variable, b: Variable) -> None: """Assert that two variables are equal.""" return assert_equal(_var_unwrap(a), _var_unwrap(b)) @@ -16,10 +32,18 @@ def assert_varequal(a: Variable, b: Variable) -> None: def assert_linequal( a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression ) -> None: - """Assert that two linear expressions are equal.""" + """ + Assert that two linear expressions are semantically equal. + + Terms are sorted by variable labels along _term before comparing, + so expressions with different term orderings but identical mathematical + meaning are considered equal. + """ assert isinstance(a, LinearExpression) assert isinstance(b, LinearExpression) - return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) + a_sorted = _sort_by_vars_along_term(a) + b_sorted = _sort_by_vars_along_term(b) + return assert_equal(_expr_unwrap(a_sorted), _expr_unwrap(b_sorted)) def assert_quadequal(