diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 882552dd..52c08dad 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -3,6 +3,11 @@ Release Notes .. Upcoming Version +Upcoming Version +---------------- + +* Fix multiplication of constant-only ``LinearExpression`` with other expressions + Version 0.6.1 -------------- diff --git a/linopy/expressions.py b/linopy/expressions.py index 10e243de..848067cf 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -13,7 +13,7 @@ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, cast, overload from warnings import warn import numpy as np @@ -507,12 +507,18 @@ def __neg__(self: GenericExpression) -> GenericExpression: def _multiply_by_linear_expression( self, other: LinearExpression | ScalarLinearExpression - ) -> QuadraticExpression: + ) -> LinearExpression | QuadraticExpression: if isinstance(other, ScalarLinearExpression): other = other.to_linexpr() if other.nterm > 1: raise TypeError("Multiplication of multiple terms is not supported.") + + if other.is_constant: + return cast(LinearExpression, self._multiply_by_constant(other.const)) + if self.is_constant: + return cast(LinearExpression, other._multiply_by_constant(self.const)) + # multiplication: (v1 + c1) * (v2 + c2) = v1 * v2 + c1 * v2 + c2 * v1 + c1 * c2 # with v being the variables and c the constants # merge on factor dimension only returns v1 * v2 + c1 * c2 diff --git a/linopy/variables.py b/linopy/variables.py index e2570b5d..d90a4775 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -14,6 +14,7 @@ from typing import ( TYPE_CHECKING, Any, + cast, overload, ) from warnings import warn @@ -420,7 +421,9 @@ def __pow__(self, other: int) -> QuadraticExpression: return NotImplemented if other == 2: expr = self.to_linexpr() - return expr._multiply_by_linear_expression(expr) + return cast( + "QuadraticExpression", expr._multiply_by_linear_expression(expr) + ) raise ValueError("Can only raise to the power of 2") @overload diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index a75ace3f..0da9ec7f 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1313,3 +1313,89 @@ def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None: assert all(simplified.coeffs.values == 3.0), ( f"Expected coefficient 3.0, got {simplified.coeffs.values}" ) + + +def test_constant_only_expression_mul_dataarray(m: Model) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + assert const_expr.nterm == 0 + + data_arr = xr.DataArray([10, 20], dims=["dim_0"]) + expected_const = const_arr * data_arr + + result = const_expr * data_arr + assert isinstance(result, LinearExpression) + assert result.is_constant + assert (result.const == expected_const).all() + + result_rev = data_arr * const_expr + assert isinstance(result_rev, LinearExpression) + assert result_rev.is_constant + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_linexpr_with_vars(m: Model, x: Variable) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + assert const_expr.nterm == 0 + + expr_with_vars = 1 * x + 5 + expected_coeffs = const_arr + expected_const = const_arr * 5 + + result = const_expr * expr_with_vars + assert isinstance(result, LinearExpression) + assert (result.coeffs == expected_coeffs).all() + assert (result.const == expected_const).all() + + result_rev = expr_with_vars * const_expr + assert isinstance(result_rev, LinearExpression) + assert (result_rev.coeffs == expected_coeffs).all() + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_constant_only(m: Model) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_arr2 = xr.DataArray([4, 5], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + const_expr2 = LinearExpression(const_arr2, m) + assert const_expr.is_constant + assert const_expr2.is_constant + + expected_const = const_arr * const_arr2 + + result = const_expr * const_expr2 + assert isinstance(result, LinearExpression) + assert result.is_constant + assert (result.const == expected_const).all() + + result_rev = const_expr2 * const_expr + assert isinstance(result_rev, LinearExpression) + assert result_rev.is_constant + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_linexpr_with_vars_and_const( + m: Model, x: Variable +) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + + expr_with_vars_and_const = 4 * x + 10 + expected_coeffs = const_arr * 4 + expected_const = const_arr * 10 + + result = const_expr * expr_with_vars_and_const + assert isinstance(result, LinearExpression) + assert not result.is_constant + assert (result.coeffs == expected_coeffs).all() + assert (result.const == expected_const).all() + + result_rev = expr_with_vars_and_const * const_expr + assert isinstance(result_rev, LinearExpression) + assert not result_rev.is_constant + assert (result_rev.coeffs == expected_coeffs).all() + assert (result_rev.const == expected_const).all()