From 55981f52a6e1dbc12a321bd907a9fcf431c1b269 Mon Sep 17 00:00:00 2001 From: Vahid Ahmadi Date: Thu, 30 Apr 2026 13:24:31 +0100 Subject: [PATCH] feat: add Simulation.from_situation for situation-JSON input (#51) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a classmethod to the Python wrapper that accepts the PolicyEngine web-app situation-JSON format (people / benunits / households with `members` lists and period-keyed values) and converts it into the three input DataFrames the Rust engine consumes. Closes #51 in part — the small, low-risk piece. Datasets-from-URL and a direct dataframe entry point can follow in subsequent PRs. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/test.yml | 11 + changelog.d/added/from-situation.md | 1 + .../python/policyengine_uk_compiled/engine.py | 256 ++++++++++++++++ interfaces/python/tests/__init__.py | 0 .../python/tests/test_from_situation.py | 274 ++++++++++++++++++ 5 files changed, 542 insertions(+) create mode 100644 changelog.d/added/from-situation.md create mode 100644 interfaces/python/tests/__init__.py create mode 100644 interfaces/python/tests/test_from_situation.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d46c4c..7f55424 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,3 +29,14 @@ jobs: - name: Test run: cargo test + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.11" + + - name: Install Python wrapper test dependencies + run: pip install pandas pydantic pytest + + - name: Test Python wrapper + run: PYTHONPATH=interfaces/python pytest interfaces/python/tests -q diff --git a/changelog.d/added/from-situation.md b/changelog.d/added/from-situation.md new file mode 100644 index 0000000..78eff1d --- /dev/null +++ b/changelog.d/added/from-situation.md @@ -0,0 +1 @@ +Add `Simulation.from_situation(situation, year)` to the Python wrapper, accepting the PolicyEngine web-app situation-JSON format (people / benunits / households with `members` lists and period-keyed values) and converting it into the input DataFrames the Rust engine expects. diff --git a/interfaces/python/policyengine_uk_compiled/engine.py b/interfaces/python/policyengine_uk_compiled/engine.py index 74e670a..92c6d7c 100644 --- a/interfaces/python/policyengine_uk_compiled/engine.py +++ b/interfaces/python/policyengine_uk_compiled/engine.py @@ -120,6 +120,202 @@ def _parse_stdin_payload(payload: str): ) +# Region values accepted in situation dicts → canonical form the Rust engine expects. +# Accepts the upper-snake names used by the PolicyEngine web app and the +# title-case forms used by `PERSON_DEFAULTS`/`HOUSEHOLD_DEFAULTS` and the +# `parse_region` function in `src/data/clean.rs`. +_REGION_CANONICAL = { + "NORTH_EAST": "North East", "North East": "North East", + "NORTH_WEST": "North West", "North West": "North West", + "YORKSHIRE": "Yorkshire", "Yorkshire": "Yorkshire", + "EAST_MIDLANDS": "East Midlands", "East Midlands": "East Midlands", + "WEST_MIDLANDS": "West Midlands", "West Midlands": "West Midlands", + "EAST_OF_ENGLAND": "East of England", "East of England": "East of England", + "LONDON": "London", "London": "London", + "SOUTH_EAST": "South East", "South East": "South East", + "SOUTH_WEST": "South West", "South West": "South West", + "WALES": "Wales", "Wales": "Wales", + "SCOTLAND": "Scotland", "Scotland": "Scotland", + "NORTHERN_IRELAND": "Northern Ireland", "Northern Ireland": "Northern Ireland", +} + + +def _resolve_period_value(value, year: int): + """Pick a value out of a period-keyed dict, or return the scalar unchanged. + + Picks an exact match on ``year`` first, then any period whose first four + characters match (covers ``"2025-01"`` style entries), then the most-recent + period that is not later than ``year``, then the earliest period. + """ + if not isinstance(value, dict): + return value + year_str = str(year) + if year_str in value: + return value[year_str] + for k, v in value.items(): + if str(k)[:4] == year_str: + return v + # Numeric-period fallback + candidates = [] + for k, v in value.items(): + try: + candidates.append((int(str(k)[:4]), v)) + except (ValueError, TypeError): + continue + if not candidates: + # Single non-period entry (e.g. {"ETERNITY": x}) → use it + return next(iter(value.values())) + candidates.sort() + earlier_or_equal = [v for y, v in candidates if y <= year] + return earlier_or_equal[-1] if earlier_or_equal else candidates[0][1] + + +def _situation_to_dataframes(situation: dict, year: int): + """Convert a PolicyEngine situation-JSON dict into the three input DataFrames. + + See ``Simulation.from_situation`` for the supported dict shape. + """ + if not HAS_PANDAS: + raise ImportError("pandas is required for from_situation") + + people = situation.get("people") or {} + benunits = situation.get("benunits") or {} + households = situation.get("households") or {} + + if not people: + raise ValueError("situation must contain at least one entry under 'people'") + if not households: + raise ValueError("situation must contain at least one entry under 'households'") + if not benunits: + # Fold all people into a single implicit benunit so callers don't + # have to supply one for trivial cases. + benunits = {"_default": {"members": list(people.keys())}} + + person_id_map = {pid: i for i, pid in enumerate(people.keys())} + benunit_id_map = {bid: i for i, bid in enumerate(benunits.keys())} + household_id_map = {hid: i for i, hid in enumerate(households.keys())} + + # Build reverse lookups: person → benunit, person → household + person_to_benunit: dict[str, str] = {} + for bid, fields in benunits.items(): + for member in (fields.get("members") or []): + person_to_benunit[member] = bid + person_to_household: dict[str, str] = {} + for hid, fields in households.items(): + for member in (fields.get("members") or []): + person_to_household[member] = hid + + person_rows = [] + for pid, fields in people.items(): + if pid not in person_to_benunit: + raise ValueError(f"person {pid!r} is not a member of any benunit") + if pid not in person_to_household: + raise ValueError(f"person {pid!r} is not a member of any household") + row = dict(PERSON_DEFAULTS) + row["person_id"] = person_id_map[pid] + row["benunit_id"] = benunit_id_map[person_to_benunit[pid]] + row["household_id"] = household_id_map[person_to_household[pid]] + for var, val in (fields or {}).items(): + if var == "members": + continue + resolved = _resolve_period_value(val, year) + if var == "gender" and isinstance(resolved, str): + resolved = resolved.lower() + row[var] = resolved + person_rows.append(row) + + # Mark the first member of each benunit as benunit head, and the first + # member of each household as household head, unless the situation + # already specified these flags. + seen_bu_head: set[int] = set() + seen_hh_head: set[int] = set() + explicit_bu_head: set[str] = set() + explicit_hh_head: set[str] = set() + for pid, fields in people.items(): + if "is_benunit_head" in (fields or {}): + explicit_bu_head.add(pid) + if "is_household_head" in (fields or {}): + explicit_hh_head.add(pid) + for pid, row in zip(people.keys(), person_rows): + bu = row["benunit_id"] + hh = row["household_id"] + if pid in explicit_bu_head: + seen_bu_head.add(bu) + else: + row["is_benunit_head"] = bu not in seen_bu_head + if bu not in seen_bu_head: + seen_bu_head.add(bu) + if pid in explicit_hh_head: + seen_hh_head.add(hh) + else: + row["is_household_head"] = hh not in seen_hh_head + if hh not in seen_hh_head: + seen_hh_head.add(hh) + + benunit_rows = [] + for bid, fields in benunits.items(): + members = fields.get("members") or [] + member_int_ids = [person_id_map[m] for m in members if m in person_id_map] + # Single household owns this benunit — pick from the first member. + if member_int_ids: + owner_household = next( + household_id_map[person_to_household[m]] + for m in members + if m in person_to_household + ) + else: + owner_household = 0 + row = dict(BENUNIT_DEFAULTS) + row["benunit_id"] = benunit_id_map[bid] + row["household_id"] = owner_household + row["person_ids"] = ";".join(str(i) for i in member_int_ids) + for var, val in (fields or {}).items(): + if var == "members": + continue + row[var] = _resolve_period_value(val, year) + benunit_rows.append(row) + + household_rows = [] + for hid, fields in households.items(): + members = fields.get("members") or [] + member_int_ids = [person_id_map[m] for m in members if m in person_id_map] + member_benunits = sorted({ + benunit_id_map[person_to_benunit[m]] + for m in members + if m in person_to_benunit + }) + row = dict(HOUSEHOLD_DEFAULTS) + row["household_id"] = household_id_map[hid] + row["person_ids"] = ";".join(str(i) for i in member_int_ids) + row["benunit_ids"] = ";".join(str(i) for i in member_benunits) + for var, val in (fields or {}).items(): + if var == "members": + continue + resolved = _resolve_period_value(val, year) + if var == "region" and isinstance(resolved, str): + resolved = _REGION_CANONICAL.get(resolved, resolved) + row[var] = resolved + household_rows.append(row) + + # Propagate `is_in_scotland` from each person's household region unless + # the situation already set it explicitly. + region_by_household = {h["household_id"]: h.get("region") for h in household_rows} + explicit_in_scotland = { + pid for pid, fields in people.items() + if "is_in_scotland" in (fields or {}) + } + for pid, row in zip(people.keys(), person_rows): + if pid in explicit_in_scotland: + continue + row["is_in_scotland"] = region_by_household.get(row["household_id"]) == "Scotland" + + return ( + pd.DataFrame(person_rows), + pd.DataFrame(benunit_rows), + pd.DataFrame(household_rows), + ) + + def _parse_microdata_stdout(raw: str) -> MicrodataResult: """Parse the concatenated CSV protocol output into a MicrodataResult.""" sections = {} @@ -612,6 +808,66 @@ def get_baseline_params(self, timeout: int = 10) -> dict: # ── Convenience constructors for hypothetical households ────────────── + @staticmethod + def from_situation( + situation: dict, + year: int = 2025, + **kwargs, + ) -> "Simulation": + """Build a Simulation from a PolicyEngine situation-JSON dict. + + The situation dict mirrors the PolicyEngine web-app format:: + + { + "people": {"": {"": {"": }, ...}, ...}, + "benunits": {"": {"members": [...], "": ..., ...}, ...}, + "households": {"": {"members": [...], "": ..., ...}, ...}, + } + + Each variable's value may be either a period-keyed dict (e.g. + ``{"2025": 50000}``) or a plain scalar — scalars are treated as + applying to ``year``. + + Variable names map directly to the wrapper input columns (see + ``PERSON_DEFAULTS``, ``BENUNIT_DEFAULTS``, ``HOUSEHOLD_DEFAULTS``). + ``region`` accepts either the title-case form (``"London"``, + ``"North East"``) or the upper-snake form used by the + PolicyEngine web app (``"LONDON"``, ``"NORTH_EAST"``); it is + normalised before being passed to the Rust engine and + ``is_in_scotland`` is set automatically. ``gender`` is + case-insensitive. + + Members lists on benunits/households reference the keys used in + ``situation["people"]``; people are assigned integer ``person_id`` + values in the order they appear under ``people``, and benunits/ + households receive the ``person_ids`` / ``benunit_ids`` strings + the engine expects. + + Example:: + + sim = Simulation.from_situation( + { + "people": { + "you": {"age": 30, "employment_income": {"2025": 50000}}, + }, + "benunits": {"yours": {"members": ["you"]}}, + "households": {"yours": {"members": ["you"], "region": "LONDON"}}, + }, + year=2025, + ) + result = sim.run() + """ + persons_df, benunits_df, households_df = _situation_to_dataframes( + situation, year + ) + return Simulation( + year=year, + persons=persons_df, + benunits=benunits_df, + households=households_df, + **kwargs, + ) + @staticmethod def single_person( age: float = 30, diff --git a/interfaces/python/tests/__init__.py b/interfaces/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/interfaces/python/tests/test_from_situation.py b/interfaces/python/tests/test_from_situation.py new file mode 100644 index 0000000..3e648ee --- /dev/null +++ b/interfaces/python/tests/test_from_situation.py @@ -0,0 +1,274 @@ +"""Tests for ``Simulation.from_situation``. + +These exercise the pure-Python conversion from a situation dict into the three +input DataFrames the wrapper passes to the Rust binary. They do not invoke the +binary itself, so they run quickly with no data dependencies. +""" + +from __future__ import annotations + +import pytest + +pd = pytest.importorskip("pandas") + +from policyengine_uk_compiled.engine import ( + _resolve_period_value, + _situation_to_dataframes, + Simulation, + PERSON_DEFAULTS, + BENUNIT_DEFAULTS, + HOUSEHOLD_DEFAULTS, +) + + +# ── _resolve_period_value ───────────────────────────────────────────────────── + +class TestResolvePeriodValue: + def test_scalar_passes_through(self): + assert _resolve_period_value(42, year=2025) == 42 + assert _resolve_period_value("LONDON", year=2025) == "LONDON" + assert _resolve_period_value(None, year=2025) is None + + def test_exact_year_match(self): + assert _resolve_period_value({"2024": 100, "2025": 200}, year=2025) == 200 + + def test_falls_back_to_most_recent_earlier_year(self): + # No 2025 entry — pick the latest period <= 2025. + assert _resolve_period_value({"2020": 1, "2023": 2}, year=2025) == 2 + + def test_falls_back_to_earliest_when_only_later_years_present(self): + assert _resolve_period_value({"2030": 9, "2040": 10}, year=2025) == 9 + + def test_handles_year_month_keys(self): + assert _resolve_period_value({"2025-04": 50}, year=2025) == 50 + + def test_handles_eternity_style_keys(self): + assert _resolve_period_value({"ETERNITY": "x"}, year=2025) == "x" + + +# ── _situation_to_dataframes ────────────────────────────────────────────────── + +class TestSituationToDataframes: + def test_minimal_single_person(self): + situation = { + "people": {"alice": {"age": 30, "employment_income": {"2025": 50_000}}}, + "benunits": {"bu_1": {"members": ["alice"]}}, + "households": {"hh_1": {"members": ["alice"], "region": "LONDON"}}, + } + persons, benunits, households = _situation_to_dataframes(situation, year=2025) + + assert len(persons) == 1 + assert persons.loc[0, "person_id"] == 0 + assert persons.loc[0, "benunit_id"] == 0 + assert persons.loc[0, "household_id"] == 0 + assert persons.loc[0, "age"] == 30 + assert persons.loc[0, "employment_income"] == 50_000 + assert bool(persons.loc[0, "is_benunit_head"]) is True + assert bool(persons.loc[0, "is_household_head"]) is True + assert bool(persons.loc[0, "is_in_scotland"]) is False + + assert len(benunits) == 1 + assert benunits.loc[0, "person_ids"] == "0" + assert benunits.loc[0, "household_id"] == 0 + + assert len(households) == 1 + assert households.loc[0, "person_ids"] == "0" + assert households.loc[0, "benunit_ids"] == "0" + assert households.loc[0, "region"] == "London" + + def test_matches_single_person_constructor(self): + """`from_situation` produces the same input frames as `single_person`.""" + sp_persons, sp_benunits, sp_households = Simulation.single_person( + age=40, employment_income=30_000, region="London" + ) + situation = { + "people": {"alice": {"age": 40, "employment_income": 30_000}}, + "benunits": {"bu": {"members": ["alice"]}}, + "households": {"hh": {"members": ["alice"], "region": "London"}}, + } + s_persons, s_benunits, s_households = _situation_to_dataframes( + situation, year=2025 + ) + # Compare on the columns single_person is expected to set / leave as defaults. + for col in PERSON_DEFAULTS: + assert sp_persons.loc[0, col] == s_persons.loc[0, col], col + for col in BENUNIT_DEFAULTS: + assert sp_benunits.loc[0, col] == s_benunits.loc[0, col], col + for col in HOUSEHOLD_DEFAULTS: + assert sp_households.loc[0, col] == s_households.loc[0, col], col + + def test_couple_with_children(self): + situation = { + "people": { + "p1": {"age": 35, "employment_income": {"2025": 40_000}}, + "p2": {"age": 33, "employment_income": {"2025": 25_000}}, + "c1": {"age": 6}, + "c2": {"age": 3}, + }, + "benunits": {"bu": {"members": ["p1", "p2", "c1", "c2"]}}, + "households": {"hh": {"members": ["p1", "p2", "c1", "c2"], "region": "South East"}}, + } + persons, benunits, households = _situation_to_dataframes(situation, year=2025) + + assert len(persons) == 4 + assert list(persons["person_id"]) == [0, 1, 2, 3] + # p1 is the implicit head of both benunit and household. + assert bool(persons.loc[0, "is_benunit_head"]) is True + assert bool(persons.loc[0, "is_household_head"]) is True + assert bool(persons.loc[1, "is_benunit_head"]) is False + assert bool(persons.loc[1, "is_household_head"]) is False + assert benunits.loc[0, "person_ids"] == "0;1;2;3" + assert households.loc[0, "person_ids"] == "0;1;2;3" + assert households.loc[0, "benunit_ids"] == "0" + assert households.loc[0, "region"] == "South East" + + def test_region_normalisation_upper_snake(self): + situation = { + "people": {"p": {"age": 30}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"], "region": "NORTH_EAST"}}, + } + _, _, households = _situation_to_dataframes(situation, year=2025) + assert households.loc[0, "region"] == "North East" + + def test_region_normalisation_title_case_passthrough(self): + situation = { + "people": {"p": {"age": 30}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"], "region": "South West"}}, + } + _, _, households = _situation_to_dataframes(situation, year=2025) + assert households.loc[0, "region"] == "South West" + + def test_scotland_sets_is_in_scotland(self): + situation = { + "people": {"p": {"age": 40}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"], "region": "SCOTLAND"}}, + } + persons, _, households = _situation_to_dataframes(situation, year=2025) + assert households.loc[0, "region"] == "Scotland" + assert bool(persons.loc[0, "is_in_scotland"]) is True + + def test_explicit_is_in_scotland_overrides_region(self): + situation = { + "people": {"p": {"age": 40, "is_in_scotland": True}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"], "region": "London"}}, + } + persons, _, _ = _situation_to_dataframes(situation, year=2025) + # Explicit value wins even though the region is London. + assert bool(persons.loc[0, "is_in_scotland"]) is True + + def test_gender_lowercased(self): + situation = { + "people": {"p": {"age": 30, "gender": "FEMALE"}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"]}}, + } + persons, _, _ = _situation_to_dataframes(situation, year=2025) + assert persons.loc[0, "gender"] == "female" + + def test_period_keyed_value_picks_year(self): + situation = { + "people": { + "p": { + "age": 30, + "employment_income": {"2024": 1000, "2025": 2000, "2026": 3000}, + } + }, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"]}}, + } + persons, _, _ = _situation_to_dataframes(situation, year=2025) + assert persons.loc[0, "employment_income"] == 2000 + + def test_implicit_benunit_when_omitted(self): + situation = { + "people": {"p": {"age": 30}}, + # benunits omitted entirely + "households": {"h": {"members": ["p"]}}, + } + persons, benunits, households = _situation_to_dataframes(situation, year=2025) + assert len(benunits) == 1 + assert benunits.loc[0, "person_ids"] == "0" + assert persons.loc[0, "benunit_id"] == 0 + + def test_multiple_benunits_in_one_household(self): + situation = { + "people": { + "lodger": {"age": 28, "employment_income": {"2025": 20_000}}, + "owner": {"age": 45, "employment_income": {"2025": 60_000}}, + }, + "benunits": { + "bu_lodger": {"members": ["lodger"]}, + "bu_owner": {"members": ["owner"]}, + }, + "households": { + "hh": {"members": ["lodger", "owner"], "region": "London"}, + }, + } + persons, benunits, households = _situation_to_dataframes(situation, year=2025) + assert len(benunits) == 2 + # Each benunit head is the first (and only) member of its benunit. + assert bool(persons.loc[0, "is_benunit_head"]) is True + assert bool(persons.loc[1, "is_benunit_head"]) is True + # But only the first person in the household is the household head. + assert bool(persons.loc[0, "is_household_head"]) is True + assert bool(persons.loc[1, "is_household_head"]) is False + assert households.loc[0, "person_ids"] == "0;1" + assert households.loc[0, "benunit_ids"] == "0;1" + + def test_missing_person_membership_raises(self): + situation = { + "people": {"orphan": {"age": 30}}, + "benunits": {"b": {"members": []}}, + "households": {"h": {"members": []}}, + } + with pytest.raises(ValueError, match="not a member of any benunit"): + _situation_to_dataframes(situation, year=2025) + + def test_no_people_raises(self): + with pytest.raises(ValueError, match="people"): + _situation_to_dataframes( + {"people": {}, "benunits": {}, "households": {"h": {"members": []}}}, + year=2025, + ) + + def test_no_households_raises(self): + with pytest.raises(ValueError, match="households"): + _situation_to_dataframes( + {"people": {"p": {"age": 30}}, "benunits": {}, "households": {}}, + year=2025, + ) + + +# ── Simulation.from_situation ──────────────────────────────────────────────── + +class TestFromSituationClassmethod: + def test_returns_simulation_with_year(self): + sim = Simulation.from_situation( + { + "people": {"p": {"age": 30}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"], "region": "London"}}, + }, + year=2024, + ) + assert isinstance(sim, Simulation) + assert sim.year == 2024 + + def test_passes_through_dataframe_to_constructor(self): + sim = Simulation.from_situation( + { + "people": {"p": {"age": 30, "employment_income": 25_000}}, + "benunits": {"b": {"members": ["p"]}}, + "households": {"h": {"members": ["p"], "region": "LONDON"}}, + }, + ) + # Constructor stored the DataFrames so structural pre-hooks can see them. + assert sim._persons_df is not None + assert sim._benunits_df is not None + assert sim._households_df is not None + assert sim._stdin_payload is not None + assert "===PERSONS===" in sim._stdin_payload