diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 13df0267..c393043d 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,7 @@ Release Notes Upcoming Version ---------------- +* Fix warning when multiplying variables with pd.Series containing time-zone aware index * Fix docs (pick highs solver) * Add the `sphinx-copybutton` to the documentation * Add ``auto_mask`` parameter to ``Model`` class that automatically masks variables and constraints where bounds, coefficients, or RHS values contain NaN. This eliminates the need to manually create mask arrays when working with sparse or incomplete data. @@ -20,7 +21,6 @@ Version 0.6.0 -------------- **Features** - * Add ``mock_solve`` option to ``Model.solve()`` for quick testing without actual solving * Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()`` * Add ``simplify`` method to ``LinearExpression`` to combine duplicate terms diff --git a/linopy/common.py b/linopy/common.py index e6eef583..ba38c67e 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -12,12 +12,13 @@ from collections.abc import Callable, Generator, Hashable, Iterable, Sequence from functools import partial, reduce, wraps from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypeVar, overload from warnings import warn import numpy as np import pandas as pd import polars as pl +import xarray as xr from numpy import arange, signedinteger from xarray import DataArray, Dataset, apply_ufunc, broadcast from xarray import align as xr_align @@ -45,6 +46,48 @@ from linopy.variables import Variable +class CoordAlignWarning(UserWarning): ... + + +class TimezoneAlignError(ValueError): ... + + +P = ParamSpec("P") +R = TypeVar("R") + + +class CatchDatetimeTypeError: + """Context manager that catches datetime-related TypeErrors and re-raises as TimezoneAlignError.""" + + def __enter__(self) -> CatchDatetimeTypeError: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> Literal[False]: + if exc_type is TypeError and exc_val is not None: + if "Cannot interpret 'datetime" in str(exc_val): + raise TimezoneAlignError( + "Timezone information across datetime coordinates not aligned." + ) from exc_val + return False + + +def catch_datetime_type_error_and_re_raise(func: Callable[P, R]) -> Callable[P, R]: + """Decorator that catches datetime-related TypeErrors and re-raises as TimezoneAlignError.""" + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + with CatchDatetimeTypeError(): + result = func(*args, **kwargs) + return result + + return wrapper + + def set_int_index(series: pd.Series) -> pd.Series: """ Convert string index to int index. @@ -128,6 +171,21 @@ def get_from_iterable(lst: DimsLike | None, index: int) -> Any | None: return lst[index] if 0 <= index < len(lst) else None +def try_to_convert_to_pd_datetime_index( + coord: xr.DataArray | Sequence | pd.Index | Any, +) -> pd.DatetimeIndex | xr.DataArray | Sequence | pd.Index | Any: + if isinstance(coord, pd.DatetimeIndex): + return coord + try: + if isinstance(coord, xr.DataArray): + index = coord.to_index() + assert isinstance(index, pd.DatetimeIndex) + return index + return pd.DatetimeIndex(coord) + except Exception: + return coord + + def pandas_to_dataarray( arr: pd.DataFrame | pd.Series, coords: CoordsLike | None = None, @@ -168,7 +226,10 @@ def pandas_to_dataarray( shared_dims = set(pandas_coords.keys()) & set(coords.keys()) non_aligned = [] for dim in shared_dims: + pd_coord = pandas_coords[dim] coord = coords[dim] + if isinstance(pd_coord, pd.DatetimeIndex): + coord = try_to_convert_to_pd_datetime_index(coord) if not isinstance(coord, pd.Index): coord = pd.Index(coord) if not pandas_coords[dim].equals(coord): @@ -178,7 +239,8 @@ def pandas_to_dataarray( f"coords for dimension(s) {non_aligned} is not aligned with the pandas object. " "Previously, the indexes of the pandas were ignored and overwritten in " "these cases. Now, the pandas object's coordinates are taken considered" - " for alignment." + " for alignment.", + CoordAlignWarning, ) return DataArray(arr, coords=None, dims=dims, **kwargs) @@ -468,6 +530,7 @@ def maybe_group_terms_polars(df: pl.DataFrame) -> pl.DataFrame: return df.select(keys + ["coeffs"] + rest) +@catch_datetime_type_error_and_re_raise def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: """ Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal. @@ -477,7 +540,7 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: except ValueError: warn( "Coordinates across variables not equal. Perform outer join.", - UserWarning, + CoordAlignWarning, ) arrs = xr_align(*dataarrays, join="outer") if integer_dtype: @@ -485,6 +548,7 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: return Dataset({ds.name: ds for ds in arrs}) +@catch_datetime_type_error_and_re_raise def assign_multiindex_safe(ds: Dataset, **fields: Any) -> Dataset: """ Assign a field to a xarray Dataset while being safe against warnings about multiindex corruption. diff --git a/linopy/expressions.py b/linopy/expressions.py index 649989f7..79ffd173 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -47,6 +47,7 @@ LocIndexer, as_dataarray, assign_multiindex_safe, + catch_datetime_type_error_and_re_raise, check_common_keys_values, check_has_nulls, check_has_nulls_polars, @@ -506,6 +507,7 @@ def __neg__(self: GenericExpression) -> GenericExpression: """ return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const) + @catch_datetime_type_error_and_re_raise def _multiply_by_linear_expression( self, other: LinearExpression | ScalarLinearExpression ) -> LinearExpression | QuadraticExpression: @@ -533,6 +535,7 @@ def _multiply_by_linear_expression( res = res + self.reset_const() * other.const return res + @catch_datetime_type_error_and_re_raise def _multiply_by_constant( self: GenericExpression, other: ConstantLike ) -> GenericExpression: @@ -1456,7 +1459,7 @@ def to_polars(self) -> pl.DataFrame: The resulting DataFrame represents a long table format of the all non-masked expressions with non-zero coefficients. It contains the - columns `coeffs`, `vars`, `const`. The coeffs and vars columns will be null if the expression is constant. + columns `vars`, `coeffs`, `const`. The coeffs and vars columns will be null if the expression is constant. Returns ------- @@ -1472,7 +1475,7 @@ def to_polars(self) -> pl.DataFrame: df = filter_nulls_polars(df) df = maybe_group_terms_polars(df) check_has_nulls_polars(df, name=self.type) - return df + return df.select(["vars", "coeffs", "const"]) def simplify(self) -> LinearExpression: """ diff --git a/linopy/variables.py b/linopy/variables.py index d90a4775..78e06ac8 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -35,6 +35,7 @@ LocIndexer, as_dataarray, assign_multiindex_safe, + catch_datetime_type_error_and_re_raise, check_has_nulls, check_has_nulls_polars, filter_nulls_polars, @@ -296,6 +297,7 @@ def loc(self) -> LocIndexer: def to_pandas(self) -> pd.Series: return self.labels.to_pandas() + @catch_datetime_type_error_and_re_raise def to_linexpr( self, coefficient: ConstantLike = 1, diff --git a/test/test_common.py b/test/test_common.py index c3500155..fdd3f3fd 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -5,6 +5,9 @@ @author: fabian """ +from datetime import datetime +from zoneinfo import ZoneInfo + import numpy as np import pandas as pd import polars as pl @@ -16,6 +19,7 @@ from linopy import LinearExpression, Model, Variable from linopy.common import ( + CoordAlignWarning, align, as_dataarray, assign_multiindex_safe, @@ -27,6 +31,8 @@ ) from linopy.testing import assert_linequal, assert_varequal +UTC = ZoneInfo("UTC") + def test_as_dataarray_with_series_dims_default() -> None: target_dim = "dim_0" @@ -74,6 +80,67 @@ def test_as_dataarray_with_series_dims_priority() -> None: assert list(da.coords[target_dim].values) == target_index +def test_as_datarray_with_tz_aware_series_index() -> None: + time_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + other_index = pd.Index(name="time", data=[0, 1, 2, 3]) + + panda_series = pd.Series(index=time_index, data=1.0) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index]) + result = as_dataarray(arr=panda_series, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index]) + with pytest.warns(CoordAlignWarning): + result = as_dataarray(arr=panda_series, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": time_index} + result = as_dataarray(arr=panda_series, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": [0, 1, 2, 3]} + result = as_dataarray(arr=panda_series, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + +def test_as_datarray_with_tz_aware_dataframe_columns_index() -> None: + time_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + other_index = pd.Index(name="time", data=[0, 1, 2, 3]) + + index = pd.Index([0, 1, 2, 3], name="x") + pandas_df = pd.DataFrame(index=index, columns=time_index, data=1.0) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index]) + result = as_dataarray(arr=pandas_df, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index]) + with pytest.warns(CoordAlignWarning): + result = as_dataarray(arr=pandas_df, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": time_index} + result = as_dataarray(arr=pandas_df, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": [0, 1, 2, 3]} + result = as_dataarray(arr=pandas_df, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + def test_as_dataarray_with_series_dims_subset() -> None: target_dim = "dim_0" target_index = ["a", "b", "c"] @@ -100,7 +167,7 @@ def test_as_dataarray_with_series_override_coords() -> None: target_dim = "dim_0" target_index = ["a", "b", "c"] s = pd.Series([1, 2, 3], index=target_index) - with pytest.warns(UserWarning): + with pytest.warns(CoordAlignWarning): da = as_dataarray(s, coords=[[1, 2, 3]]) assert isinstance(da, DataArray) assert da.dims == (target_dim,) @@ -219,7 +286,7 @@ def test_as_dataarray_dataframe_override_coords() -> None: target_index = ["a", "b"] target_columns = ["A", "B"] df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) - with pytest.warns(UserWarning): + with pytest.warns(CoordAlignWarning): da = as_dataarray(df, coords=[[1, 2], [2, 3]]) assert isinstance(da, DataArray) assert da.dims == target_dims diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 0da9ec7f..9dc4f462 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,6 +7,9 @@ from __future__ import annotations +from datetime import datetime +from zoneinfo import ZoneInfo + import numpy as np import pandas as pd import polars as pl @@ -15,11 +18,14 @@ from xarray.testing import assert_equal from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge +from linopy.common import TimezoneAlignError from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression from linopy.testing import assert_linequal, assert_quadequal from linopy.variables import ScalarVariable +UTC = ZoneInfo("UTC") + @pytest.fixture def m() -> Model: @@ -1230,6 +1236,30 @@ def test_cumsum(m: Model, multiple: float) -> None: cumsum.nterm == 2 +def test_timezone_alignment_failure() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + tz_naive_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=None, + name="time", + ) + model = Model() + series1 = pd.Series(index=tz_naive_index, data=1.0) + expr = model.add_variables(coords=[utc_index], name="var1") * 1.0 + + with pytest.raises(TimezoneAlignError): + # We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together + _ = expr * series1 + + def test_simplify_basic(x: Variable) -> None: """Test basic simplification with duplicate terms.""" expr = 2 * x + 3 * x + 1 * x diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fc1bb25f..7da5b436 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from datetime import datetime +from zoneinfo import ZoneInfo import numpy as np import pandas as pd @@ -8,10 +10,13 @@ from xarray import DataArray from linopy import Model, Variable, merge +from linopy.common import TimezoneAlignError from linopy.constants import FACTOR_DIM, TERM_DIM from linopy.expressions import LinearExpression, QuadraticExpression from linopy.testing import assert_quadequal +UTC = ZoneInfo("UTC") + @pytest.fixture def model() -> Model: @@ -360,3 +365,28 @@ def test_power_of_three(x: Variable) -> None: x**3 with pytest.raises(TypeError): (x * x) * (x * x) + + +def test_timezone_alignment_failure() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + tz_naive_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=None, + name="time", + ) + model = Model() + series1 = pd.Series(index=tz_naive_index, data=1.0) + var = model.add_variables(coords=[utc_index], name="var1") + expr = var * var + + with pytest.raises(TimezoneAlignError): + # We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together + _ = expr * series1 diff --git a/test/test_variables.py b/test/test_variables.py index 3984b091..cad558f8 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -3,6 +3,10 @@ This module aims at testing the correct behavior of the Variables class. """ +import warnings +from datetime import datetime +from zoneinfo import ZoneInfo + import numpy as np import pandas as pd import pytest @@ -12,9 +16,12 @@ import linopy from linopy import Model +from linopy.common import CoordAlignWarning, TimezoneAlignError from linopy.testing import assert_varequal from linopy.variables import ScalarVariable +UTC = ZoneInfo("UTC") + @pytest.fixture def m() -> Model: @@ -122,3 +129,48 @@ def test_scalar_variable(m: Model) -> None: x = ScalarVariable(label=0, model=m) assert isinstance(x, ScalarVariable) assert x.__rmul__(x) is NotImplemented # type: ignore + + +def test_timezone_alignment_with_multiplication() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + model = Model() + series1 = pd.Series(index=utc_index, data=1.0) + var1 = model.add_variables(coords=[utc_index], name="var1") + + with warnings.catch_warnings(): + warnings.simplefilter("error", CoordAlignWarning) + expr = var1 * series1 + + index: pd.DatetimeIndex = expr.coords["time"].to_index() + assert index.equals(utc_index) + assert index.tzinfo is UTC + + +def test_timezone_alignment_failure() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + tz_naive_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=None, + name="time", + ) + model = Model() + series1 = pd.Series(index=tz_naive_index, data=1.0) + var1 = model.add_variables(coords=[utc_index], name="var1") + + with pytest.raises(TimezoneAlignError): + # We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together + _ = var1 * series1