Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ Release Notes

.. Upcoming Version

Upcoming Version
----------------

* Fix multiplication of constant-only ``LinearExpression`` with other expressions

Version 0.6.1
--------------

Expand Down
10 changes: 8 additions & 2 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from itertools import product, zip_longest
from typing import TYPE_CHECKING, Any, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -507,12 +507,18 @@ def __neg__(self: GenericExpression) -> GenericExpression:

def _multiply_by_linear_expression(
self, other: LinearExpression | ScalarLinearExpression
) -> QuadraticExpression:
) -> LinearExpression | QuadraticExpression:
if isinstance(other, ScalarLinearExpression):
other = other.to_linexpr()

if other.nterm > 1:
raise TypeError("Multiplication of multiple terms is not supported.")

if other.is_constant:
return cast(LinearExpression, self._multiply_by_constant(other.const))
if self.is_constant:
return cast(LinearExpression, other._multiply_by_constant(self.const))

# multiplication: (v1 + c1) * (v2 + c2) = v1 * v2 + c1 * v2 + c2 * v1 + c1 * c2
# with v being the variables and c the constants
# merge on factor dimension only returns v1 * v2 + c1 * c2
Expand Down
5 changes: 4 additions & 1 deletion linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import (
TYPE_CHECKING,
Any,
cast,
overload,
)
from warnings import warn
Expand Down Expand Up @@ -420,7 +421,9 @@ def __pow__(self, other: int) -> QuadraticExpression:
return NotImplemented
if other == 2:
expr = self.to_linexpr()
return expr._multiply_by_linear_expression(expr)
return cast(
"QuadraticExpression", expr._multiply_by_linear_expression(expr)
)
raise ValueError("Can only raise to the power of 2")

@overload
Expand Down
86 changes: 86 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,3 +1313,89 @@ def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None:
assert all(simplified.coeffs.values == 3.0), (
f"Expected coefficient 3.0, got {simplified.coeffs.values}"
)


def test_constant_only_expression_mul_dataarray(m: Model) -> None:
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
const_expr = LinearExpression(const_arr, m)
assert const_expr.is_constant
assert const_expr.nterm == 0

data_arr = xr.DataArray([10, 20], dims=["dim_0"])
expected_const = const_arr * data_arr

result = const_expr * data_arr
assert isinstance(result, LinearExpression)
assert result.is_constant
assert (result.const == expected_const).all()

result_rev = data_arr * const_expr
assert isinstance(result_rev, LinearExpression)
assert result_rev.is_constant
assert (result_rev.const == expected_const).all()


def test_constant_only_expression_mul_linexpr_with_vars(m: Model, x: Variable) -> None:
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
const_expr = LinearExpression(const_arr, m)
assert const_expr.is_constant
assert const_expr.nterm == 0

expr_with_vars = 1 * x + 5
expected_coeffs = const_arr
expected_const = const_arr * 5

result = const_expr * expr_with_vars
assert isinstance(result, LinearExpression)
assert (result.coeffs == expected_coeffs).all()
assert (result.const == expected_const).all()

result_rev = expr_with_vars * const_expr
assert isinstance(result_rev, LinearExpression)
assert (result_rev.coeffs == expected_coeffs).all()
assert (result_rev.const == expected_const).all()


def test_constant_only_expression_mul_constant_only(m: Model) -> None:
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
const_arr2 = xr.DataArray([4, 5], dims=["dim_0"])
const_expr = LinearExpression(const_arr, m)
const_expr2 = LinearExpression(const_arr2, m)
assert const_expr.is_constant
assert const_expr2.is_constant

expected_const = const_arr * const_arr2

result = const_expr * const_expr2
assert isinstance(result, LinearExpression)
assert result.is_constant
assert (result.const == expected_const).all()

result_rev = const_expr2 * const_expr
assert isinstance(result_rev, LinearExpression)
assert result_rev.is_constant
assert (result_rev.const == expected_const).all()


def test_constant_only_expression_mul_linexpr_with_vars_and_const(
m: Model, x: Variable
) -> None:
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
const_expr = LinearExpression(const_arr, m)
assert const_expr.is_constant

expr_with_vars_and_const = 4 * x + 10
expected_coeffs = const_arr * 4
expected_const = const_arr * 10

result = const_expr * expr_with_vars_and_const
assert isinstance(result, LinearExpression)
assert not result.is_constant
assert (result.coeffs == expected_coeffs).all()
assert (result.const == expected_const).all()

result_rev = expr_with_vars_and_const * const_expr
assert isinstance(result_rev, LinearExpression)
assert not result_rev.is_constant
assert (result_rev.coeffs == expected_coeffs).all()
assert (result_rev.const == expected_const).all()