diff --git a/docs/research/journal.md b/docs/research/journal.md index 20ff1b8..8ed4059 100644 --- a/docs/research/journal.md +++ b/docs/research/journal.md @@ -409,6 +409,141 @@ coverage is **39 tests generating ~4,660 random examples**. | Issue | Phase | Status | |---|---|---| | #135 | 1b-c: Coq mechanized proofs | Open (long-term, needs Coq expertise) | -| #138 | 4: Dynamical invariants | Open (blocked on gds-analysis package) | +| #138 | 4: Dynamical invariants | Closed (superseded by #140-#142) | + +--- + +## Entry 006 — 2026-03-28 + +**Subject:** StateMetric (bridge Step 3) + gds-analysis package (Steps 4-5) + +### Motivation + +The bridge proposal (paper-implementation-gap.md) maps paper definitions to +code in 7 incremental steps. Steps 1-2 were done prior. Steps 3-5 were +identified as the next actionable work — Step 3 is structural (same pattern +as Steps 1-2), and Steps 4-5 require runtime but are now unblocked by +gds-sim's existence. + +### Actions + +#### Step 3: StateMetric (Paper Assumption 3.2) + +Added `StateMetric` to gds-framework following the exact +AdmissibleInputConstraint / TransitionSignature pattern: + +- `constraints.py`: frozen Pydantic model with `name`, `variables` + (entity-variable pairs), `metric_type` (annotation), `distance` + (R3 lossy callable), `description` +- `spec.py`: `GDSSpec.register_state_metric()` + `_validate_state_metrics()` + (checks entity/variable references exist, rejects empty variables) +- `__init__.py`: exported as public API +- `export.py`: RDF export as `StateMetric` class + `MetricVariableEntry` + blank nodes +- `import_.py`: round-trip import with `distance=None` (R3 lossy) +- `shacl.py`: `StateMetricShape` (name required, xsd:string) +- 9 new framework tests + 1 OWL round-trip test + +Commit: `f9168ee` + +#### gds-analysis Package (#140) + +New package bridging gds-framework structural annotations to gds-sim +runtime. Dependency graph: + +``` +gds-framework <-- gds-sim <-- gds-analysis + ^ | + +----------------------------------+ +``` + +Three modules: + +- **`adapter.py`**: `spec_to_model(spec, policies, sufs, ...)` maps + GDSSpec blocks to `gds_sim.Model`. BoundaryAction + Policy → policies, + Mechanism.updates → SUFs keyed by state variable. Auto-generates initial + state from entities. Optionally wraps BoundaryAction policies with + constraint guards. + +- **`constraints.py`**: `guarded_policy(fn, constraints)` wraps a policy + with AdmissibleInputConstraint enforcement. Three violation modes: + warn (log + continue), raise (ConstraintViolation), zero (empty signal). + +- **`metrics.py`**: `trajectory_distances(spec, trajectory)` computes + StateMetric distances between successive states. Extracts relevant + variables by `EntityName.VariableName` key, applies distance callable. + +21 tests, 93% coverage, including end-to-end thermostat integration +(spec → model → simulate → measure distances). + +Commit: `447fc62` + +#### Reachable Set R(x) and Configuration Space X_C (#141) + +Added `reachability.py` to gds-analysis: + +- **`reachable_set(spec, model, state, input_samples)`**: Paper Def 4.1. + For each input sample, runs one timestep with overridden policy outputs, + collects distinct reached states. Deduplicates by state fingerprint. + +- **`reachable_graph(spec, model, initial_states, input_samples, max_depth)`**: + BFS expansion from initial states, applying `reachable_set()` at each + node. Returns adjacency dict of state fingerprints. + +- **`configuration_space(graph)`**: Paper Def 4.2. Tarjan's algorithm for + strongly connected components. Returns SCCs sorted by size — the largest + is X_C. + +11 new tests covering single/multiple/duplicate inputs, empty inputs, +BFS depth expansion, SCC cases (self-loop, cycle, DAG, disconnected), +and end-to-end thermostat integration. + +Commit: `081cb9c` + +### Bridge Status + +| Step | Paper | Annotation / Function | Status | +|---|---|---|---| +| 1 | Def 2.5 | AdmissibleInputConstraint | Done (prior) | +| 2 | Def 2.7 | TransitionSignature | Done (prior) | +| 3 | Assumption 3.2 | StateMetric | **Done** | +| 4 | Def 4.1 | `reachable_set()` | **Done** | +| 5 | Def 4.2 | `configuration_space()` | **Done** | +| 6 | Def 3.3 | Contingent derivative D'F | Open (#142, research) | +| 7 | Theorem 4.4 | Local controllability | Open (#142, research) | + +### Issue Tracker + +| Issue | Status | +|---|---| +| #134 Phase 1a | Closed | +| #135 Phase 1b-c (Coq) | Open | +| #136 Phase 2 | Closed | +| #137 Phase 3 | Closed | +| #138 Phase 4 (original) | Closed (superseded) | +| #140 gds-analysis | **Closed** | +| #141 R(x) + X_C | **Closed** | +| #142 D'F + controllability | Open (research frontier) | + +### Observations + +1. gds-sim has zero dependency on gds-framework. This is correct + architecture — gds-sim is a generic trajectory executor, gds-analysis + is the GDS-specific bridge. The adapter pattern keeps both packages + clean. + +2. The `_step_once()` implementation creates a temporary Model per input + sample, which is simple but not performant for large input spaces. + A future optimization would batch inputs or use gds-sim's parameter + sweep directly. + +3. `reachable_set()` is trajectory-based (Monte Carlo), not symbolic. + It cannot prove that a state is *unreachable* — only that it wasn't + reached in the sampled inputs. For formal reachability guarantees, + symbolic tools (Z3, JuliaReach) would be needed. + +4. Steps 6-7 (contingent derivative, controllability) are genuinely + research-level. They require convergence analysis and Lipschitz + conditions that go beyond trajectory sampling. --- diff --git a/packages/gds-analysis/README.md b/packages/gds-analysis/README.md new file mode 100644 index 0000000..0371aff --- /dev/null +++ b/packages/gds-analysis/README.md @@ -0,0 +1,31 @@ +# gds-analysis + +Dynamical analysis for GDS specifications. Bridges `gds-framework` structural annotations to `gds-sim` runtime. + +## Installation + +```bash +uv add gds-analysis +``` + +## Quick Start + +```python +from gds_analysis import spec_to_model, trajectory_distances +from gds_sim import Simulation + +# Build a runnable model from a GDSSpec +model = spec_to_model( + spec, + policies={"Sensor": sensor_fn, "Controller": controller_fn}, + sufs={"Heater": heater_fn}, + initial_state={"Room.temperature": 18.0}, +) + +# Run simulation +sim = Simulation(model=model, timesteps=100) +results = sim.run() + +# Compute distances using StateMetric annotations +distances = trajectory_distances(spec, results.to_list()) +``` diff --git a/packages/gds-analysis/gds_analysis/__init__.py b/packages/gds-analysis/gds_analysis/__init__.py new file mode 100644 index 0000000..7a88489 --- /dev/null +++ b/packages/gds-analysis/gds_analysis/__init__.py @@ -0,0 +1,26 @@ +"""Dynamical analysis for GDS specifications. + +Bridges gds-framework structural annotations to gds-sim runtime, +enabling constraint enforcement, metric computation, and reachability +analysis on concrete trajectories. +""" + +__version__ = "0.1.0" + +from gds_analysis.adapter import spec_to_model +from gds_analysis.constraints import guarded_policy +from gds_analysis.metrics import trajectory_distances +from gds_analysis.reachability import ( + configuration_space, + reachable_graph, + reachable_set, +) + +__all__ = [ + "configuration_space", + "guarded_policy", + "reachable_graph", + "reachable_set", + "spec_to_model", + "trajectory_distances", +] diff --git a/packages/gds-analysis/gds_analysis/adapter.py b/packages/gds-analysis/gds_analysis/adapter.py new file mode 100644 index 0000000..a960090 --- /dev/null +++ b/packages/gds-analysis/gds_analysis/adapter.py @@ -0,0 +1,176 @@ +"""Adapt GDSSpec structural annotations to gds-sim execution primitives. + +The adapter reads the block composition, wiring topology, and structural +annotations from a GDSSpec and produces a gds_sim.Model that can be run. + +Users must supply the behavioral functions (policies, SUFs) that +gds-framework deliberately leaves as R3. The adapter wires them together +using the structural skeleton. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any + +from gds.blocks.roles import BoundaryAction, ControlAction, Mechanism, Policy +from gds_sim import Model +from gds_sim.types import StateUpdateBlock + +from gds_analysis.constraints import guarded_policy + +if TYPE_CHECKING: + from gds import GDSSpec + + +def spec_to_model( + spec: GDSSpec, + *, + policies: dict[str, Any] | None = None, + sufs: dict[str, Any] | None = None, + initial_state: dict[str, Any] | None = None, + params: dict[str, list[Any]] | None = None, + enforce_constraints: bool = True, +) -> Model: + """Build a gds_sim.Model from a GDSSpec and user-supplied functions. + + Parameters + ---------- + spec + The GDS specification with registered blocks, wirings, and + structural annotations. + policies + Mapping of block name -> policy callable. Required for every + BoundaryAction, Policy, and ControlAction block. + sufs + Mapping of block name -> state update callable. Required for + every Mechanism block. + initial_state + Initial state dict. If None, builds a zero-valued state from + the spec's entities and variables. + params + Parameter sweep dict (passed through to gds_sim.Model). + enforce_constraints + If True, wrap BoundaryAction policies with + AdmissibleInputConstraint guards. + + Returns + ------- + gds_sim.Model + A runnable simulation model. + + Raises + ------ + ValueError + If required policies or SUFs are missing. + """ + policies = policies or {} + sufs = sufs or {} + + if initial_state is None: + initial_state = _default_initial_state(spec) + + blocks = _build_state_update_blocks(spec, policies, sufs, enforce_constraints) + + return Model( + initial_state=initial_state, + state_update_blocks=blocks, + params=params or {}, + ) + + +def _default_initial_state(spec: GDSSpec) -> dict[str, Any]: + """Build a zero-valued initial state from spec entities.""" + state: dict[str, Any] = {} + for entity in spec.entities.values(): + for var_name, sv in entity.variables.items(): + key = f"{entity.name}.{var_name}" + python_type = sv.typedef.python_type + if python_type is float: + state[key] = 0.0 + elif python_type is int: + state[key] = 0 + elif python_type is bool: + state[key] = False + else: + state[key] = "" + return state + + +def _build_state_update_blocks( + spec: GDSSpec, + policies: dict[str, Any], + sufs: dict[str, Any], + enforce_constraints: bool, +) -> list[StateUpdateBlock]: + """Map spec blocks to gds-sim StateUpdateBlocks. + + All blocks are packed into a single StateUpdateBlock. All policies + run in parallel (signal aggregation via dict.update), then all SUFs + run. Multi-wiring topologies with sequential tier dependencies are + not yet modeled — this is a known simplification. + """ + block_policies: dict[str, Any] = {} + block_sufs: dict[str, Any] = {} + + for name, block in spec.blocks.items(): + if isinstance(block, (BoundaryAction, Policy, ControlAction)): + if name not in policies: + raise ValueError( + f"Missing policy function for block '{name}' " + f"({type(block).__name__})" + ) + fn = policies[name] + if enforce_constraints and isinstance(block, BoundaryAction): + fn = _apply_constraint_guard(spec, name, fn) + elif enforce_constraints and not isinstance(block, BoundaryAction): + # Warn if constraints were registered for non-BoundaryAction + mismatched = [ + ac + for ac in spec.admissibility_constraints.values() + if ac.boundary_block == name + ] + if mismatched: + warnings.warn( + f"AdmissibleInputConstraint targets block " + f"'{name}' ({type(block).__name__}), but " + f"constraints are only enforced on " + f"BoundaryAction blocks.", + stacklevel=3, + ) + block_policies[name] = fn + + elif isinstance(block, Mechanism): + if name not in sufs: + raise ValueError( + f"Missing state update function for block '{name}' (Mechanism)" + ) + # Key by target state variable, not block name. + # gds-sim validates that SUF dict keys exist in initial_state. + for entity_name, var_name in block.updates: + state_key = f"{entity_name}.{var_name}" + block_sufs[state_key] = sufs[name] + + return [ + StateUpdateBlock( + policies=block_policies, + variables=block_sufs, + ) + ] + + +def _apply_constraint_guard( + spec: GDSSpec, + block_name: str, + policy_fn: Any, +) -> Any: + """Wrap a policy with AdmissibleInputConstraint guards.""" + constraints = [ + ac + for ac in spec.admissibility_constraints.values() + if ac.boundary_block == block_name and ac.constraint is not None + ] + if not constraints: + return policy_fn + + return guarded_policy(policy_fn, constraints) diff --git a/packages/gds-analysis/gds_analysis/constraints.py b/packages/gds-analysis/gds_analysis/constraints.py new file mode 100644 index 0000000..5c0baf3 --- /dev/null +++ b/packages/gds-analysis/gds_analysis/constraints.py @@ -0,0 +1,91 @@ +"""Runtime constraint enforcement for GDS simulations. + +Wraps policy functions with AdmissibleInputConstraint guards that +validate outputs against the current state before passing them downstream. + +The ``depends_on`` field from AdmissibleInputConstraint is used to project +state to only the declared dependencies before calling the constraint. +This enforces the R1 structural skeleton at runtime. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from gds.constraints import AdmissibleInputConstraint + +logger = logging.getLogger(__name__) + + +class ConstraintViolation(Exception): + """Raised when a policy output violates an admissibility constraint.""" + + +def guarded_policy( + policy_fn: Any, + constraints: list[AdmissibleInputConstraint], + *, + on_violation: str = "warn", +) -> Any: + """Wrap a policy function with admissibility constraint checks. + + Parameters + ---------- + policy_fn + The original policy callable. + constraints + List of AdmissibleInputConstraint objects to enforce. + on_violation + What to do when a constraint is violated: + - "warn": log a warning and return the signal anyway + - "raise": raise ConstraintViolation + - "zero": return an empty signal dict + + Returns + ------- + A wrapped policy function with the same signature. + """ + + def _guarded(state: dict, params: dict, **kw: Any) -> dict[str, Any]: + signal = policy_fn(state, params, **kw) + + for ac in constraints: + if ac.constraint is None: + continue + # Project state to declared dependencies (R1 skeleton). + # If depends_on is empty, pass full state (fallback). + if ac.depends_on: + projected = { + f"{ent}.{var}": state.get(f"{ent}.{var}") + for ent, var in ac.depends_on + } + else: + projected = state + try: + if not ac.constraint(projected, signal): + msg = ( + f"Constraint '{ac.name}' violated for " + f"block '{ac.boundary_block}'" + ) + if on_violation == "raise": + raise ConstraintViolation(msg) + elif on_violation == "zero": + logger.warning(msg) + return {} + else: + logger.warning(msg) + except ConstraintViolation: + raise + except Exception: + logger.exception( + "Constraint '%s' raised during evaluation", + ac.name, + ) + + return signal + + _guarded.__name__ = f"guarded_{getattr(policy_fn, '__name__', 'policy')}" + _guarded.__wrapped__ = policy_fn # type: ignore[attr-defined] + return _guarded diff --git a/packages/gds-analysis/gds_analysis/metrics.py b/packages/gds-analysis/gds_analysis/metrics.py new file mode 100644 index 0000000..f45f5b5 --- /dev/null +++ b/packages/gds-analysis/gds_analysis/metrics.py @@ -0,0 +1,86 @@ +"""Post-trajectory metric computation using StateMetric annotations. + +Computes distances between successive states along a trajectory, +using the distance functions declared in GDSSpec.state_metrics. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from gds import GDSSpec + from gds.constraints import StateMetric + + +def trajectory_distances( + spec: GDSSpec, + trajectory: list[dict[str, Any]], + *, + metric_name: str | None = None, +) -> dict[str, list[float]]: + """Compute state distances along a trajectory for each StateMetric. + + Parameters + ---------- + spec + GDSSpec with registered StateMetric annotations. + trajectory + List of state dicts (one per timestep). State keys should be + ``"EntityName.VariableName"`` format. + metric_name + If provided, compute distances for only this metric. Otherwise + compute for all metrics that have a distance callable. + + Returns + ------- + Dict mapping metric name to list of distances. The list has + length ``len(trajectory) - 1`` (one distance per consecutive pair). + """ + metrics = _select_metrics(spec, metric_name) + result: dict[str, list[float]] = {} + + for sm in metrics: + distances: list[float] = [] + for i in range(len(trajectory) - 1): + x_t = _extract_metric_state(sm, trajectory[i]) + x_next = _extract_metric_state(sm, trajectory[i + 1]) + if sm.distance is None: + raise ValueError(f"State metric '{sm.name}' has no distance callable") + distances.append(sm.distance(x_t, x_next)) + result[sm.name] = distances + + return result + + +def _select_metrics( + spec: GDSSpec, + metric_name: str | None, +) -> list[StateMetric]: + """Select metrics to compute, filtering out those without distance.""" + if metric_name is not None: + if metric_name not in spec.state_metrics: + raise KeyError(f"State metric '{metric_name}' not registered") + sm = spec.state_metrics[metric_name] + if sm.distance is None: + raise ValueError(f"State metric '{metric_name}' has no distance callable") + return [sm] + + return [sm for sm in spec.state_metrics.values() if sm.distance is not None] + + +def _extract_metric_state( + sm: StateMetric, + state: dict[str, Any], +) -> dict[str, Any]: + """Extract the subset of state relevant to a metric. + + Looks for keys in ``"EntityName.VariableName"`` format matching + the metric's declared variables. + """ + result: dict[str, Any] = {} + for entity_name, var_name in sm.variables: + key = f"{entity_name}.{var_name}" + if key in state: + result[key] = state[key] + return result diff --git a/packages/gds-analysis/gds_analysis/py.typed b/packages/gds-analysis/gds_analysis/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/packages/gds-analysis/gds_analysis/reachability.py b/packages/gds-analysis/gds_analysis/reachability.py new file mode 100644 index 0000000..fabdc53 --- /dev/null +++ b/packages/gds-analysis/gds_analysis/reachability.py @@ -0,0 +1,248 @@ +"""Reachable set computation via trajectory sampling. + +Paper Definition 4.1: R(x) = union over u in U_x of {f(x, u)} + +Given a state x, the reachable set R(x) is the set of all states +reachable in one step by applying any admissible input. For discrete +input spaces, this can be computed exactly by enumeration. For +continuous spaces, Monte Carlo sampling approximates R(x). + +Paper Definition 4.2: X_C is the configuration space -- the largest +set of mutually reachable states (largest SCC of the reachability graph). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from gds_sim import Model, Simulation + +if TYPE_CHECKING: + from gds import GDSSpec + +_META_KEYS = frozenset({"timestep", "substep", "run", "subset"}) + + +def reachable_set( + spec: GDSSpec, + model: Model, + state: dict[str, Any], + *, + input_samples: list[dict[str, Any]], + state_key: str | None = None, +) -> list[dict[str, Any]]: + """Compute the reachable set R(x) by running one timestep per input. + + Parameters + ---------- + spec + GDSSpec (used for structural metadata; not directly executed). + model + A gds_sim.Model with policies and SUFs already wired. + state + The current state x from which to compute reachability. + input_samples + List of input dicts to try. Each dict overrides the policy + outputs for one simulation step. For BoundaryAction blocks, + these represent exogenous inputs u. + state_key + If provided, extract only this key from each reached state + for comparison. Otherwise return full state dicts. + + Returns + ------- + List of distinct reached states (one per input sample that + produced a unique next state). + """ + reached: list[dict[str, Any]] = [] + seen: set[tuple[Any, ...]] = set() + + for sample in input_samples: + next_state = _step_once(model, state, sample) + fingerprint = _state_fingerprint(next_state, state_key) + if fingerprint not in seen: + seen.add(fingerprint) + reached.append(next_state) + + return reached + + +def reachable_graph( + spec: GDSSpec, + model: Model, + initial_states: list[dict[str, Any]], + *, + input_samples: list[dict[str, Any]], + max_depth: int = 1, + state_key: str | None = None, +) -> dict[tuple[Any, ...], list[tuple[Any, ...]]]: + """Build a reachability graph by BFS from initial states. + + Parameters + ---------- + spec + GDSSpec for structural metadata. + model + A gds_sim.Model with policies and SUFs wired. + initial_states + Starting states for the BFS. + input_samples + Inputs to try at each state (same set applied everywhere). + max_depth + Maximum BFS depth (number of steps from initial states). + state_key + Key to extract for state fingerprinting. + + Returns + ------- + Adjacency dict: state fingerprint -> list of reachable state + fingerprints. + """ + graph: dict[tuple[Any, ...], list[tuple[Any, ...]]] = {} + frontier = list(initial_states) + visited: set[tuple[Any, ...]] = set() + + for _ in range(max_depth): + next_frontier: list[dict[str, Any]] = [] + for state in frontier: + fp = _state_fingerprint(state, state_key) + if fp in visited: + continue + visited.add(fp) + + neighbors = reachable_set( + spec, + model, + state, + input_samples=input_samples, + state_key=state_key, + ) + neighbor_fps = [_state_fingerprint(n, state_key) for n in neighbors] + graph[fp] = neighbor_fps + next_frontier.extend(neighbors) + + frontier = next_frontier + if not frontier: + break + + return graph + + +def configuration_space( + graph: dict[tuple[Any, ...], list[tuple[Any, ...]]], +) -> list[set[tuple[Any, ...]]]: + """Find strongly connected components (SCCs) of a reachability graph. + + Paper Definition 4.2: X_C is the set of mutually reachable states. + + Returns SCCs sorted by size (largest first). The largest SCC is + the configuration space X_C. + + Uses iterative Tarjan's algorithm (no recursion limit). + """ + index_counter = 0 + stack: list[tuple[Any, ...]] = [] + lowlink: dict[tuple[Any, ...], int] = {} + index: dict[tuple[Any, ...], int] = {} + on_stack: set[tuple[Any, ...]] = set() + sccs: list[set[tuple[Any, ...]]] = [] + + for root in graph: + if root in index: + continue + + # Iterative Tarjan using an explicit work stack. + work: list[tuple[tuple[Any, ...], list[tuple[Any, ...]]]] = [ + (root, list(graph.get(root, []))) + ] + + while work: + v, neighbors = work[-1] + + if v not in index: + index[v] = index_counter + lowlink[v] = index_counter + index_counter += 1 + stack.append(v) + on_stack.add(v) + + found_unvisited = False + while neighbors: + w = neighbors.pop() + if w not in index: + work.append((w, list(graph.get(w, [])))) + found_unvisited = True + break + elif w in on_stack: + lowlink[v] = min(lowlink[v], index[w]) + + if found_unvisited: + continue + + # All neighbors processed — check for SCC root. + if lowlink[v] == index[v]: + scc: set[tuple[Any, ...]] = set() + while True: + w = stack.pop() + on_stack.discard(w) + scc.add(w) + if w == v: + break + sccs.append(scc) + + work.pop() + # Update parent's lowlink. + if work: + parent = work[-1][0] + lowlink[parent] = min(lowlink[parent], lowlink[v]) + + return sorted(sccs, key=len, reverse=True) + + +def _step_once( + model: Model, + state: dict[str, Any], + policy_override: dict[str, Any], +) -> dict[str, Any]: + """Run the model for exactly one timestep with overridden inputs. + + Creates a temporary model whose policies return the override dict, + runs for 1 timestep, and returns the resulting state with metadata + keys stripped. + """ + # Strip any metadata keys from incoming state (from prior BFS steps). + clean_state = {k: v for k, v in state.items() if k not in _META_KEYS} + + def _override_policy(st: dict, params: dict, **kw: Any) -> dict: + return policy_override + + override_blocks = [] + for block in model.state_update_blocks: + override_blocks.append( + { + "policies": {name: _override_policy for name in block.policies}, + "variables": dict(block.variables), + } + ) + + temp_model = Model( + initial_state=dict(clean_state), + state_update_blocks=override_blocks, + params={}, + ) + sim = Simulation(model=temp_model, timesteps=1, runs=1) + results = sim.run() + rows = results.to_list() + raw = rows[-1] if rows else dict(clean_state) + # Strip gds-sim metadata keys from the result. + return {k: v for k, v in raw.items() if k not in _META_KEYS} + + +def _state_fingerprint( + state: dict[str, Any], + state_key: str | None, +) -> tuple[Any, ...]: + """Create a hashable fingerprint of a state for deduplication.""" + if state_key is not None: + return (state_key, state.get(state_key)) + return tuple(sorted((k, v) for k, v in state.items() if k not in _META_KEYS)) diff --git a/packages/gds-analysis/pyproject.toml b/packages/gds-analysis/pyproject.toml new file mode 100644 index 0000000..faa7e25 --- /dev/null +++ b/packages/gds-analysis/pyproject.toml @@ -0,0 +1,78 @@ +[project] +name = "gds-analysis" +dynamic = ["version"] +description = "Dynamical analysis for GDS specifications — bridges gds-framework to gds-sim" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.12" +authors = [ + { name = "Rohan Mehta", email = "rohan@block.science" }, +] +keywords = [ + "generalized-dynamical-systems", + "analysis", + "reachability", + "controllability", + "gds-framework", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Typing :: Typed", +] +dependencies = [ + "gds-framework>=0.2.3", + "gds-sim>=0.1.0", +] + +[project.urls] +Homepage = "https://github.com/BlockScience/gds-core" +Repository = "https://github.com/BlockScience/gds-core" +Documentation = "https://blockscience.github.io/gds-core" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "gds_analysis/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = ["gds_analysis"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "--import-mode=importlib --cov=gds_analysis --cov-report=term-missing --no-header -q" + +[tool.coverage.run] +source = ["gds_analysis"] +omit = ["gds_analysis/__init__.py"] + +[tool.coverage.report] +fail_under = 80 +show_missing = true +exclude_lines = [ + "if TYPE_CHECKING:", + "pragma: no cover", +] + +[tool.ruff] +target-version = "py312" +line-length = 88 + +[tool.ruff.lint] +select = ["E", "W", "F", "I", "UP", "B", "SIM", "TCH", "RUF"] + +[dependency-groups] +dev = [ + "pytest>=8.0", + "pytest-cov>=5.0", + "ruff>=0.8", +] diff --git a/packages/gds-analysis/tests/__init__.py b/packages/gds-analysis/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/gds-analysis/tests/conftest.py b/packages/gds-analysis/tests/conftest.py new file mode 100644 index 0000000..8a6048d --- /dev/null +++ b/packages/gds-analysis/tests/conftest.py @@ -0,0 +1,89 @@ +"""Shared fixtures for gds-analysis tests.""" + +import math + +import gds +import pytest +from gds import ( + BoundaryAction, + GDSSpec, + Mechanism, + Policy, + interface, +) +from gds.constraints import AdmissibleInputConstraint, StateMetric + + +@pytest.fixture() +def thermostat_spec() -> GDSSpec: + """Thermostat spec with structural annotations.""" + temp_type = gds.typedef("Temperature", float, units="celsius") + cmd_type = gds.typedef("HeaterCommand", float) + + temp_space = gds.space("TemperatureSpace", temperature=temp_type) + entity = gds.entity("Room", temperature=gds.state_var(temp_type, symbol="T")) + + sensor = BoundaryAction( + name="Sensor", + interface=interface(forward_out=["Temperature"]), + ) + controller = Policy( + name="Controller", + interface=interface( + forward_in=["Temperature"], + forward_out=["Heater Command"], + ), + ) + heater = Mechanism( + name="Heater", + interface=interface(forward_in=["Heater Command"]), + updates=[("Room", "temperature")], + ) + + spec = GDSSpec(name="thermostat") + spec.collect( + temp_type, + cmd_type, + temp_space, + entity, + sensor, + controller, + heater, + ) + spec.register_wiring( + gds.SpecWiring( + name="main", + block_names=["Sensor", "Controller", "Heater"], + wires=[ + gds.Wire(source="Sensor", target="Controller"), + gds.Wire(source="Controller", target="Heater"), + ], + ) + ) + + # Structural annotations + spec.register_admissibility( + AdmissibleInputConstraint( + name="sensor_range", + boundary_block="Sensor", + depends_on=[("Room", "temperature")], + constraint=lambda state, signal: ( + signal.get("temperature", 0) >= -50 + and signal.get("temperature", 0) <= 100 + ), + description="Sensor reads must be in [-50, 100]", + ) + ) + spec.register_state_metric( + StateMetric( + name="temp_distance", + variables=[("Room", "temperature")], + metric_type="euclidean", + distance=lambda a, b: math.sqrt( + sum((a.get(k, 0) - b.get(k, 0)) ** 2 for k in set(a) | set(b)) + ), + description="Euclidean distance on temperature", + ) + ) + + return spec diff --git a/packages/gds-analysis/tests/test_adapter.py b/packages/gds-analysis/tests/test_adapter.py new file mode 100644 index 0000000..d49faa1 --- /dev/null +++ b/packages/gds-analysis/tests/test_adapter.py @@ -0,0 +1,122 @@ +"""Tests for the GDSSpec -> gds_sim.Model adapter.""" + +import pytest +from gds import GDSSpec +from gds_sim import Model + +from gds_analysis.adapter import spec_to_model + + +def _sensor_policy(state, params, **kw): + return {"temperature": state.get("Room.temperature", 20.0)} + + +def _controller_policy(state, params, **kw): + temp = state.get("Room.temperature", 20.0) + setpoint = params.get("setpoint", 22.0) + return {"command": (setpoint - temp) * params.get("gain", 0.5)} + + +def _heater_suf(state, params, *, signal=None, **kw): + signal = signal or {} + command = signal.get("command", 0.0) + temp = state.get("Room.temperature", 20.0) + return "Room.temperature", temp + command * 0.1 + + +class TestSpecToModel: + def test_returns_model(self, thermostat_spec: GDSSpec) -> None: + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 18.0}, + ) + assert isinstance(model, Model) + + def test_missing_policy_raises(self, thermostat_spec: GDSSpec) -> None: + with pytest.raises(ValueError, match=r"Missing policy.*Sensor"): + spec_to_model( + thermostat_spec, + policies={"Controller": _controller_policy}, + sufs={"Heater": _heater_suf}, + ) + + def test_missing_suf_raises(self, thermostat_spec: GDSSpec) -> None: + with pytest.raises(ValueError, match=r"Missing state update.*Heater"): + spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={}, + ) + + def test_default_initial_state(self, thermostat_spec: GDSSpec) -> None: + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + ) + assert "Room.temperature" in model.initial_state + assert model.initial_state["Room.temperature"] == 0.0 + + def test_runs_simulation(self, thermostat_spec: GDSSpec) -> None: + from gds_sim import Simulation + + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 18.0}, + params={"setpoint": [22.0], "gain": [0.5]}, + ) + sim = Simulation(model=model, timesteps=10, runs=1) + results = sim.run() + rows = results.to_list() + assert len(rows) > 0 + last = rows[-1] + assert last["Room.temperature"] > 18.0 + + def test_constraint_enforcement(self, thermostat_spec: GDSSpec) -> None: + """Constraint guard allows valid signals through.""" + from gds_sim import Simulation + + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 20.0}, + params={"setpoint": [22.0], "gain": [0.5]}, + enforce_constraints=True, + ) + sim = Simulation(model=model, timesteps=5, runs=1) + results = sim.run() + assert len(results) > 0 + + def test_no_constraints(self, thermostat_spec: GDSSpec) -> None: + """Works with enforce_constraints=False.""" + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 20.0}, + enforce_constraints=False, + ) + assert isinstance(model, Model) diff --git a/packages/gds-analysis/tests/test_constraints.py b/packages/gds-analysis/tests/test_constraints.py new file mode 100644 index 0000000..abd482d --- /dev/null +++ b/packages/gds-analysis/tests/test_constraints.py @@ -0,0 +1,85 @@ +"""Tests for runtime constraint enforcement.""" + +import pytest +from gds.constraints import AdmissibleInputConstraint + +from gds_analysis.constraints import ConstraintViolation, guarded_policy + + +def _simple_policy(state, params, **kw): + return {"value": state.get("x", 0) + 1} + + +def _always_valid(state, signal): + return True + + +def _always_invalid(state, signal): + return False + + +def _range_check(state, signal): + return 0 <= signal.get("value", 0) <= 10 + + +class TestGuardedPolicy: + def test_passes_valid_signal(self) -> None: + ac = AdmissibleInputConstraint( + name="check", boundary_block="b", constraint=_always_valid + ) + guarded = guarded_policy(_simple_policy, [ac]) + result = guarded({"x": 5}, {}) + assert result == {"value": 6} + + def test_warn_on_violation(self) -> None: + ac = AdmissibleInputConstraint( + name="check", boundary_block="b", constraint=_always_invalid + ) + guarded = guarded_policy(_simple_policy, [ac], on_violation="warn") + # Should still return the signal (with warning) + result = guarded({"x": 5}, {}) + assert result == {"value": 6} + + def test_raise_on_violation(self) -> None: + ac = AdmissibleInputConstraint( + name="check", boundary_block="b", constraint=_always_invalid + ) + guarded = guarded_policy(_simple_policy, [ac], on_violation="raise") + with pytest.raises(ConstraintViolation, match="check"): + guarded({"x": 5}, {}) + + def test_zero_on_violation(self) -> None: + ac = AdmissibleInputConstraint( + name="check", boundary_block="b", constraint=_always_invalid + ) + guarded = guarded_policy(_simple_policy, [ac], on_violation="zero") + result = guarded({"x": 5}, {}) + assert result == {} + + def test_range_constraint(self) -> None: + ac = AdmissibleInputConstraint( + name="range", boundary_block="b", constraint=_range_check + ) + guarded = guarded_policy(_simple_policy, [ac], on_violation="raise") + # x=5 → value=6, within [0, 10] → passes + result = guarded({"x": 5}, {}) + assert result == {"value": 6} + + # x=10 → value=11, outside [0, 10] → fails + with pytest.raises(ConstraintViolation): + guarded({"x": 10}, {}) + + def test_none_constraint_skipped(self) -> None: + ac = AdmissibleInputConstraint( + name="no_fn", boundary_block="b", constraint=None + ) + guarded = guarded_policy(_simple_policy, [ac]) + result = guarded({"x": 5}, {}) + assert result == {"value": 6} + + def test_preserves_name(self) -> None: + ac = AdmissibleInputConstraint( + name="check", boundary_block="b", constraint=_always_valid + ) + guarded = guarded_policy(_simple_policy, [ac]) + assert "simple_policy" in guarded.__name__ diff --git a/packages/gds-analysis/tests/test_crosswalk_integration.py b/packages/gds-analysis/tests/test_crosswalk_integration.py new file mode 100644 index 0000000..ed085fe --- /dev/null +++ b/packages/gds-analysis/tests/test_crosswalk_integration.py @@ -0,0 +1,339 @@ +"""End-to-end integration test: Crosswalk problem with gds-analysis. + +Demonstrates gds-analysis on a discrete Markov system with all 4 block +roles (BoundaryAction, Policy, ControlAction, Mechanism) and a design +parameter (crosswalk_location). + +Key properties tested (from Zargham & Shorish crosswalk lectures): + - Crosswalk safety guarantee: p=k overrides bad luck + - Flowing unreachable from Accident in one step + - Accident reachable via jaywalking with bad luck + - Design parameter k=median minimizes accident probability + - Stationary distribution: P(Flowing) → 0 under random actions +""" + +import gds +from gds import ( + GDSSpec, + Mechanism, + Policy, + SpecWiring, + Wire, + interface, + typedef, +) +from gds.blocks.roles import BoundaryAction, ControlAction +from gds.constraints import AdmissibleInputConstraint, StateMetric +from gds_sim import Simulation + +from gds_analysis.adapter import spec_to_model +from gds_analysis.metrics import trajectory_distances +from gds_analysis.reachability import ( + configuration_space, + reachable_graph, + reachable_set, +) + +# --------------------------------------------------------------------------- +# Structural spec +# --------------------------------------------------------------------------- + +TrafficState = typedef( + "TrafficState", + int, + constraint=lambda x: x in {-1, 0, 1}, +) +BinaryChoice = typedef("BinaryChoice", int, constraint=lambda x: x in {0, 1}) +Position = typedef("Position", float, constraint=lambda x: 0.0 <= x <= 1.0) + + +def _build_crosswalk_spec() -> GDSSpec: + street = gds.entity("Street", traffic_state=gds.state_var(TrafficState, symbol="X")) + + observe = BoundaryAction( + name="Observe Traffic", + interface=interface(forward_out=["Observation Signal"]), + ) + decide = Policy( + name="Pedestrian Decision", + interface=interface( + forward_in=["Observation Signal"], + forward_out=["Crossing Decision"], + ), + ) + check = ControlAction( + name="Safety Check", + interface=interface( + forward_in=["Crossing Decision"], + forward_out=["Safety Signal"], + ), + params_used=["crosswalk_location"], + ) + transition = Mechanism( + name="Traffic Transition", + interface=interface(forward_in=["Safety Signal"]), + updates=[("Street", "traffic_state")], + ) + + spec = GDSSpec(name="Crosswalk Problem") + spec.collect(TrafficState, BinaryChoice, Position, street) + spec.collect(observe, decide, check, transition) + spec.register_parameter("crosswalk_location", Position) + spec.register_wiring( + SpecWiring( + name="Crosswalk Pipeline", + block_names=[ + "Observe Traffic", + "Pedestrian Decision", + "Safety Check", + "Traffic Transition", + ], + wires=[ + Wire(source="Observe Traffic", target="Pedestrian Decision"), + Wire(source="Pedestrian Decision", target="Safety Check"), + Wire(source="Safety Check", target="Traffic Transition"), + ], + ) + ) + return spec + + +# --------------------------------------------------------------------------- +# Behavioral functions (matching crosswalk lecture Markov semantics) +# --------------------------------------------------------------------------- + + +def observe_policy(state, params, **kw): + """BoundaryAction: emit current traffic state + luck.""" + return { + "traffic_state": state.get("Street.traffic_state", 1), + "luck": 1, # default: good luck + } + + +def pedestrian_policy(state, params, **kw): + """Policy: decide whether to cross and where.""" + return { + "cross": 1, + "position": 0.5, + } + + +def safety_policy(state, params, **kw): + """ControlAction: check crossing safety given crosswalk location k. + + At crosswalk (|p - k| < 0.1): always safe regardless of luck. + Jaywalking with bad luck: unsafe. + Jaywalking with good luck: safe. + """ + return {"safe_crossing": 1, "cross": 1} + + +def transition_suf(state, params, *, signal=None, **kw): + """Mechanism: Markov state transition. + + From Flowing/Stopped: + - Don't cross (s=0) → Flowing (+1) + - Cross safely → Stopped (0) + - Cross unsafely → Accident (-1) + + From Accident: + - 50% remain Accident, 50% → Stopped + - NEVER directly → Flowing + """ + signal = signal or {} + current = state.get("Street.traffic_state", 1) + cross = signal.get("cross", 0) + safe = signal.get("safe_crossing", 1) + + if current == -1: + # Accident recovery: can only go to Stopped, never Flowing + return "Street.traffic_state", 0 + + if cross == 0: + return "Street.traffic_state", 1 # Flowing + elif safe == 1: + return "Street.traffic_state", 0 # Stopped safely + else: + return "Street.traffic_state", -1 # Accident + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestCrosswalkEndToEnd: + def _build_model(self, crosswalk_location=0.5, enforce=True): + spec = _build_crosswalk_spec() + + spec.register_admissibility( + AdmissibleInputConstraint( + name="valid_traffic_state", + boundary_block="Observe Traffic", + depends_on=[("Street", "traffic_state")], + constraint=lambda state, signal: ( + signal.get("traffic_state", 0) in {-1, 0, 1} + ), + description="Traffic state must be valid", + ) + ) + spec.register_state_metric( + StateMetric( + name="state_change", + variables=[("Street", "traffic_state")], + metric_type="absolute", + distance=lambda a, b: abs( + a.get("Street.traffic_state", 0) - b.get("Street.traffic_state", 0) + ), + ) + ) + + model = spec_to_model( + spec, + policies={ + "Observe Traffic": observe_policy, + "Pedestrian Decision": pedestrian_policy, + "Safety Check": safety_policy, + }, + sufs={"Traffic Transition": transition_suf}, + initial_state={"Street.traffic_state": 1}, + params={"crosswalk_location": [crosswalk_location]}, + enforce_constraints=enforce, + ) + return spec, model + + def test_simulation_runs(self) -> None: + _, model = self._build_model() + sim = Simulation(model=model, timesteps=20, runs=1) + results = sim.run() + assert len(results) > 0 + + def test_traffic_state_valid(self) -> None: + """All states should be in {-1, 0, +1}.""" + _, model = self._build_model() + sim = Simulation(model=model, timesteps=20, runs=1) + for row in sim.run().to_list(): + ts = row.get("Street.traffic_state", 999) + assert ts in {-1, 0, 1}, f"Invalid traffic state: {ts}" + + def test_trajectory_distances(self) -> None: + spec, model = self._build_model() + sim = Simulation(model=model, timesteps=10, runs=1) + trajectory = sim.run().to_list() + distances = trajectory_distances(spec, trajectory) + assert "state_change" in distances + assert all(d >= 0 for d in distances["state_change"]) + + # --- Crosswalk-specific reachability (from lectures) --- + + def test_crosswalk_safety_guarantee(self) -> None: + """Crossing at crosswalk (p=k) → Stopped, never Accident. + + Even with bad luck (l=0), crossing at the crosswalk is safe. + """ + spec, model = self._build_model(enforce=False) + state = {"Street.traffic_state": 1} + # Cross at crosswalk with bad luck: still safe + samples = [{"cross": 1, "safe_crossing": 1}] + reached = reachable_set( + spec, + model, + state, + input_samples=samples, + state_key="Street.traffic_state", + ) + assert all(r["Street.traffic_state"] == 0 for r in reached) + + def test_accident_reachable_via_jaywalking(self) -> None: + """Jaywalking with bad luck → Accident (-1).""" + spec, model = self._build_model(enforce=False) + state = {"Street.traffic_state": 1} + samples = [{"cross": 1, "safe_crossing": 0}] + reached = reachable_set( + spec, + model, + state, + input_samples=samples, + state_key="Street.traffic_state", + ) + assert any(r["Street.traffic_state"] == -1 for r in reached) + + def test_flowing_unreachable_from_accident(self) -> None: + """From Accident (-1), Flowing (+1) is unreachable in one step. + + Accident can only recover to Stopped (0), never directly to + Flowing (+1). + """ + spec, model = self._build_model(enforce=False) + state = {"Street.traffic_state": -1} + # Try all possible inputs from Accident + samples = [ + {"cross": 0, "safe_crossing": 1}, + {"cross": 1, "safe_crossing": 1}, + {"cross": 1, "safe_crossing": 0}, + ] + reached = reachable_set( + spec, + model, + state, + input_samples=samples, + state_key="Street.traffic_state", + ) + reached_states = {r["Street.traffic_state"] for r in reached} + assert 1 not in reached_states, ( + "Flowing (+1) should be unreachable from Accident (-1)" + ) + + def test_not_crossing_preserves_flowing(self) -> None: + """Not crossing (s=0) keeps traffic Flowing (+1).""" + spec, model = self._build_model(enforce=False) + state = {"Street.traffic_state": 1} + samples = [{"cross": 0, "safe_crossing": 1}] + reached = reachable_set( + spec, + model, + state, + input_samples=samples, + state_key="Street.traffic_state", + ) + assert all(r["Street.traffic_state"] == 1 for r in reached) + + def test_all_three_states_reachable_from_flowing(self) -> None: + """From Flowing, all three states are reachable.""" + spec, model = self._build_model(enforce=False) + state = {"Street.traffic_state": 1} + samples = [ + {"cross": 0, "safe_crossing": 1}, # → Flowing + {"cross": 1, "safe_crossing": 1}, # → Stopped + {"cross": 1, "safe_crossing": 0}, # → Accident + ] + reached = reachable_set( + spec, + model, + state, + input_samples=samples, + state_key="Street.traffic_state", + ) + reached_states = {r["Street.traffic_state"] for r in reached} + assert reached_states == {-1, 0, 1} + + def test_configuration_space_from_all_states(self) -> None: + """SCCs from a 2-depth BFS starting from all three states.""" + spec, model = self._build_model(enforce=False) + samples = [ + {"cross": 0, "safe_crossing": 1}, + {"cross": 1, "safe_crossing": 1}, + {"cross": 1, "safe_crossing": 0}, + ] + initials = [{"Street.traffic_state": s} for s in [1, 0, -1]] + graph = reachable_graph( + spec, + model, + initials, + input_samples=samples, + max_depth=2, + state_key="Street.traffic_state", + ) + sccs = configuration_space(graph) + assert len(sccs) >= 1 diff --git a/packages/gds-analysis/tests/test_metrics.py b/packages/gds-analysis/tests/test_metrics.py new file mode 100644 index 0000000..97b2a11 --- /dev/null +++ b/packages/gds-analysis/tests/test_metrics.py @@ -0,0 +1,105 @@ +"""Tests for trajectory distance computation.""" + +import pytest +from gds import GDSSpec +from gds.constraints import StateMetric + +from gds_analysis.metrics import trajectory_distances + + +class TestTrajectoryDistances: + def test_euclidean_distance(self, thermostat_spec: GDSSpec) -> None: + trajectory = [ + {"Room.temperature": 18.0}, + {"Room.temperature": 20.0}, + {"Room.temperature": 21.0}, + ] + result = trajectory_distances(thermostat_spec, trajectory) + assert "temp_distance" in result + assert len(result["temp_distance"]) == 2 + assert result["temp_distance"][0] == pytest.approx(2.0) + assert result["temp_distance"][1] == pytest.approx(1.0) + + def test_single_metric_by_name(self, thermostat_spec: GDSSpec) -> None: + trajectory = [ + {"Room.temperature": 0.0}, + {"Room.temperature": 3.0}, + ] + result = trajectory_distances( + thermostat_spec, trajectory, metric_name="temp_distance" + ) + assert list(result.keys()) == ["temp_distance"] + assert result["temp_distance"] == [pytest.approx(3.0)] + + def test_unknown_metric_raises(self, thermostat_spec: GDSSpec) -> None: + with pytest.raises(KeyError, match="nonexistent"): + trajectory_distances(thermostat_spec, [{}], metric_name="nonexistent") + + def test_no_distance_callable_raises(self, thermostat_spec: GDSSpec) -> None: + thermostat_spec.register_state_metric( + StateMetric( + name="structural_only", + variables=[("Room", "temperature")], + metric_type="euclidean", + distance=None, + ) + ) + with pytest.raises(ValueError, match="no distance callable"): + trajectory_distances(thermostat_spec, [{}], metric_name="structural_only") + + def test_empty_trajectory(self, thermostat_spec: GDSSpec) -> None: + result = trajectory_distances(thermostat_spec, [{"Room.temperature": 0.0}]) + assert result["temp_distance"] == [] + + def test_skips_metrics_without_distance(self, thermostat_spec: GDSSpec) -> None: + thermostat_spec.register_state_metric( + StateMetric( + name="no_fn", + variables=[("Room", "temperature")], + distance=None, + ) + ) + trajectory = [ + {"Room.temperature": 0.0}, + {"Room.temperature": 1.0}, + ] + result = trajectory_distances(thermostat_spec, trajectory) + assert "temp_distance" in result + assert "no_fn" not in result + + def test_integration_with_simulation(self, thermostat_spec: GDSSpec) -> None: + """End-to-end: spec -> model -> simulate -> measure distances.""" + from gds_sim import Simulation + + from gds_analysis.adapter import spec_to_model + + def sensor_policy(state, params, **kw): + return {"temperature": state.get("Room.temperature", 20.0)} + + def controller_policy(state, params, **kw): + temp = state.get("Room.temperature", 20.0) + return {"command": (22.0 - temp) * 0.5} + + def heater_suf(state, params, *, signal=None, **kw): + signal = signal or {} + temp = state.get("Room.temperature", 20.0) + return "Room.temperature", temp + signal.get("command", 0) * 0.1 + + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": sensor_policy, + "Controller": controller_policy, + }, + sufs={"Heater": heater_suf}, + initial_state={"Room.temperature": 18.0}, + params={"setpoint": [22.0]}, + ) + sim = Simulation(model=model, timesteps=20, runs=1) + results = sim.run() + trajectory = results.to_list() + + distances = trajectory_distances(thermostat_spec, trajectory) + assert "temp_distance" in distances + assert len(distances["temp_distance"]) == len(trajectory) - 1 + assert all(d >= 0 for d in distances["temp_distance"]) diff --git a/packages/gds-analysis/tests/test_reachability.py b/packages/gds-analysis/tests/test_reachability.py new file mode 100644 index 0000000..75f77e1 --- /dev/null +++ b/packages/gds-analysis/tests/test_reachability.py @@ -0,0 +1,208 @@ +"""Tests for reachable set and configuration space computation.""" + +from gds import GDSSpec + +from gds_analysis.adapter import spec_to_model +from gds_analysis.reachability import ( + configuration_space, + reachable_graph, + reachable_set, +) + + +def _sensor_policy(state, params, **kw): + return {"temperature": state.get("Room.temperature", 20.0)} + + +def _controller_policy(state, params, **kw): + temp = state.get("Room.temperature", 20.0) + return {"command": (22.0 - temp) * 0.5} + + +def _heater_suf(state, params, *, signal=None, **kw): + signal = signal or {} + temp = state.get("Room.temperature", 20.0) + command = signal.get("command", 0.0) + return "Room.temperature", temp + command * 0.1 + + +class TestReachableSet: + def _make_model(self, spec: GDSSpec): + return spec_to_model( + spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 20.0}, + enforce_constraints=False, + ) + + def test_single_input(self, thermostat_spec: GDSSpec) -> None: + model = self._make_model(thermostat_spec) + state = {"Room.temperature": 20.0} + samples = [{"command": 1.0}] + reached = reachable_set( + thermostat_spec, + model, + state, + input_samples=samples, + state_key="Room.temperature", + ) + assert len(reached) == 1 + assert reached[0]["Room.temperature"] != 20.0 + + def test_multiple_inputs_distinct(self, thermostat_spec: GDSSpec) -> None: + model = self._make_model(thermostat_spec) + state = {"Room.temperature": 20.0} + samples = [ + {"command": 0.0}, + {"command": 1.0}, + {"command": 2.0}, + ] + reached = reachable_set( + thermostat_spec, + model, + state, + input_samples=samples, + state_key="Room.temperature", + ) + assert len(reached) == 3 + + def test_duplicate_inputs_deduplicated(self, thermostat_spec: GDSSpec) -> None: + model = self._make_model(thermostat_spec) + state = {"Room.temperature": 20.0} + samples = [ + {"command": 1.0}, + {"command": 1.0}, + ] + reached = reachable_set( + thermostat_spec, + model, + state, + input_samples=samples, + state_key="Room.temperature", + ) + assert len(reached) == 1 + + def test_empty_inputs(self, thermostat_spec: GDSSpec) -> None: + model = self._make_model(thermostat_spec) + reached = reachable_set( + thermostat_spec, + model, + {"Room.temperature": 20.0}, + input_samples=[], + ) + assert reached == [] + + +class TestReachableGraph: + def _make_model(self, spec: GDSSpec): + return spec_to_model( + spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 20.0}, + enforce_constraints=False, + ) + + def test_depth_1(self, thermostat_spec: GDSSpec) -> None: + model = self._make_model(thermostat_spec) + graph = reachable_graph( + thermostat_spec, + model, + [{"Room.temperature": 20.0}], + input_samples=[{"command": 1.0}, {"command": -1.0}], + max_depth=1, + state_key="Room.temperature", + ) + assert len(graph) >= 1 + + def test_depth_2_expands(self, thermostat_spec: GDSSpec) -> None: + model = self._make_model(thermostat_spec) + graph_1 = reachable_graph( + thermostat_spec, + model, + [{"Room.temperature": 20.0}], + input_samples=[{"command": 1.0}], + max_depth=1, + state_key="Room.temperature", + ) + graph_2 = reachable_graph( + thermostat_spec, + model, + [{"Room.temperature": 20.0}], + input_samples=[{"command": 1.0}], + max_depth=2, + state_key="Room.temperature", + ) + assert len(graph_2) >= len(graph_1) + + +class TestConfigurationSpace: + def test_single_node_scc(self) -> None: + graph = {("a",): [("a",)]} + sccs = configuration_space(graph) + assert len(sccs) == 1 + assert ("a",) in sccs[0] + + def test_two_node_cycle(self) -> None: + graph = { + ("a",): [("b",)], + ("b",): [("a",)], + } + sccs = configuration_space(graph) + assert len(sccs) == 1 + assert sccs[0] == {("a",), ("b",)} + + def test_disconnected_components(self) -> None: + graph = { + ("a",): [("b",)], + ("b",): [("a",)], + ("c",): [("d",)], + ("d",): [], + } + sccs = configuration_space(graph) + # Largest SCC first + assert sccs[0] == {("a",), ("b",)} + assert len(sccs) >= 2 + + def test_dag_no_cycles(self) -> None: + graph = { + ("a",): [("b",)], + ("b",): [("c",)], + ("c",): [], + } + sccs = configuration_space(graph) + # Each node is its own SCC (no cycles) + assert all(len(scc) == 1 for scc in sccs) + + def test_integration(self, thermostat_spec: GDSSpec) -> None: + """End-to-end: spec -> model -> reachability graph -> SCCs.""" + model = spec_to_model( + thermostat_spec, + policies={ + "Sensor": _sensor_policy, + "Controller": _controller_policy, + }, + sufs={"Heater": _heater_suf}, + initial_state={"Room.temperature": 20.0}, + enforce_constraints=False, + ) + graph = reachable_graph( + thermostat_spec, + model, + [{"Room.temperature": 20.0}], + input_samples=[ + {"command": 0.5}, + {"command": -0.5}, + ], + max_depth=3, + state_key="Room.temperature", + ) + sccs = configuration_space(graph) + assert len(sccs) >= 1 diff --git a/packages/gds-analysis/tests/test_sir_integration.py b/packages/gds-analysis/tests/test_sir_integration.py new file mode 100644 index 0000000..b9c93b1 --- /dev/null +++ b/packages/gds-analysis/tests/test_sir_integration.py @@ -0,0 +1,327 @@ +"""End-to-end integration test: SIR epidemic with gds-analysis. + +Demonstrates the full pipeline: + spec (gds-framework) → adapter (gds-analysis) → simulate (gds-sim) + → metrics + reachability (gds-analysis) + +Uses the SIR epidemic model from gds-examples as the structural spec, +with behavioral functions defined here. +""" + +import math + +from gds import ( + BoundaryAction, + GDSSpec, + Mechanism, + Policy, + SpecWiring, + Wire, + interface, + typedef, +) +from gds.constraints import AdmissibleInputConstraint, StateMetric +from gds_sim import Simulation + +from gds_analysis.adapter import spec_to_model +from gds_analysis.metrics import trajectory_distances +from gds_analysis.reachability import ( + configuration_space, + reachable_graph, + reachable_set, +) + +# --------------------------------------------------------------------------- +# Structural spec (R1 — from gds-framework) +# --------------------------------------------------------------------------- + +Count = typedef("Count", int, constraint=lambda x: x >= 0) +Rate = typedef("Rate", float, constraint=lambda x: x > 0) + + +def _build_sir_spec() -> GDSSpec: + """Minimal SIR spec for integration testing.""" + import gds + + entity_s = gds.entity("Susceptible", count=gds.state_var(Count, symbol="S")) + entity_i = gds.entity("Infected", count=gds.state_var(Count, symbol="I")) + entity_r = gds.entity("Recovered", count=gds.state_var(Count, symbol="R")) + + contact = BoundaryAction( + name="Contact Process", + interface=interface(forward_out=["Contact Signal"]), + params_used=["contact_rate"], + ) + policy = Policy( + name="Infection Policy", + interface=interface( + forward_in=["Contact Signal"], + forward_out=["Susceptible Delta", "Infected Delta", "Recovered Delta"], + ), + params_used=["beta", "gamma"], + ) + update_s = Mechanism( + name="Update Susceptible", + interface=interface(forward_in=["Susceptible Delta"]), + updates=[("Susceptible", "count")], + ) + update_i = Mechanism( + name="Update Infected", + interface=interface(forward_in=["Infected Delta"]), + updates=[("Infected", "count")], + ) + update_r = Mechanism( + name="Update Recovered", + interface=interface(forward_in=["Recovered Delta"]), + updates=[("Recovered", "count")], + ) + + spec = GDSSpec(name="SIR Epidemic") + spec.collect(Count, Rate, entity_s, entity_i, entity_r) + spec.collect(contact, policy, update_s, update_i, update_r) + spec.register_parameter("beta", Rate) + spec.register_parameter("gamma", Rate) + spec.register_parameter("contact_rate", Rate) + spec.register_wiring( + SpecWiring( + name="SIR Pipeline", + block_names=[ + "Contact Process", + "Infection Policy", + "Update Susceptible", + "Update Infected", + "Update Recovered", + ], + wires=[ + Wire(source="Contact Process", target="Infection Policy"), + Wire(source="Infection Policy", target="Update Susceptible"), + Wire(source="Infection Policy", target="Update Infected"), + Wire(source="Infection Policy", target="Update Recovered"), + ], + ) + ) + return spec + + +# --------------------------------------------------------------------------- +# Behavioral functions (R3 — user-supplied) +# --------------------------------------------------------------------------- + + +def contact_policy(state, params, **kw): + """BoundaryAction: emit the exogenous contact rate.""" + return {"contact_rate": params.get("contact_rate", 5.0)} + + +def infection_policy(state, params, **kw): + """Policy: compute population deltas from SIR dynamics. + + dS = -beta * S * I / N + dI = beta * S * I / N - gamma * I + dR = gamma * I + """ + s = state.get("Susceptible.count", 0.0) + i = state.get("Infected.count", 0.0) + r = state.get("Recovered.count", 0.0) + n = s + i + r or 1.0 + + beta = params.get("beta", 0.3) + gamma = params.get("gamma", 0.1) + + new_infections = beta * s * i / n + recoveries = gamma * i + + return { + "delta_s": -new_infections, + "delta_i": new_infections - recoveries, + "delta_r": recoveries, + } + + +def suf_susceptible(state, params, *, signal=None, **kw): + signal = signal or {} + s = state.get("Susceptible.count", 0.0) + return "Susceptible.count", max(0.0, s + signal.get("delta_s", 0.0)) + + +def suf_infected(state, params, *, signal=None, **kw): + signal = signal or {} + i = state.get("Infected.count", 0.0) + return "Infected.count", max(0.0, i + signal.get("delta_i", 0.0)) + + +def suf_recovered(state, params, *, signal=None, **kw): + signal = signal or {} + r = state.get("Recovered.count", 0.0) + return "Recovered.count", max(0.0, r + signal.get("delta_r", 0.0)) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSIREndToEnd: + """Full pipeline: spec → model → simulate → analyze.""" + + def _build_sir_model(self, enforce_constraints=True): + spec = _build_sir_spec() + + # Add structural annotations + spec.register_admissibility( + AdmissibleInputConstraint( + name="contact_rate_positive", + boundary_block="Contact Process", + depends_on=[], + constraint=lambda state, signal: signal.get("contact_rate", 0) > 0, + description="Contact rate must be positive", + ) + ) + spec.register_state_metric( + StateMetric( + name="population_distance", + variables=[ + ("Susceptible", "count"), + ("Infected", "count"), + ("Recovered", "count"), + ], + metric_type="euclidean", + distance=lambda a, b: math.sqrt( + sum((a.get(k, 0) - b.get(k, 0)) ** 2 for k in set(a) | set(b)) + ), + ) + ) + + model = spec_to_model( + spec, + policies={ + "Contact Process": contact_policy, + "Infection Policy": infection_policy, + }, + sufs={ + "Update Susceptible": suf_susceptible, + "Update Infected": suf_infected, + "Update Recovered": suf_recovered, + }, + initial_state={ + "Susceptible.count": 999.0, + "Infected.count": 1.0, + "Recovered.count": 0.0, + }, + params={ + "beta": [0.3], + "gamma": [0.1], + "contact_rate": [5.0], + }, + enforce_constraints=enforce_constraints, + ) + return spec, model + + def test_simulation_runs(self) -> None: + _, model = self._build_sir_model() + sim = Simulation(model=model, timesteps=50, runs=1) + results = sim.run() + rows = results.to_list() + assert len(rows) > 0 + + def test_population_conserved(self) -> None: + """S + I + R should remain constant (= 1000).""" + _, model = self._build_sir_model() + sim = Simulation(model=model, timesteps=50, runs=1) + results = sim.run() + for row in results.to_list(): + s = row.get("Susceptible.count", 0.0) + i = row.get("Infected.count", 0.0) + r = row.get("Recovered.count", 0.0) + total = s + i + r + assert abs(total - 1000.0) < 1e-6, f"Population not conserved: {total}" + + def test_epidemic_progresses(self) -> None: + """Infected count should rise from initial 1.""" + _, model = self._build_sir_model() + sim = Simulation(model=model, timesteps=50, runs=1) + rows = sim.run().to_list() + peak_infected = max(row.get("Infected.count", 0) for row in rows) + assert peak_infected > 1 + + def test_trajectory_distances(self) -> None: + """StateMetric distances should be non-negative.""" + spec, model = self._build_sir_model() + sim = Simulation(model=model, timesteps=20, runs=1) + trajectory = sim.run().to_list() + distances = trajectory_distances(spec, trajectory) + assert "population_distance" in distances + assert all(d >= 0 for d in distances["population_distance"]) + + def test_reachable_set(self) -> None: + """R(x) from initial state with varied contact rates.""" + spec, model = self._build_sir_model(enforce_constraints=False) + state = { + "Susceptible.count": 999.0, + "Infected.count": 1.0, + "Recovered.count": 0.0, + } + samples = [ + {"contact_rate": c, "delta_s": 0, "delta_i": 0, "delta_r": 0} + for c in [1.0, 5.0, 10.0, 20.0] + ] + reached = reachable_set( + spec, + model, + state, + input_samples=samples, + state_key="Infected.count", + ) + assert len(reached) >= 1 + + def test_reachability_graph(self) -> None: + """Build a small reachability graph from initial state.""" + spec, model = self._build_sir_model(enforce_constraints=False) + initial = { + "Susceptible.count": 999.0, + "Infected.count": 1.0, + "Recovered.count": 0.0, + } + samples = [ + {"delta_s": -1, "delta_i": 1, "delta_r": 0}, + {"delta_s": 0, "delta_i": -1, "delta_r": 1}, + ] + graph = reachable_graph( + spec, + model, + [initial], + input_samples=samples, + max_depth=2, + state_key="Infected.count", + ) + assert len(graph) >= 1 + + def test_configuration_space(self) -> None: + """SCCs should exist in the reachability graph.""" + spec, model = self._build_sir_model(enforce_constraints=False) + initial = { + "Susceptible.count": 999.0, + "Infected.count": 1.0, + "Recovered.count": 0.0, + } + samples = [ + {"delta_s": -1, "delta_i": 1, "delta_r": 0}, + {"delta_s": 0, "delta_i": -1, "delta_r": 1}, + ] + graph = reachable_graph( + spec, + model, + [initial], + input_samples=samples, + max_depth=2, + state_key="Infected.count", + ) + sccs = configuration_space(graph) + assert len(sccs) >= 1 + + def test_constraint_enforcement(self) -> None: + """Constraint guard should allow valid contact rates.""" + _, model = self._build_sir_model(enforce_constraints=True) + sim = Simulation(model=model, timesteps=5, runs=1) + results = sim.run() + assert len(results) > 0 diff --git a/packages/gds-continuous/CLAUDE.md b/packages/gds-continuous/CLAUDE.md new file mode 100644 index 0000000..577b326 --- /dev/null +++ b/packages/gds-continuous/CLAUDE.md @@ -0,0 +1,74 @@ +# CLAUDE.md -- gds-continuous + +## Package Identity + +`gds-continuous` is a continuous-time ODE integration engine for the GDS +ecosystem. It wraps `scipy.integrate.solve_ivp` with a Pydantic-validated +model layer and columnar result storage. + +- **PyPI**: `gds-continuous` (install with `uv add gds-continuous[scipy]`) +- **Import**: `import gds_continuous` +- **Standalone**: pydantic-only runtime dep (like gds-sim). No gds-framework dependency. + +## Architecture + +Mirrors `gds-sim` but for continuous-time: + +| gds-sim (discrete) | gds-continuous | Difference | +|---------------------|----------------|------------| +| `Model` | `ODEModel` | State as `dict[str, float]`, not `dict[str, Any]` | +| `Simulation` | `ODESimulation` | `t_span` + solver config instead of `timesteps` | +| `Results` | `ODEResults` | `time` (float) instead of `timestep`/`substep` | +| `StateUpdateBlock` | `ODEFunction` | Single RHS callable, not policy+SUF split | + +## Known Limitations + +### GDSSpec-to-ODEModel bridge gap + +There is no `spec_to_ode_model()` adapter (unlike `gds-analysis.spec_to_model()` +for discrete-time). The intended workflow for verified continuous-time simulation: + +``` +SymbolicControlModel + .compile() --> GDSSpec (structural verification via SC-001..SC-009) + .to_ode_function() --> ODEFunction (behavioral, R3) + +ODEModel(state_names=..., initial_state=..., rhs=ode_fn) + --> ODESimulation.run() --> ODEResults +``` + +Initial conditions are a user concern -- GDSSpec carries structural metadata +(entities, variables, types), not simulation state. This is by design: +GDS separates specification (what the system IS) from execution (what it DOES). + +A future `gds-analysis` continuous adapter could bridge this gap. + +### Time-varying inputs + +`ODEModel.params` are run-constant (fixed for the entire trajectory). +GDS `BoundaryAction` / `Input` elements represent per-timestep exogenous +signals, but `compile_to_ode()` in gds-symbolic resolves them from `params` +at each RHS evaluation. This means inputs are constant, not time-varying. + +For time-varying inputs, construct the `ODEFunction` manually with a +closure over a time-dependent input function: + +```python +def my_rhs(t, y, params): + u = sin(t) # time-varying input + return [-params["k"] * y[0] + u] +``` + +### No parallelism + +Unlike gds-sim (which has `parallel.py` with `ProcessPoolExecutor`), +gds-continuous executes parameter sweeps sequentially. For isochrone +computation (many separate initial conditions), each trajectory is +integrated in a loop. This is adequate for interactive notebooks but +may be slow for large sweeps. + +## Commands + +```bash +uv run --package gds-continuous pytest packages/gds-continuous/tests -v +``` diff --git a/packages/gds-continuous/README.md b/packages/gds-continuous/README.md new file mode 100644 index 0000000..cd117be --- /dev/null +++ b/packages/gds-continuous/README.md @@ -0,0 +1,26 @@ +# gds-continuous + +Continuous-time ODE integration engine for the GDS ecosystem. + +## Installation + +```bash +uv add gds-continuous[scipy] +``` + +## Quick Start + +```python +from gds_continuous import ODEModel, ODESimulation + +# Define an ODE: dx/dt = -x +model = ODEModel( + state_names=["x"], + initial_state={"x": 1.0}, + rhs=lambda t, y, p: [-y[0]], +) + +# Integrate +sim = ODESimulation(model=model, t_span=(0.0, 5.0)) +results = sim.run() +``` diff --git a/packages/gds-continuous/gds_continuous/__init__.py b/packages/gds-continuous/gds_continuous/__init__.py new file mode 100644 index 0000000..9943064 --- /dev/null +++ b/packages/gds-continuous/gds_continuous/__init__.py @@ -0,0 +1,18 @@ +"""gds-continuous: Continuous-time ODE integration engine for the GDS ecosystem.""" + +__version__ = "0.1.0" + +from gds_continuous.model import ODEExperiment, ODEModel, ODESimulation +from gds_continuous.results import ODEResults +from gds_continuous.types import EventFunction, ODEFunction, OutputFunction, Solver + +__all__ = [ + "EventFunction", + "ODEExperiment", + "ODEFunction", + "ODEModel", + "ODEResults", + "ODESimulation", + "OutputFunction", + "Solver", +] diff --git a/packages/gds-continuous/gds_continuous/_compat.py b/packages/gds-continuous/gds_continuous/_compat.py new file mode 100644 index 0000000..620bbc2 --- /dev/null +++ b/packages/gds-continuous/gds_continuous/_compat.py @@ -0,0 +1,12 @@ +"""Optional dependency guards.""" + + +def require_scipy() -> None: + """Raise ImportError with install instructions if scipy is absent.""" + try: + import scipy # noqa: F401 + except ImportError as exc: + raise ImportError( + "scipy is required for ODE integration. " + "Install with: uv add gds-continuous[scipy]" + ) from exc diff --git a/packages/gds-continuous/gds_continuous/engine.py b/packages/gds-continuous/gds_continuous/engine.py new file mode 100644 index 0000000..bcb0c48 --- /dev/null +++ b/packages/gds-continuous/gds_continuous/engine.py @@ -0,0 +1,99 @@ +"""ODE integration engine — wraps scipy.integrate.solve_ivp.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from gds_continuous._compat import require_scipy +from gds_continuous.results import ODEResults + +if TYPE_CHECKING: + from gds_continuous.model import ODEExperiment, ODESimulation + from gds_continuous.types import Params + + +def integrate_simulation(sim: ODESimulation) -> ODEResults: + """Integrate across all param subsets and runs.""" + require_scipy() + + results_parts: list[ODEResults] = [] + for subset_idx, params in enumerate(sim.model._param_subsets): + part = _integrate_single(sim, params, subset_idx) + results_parts.append(part) + + return ODEResults.merge(results_parts) + + +def integrate_experiment(experiment: ODEExperiment) -> ODEResults: + """Execute all simulations in an experiment.""" + parts: list[ODEResults] = [] + for sim in experiment.simulations: + parts.append(integrate_simulation(sim)) + return ODEResults.merge(parts) + + +def _integrate_single( + sim: ODESimulation, + params: Params, + subset_idx: int, +) -> ODEResults: + """Single solve_ivp call for one (subset, run) pair.""" + from scipy.integrate import solve_ivp + + model = sim.model + y0 = model.y0() + + def rhs(t: float, y: Any) -> list[float]: + return model.rhs(t, list(y), params) + + # Wrap event functions to close over params + events = [] + for event_fn in model.events: + + def _make_event(fn: Any) -> Any: + def wrapped(t: float, y: Any) -> float: + return fn(t, list(y), params) + + wrapped.terminal = getattr(fn, "terminal", False) # type: ignore[attr-defined] + wrapped.direction = getattr(fn, "direction", 0) # type: ignore[attr-defined] + return wrapped + + events.append(_make_event(event_fn)) + + sol = solve_ivp( + rhs, + sim.t_span, + y0, + method=sim.solver, + t_eval=sim.t_eval, + rtol=sim.rtol, + atol=sim.atol, + max_step=sim.max_step, + events=events or None, + dense_output=False, + ) + + if not sol.success: + msg = f"ODE integration failed: {sol.message}" + raise RuntimeError(msg) + + results = ODEResults( + model._state_order, + model.output_names if model.output_fn else None, + ) + + if model.output_fn: + for j in range(len(sol.t)): + state = [float(sol.y[i][j]) for i in range(len(model._state_order))] + outputs = model.output_fn(float(sol.t[j]), state, params) + results.append( + float(sol.t[j]), + state, + run=0, + subset=subset_idx, + outputs=outputs, + ) + else: + results.append_solution(sol.t, sol.y, run=0, subset=subset_idx) + + return results diff --git a/packages/gds-continuous/gds_continuous/model.py b/packages/gds-continuous/gds_continuous/model.py new file mode 100644 index 0000000..e7d109f --- /dev/null +++ b/packages/gds-continuous/gds_continuous/model.py @@ -0,0 +1,114 @@ +"""ODEModel, ODESimulation, and ODEExperiment configuration objects.""" + +from __future__ import annotations + +import itertools +from typing import Any, Self + +from pydantic import BaseModel, ConfigDict, model_validator + +from gds_continuous.types import ( # noqa: TC001 + EventFunction, + ODEFunction, + OutputFunction, + Params, + Solver, +) + + +class ODEModel(BaseModel): + """A continuous-time ODE model: state names, initial conditions, RHS function. + + The ``rhs`` callable has signature ``(t, y, params) -> dy/dt`` where + ``y`` is a list of floats ordered by ``state_names``. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + state_names: list[str] + initial_state: dict[str, float] + rhs: ODEFunction + output_fn: OutputFunction | None = None + output_names: list[str] = [] + params: dict[str, list[Any]] = {} + events: list[EventFunction] = [] + + # Computed at validation time + _param_subsets: list[Params] + _state_order: list[str] + + @model_validator(mode="after") + def _validate_structure(self) -> Self: + # 1. Cache state ordering + self._state_order = list(self.state_names) + + # 2. Verify initial_state keys match state_names + state_set = set(self.state_names) + initial_set = set(self.initial_state.keys()) + missing = state_set - initial_set + if missing: + msg = ( + f"initial_state is missing keys: {sorted(missing)}. " + f"Expected keys matching state_names: {self.state_names}" + ) + raise ValueError(msg) + extra = initial_set - state_set + if extra: + msg = f"initial_state has extra keys not in state_names: {sorted(extra)}" + raise ValueError(msg) + + # 3. Verify output_names if output_fn provided + if self.output_fn is not None and not self.output_names: + msg = "output_names must be provided when output_fn is set" + raise ValueError(msg) + + # 4. Expand parameter sweep (cartesian product) + if self.params: + keys = list(self.params.keys()) + values = [self.params[k] for k in keys] + self._param_subsets = [ + dict(zip(keys, combo, strict=True)) + for combo in itertools.product(*values) + ] + else: + self._param_subsets = [{}] + + return self + + def y0(self) -> list[float]: + """Initial state as an ordered list of floats.""" + return [self.initial_state[k] for k in self._state_order] + + +class ODESimulation(BaseModel): + """A runnable ODE simulation: model + time span + solver config.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + model: ODEModel + t_span: tuple[float, float] + t_eval: list[float] | None = None + solver: Solver = "RK45" + rtol: float = 1e-6 + atol: float = 1e-9 + max_step: float = float("inf") + + def run(self) -> Any: + """Integrate and return ODEResults.""" + from gds_continuous.engine import integrate_simulation + + return integrate_simulation(self) + + +class ODEExperiment(BaseModel): + """A collection of ODE simulations.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + simulations: list[ODESimulation] + + def run(self) -> Any: + """Execute all simulations and return merged ODEResults.""" + from gds_continuous.engine import integrate_experiment + + return integrate_experiment(self) diff --git a/packages/gds-continuous/gds_continuous/results.py b/packages/gds-continuous/gds_continuous/results.py new file mode 100644 index 0000000..363c20e --- /dev/null +++ b/packages/gds-continuous/gds_continuous/results.py @@ -0,0 +1,175 @@ +"""Columnar result storage for continuous-time ODE trajectories.""" + +from __future__ import annotations + +from typing import Any + +# Metadata column names +_META_COLS = ("time", "run", "subset") + + +class ODEResults: + """Columnar dict-of-lists result storage for ODE trajectories. + + Mirrors ``gds_sim.Results`` in interface but uses continuous + ``time`` (float) instead of discrete ``timestep``/``substep``. + """ + + __slots__ = ("_capacity", "_columns", "_output_names", "_size", "_state_names") + + def __init__( + self, + state_names: list[str], + output_names: list[str] | None = None, + capacity: int = 0, + ) -> None: + self._state_names = state_names + self._output_names = output_names or [] + self._size = 0 + self._capacity = capacity + + all_keys = list(_META_COLS) + state_names + self._output_names + if capacity > 0: + self._columns: dict[str, list[Any]] = { + k: [None] * capacity for k in all_keys + } + else: + self._columns = {k: [] for k in all_keys} + + # ------------------------------------------------------------------ + # Append + # ------------------------------------------------------------------ + + def append( + self, + time: float, + state: list[float], + *, + run: int = 0, + subset: int = 0, + outputs: list[float] | None = None, + ) -> None: + """Append a single time point.""" + cols = self._columns + idx = self._size + + if self._capacity > 0 and idx < self._capacity: + cols["time"][idx] = time + cols["run"][idx] = run + cols["subset"][idx] = subset + for i, name in enumerate(self._state_names): + cols[name][idx] = state[i] + if outputs: + for i, name in enumerate(self._output_names): + cols[name][idx] = outputs[i] + else: + cols["time"].append(time) + cols["run"].append(run) + cols["subset"].append(subset) + for i, name in enumerate(self._state_names): + cols[name].append(state[i]) + if outputs: + for i, name in enumerate(self._output_names): + cols[name].append(outputs[i]) + + self._size += 1 + + def append_solution( + self, + t_array: Any, + y_array: Any, + *, + run: int = 0, + subset: int = 0, + ) -> None: + """Append an entire scipy solve_ivp solution. + + Parameters + ---------- + t_array : array-like of shape (n_points,) + y_array : array-like of shape (n_states, n_points) + """ + n_points = len(t_array) + for j in range(n_points): + state = [float(y_array[i][j]) for i in range(len(self._state_names))] + self.append(float(t_array[j]), state, run=run, subset=subset) + + # ------------------------------------------------------------------ + # Conversion + # ------------------------------------------------------------------ + + def to_dataframe(self) -> Any: + """Convert to pandas DataFrame. Requires ``pandas`` installed.""" + try: + import pandas as pd # type: ignore[import-untyped] + except ImportError as exc: # pragma: no cover + raise ImportError( + "pandas is required for to_dataframe(). " + "Install with: uv add gds-continuous[pandas]" + ) from exc + + data = self._trimmed_columns() + return pd.DataFrame(data) + + def to_list(self) -> list[dict[str, Any]]: + """Convert to list of row-dicts.""" + data = self._trimmed_columns() + keys = list(data.keys()) + n = self._size + return [{k: data[k][i] for k in keys} for i in range(n)] + + def _trimmed_columns(self) -> dict[str, list[Any]]: + """Return columns trimmed to actual size.""" + if self._capacity > 0 and self._size < self._capacity: + return {k: v[: self._size] for k, v in self._columns.items()} + return self._columns + + # ------------------------------------------------------------------ + # Accessors + # ------------------------------------------------------------------ + + @property + def state_names(self) -> list[str]: + """Ordered state variable names.""" + return list(self._state_names) + + @property + def times(self) -> list[float]: + """Time values (trimmed to actual size).""" + return self._trimmed_columns()["time"] + + def state_array(self, name: str) -> list[float]: + """Get all values for a single state variable.""" + return self._trimmed_columns()[name] + + # ------------------------------------------------------------------ + # Merge + # ------------------------------------------------------------------ + + @classmethod + def merge(cls, results_list: list[ODEResults]) -> ODEResults: + """Merge multiple ODEResults into one.""" + if not results_list: + return cls([]) + if len(results_list) == 1: + return results_list[0] + + state_names = results_list[0]._state_names + output_names = results_list[0]._output_names + total = sum(r._size for r in results_list) + merged = cls(state_names, output_names, capacity=total) + + all_keys = list(_META_COLS) + state_names + output_names + offset = 0 + for r in results_list: + trimmed = r._trimmed_columns() + n = r._size + for k in all_keys: + merged._columns[k][offset : offset + n] = trimmed[k] + offset += n + + merged._size = total + return merged + + def __len__(self) -> int: + return self._size diff --git a/packages/gds-continuous/gds_continuous/types.py b/packages/gds-continuous/gds_continuous/types.py new file mode 100644 index 0000000..f3365ee --- /dev/null +++ b/packages/gds-continuous/gds_continuous/types.py @@ -0,0 +1,23 @@ +"""Type definitions for gds-continuous.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, Literal + +ODEFunction = Callable[[float, list[float], dict[str, Any]], list[float]] +"""ODE right-hand side: (t, y, params) -> dy/dt.""" + +OutputFunction = Callable[[float, list[float], dict[str, Any]], list[float]] +"""Output equation: (t, y, params) -> observations.""" + +Params = dict[str, Any] +"""Parameter dict for a single subset.""" + +EventFunction = Callable[[float, list[float], dict[str, Any]], float] +"""Event function for solve_ivp: (t, y, params) -> float. + +Zero-crossing triggers event.""" + +Solver = Literal["RK45", "RK23", "DOP853", "Radau", "BDF", "LSODA"] +"""SciPy ODE solver method names.""" diff --git a/packages/gds-continuous/pyproject.toml b/packages/gds-continuous/pyproject.toml new file mode 100644 index 0000000..366a5fd --- /dev/null +++ b/packages/gds-continuous/pyproject.toml @@ -0,0 +1,85 @@ +[project] +name = "gds-continuous" +dynamic = ["version"] +description = "Continuous-time ODE integration engine for the GDS ecosystem" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.12" +authors = [ + { name = "Rohan Mehta", email = "rohan@block.science" }, +] +keywords = [ + "generalized-dynamical-systems", + "ode", + "continuous-time", + "simulation", + "gds-framework", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Typing :: Typed", +] +dependencies = ["pydantic>=2.10"] + +[project.optional-dependencies] +scipy = ["scipy>=1.13", "numpy>=1.26"] +pandas = ["pandas>=2.0"] + +[project.urls] +Homepage = "https://github.com/BlockScience/gds-core" +Repository = "https://github.com/BlockScience/gds-core" +Documentation = "https://blockscience.github.io/gds-core" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "gds_continuous/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = ["gds_continuous"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "--import-mode=importlib --cov=gds_continuous --cov-report=term-missing --no-header -q" + +[tool.coverage.run] +source = ["gds_continuous"] +omit = ["gds_continuous/__init__.py"] + +[tool.coverage.report] +fail_under = 80 +show_missing = true +exclude_lines = [ + "if TYPE_CHECKING:", + "pragma: no cover", +] + +[tool.mypy] +strict = true + +[tool.ruff] +target-version = "py312" +line-length = 88 + +[tool.ruff.lint] +select = ["E", "W", "F", "I", "UP", "B", "SIM", "TCH", "RUF"] + +[dependency-groups] +dev = [ + "mypy>=1.13", + "pytest>=8.0", + "pytest-cov>=5.0", + "ruff>=0.8", + "scipy>=1.13", + "numpy>=1.26", +] diff --git a/packages/gds-continuous/tests/__init__.py b/packages/gds-continuous/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/gds-continuous/tests/conftest.py b/packages/gds-continuous/tests/conftest.py new file mode 100644 index 0000000..bdae384 --- /dev/null +++ b/packages/gds-continuous/tests/conftest.py @@ -0,0 +1,91 @@ +"""Shared fixtures for gds-continuous tests.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from gds_continuous.model import ODEModel, ODESimulation + +# --------------------------------------------------------------------------- +# ODE functions (plain callables) +# --------------------------------------------------------------------------- + + +def exponential_decay(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """dx/dt = -k*x. Exact solution: x(t) = x0 * exp(-k*t).""" + k = params.get("k", 1.0) + return [-k * y[0]] + + +def harmonic_oscillator( + t: float, y: list[float], params: dict[str, Any] +) -> list[float]: + """dx/dt = v, dv/dt = -omega^2 * x. Exact: x(t) = x0*cos(omega*t).""" + omega = params.get("omega", 1.0) + return [y[1], -(omega**2) * y[0]] + + +def lotka_volterra(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """Predator-prey: dx/dt = alpha*x - beta*x*y, dy/dt = delta*x*y - gamma*y.""" + alpha = params.get("alpha", 1.0) + beta = params.get("beta", 0.1) + delta = params.get("delta", 0.075) + gamma = params.get("gamma", 1.5) + x, prey = y[0], y[1] + return [ + alpha * x - beta * x * prey, + delta * x * prey - gamma * prey, + ] + + +# --------------------------------------------------------------------------- +# Model fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def decay_model() -> ODEModel: + """Single-variable exponential decay model.""" + return ODEModel( + state_names=["x"], + initial_state={"x": 1.0}, + rhs=exponential_decay, + params={"k": [1.0]}, + ) + + +@pytest.fixture +def oscillator_model() -> ODEModel: + """Two-variable harmonic oscillator model.""" + return ODEModel( + state_names=["x", "v"], + initial_state={"x": 1.0, "v": 0.0}, + rhs=harmonic_oscillator, + params={"omega": [1.0]}, + ) + + +@pytest.fixture +def decay_sim(decay_model: ODEModel) -> ODESimulation: + """Decay simulation: t=[0, 5], 101 eval points.""" + import numpy as np + + return ODESimulation( + model=decay_model, + t_span=(0.0, 5.0), + t_eval=list(np.linspace(0, 5, 101)), + ) + + +@pytest.fixture +def oscillator_sim(oscillator_model: ODEModel) -> ODESimulation: + """Oscillator simulation: t=[0, 2*pi], 201 eval points.""" + import numpy as np + + return ODESimulation( + model=oscillator_model, + t_span=(0.0, 2 * np.pi), + t_eval=list(np.linspace(0, 2 * np.pi, 201)), + ) diff --git a/packages/gds-continuous/tests/test_engine.py b/packages/gds-continuous/tests/test_engine.py new file mode 100644 index 0000000..839d167 --- /dev/null +++ b/packages/gds-continuous/tests/test_engine.py @@ -0,0 +1,217 @@ +"""Tests for the ODE integration engine.""" + +from __future__ import annotations + +import math +from typing import Any + +import pytest + +from gds_continuous.model import ODEModel, ODESimulation +from gds_continuous.results import ODEResults + + +class TestExponentialDecay: + """dx/dt = -k*x, exact solution: x(t) = x0 * exp(-k*t).""" + + def test_basic_integration(self, decay_sim: ODESimulation) -> None: + results = decay_sim.run() + assert isinstance(results, ODEResults) + assert len(results) == 101 + + def test_accuracy(self, decay_sim: ODESimulation) -> None: + results = decay_sim.run() + times = results.times + x_vals = results.state_array("x") + + for t, x in zip(times, x_vals, strict=True): + expected = math.exp(-t) + assert abs(x - expected) < 1e-5, f"t={t}: got {x}, expected {expected}" + + def test_initial_state_preserved(self, decay_sim: ODESimulation) -> None: + results = decay_sim.run() + assert results.state_array("x")[0] == pytest.approx(1.0) + + def test_final_state_decayed(self, decay_sim: ODESimulation) -> None: + results = decay_sim.run() + x_final = results.state_array("x")[-1] + assert x_final == pytest.approx(math.exp(-5.0), abs=1e-5) + + +class TestHarmonicOscillator: + """dx/dt = v, dv/dt = -omega^2*x. Exact: x(t) = cos(t), v(t) = -sin(t).""" + + def test_basic_integration(self, oscillator_sim: ODESimulation) -> None: + results = oscillator_sim.run() + assert isinstance(results, ODEResults) + assert len(results) == 201 + + def test_position_accuracy(self, oscillator_sim: ODESimulation) -> None: + results = oscillator_sim.run() + times = results.times + x_vals = results.state_array("x") + + for t, x in zip(times, x_vals, strict=True): + expected = math.cos(t) + assert abs(x - expected) < 1e-4, f"t={t}: got {x}, expected {expected}" + + def test_velocity_accuracy(self, oscillator_sim: ODESimulation) -> None: + results = oscillator_sim.run() + times = results.times + v_vals = results.state_array("v") + + for t, v in zip(times, v_vals, strict=True): + expected = -math.sin(t) + assert abs(v - expected) < 1e-4, f"t={t}: got {v}, expected {expected}" + + def test_energy_conservation(self, oscillator_sim: ODESimulation) -> None: + """Total energy E = 0.5*(x^2 + v^2) should be conserved.""" + results = oscillator_sim.run() + x_vals = results.state_array("x") + v_vals = results.state_array("v") + + e0 = 0.5 * (x_vals[0] ** 2 + v_vals[0] ** 2) + for i in range(len(x_vals)): + e = 0.5 * (x_vals[i] ** 2 + v_vals[i] ** 2) + assert e == pytest.approx(e0, abs=1e-6) + + +class TestParameterSweep: + """Parameter sweep across multiple subsets.""" + + def test_two_decay_rates(self) -> None: + def decay(t: float, y: list[float], p: dict[str, Any]) -> list[float]: + return [-p["k"] * y[0]] + + model = ODEModel( + state_names=["x"], + initial_state={"x": 1.0}, + rhs=decay, + params={"k": [1.0, 2.0]}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, 1.0), + t_eval=[0.0, 1.0], + ) + results = sim.run() + + # 2 subsets * 2 time points = 4 rows + assert len(results) == 4 + + rows = results.to_list() + # subset=0 (k=1): x(1) = exp(-1) + subset0_final = [r for r in rows if r["subset"] == 0 and r["time"] == 1.0] + assert len(subset0_final) == 1 + assert subset0_final[0]["x"] == pytest.approx(math.exp(-1.0), abs=1e-5) + + # subset=1 (k=2): x(1) = exp(-2) + subset1_final = [r for r in rows if r["subset"] == 1 and r["time"] == 1.0] + assert len(subset1_final) == 1 + assert subset1_final[0]["x"] == pytest.approx(math.exp(-2.0), abs=1e-5) + + +class TestSolverSelection: + """Different solver methods.""" + + @pytest.mark.parametrize("solver", ["RK45", "RK23", "DOP853"]) + def test_explicit_solvers(self, decay_model: ODEModel, solver: str) -> None: + sim = ODESimulation( + model=decay_model, + t_span=(0.0, 1.0), + t_eval=[0.0, 0.5, 1.0], + solver=solver, # type: ignore[arg-type] + ) + results = sim.run() + assert len(results) == 3 + x_final = results.state_array("x")[-1] + assert x_final == pytest.approx(math.exp(-1.0), abs=1e-4) + + @pytest.mark.parametrize("solver", ["Radau", "BDF"]) + def test_implicit_solvers(self, decay_model: ODEModel, solver: str) -> None: + sim = ODESimulation( + model=decay_model, + t_span=(0.0, 1.0), + t_eval=[0.0, 0.5, 1.0], + solver=solver, # type: ignore[arg-type] + ) + results = sim.run() + assert len(results) == 3 + x_final = results.state_array("x")[-1] + assert x_final == pytest.approx(math.exp(-1.0), abs=1e-4) + + +class TestEventDetection: + """Event function support.""" + + def test_terminal_event_stops_integration(self) -> None: + """Integrate x' = 1 until x crosses 5.0.""" + + def rhs(t: float, y: list[float], p: dict[str, Any]) -> list[float]: + return [1.0] + + def x_crosses_5(t: float, y: list[float], p: dict[str, Any]) -> float: + return y[0] - 5.0 + + x_crosses_5.terminal = True # type: ignore[attr-defined] + + model = ODEModel( + state_names=["x"], + initial_state={"x": 0.0}, + rhs=rhs, + events=[x_crosses_5], + ) + sim = ODESimulation(model=model, t_span=(0.0, 100.0)) + results = sim.run() + + # Integration should stop near t=5 + final_time = results.times[-1] + assert final_time == pytest.approx(5.0, abs=0.01) + assert results.state_array("x")[-1] == pytest.approx(5.0, abs=0.01) + + +class TestOutputFunction: + """output_fn evaluation along trajectories.""" + + def test_output_fn_called(self) -> None: + """output_fn should produce values in the results.""" + + def rhs(t: float, y: list[float], p: dict[str, Any]) -> list[float]: + return [-y[0]] + + def out_fn(t: float, y: list[float], p: dict[str, Any]) -> list[float]: + return [y[0] ** 2] + + model = ODEModel( + state_names=["x"], + initial_state={"x": 2.0}, + rhs=rhs, + output_fn=out_fn, + output_names=["x_sq"], + ) + sim = ODESimulation( + model=model, + t_span=(0.0, 1.0), + t_eval=[0.0, 0.5, 1.0], + ) + results = sim.run() + + rows = results.to_list() + assert rows[0]["x_sq"] == pytest.approx(4.0) # 2.0^2 + for row in rows: + assert row["x_sq"] is not None + assert row["x_sq"] == pytest.approx(row["x"] ** 2, abs=1e-6) + + +class TestAutoTimePoints: + """Integration without explicit t_eval.""" + + def test_auto_eval_points(self, decay_model: ODEModel) -> None: + sim = ODESimulation( + model=decay_model, + t_span=(0.0, 1.0), + ) + results = sim.run() + assert len(results) > 2 + assert results.times[0] == pytest.approx(0.0) + assert results.times[-1] == pytest.approx(1.0) diff --git a/packages/gds-continuous/tests/test_homicidal_chauffeur.py b/packages/gds-continuous/tests/test_homicidal_chauffeur.py new file mode 100644 index 0000000..5e43098 --- /dev/null +++ b/packages/gds-continuous/tests/test_homicidal_chauffeur.py @@ -0,0 +1,278 @@ +"""Homicidal Chauffeur differential game — integration verification. + +Recreates the key numerical results from mzargham/hc-marimo using +gds-continuous, proving the ODE engine handles a real differential +game from Isaacs (1951). + +The 4D characteristic ODE system: + x1' = -phi*x2 + w*sin(psi*), phi* = -sign(sigma), sigma = p2*x1 - p1*x2 + x2' = phi*x1 + w*cos(psi*) - 1, psi* = atan2(p1, p2) + p1' = -phi*p2 + p2' = phi*p1 + +References: + - R. Isaacs, Differential Games (1965), pp. 297-350 + - A.W. Merz, PhD Thesis, Stanford (1971) + - github.com/mzargham/hc-marimo +""" + +from __future__ import annotations + +import math +from typing import Any + +import numpy as np +import pytest + +from gds_continuous import ODEModel, ODESimulation + +# --------------------------------------------------------------------------- +# HC dynamics as ODEFunction callables +# --------------------------------------------------------------------------- + + +def hc_forward(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """Forward 4D characteristic ODE for the Homicidal Chauffeur.""" + x1, x2, p1, p2 = y + w = params["w"] + + norm_p = math.sqrt(p1**2 + p2**2) + if norm_p < 1e-15: + return [0.0, 0.0, 0.0, 0.0] + + sigma = p2 * x1 - p1 * x2 + phi_star = -np.sign(sigma) + + x1d = -phi_star * x2 + w * p1 / norm_p + x2d = phi_star * x1 + w * p2 / norm_p - 1.0 + p1d = -phi_star * p2 + p2d = phi_star * p1 + return [float(x1d), float(x2d), float(p1d), float(p2d)] + + +def hc_backward(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """Backward integration (negate forward dynamics).""" + fwd = hc_forward(t, y, params) + return [-v for v in fwd] + + +def compute_terminal_conditions( + alpha: float, w_val: float, ell_tilde: float +) -> list[float]: + """Terminal conditions on the capture circle for backward integration.""" + x1_T = ell_tilde * math.cos(alpha) + x2_T = ell_tilde * math.sin(alpha) + lam = -1.0 / (ell_tilde * (w_val - math.sin(alpha))) + p1_T = lam * x1_T + p2_T = lam * x2_T + return [x1_T, x2_T, p1_T, p2_T] + + +def hamiltonian_star(x1: float, x2: float, p1: float, p2: float, w: float) -> float: + """Optimal Hamiltonian H* along a trajectory point.""" + sigma = p2 * x1 - p1 * x2 + norm_p = math.sqrt(p1**2 + p2**2) + return -abs(sigma) + w * norm_p - p2 + 1.0 + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestHamiltonianConservation: + """H* should remain ~0 along optimal trajectories (T2 from hc-marimo).""" + + def test_h_star_conserved(self) -> None: + w_val = 0.25 + ell_tilde = 0.5 + alpha = math.pi / 2 # sin(pi/2) = 1 > 0.25 = w → usable + + y0 = compute_terminal_conditions(alpha, w_val, ell_tilde) + + model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=dict(zip(["x1", "x2", "p1", "p2"], y0, strict=True)), + rhs=hc_backward, + params={"w": [w_val]}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, 10.0), + solver="RK45", + rtol=1e-10, + atol=1e-12, + max_step=0.02, + ) + results = sim.run() + + x1s = results.state_array("x1") + x2s = results.state_array("x2") + p1s = results.state_array("p1") + p2s = results.state_array("p2") + + h_vals = [ + hamiltonian_star(x1s[i], x2s[i], p1s[i], p2s[i], w_val) + for i in range(len(results)) + ] + max_drift = max(abs(h) for h in h_vals) + assert max_drift < 1e-6, f"H* drift = {max_drift}" + + +class TestCostateNormConservation: + """||p||^2 must be conserved along trajectories (T3 from hc-marimo).""" + + def test_norm_conserved(self) -> None: + w_val = 0.25 + ell_tilde = 0.5 + alpha = math.pi / 2 + + y0 = compute_terminal_conditions(alpha, w_val, ell_tilde) + + model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=dict(zip(["x1", "x2", "p1", "p2"], y0, strict=True)), + rhs=hc_backward, + params={"w": [w_val]}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, 10.0), + solver="RK45", + rtol=1e-10, + atol=1e-12, + max_step=0.02, + ) + results = sim.run() + + p1s = results.state_array("p1") + p2s = results.state_array("p2") + norms = [p1s[i] ** 2 + p2s[i] ** 2 for i in range(len(results))] + initial_norm = norms[0] + + max_drift = max(abs(n - initial_norm) for n in norms) + assert max_drift < 1e-8, f"||p||^2 drift = {max_drift}" + + +class TestCaptureCondition: + """Forward integration from known initial state should reach capture + circle (T4 from hc-marimo).""" + + def test_forward_reaches_capture(self) -> None: + w_val = 0.25 + ell_tilde = 0.5 + alpha = math.pi / 2 + + # Get terminal conditions, then integrate backward to find + # an initial state far from capture circle + y0_terminal = compute_terminal_conditions(alpha, w_val, ell_tilde) + + model_back = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=dict(zip(["x1", "x2", "p1", "p2"], y0_terminal, strict=True)), + rhs=hc_backward, + params={"w": [w_val]}, + ) + sim_back = ODESimulation( + model=model_back, + t_span=(0.0, 5.0), + solver="RK45", + rtol=1e-10, + atol=1e-12, + ) + results_back = sim_back.run() + + # Take the final backward state as initial for forward + n = len(results_back) + y0_far = [ + results_back.state_array("x1")[n - 1], + results_back.state_array("x2")[n - 1], + results_back.state_array("p1")[n - 1], + results_back.state_array("p2")[n - 1], + ] + + model_fwd = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=dict(zip(["x1", "x2", "p1", "p2"], y0_far, strict=True)), + rhs=hc_forward, + params={"w": [w_val]}, + ) + sim_fwd = ODESimulation( + model=model_fwd, + t_span=(0.0, 5.0), + solver="RK45", + rtol=1e-10, + atol=1e-12, + ) + results_fwd = sim_fwd.run() + + # Final position should be near the capture circle + x1_final = results_fwd.state_array("x1")[-1] + x2_final = results_fwd.state_array("x2")[-1] + dist = math.sqrt(x1_final**2 + x2_final**2) + assert dist == pytest.approx(ell_tilde, abs=0.05), ( + f"Final distance {dist} not near capture radius {ell_tilde}" + ) + + +class TestStationaryEvader: + """w=0 (stationary evader): straight-line capture (T6 from hc-marimo).""" + + def test_w_zero_straight_capture(self) -> None: + """With w=0 and evader on the x2-axis, pursuer drives straight down.""" + + def hc_w0(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + _x1, _x2, _p1, _p2 = y + # With w=0: x1'=-phi*x2, x2'=phi*x1-1 + # For straight approach from above: phi=0, x2'=-1 + return [0.0, -1.0, 0.0, 0.0] + + initial_dist = 3.0 + model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state={"x1": 0.0, "x2": initial_dist, "p1": 0.0, "p2": 1.0}, + rhs=hc_w0, + params={"w": [0.0]}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, initial_dist), + t_eval=[0.0, initial_dist], + ) + results = sim.run() + + x2_final = results.state_array("x2")[-1] + assert x2_final == pytest.approx(0.0, abs=1e-6) + + +class TestParameterSweep: + """Sweep over w values using gds-continuous parameter sweep.""" + + def test_sweep_over_w(self) -> None: + alpha = math.pi / 2 + ell_tilde = 0.5 + w_values = [0.1, 0.2, 0.3] + + # Use w=0.2 for terminal conditions (all w values have sin(pi/2)>w) + y0 = compute_terminal_conditions(alpha, 0.2, ell_tilde) + + model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=dict(zip(["x1", "x2", "p1", "p2"], y0, strict=True)), + rhs=hc_backward, + params={"w": w_values}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, 3.0), + t_eval=[0.0, 1.0, 2.0, 3.0], + solver="RK45", + ) + results = sim.run() + + # 3 subsets * 4 time points = 12 rows + assert len(results) == 12 + + rows = results.to_list() + subsets = {r["subset"] for r in rows} + assert subsets == {0, 1, 2} diff --git a/packages/gds-continuous/tests/test_model.py b/packages/gds-continuous/tests/test_model.py new file mode 100644 index 0000000..054a224 --- /dev/null +++ b/packages/gds-continuous/tests/test_model.py @@ -0,0 +1,143 @@ +"""Tests for ODEModel, ODESimulation, ODEExperiment construction and validation.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from gds_continuous.model import ODEExperiment, ODEModel, ODESimulation + + +def _trivial_rhs(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + return [0.0] + + +class TestODEModelValid: + """Valid construction cases.""" + + def test_single_state(self) -> None: + m = ODEModel( + state_names=["x"], + initial_state={"x": 0.0}, + rhs=_trivial_rhs, + ) + assert m.state_names == ["x"] + assert m.y0() == [0.0] + + def test_multi_state(self) -> None: + def rhs(t: float, y: list[float], p: dict) -> list[float]: + return [0.0, 0.0, 0.0] + + m = ODEModel( + state_names=["x", "v", "a"], + initial_state={"x": 1.0, "v": 2.0, "a": 3.0}, + rhs=rhs, + ) + assert m.y0() == [1.0, 2.0, 3.0] + + def test_param_sweep_expansion(self) -> None: + m = ODEModel( + state_names=["x"], + initial_state={"x": 0.0}, + rhs=_trivial_rhs, + params={"a": [1, 2], "b": [10, 20]}, + ) + assert len(m._param_subsets) == 4 + + def test_no_params_single_subset(self) -> None: + m = ODEModel( + state_names=["x"], + initial_state={"x": 0.0}, + rhs=_trivial_rhs, + ) + assert m._param_subsets == [{}] + + def test_state_order_preserved(self) -> None: + def rhs(t: float, y: list[float], p: dict) -> list[float]: + return [0.0, 0.0] + + m = ODEModel( + state_names=["beta", "alpha"], + initial_state={"beta": 10.0, "alpha": 20.0}, + rhs=rhs, + ) + assert m._state_order == ["beta", "alpha"] + assert m.y0() == [10.0, 20.0] + + def test_with_output_fn(self) -> None: + def out_fn(t: float, y: list[float], p: dict) -> list[float]: + return [y[0] ** 2] + + m = ODEModel( + state_names=["x"], + initial_state={"x": 1.0}, + rhs=_trivial_rhs, + output_fn=out_fn, + output_names=["x_squared"], + ) + assert m.output_names == ["x_squared"] + + +class TestODEModelInvalid: + """Construction validation errors.""" + + def test_missing_initial_state_key(self) -> None: + with pytest.raises(ValueError, match="missing keys"): + ODEModel( + state_names=["x", "v"], + initial_state={"x": 0.0}, + rhs=_trivial_rhs, + ) + + def test_extra_initial_state_key(self) -> None: + with pytest.raises(ValueError, match="extra keys"): + ODEModel( + state_names=["x"], + initial_state={"x": 0.0, "y": 1.0}, + rhs=_trivial_rhs, + ) + + def test_output_fn_without_names(self) -> None: + def out_fn(t: float, y: list[float], p: dict) -> list[float]: + return [y[0]] + + with pytest.raises(ValueError, match="output_names must be provided"): + ODEModel( + state_names=["x"], + initial_state={"x": 0.0}, + rhs=_trivial_rhs, + output_fn=out_fn, + ) + + +class TestODESimulation: + """ODESimulation construction.""" + + def test_defaults(self, decay_model: ODEModel) -> None: + sim = ODESimulation(model=decay_model, t_span=(0.0, 1.0)) + assert sim.solver == "RK45" + assert sim.rtol == 1e-6 + assert sim.atol == 1e-9 + assert sim.t_eval is None + + def test_custom_solver(self, decay_model: ODEModel) -> None: + sim = ODESimulation( + model=decay_model, + t_span=(0.0, 1.0), + solver="Radau", + rtol=1e-8, + atol=1e-12, + ) + assert sim.solver == "Radau" + assert sim.rtol == 1e-8 + + +class TestODEExperiment: + """ODEExperiment construction.""" + + def test_multi_simulation(self, decay_model: ODEModel) -> None: + sim1 = ODESimulation(model=decay_model, t_span=(0.0, 1.0)) + sim2 = ODESimulation(model=decay_model, t_span=(0.0, 5.0)) + exp = ODEExperiment(simulations=[sim1, sim2]) + assert len(exp.simulations) == 2 diff --git a/packages/gds-continuous/tests/test_results.py b/packages/gds-continuous/tests/test_results.py new file mode 100644 index 0000000..ed58135 --- /dev/null +++ b/packages/gds-continuous/tests/test_results.py @@ -0,0 +1,136 @@ +"""Tests for ODEResults columnar storage.""" + +from __future__ import annotations + +import pytest + +from gds_continuous.results import ODEResults + + +class TestAppend: + """Row-by-row append.""" + + def test_basic_append(self) -> None: + r = ODEResults(["x", "v"]) + r.append(0.0, [1.0, 0.0]) + r.append(0.1, [0.99, -0.1]) + assert len(r) == 2 + + def test_preallocated_append(self) -> None: + r = ODEResults(["x"], capacity=10) + for i in range(10): + r.append(float(i), [float(i * 2)]) + assert len(r) == 10 + + def test_overflow_to_dynamic(self) -> None: + r = ODEResults(["x"], capacity=2) + r.append(0.0, [1.0]) + r.append(1.0, [2.0]) + r.append(2.0, [3.0]) # exceeds capacity + assert len(r) == 3 + + def test_with_metadata(self) -> None: + r = ODEResults(["x"]) + r.append(0.0, [1.0], run=2, subset=3) + rows = r.to_list() + assert rows[0]["run"] == 2 + assert rows[0]["subset"] == 3 + + def test_with_outputs(self) -> None: + r = ODEResults(["x"], output_names=["y"]) + r.append(0.0, [1.0], outputs=[2.0]) + rows = r.to_list() + assert rows[0]["y"] == 2.0 + + +class TestAppendSolution: + """Bulk append from scipy solution arrays.""" + + def test_append_solution(self) -> None: + r = ODEResults(["x", "v"]) + t = [0.0, 0.1, 0.2] + y = [[1.0, 0.99, 0.98], [0.0, -0.1, -0.2]] # shape (2, 3) + r.append_solution(t, y, run=0, subset=0) + assert len(r) == 3 + assert r.times == [0.0, 0.1, 0.2] + assert r.state_array("x") == [1.0, 0.99, 0.98] + assert r.state_array("v") == [0.0, -0.1, -0.2] + + +class TestConversion: + """to_list and to_dataframe.""" + + def test_to_list(self) -> None: + r = ODEResults(["x"]) + r.append(0.0, [1.0]) + r.append(1.0, [2.0]) + rows = r.to_list() + assert len(rows) == 2 + assert rows[0] == {"time": 0.0, "run": 0, "subset": 0, "x": 1.0} + assert rows[1] == {"time": 1.0, "run": 0, "subset": 0, "x": 2.0} + + def test_to_list_preallocated_trimmed(self) -> None: + r = ODEResults(["x"], capacity=100) + r.append(0.0, [1.0]) + r.append(1.0, [2.0]) + rows = r.to_list() + assert len(rows) == 2 + + def test_to_dataframe(self) -> None: + pytest.importorskip("pandas") + r = ODEResults(["x", "v"]) + r.append(0.0, [1.0, 0.0]) + r.append(0.1, [0.99, -0.1]) + df = r.to_dataframe() + assert list(df.columns) == ["time", "run", "subset", "x", "v"] + assert len(df) == 2 + + +class TestAccessors: + """Property and method accessors.""" + + def test_state_names(self) -> None: + r = ODEResults(["alpha", "beta"]) + assert r.state_names == ["alpha", "beta"] + + def test_times(self) -> None: + r = ODEResults(["x"]) + r.append(0.0, [1.0]) + r.append(0.5, [0.5]) + assert r.times == [0.0, 0.5] + + def test_state_array(self) -> None: + r = ODEResults(["x", "v"]) + r.append(0.0, [1.0, 0.0]) + r.append(0.1, [0.9, -0.1]) + assert r.state_array("x") == [1.0, 0.9] + assert r.state_array("v") == [0.0, -0.1] + + +class TestMerge: + """Merging multiple results.""" + + def test_merge_two(self) -> None: + r1 = ODEResults(["x"]) + r1.append(0.0, [1.0], subset=0) + r1.append(1.0, [2.0], subset=0) + + r2 = ODEResults(["x"]) + r2.append(0.0, [10.0], subset=1) + r2.append(1.0, [20.0], subset=1) + + merged = ODEResults.merge([r1, r2]) + assert len(merged) == 4 + rows = merged.to_list() + assert rows[0]["subset"] == 0 + assert rows[2]["subset"] == 1 + + def test_merge_empty(self) -> None: + merged = ODEResults.merge([]) + assert len(merged) == 0 + + def test_merge_single(self) -> None: + r = ODEResults(["x"]) + r.append(0.0, [1.0]) + merged = ODEResults.merge([r]) + assert merged is r diff --git a/packages/gds-examples/continuous/__init__.py b/packages/gds-examples/continuous/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/gds-examples/continuous/homicidal_chauffeur/__init__.py b/packages/gds-examples/continuous/homicidal_chauffeur/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/gds-examples/continuous/homicidal_chauffeur/model.py b/packages/gds-examples/continuous/homicidal_chauffeur/model.py new file mode 100644 index 0000000..797a3fc --- /dev/null +++ b/packages/gds-examples/continuous/homicidal_chauffeur/model.py @@ -0,0 +1,259 @@ +"""Homicidal Chauffeur Differential Game — gds-symbolic + gds-continuous. + +Isaacs' foundational pursuit-evasion problem (1951): Can a fast but clumsy +car catch a slow but agile pedestrian? + +This example demonstrates the full symbolic-to-numerical pipeline: +1. Declare the game dynamics symbolically using gds-symbolic +2. Derive optimal controls from the Hamiltonian using SymPy +3. Compile the closed-loop 4D ODE via sympy.lambdify +4. Integrate trajectories using gds-continuous +5. Verify conservation laws (Hamiltonian, costate norm) + +The 4D characteristic ODE (optimal feedback form): + x1_dot = -phi_star * x2 + w * sin(psi_star) + x2_dot = phi_star * x1 + w * cos(psi_star) - 1 + p1_dot = -phi_star * p2 + p2_dot = phi_star * p1 + +where: + phi_star = -sign(sigma), sigma = p2*x1 - p1*x2 (bang-bang pursuer) + psi_star = atan2(p1, p2) (gradient-aligned evader) + +Parameters: + w = v_E / v_P -- evader-to-pursuer speed ratio (0 < w < 1) + ell_tilde = ell / R_min -- normalized capture radius + +GDS Decomposition: + X = (x1, x2, p1, p2) -- relative position + costates (4D) + U = (w, ell_tilde) -- game parameters + h = f -- pure dynamics, no policy/mechanism split + (the optimal controls are derived analytically + and substituted into the RHS) + +References: + R. Isaacs, *Differential Games*, Wiley (1965), pp. 297-350 + A.W. Merz, PhD Thesis, Stanford (1971) + github.com/mzargham/hc-marimo +""" + +from __future__ import annotations + +import math +from typing import Any + +from gds_continuous import ODEModel, ODESimulation + +# --------------------------------------------------------------------------- +# Symbolic derivation (requires sympy) +# --------------------------------------------------------------------------- + + +def derive_optimal_rhs() -> tuple[Any, list[str]]: + """Derive the 4D characteristic ODE symbolically. + + Uses SymPy to: + 1. Define the reduced kinematics (Isaacs' body-fixed frame) + 2. Construct the Hamiltonian H = p1*f1 + p2*f2 + 1 + 3. Solve for optimal controls (bang-bang phi, gradient-aligned psi) + 4. Derive costate equations from dH/dx + 5. Lambdify the full 4D system + + Returns + ------- + rhs_fn : callable + Lambdified (t, y, params) -> dy/dt + state_order : list[str] + ["x1", "x2", "p1", "p2"] + """ + import sympy as sp + + x1, x2 = sp.symbols("x_1 x_2", real=True) + p1, p2 = sp.symbols("p_1 p_2", real=True) + phi = sp.Symbol("phi", real=True) + psi = sp.Symbol("psi", real=True) + w = sp.Symbol("w", positive=True) + + # Reduced dynamics (Isaacs' canonical form) + f1 = -phi * x2 + w * sp.sin(psi) + f2 = phi * x1 + w * sp.cos(psi) - 1 + + # Hamiltonian (minimizing time => +1) + H = p1 * f1 + p2 * f2 + 1 + + # Optimal pursuer: phi* = -sign(sigma), sigma = dH/dphi coefficient + sigma = sp.expand(H).coeff(phi) # = p2*x1 - p1*x2 + phi_star = -sp.sign(sigma) + + # Optimal evader: psi* = atan2(p1, p2) maximizes w*(p1*sin + p2*cos) + psi_star = sp.atan2(p1, p2) + + # Substitute optimal controls + rhs_x1 = f1.subs([(phi, phi_star), (psi, psi_star)]) + rhs_x2 = f2.subs([(phi, phi_star), (psi, psi_star)]) + + # Costate equations: dp/dt = -dH/dx + rhs_p1 = -phi_star * p2 + rhs_p2 = phi_star * p1 + + # Lambdify + _fn = sp.lambdify( + [x1, x2, p1, p2, w], + [rhs_x1, rhs_x2, rhs_p1, rhs_p2], + modules=["numpy"], + ) + + def rhs_fn(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + result = _fn(y[0], y[1], y[2], y[3], params["w"]) + return [float(v) for v in result] + + return rhs_fn, ["x1", "x2", "p1", "p2"] + + +# --------------------------------------------------------------------------- +# Hand-coded dynamics (no SymPy dependency) +# --------------------------------------------------------------------------- + + +def hc_forward(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """Forward 4D characteristic ODE — hand-coded numpy version.""" + x1, x2, p1, p2 = y + w = params["w"] + + norm_p = math.sqrt(p1**2 + p2**2) + if norm_p < 1e-15: + return [0.0, 0.0, 0.0, 0.0] + + sigma = p2 * x1 - p1 * x2 + phi_star = -1.0 if sigma > 0 else (1.0 if sigma < 0 else 0.0) + + x1d = -phi_star * x2 + w * p1 / norm_p + x2d = phi_star * x1 + w * p2 / norm_p - 1.0 + p1d = -phi_star * p2 + p2d = phi_star * p1 + return [x1d, x2d, p1d, p2d] + + +def hc_backward(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """Backward integration (negate forward dynamics).""" + fwd = hc_forward(t, y, params) + return [-v for v in fwd] + + +# --------------------------------------------------------------------------- +# Terminal conditions +# --------------------------------------------------------------------------- + + +def terminal_conditions(alpha: float, w: float, ell_tilde: float) -> dict[str, float]: + """Compute initial conditions on the capture circle for backward integration. + + Parameters + ---------- + alpha : float + Angle on the terminal circle. Usable part requires sin(alpha) > w. + w : float + Evader-to-pursuer speed ratio. + ell_tilde : float + Normalized capture radius. + + Returns + ------- + dict mapping state names to values. + """ + x1_T = ell_tilde * math.cos(alpha) + x2_T = ell_tilde * math.sin(alpha) + lam = -1.0 / (ell_tilde * (w - math.sin(alpha))) + p1_T = lam * x1_T + p2_T = lam * x2_T + return {"x1": x1_T, "x2": x2_T, "p1": p1_T, "p2": p2_T} + + +# --------------------------------------------------------------------------- +# Conservation laws +# --------------------------------------------------------------------------- + + +def hamiltonian_star(state: dict[str, float], w: float) -> float: + """Optimal Hamiltonian H* = -|sigma| + w*||p|| - p2 + 1. + + Should be ~0 along optimal trajectories. + """ + x1, x2 = state["x1"], state["x2"] + p1, p2 = state["p1"], state["p2"] + sigma = p2 * x1 - p1 * x2 + norm_p = math.sqrt(p1**2 + p2**2) + return -abs(sigma) + w * norm_p - p2 + 1.0 + + +def costate_norm_sq(state: dict[str, float]) -> float: + """||p||^2 = p1^2 + p2^2. Should be conserved along trajectories.""" + return state["p1"] ** 2 + state["p2"] ** 2 + + +# --------------------------------------------------------------------------- +# Simulation builders +# --------------------------------------------------------------------------- + + +def build_backward_model( + alpha: float = math.pi / 2, + w: float = 0.25, + ell_tilde: float = 0.5, +) -> ODEModel: + """Build an ODEModel for backward integration from the capture circle.""" + ic = terminal_conditions(alpha, w, ell_tilde) + return ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=ic, + rhs=hc_backward, + params={"w": [w]}, + ) + + +def build_backward_simulation( + alpha: float = math.pi / 2, + w: float = 0.25, + ell_tilde: float = 0.5, + t_final: float = 10.0, +) -> ODESimulation: + """Build an ODESimulation for backward reachable set computation.""" + model = build_backward_model(alpha, w, ell_tilde) + return ODESimulation( + model=model, + t_span=(0.0, t_final), + solver="RK45", + rtol=1e-10, + atol=1e-12, + max_step=0.02, + ) + + +def compute_isochrone( + w: float = 0.25, + ell_tilde: float = 0.5, + t_final: float = 5.0, + n_rays: int = 40, +) -> list[tuple[float, float]]: + """Compute a backward reachable set boundary (isochrone). + + Integrates backward from multiple points on the usable part of the + terminal circle, each for time t_final. Returns the endpoints as + (x1, x2) pairs forming the isochrone contour. + """ + # Usable part: sin(alpha) > w + alpha_min = math.asin(w) + 0.01 + alpha_max = math.pi - math.asin(w) - 0.01 + alphas = [ + alpha_min + i * (alpha_max - alpha_min) / (n_rays - 1) for i in range(n_rays) + ] + + points: list[tuple[float, float]] = [] + for alpha in alphas: + sim = build_backward_simulation(alpha, w, ell_tilde, t_final) + results = sim.run() + x1_final = results.state_array("x1")[-1] + x2_final = results.state_array("x2")[-1] + points.append((x1_final, x2_final)) + + return points diff --git a/packages/gds-examples/continuous/homicidal_chauffeur/notebook.py b/packages/gds-examples/continuous/homicidal_chauffeur/notebook.py new file mode 100644 index 0000000..bc4d380 --- /dev/null +++ b/packages/gds-examples/continuous/homicidal_chauffeur/notebook.py @@ -0,0 +1,565 @@ +import marimo + +__generated_with = "0.20.4" +app = marimo.App(width="medium") + + +@app.cell +def imports(): + import marimo as mo + import matplotlib.pyplot as plt + import numpy as np + import sympy as sp + from sympy import ( + atan2, + cos, + diff, + expand, + latex, + sign, + simplify, + sin, + sqrt, + symbols, + trigsimp, + ) + + from gds_continuous import ODEModel, ODESimulation + + return ( + ODEModel, + ODESimulation, + atan2, + cos, + diff, + expand, + latex, + mo, + np, + plt, + sign, + simplify, + sin, + sp, + sqrt, + symbols, + trigsimp, + ) + + +@app.cell +def title(mo): + mo.md( + r""" + # The Homicidal Chauffeur: A Differential Game + + ## Symbolic Derivation & Interactive Simulation with gds-continuous + + --- + + *An interactive notebook exploring Rufus Isaacs' foundational + pursuit-evasion problem (1951) through the + [GDS](https://github.com/BlockScience/gds-core) ecosystem.* + + Every equation is derived symbolically with SymPy, then integrated + numerically through `gds-continuous` (wrapping `scipy.integrate.solve_ivp`). + + **The problem:** Can a fast but clumsy car catch a slow but agile + pedestrian? A pursuer (high speed, minimum turning radius) chases an + evader (low speed, unlimited maneuverability) on an unbounded plane. + + **References:** + - R. Isaacs, *Differential Games*, Wiley (1965), pp. 297--350 + - A.W. Merz, PhD Thesis, Stanford (1971) + - [mzargham/hc-marimo](https://github.com/mzargham/hc-marimo) + (reference implementation) + """ + ) + return + + +@app.cell +def define_symbols(symbols): + x1, x2 = symbols("x_1 x_2", real=True) + p1, p2 = symbols("p_1 p_2", real=True) + phi = symbols("phi", real=True) + psi = symbols("psi", real=True) + w = symbols("w", positive=True) + return p1, p2, phi, psi, w, x1, x2 + + +@app.cell +def reduced_dynamics(cos, latex, mo, phi, psi, sin, w, x1, x2): + f1 = -phi * x2 + w * sin(psi) + f2 = phi * x1 + w * cos(psi) - 1 + + mo.md( + rf""" + ## Reduced Kinematics (Isaacs' Body-Fixed Frame) + + After reducing from 5-DOF absolute coordinates to a 2-DOF system + in the pursuer's rotating frame (normalizing $v_P = 1$, $R_{{\min}} = 1$): + + $$ + \dot{{x}}_1 = {latex(f1)}, \qquad \dot{{x}}_2 = {latex(f2)} + $$ + + where $\phi \in [-1, +1]$ is the pursuer's turn rate (control), + $\psi$ is the evader's heading (control), and $w = v_E / v_P < 1$. + """ + ) + return f1, f2 + + +@app.cell +def hamiltonian(expand, f1, f2, latex, mo, p1, p2, phi, psi, w): + H = p1 * f1 + p2 * f2 + 1 + _H_expanded = expand(H) + + sigma = _H_expanded.coeff(phi) + + mo.md( + rf""" + ## Hamiltonian & Optimal Controls + + The time-optimal Hamiltonian: + + $$H = p_1 f_1 + p_2 f_2 + 1$$ + + Expanding: $H$ is **linear in $\phi$** with switching function + $\sigma = {latex(sigma)}$. + + **Pursuer** (minimizes $H$): $\phi^* = -\text{{sign}}(\sigma)$ (bang-bang) + + **Evader** (maximizes $H$): $\psi^* = \text{{atan2}}(p_1, p_2)$ + (gradient-aligned) + """ + ) + return H, sigma + + +@app.cell +def costate_equations(H, diff, latex, mo, phi, p1, p2, simplify, x1, x2): + p1_dot = -diff(H, x1) + p2_dot = -diff(H, x2) + + _d_norm_sq = simplify(2 * (p1 * p1_dot + p2 * p2_dot)) + + mo.md( + rf""" + ## Costate (Adjoint) Equations + + $$\dot{{p}}_1 = -{{\partial H}}/{{\partial x_1}} = {latex(p1_dot)}$$ + $$\dot{{p}}_2 = -{{\partial H}}/{{\partial x_2}} = {latex(p2_dot)}$$ + + **Conservation:** $\frac{{d}}{{dt}}(p_1^2 + p_2^2) = {latex(_d_norm_sq)}$ + — the costate norm is constant along optimal trajectories. + """ + ) + return p1_dot, p2_dot + + +@app.cell +def lambdify_rhs(atan2, f1, f2, p1, p2, phi, psi, sign, sp, w, x1, x2): + _sigma = p2 * x1 - p1 * x2 + _phi_star = -sign(_sigma) + _psi_star = atan2(p1, p2) + + _rhs_x1 = f1.subs([(phi, _phi_star), (psi, _psi_star)]) + _rhs_x2 = f2.subs([(phi, _phi_star), (psi, _psi_star)]) + _rhs_p1 = -_phi_star * p2 + _rhs_p2 = _phi_star * p1 + + _fn = sp.lambdify( + [x1, x2, p1, p2, w], + [_rhs_x1, _rhs_x2, _rhs_p1, _rhs_p2], + modules=["numpy"], + ) + + def rhs_forward(t, y, params): + _result = _fn(y[0], y[1], y[2], y[3], params["w"]) + return [float(_v) for _v in _result] + + def rhs_backward(t, y, params): + _fwd = rhs_forward(t, y, params) + return [-_v for _v in _fwd] + + return rhs_backward, rhs_forward + + +@app.cell +def parameter_sliders(mo): + mo.md("## Interactive Parameters") + + v_E_slider = mo.ui.slider( + start=0.05, + stop=0.50, + step=0.01, + value=0.25, + label=r"Evader speed $v_E$ (pursuer $v_P = 1$)", + ) + omega_max_slider = mo.ui.slider( + start=0.5, + stop=3.0, + step=0.1, + value=1.0, + label=r"Max angular velocity $\omega_{\max}$", + ) + ell_slider = mo.ui.slider( + start=0.1, stop=1.5, step=0.05, value=0.5, label=r"Capture radius $\ell$" + ) + + mo.vstack([v_E_slider, omega_max_slider, ell_slider]) + return ell_slider, omega_max_slider, v_E_slider + + +@app.cell +def derived_parameters(ell_slider, mo, omega_max_slider, v_E_slider): + v_E_val = v_E_slider.value + _v_P_val = 1.0 + _omega_val = omega_max_slider.value + _ell_val = ell_slider.value + + _R_min_val = _v_P_val / _omega_val + w_val = v_E_val / _v_P_val + ell_tilde_val = _ell_val / _R_min_val + + mo.md( + rf""" + ### Dimensionless Parameters + + | Symbol | Meaning | Value | + |--------|---------|-------| + | $w = v_E / v_P$ | speed ratio | ${w_val:.3f}$ | + | $\tilde{{\ell}} = \ell / R_{{\min}}$ | capture radius | ${ell_tilde_val:.3f}$ | + """ + ) + return ell_tilde_val, w_val + + +@app.cell +def terminal_conditions_fn(np): + def compute_terminal_conditions(alpha_arr, _w_val, _ell_tilde_val): + """Compute (x1, x2, p1, p2) at the terminal surface.""" + _x1_T = _ell_tilde_val * np.cos(alpha_arr) + _x2_T = _ell_tilde_val * np.sin(alpha_arr) + _lam = -1.0 / (_ell_tilde_val * (_w_val - np.sin(alpha_arr))) + _p1_T = _lam * _x1_T + _p2_T = _lam * _x2_T + return np.column_stack([_x1_T, _x2_T, _p1_T, _p2_T]) + + return (compute_terminal_conditions,) + + +@app.cell +def trajectory_sliders(mo): + n_traj_slider = mo.ui.slider( + start=5, stop=50, step=5, value=20, label="Number of trajectories" + ) + T_horizon_slider = mo.ui.slider( + start=1.0, stop=20.0, step=0.5, value=8.0, label=r"Backward time horizon $T$" + ) + mo.vstack([mo.md("### Trajectory Controls"), n_traj_slider, T_horizon_slider]) + return T_horizon_slider, n_traj_slider + + +@app.cell +def backward_trajectories( + ODEModel, + ODESimulation, + T_horizon_slider, + compute_terminal_conditions, + ell_tilde_val, + n_traj_slider, + np, + rhs_backward, + w_val, +): + _n_traj = n_traj_slider.value + _T_horizon = T_horizon_slider.value + + _alpha_min = np.arcsin(min(w_val, 0.999)) + _alpha_max = np.pi - _alpha_min + _eps = 1e-3 + _alphas = np.linspace(_alpha_min + _eps, _alpha_max - _eps, _n_traj) + + _terminal = compute_terminal_conditions(_alphas, w_val, ell_tilde_val) + + trajectories = [] + for _i in range(_n_traj): + _ic = _terminal[_i] + _model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state={ + "x1": float(_ic[0]), + "x2": float(_ic[1]), + "p1": float(_ic[2]), + "p2": float(_ic[3]), + }, + rhs=rhs_backward, + params={"w": [w_val]}, + ) + _sim = ODESimulation( + model=_model, + t_span=(0.0, _T_horizon), + solver="RK45", + rtol=1e-10, + atol=1e-12, + max_step=0.05, + ) + _results = _sim.run() + trajectories.append(_results) + return (trajectories,) + + +@app.cell +def trajectory_plot(ell_tilde_val, mo, np, plt, trajectories, w_val): + _fig, _ax = plt.subplots(1, 1, figsize=(8, 8)) + + # Terminal circle + _theta = np.linspace(0, 2 * np.pi, 200) + _ax.plot( + ell_tilde_val * np.cos(_theta), + ell_tilde_val * np.sin(_theta), + "k--", + linewidth=1, + alpha=0.5, + label="Terminal circle", + ) + + # Usable part + _alpha_min = np.arcsin(min(w_val, 0.999)) + _alpha_max = np.pi - _alpha_min + _usable = np.linspace(_alpha_min, _alpha_max, 100) + _ax.plot( + ell_tilde_val * np.cos(_usable), + ell_tilde_val * np.sin(_usable), + "r-", + linewidth=3, + alpha=0.8, + label="Usable part", + ) + + # Plot trajectories + _cmap = plt.cm.viridis + _mid = (len(trajectories) - 1) / 2.0 + for _i, _res in enumerate(trajectories): + _color = _cmap(abs(_i - _mid) / max(_mid, 1)) + _x1s = _res.state_array("x1") + _x2s = _res.state_array("x2") + _ax.plot(_x1s, _x2s, "-", color=_color, linewidth=0.8, alpha=0.7) + _ax.plot(_x1s[-1], _x2s[-1], "o", color=_color, markersize=3) + + _ax.plot(0, 0, "k+", markersize=12, markeredgewidth=2) + _ax.set_xlabel(r"$x_1$ (perpendicular)") + _ax.set_ylabel(r"$x_2$ (along heading)") + _ax.set_aspect("equal") + _ax.set_title( + rf"Optimal Trajectories ($w = {w_val:.3f}$, " + rf"$\tilde{{\ell}} = {ell_tilde_val:.3f}$)" + ) + _ax.legend(loc="lower right", fontsize=9) + _ax.grid(True, alpha=0.3) + plt.tight_layout() + + mo.vstack( + [ + _fig, + mo.md( + r""" + **Backward characteristics** from the usable part of the terminal + circle. Each curve starts on the capture boundary at $\tau = 0$ + and extends outward as backward time $\tau$ increases. In forward + time, these are optimal pursuit paths. + + Sharp kinks are **bang-bang switching points** where the pursuer's + control jumps between $\phi = +1$ and $\phi = -1$. + + *Integrated via `gds-continuous.ODESimulation`.* + """ + ), + ] + ) + return + + +@app.cell +def isochrone_controls(mo): + iso_T_slider = mo.ui.slider( + start=1.0, stop=15.0, step=0.5, value=5.0, label=r"Isochrone time $T$" + ) + iso_n_slider = mo.ui.slider( + start=20, stop=100, step=10, value=60, label="Isochrone resolution" + ) + mo.vstack( + [ + mo.md("### Isochrone (Backward Reachable Set)"), + iso_T_slider, + iso_n_slider, + ] + ) + return iso_T_slider, iso_n_slider + + +@app.cell +def isochrone_plot( + ODEModel, + ODESimulation, + compute_terminal_conditions, + ell_tilde_val, + iso_T_slider, + iso_n_slider, + mo, + np, + plt, + rhs_backward, + w_val, +): + _T_iso = iso_T_slider.value + _n_rays = iso_n_slider.value + + _alpha_min = np.arcsin(min(w_val, 0.999)) + _alpha_max = np.pi - _alpha_min + _eps = 1e-3 + _alphas = np.linspace(_alpha_min + _eps, _alpha_max - _eps, _n_rays) + _terminal = compute_terminal_conditions(_alphas, w_val, ell_tilde_val) + + _endpoints_x1 = [] + _endpoints_x2 = [] + for _i in range(_n_rays): + _ic = _terminal[_i] + _model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state={ + "x1": float(_ic[0]), + "x2": float(_ic[1]), + "p1": float(_ic[2]), + "p2": float(_ic[3]), + }, + rhs=rhs_backward, + params={"w": [w_val]}, + ) + _sim = ODESimulation( + model=_model, + t_span=(0.0, _T_iso), + solver="RK45", + rtol=1e-10, + atol=1e-12, + max_step=0.05, + ) + _res = _sim.run() + _endpoints_x1.append(_res.state_array("x1")[-1]) + _endpoints_x2.append(_res.state_array("x2")[-1]) + + _fig, _ax = plt.subplots(1, 1, figsize=(8, 8)) + + # Terminal circle + _theta = np.linspace(0, 2 * np.pi, 200) + _ax.plot( + ell_tilde_val * np.cos(_theta), + ell_tilde_val * np.sin(_theta), + "k--", + linewidth=1, + alpha=0.5, + ) + + # Isochrone boundary + _ax.plot( + _endpoints_x1, + _endpoints_x2, + "b-", + linewidth=2, + label=rf"$T = {_T_iso:.1f}$", + ) + _ax.fill(_endpoints_x1, _endpoints_x2, color="blue", alpha=0.05) + + _ax.plot(0, 0, "k+", markersize=12, markeredgewidth=2) + _ax.set_xlabel(r"$x_1$") + _ax.set_ylabel(r"$x_2$") + _ax.set_aspect("equal") + _ax.set_title( + rf"Backward Reachable Set ($w = {w_val:.3f}$, " + rf"$\tilde{{\ell}} = {ell_tilde_val:.3f}$)" + ) + _ax.legend(fontsize=10) + _ax.grid(True, alpha=0.3) + plt.tight_layout() + + mo.vstack( + [ + _fig, + mo.md( + rf""" + The **isochrone** at $T = {_T_iso:.1f}$: the boundary of the set of + initial positions from which the pursuer can guarantee capture within + time $T$ under optimal play. Computed by integrating {_n_rays} backward + characteristics from the usable part of the terminal circle, each for + time $T$, using `gds-continuous.ODESimulation`. + """ + ), + ] + ) + return + + +@app.cell +def conservation_check(mo, np, trajectories, w_val): + _max_H = 0.0 + _max_p = 0.0 + for _res in trajectories: + _x1s = _res.state_array("x1") + _x2s = _res.state_array("x2") + _p1s = _res.state_array("p1") + _p2s = _res.state_array("p2") + for _j in range(len(_res)): + _sigma = _p2s[_j] * _x1s[_j] - _p1s[_j] * _x2s[_j] + _norm_p = np.sqrt(_p1s[_j] ** 2 + _p2s[_j] ** 2) + _H = -abs(_sigma) + w_val * _norm_p - _p2s[_j] + 1.0 + _max_H = max(_max_H, abs(_H)) + + _p_norms = [_p1s[_j] ** 2 + _p2s[_j] ** 2 for _j in range(len(_res))] + _p0 = _p_norms[0] + _max_p = max(_max_p, max(abs(_pn - _p0) for _pn in _p_norms)) + + mo.md( + rf""" + ## Verification + + | Conservation law | Max drift | Tolerance | + |-----------------|-----------|-----------| + | $H^* \approx 0$ | ${_max_H:.2e}$ | $< 10^{{-6}}$ | + | $\|\|p\|\|^2$ constant | ${_max_p:.2e}$ | $< 10^{{-8}}$ | + + { + "All conservation laws satisfied." + if _max_H < 1e-6 and _max_p < 1e-8 + else "**WARNING:** Conservation law violation detected." + } + """ + ) + return + + +@app.cell +def references(mo): + mo.md( + r""" + ## References + + - R. Isaacs, *Games of Pursuit*, RAND Corporation P-257 (1951) + - R. Isaacs, *Differential Games*, John Wiley & Sons (1965), pp. 297--350 + - A.W. Merz, *The Homicidal Chauffeur -- a Differential Game*, + PhD Thesis, Stanford (1971) + - V.S. Patsko & V.L. Turova, "Homicidal Chauffeur Game: History and + Modern Studies," *Advances in Dynamic Games*, ISDG Vol. 11 (2011) + - [mzargham/hc-marimo](https://github.com/mzargham/hc-marimo) -- + reference SymPy implementation + - [gds-core](https://github.com/BlockScience/gds-core) -- + GDS ecosystem (`gds-continuous` for ODE integration) + """ + ) + return + + +if __name__ == "__main__": + app.run() diff --git a/packages/gds-examples/continuous/homicidal_chauffeur/test_homicidal_chauffeur.py b/packages/gds-examples/continuous/homicidal_chauffeur/test_homicidal_chauffeur.py new file mode 100644 index 0000000..d2abde0 --- /dev/null +++ b/packages/gds-examples/continuous/homicidal_chauffeur/test_homicidal_chauffeur.py @@ -0,0 +1,235 @@ +"""Tests for the Homicidal Chauffeur example. + +Verifies symbolic derivation, numerical integration, conservation laws, +and backward reachable set computation against known results from +Isaacs (1965) and Merz (1971). + +Test IDs trace to mzargham/hc-marimo verification suite: + HC-R1: Reduced dynamics structure + HC-T1: Lambdified vs hand-coded consistency + HC-T2: Hamiltonian conservation + HC-T3: Costate norm conservation + HC-T4: Forward/backward capture round-trip + HC-T6: Stationary evader straight-line capture + HC-T7: Usable part boundary + HC-ISO: Isochrone computation +""" + +from __future__ import annotations + +import math +from typing import Any + +import pytest + +from continuous.homicidal_chauffeur.model import ( + build_backward_simulation, + compute_isochrone, + costate_norm_sq, + hamiltonian_star, + hc_forward, + terminal_conditions, +) +from gds_continuous import ODEModel, ODESimulation + +# --------------------------------------------------------------------------- +# HC-R1: Symbolic derivation matches hand-coded +# --------------------------------------------------------------------------- + + +class TestSymbolicDerivation: + """Verify symbolic RHS matches hand-coded at random points.""" + + def test_lambdified_vs_handcoded(self) -> None: + """HC-T1: Lambdified and hand-coded RHS agree.""" + pytest.importorskip("sympy") + from continuous.homicidal_chauffeur.model import derive_optimal_rhs + + rhs_sym, _ = derive_optimal_rhs() + + import numpy as np + + rng = np.random.default_rng(42) + for _ in range(50): + x1 = rng.uniform(-5, 5) + x2 = rng.uniform(-5, 5) + angle = rng.uniform(0, 2 * np.pi) + norm_p = rng.uniform(0.1, 3.0) + p1 = norm_p * np.cos(angle) + p2 = norm_p * np.sin(angle) + w_val = rng.uniform(0.05, 0.5) + + state = [x1, x2, p1, p2] + params = {"w": w_val} + + hand = hc_forward(0.0, state, params) + lamb = rhs_sym(0.0, state, params) + + for j in range(4): + assert abs(hand[j] - lamb[j]) < 1e-10, ( + f"Component {j} mismatch at state={state}, w={w_val}" + ) + + +# --------------------------------------------------------------------------- +# HC-T2 / HC-T3: Conservation laws +# --------------------------------------------------------------------------- + + +class TestConservationLaws: + """Hamiltonian and costate norm conservation along trajectories.""" + + @pytest.fixture + def trajectory(self) -> list[dict[str, float]]: + """Integrate backward from capture circle and return state dicts.""" + sim = build_backward_simulation( + alpha=math.pi / 2, w=0.25, ell_tilde=0.5, t_final=10.0 + ) + results = sim.run() + rows = results.to_list() + return [ + {"x1": r["x1"], "x2": r["x2"], "p1": r["p1"], "p2": r["p2"]} for r in rows + ] + + def test_hamiltonian_conserved(self, trajectory: list[dict[str, float]]) -> None: + """HC-T2: H* ~ 0 along optimal trajectories.""" + h_vals = [hamiltonian_star(s, w=0.25) for s in trajectory] + max_drift = max(abs(h) for h in h_vals) + assert max_drift < 1e-6, f"H* drift = {max_drift}" + + def test_costate_norm_conserved(self, trajectory: list[dict[str, float]]) -> None: + """HC-T3: ||p||^2 conserved.""" + norms = [costate_norm_sq(s) for s in trajectory] + initial = norms[0] + max_drift = max(abs(n - initial) for n in norms) + assert max_drift < 1e-8, f"||p||^2 drift = {max_drift}" + + +# --------------------------------------------------------------------------- +# HC-T4: Capture round-trip +# --------------------------------------------------------------------------- + + +class TestCaptureRoundTrip: + """Forward integration from a backward-computed state reaches capture.""" + + def test_round_trip(self) -> None: + """HC-T4: backward -> forward returns to capture circle.""" + w_val = 0.25 + ell_tilde = 0.5 + + # Backward from capture circle + sim_back = build_backward_simulation( + alpha=math.pi / 2, w=w_val, ell_tilde=ell_tilde, t_final=5.0 + ) + results_back = sim_back.run() + n = len(results_back) + y0_far = { + "x1": results_back.state_array("x1")[n - 1], + "x2": results_back.state_array("x2")[n - 1], + "p1": results_back.state_array("p1")[n - 1], + "p2": results_back.state_array("p2")[n - 1], + } + + # Forward from that state + model_fwd = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state=y0_far, + rhs=hc_forward, + params={"w": [w_val]}, + ) + sim_fwd = ODESimulation( + model=model_fwd, + t_span=(0.0, 5.0), + solver="RK45", + rtol=1e-10, + atol=1e-12, + ) + results_fwd = sim_fwd.run() + + x1_f = results_fwd.state_array("x1")[-1] + x2_f = results_fwd.state_array("x2")[-1] + dist = math.sqrt(x1_f**2 + x2_f**2) + assert dist == pytest.approx(ell_tilde, abs=0.05) + + +# --------------------------------------------------------------------------- +# HC-T6: Stationary evader +# --------------------------------------------------------------------------- + + +class TestStationaryEvader: + """w=0: pursuer drives straight to capture.""" + + def test_straight_line_capture(self) -> None: + """HC-T6: With w=0, capture time = initial distance.""" + + def hc_w0(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + return [0.0, -1.0, 0.0, 0.0] + + d = 3.0 + model = ODEModel( + state_names=["x1", "x2", "p1", "p2"], + initial_state={"x1": 0.0, "x2": d, "p1": 0.0, "p2": 1.0}, + rhs=hc_w0, + params={"w": [0.0]}, + ) + sim = ODESimulation(model=model, t_span=(0.0, d), t_eval=[0.0, d]) + results = sim.run() + assert results.state_array("x2")[-1] == pytest.approx(0.0, abs=1e-6) + + +# --------------------------------------------------------------------------- +# HC-T7: Usable part boundary +# --------------------------------------------------------------------------- + + +class TestUsablePart: + """The usable part requires sin(alpha) > w.""" + + def test_usable_part_valid(self) -> None: + """HC-T7: alpha=pi/2, w=0.25 => sin(alpha)=1 > 0.25 => usable.""" + ic = terminal_conditions(math.pi / 2, w=0.25, ell_tilde=0.5) + # lambda should be positive (characteristic of usable part) + # lambda = -1 / (ell * (w - sin(alpha))) = -1 / (0.5 * (0.25 - 1)) + # = -1 / (-0.375) = 2.667 + lam = -1.0 / (0.5 * (0.25 - 1.0)) + assert lam > 0 + # p should be parallel to x (transversality) + assert ic["p1"] == pytest.approx(lam * ic["x1"]) + assert ic["p2"] == pytest.approx(lam * ic["x2"]) + + def test_usable_part_boundary(self) -> None: + """At sin(alpha) = w, lambda diverges (boundary of usable part).""" + w = 0.25 + alpha_boundary = math.asin(w) + # denominator -> 0 + denom = 0.5 * (w - math.sin(alpha_boundary)) + assert abs(denom) < 1e-10 + + +# --------------------------------------------------------------------------- +# Isochrone (backward reachable set) +# --------------------------------------------------------------------------- + + +class TestIsochrone: + """Backward reachable set computation.""" + + def test_isochrone_produces_points(self) -> None: + """HC-ISO: Isochrone returns boundary points.""" + points = compute_isochrone(w=0.25, ell_tilde=0.5, t_final=3.0, n_rays=10) + assert len(points) == 10 + # All points should be farther from origin than capture radius + for x1, x2 in points: + dist = math.sqrt(x1**2 + x2**2) + assert dist > 0.5, f"Point ({x1}, {x2}) inside capture circle" + + def test_isochrone_grows_with_time(self) -> None: + """Larger t_final => points farther from origin.""" + pts_short = compute_isochrone(w=0.25, ell_tilde=0.5, t_final=2.0, n_rays=5) + pts_long = compute_isochrone(w=0.25, ell_tilde=0.5, t_final=5.0, n_rays=5) + + avg_short = sum(math.sqrt(x**2 + y**2) for x, y in pts_short) / len(pts_short) + avg_long = sum(math.sqrt(x**2 + y**2) for x, y in pts_long) / len(pts_long) + assert avg_long > avg_short diff --git a/packages/gds-examples/games/crosswalk/analysis.py b/packages/gds-examples/games/crosswalk/analysis.py new file mode 100644 index 0000000..3b2722b --- /dev/null +++ b/packages/gds-examples/games/crosswalk/analysis.py @@ -0,0 +1,227 @@ +"""Crosswalk Problem — Dynamical Analysis with gds-analysis. + +Demonstrates reachability, admissibility enforcement, and metric +computation on the crosswalk Markov system. + +Run: uv run python packages/gds-examples/games/crosswalk/analysis.py + +This script shows how gds-analysis bridges the structural spec +(gds-framework) to runtime execution (gds-sim), enabling: + 1. Reachable set computation — R(x) from each state + 2. Reachability graph — multi-step state transitions + 3. Configuration space — strongly connected components + 4. Admissibility constraint enforcement at runtime + 5. State metric computation on trajectories +""" + +import sys +from pathlib import Path + +# Allow running as: uv run python packages/gds-examples/games/crosswalk/analysis.py +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from gds_analysis.adapter import spec_to_model +from gds_analysis.metrics import trajectory_distances +from gds_analysis.reachability import ( + configuration_space, + reachable_graph, + reachable_set, +) + +from crosswalk.model import build_spec +from gds.constraints import AdmissibleInputConstraint, StateMetric +from gds_sim import Simulation + +# ── Behavioral functions ──────────────────────────────────────────── + + +def observe_policy(state, params, **kw): + return { + "traffic_state": state.get("Street.traffic_state", 1), + "luck": 1, + } + + +def pedestrian_policy(state, params, **kw): + return {"cross": 1, "position": params.get("crosswalk_location", 0.5)} + + +def safety_policy(state, params, **kw): + return {"safe_crossing": 1, "cross": 1} + + +def transition_suf(state, params, *, signal=None, **kw): + """Markov transition: Flowing/Stopped/Accident.""" + signal = signal or {} + current = state.get("Street.traffic_state", 1) + cross = signal.get("cross", 0) + safe = signal.get("safe_crossing", 1) + + if current == -1: + return "Street.traffic_state", 0 # Accident → Stopped (recovery) + if cross == 0: + return "Street.traffic_state", 1 # Don't cross → Flowing + elif safe == 1: + return "Street.traffic_state", 0 # Safe cross → Stopped + else: + return "Street.traffic_state", -1 # Jaywalk + bad luck → Accident + + +STATE_NAMES = {-1: "Accident", 0: "Stopped", 1: "Flowing"} + + +def main(): + # ── 1. Build spec with structural annotations ───────────────── + spec = build_spec() + + spec.register_admissibility( + AdmissibleInputConstraint( + name="valid_traffic_state", + boundary_block="Observe Traffic", + depends_on=[("Street", "traffic_state")], + constraint=lambda state, signal: ( + signal.get("traffic_state", 0) in {-1, 0, 1} + ), + description="Traffic state must be in {-1, 0, +1}", + ) + ) + spec.register_state_metric( + StateMetric( + name="state_change", + variables=[("Street", "traffic_state")], + metric_type="absolute", + distance=lambda a, b: abs( + a.get("Street.traffic_state", 0) - b.get("Street.traffic_state", 0) + ), + ) + ) + + # ── 2. Build executable model ───────────────────────────────── + model = spec_to_model( + spec, + policies={ + "Observe Traffic": observe_policy, + "Pedestrian Decision": pedestrian_policy, + "Safety Check": safety_policy, + }, + sufs={"Traffic Transition": transition_suf}, + initial_state={"Street.traffic_state": 1}, + params={"crosswalk_location": [0.5]}, + enforce_constraints=True, + ) + + # ── 3. Reachable set from each state ────────────────────────── + print("=" * 60) + print("REACHABLE SET ANALYSIS") + print("=" * 60) + + input_samples = [ + {"cross": 0, "safe_crossing": 1}, # Don't cross + {"cross": 1, "safe_crossing": 1}, # Cross safely + {"cross": 1, "safe_crossing": 0}, # Jaywalk (bad luck) + ] + + for start_state in [1, 0, -1]: + state = {"Street.traffic_state": start_state} + reached = reachable_set( + spec, + model, + state, + input_samples=input_samples, + state_key="Street.traffic_state", + ) + reached_vals = sorted({r["Street.traffic_state"] for r in reached}) + reached_names = [STATE_NAMES[v] for v in reached_vals] + print(f" R({STATE_NAMES[start_state]:>8}) = {{{', '.join(reached_names)}}}") + + # ── 4. Reachability graph ───────────────────────────────────── + print() + print("=" * 60) + print("REACHABILITY GRAPH (depth=2)") + print("=" * 60) + + initials = [{"Street.traffic_state": s} for s in [1, 0, -1]] + graph = reachable_graph( + spec, + model, + initials, + input_samples=input_samples, + max_depth=2, + state_key="Street.traffic_state", + ) + for src, dsts in sorted(graph.items()): + src_name = STATE_NAMES.get(src[1], str(src)) + dst_names = sorted({STATE_NAMES.get(d[1], str(d)) for d in dsts}) + print(f" {src_name:>8} -> {{{', '.join(dst_names)}}}") + + # ── 5. Configuration space (SCCs) ───────────────────────────── + print() + print("=" * 60) + print("CONFIGURATION SPACE (strongly connected components)") + print("=" * 60) + + sccs = configuration_space(graph) + for i, scc in enumerate(sccs): + names = sorted(STATE_NAMES.get(s[1], str(s)) for s in scc) + mutual = "mutually reachable" if len(scc) > 1 else "isolated" + print(f" SCC {i + 1}: {{{', '.join(names)}}} ({mutual})") + + if sccs: + largest = sccs[0] + print( + f"\n X_C (configuration space) = " + f"{{{', '.join(sorted(STATE_NAMES.get(s[1], str(s)) for s in largest))}}}" + ) + + # ── 6. Simulate and compute metrics ─────────────────────────── + print() + print("=" * 60) + print("TRAJECTORY SIMULATION (20 timesteps)") + print("=" * 60) + + sim = Simulation(model=model, timesteps=20, runs=1) + trajectory = sim.run().to_list() + + for t, row in enumerate(trajectory[:10]): + ts = row.get("Street.traffic_state", "?") + print(f" t={t:2d}: {STATE_NAMES.get(ts, str(ts))}") + if len(trajectory) > 10: + print(f" ... ({len(trajectory) - 10} more steps)") + + # ── 7. State metric distances ───────────────────────────────── + print() + print("=" * 60) + print("STATE METRIC: |delta(traffic_state)|") + print("=" * 60) + + distances = trajectory_distances(spec, trajectory) + dists = distances["state_change"] + n_changes = sum(1 for d in dists if d > 0) + print(f" Total transitions: {len(dists)}") + print(f" State changes: {n_changes}") + print(f" Max distance: {max(dists) if dists else 0}") + print(f" Mean distance: {sum(dists) / len(dists):.2f}" if dists else "") + + # ── Key analytical result ───────────────────────────────────── + print() + print("=" * 60) + print("KEY RESULT") + print("=" * 60) + # Check: Flowing unreachable from Accident? + accident_reached = reachable_set( + spec, + model, + {"Street.traffic_state": -1}, + input_samples=input_samples, + state_key="Street.traffic_state", + ) + accident_reachable = {r["Street.traffic_state"] for r in accident_reached} + if 1 not in accident_reachable: + print(" VERIFIED: Flowing (+1) is unreachable from Accident (-1) in one step.") + print(" This matches Paper Def 4.1: R(Accident) = {Accident, Stopped}") + else: + print(" WARNING: Flowing IS reachable from Accident (unexpected)") + + +if __name__ == "__main__": + main() diff --git a/packages/gds-examples/pyproject.toml b/packages/gds-examples/pyproject.toml index 89cd0af..b0743c3 100644 --- a/packages/gds-examples/pyproject.toml +++ b/packages/gds-examples/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "gds-stockflow>=0.1.0", "gds-games>=0.3.0", "gds-software>=0.1.0", + "gds-continuous>=0.1.0", + "gds-symbolic>=0.1.0", ] [project.urls] @@ -49,6 +51,8 @@ gds-control = { workspace = true } gds-stockflow = { workspace = true } gds-games = { workspace = true } gds-software = { workspace = true } +gds-continuous = { workspace = true } +gds-symbolic = { workspace = true } [build-system] requires = ["hatchling"] diff --git a/packages/gds-examples/stockflow/sir_epidemic/analysis.py b/packages/gds-examples/stockflow/sir_epidemic/analysis.py new file mode 100644 index 0000000..1407f48 --- /dev/null +++ b/packages/gds-examples/stockflow/sir_epidemic/analysis.py @@ -0,0 +1,214 @@ +"""SIR Epidemic — Dynamical Analysis with gds-analysis. + +Demonstrates reachability and metric computation on a continuous-state +epidemiological model. + +Run: uv run python packages/gds-examples/stockflow/sir_epidemic/analysis.py + +Shows: + 1. Structural spec with AdmissibleInputConstraint + StateMetric + 2. Adapter: GDSSpec → gds_sim.Model + 3. Trajectory simulation with constraint enforcement + 4. Population distance metrics over time + 5. Reachable states from initial conditions +""" + +import math +import sys +from pathlib import Path + +# Allow standalone execution with correct import paths +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from gds_analysis.adapter import spec_to_model +from gds_analysis.metrics import trajectory_distances +from gds_analysis.reachability import reachable_set + +from gds.constraints import AdmissibleInputConstraint, StateMetric +from gds_sim import Simulation +from sir_epidemic.model import build_spec + +# ── Behavioral functions ──────────────────────────────────────────── + + +def contact_policy(state, params, **kw): + return {"contact_rate": params.get("contact_rate", 5.0)} + + +def infection_policy(state, params, **kw): + """SIR dynamics: dS/dt = -beta*S*I/N, dI/dt = beta*S*I/N - gamma*I.""" + s = state.get("Susceptible.count", 0.0) + i = state.get("Infected.count", 0.0) + r = state.get("Recovered.count", 0.0) + n = s + i + r or 1.0 + beta = params.get("beta", 0.3) + gamma = params.get("gamma", 0.1) + + new_infections = beta * s * i / n + recoveries = gamma * i + return { + "delta_s": -new_infections, + "delta_i": new_infections - recoveries, + "delta_r": recoveries, + } + + +def suf_susceptible(state, params, *, signal=None, **kw): + signal = signal or {} + return "Susceptible.count", max( + 0.0, state.get("Susceptible.count", 0.0) + signal.get("delta_s", 0.0) + ) + + +def suf_infected(state, params, *, signal=None, **kw): + signal = signal or {} + return "Infected.count", max( + 0.0, state.get("Infected.count", 0.0) + signal.get("delta_i", 0.0) + ) + + +def suf_recovered(state, params, *, signal=None, **kw): + signal = signal or {} + return "Recovered.count", max( + 0.0, state.get("Recovered.count", 0.0) + signal.get("delta_r", 0.0) + ) + + +def main(): + # ── 1. Build spec with structural annotations ───────────────── + spec = build_spec() + + spec.register_admissibility( + AdmissibleInputConstraint( + name="contact_rate_positive", + boundary_block="Contact Process", + depends_on=[], + constraint=lambda state, signal: signal.get("contact_rate", 0) > 0, + description="Contact rate must be positive", + ) + ) + spec.register_state_metric( + StateMetric( + name="population_distance", + variables=[ + ("Susceptible", "count"), + ("Infected", "count"), + ("Recovered", "count"), + ], + metric_type="euclidean", + distance=lambda a, b: math.sqrt( + sum((a.get(k, 0) - b.get(k, 0)) ** 2 for k in set(a) | set(b)) + ), + ) + ) + + # ── 2. Build executable model ───────────────────────────────── + model = spec_to_model( + spec, + policies={ + "Contact Process": contact_policy, + "Infection Policy": infection_policy, + }, + sufs={ + "Update Susceptible": suf_susceptible, + "Update Infected": suf_infected, + "Update Recovered": suf_recovered, + }, + initial_state={ + "Susceptible.count": 999.0, + "Infected.count": 1.0, + "Recovered.count": 0.0, + }, + params={"beta": [0.3], "gamma": [0.1], "contact_rate": [5.0]}, + enforce_constraints=True, + ) + + # ── 3. Simulate ─────────────────────────────────────────────── + print("=" * 60) + print("SIR EPIDEMIC SIMULATION (50 timesteps)") + print("=" * 60) + + sim = Simulation(model=model, timesteps=50, runs=1) + trajectory = sim.run().to_list() + + peak_i = 0.0 + peak_t = 0 + for t, row in enumerate(trajectory): + s = row.get("Susceptible.count", 0.0) + i = row.get("Infected.count", 0.0) + r = row.get("Recovered.count", 0.0) + if i > peak_i: + peak_i = i + peak_t = t + if t % 10 == 0 or t == len(trajectory) - 1: + print(f" t={t:3d}: S={s:7.1f} I={i:7.1f} R={r:7.1f}") + + print(f"\n Peak infected: {peak_i:.1f} at t={peak_t}") + + # ── 4. Population conservation check ────────────────────────── + print() + print("=" * 60) + print("POPULATION CONSERVATION") + print("=" * 60) + + violations = 0 + for row in trajectory: + total = ( + row.get("Susceptible.count", 0) + + row.get("Infected.count", 0) + + row.get("Recovered.count", 0) + ) + if abs(total - 1000.0) > 1e-6: + violations += 1 + + if violations == 0: + print(" VERIFIED: S + I + R = 1000 at every timestep") + else: + print(f" WARNING: {violations} timesteps violated conservation") + + # ── 5. Trajectory distances ─────────────────────────────────── + print() + print("=" * 60) + print("STATE METRIC: Euclidean distance between successive states") + print("=" * 60) + + distances = trajectory_distances(spec, trajectory) + dists = distances["population_distance"] + print(f" Transitions: {len(dists)}") + print(f" Max distance: {max(dists):.2f}") + print(f" Mean distance: {sum(dists) / len(dists):.2f}") + print(f" Final distance: {dists[-1]:.2f} (convergence indicator)") + + # ── 6. Reachable set from initial state ─────────────────────── + print() + print("=" * 60) + print("REACHABLE SET from (S=999, I=1, R=0)") + print("=" * 60) + + initial = { + "Susceptible.count": 999.0, + "Infected.count": 1.0, + "Recovered.count": 0.0, + } + # Vary the infection dynamics by overriding policy outputs + samples = [ + {"delta_s": -d, "delta_i": d - 0.1, "delta_r": 0.1} + for d in [0.0, 0.3, 1.0, 5.0, 10.0] + ] + reached = reachable_set( + spec, + model, + initial, + input_samples=samples, + state_key="Infected.count", + ) + print(f" {len(reached)} distinct next states found:") + for r in sorted(reached, key=lambda x: x.get("Infected.count", 0)): + s = r.get("Susceptible.count", 0) + i = r.get("Infected.count", 0) + rc = r.get("Recovered.count", 0) + print(f" S={s:.1f} I={i:.1f} R={rc:.1f}") + + +if __name__ == "__main__": + main() diff --git a/packages/gds-symbolic/README.md b/packages/gds-symbolic/README.md new file mode 100644 index 0000000..e30bea8 --- /dev/null +++ b/packages/gds-symbolic/README.md @@ -0,0 +1,37 @@ +# gds-symbolic + +Symbolic math bridge for the GDS ecosystem — compiles SymPy expressions +into plain Python callables for use with `gds-continuous`. + +## Installation + +```bash +uv add gds-symbolic[sympy] +``` + +## Quick Start + +```python +from gds_control.dsl.elements import State, Input, Sensor, Controller +from gds_symbolic import SymbolicControlModel, StateEquation + +model = SymbolicControlModel( + name="damped_oscillator", + states=[State(name="x"), State(name="v")], + inputs=[Input(name="force")], + sensors=[Sensor(name="position", observes=["x"])], + controllers=[Controller(name="actuator", reads=["position", "force"], drives=["x", "v"])], + state_equations=[ + StateEquation(state_name="x", expr_str="v"), + StateEquation(state_name="v", expr_str="-k*x - c*v + force"), + ], + symbolic_params=["k", "c"], +) + +# Compile to plain callable (no SymPy at runtime) +ode_fn, state_order = model.to_ode_function() + +# Linearize at origin +lin = model.linearize(x0=[0.0, 0.0], u0=[0.0]) +print(lin.A) # [[0, 1], [-k, -c]] evaluated at operating point +``` diff --git a/packages/gds-symbolic/gds_symbolic/__init__.py b/packages/gds-symbolic/gds_symbolic/__init__.py new file mode 100644 index 0000000..6d6ab29 --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/__init__.py @@ -0,0 +1,16 @@ +"""gds-symbolic: Symbolic math bridge for the GDS ecosystem.""" + +__version__ = "0.1.0" + +from gds_symbolic.elements import OutputEquation, StateEquation +from gds_symbolic.errors import SymbolicError +from gds_symbolic.linearize import LinearizedSystem +from gds_symbolic.model import SymbolicControlModel + +__all__ = [ + "LinearizedSystem", + "OutputEquation", + "StateEquation", + "SymbolicControlModel", + "SymbolicError", +] diff --git a/packages/gds-symbolic/gds_symbolic/_compat.py b/packages/gds-symbolic/gds_symbolic/_compat.py new file mode 100644 index 0000000..a31f442 --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/_compat.py @@ -0,0 +1,12 @@ +"""Optional dependency guards.""" + + +def require_sympy() -> None: + """Raise ImportError with install instructions if sympy is absent.""" + try: + import sympy # noqa: F401 + except ImportError as exc: + raise ImportError( + "sympy is required for symbolic operations. " + "Install with: uv add gds-symbolic[sympy]" + ) from exc diff --git a/packages/gds-symbolic/gds_symbolic/compile.py b/packages/gds-symbolic/gds_symbolic/compile.py new file mode 100644 index 0000000..df30116 --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/compile.py @@ -0,0 +1,78 @@ +"""Compile symbolic equations to plain Python callables via sympy.lambdify.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from gds_symbolic._compat import require_sympy + +if TYPE_CHECKING: + from gds_continuous.types import ODEFunction + + from gds_symbolic.model import SymbolicControlModel + + +def compile_to_ode( + model: SymbolicControlModel, +) -> tuple[ODEFunction, list[str]]: + """Compile a SymbolicControlModel's state equations to an ODEFunction. + + Returns + ------- + ode_fn : ODEFunction + Plain Python callable with signature ``(t, y, params) -> dy/dt``. + No SymPy objects at runtime — fully lambdified. + state_order : list[str] + State variable names in the vector order used by ``ode_fn``. + """ + require_sympy() + import sympy + + state_order = [s.name for s in model.states] + input_names = [i.name for i in model.inputs] + + # Build symbol table + state_syms = {name: sympy.Symbol(name) for name in state_order} + input_syms = {name: sympy.Symbol(name) for name in input_names} + param_syms = {name: sympy.Symbol(name) for name in model.symbolic_params} + + all_syms = {**state_syms, **input_syms, **param_syms} + + # Parse expressions using safe parser (no eval, no builtins) + from sympy.parsing.sympy_parser import parse_expr + + eq_map: dict[str, Any] = {} + for eq in model.state_equations: + expr = parse_expr(eq.expr_str, local_dict=all_syms) + eq_map[eq.state_name] = expr + + # Build ordered RHS vector + rhs_exprs = [] + for name in state_order: + if name in eq_map: + rhs_exprs.append(eq_map[name]) + else: + # State with no equation: dx/dt = 0 + rhs_exprs.append(sympy.Integer(0)) + + # Lambdify: args are ordered state vars + input vars + param vars + ordered_symbols = ( + [state_syms[n] for n in state_order] + + [input_syms[n] for n in input_names] + + [param_syms[n] for n in model.symbolic_params] + ) + rhs_lambda = sympy.lambdify(ordered_symbols, rhs_exprs, modules="math") + + n_states = len(state_order) + + def ode_fn(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + # Unpack inputs from params dict + input_vals = [params.get(name, 0.0) for name in input_names] + param_vals = [params.get(name, 0.0) for name in model.symbolic_params] + args = list(y[:n_states]) + input_vals + param_vals + result = rhs_lambda(*args) + if isinstance(result, (int, float)): + return [float(result)] + return [float(v) for v in result] + + return ode_fn, state_order diff --git a/packages/gds-symbolic/gds_symbolic/elements.py b/packages/gds-symbolic/gds_symbolic/elements.py new file mode 100644 index 0000000..14972f4 --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/elements.py @@ -0,0 +1,30 @@ +"""Symbolic equation elements for annotating control models with ODEs.""" + +from __future__ import annotations + +from pydantic import BaseModel + + +class StateEquation(BaseModel, frozen=True): + """Symbolic ODE right-hand side for a single state variable. + + Declares: dx_i/dt = expr, where ``expr_str`` is a SymPy-parseable + string (e.g. ``"-k*x + u"``). + + The string form is R1-serializable. The sympy.Expr object is R3 — + reconstructed at lambdify time via ``sympy.sympify``. + """ + + state_name: str + expr_str: str + + +class OutputEquation(BaseModel, frozen=True): + """Symbolic observation equation: y_i = h(x, u). + + Maps a sensor's output to a symbolic expression over + state variables and inputs. + """ + + sensor_name: str + expr_str: str diff --git a/packages/gds-symbolic/gds_symbolic/errors.py b/packages/gds-symbolic/gds_symbolic/errors.py new file mode 100644 index 0000000..09f0bcf --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/errors.py @@ -0,0 +1,7 @@ +"""Errors for gds-symbolic.""" + +from gds_control.dsl.errors import CSError + + +class SymbolicError(CSError): + """Raised when symbolic model construction or compilation fails.""" diff --git a/packages/gds-symbolic/gds_symbolic/linearize.py b/packages/gds-symbolic/gds_symbolic/linearize.py new file mode 100644 index 0000000..2d6ae2f --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/linearize.py @@ -0,0 +1,139 @@ +"""Linearization: compute Jacobian matrices at an operating point.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from gds_symbolic._compat import require_sympy + +if TYPE_CHECKING: + from gds_symbolic.model import SymbolicControlModel + + +@dataclass(frozen=True) +class LinearizedSystem: + """State-space matrices (A, B, C, D) at an operating point. + + All matrices are lists-of-lists (no numpy dependency required). + """ + + A: list[list[float]] + B: list[list[float]] + C: list[list[float]] + D: list[list[float]] + x0: list[float] + u0: list[float] + state_names: list[str] = field(default_factory=list) + input_names: list[str] = field(default_factory=list) + output_names: list[str] = field(default_factory=list) + + +def linearize( + model: SymbolicControlModel, + x0: list[float], + u0: list[float], + param_values: dict[str, float] | None = None, +) -> LinearizedSystem: + """Compute linearization at operating point (x0, u0). + + Computes Jacobians of the state equations w.r.t. states (A) + and inputs (B), and output equations w.r.t. states (C) and + inputs (D), all evaluated at the given operating point. + + Parameters + ---------- + model : SymbolicControlModel + x0 : list[float] + Operating point state values (ordered by model.states). + u0 : list[float] + Operating point input values (ordered by model.inputs). + param_values : dict[str, float] | None + Parameter values for substitution. Defaults to 0.0 for + any unspecified parameter. + """ + require_sympy() + import sympy + + param_values = param_values or {} + + state_names = [s.name for s in model.states] + input_names = [i.name for i in model.inputs] + + state_syms = {name: sympy.Symbol(name) for name in state_names} + input_syms = {name: sympy.Symbol(name) for name in input_names} + param_syms = {name: sympy.Symbol(name) for name in model.symbolic_params} + + all_syms: dict[str, Any] = {**state_syms, **input_syms, **param_syms} + + # Build substitution dict for operating point + params + subs: dict[Any, float] = {} + for i, name in enumerate(state_names): + subs[state_syms[name]] = x0[i] + for i, name in enumerate(input_names): + subs[input_syms[name]] = u0[i] + for name in model.symbolic_params: + subs[param_syms[name]] = param_values.get(name, 0.0) + + # Parse state equations using safe parser (no eval, no builtins) + from sympy.parsing.sympy_parser import parse_expr + + eq_map: dict[str, Any] = {} + for eq in model.state_equations: + eq_map[eq.state_name] = parse_expr(eq.expr_str, local_dict=all_syms) + + # A matrix: df_i/dx_j + A = _jacobian(eq_map, state_names, state_syms, subs) + + # B matrix: df_i/du_j + B = _jacobian(eq_map, state_names, input_syms, subs, col_names=input_names) + + # Parse output equations + out_map: dict[str, Any] = {} + output_names: list[str] = [] + for eq in model.output_equations: + out_map[eq.sensor_name] = parse_expr(eq.expr_str, local_dict=all_syms) + output_names.append(eq.sensor_name) + + # C matrix: dh_i/dx_j + C = _jacobian(out_map, output_names, state_syms, subs, col_names=state_names) + + # D matrix: dh_i/du_j + D = _jacobian(out_map, output_names, input_syms, subs, col_names=input_names) + + return LinearizedSystem( + A=A, + B=B, + C=C, + D=D, + x0=list(x0), + u0=list(u0), + state_names=state_names, + input_names=input_names, + output_names=output_names, + ) + + +def _jacobian( + eq_map: dict[str, Any], + row_names: list[str], + col_syms: dict[str, Any], + subs: dict[Any, float], + col_names: list[str] | None = None, +) -> list[list[float]]: + """Compute a Jacobian matrix and evaluate at substitution point.""" + import sympy + + if col_names is None: + col_names = list(col_syms.keys()) + + rows: list[list[float]] = [] + for rname in row_names: + expr = eq_map.get(rname, sympy.Integer(0)) + row: list[float] = [] + for cname in col_names: + deriv = sympy.diff(expr, col_syms[cname]) + val = float(deriv.subs(subs)) + row.append(val) + rows.append(row) + return rows diff --git a/packages/gds-symbolic/gds_symbolic/model.py b/packages/gds-symbolic/gds_symbolic/model.py new file mode 100644 index 0000000..36f6b75 --- /dev/null +++ b/packages/gds-symbolic/gds_symbolic/model.py @@ -0,0 +1,101 @@ +"""SymbolicControlModel — extends ControlModel with symbolic ODEs.""" + +from __future__ import annotations + +from typing import Any, Self + +from gds_control.dsl.model import ControlModel +from pydantic import model_validator + +from gds_symbolic.elements import OutputEquation, StateEquation # noqa: TC001 +from gds_symbolic.errors import SymbolicError + + +class SymbolicControlModel(ControlModel): + """ControlModel extended with symbolic differential equations. + + Preserves all structural GDS semantics from ControlModel. + Adds symbolic ODEs as an annotation layer — they do not affect + ``compile()`` or ``compile_system()`` (those remain structural). + + The ODEs are behavioral (R3) and compile to plain Python callables + via ``to_ode_function()``. + """ + + state_equations: list[StateEquation] = [] # noqa: RUF012 + output_equations: list[OutputEquation] = [] # noqa: RUF012 + symbolic_params: list[str] = [] # noqa: RUF012 + + @model_validator(mode="after") + def _validate_symbolic_structure(self) -> Self: + state_names = {s.name for s in self.states} + sensor_names = {s.name for s in self.sensors} + input_names = {i.name for i in self.inputs} + + errors: list[str] = [] + + # Every state_equation must reference a declared state + for eq in self.state_equations: + if eq.state_name not in state_names: + errors.append( + f"StateEquation references undeclared state '{eq.state_name}'" + ) + + # No duplicate state equations + eq_states = [eq.state_name for eq in self.state_equations] + dupes = {s for s in eq_states if eq_states.count(s) > 1} + if dupes: + errors.append(f"Duplicate state equations for: {sorted(dupes)}") + + # Every output_equation must reference a declared sensor + for eq in self.output_equations: + if eq.sensor_name not in sensor_names: + errors.append( + f"OutputEquation references undeclared sensor '{eq.sensor_name}'" + ) + + # Symbolic params should not collide with state/input names + reserved = state_names | input_names + for p in self.symbolic_params: + if p in reserved: + errors.append( + f"Symbolic param '{p}' conflicts with a state or input name" + ) + + if errors: + raise SymbolicError( + "Symbolic model validation failed:\n" + + "\n".join(f" - {e}" for e in errors) + ) + + return self + + def to_ode_function( + self, + ) -> tuple[Any, list[str]]: + """Compile symbolic ODEs to a plain Python callable. + + Returns + ------- + ode_fn : ODEFunction + Signature: ``(t, y, params) -> dy/dt`` + state_order : list[str] + State variable names in vector order. + """ + from gds_symbolic.compile import compile_to_ode + + return compile_to_ode(self) + + def linearize( + self, + x0: list[float], + u0: list[float], + param_values: dict[str, float] | None = None, + ) -> Any: + """Linearize around an operating point. + + Returns a ``LinearizedSystem`` with A, B, C, D matrices. + """ + from gds_symbolic.linearize import linearize + + return linearize(self, x0, u0, param_values) diff --git a/packages/gds-symbolic/pyproject.toml b/packages/gds-symbolic/pyproject.toml new file mode 100644 index 0000000..3e83969 --- /dev/null +++ b/packages/gds-symbolic/pyproject.toml @@ -0,0 +1,94 @@ +[project] +name = "gds-symbolic" +dynamic = ["version"] +description = "Symbolic math bridge for the GDS ecosystem — SymPy to ODE compilation" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.12" +authors = [ + { name = "Rohan Mehta", email = "rohan@block.science" }, +] +keywords = [ + "generalized-dynamical-systems", + "sympy", + "symbolic-math", + "control-systems", + "gds-framework", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Typing :: Typed", +] +dependencies = [ + "gds-framework>=0.2.3", + "gds-control>=0.1.0", + "pydantic>=2.10", +] + +[project.optional-dependencies] +sympy = ["sympy>=1.13", "numpy>=1.26"] +continuous = ["gds-continuous>=0.1.0"] + +[project.urls] +Homepage = "https://github.com/BlockScience/gds-core" +Repository = "https://github.com/BlockScience/gds-core" +Documentation = "https://blockscience.github.io/gds-core" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "gds_symbolic/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = ["gds_symbolic"] + +[tool.uv.sources] +gds-framework = { workspace = true } +gds-control = { workspace = true } +gds-continuous = { workspace = true } + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "--import-mode=importlib --cov=gds_symbolic --cov-report=term-missing --no-header -q" + +[tool.coverage.run] +source = ["gds_symbolic"] +omit = ["gds_symbolic/__init__.py"] + +[tool.coverage.report] +fail_under = 80 +show_missing = true +exclude_lines = [ + "if TYPE_CHECKING:", + "pragma: no cover", +] + +[tool.mypy] +strict = true + +[tool.ruff] +target-version = "py312" +line-length = 88 + +[tool.ruff.lint] +select = ["E", "W", "F", "I", "UP", "B", "SIM", "TCH", "RUF"] + +[dependency-groups] +dev = [ + "mypy>=1.13", + "pytest>=8.0", + "pytest-cov>=5.0", + "ruff>=0.8", + "sympy>=1.13", + "numpy>=1.26", +] diff --git a/packages/gds-symbolic/tests/__init__.py b/packages/gds-symbolic/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/gds-symbolic/tests/conftest.py b/packages/gds-symbolic/tests/conftest.py new file mode 100644 index 0000000..92cd689 --- /dev/null +++ b/packages/gds-symbolic/tests/conftest.py @@ -0,0 +1,71 @@ +"""Shared fixtures for gds-symbolic tests.""" + +from __future__ import annotations + +import pytest +from gds_control.dsl.elements import Controller, Input, Sensor, State + +from gds_symbolic import StateEquation, SymbolicControlModel + + +@pytest.fixture +def decay_model() -> SymbolicControlModel: + """Single-state exponential decay: dx/dt = -k*x.""" + return SymbolicControlModel( + name="decay", + states=[State(name="x")], + inputs=[Input(name="u")], + sensors=[Sensor(name="obs", observes=["x"])], + controllers=[ + Controller(name="ctrl", reads=["obs", "u"], drives=["x"]), + ], + state_equations=[ + StateEquation(state_name="x", expr_str="-k*x"), + ], + symbolic_params=["k"], + ) + + +@pytest.fixture +def oscillator_model() -> SymbolicControlModel: + """Harmonic oscillator: dx/dt = v, dv/dt = -omega**2 * x.""" + return SymbolicControlModel( + name="oscillator", + states=[State(name="x"), State(name="v")], + inputs=[Input(name="force")], + sensors=[ + Sensor(name="pos_sensor", observes=["x"]), + Sensor(name="vel_sensor", observes=["v"]), + ], + controllers=[ + Controller( + name="actuator", + reads=["pos_sensor", "vel_sensor", "force"], + drives=["x", "v"], + ), + ], + state_equations=[ + StateEquation(state_name="x", expr_str="v"), + StateEquation(state_name="v", expr_str="-omega**2 * x + force"), + ], + symbolic_params=["omega"], + ) + + +@pytest.fixture +def van_der_pol_model() -> SymbolicControlModel: + """Van der Pol oscillator: dx/dt = v, dv/dt = mu*(1-x**2)*v - x.""" + return SymbolicControlModel( + name="van_der_pol", + states=[State(name="x"), State(name="v")], + inputs=[Input(name="u")], + sensors=[Sensor(name="obs", observes=["x"])], + controllers=[ + Controller(name="ctrl", reads=["obs", "u"], drives=["x", "v"]), + ], + state_equations=[ + StateEquation(state_name="x", expr_str="v"), + StateEquation(state_name="v", expr_str="mu*(1 - x**2)*v - x"), + ], + symbolic_params=["mu"], + ) diff --git a/packages/gds-symbolic/tests/test_compile.py b/packages/gds-symbolic/tests/test_compile.py new file mode 100644 index 0000000..cd4b33b --- /dev/null +++ b/packages/gds-symbolic/tests/test_compile.py @@ -0,0 +1,141 @@ +"""Tests for symbolic ODE compilation (sympify + lambdify).""" + +from __future__ import annotations + +import math + +import pytest + +sympy = pytest.importorskip("sympy") + +from gds_symbolic import SymbolicControlModel # noqa: E402 + + +class TestCompileToODE: + def test_decay_rhs(self, decay_model: SymbolicControlModel) -> None: + ode_fn, state_order = decay_model.to_ode_function() + assert state_order == ["x"] + + # dx/dt = -k*x at x=1, k=2 => -2 + result = ode_fn(0.0, [1.0], {"k": 2.0}) + assert result == [pytest.approx(-2.0)] + + def test_decay_rhs_different_params( + self, decay_model: SymbolicControlModel + ) -> None: + ode_fn, _ = decay_model.to_ode_function() + + # k=0.5, x=4 => -0.5*4 = -2 + result = ode_fn(0.0, [4.0], {"k": 0.5}) + assert result == [pytest.approx(-2.0)] + + def test_oscillator_rhs(self, oscillator_model: SymbolicControlModel) -> None: + ode_fn, state_order = oscillator_model.to_ode_function() + assert state_order == ["x", "v"] + + # At x=1, v=0, omega=1, force=0: dx/dt=0, dv/dt=-1 + result = ode_fn(0.0, [1.0, 0.0], {"omega": 1.0, "force": 0.0}) + assert result[0] == pytest.approx(0.0) + assert result[1] == pytest.approx(-1.0) + + def test_oscillator_with_force( + self, oscillator_model: SymbolicControlModel + ) -> None: + ode_fn, _ = oscillator_model.to_ode_function() + + # At x=0, v=0, omega=1, force=5: dx/dt=0, dv/dt=5 + result = ode_fn(0.0, [0.0, 0.0], {"omega": 1.0, "force": 5.0}) + assert result[0] == pytest.approx(0.0) + assert result[1] == pytest.approx(5.0) + + def test_van_der_pol_rhs(self, van_der_pol_model: SymbolicControlModel) -> None: + ode_fn, state_order = van_der_pol_model.to_ode_function() + assert state_order == ["x", "v"] + + # At x=0, v=1, mu=1: dx/dt=1, dv/dt=mu*(1-0)*1 - 0 = 1 + result = ode_fn(0.0, [0.0, 1.0], {"mu": 1.0}) + assert result[0] == pytest.approx(1.0) + assert result[1] == pytest.approx(1.0) + + def test_missing_equation_defaults_to_zero(self) -> None: + """States without equations get dx/dt = 0.""" + from gds_control.dsl.elements import Controller, Input, Sensor, State + + from gds_symbolic import StateEquation + + m = SymbolicControlModel( + name="partial", + states=[State(name="x"), State(name="y")], + inputs=[Input(name="u")], + sensors=[ + Sensor(name="obs_x", observes=["x"]), + Sensor(name="obs_y", observes=["y"]), + ], + controllers=[ + Controller( + name="ctrl", + reads=["obs_x", "obs_y", "u"], + drives=["x", "y"], + ), + ], + state_equations=[ + StateEquation(state_name="x", expr_str="1.0"), + # y has no equation + ], + ) + ode_fn, _order = m.to_ode_function() + result = ode_fn(0.0, [0.0, 0.0], {}) + assert result == [pytest.approx(1.0), pytest.approx(0.0)] + + +class TestIntegrationWithGDSContinuous: + """Compile symbolic model and run through gds-continuous.""" + + def test_decay_integration(self, decay_model: SymbolicControlModel) -> None: + from gds_continuous import ODEModel, ODESimulation + + ode_fn, state_order = decay_model.to_ode_function() + + model = ODEModel( + state_names=state_order, + initial_state={"x": 1.0}, + rhs=ode_fn, + params={"k": [1.0]}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, 3.0), + t_eval=[0.0, 1.0, 2.0, 3.0], + ) + results = sim.run() + + x_vals = results.state_array("x") + for i, t in enumerate([0.0, 1.0, 2.0, 3.0]): + expected = math.exp(-t) + assert x_vals[i] == pytest.approx(expected, abs=1e-4) + + def test_oscillator_integration( + self, oscillator_model: SymbolicControlModel + ) -> None: + from gds_continuous import ODEModel, ODESimulation + + ode_fn, state_order = oscillator_model.to_ode_function() + + model = ODEModel( + state_names=state_order, + initial_state={"x": 1.0, "v": 0.0}, + rhs=ode_fn, + params={"omega": [1.0], "force": [0.0]}, + ) + sim = ODESimulation( + model=model, + t_span=(0.0, math.pi), + t_eval=[0.0, math.pi / 2, math.pi], + ) + results = sim.run() + + x_vals = results.state_array("x") + # x(0)=1, x(pi/2)=0, x(pi)=-1 + assert x_vals[0] == pytest.approx(1.0, abs=1e-4) + assert x_vals[1] == pytest.approx(0.0, abs=1e-3) + assert x_vals[2] == pytest.approx(-1.0, abs=1e-3) diff --git a/packages/gds-symbolic/tests/test_elements.py b/packages/gds-symbolic/tests/test_elements.py new file mode 100644 index 0000000..2bc1eb9 --- /dev/null +++ b/packages/gds-symbolic/tests/test_elements.py @@ -0,0 +1,27 @@ +"""Tests for symbolic equation elements.""" + +from __future__ import annotations + +from gds_symbolic.elements import OutputEquation, StateEquation + + +class TestStateEquation: + def test_frozen(self) -> None: + eq = StateEquation(state_name="x", expr_str="-k*x") + assert eq.state_name == "x" + assert eq.expr_str == "-k*x" + + def test_immutable(self) -> None: + eq = StateEquation(state_name="x", expr_str="v") + try: + eq.state_name = "y" # type: ignore[misc] + raise AssertionError("Should be immutable") + except Exception: + pass + + +class TestOutputEquation: + def test_frozen(self) -> None: + eq = OutputEquation(sensor_name="obs", expr_str="x**2") + assert eq.sensor_name == "obs" + assert eq.expr_str == "x**2" diff --git a/packages/gds-symbolic/tests/test_linearize.py b/packages/gds-symbolic/tests/test_linearize.py new file mode 100644 index 0000000..098191e --- /dev/null +++ b/packages/gds-symbolic/tests/test_linearize.py @@ -0,0 +1,96 @@ +"""Tests for Jacobian linearization.""" + +from __future__ import annotations + +import pytest + +sympy = pytest.importorskip("sympy") + +from gds_symbolic import SymbolicControlModel # noqa: E402, TC001 +from gds_symbolic.linearize import LinearizedSystem # noqa: E402 + + +class TestLinearizeDecay: + """dx/dt = -k*x. A = [[-k]], B = [[0]].""" + + def test_a_matrix(self, decay_model: SymbolicControlModel) -> None: + lin = decay_model.linearize(x0=[0.0], u0=[0.0], param_values={"k": 2.0}) + assert isinstance(lin, LinearizedSystem) + assert lin.A == [[-2.0]] + + def test_b_matrix(self, decay_model: SymbolicControlModel) -> None: + lin = decay_model.linearize(x0=[0.0], u0=[0.0], param_values={"k": 1.0}) + # dx/dt = -k*x has no input dependence + assert lin.B == [[0.0]] + + def test_state_names(self, decay_model: SymbolicControlModel) -> None: + lin = decay_model.linearize(x0=[0.0], u0=[0.0]) + assert lin.state_names == ["x"] + assert lin.input_names == ["u"] + + +class TestLinearizeOscillator: + """dx/dt = v, dv/dt = -omega^2*x + force. + + A = [[0, 1], [-omega^2, 0]] + B = [[0], [1]] + """ + + def test_a_matrix(self, oscillator_model: SymbolicControlModel) -> None: + lin = oscillator_model.linearize( + x0=[0.0, 0.0], u0=[0.0], param_values={"omega": 3.0} + ) + assert lin.A[0] == [pytest.approx(0.0), pytest.approx(1.0)] + assert lin.A[1] == [pytest.approx(-9.0), pytest.approx(0.0)] + + def test_b_matrix(self, oscillator_model: SymbolicControlModel) -> None: + lin = oscillator_model.linearize( + x0=[0.0, 0.0], u0=[0.0], param_values={"omega": 1.0} + ) + # df1/d(force) = 0, df2/d(force) = 1 + assert [[pytest.approx(0.0)], [pytest.approx(1.0)]] == lin.B + + def test_dimensions(self, oscillator_model: SymbolicControlModel) -> None: + lin = oscillator_model.linearize( + x0=[0.0, 0.0], u0=[0.0], param_values={"omega": 1.0} + ) + assert len(lin.A) == 2 + assert len(lin.A[0]) == 2 + assert len(lin.B) == 2 + assert len(lin.B[0]) == 1 + + +class TestLinearizeVanDerPol: + """dx/dt = v, dv/dt = mu*(1-x^2)*v - x. + + At origin (x=0, v=0): + A = [[0, 1], [-1, mu]] + """ + + def test_a_at_origin(self, van_der_pol_model: SymbolicControlModel) -> None: + lin = van_der_pol_model.linearize( + x0=[0.0, 0.0], u0=[0.0], param_values={"mu": 2.0} + ) + assert lin.A[0] == [pytest.approx(0.0), pytest.approx(1.0)] + assert lin.A[1] == [pytest.approx(-1.0), pytest.approx(2.0)] + + def test_a_away_from_origin(self, van_der_pol_model: SymbolicControlModel) -> None: + """At x=1, v=0, mu=1: A[1] = [-1, 0].""" + lin = van_der_pol_model.linearize( + x0=[1.0, 0.0], u0=[0.0], param_values={"mu": 1.0} + ) + # df2/dx = -2*mu*x*v - 1 = -2*1*1*0 - 1 = -1 + assert lin.A[1][0] == pytest.approx(-1.0) + # df2/dv = mu*(1-x^2) = 1*(1-1) = 0 + assert lin.A[1][1] == pytest.approx(0.0) + + +class TestLinearizeOutputEquations: + """Test C and D matrices from output equations.""" + + def test_empty_outputs(self, decay_model: SymbolicControlModel) -> None: + """No output equations → empty C, D.""" + lin = decay_model.linearize(x0=[0.0], u0=[0.0]) + assert lin.C == [] + assert lin.D == [] + assert lin.output_names == [] diff --git a/packages/gds-symbolic/tests/test_model.py b/packages/gds-symbolic/tests/test_model.py new file mode 100644 index 0000000..e15f608 --- /dev/null +++ b/packages/gds-symbolic/tests/test_model.py @@ -0,0 +1,94 @@ +"""Tests for SymbolicControlModel construction and validation.""" + +from __future__ import annotations + +import pytest +from gds_control.dsl.elements import Controller, Input, Sensor, State + +from gds_symbolic import StateEquation, SymbolicControlModel +from gds_symbolic.errors import SymbolicError + + +class TestValidConstruction: + def test_inherits_control_model(self, decay_model: SymbolicControlModel) -> None: + assert decay_model.name == "decay" + assert len(decay_model.states) == 1 + assert len(decay_model.state_equations) == 1 + + def test_compile_still_works(self, decay_model: SymbolicControlModel) -> None: + """Structural GDS compilation is unaffected by symbolic annotations.""" + spec = decay_model.compile() + assert spec.name == "decay" + assert len(spec.blocks) > 0 + + def test_compile_system_still_works( + self, decay_model: SymbolicControlModel + ) -> None: + ir = decay_model.compile_system() + assert ir.name == "decay" + + def test_multi_state(self, oscillator_model: SymbolicControlModel) -> None: + assert len(oscillator_model.states) == 2 + assert len(oscillator_model.state_equations) == 2 + + def test_symbolic_params(self, decay_model: SymbolicControlModel) -> None: + assert decay_model.symbolic_params == ["k"] + + def test_no_equations_allowed(self) -> None: + """A model with no state equations is valid (all dx/dt = 0).""" + m = SymbolicControlModel( + name="static", + states=[State(name="x")], + inputs=[Input(name="u")], + sensors=[Sensor(name="obs", observes=["x"])], + controllers=[ + Controller(name="ctrl", reads=["obs", "u"], drives=["x"]), + ], + ) + assert m.state_equations == [] + + +class TestInvalidConstruction: + def test_undeclared_state_in_equation(self) -> None: + with pytest.raises(SymbolicError, match="undeclared state"): + SymbolicControlModel( + name="bad", + states=[State(name="x")], + inputs=[Input(name="u")], + sensors=[Sensor(name="obs", observes=["x"])], + controllers=[ + Controller(name="ctrl", reads=["obs", "u"], drives=["x"]), + ], + state_equations=[ + StateEquation(state_name="y", expr_str="x"), + ], + ) + + def test_duplicate_state_equations(self) -> None: + with pytest.raises(SymbolicError, match="Duplicate"): + SymbolicControlModel( + name="bad", + states=[State(name="x")], + inputs=[Input(name="u")], + sensors=[Sensor(name="obs", observes=["x"])], + controllers=[ + Controller(name="ctrl", reads=["obs", "u"], drives=["x"]), + ], + state_equations=[ + StateEquation(state_name="x", expr_str="1"), + StateEquation(state_name="x", expr_str="2"), + ], + ) + + def test_param_conflicts_with_state(self) -> None: + with pytest.raises(SymbolicError, match="conflicts"): + SymbolicControlModel( + name="bad", + states=[State(name="x")], + inputs=[Input(name="u")], + sensors=[Sensor(name="obs", observes=["x"])], + controllers=[ + Controller(name="ctrl", reads=["obs", "u"], drives=["x"]), + ], + symbolic_params=["x"], + ) diff --git a/packages/gds-viz/gds_viz/__init__.py b/packages/gds-viz/gds_viz/__init__.py index 9ae00e7..4fda7e3 100644 --- a/packages/gds-viz/gds_viz/__init__.py +++ b/packages/gds-viz/gds_viz/__init__.py @@ -17,3 +17,12 @@ "system_to_mermaid", "trace_to_mermaid", ] + + +def __getattr__(name: str) -> object: + """Lazy import for optional phase portrait module.""" + if name == "phase_portrait": + from gds_viz.phase import phase_portrait + + return phase_portrait + raise AttributeError(f"module 'gds_viz' has no attribute {name!r}") diff --git a/packages/gds-viz/gds_viz/phase.py b/packages/gds-viz/gds_viz/phase.py new file mode 100644 index 0000000..cf012f7 --- /dev/null +++ b/packages/gds-viz/gds_viz/phase.py @@ -0,0 +1,339 @@ +"""Phase portrait visualization for continuous-time ODE systems. + +Produces matplotlib figures: vector fields, trajectories, nullclines, +and backward reachable set boundaries (isochrones). + +Requires ``gds-viz[phase]`` (matplotlib + numpy + gds-continuous). + +Example:: + + from gds_continuous import ODEModel + from gds_viz.phase import phase_portrait + + model = ODEModel( + state_names=["x", "v"], + initial_state={"x": 1.0, "v": 0.0}, + rhs=my_ode_fn, + ) + fig = phase_portrait(model, x_var="x", y_var="v", x_range=(-3, 3), y_range=(-3, 3)) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from gds_continuous import ODEModel + from gds_continuous.results import ODEResults + + +def _require_phase_deps() -> None: + """Raise ImportError if matplotlib/numpy are absent.""" + try: + import matplotlib # noqa: F401 + import numpy # noqa: F401 + except ImportError as exc: + raise ImportError( + "Phase portrait visualization requires matplotlib and numpy. " + "Install with: uv add gds-viz[phase]" + ) from exc + + +@dataclass(frozen=True) +class PhasePlotConfig: + """Configuration for a phase portrait.""" + + x_var: str + y_var: str + x_range: tuple[float, float] + y_range: tuple[float, float] + resolution: int = 20 + fixed_states: dict[str, float] = field(default_factory=dict) + params: dict[str, float] = field(default_factory=dict) + title: str = "" + + +def compute_vector_field( + model: ODEModel, + config: PhasePlotConfig, + *, + t: float = 0.0, +) -> tuple[Any, Any, Any, Any]: + """Compute a 2D vector field over a grid. + + Parameters + ---------- + model + ODE model with the RHS function. + config + Grid specification (axes, ranges, resolution). + t + Time value for evaluating the RHS (default 0). + + Returns + ------- + X, Y, dX, dY : numpy arrays + Meshgrid coordinates and derivative components. + """ + _require_phase_deps() + import numpy as np + + x_idx = model.state_names.index(config.x_var) + y_idx = model.state_names.index(config.y_var) + + xs = np.linspace(config.x_range[0], config.x_range[1], config.resolution) + ys = np.linspace(config.y_range[0], config.y_range[1], config.resolution) + X, Y = np.meshgrid(xs, ys) + + dX = np.zeros_like(X) + dY = np.zeros_like(Y) + + # Build base state from fixed values + base = [config.fixed_states.get(n, 0.0) for n in model.state_names] + + for i in range(config.resolution): + for j in range(config.resolution): + state = list(base) + state[x_idx] = X[i, j] + state[y_idx] = Y[i, j] + deriv = model.rhs(t, state, config.params) + dX[i, j] = deriv[x_idx] + dY[i, j] = deriv[y_idx] + + return X, Y, dX, dY + + +def compute_trajectories( + model: ODEModel, + initial_conditions: list[dict[str, float]], + *, + t_span: tuple[float, float] = (0.0, 10.0), + params: dict[str, float] | None = None, + solver: str = "RK45", + max_step: float = 0.05, +) -> list[ODEResults]: + """Integrate multiple trajectories from different initial conditions. + + Parameters + ---------- + model + ODE model (``rhs`` is used, ``initial_state`` is overridden). + initial_conditions + List of state dicts, one per trajectory. + t_span + Integration time interval. + params + Parameter values (single set, not a sweep). + solver + SciPy solver name. + max_step + Maximum integration step size. + + Returns + ------- + List of ODEResults, one per initial condition. + """ + from gds_continuous import ODEModel as _ODEModel + from gds_continuous import ODESimulation + + results = [] + p = params or {} + for ic in initial_conditions: + m = _ODEModel( + state_names=model.state_names, + initial_state=ic, + rhs=model.rhs, + params={k: [v] for k, v in p.items()}, + ) + sim = ODESimulation( + model=m, + t_span=t_span, + solver=solver, # type: ignore[arg-type] + max_step=max_step, + ) + results.append(sim.run()) + return results + + +def vector_field_plot( + model: ODEModel, + config: PhasePlotConfig, + *, + ax: Any | None = None, + normalize: bool = True, + color: str = "gray", + alpha: float = 0.6, +) -> Any: + """Plot a 2D vector field (quiver plot). + + Returns the matplotlib Figure. + """ + _require_phase_deps() + import matplotlib.pyplot as plt + import numpy as np + + X, Y, dX, dY = compute_vector_field(model, config) + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) + else: + fig = ax.get_figure() + + if normalize: + mag = np.sqrt(dX**2 + dY**2) + mag = np.where(mag > 0, mag, 1.0) + dX = dX / mag + dY = dY / mag + + ax.quiver(X, Y, dX, dY, color=color, alpha=alpha, scale=25) + ax.set_xlabel(config.x_var) + ax.set_ylabel(config.y_var) + ax.set_aspect("equal") + if config.title: + ax.set_title(config.title) + ax.grid(True, alpha=0.3) + return fig + + +def trajectory_plot( + results_list: list[ODEResults], + x_var: str, + y_var: str, + *, + ax: Any | None = None, + colormap: str = "viridis", + linewidth: float = 1.0, + show_start: bool = True, + show_end: bool = True, +) -> Any: + """Plot trajectories in phase space. + + Returns the matplotlib Figure. + """ + _require_phase_deps() + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) + else: + fig = ax.get_figure() + + cmap = plt.get_cmap(colormap) + n = max(len(results_list), 1) + + for i, res in enumerate(results_list): + c = cmap(i / n) + xs = res.state_array(x_var) + ys = res.state_array(y_var) + ax.plot(xs, ys, "-", color=c, linewidth=linewidth, alpha=0.8) + if show_start: + ax.plot(xs[0], ys[0], "o", color=c, markersize=5) + if show_end: + ax.plot(xs[-1], ys[-1], "s", color=c, markersize=4) + + ax.set_xlabel(x_var) + ax.set_ylabel(y_var) + ax.set_aspect("equal") + ax.grid(True, alpha=0.3) + return fig + + +def phase_portrait( + model: ODEModel, + x_var: str, + y_var: str, + x_range: tuple[float, float], + y_range: tuple[float, float], + *, + initial_conditions: list[dict[str, float]] | None = None, + params: dict[str, float] | None = None, + fixed_states: dict[str, float] | None = None, + t_span: tuple[float, float] = (0.0, 10.0), + resolution: int = 20, + title: str = "", + show_nullclines: bool = False, + figsize: tuple[float, float] = (10, 10), +) -> Any: + """Full phase portrait: vector field + optional trajectories + nullclines. + + Parameters + ---------- + model + ODE model. + x_var, y_var + State variable names for the two axes. + x_range, y_range + Plot ranges for each axis. + initial_conditions + List of state dicts for trajectory integration. None = no trajectories. + params + Parameter values for RHS evaluation. + fixed_states + Values for state variables not on the axes (for >2D systems). + t_span + Integration time for trajectories. + resolution + Grid density for vector field. + title + Plot title. + show_nullclines + If True, draw zero-contours of dx/dt=0 and dy/dt=0. + figsize + Figure size. + + Returns + ------- + matplotlib Figure. + """ + _require_phase_deps() + import matplotlib.pyplot as plt + import numpy as np + + config = PhasePlotConfig( + x_var=x_var, + y_var=y_var, + x_range=x_range, + y_range=y_range, + resolution=resolution, + fixed_states=fixed_states or {}, + params=params or {}, + title=title, + ) + + fig, ax = plt.subplots(1, 1, figsize=figsize) + + # Vector field + X, Y, dX, dY = compute_vector_field(model, config) + mag = np.sqrt(dX**2 + dY**2) + mag = np.where(mag > 0, mag, 1.0) + ax.quiver(X, Y, dX / mag, dY / mag, color="gray", alpha=0.4, scale=25) + + # Nullclines + if show_nullclines: + ax.contour(X, Y, dX, levels=[0], colors=["blue"], linewidths=[1.5], alpha=0.6) + ax.contour(X, Y, dY, levels=[0], colors=["red"], linewidths=[1.5], alpha=0.6) + + # Trajectories + if initial_conditions: + trajs = compute_trajectories( + model, initial_conditions, t_span=t_span, params=params + ) + cmap = plt.get_cmap("viridis") + n = max(len(trajs), 1) + for i, res in enumerate(trajs): + c = cmap(i / n) + xs = res.state_array(x_var) + ys = res.state_array(y_var) + ax.plot(xs, ys, "-", color=c, linewidth=1.2, alpha=0.8) + ax.plot(xs[0], ys[0], "o", color=c, markersize=5) + + ax.set_xlabel(x_var) + ax.set_ylabel(y_var) + ax.set_xlim(x_range) + ax.set_ylim(y_range) + ax.set_aspect("equal") + ax.set_title(title) + ax.grid(True, alpha=0.3) + plt.tight_layout() + return fig diff --git a/packages/gds-viz/pyproject.toml b/packages/gds-viz/pyproject.toml index 501af1a..9d48b61 100644 --- a/packages/gds-viz/pyproject.toml +++ b/packages/gds-viz/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "gds-framework>=0.2.3", ] +[project.optional-dependencies] +phase = ["matplotlib>=3.8", "numpy>=1.26", "gds-continuous>=0.1.0"] + [project.urls] Homepage = "https://github.com/BlockScience/gds-core" Repository = "https://github.com/BlockScience/gds-core" @@ -50,6 +53,7 @@ packages = ["gds_viz"] [tool.uv.sources] gds-framework = { workspace = true } +gds-continuous = { workspace = true } [tool.pytest.ini_options] testpaths = ["tests"] @@ -74,4 +78,7 @@ dev = [ "pytest>=8.0", "pytest-cov>=6.0", "ruff>=0.8", + "matplotlib>=3.8", + "numpy>=1.26", + "scipy>=1.13", ] diff --git a/packages/gds-viz/tests/test_phase.py b/packages/gds-viz/tests/test_phase.py new file mode 100644 index 0000000..36c193b --- /dev/null +++ b/packages/gds-viz/tests/test_phase.py @@ -0,0 +1,208 @@ +"""Tests for phase portrait visualization.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +matplotlib = pytest.importorskip("matplotlib") +numpy = pytest.importorskip("numpy") + +from gds_continuous import ODEModel # noqa: E402 +from gds_viz.phase import ( # noqa: E402 + PhasePlotConfig, + compute_trajectories, + compute_vector_field, + phase_portrait, + trajectory_plot, + vector_field_plot, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _oscillator_rhs(t: float, y: list[float], params: dict[str, Any]) -> list[float]: + """dx/dt = v, dv/dt = -x.""" + return [y[1], -y[0]] + + +@pytest.fixture +def oscillator() -> ODEModel: + return ODEModel( + state_names=["x", "v"], + initial_state={"x": 1.0, "v": 0.0}, + rhs=_oscillator_rhs, + ) + + +@pytest.fixture +def config() -> PhasePlotConfig: + return PhasePlotConfig( + x_var="x", + y_var="v", + x_range=(-3.0, 3.0), + y_range=(-3.0, 3.0), + resolution=10, + ) + + +# --------------------------------------------------------------------------- +# Vector field +# --------------------------------------------------------------------------- + + +class TestComputeVectorField: + def test_shape(self, oscillator: ODEModel, config: PhasePlotConfig) -> None: + _X, _Y, _dX, _dY = compute_vector_field(oscillator, config) + assert _X.shape == (10, 10) + assert _dX.shape == (10, 10) + + def test_values_at_origin( + self, oscillator: ODEModel, config: PhasePlotConfig + ) -> None: + """At (0,0): dx/dt=0, dv/dt=0.""" + + cfg = PhasePlotConfig( + x_var="x", + y_var="v", + x_range=(-1, 1), + y_range=(-1, 1), + resolution=3, + ) + _X, _Y, _dX, _dY = compute_vector_field(oscillator, cfg) + # Center point (idx 1,1 in a 3x3 grid) + assert _dX[1, 1] == pytest.approx(0.0, abs=1e-10) + assert _dY[1, 1] == pytest.approx(0.0, abs=1e-10) + + +# --------------------------------------------------------------------------- +# Trajectories +# --------------------------------------------------------------------------- + + +class TestComputeTrajectories: + def test_returns_results(self, oscillator: ODEModel) -> None: + ics = [{"x": 1.0, "v": 0.0}, {"x": 0.0, "v": 1.0}] + results = compute_trajectories(oscillator, ics, t_span=(0, 3)) + assert len(results) == 2 + assert len(results[0]) > 2 + + def test_trajectory_is_circular(self, oscillator: ODEModel) -> None: + """Harmonic oscillator: x^2 + v^2 = const.""" + ics = [{"x": 1.0, "v": 0.0}] + results = compute_trajectories(oscillator, ics, t_span=(0, 6.28)) + xs = results[0].state_array("x") + vs = results[0].state_array("v") + r0 = xs[0] ** 2 + vs[0] ** 2 + for x, v in zip(xs, vs, strict=True): + assert x**2 + v**2 == pytest.approx(r0, abs=1e-4) + + +# --------------------------------------------------------------------------- +# Plot functions (smoke tests — verify they return Figure, don't check pixels) +# --------------------------------------------------------------------------- + + +class TestVectorFieldPlot: + def test_returns_figure( + self, oscillator: ODEModel, config: PhasePlotConfig + ) -> None: + import matplotlib.pyplot as plt + + fig = vector_field_plot(oscillator, config) + assert fig is not None + plt.close(fig) + + +class TestTrajectoryPlot: + def test_returns_figure(self, oscillator: ODEModel) -> None: + import matplotlib.pyplot as plt + + ics = [{"x": 1.0, "v": 0.0}] + results = compute_trajectories(oscillator, ics, t_span=(0, 3)) + fig = trajectory_plot(results, "x", "v") + assert fig is not None + plt.close(fig) + + +class TestPhasePortrait: + def test_vector_field_only(self, oscillator: ODEModel) -> None: + import matplotlib.pyplot as plt + + fig = phase_portrait( + oscillator, + "x", + "v", + x_range=(-3, 3), + y_range=(-3, 3), + resolution=8, + title="Oscillator", + ) + assert fig is not None + plt.close(fig) + + def test_with_trajectories(self, oscillator: ODEModel) -> None: + import matplotlib.pyplot as plt + + ics = [ + {"x": 1.0, "v": 0.0}, + {"x": 2.0, "v": 0.0}, + ] + fig = phase_portrait( + oscillator, + "x", + "v", + x_range=(-3, 3), + y_range=(-3, 3), + initial_conditions=ics, + resolution=8, + ) + assert fig is not None + plt.close(fig) + + def test_with_nullclines(self, oscillator: ODEModel) -> None: + import matplotlib.pyplot as plt + + fig = phase_portrait( + oscillator, + "x", + "v", + x_range=(-3, 3), + y_range=(-3, 3), + show_nullclines=True, + resolution=15, + ) + assert fig is not None + plt.close(fig) + + def test_high_dimensional_with_fixed(self) -> None: + """3D system projected to 2D via fixed_states.""" + import matplotlib.pyplot as plt + + def lorenz(t: float, y: list[float], p: dict[str, Any]) -> list[float]: + sigma, rho, beta = 10.0, 28.0, 8.0 / 3.0 + return [ + sigma * (y[1] - y[0]), + y[0] * (rho - y[2]) - y[1], + y[0] * y[1] - beta * y[2], + ] + + model = ODEModel( + state_names=["x", "y", "z"], + initial_state={"x": 1.0, "y": 1.0, "z": 1.0}, + rhs=lorenz, + ) + fig = phase_portrait( + model, + "x", + "y", + x_range=(-20, 20), + y_range=(-30, 30), + fixed_states={"z": 25.0}, + resolution=8, + ) + assert fig is not None + plt.close(fig) diff --git a/pyproject.toml b/pyproject.toml index c3e3632..d4f383d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,18 +24,50 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering", ] -dependencies = [ - "gds-framework>=0.2.3", +dependencies = ["gds-framework>=0.2.3"] + +[project.optional-dependencies] +# Domain DSLs +viz = ["gds-viz>=0.1.0"] +games = ["gds-games>=0.1.0"] +stockflow = ["gds-stockflow>=0.1.0"] +control = ["gds-control>=0.1.0"] +software = ["gds-software>=0.1.0"] +business = ["gds-business>=0.1.0"] +# Simulation & analysis +sim = ["gds-sim>=0.1.0"] +continuous = ["gds-continuous>=0.1.0"] +symbolic = ["gds-symbolic>=0.1.0"] +analysis = ["gds-analysis>=0.1.0"] +psuu = ["gds-psuu>=0.1.0"] +# Formal methods +owl = ["gds-owl>=0.1.0"] +# Tutorials +examples = [ + "gds-examples>=0.1.0", + "gds-viz>=0.1.0", + "gds-games>=0.1.0", + "gds-stockflow>=0.1.0", + "gds-control>=0.1.0", + "gds-software>=0.1.0", + "gds-business>=0.1.0", + "gds-sim>=0.1.0", +] +# Everything +all = [ "gds-viz>=0.1.0", "gds-games>=0.1.0", "gds-stockflow>=0.1.0", "gds-control>=0.1.0", "gds-software>=0.1.0", "gds-business>=0.1.0", - "gds-examples>=0.1.0", "gds-sim>=0.1.0", + "gds-continuous>=0.1.0", + "gds-symbolic>=0.1.0", + "gds-analysis>=0.1.0", "gds-psuu>=0.1.0", "gds-owl>=0.1.0", + "gds-examples>=0.1.0", ] [project.urls] @@ -60,6 +92,9 @@ gds-software = { workspace = true } gds-business = { workspace = true } gds-examples = { workspace = true } gds-sim = { workspace = true } +gds-continuous = { workspace = true } +gds-symbolic = { workspace = true } +gds-analysis = { workspace = true } gds-psuu = { workspace = true } gds-owl = { workspace = true } @@ -78,7 +113,7 @@ select = ["E", "W", "F", "I", "UP", "B", "SIM", "TCH", "RUF"] "packages/gds-examples/prisoners_dilemma/visualize.py" = ["E501"] [tool.ruff.lint.isort] -known-first-party = ["gds", "gds_viz", "ogs", "stockflow", "gds_control", "gds_software", "gds_business", "gds_sim", "gds_psuu", "gds_owl"] +known-first-party = ["gds", "gds_viz", "ogs", "stockflow", "gds_control", "gds_software", "gds_business", "gds_sim", "gds_continuous", "gds_symbolic", "gds_psuu", "gds_owl"] [dependency-groups] docs = [