diff --git a/linopy/io.py b/linopy/io.py index 4dc4dc02..657d2c19 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy as _copy import json import logging import shutil @@ -845,7 +846,19 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: Arguments passed to ``xarray.Dataset.to_netcdf``. **kwargs : TYPE Keyword arguments passed to ``xarray.Dataset.to_netcdf``. + + Raises + ------ + RuntimeError + If the model has an active SOS reformulation. Call + :meth:`Model.undo_sos_reformulation` before serializing. """ + if m._sos_reformulation_state is not None: + raise RuntimeError( + "Cannot serialize a model with an active SOS reformulation. " + "Call `model.undo_sos_reformulation()` first to restore the " + "original SOS form before saving." + ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: to_rename = set([*ds.dims, *ds.coords, *ds]) @@ -1117,6 +1130,9 @@ def copy(m: Model, include_solution: bool = False, deep: bool = True) -> Model: if include_solution or attr not in SOLVE_STATE_ATTRS: setattr(new_model, attr, getattr(m, attr)) + if m._sos_reformulation_state is not None: + new_model._sos_reformulation_state = _copy.deepcopy(m._sos_reformulation_state) + return new_model diff --git a/linopy/model.py b/linopy/model.py index 03fd9479..03450d62 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -90,6 +90,7 @@ available_solvers, ) from linopy.sos_reformulation import ( + SOSReformulationResult, reformulate_sos_constraints, undo_sos_reformulation, ) @@ -240,6 +241,7 @@ class Model: "_relaxed_registry", "_piecewise_formulations", "_solver", + "_sos_reformulation_state", "__weakref__", ) @@ -310,6 +312,7 @@ def __init__( gettempdir() if solver_dir is None else solver_dir ) self._solver: solvers.Solver | None = None + self._sos_reformulation_state: SOSReformulationResult | None = None @property def solver(self) -> solvers.Solver | None: @@ -1221,6 +1224,44 @@ def remove_sos_constraints(self, variable: Variable) -> None: reformulate_sos_constraints = reformulate_sos_constraints + def apply_sos_reformulation(self) -> None: + """ + Reformulate SOS constraints into binary + linear form, in place. + + The reformulation token is stored on the model so it can be reverted + with :meth:`undo_sos_reformulation`. This is the stateful counterpart + to :func:`linopy.sos_reformulation.reformulate_sos_constraints`, where + the caller owns the token. + + Raises + ------ + RuntimeError + If a reformulation has already been applied and not undone. + """ + if self._sos_reformulation_state is not None: + raise RuntimeError( + "SOS reformulation has already been applied to this model. " + "Call `undo_sos_reformulation()` before applying again." + ) + self._sos_reformulation_state = reformulate_sos_constraints(self) + + def undo_sos_reformulation(self) -> None: + """ + Revert a previously applied SOS reformulation. + + Raises + ------ + RuntimeError + If no reformulation is currently applied. + """ + if self._sos_reformulation_state is None: + raise RuntimeError( + "No SOS reformulation is currently applied to this model." + ) + state = self._sos_reformulation_state + self._sos_reformulation_state = None + undo_sos_reformulation(self, state) + def _check_sos_unmasked(self) -> None: """ Reject the model if any SOS variable has masked entries. @@ -1642,12 +1683,6 @@ def solve( sanitize_zeros=sanitize_zeros, sanitize_infinities=sanitize_infinities ) - if self.objective.expression.empty: - raise ValueError( - "No objective has been set on the model. Use `m.add_objective(...)` " - "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." - ) - # check io_api if io_api is not None and io_api not in IO_APIS: raise ValueError( @@ -1655,6 +1690,16 @@ def solve( ) if remote is not None: + # The remote branch short-circuits before reaching Solver.solve(), + # which is where the empty-objective check normally fires. Replicate + # it here. This duplication becomes obsolete once OETC is folded + # into the Solver pipeline (see PyPSA/linopy#683). + if self.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use " + "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " + "for a pure feasibility problem)." + ) if isinstance(remote, OetcHandler): solved = remote.solve_on_oetc( self, solver_name=solver_name, **solver_options @@ -1720,26 +1765,13 @@ def solve( else: solution_fn = self.get_solution_file() - if sanitize_zeros: - self.constraints.sanitize_zeros() - - if sanitize_infinities: - self.constraints.sanitize_infinities() - - if self.is_quadratic and not solver_class.supports( - SolverFeature.QUADRATIC_OBJECTIVE - ): - raise ValueError( - f"Solver {solver_name} does not support quadratic problems." - ) - if reformulate_sos not in (True, False, "auto"): raise ValueError( f"Invalid value for reformulate_sos: {reformulate_sos!r}. " "Must be True, False, or 'auto'." ) - sos_reform_result = None + applied_sos_reformulation_here = False if self.variables.sos: supports_sos = solver_class.supports(SolverFeature.SOS_CONSTRAINTS) should_reformulate = reformulate_sos is True or ( @@ -1748,67 +1780,90 @@ def solve( if should_reformulate: logger.info(f"Reformulating SOS constraints for solver {solver_name}") - sos_reform_result = reformulate_sos_constraints(self) - elif reformulate_sos is False and not supports_sos: - raise ValueError( - f"Solver {solver_name} does not support SOS constraints. " - "Use reformulate_sos=True or 'auto', or a solver that supports SOS." - ) - - if self.variables.semi_continuous: - if not solver_class.supports(SolverFeature.SEMI_CONTINUOUS_VARIABLES): - raise ValueError( - f"Solver {solver_name} does not support semi-continuous variables. " - "Use a solver that supports them (gurobi, cplex, highs)." - ) + self.apply_sos_reformulation() + applied_sos_reformulation_here = True + # If SOS is present and the solver doesn't support it (and the user + # didn't ask for reformulation), Solver._build() will raise. try: - self.solver = None # closes any previous solver - if io_api == "direct": - if set_names is None: - set_names = self.set_names_in_solver_io - build_kwargs: dict[str, Any] = { - "explicit_coordinate_names": explicit_coordinate_names, - "set_names": set_names, - "log_fn": to_path(log_fn), - } - if env is not None: - build_kwargs["env"] = env - else: - build_kwargs = { - "explicit_coordinate_names": explicit_coordinate_names, - "slice_size": slice_size, - "progress": progress, - "problem_fn": to_path(problem_fn), - } - self.solver = solver = solvers.Solver.from_name( - solver_name, - model=self, - io_api=io_api, - options=solver_options, - **build_kwargs, - ) - if io_api != "direct": - problem_fn = solver._problem_fn - result = solver.solve( - solution_fn=to_path(solution_fn), - log_fn=to_path(log_fn), - warmstart_fn=to_path(warmstart_fn), - basis_fn=to_path(basis_fn), - env=env, - ) - finally: - for fn in (problem_fn, solution_fn): - if fn is not None and (os.path.exists(fn) and not keep_files): - os.remove(fn) + if sanitize_zeros: + self.constraints.sanitize_zeros() + if sanitize_infinities: + self.constraints.sanitize_infinities() + + try: + self.solver = None # closes any previous solver + if io_api == "direct": + if set_names is None: + set_names = self.set_names_in_solver_io + build_kwargs: dict[str, Any] = { + "explicit_coordinate_names": explicit_coordinate_names, + "set_names": set_names, + "log_fn": to_path(log_fn), + } + if env is not None: + build_kwargs["env"] = env + else: + build_kwargs = { + "explicit_coordinate_names": explicit_coordinate_names, + "slice_size": slice_size, + "progress": progress, + "problem_fn": to_path(problem_fn), + } + self.solver = solver = solvers.Solver.from_name( + solver_name, + model=self, + io_api=io_api, + options=solver_options, + **build_kwargs, + ) + if io_api != "direct": + problem_fn = solver._problem_fn + result = solver.solve( + solution_fn=to_path(solution_fn), + log_fn=to_path(log_fn), + warmstart_fn=to_path(warmstart_fn), + basis_fn=to_path(basis_fn), + env=env, + ) + finally: + for fn in (problem_fn, solution_fn): + if fn is not None and (os.path.exists(fn) and not keep_files): + os.remove(fn) - try: return self.assign_result(result) finally: - if sos_reform_result is not None: - undo_sos_reformulation(self, sos_reform_result) + if applied_sos_reformulation_here: + self.undo_sos_reformulation() + + def assign_result( + self, + result: Result, + solver: solvers.Solver | None = None, + ) -> tuple[str, str]: + """ + Write a solver Result back onto the model. + + Copies primal / dual values onto variables / constraints, sets + :attr:`status`, :attr:`termination_condition`, and + :attr:`objective.value`. When ``solver`` is provided, also stores it on + ``self.solver`` so post-solve introspection (``model.solver_model``, + ``compute_infeasibilities()``) works. + + Parameters + ---------- + result : Result + The :class:`linopy.constants.Result` returned by + :meth:`linopy.solvers.Solver.solve`. + solver : Solver, optional + The solver instance that produced the result. Pass it on the + low-level ``Solver.from_name(...).solve()`` path to attach it as + ``self.solver`` for post-solve introspection. ``Model.solve()`` + attaches the solver itself and does not pass this argument. + """ + if solver is not None: + self.solver = solver - def assign_result(self, result: Result) -> tuple[str, str]: result.info() if result.solution is not None: diff --git a/linopy/solvers.py b/linopy/solvers.py index 9466db0f..d6cc50e6 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -504,15 +504,52 @@ def from_model( return instance def _build(self, **build_kwargs: Any) -> None: - """Dispatch to direct or file build based on ``io_api``.""" + """ + Dispatch to direct or file build based on ``io_api``. + + The Solver never mutates ``self.model``. Constraint sanitization + (``model.constraints.sanitize_zeros()`` / + ``.sanitize_infinities()``) and SOS reformulation + (``model.apply_sos_reformulation()``) are Model-level operations + the caller applies first; this builder consumes whatever shape it + is handed. + """ if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") + self._validate_model() self.model._check_sos_unmasked() if self.io_api == "direct": self._build_direct(**build_kwargs) else: self._build_file(**build_kwargs) + def _validate_model(self) -> None: + """Pre-build checks on whether this solver can handle ``self.model``.""" + model = self.model + assert model is not None + solver_name = self.solver_name.value + cls = type(self) + + if model.is_quadratic and not cls.supports(SolverFeature.QUADRATIC_OBJECTIVE): + raise ValueError( + f"Solver {solver_name} does not support quadratic problems." + ) + + if model.variables.semi_continuous and not cls.supports( + SolverFeature.SEMI_CONTINUOUS_VARIABLES + ): + raise ValueError( + f"Solver {solver_name} does not support semi-continuous variables. " + "Use a solver that supports them (gurobi, cplex, highs)." + ) + + if model.variables.sos and not cls.supports(SolverFeature.SOS_CONSTRAINTS): + raise ValueError( + f"Solver {solver_name} does not support SOS constraints. " + "Reformulate first via `Model.solve(reformulate_sos=True)` or " + "`model.apply_sos_reformulation()`, or use a solver that supports SOS." + ) + def _build_direct(self, **build_kwargs: Any) -> None: """Build the native solver model from ``self.model``. Override per-solver.""" raise NotImplementedError( @@ -553,7 +590,30 @@ def _build_file(self, **build_kwargs: Any) -> None: self._cache_model_sizes(model) def solve(self, **run_kwargs: Any) -> Result: - """Run the prepared solver and return a :class:`Result`.""" + """ + Run the prepared solver and return a :class:`Result`. + + The canonical low-level pattern is:: + + solver = Solver.from_name("gurobi", model, io_api="direct") + result = solver.solve() + model.assign_result(result, solver=solver) + + Passing ``solver=`` to :meth:`Model.assign_result` wires + ``model.solver`` so post-solve helpers like + :meth:`Model.compute_infeasibilities` keep working. + + Raises + ------ + ValueError + If the attached model has no objective set. Submit-time check + shared by both ``Model.solve()`` and direct-Solver callers. + """ + if self.model is not None and self.model.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use `m.add_objective(...)` " + "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." + ) if self.io_api == "direct" or self.solver_model is not None: return self._run_direct(**run_kwargs) if self._problem_fn is not None: diff --git a/linopy/sos_reformulation.py b/linopy/sos_reformulation.py index 8ccb7613..0c677216 100644 --- a/linopy/sos_reformulation.py +++ b/linopy/sos_reformulation.py @@ -233,8 +233,10 @@ def reformulate_sos_constraints( 1. If custom big_m was specified in add_sos_constraints(), use that 2. Otherwise, use the variable bounds (tightest valid Big-M) - Note: This permanently mutates the model. To solve with automatic - undo, use ``model.solve(reformulate_sos=True)`` instead. + Note: This permanently mutates the model and returns a token the caller + owns. For a stateful, reversible API use ``model.apply_sos_reformulation()`` + / ``model.undo_sos_reformulation()``; for automatic undo around a single + solve use ``model.solve(reformulate_sos=True)``. Parameters ---------- diff --git a/test/test_solvers.py b/test/test_solvers.py index db894137..1b6bd9a9 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -23,6 +23,14 @@ from linopy.solvers import _installed_version_in +@pytest.fixture +def lp_only_solver() -> str: + for name in ("glpk", "cbc"): + if name in solvers.available_solvers: + return name + pytest.skip("Need an LP-only solver (glpk or cbc) installed") + + @pytest.fixture def simple_model() -> Model: m = Model(chunk=None) @@ -464,3 +472,95 @@ def test_xpress_gpu_feature_reflects_installed_version() -> None: assert solvers.Xpress.supports( SolverFeature.GPU_ACCELERATION ) == _installed_version_in("xpress", ">=9.8.0") + + +class TestValidateModelOnBuild: + """Solver._build() runs solver-feature checks regardless of entry point.""" + + def test_quadratic_without_qp_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x * x, sense="min") + + with pytest.raises(ValueError, match="does not support quadratic"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + def test_semi_continuous_without_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=1, upper=10, semi_continuous=True) + m.add_objective(x) + + with pytest.raises(ValueError, match="does not support semi-continuous"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_solve_without_objective_raises(self) -> None: + m = Model() + m.add_variables(name="x", lower=0, upper=10) + # No objective added — both entry points should raise the same error. + with pytest.raises(ValueError, match="No objective has been set"): + solvers.Solver.from_name("highs", m, io_api="lp").solve() + with pytest.raises(ValueError, match="No objective has been set"): + m.solve("highs") + + +class TestSolverDoesNotMutateModel: + """Solver.from_model() must not mutate model state (sanitize stays Model-level).""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_from_model_leaves_constraints_untouched(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + # Constraint with a near-zero coefficient — would be sanitized away if + # the Solver path were sanitizing on build. + m.add_constraints(1e-12 * x + x >= 0, name="c") + m.add_objective(x) + + before = m.constraints["c"].coeffs.values.copy() + solvers.Solver.from_name("highs", m, io_api="lp") + after = m.constraints["c"].coeffs.values + + assert np.allclose(before, after, equal_nan=True), ( + "Solver.from_model() must not mutate model constraints. " + "Sanitization is a Model-level primitive; call " + "model.constraints.sanitize_zeros() / .sanitize_infinities() " + "explicitly before building." + ) + + +class TestAssignResultWiring: + """assign_result(result, solver=...) populates model.solver.""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_with_solver_wires_model_solver(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + assert m.solver is None + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result, solver=solver) + + assert m.solver is solver + assert m.solver_model is solver.solver_model + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_without_solver_kwarg_leaves_solver_unset(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result) # no solver kwarg + + assert m.solver is None diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py index 30b2d767..a9529dc0 100644 --- a/test/test_sos_constraints.py +++ b/test/test_sos_constraints.py @@ -316,7 +316,7 @@ def test_unsupported_solver_raises_error() -> None: m.solve(solver_name=solver) -def test_to_highspy_raises_not_implemented() -> None: +def test_to_highspy_raises_when_sos_present() -> None: pytest.importorskip("highspy") m = Model() @@ -324,8 +324,5 @@ def test_to_highspy_raises_not_implemented() -> None: build = m.add_variables(coords=[locations], name="build", binary=True) m.add_sos_constraints(build, sos_type=1, sos_dim="locations") - with pytest.raises( - NotImplementedError, - match="SOS constraints are not supported by the HiGHS direct API", - ): + with pytest.raises(ValueError, match="does not support SOS constraints"): m.to_highspy() diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 24ba62b3..b244d9b6 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging +from collections.abc import Callable +from pathlib import Path import numpy as np import pandas as pd @@ -312,6 +314,157 @@ def test_reformulate_inplace(self) -> None: assert "_sos_reform_x_y" in m.variables +class TestApplyUndoSOSReformulation: + """Tests for Model.apply_sos_reformulation / undo_sos_reformulation.""" + + @staticmethod + def _build_sos1_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + return m + + def test_apply_stashes_state(self) -> None: + m = self._build_sos1_model() + assert m._sos_reformulation_state is None + + m.apply_sos_reformulation() + + assert m._sos_reformulation_state is not None + assert m._sos_reformulation_state.reformulated == ["x"] + assert len(list(m.variables.sos)) == 0 + assert "_sos_reform_x_y" in m.variables + + def test_undo_restores_and_clears_state(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + m.undo_sos_reformulation() + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + def test_double_apply_raises(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="already been applied"): + m.apply_sos_reformulation() + + def test_undo_without_apply_raises(self) -> None: + m = self._build_sos1_model() + + with pytest.raises(RuntimeError, match="No SOS reformulation"): + m.undo_sos_reformulation() + + @pytest.mark.parametrize( + "copy_fn", + [ + pytest.param(lambda m: m.copy(), id="model.copy()"), + pytest.param(lambda m: __import__("copy").copy(m), id="copy.copy(model)"), + pytest.param( + lambda m: __import__("copy").deepcopy(m), id="copy.deepcopy(model)" + ), + ], + ) + def test_copy_persists_state_and_undo_works_on_copy( + self, copy_fn: Callable[[Model], Model] + ) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + c = copy_fn(m) + + # State is carried over but is an independent object + assert c._sos_reformulation_state is not None + assert c._sos_reformulation_state is not m._sos_reformulation_state + # Aux vars/cons exist on the copy (they were copied as part of the + # reformulated model state) + assert "_sos_reform_x_y" in c.variables + assert "_sos_reform_x_upper" in c.constraints + assert "_sos_reform_x_card" in c.constraints + # SOS attrs are not on the copy's "x" yet (still in reformulated form) + assert "x" not in list(c.variables.sos) + + # Undo on the copy fully restores the original SOS form + c.undo_sos_reformulation() + assert c._sos_reformulation_state is None + assert list(c.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in c.variables + assert "_sos_reform_x_upper" not in c.constraints + assert "_sos_reform_x_card" not in c.constraints + + # Original is entirely unaffected + assert m._sos_reformulation_state is not None + assert "_sos_reform_x_y" in m.variables + assert len(list(m.variables.sos)) == 0 + + def test_to_netcdf_raises_when_state_active(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="active SOS reformulation"): + m.to_netcdf(tmp_path / "m.nc") + + def test_to_netcdf_works_after_undo(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + m.undo_sos_reformulation() + + m.to_netcdf(tmp_path / "m.nc") # should not raise + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolverPathSOSCheck: + """Solver._build() must raise on SOS-bearing model with non-SOS solver.""" + + def test_solver_from_name_raises_without_reformulation(self) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + with pytest.raises(ValueError, match="does not support SOS"): + solvers.Solver.from_name("highs", m, io_api="lp") + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolveAutoUndoOnFailure: + """Model.solve must auto-undo SOS reformulation when build/solve raises.""" + + def test_state_restored_when_build_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + def boom(*args: object, **kwargs: object) -> None: + raise RuntimeError("simulated build failure") + + monkeypatch.setattr(solvers.Solver, "from_name", boom) + + with pytest.raises(RuntimeError, match="simulated build failure"): + m.solve(solver_name="highs", reformulate_sos=True) + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + # A subsequent real solve must not hit "already applied" + monkeypatch.undo() + m.solve(solver_name="highs", reformulate_sos=True) + + @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") class TestSolveWithReformulation: """Tests for solving with SOS reformulation."""