Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions linopy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,16 +423,26 @@ def sos_to_file(

for name in names:
var = m.variables[name]
sos_type = var.attrs[SOS_TYPE_ATTR]
sos_dim = var.attrs[SOS_DIM_ATTR]
sos_type = int(var.attrs[SOS_TYPE_ATTR]) # type: ignore[call-overload]
sos_dim = str(var.attrs[SOS_DIM_ATTR])

other_dims = [dim for dim in var.labels.dims if dim != sos_dim]
for var_slice in var.iterate_slices(slice_size, other_dims):
ds = var_slice.labels.to_dataset()
ds["sos_labels"] = ds["labels"].isel({sos_dim: 0})
# Per-set id: max of labels along the SOS dim. Real labels are
# non-negative and globally unique, so max yields a valid,
# unique-per-set id whenever the set has any unmasked slot.
# Fully-masked sets get id -1 and are filtered out below.
ds["sos_labels"] = ds["labels"].max(sos_dim)
ds["weights"] = ds.coords[sos_dim]
df = to_polars(ds)

# Drop masked member rows so the LP file never emits `x-1`, and
# drop any rows belonging to a fully-masked set (sos_labels == -1).
df = df.filter((pl.col("labels") != -1) & (pl.col("sos_labels") != -1))
if df.is_empty():
continue

df = df.group_by("sos_labels").agg(
pl.concat_str(
*print_variable(pl.col("labels")), pl.lit(":"), pl.col("weights")
Expand Down Expand Up @@ -592,8 +602,6 @@ def to_file(
"""
Write out a model to a lp or mps file.
"""
m._check_sos_unmasked()

if fn is None:
fn = Path(m.get_problem_file())
if isinstance(fn, str):
Expand Down
28 changes: 0 additions & 28 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,34 +1262,6 @@ def undo_sos_reformulation(self) -> None:
self._sos_reformulation_state = None
undo_sos_reformulation(self, state)

def _check_sos_unmasked(self) -> None:
"""
Reject the model if any SOS variable has masked entries.

The SOS plumbing (both direct-API solvers and the LP file writer) treats
linopy variable labels as solver column indices / names, which breaks as
soon as a label is ``-1`` (linopy's ``FILL_VALUE["labels"]`` for masked
slots). The downstream symptoms are solver-specific — ``IndexError`` on
gurobipy, ``?404 Invalid column number`` on xpress, parse errors on
xpress/cplex LP readers, silent SOS-set corruption on gurobi's LP reader.

Surface a single clear error until #688 lands the proper fix.
"""
if not self.variables.sos:
return
affected = [
name
for name in self.variables.sos
if (self.variables[name].labels.values == -1).any()
]
if affected:
raise NotImplementedError(
f"SOS constraints on masked variables are not yet supported "
f"(affected: {affected}; "
"see https://github.com/PyPSA/linopy/issues/688). "
"Pass reformulate_sos=True as a workaround."
)

def remove_objective(self) -> None:
"""
Remove the objective's linear expression from the model.
Expand Down
114 changes: 74 additions & 40 deletions linopy/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from abc import ABC
from collections import namedtuple
from collections.abc import Callable, Generator, Iterator, Sequence
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from dataclasses import dataclass, field
from enum import Enum, auto
from importlib.metadata import PackageNotFoundError
Expand Down Expand Up @@ -106,6 +106,75 @@ def _solution_from_labels(
return values_to_lookup_array(np.asarray(values, dtype=float), labels, size=size)


def _sos_set_positions(
labels: np.ndarray, weights: np.ndarray, label_to_pos: np.ndarray
) -> tuple[list[int], list[float]]:
"""
Convert a SOS set's linopy labels to solver column positions.

Direct-API solvers (gurobi, xpress) accept SOS members as 0-based column
positions in the solver's variable array, which corresponds to the active
(non-masked) variable order — i.e., the order of
``model.variables.label_index.vlabels``. Masked entries (label ``-1``) are
dropped along with their weights.

Parameters
----------
labels : np.ndarray
Flat array of linopy labels for the SOS members.
weights : np.ndarray
Matching weights; same length as ``labels``.
label_to_pos : np.ndarray
``model.variables.label_index.label_to_pos`` — lookup of label →
active-variable position.

Returns
-------
tuple[list[int], list[float]]
Solver column positions and matching weights, with masked entries
removed.
"""
mask = labels != -1
return (
label_to_pos[labels[mask]].tolist(),
weights[mask].tolist(),
)


def _iter_sos_sets(model: Model) -> Iterator[tuple[int, list[int], list[float]]]:
"""
Yield ``(sos_type, positions, weights)`` per active SOS set in ``model``.

Iterates 1D SOS variables as a single set and multi-dim SOS variables as
one set per non-SOS-dim coordinate. Masked members are dropped, surviving
linopy labels are resolved to solver column positions via
``_sos_set_positions``, and empty sets are skipped.

Shared between direct-API solvers (Gurobi, Xpress). Each solver only
differs in the vendor ``addSOS`` call.
"""
label_to_pos = model.variables.label_index.label_to_pos
for var_name in model.variables.sos:
var = model.variables.sos[var_name]
sos_type = int(var.attrs[SOS_TYPE_ATTR]) # type: ignore[call-overload]
sos_dim = str(var.attrs[SOS_DIM_ATTR])
others = [d for d in var.labels.dims if d != sos_dim]

if not others:
sets: Iterable[xr.DataArray] = [var.labels]
else:
stacked = var.labels.stack(_sos_group=others)
sets = (s.unstack("_sos_group") for _, s in stacked.groupby("_sos_group"))

for s in sets:
s = s.squeeze()
labels = s.values.flatten()
weights = s.coords[sos_dim].values
positions, kept_weights = _sos_set_positions(labels, weights, label_to_pos)
if positions:
yield sos_type, positions, kept_weights


class SolverFeature(Enum):
"""Enumeration of all solver capabilities tracked by linopy."""

Expand Down Expand Up @@ -517,7 +586,6 @@ def _build(self, **build_kwargs: Any) -> None:
if self.model is None:
raise RuntimeError("Solver has no model attached; cannot build.")
self._validate_model()
self.model._check_sos_unmasked()
if self.io_api == "direct":
self._build_direct(**build_kwargs)
else:
Expand Down Expand Up @@ -1580,25 +1648,8 @@ def _build_solver_model(
names = print_constraints(M.clabels)
c.setAttr("ConstrName", names)

if model.variables.sos:
for var_name in model.variables.sos:
var = model.variables.sos[var_name]
sos_type: int = var.attrs[SOS_TYPE_ATTR] # type: ignore[assignment]
sos_dim: str = var.attrs[SOS_DIM_ATTR] # type: ignore[assignment]

def add_sos(s: xr.DataArray, sos_type: int, sos_dim: str) -> None:
s = s.squeeze()
indices = s.values.flatten().tolist()
weights = s.coords[sos_dim].values.tolist()
gm.addSOS(sos_type, x[indices].tolist(), weights)

others = [dim for dim in var.labels.dims if dim != sos_dim]
if not others:
add_sos(var.labels, sos_type, sos_dim)
else:
stacked = var.labels.stack(_sos_group=others)
for _, s in stacked.groupby("_sos_group"):
add_sos(s.unstack("_sos_group"), sos_type, sos_dim)
for sos_type, positions, weights in _iter_sos_sets(model):
gm.addSOS(sos_type, x[positions].tolist(), weights)

gm.update()
return gm
Expand Down Expand Up @@ -2222,25 +2273,8 @@ def _build_solver_model(
if cnames:
problem.addnames(xpress_Namespaces.ROW, cnames, 0, len(cnames) - 1)

if model.variables.sos:
for var_name in model.variables.sos:
var = model.variables.sos[var_name]
sos_type: int = var.attrs[SOS_TYPE_ATTR] # type: ignore[assignment]
sos_dim: str = var.attrs[SOS_DIM_ATTR] # type: ignore[assignment]

def add_sos(s: xr.DataArray, sos_type: int, sos_dim: str) -> None:
s = s.squeeze()
indices = s.values.flatten().tolist()
weights = s.coords[sos_dim].values.tolist()
problem.addSOS(indices, weights, type=sos_type)

others = [dim for dim in var.labels.dims if dim != sos_dim]
if not others:
add_sos(var.labels, sos_type, sos_dim)
else:
stacked = var.labels.stack(_sos_group=others)
for _, s in stacked.groupby("_sos_group"):
add_sos(s.unstack("_sos_group"), sos_type, sos_dim)
for sos_type, positions, weights in _iter_sos_sets(model):
problem.addSOS(positions, weights, type=sos_type)

return problem

Expand Down
19 changes: 11 additions & 8 deletions test/test_piecewise_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,23 +2313,26 @@ def test_lp_per_entity_nan_padding(
Per-entity NaN-padded breakpoints with method='lp': padded
segments must be masked out so they don't create spurious
``y ≤ 0`` constraints (bug-2 regression).

``method='sos2'`` would emit a masked SOS lambda variable, which the
native SOS path doesn't yet support (#688) — exercised separately in
:py:meth:`test_sos2_per_entity_nan_padding_errors`.
"""
m = nan_padded_pwl_model("lp")
m.solve()
# f_b(10) on chord (5,10)→(15,15) is 12.5
assert abs(float(m.solution.sel({"entity": "b"})["y"]) - 12.5) < 1e-3

def test_sos2_per_entity_nan_padding_errors(
def test_sos2_per_entity_nan_padding(
self, nan_padded_pwl_model: Callable[[Method], Model]
) -> None:
"""Masked SOS lambdas hit the #688 guard at solve time."""
"""
Per-entity NaN-padded breakpoints with method='sos2': the SOS
lambda variable's masked entries must flow through both the
direct API (via label→position resolution) and the LP writer
(via masked-member filtering) so the solve returns the same
answer as ``method='lp'``. Regression for #688.
"""
m = nan_padded_pwl_model("sos2")
with pytest.raises(NotImplementedError, match="masked"):
m.solve()
m.solve()
# f_b(10) on chord (5,10)→(15,15) is 12.5 — same oracle as lp variant
assert abs(float(m.solution.sel({"entity": "b"})["y"]) - 12.5) < 1e-3

def test_lp_rejects_decreasing_x_concave_ge(self) -> None:
"""
Expand Down
73 changes: 0 additions & 73 deletions test/test_sos_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,6 @@
import xarray as xr

from linopy import Model, available_solvers
from linopy.solver_capabilities import (
SolverFeature,
get_available_solvers_with_feature,
solver_supports,
)

_direct_sos_solvers = [
s
for s in get_available_solvers_with_feature(
SolverFeature.SOS_CONSTRAINTS, available_solvers
)
if solver_supports(s, SolverFeature.DIRECT_API)
]


def test_add_sos_constraints_registers_variable() -> None:
Expand Down Expand Up @@ -209,66 +196,6 @@ def test_qp_sos1_xpress_direct() -> None:
assert np.isclose(m.objective.value, -25)


@pytest.fixture
def masked_sos_model() -> Model:
"""Tiny model with a single masked SOS1 variable."""
m = Model()
coords = pd.Index([0, 1, 2, 3], name="i")
mask = pd.Series([True, True, False, True], index=coords)
var = m.add_variables(lower=0, upper=1, coords=[coords], mask=mask, name="sos_var")
m.add_sos_constraints(var, sos_type=1, sos_dim="i")
m.add_objective(-var.sum())
return m


@pytest.mark.parametrize("solver_name", _direct_sos_solvers)
def test_direct_api_raises_on_masked_sos(
solver_name: str, masked_sos_model: Model
) -> None:
with pytest.raises(NotImplementedError, match="masked"):
masked_sos_model.solve(solver_name=solver_name, io_api="direct")


def test_lp_writer_raises_on_masked_sos(
masked_sos_model: Model, tmp_path: Path
) -> None:
with pytest.raises(NotImplementedError, match="masked"):
masked_sos_model.to_file(tmp_path / "sos.lp", io_api="lp")


@pytest.mark.parametrize(
"solver_name",
[
pytest.param(
"gurobi",
marks=pytest.mark.skipif(
"gurobi" not in available_solvers, reason="Gurobi not installed"
),
),
pytest.param(
"highs",
marks=pytest.mark.skipif(
"highs" not in available_solvers, reason="HiGHS not installed"
),
),
],
)
def test_reformulate_sos_true_solves_masked_sos(
solver_name: str, masked_sos_model: Model
) -> None:
"""The documented workaround for the masked-SOS bug actually solves."""
masked_sos_model.solve(solver_name=solver_name, reformulate_sos=True)
sol = masked_sos_model.variables["sos_var"].solution.values
# SOS1 over 3 unmasked entries, max sum, each in [0, 1]:
# one entry == 1, others == 0, masked stays NaN.
assert masked_sos_model.objective.value is not None
assert np.isclose(masked_sos_model.objective.value, -1.0)
assert np.isnan(sol[2])
nonzero = np.flatnonzero(~np.isnan(sol) & (sol > 1e-6))
assert len(nonzero) == 1
assert np.isclose(sol[nonzero[0]], 1.0)


@pytest.mark.skipif("gurobi" not in available_solvers, reason="Gurobi not installed")
def test_reformulate_sos_true_reformulates_on_native_solver(tmp_path: Path) -> None:
"""
Expand Down
Loading