Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ jobs:

- name: Test
run: cargo test

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.11"

- name: Install Python wrapper test dependencies
run: pip install pandas pydantic pytest

- name: Test Python wrapper
run: PYTHONPATH=interfaces/python pytest interfaces/python/tests -q
1 change: 1 addition & 0 deletions changelog.d/added/from-situation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `Simulation.from_situation(situation, year)` to the Python wrapper, accepting the PolicyEngine web-app situation-JSON format (people / benunits / households with `members` lists and period-keyed values) and converting it into the input DataFrames the Rust engine expects.
256 changes: 256 additions & 0 deletions interfaces/python/policyengine_uk_compiled/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,202 @@ def _parse_stdin_payload(payload: str):
)


# Region values accepted in situation dicts → canonical form the Rust engine expects.
# Accepts the upper-snake names used by the PolicyEngine web app and the
# title-case forms used by `PERSON_DEFAULTS`/`HOUSEHOLD_DEFAULTS` and the
# `parse_region` function in `src/data/clean.rs`.
_REGION_CANONICAL = {
"NORTH_EAST": "North East", "North East": "North East",
"NORTH_WEST": "North West", "North West": "North West",
"YORKSHIRE": "Yorkshire", "Yorkshire": "Yorkshire",
"EAST_MIDLANDS": "East Midlands", "East Midlands": "East Midlands",
"WEST_MIDLANDS": "West Midlands", "West Midlands": "West Midlands",
"EAST_OF_ENGLAND": "East of England", "East of England": "East of England",
"LONDON": "London", "London": "London",
"SOUTH_EAST": "South East", "South East": "South East",
"SOUTH_WEST": "South West", "South West": "South West",
"WALES": "Wales", "Wales": "Wales",
"SCOTLAND": "Scotland", "Scotland": "Scotland",
"NORTHERN_IRELAND": "Northern Ireland", "Northern Ireland": "Northern Ireland",
}


def _resolve_period_value(value, year: int):
"""Pick a value out of a period-keyed dict, or return the scalar unchanged.

Picks an exact match on ``year`` first, then any period whose first four
characters match (covers ``"2025-01"`` style entries), then the most-recent
period that is not later than ``year``, then the earliest period.
"""
if not isinstance(value, dict):
return value
year_str = str(year)
if year_str in value:
return value[year_str]
for k, v in value.items():
if str(k)[:4] == year_str:
return v
# Numeric-period fallback
candidates = []
for k, v in value.items():
try:
candidates.append((int(str(k)[:4]), v))
except (ValueError, TypeError):
continue
if not candidates:
# Single non-period entry (e.g. {"ETERNITY": x}) → use it
return next(iter(value.values()))
candidates.sort()
earlier_or_equal = [v for y, v in candidates if y <= year]
return earlier_or_equal[-1] if earlier_or_equal else candidates[0][1]


def _situation_to_dataframes(situation: dict, year: int):
"""Convert a PolicyEngine situation-JSON dict into the three input DataFrames.

See ``Simulation.from_situation`` for the supported dict shape.
"""
if not HAS_PANDAS:
raise ImportError("pandas is required for from_situation")

people = situation.get("people") or {}
benunits = situation.get("benunits") or {}
households = situation.get("households") or {}

if not people:
raise ValueError("situation must contain at least one entry under 'people'")
if not households:
raise ValueError("situation must contain at least one entry under 'households'")
if not benunits:
# Fold all people into a single implicit benunit so callers don't
# have to supply one for trivial cases.
benunits = {"_default": {"members": list(people.keys())}}

person_id_map = {pid: i for i, pid in enumerate(people.keys())}
benunit_id_map = {bid: i for i, bid in enumerate(benunits.keys())}
household_id_map = {hid: i for i, hid in enumerate(households.keys())}

# Build reverse lookups: person → benunit, person → household
person_to_benunit: dict[str, str] = {}
for bid, fields in benunits.items():
for member in (fields.get("members") or []):
person_to_benunit[member] = bid
person_to_household: dict[str, str] = {}
for hid, fields in households.items():
for member in (fields.get("members") or []):
person_to_household[member] = hid

person_rows = []
for pid, fields in people.items():
if pid not in person_to_benunit:
raise ValueError(f"person {pid!r} is not a member of any benunit")
if pid not in person_to_household:
raise ValueError(f"person {pid!r} is not a member of any household")
row = dict(PERSON_DEFAULTS)
row["person_id"] = person_id_map[pid]
row["benunit_id"] = benunit_id_map[person_to_benunit[pid]]
row["household_id"] = household_id_map[person_to_household[pid]]
for var, val in (fields or {}).items():
if var == "members":
continue
resolved = _resolve_period_value(val, year)
if var == "gender" and isinstance(resolved, str):
resolved = resolved.lower()
row[var] = resolved
person_rows.append(row)

# Mark the first member of each benunit as benunit head, and the first
# member of each household as household head, unless the situation
# already specified these flags.
seen_bu_head: set[int] = set()
seen_hh_head: set[int] = set()
explicit_bu_head: set[str] = set()
explicit_hh_head: set[str] = set()
for pid, fields in people.items():
if "is_benunit_head" in (fields or {}):
explicit_bu_head.add(pid)
if "is_household_head" in (fields or {}):
explicit_hh_head.add(pid)
for pid, row in zip(people.keys(), person_rows):
bu = row["benunit_id"]
hh = row["household_id"]
if pid in explicit_bu_head:
seen_bu_head.add(bu)
else:
row["is_benunit_head"] = bu not in seen_bu_head
if bu not in seen_bu_head:
seen_bu_head.add(bu)
if pid in explicit_hh_head:
seen_hh_head.add(hh)
else:
row["is_household_head"] = hh not in seen_hh_head
if hh not in seen_hh_head:
seen_hh_head.add(hh)

benunit_rows = []
for bid, fields in benunits.items():
members = fields.get("members") or []
member_int_ids = [person_id_map[m] for m in members if m in person_id_map]
# Single household owns this benunit — pick from the first member.
if member_int_ids:
owner_household = next(
household_id_map[person_to_household[m]]
for m in members
if m in person_to_household
)
else:
owner_household = 0
row = dict(BENUNIT_DEFAULTS)
row["benunit_id"] = benunit_id_map[bid]
row["household_id"] = owner_household
row["person_ids"] = ";".join(str(i) for i in member_int_ids)
for var, val in (fields or {}).items():
if var == "members":
continue
row[var] = _resolve_period_value(val, year)
benunit_rows.append(row)

household_rows = []
for hid, fields in households.items():
members = fields.get("members") or []
member_int_ids = [person_id_map[m] for m in members if m in person_id_map]
member_benunits = sorted({
benunit_id_map[person_to_benunit[m]]
for m in members
if m in person_to_benunit
})
row = dict(HOUSEHOLD_DEFAULTS)
row["household_id"] = household_id_map[hid]
row["person_ids"] = ";".join(str(i) for i in member_int_ids)
row["benunit_ids"] = ";".join(str(i) for i in member_benunits)
for var, val in (fields or {}).items():
if var == "members":
continue
resolved = _resolve_period_value(val, year)
if var == "region" and isinstance(resolved, str):
resolved = _REGION_CANONICAL.get(resolved, resolved)
row[var] = resolved
household_rows.append(row)

# Propagate `is_in_scotland` from each person's household region unless
# the situation already set it explicitly.
region_by_household = {h["household_id"]: h.get("region") for h in household_rows}
explicit_in_scotland = {
pid for pid, fields in people.items()
if "is_in_scotland" in (fields or {})
}
for pid, row in zip(people.keys(), person_rows):
if pid in explicit_in_scotland:
continue
row["is_in_scotland"] = region_by_household.get(row["household_id"]) == "Scotland"

return (
pd.DataFrame(person_rows),
pd.DataFrame(benunit_rows),
pd.DataFrame(household_rows),
)


def _parse_microdata_stdout(raw: str) -> MicrodataResult:
"""Parse the concatenated CSV protocol output into a MicrodataResult."""
sections = {}
Expand Down Expand Up @@ -612,6 +808,66 @@ def get_baseline_params(self, timeout: int = 10) -> dict:

# ── Convenience constructors for hypothetical households ──────────────

@staticmethod
def from_situation(
situation: dict,
year: int = 2025,
**kwargs,
) -> "Simulation":
"""Build a Simulation from a PolicyEngine situation-JSON dict.

The situation dict mirrors the PolicyEngine web-app format::

{
"people": {"<id>": {"<var>": {"<period>": <value>}, ...}, ...},
"benunits": {"<id>": {"members": [...], "<var>": ..., ...}, ...},
"households": {"<id>": {"members": [...], "<var>": ..., ...}, ...},
}

Each variable's value may be either a period-keyed dict (e.g.
``{"2025": 50000}``) or a plain scalar — scalars are treated as
applying to ``year``.

Variable names map directly to the wrapper input columns (see
``PERSON_DEFAULTS``, ``BENUNIT_DEFAULTS``, ``HOUSEHOLD_DEFAULTS``).
``region`` accepts either the title-case form (``"London"``,
``"North East"``) or the upper-snake form used by the
PolicyEngine web app (``"LONDON"``, ``"NORTH_EAST"``); it is
normalised before being passed to the Rust engine and
``is_in_scotland`` is set automatically. ``gender`` is
case-insensitive.

Members lists on benunits/households reference the keys used in
``situation["people"]``; people are assigned integer ``person_id``
values in the order they appear under ``people``, and benunits/
households receive the ``person_ids`` / ``benunit_ids`` strings
the engine expects.

Example::

sim = Simulation.from_situation(
{
"people": {
"you": {"age": 30, "employment_income": {"2025": 50000}},
},
"benunits": {"yours": {"members": ["you"]}},
"households": {"yours": {"members": ["you"], "region": "LONDON"}},
},
year=2025,
)
result = sim.run()
"""
persons_df, benunits_df, households_df = _situation_to_dataframes(
situation, year
)
return Simulation(
year=year,
persons=persons_df,
benunits=benunits_df,
households=households_df,
**kwargs,
)

@staticmethod
def single_person(
age: float = 30,
Expand Down
Empty file.
Loading