From 26d0d42109f1689d387da6b6ed4d0d8df37eb3c7 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 18 May 2026 16:03:35 +0200 Subject: [PATCH 1/4] refactor(sos): add Model.apply/undo_sos_reformulation methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a stateful pair of methods on Model that own the SOS reformulation lifecycle: - apply_sos_reformulation() stashes the reformulation token on the model (new _sos_reformulation_state attribute). Raises if already applied. - undo_sos_reformulation() reads the stashed token and restores the original SOS form. No-op if nothing is applied. Model.solve(reformulate_sos=...) now delegates to these methods rather than threading the token through local state. The Solver path (which was previously raising via Model.solve's pre-flight check) now gets a clean ValueError directly from Solver._build() when an SOS-bearing model is handed to a solver without native SOS support — making the low-level API safe to use independently of Model.solve. Persistence: - copy() (and copy.copy / copy.deepcopy) carry the reformulation token with a deepcopy, so the copy is independently undoable. - to_netcdf() raises if a reformulation is active; users must undo first to serialize a stable model state. Context: motivated by the same investigation as PyPSA/linopy#688 — while reviewing the new Solver.from_model() API surface introduced by #682, the SOS reformulation lifecycle stood out as load-bearing orchestration that the Solver path couldn't reproduce. Co-Authored-By: Claude Opus 4.7 (1M context) --- linopy/io.py | 17 +++++ linopy/model.py | 52 +++++++++++--- linopy/solvers.py | 8 +++ test/test_sos_constraints.py | 7 +- test/test_sos_reformulation.py | 124 +++++++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 14 deletions(-) diff --git a/linopy/io.py b/linopy/io.py index 36d7abb3..ba0400b6 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy as _copy import json import logging import shutil @@ -828,7 +829,20 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: Arguments passed to ``xarray.Dataset.to_netcdf``. **kwargs : TYPE Keyword arguments passed to ``xarray.Dataset.to_netcdf``. + + Raises + ------ + RuntimeError + If the model has an active SOS reformulation. Call + :meth:`Model.undo_sos_reformulation` before serializing. """ + if m._sos_reformulation_state is not None: + raise RuntimeError( + "Cannot serialize a model with an active SOS reformulation. " + "Call `model.undo_sos_reformulation()` first to restore the " + "original SOS form, or save the model in its reformulated form " + "after explicitly clearing `model._sos_reformulation_state`." + ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: to_rename = set([*ds.dims, *ds.coords, *ds]) @@ -1100,6 +1114,9 @@ def copy(m: Model, include_solution: bool = False, deep: bool = True) -> Model: if include_solution or attr not in SOLVE_STATE_ATTRS: setattr(new_model, attr, getattr(m, attr)) + if m._sos_reformulation_state is not None: + new_model._sos_reformulation_state = _copy.deepcopy(m._sos_reformulation_state) + return new_model diff --git a/linopy/model.py b/linopy/model.py index 4eb91fc6..be572279 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -89,6 +89,7 @@ available_solvers, ) from linopy.sos_reformulation import ( + SOSReformulationResult, reformulate_sos_constraints, undo_sos_reformulation, ) @@ -239,6 +240,7 @@ class Model: "_relaxed_registry", "_piecewise_formulations", "_solver", + "_sos_reformulation_state", "__weakref__", ) @@ -309,6 +311,7 @@ def __init__( gettempdir() if solver_dir is None else solver_dir ) self._solver: solvers.Solver | None = None + self._sos_reformulation_state: SOSReformulationResult | None = None @property def solver(self) -> solvers.Solver | None: @@ -1220,6 +1223,39 @@ def remove_sos_constraints(self, variable: Variable) -> None: reformulate_sos_constraints = reformulate_sos_constraints + def apply_sos_reformulation(self) -> None: + """ + Reformulate SOS constraints into binary + linear form, in place. + + The reformulation token is stored on the model so it can be reverted + with :meth:`undo_sos_reformulation`. This is the stateful counterpart + to :func:`linopy.sos_reformulation.reformulate_sos_constraints`, where + the caller owns the token. + + Raises + ------ + RuntimeError + If a reformulation has already been applied and not undone. + """ + if self._sos_reformulation_state is not None: + raise RuntimeError( + "SOS reformulation has already been applied to this model. " + "Call `undo_sos_reformulation()` before applying again." + ) + self._sos_reformulation_state = reformulate_sos_constraints(self) + + def undo_sos_reformulation(self) -> None: + """ + Revert a previously applied SOS reformulation. + + No-op if no reformulation is currently applied. + """ + if self._sos_reformulation_state is None: + return + state = self._sos_reformulation_state + self._sos_reformulation_state = None + undo_sos_reformulation(self, state) + def remove_objective(self) -> None: """ Remove the objective's linear expression from the model. @@ -1711,22 +1747,20 @@ def solve( "Must be True, False, or 'auto'." ) - sos_reform_result = None + applied_sos_reformulation_here = False if self.variables.sos: supports_sos = solver_class.supports(SolverFeature.SOS_CONSTRAINTS) if reformulate_sos in (True, "auto") and not supports_sos: logger.info(f"Reformulating SOS constraints for solver {solver_name}") - sos_reform_result = reformulate_sos_constraints(self) + self.apply_sos_reformulation() + applied_sos_reformulation_here = True elif reformulate_sos is True and supports_sos: logger.warning( f"Solver {solver_name} supports SOS natively; " "reformulate_sos=True is ignored." ) - elif reformulate_sos is False and not supports_sos: - raise ValueError( - f"Solver {solver_name} does not support SOS constraints. " - "Use reformulate_sos=True or 'auto', or a solver that supports SOS (gurobi, cplex)." - ) + # If SOS is present and the solver doesn't support it (and the user + # didn't ask for reformulation), Solver._build() will raise. if self.variables.semi_continuous: if not solver_class.supports(SolverFeature.SEMI_CONTINUOUS_VARIABLES): @@ -1778,8 +1812,8 @@ def solve( try: return self.assign_result(result) finally: - if sos_reform_result is not None: - undo_sos_reformulation(self, sos_reform_result) + if applied_sos_reformulation_here: + self.undo_sos_reformulation() def assign_result(self, result: Result) -> tuple[str, str]: result.info() diff --git a/linopy/solvers.py b/linopy/solvers.py index 548db835..0ca5d956 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -507,6 +507,14 @@ def _build(self, **build_kwargs: Any) -> None: """Dispatch to direct or file build based on ``io_api``.""" if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") + if self.model.variables.sos and not type(self).supports( + SolverFeature.SOS_CONSTRAINTS + ): + raise ValueError( + f"Solver {self.solver_name.value} does not support SOS constraints. " + "Call `model.apply_sos_reformulation()` first, or use a solver that " + "supports SOS." + ) if self.io_api == "direct": self._build_direct(**build_kwargs) else: diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py index 5d94162e..3c3e79f6 100644 --- a/test/test_sos_constraints.py +++ b/test/test_sos_constraints.py @@ -150,7 +150,7 @@ def test_unsupported_solver_raises_error() -> None: m.solve(solver_name=solver) -def test_to_highspy_raises_not_implemented() -> None: +def test_to_highspy_raises_when_sos_present() -> None: pytest.importorskip("highspy") m = Model() @@ -158,8 +158,5 @@ def test_to_highspy_raises_not_implemented() -> None: build = m.add_variables(coords=[locations], name="build", binary=True) m.add_sos_constraints(build, sos_type=1, sos_dim="locations") - with pytest.raises( - NotImplementedError, - match="SOS constraints are not supported by the HiGHS direct API", - ): + with pytest.raises(ValueError, match="does not support SOS constraints"): m.to_highspy() diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 24ba62b3..f2f48075 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging +from collections.abc import Callable +from pathlib import Path import numpy as np import pandas as pd @@ -312,6 +314,128 @@ def test_reformulate_inplace(self) -> None: assert "_sos_reform_x_y" in m.variables +class TestApplyUndoSOSReformulation: + """Tests for Model.apply_sos_reformulation / undo_sos_reformulation.""" + + @staticmethod + def _build_sos1_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + return m + + def test_apply_stashes_state(self) -> None: + m = self._build_sos1_model() + assert m._sos_reformulation_state is None + + m.apply_sos_reformulation() + + assert m._sos_reformulation_state is not None + assert m._sos_reformulation_state.reformulated == ["x"] + assert len(list(m.variables.sos)) == 0 + assert "_sos_reform_x_y" in m.variables + + def test_undo_restores_and_clears_state(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + m.undo_sos_reformulation() + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + def test_double_apply_raises(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="already been applied"): + m.apply_sos_reformulation() + + def test_undo_without_apply_is_noop(self) -> None: + m = self._build_sos1_model() + assert m._sos_reformulation_state is None + + m.undo_sos_reformulation() # should not raise + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + + @pytest.mark.parametrize( + "copy_fn", + [ + pytest.param(lambda m: m.copy(), id="model.copy()"), + pytest.param(lambda m: __import__("copy").copy(m), id="copy.copy(model)"), + pytest.param( + lambda m: __import__("copy").deepcopy(m), id="copy.deepcopy(model)" + ), + ], + ) + def test_copy_persists_state_and_undo_works_on_copy( + self, copy_fn: Callable[[Model], Model] + ) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + c = copy_fn(m) + + # State is carried over but is an independent object + assert c._sos_reformulation_state is not None + assert c._sos_reformulation_state is not m._sos_reformulation_state + # Aux vars/cons exist on the copy (they were copied as part of the + # reformulated model state) + assert "_sos_reform_x_y" in c.variables + assert "_sos_reform_x_upper" in c.constraints + assert "_sos_reform_x_card" in c.constraints + # SOS attrs are not on the copy's "x" yet (still in reformulated form) + assert "x" not in list(c.variables.sos) + + # Undo on the copy fully restores the original SOS form + c.undo_sos_reformulation() + assert c._sos_reformulation_state is None + assert list(c.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in c.variables + assert "_sos_reform_x_upper" not in c.constraints + assert "_sos_reform_x_card" not in c.constraints + + # Original is entirely unaffected + assert m._sos_reformulation_state is not None + assert "_sos_reform_x_y" in m.variables + assert len(list(m.variables.sos)) == 0 + + def test_to_netcdf_raises_when_state_active(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="active SOS reformulation"): + m.to_netcdf(tmp_path / "m.nc") + + def test_to_netcdf_works_after_undo(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + m.undo_sos_reformulation() + + m.to_netcdf(tmp_path / "m.nc") # should not raise + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolverPathSOSCheck: + """Solver._build() must raise on SOS-bearing model with non-SOS solver.""" + + def test_solver_from_name_raises_without_reformulation(self) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + with pytest.raises(ValueError, match="does not support SOS"): + solvers.Solver.from_name("highs", m, io_api="lp") + + @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") class TestSolveWithReformulation: """Tests for solving with SOS reformulation.""" From 68d75cd8fee240d85443d8165b366536a3b78f16 Mon Sep 17 00:00:00 2001 From: Felix <117816358+FBumann@users.noreply.github.com> Date: Mon, 18 May 2026 20:53:06 +0200 Subject: [PATCH 2/4] refactor(solver): validation, sanitize kwargs, and result wiring on Solver path (#691) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(solver): lift feature checks + sanitize/wiring to Solver path Make Solver.from_name(...).solve() a real first-class entry point that doesn't lose Model.solve()'s safety nets: - Lift solver-feature gates into Solver._build() via a new _validate_model() hook: quadratic models against LP-only solvers and semi-continuous variables against solvers that don't support them. Removed the duplicate checks from Model.solve(). - Add sanitize_zeros / sanitize_infinities kwargs to Solver.from_model() (default True). The kwargs are processed in _build() before dispatch, so both file and direct io_apis honor them. Model.solve() forwards the kwargs through instead of pre-mutating the constraints itself. - Extend Model.assign_result(result, solver=None) so the Solver-path canonical pattern works: solver = Solver.from_name(...); result = solver.solve(); model.assign_result(result, solver=solver). When the solver kwarg is provided, model.solver gets wired the same way Model.solve() wires it, so compute_infeasibilities() and friends keep working through the low-level API. The empty-objective check stays on Model.solve() — to_gurobipy() / to_highspy() and similar build-only converters legitimately work against objectiveless models (gurobi/highs default to a zero objective), so the check belongs at the actual submit point. Co-Authored-By: Claude Opus 4.7 (1M context) * move empty-objective check to Solver.solve() for entry-point parity The empty-objective UX guardrail was previously only on Model.solve(), leaving the lower-level Solver.from_name(...).solve() path with a silent gap. Move it to Solver.solve() — the actual submit primitive that both entry points go through — so the same check fires regardless of which API the user reaches for. Build-time translate-only paths (to_gurobipy(), to_highspy(), to_file()) are unaffected since they don't call solve(). The cost of catching the error after build instead of before is bounded and only hits a programming-error case. Co-Authored-By: Claude Opus 4.7 (1M context) * test: parametrize empty-objective check across both entry points Consolidate the Model.solve() and Solver.from_name(...).solve() tests into one parametrized case — same check, two callers, one assertion. Co-Authored-By: Claude Opus 4.7 (1M context) * test: collapse parametrize to a single test with two raises blocks Same property tested twice — no need for separate test IDs. Co-Authored-By: Claude Opus 4.7 (1M context) * preserve empty-objective check for remote-solve path in Model.solve() The remote-solve branch in Model.solve() short-circuits to a RemoteHandler before reaching Solver.solve(), so the check now in Solver.solve() doesn't cover it. Restore the early raise in Model.solve() so behavior is unchanged for all Model.solve() callers (mock, remote, local) while Solver.solve() still covers direct-Solver callers. Co-Authored-By: Claude Opus 4.7 (1M context) * move remote-path empty-objective check inside the remote branch The early-position check was a workaround: the remote branch short-circuits before Solver.solve() (where the canonical check now lives), so empty-objective with remote=... wouldn't raise. Moving it into the remote branch itself makes the intent local to where it's needed, with a comment pointing at #683 where this duplication disappears once OETC becomes a Solver subclass. Co-Authored-By: Claude Opus 4.7 (1M context) * keep sanitize on Model; Solver.from_model() stays mutation-free Remove the sanitize_zeros / sanitize_infinities kwargs from Solver.from_model(). The Solver builder now never mutates the model. Sanitization is exposed where it has always lived — model.constraints.sanitize_zeros() / .sanitize_infinities() — and Model.solve() calls them inline as part of its orchestration. Rationale: model-state transformations should be Model-level primitives (matches the SOS reformulation pattern from #690). The Solver's job is to translate the model and run; it should not silently change the caller's model on the way in. Users who go through the lower-level Solver path apply sanitize explicitly when they want it. Replaces TestSanitizeKwargs with TestSolverDoesNotMutateModel, pinning the mutation-free invariant: building a Solver against a model with a near-zero coefficient leaves model.constraints["c"].coeffs unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) * address review: SOS hint, lp_only_solver fixture, assign_result doc --------- Co-authored-by: Claude Opus 4.7 (1M context) Co-authored-by: Fabian --- linopy/model.py | 68 ++++++++++++++++++----------- linopy/solvers.py | 72 ++++++++++++++++++++++++++----- test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 36 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 4a11558a..12a46206 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1678,12 +1678,6 @@ def solve( sanitize_zeros=sanitize_zeros, sanitize_infinities=sanitize_infinities ) - if self.objective.expression.empty: - raise ValueError( - "No objective has been set on the model. Use `m.add_objective(...)` " - "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." - ) - # check io_api if io_api is not None and io_api not in IO_APIS: raise ValueError( @@ -1691,6 +1685,16 @@ def solve( ) if remote is not None: + # The remote branch short-circuits before reaching Solver.solve(), + # which is where the empty-objective check normally fires. Replicate + # it here. This duplication becomes obsolete once OETC is folded + # into the Solver pipeline (see PyPSA/linopy#683). + if self.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use " + "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " + "for a pure feasibility problem)." + ) if isinstance(remote, OetcHandler): solved = remote.solve_on_oetc( self, solver_name=solver_name, **solver_options @@ -1756,19 +1760,6 @@ def solve( else: solution_fn = self.get_solution_file() - if sanitize_zeros: - self.constraints.sanitize_zeros() - - if sanitize_infinities: - self.constraints.sanitize_infinities() - - if self.is_quadratic and not solver_class.supports( - SolverFeature.QUADRATIC_OBJECTIVE - ): - raise ValueError( - f"Solver {solver_name} does not support quadratic problems." - ) - if reformulate_sos not in (True, False, "auto"): raise ValueError( f"Invalid value for reformulate_sos: {reformulate_sos!r}. " @@ -1789,12 +1780,10 @@ def solve( # If SOS is present and the solver doesn't support it (and the user # didn't ask for reformulation), Solver._build() will raise. - if self.variables.semi_continuous: - if not solver_class.supports(SolverFeature.SEMI_CONTINUOUS_VARIABLES): - raise ValueError( - f"Solver {solver_name} does not support semi-continuous variables. " - "Use a solver that supports them (gurobi, cplex, highs)." - ) + if sanitize_zeros: + self.constraints.sanitize_zeros() + if sanitize_infinities: + self.constraints.sanitize_infinities() try: self.solver = None # closes any previous solver @@ -1842,7 +1831,34 @@ def solve( if applied_sos_reformulation_here: self.undo_sos_reformulation() - def assign_result(self, result: Result) -> tuple[str, str]: + def assign_result( + self, + result: Result, + solver: solvers.Solver | None = None, + ) -> tuple[str, str]: + """ + Write a solver Result back onto the model. + + Copies primal / dual values onto variables / constraints, sets + :attr:`status`, :attr:`termination_condition`, and + :attr:`objective.value`. When ``solver`` is provided, also stores it on + ``self.solver`` so post-solve introspection (``model.solver_model``, + ``compute_infeasibilities()``) works. + + Parameters + ---------- + result : Result + The :class:`linopy.constants.Result` returned by + :meth:`linopy.solvers.Solver.solve`. + solver : Solver, optional + The solver instance that produced the result. Pass it on the + low-level ``Solver.from_name(...).solve()`` path to attach it as + ``self.solver`` for post-solve introspection. ``Model.solve()`` + attaches the solver itself and does not pass this argument. + """ + if solver is not None: + self.solver = solver + result.info() if result.solution is not None: diff --git a/linopy/solvers.py b/linopy/solvers.py index 60669b9c..0fbd2c12 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -504,23 +504,52 @@ def from_model( return instance def _build(self, **build_kwargs: Any) -> None: - """Dispatch to direct or file build based on ``io_api``.""" + """ + Dispatch to direct or file build based on ``io_api``. + + The Solver never mutates ``self.model``. Constraint sanitization + (``model.constraints.sanitize_zeros()`` / + ``.sanitize_infinities()``) and SOS reformulation + (``model.apply_sos_reformulation()``) are Model-level operations + the caller applies first; this builder consumes whatever shape it + is handed. + """ if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") self.model._check_sos_unmasked() - if self.model.variables.sos and not type(self).supports( - SolverFeature.SOS_CONSTRAINTS - ): - raise ValueError( - f"Solver {self.solver_name.value} does not support SOS constraints. " - "Call `model.apply_sos_reformulation()` first, or use a solver that " - "supports SOS." - ) + self._validate_model() if self.io_api == "direct": self._build_direct(**build_kwargs) else: self._build_file(**build_kwargs) + def _validate_model(self) -> None: + """Pre-build checks on whether this solver can handle ``self.model``.""" + model = self.model + assert model is not None + solver_name = self.solver_name.value + cls = type(self) + + if model.is_quadratic and not cls.supports(SolverFeature.QUADRATIC_OBJECTIVE): + raise ValueError( + f"Solver {solver_name} does not support quadratic problems." + ) + + if model.variables.semi_continuous and not cls.supports( + SolverFeature.SEMI_CONTINUOUS_VARIABLES + ): + raise ValueError( + f"Solver {solver_name} does not support semi-continuous variables. " + "Use a solver that supports them (gurobi, cplex, highs)." + ) + + if model.variables.sos and not cls.supports(SolverFeature.SOS_CONSTRAINTS): + raise ValueError( + f"Solver {solver_name} does not support SOS constraints. " + "Reformulate first via `Model.solve(reformulate_sos=True)` or " + "`model.apply_sos_reformulation()`, or use a solver that supports SOS." + ) + def _build_direct(self, **build_kwargs: Any) -> None: """Build the native solver model from ``self.model``. Override per-solver.""" raise NotImplementedError( @@ -561,7 +590,30 @@ def _build_file(self, **build_kwargs: Any) -> None: self._cache_model_sizes(model) def solve(self, **run_kwargs: Any) -> Result: - """Run the prepared solver and return a :class:`Result`.""" + """ + Run the prepared solver and return a :class:`Result`. + + The canonical low-level pattern is:: + + solver = Solver.from_name("gurobi", model, io_api="direct") + result = solver.solve() + model.assign_result(result, solver=solver) + + Passing ``solver=`` to :meth:`Model.assign_result` wires + ``model.solver`` so post-solve helpers like + :meth:`Model.compute_infeasibilities` keep working. + + Raises + ------ + ValueError + If the attached model has no objective set. Submit-time check + shared by both ``Model.solve()`` and direct-Solver callers. + """ + if self.model is not None and self.model.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use `m.add_objective(...)` " + "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." + ) if self.io_api == "direct" or self.solver_model is not None: return self._run_direct(**run_kwargs) if self._problem_fn is not None: diff --git a/test/test_solvers.py b/test/test_solvers.py index db894137..1b6bd9a9 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -23,6 +23,14 @@ from linopy.solvers import _installed_version_in +@pytest.fixture +def lp_only_solver() -> str: + for name in ("glpk", "cbc"): + if name in solvers.available_solvers: + return name + pytest.skip("Need an LP-only solver (glpk or cbc) installed") + + @pytest.fixture def simple_model() -> Model: m = Model(chunk=None) @@ -464,3 +472,95 @@ def test_xpress_gpu_feature_reflects_installed_version() -> None: assert solvers.Xpress.supports( SolverFeature.GPU_ACCELERATION ) == _installed_version_in("xpress", ">=9.8.0") + + +class TestValidateModelOnBuild: + """Solver._build() runs solver-feature checks regardless of entry point.""" + + def test_quadratic_without_qp_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x * x, sense="min") + + with pytest.raises(ValueError, match="does not support quadratic"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + def test_semi_continuous_without_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=1, upper=10, semi_continuous=True) + m.add_objective(x) + + with pytest.raises(ValueError, match="does not support semi-continuous"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_solve_without_objective_raises(self) -> None: + m = Model() + m.add_variables(name="x", lower=0, upper=10) + # No objective added — both entry points should raise the same error. + with pytest.raises(ValueError, match="No objective has been set"): + solvers.Solver.from_name("highs", m, io_api="lp").solve() + with pytest.raises(ValueError, match="No objective has been set"): + m.solve("highs") + + +class TestSolverDoesNotMutateModel: + """Solver.from_model() must not mutate model state (sanitize stays Model-level).""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_from_model_leaves_constraints_untouched(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + # Constraint with a near-zero coefficient — would be sanitized away if + # the Solver path were sanitizing on build. + m.add_constraints(1e-12 * x + x >= 0, name="c") + m.add_objective(x) + + before = m.constraints["c"].coeffs.values.copy() + solvers.Solver.from_name("highs", m, io_api="lp") + after = m.constraints["c"].coeffs.values + + assert np.allclose(before, after, equal_nan=True), ( + "Solver.from_model() must not mutate model constraints. " + "Sanitization is a Model-level primitive; call " + "model.constraints.sanitize_zeros() / .sanitize_infinities() " + "explicitly before building." + ) + + +class TestAssignResultWiring: + """assign_result(result, solver=...) populates model.solver.""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_with_solver_wires_model_solver(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + assert m.solver is None + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result, solver=solver) + + assert m.solver is solver + assert m.solver_model is solver.solver_model + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_without_solver_kwarg_leaves_solver_unset(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result) # no solver kwarg + + assert m.solver is None From e08733cb39bbc2d05777555709e715a384b715c2 Mon Sep 17 00:00:00 2001 From: Fabian Date: Mon, 18 May 2026 21:16:03 +0200 Subject: [PATCH 3/4] refactor(sos): tighten undo semantics and error hints - undo_sos_reformulation() now raises if no state is applied (fail-fast) - to_netcdf error no longer suggests poking the private state slot - Solver._build runs _validate_model before _check_sos_unmasked so SOS on an LP-only solver surfaces the reformulate-first hint - reformulate_sos_constraints docstring points at the stateful API --- linopy/io.py | 3 +-- linopy/model.py | 9 +++++++-- linopy/solvers.py | 2 +- linopy/sos_reformulation.py | 6 ++++-- test/test_sos_reformulation.py | 9 +++------ 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/linopy/io.py b/linopy/io.py index 064716b9..657d2c19 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -857,8 +857,7 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: raise RuntimeError( "Cannot serialize a model with an active SOS reformulation. " "Call `model.undo_sos_reformulation()` first to restore the " - "original SOS form, or save the model in its reformulated form " - "after explicitly clearing `model._sos_reformulation_state`." + "original SOS form before saving." ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: diff --git a/linopy/model.py b/linopy/model.py index 12a46206..65e8093f 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1249,10 +1249,15 @@ def undo_sos_reformulation(self) -> None: """ Revert a previously applied SOS reformulation. - No-op if no reformulation is currently applied. + Raises + ------ + RuntimeError + If no reformulation is currently applied. """ if self._sos_reformulation_state is None: - return + raise RuntimeError( + "No SOS reformulation is currently applied to this model." + ) state = self._sos_reformulation_state self._sos_reformulation_state = None undo_sos_reformulation(self, state) diff --git a/linopy/solvers.py b/linopy/solvers.py index 0fbd2c12..d6cc50e6 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -516,8 +516,8 @@ def _build(self, **build_kwargs: Any) -> None: """ if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") - self.model._check_sos_unmasked() self._validate_model() + self.model._check_sos_unmasked() if self.io_api == "direct": self._build_direct(**build_kwargs) else: diff --git a/linopy/sos_reformulation.py b/linopy/sos_reformulation.py index 8ccb7613..0c677216 100644 --- a/linopy/sos_reformulation.py +++ b/linopy/sos_reformulation.py @@ -233,8 +233,10 @@ def reformulate_sos_constraints( 1. If custom big_m was specified in add_sos_constraints(), use that 2. Otherwise, use the variable bounds (tightest valid Big-M) - Note: This permanently mutates the model. To solve with automatic - undo, use ``model.solve(reformulate_sos=True)`` instead. + Note: This permanently mutates the model and returns a token the caller + owns. For a stateful, reversible API use ``model.apply_sos_reformulation()`` + / ``model.undo_sos_reformulation()``; for automatic undo around a single + solve use ``model.solve(reformulate_sos=True)``. Parameters ---------- diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index f2f48075..6b42918a 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -353,14 +353,11 @@ def test_double_apply_raises(self) -> None: with pytest.raises(RuntimeError, match="already been applied"): m.apply_sos_reformulation() - def test_undo_without_apply_is_noop(self) -> None: + def test_undo_without_apply_raises(self) -> None: m = self._build_sos1_model() - assert m._sos_reformulation_state is None - - m.undo_sos_reformulation() # should not raise - assert m._sos_reformulation_state is None - assert list(m.variables.sos) == ["x"] + with pytest.raises(RuntimeError, match="No SOS reformulation"): + m.undo_sos_reformulation() @pytest.mark.parametrize( "copy_fn", From 9c38ea6d36f02cd13c77c3db11c59f655a60c0fb Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 18 May 2026 21:23:16 +0200 Subject: [PATCH 4/4] fix(sos): auto-undo SOS reformulation when build/solve raises `Model.solve(reformulate_sos=...)` left `_sos_reformulation_state` set if `Solver.from_name`, `solver.solve`, or the file-cleanup `finally` raised, since the undo lived in a second `try` around `assign_result` that those failures never reached. The next solve then hit `RuntimeError: SOS reformulation has already been applied`. Wrap sanitize, build/solve, file cleanup, and assign_result in a single outer try/finally so the undo always runs. Co-Authored-By: Claude Opus 4.7 (1M context) --- linopy/model.py | 88 +++++++++++++++++----------------- test/test_sos_reformulation.py | 32 +++++++++++++ 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 65e8093f..03450d62 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1785,52 +1785,52 @@ def solve( # If SOS is present and the solver doesn't support it (and the user # didn't ask for reformulation), Solver._build() will raise. - if sanitize_zeros: - self.constraints.sanitize_zeros() - if sanitize_infinities: - self.constraints.sanitize_infinities() - try: - self.solver = None # closes any previous solver - if io_api == "direct": - if set_names is None: - set_names = self.set_names_in_solver_io - build_kwargs: dict[str, Any] = { - "explicit_coordinate_names": explicit_coordinate_names, - "set_names": set_names, - "log_fn": to_path(log_fn), - } - if env is not None: - build_kwargs["env"] = env - else: - build_kwargs = { - "explicit_coordinate_names": explicit_coordinate_names, - "slice_size": slice_size, - "progress": progress, - "problem_fn": to_path(problem_fn), - } - self.solver = solver = solvers.Solver.from_name( - solver_name, - model=self, - io_api=io_api, - options=solver_options, - **build_kwargs, - ) - if io_api != "direct": - problem_fn = solver._problem_fn - result = solver.solve( - solution_fn=to_path(solution_fn), - log_fn=to_path(log_fn), - warmstart_fn=to_path(warmstart_fn), - basis_fn=to_path(basis_fn), - env=env, - ) - finally: - for fn in (problem_fn, solution_fn): - if fn is not None and (os.path.exists(fn) and not keep_files): - os.remove(fn) + if sanitize_zeros: + self.constraints.sanitize_zeros() + if sanitize_infinities: + self.constraints.sanitize_infinities() + + try: + self.solver = None # closes any previous solver + if io_api == "direct": + if set_names is None: + set_names = self.set_names_in_solver_io + build_kwargs: dict[str, Any] = { + "explicit_coordinate_names": explicit_coordinate_names, + "set_names": set_names, + "log_fn": to_path(log_fn), + } + if env is not None: + build_kwargs["env"] = env + else: + build_kwargs = { + "explicit_coordinate_names": explicit_coordinate_names, + "slice_size": slice_size, + "progress": progress, + "problem_fn": to_path(problem_fn), + } + self.solver = solver = solvers.Solver.from_name( + solver_name, + model=self, + io_api=io_api, + options=solver_options, + **build_kwargs, + ) + if io_api != "direct": + problem_fn = solver._problem_fn + result = solver.solve( + solution_fn=to_path(solution_fn), + log_fn=to_path(log_fn), + warmstart_fn=to_path(warmstart_fn), + basis_fn=to_path(basis_fn), + env=env, + ) + finally: + for fn in (problem_fn, solution_fn): + if fn is not None and (os.path.exists(fn) and not keep_files): + os.remove(fn) - try: return self.assign_result(result) finally: if applied_sos_reformulation_here: diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 6b42918a..b244d9b6 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -433,6 +433,38 @@ def test_solver_from_name_raises_without_reformulation(self) -> None: solvers.Solver.from_name("highs", m, io_api="lp") +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolveAutoUndoOnFailure: + """Model.solve must auto-undo SOS reformulation when build/solve raises.""" + + def test_state_restored_when_build_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + def boom(*args: object, **kwargs: object) -> None: + raise RuntimeError("simulated build failure") + + monkeypatch.setattr(solvers.Solver, "from_name", boom) + + with pytest.raises(RuntimeError, match="simulated build failure"): + m.solve(solver_name="highs", reformulate_sos=True) + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + # A subsequent real solve must not hit "already applied" + monkeypatch.undo() + m.solve(solver_name="highs", reformulate_sos=True) + + @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") class TestSolveWithReformulation: """Tests for solving with SOS reformulation."""