diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 13df0267..d918b9cc 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,7 @@ Release Notes Upcoming Version ---------------- +* Fix the handling of multiplication between ``LinearExpression`` and constants with a subset of dimensions. Align with ``Variable`` behaviour * Fix docs (pick highs solver) * Add the `sphinx-copybutton` to the documentation * Add ``auto_mask`` parameter to ``Model`` class that automatically masks variables and constraints where bounds, coefficients, or RHS values contain NaN. This eliminates the need to manually create mask arrays when working with sparse or incomplete data. diff --git a/linopy/expressions.py b/linopy/expressions.py index 649989f7..0d4407b4 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -538,7 +538,6 @@ def _multiply_by_constant( ) -> GenericExpression: multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims) coeffs = self.coeffs * multiplier - assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items()) const = self.const * multiplier return self.assign(coeffs=coeffs, const=const) diff --git a/linopy/testing.py b/linopy/testing.py index 0392064e..216ad423 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -1,5 +1,8 @@ from __future__ import annotations +from collections.abc import Iterable + +import numpy as np from xarray.testing import assert_equal from linopy.constraints import Constraint, _con_unwrap @@ -72,3 +75,13 @@ def assert_model_equal(a: Model, b: Model) -> None: assert a.termination_condition == b.termination_condition assert a.type == b.type + + +def assert_lists_equal(x: Iterable[float], b: Iterable[float]) -> None: + x = list(x) + b = list(b) + assert len(x) == len(b) + for xi, bi in zip(x, b): + if np.isnan(xi) and np.isnan(bi): + continue + assert xi == bi diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 0da9ec7f..57519745 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,10 +14,10 @@ import xarray as xr from xarray.testing import assert_equal -from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge +from linopy import LinearExpression, Model, QuadraticExpression, Variable from linopy.constants import HELPER_DIMS, TERM_DIM -from linopy.expressions import ScalarLinearExpression -from linopy.testing import assert_linequal, assert_quadequal +from linopy.expressions import ScalarLinearExpression, merge +from linopy.testing import assert_linequal, assert_lists_equal, assert_quadequal from linopy.variables import ScalarVariable @@ -238,6 +238,44 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: assert expr.__rmul__(object()) is NotImplemented +def test_linear_expression_multiplication_with_missing_coords() -> None: + m = Model() + full_index = pd.Index(range(5), name="i") + x = m.add_variables(coords=[full_index]) + nan = float("nan") + scale = xr.DataArray([10.0, 30.0], dims=["i"], coords={"i": [1, 3]}) + + # These two expressions should produce the same result + r1 = x * scale + r2 = (1 * x) * scale + + for result in [r1, r2]: + assert result.coords.equals(x.coords) + assert result.vars.equals(r1.vars) + + # Use pandas to make sure nans are considered equal + expected_coeffs = [nan, 10.0, nan, 30.0, nan] + assert_lists_equal(result.coeffs.values.squeeze(), expected_coeffs) + + +def test_linear_expression_with_missing_coords_in_coeff_and_const() -> None: + m = Model() + full_index = pd.Index(range(5), name="i") + x = m.add_variables(coords=[full_index]) + nan = float("nan") + scale = xr.DataArray([10.0, 30.0], dims=["i"], coords={"i": [1, 3]}) + const = xr.DataArray([1.0, 2.0], dims=["i"], coords={"i": [0, 1]}) + + # These two expressions should produce the same result + result = (x + const) * scale + assert result.coords.equals(x.coords) + + expected_coeffs = [nan, 10.0, nan, 30.0, nan] + expected_const = [nan, 20.0, nan, 0.0, nan] # Constants are filled with zeros + assert_lists_equal(result.coeffs.values.squeeze(), expected_coeffs) + assert_lists_equal(result.const.values.squeeze(), expected_const) + + def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None: expr = 10 * x + y assert isinstance(expr, LinearExpression)