diff --git a/linopy/common.py b/linopy/common.py index 278f2c61..207645d6 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -18,7 +18,7 @@ import numpy as np import pandas as pd import polars as pl -from numpy import arange, signedinteger +from numpy import arange, nan, signedinteger from xarray import DataArray, Dataset, apply_ufunc, broadcast from xarray import align as xr_align from xarray.core import dtypes, indexing @@ -1393,3 +1393,51 @@ def is_constant(x: SideLike) -> bool: "Expected a constant, variable, or expression on the constraint side, " f"got {type(x)}." ) + + +def series_to_lookup_array(s: pd.Series) -> np.ndarray: + """ + Convert an integer-indexed Series to a dense numpy lookup array. + + Non-negative indices are placed at their corresponding positions; + negative indices are ignored. Gaps are filled with NaN. + + Parameters + ---------- + s : pd.Series + Series with an integer index. + + Returns + ------- + np.ndarray + Dense array of length ``max(index) + 1``. + """ + max_idx = max(int(s.index.max()), 0) + arr = np.full(max_idx + 1, nan) + mask = s.index >= 0 + arr[s.index[mask]] = s.values[mask] + return arr + + +def lookup_vals(arr: np.ndarray, idx: np.ndarray) -> np.ndarray: + """ + Look up values from a dense array by integer labels. + + Negative labels and labels beyond the array length map to NaN. + + Parameters + ---------- + arr : np.ndarray + Dense lookup array (e.g. from :func:`series_to_lookup_array`). + idx : np.ndarray + Integer label indices. + + Returns + ------- + np.ndarray + Array of looked-up values with the same shape as *idx*. + """ + valid = (idx >= 0) & (idx < len(arr)) + vals = np.full(idx.shape, nan) + vals[valid] = arr[idx[valid]] + return vals diff --git a/linopy/model.py b/linopy/model.py index fbc9ebc0..a1fdbb08 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -31,8 +31,10 @@ assign_multiindex_safe, best_int, broadcast_mask, + lookup_vals, maybe_replace_signs, replace_by_map, + series_to_lookup_array, set_int_index, to_path, ) @@ -1591,26 +1593,24 @@ def solve( sol = set_int_index(sol) sol.loc[-1] = nan - for name, var in self.variables.items(): - idx = np.ravel(var.labels) - try: - vals = sol[idx].values.reshape(var.labels.shape) - except KeyError: - vals = sol.reindex(idx).values.reshape(var.labels.shape) - var.solution = xr.DataArray(vals, var.coords) + sol_arr = series_to_lookup_array(sol) + + for _, var in self.variables.items(): + vals = lookup_vals(sol_arr, np.ravel(var.labels)) + var.solution = xr.DataArray(vals.reshape(var.labels.shape), var.coords) if not result.solution.dual.empty: dual = result.solution.dual.copy() dual = set_int_index(dual) dual.loc[-1] = nan - for name, con in self.constraints.items(): - idx = np.ravel(con.labels) - try: - vals = dual[idx].values.reshape(con.labels.shape) - except KeyError: - vals = dual.reindex(idx).values.reshape(con.labels.shape) - con.dual = xr.DataArray(vals, con.labels.coords) + dual_arr = series_to_lookup_array(dual) + + for _, con in self.constraints.items(): + vals = lookup_vals(dual_arr, np.ravel(con.labels)) + con.dual = xr.DataArray( + vals.reshape(con.labels.shape), con.labels.coords + ) return result.status.status.value, result.status.termination_condition.value finally: diff --git a/test/test_solution_lookup.py b/test/test_solution_lookup.py new file mode 100644 index 00000000..7dd9643f --- /dev/null +++ b/test/test_solution_lookup.py @@ -0,0 +1,73 @@ +import numpy as np +import pandas as pd +from numpy import nan + +from linopy.common import lookup_vals, series_to_lookup_array + + +class TestSeriesToLookupArray: + def test_basic(self) -> None: + s = pd.Series([10.0, 20.0, 30.0], index=pd.Index([0, 1, 2])) + arr = series_to_lookup_array(s) + np.testing.assert_array_equal(arr, [10.0, 20.0, 30.0]) + + def test_with_negative_index(self) -> None: + s = pd.Series([nan, 10.0, 20.0], index=pd.Index([-1, 0, 2])) + arr = series_to_lookup_array(s) + assert arr[0] == 10.0 + assert np.isnan(arr[1]) + assert arr[2] == 20.0 + + def test_sparse_index(self) -> None: + s = pd.Series([5.0, 7.0], index=pd.Index([0, 100])) + arr = series_to_lookup_array(s) + assert len(arr) == 101 + assert arr[0] == 5.0 + assert arr[100] == 7.0 + assert np.isnan(arr[50]) + + def test_only_negative_index(self) -> None: + s = pd.Series([nan], index=pd.Index([-1])) + arr = series_to_lookup_array(s) + assert len(arr) == 1 + assert np.isnan(arr[0]) + + +class TestLookupVals: + def test_basic(self) -> None: + arr = np.array([10.0, 20.0, 30.0]) + idx = np.array([0, 1, 2]) + result = lookup_vals(arr, idx) + np.testing.assert_array_equal(result, [10.0, 20.0, 30.0]) + + def test_negative_labels_become_nan(self) -> None: + arr = np.array([10.0, 20.0]) + idx = np.array([0, -1, 1, -1]) + result = lookup_vals(arr, idx) + assert result[0] == 10.0 + assert np.isnan(result[1]) + assert result[2] == 20.0 + assert np.isnan(result[3]) + + def test_out_of_range_labels_become_nan(self) -> None: + arr = np.array([10.0, 20.0]) + idx = np.array([0, 1, 999]) + result = lookup_vals(arr, idx) + assert result[0] == 10.0 + assert result[1] == 20.0 + assert np.isnan(result[2]) + + def test_all_negative(self) -> None: + arr = np.array([10.0]) + idx = np.array([-1, -1, -1]) + result = lookup_vals(arr, idx) + assert all(np.isnan(result)) + + def test_no_mutation_of_source(self) -> None: + arr = np.array([10.0, 20.0, 30.0]) + idx1 = np.array([-1, 1]) + idx2 = np.array([0, 2]) + lookup_vals(arr, idx1) + result2 = lookup_vals(arr, idx2) + np.testing.assert_array_equal(result2, [10.0, 30.0]) + np.testing.assert_array_equal(arr, [10.0, 20.0, 30.0])