From 17ad3712c9327ac3602e8578fcf08851c6da8c03 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sun, 8 Mar 2026 19:27:52 +0100 Subject: [PATCH 1/2] feat: Add filter_strategy column and strategy reconstruction for UK districts - Alembic migration adding filter_strategy to regions and simulations tables - Strategy reconstruction utility to rebuild ScopingStrategy objects from DB columns - Updated analysis endpoints to extract and pass filter_strategy - Updated seed script to derive filter_strategy from policyengine.py regions - Comprehensive unit tests for strategy reconstruction and filter_strategy integration Co-Authored-By: Claude Opus 4.6 --- .../20260308_add_filter_strategy_column.py | 78 ++++ scripts/seed_regions.py | 10 + src/policyengine_api/api/analysis.py | 62 ++- src/policyengine_api/models/region.py | 1 + src/policyengine_api/models/simulation.py | 5 + src/policyengine_api/utils/__init__.py | 0 .../utils/strategy_reconstruction.py | 78 ++++ test_fixtures/fixtures_regions.py | 58 +++ .../fixtures_strategy_reconstruction.py | 58 +++ tests/test_analysis.py | 358 ++++++++++++++++++ tests/test_strategy_reconstruction.py | 335 ++++++++++++++++ 11 files changed, 1041 insertions(+), 2 deletions(-) create mode 100644 alembic/versions/20260308_add_filter_strategy_column.py create mode 100644 src/policyengine_api/utils/__init__.py create mode 100644 src/policyengine_api/utils/strategy_reconstruction.py create mode 100644 test_fixtures/fixtures_strategy_reconstruction.py create mode 100644 tests/test_strategy_reconstruction.py diff --git a/alembic/versions/20260308_add_filter_strategy_column.py b/alembic/versions/20260308_add_filter_strategy_column.py new file mode 100644 index 0000000..95a9f12 --- /dev/null +++ b/alembic/versions/20260308_add_filter_strategy_column.py @@ -0,0 +1,78 @@ +"""add filter_strategy to regions and simulations + +Revision ID: add_filter_strategy +Revises: 886921687770 +Create Date: 2026-03-08 + +Adds filter_strategy column to regions and simulations tables. +Values are 'row_filter' or 'weight_replacement', indicating which +scoping strategy to use when running simulations for that region. + +Data migration: +- Existing regions with filter_field != 'household_weight' -> 'row_filter' +- Existing regions with filter_field = 'household_weight' -> 'weight_replacement' +- Simulations inherit from their region's strategy +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_filter_strategy" +down_revision: Union[str, Sequence[str], None] = "886921687770" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add filter_strategy column and backfill existing data.""" + # Add column to regions + op.add_column( + "regions", + sa.Column("filter_strategy", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + # Add column to simulations + op.add_column( + "simulations", + sa.Column("filter_strategy", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + # Backfill regions: set strategy based on existing filter_field + conn = op.get_bind() + + # Regions with filter_field = 'household_weight' use weight replacement + conn.execute( + sa.text( + "UPDATE regions SET filter_strategy = 'weight_replacement' " + "WHERE filter_field = 'household_weight'" + ) + ) + + # Regions with other non-null filter_field use row filtering + conn.execute( + sa.text( + "UPDATE regions SET filter_strategy = 'row_filter' " + "WHERE filter_field IS NOT NULL AND filter_field != 'household_weight'" + ) + ) + + # Backfill simulations based on their region's strategy + conn.execute( + sa.text( + "UPDATE simulations SET filter_strategy = regions.filter_strategy " + "FROM regions " + "WHERE simulations.region_id = regions.id " + "AND regions.filter_strategy IS NOT NULL" + ) + ) + + +def downgrade() -> None: + """Remove filter_strategy columns.""" + op.drop_column("simulations", "filter_strategy") + op.drop_column("regions", "filter_strategy") diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py index e180e7a..1128505 100644 --- a/scripts/seed_regions.py +++ b/scripts/seed_regions.py @@ -199,6 +199,11 @@ def seed_us_regions( requires_filter=pe_region.requires_filter, filter_field=pe_region.filter_field, filter_value=pe_region.filter_value, + filter_strategy=( + pe_region.scoping_strategy.strategy_type + if pe_region.scoping_strategy + else None + ), parent_code=pe_region.parent_code, state_code=pe_region.state_code, state_name=pe_region.state_name, @@ -292,6 +297,11 @@ def seed_uk_regions(session: Session) -> tuple[int, int, int]: requires_filter=pe_region.requires_filter, filter_field=pe_region.filter_field, filter_value=pe_region.filter_value, + filter_strategy=( + pe_region.scoping_strategy.strategy_type + if pe_region.scoping_strategy + else None + ), parent_code=pe_region.parent_code, state_code=None, state_name=None, diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index a46eaf2..729e792 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -251,10 +251,15 @@ def _get_deterministic_simulation_id( household_id: UUID | None = None, filter_field: str | None = None, filter_value: str | None = None, + filter_strategy: str | None = None, ) -> UUID: """Generate a deterministic UUID from simulation parameters.""" if simulation_type == SimulationType.ECONOMY: key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}:{filter_field}:{filter_value}" + # Only append filter_strategy when non-null to preserve backward + # compatibility with existing simulation IDs + if filter_strategy is not None: + key += f":{filter_strategy}" else: key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}" return uuid5(SIMULATION_NAMESPACE, key) @@ -279,6 +284,7 @@ def _get_or_create_simulation( household_id: UUID | None = None, filter_field: str | None = None, filter_value: str | None = None, + filter_strategy: str | None = None, region_id: UUID | None = None, year: int | None = None, ) -> Simulation: @@ -292,6 +298,7 @@ def _get_or_create_simulation( household_id=household_id, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, ) existing = session.get(Simulation, sim_id) @@ -309,6 +316,7 @@ def _get_or_create_simulation( status=SimulationStatus.PENDING, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, region_id=region_id, year=year, ) @@ -846,12 +854,32 @@ def build_dynamic(dynamic_id): year=dataset.year, ) - # Run simulations (with optional regional filtering) + # Reconstruct scoping strategy from DB columns (if applicable) + from policyengine_api.utils.strategy_reconstruction import reconstruct_strategy + + baseline_region = session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None + baseline_strategy = reconstruct_strategy( + filter_strategy=baseline_sim.filter_strategy, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, + region_type=baseline_region.region_type.value if baseline_region else None, + ) + + reform_region = session.get(Region, reform_sim.region_id) if reform_sim.region_id else None + reform_strategy = reconstruct_strategy( + filter_strategy=reform_sim.filter_strategy, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, + region_type=reform_region.region_type.value if reform_region else None, + ) + + # Run simulations (with optional regional scoping) pe_baseline_sim = PESimulation( dataset=pe_dataset, tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + scoping_strategy=baseline_strategy, filter_field=baseline_sim.filter_field, filter_value=baseline_sim.filter_value, ) @@ -862,6 +890,7 @@ def build_dynamic(dynamic_id): tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + scoping_strategy=reform_strategy, filter_field=reform_sim.filter_field, filter_value=reform_sim.filter_value, ) @@ -1006,12 +1035,32 @@ def build_dynamic(dynamic_id): year=dataset.year, ) - # Run simulations (with optional regional filtering) + # Reconstruct scoping strategy from DB columns (if applicable) + from policyengine_api.utils.strategy_reconstruction import reconstruct_strategy + + baseline_region = session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None + baseline_strategy = reconstruct_strategy( + filter_strategy=baseline_sim.filter_strategy, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, + region_type=baseline_region.region_type.value if baseline_region else None, + ) + + reform_region = session.get(Region, reform_sim.region_id) if reform_sim.region_id else None + reform_strategy = reconstruct_strategy( + filter_strategy=reform_sim.filter_strategy, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, + region_type=reform_region.region_type.value if reform_region else None, + ) + + # Run simulations (with optional regional scoping) pe_baseline_sim = PESimulation( dataset=pe_dataset, tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + scoping_strategy=baseline_strategy, filter_field=baseline_sim.filter_field, filter_value=baseline_sim.filter_value, ) @@ -1022,6 +1071,7 @@ def build_dynamic(dynamic_id): tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + scoping_strategy=reform_strategy, filter_field=reform_sim.filter_field, filter_value=reform_sim.filter_value, ) @@ -1199,6 +1249,7 @@ def economic_impact( # Extract filter parameters from region (if present) filter_field = region.filter_field if region and region.requires_filter else None filter_value = region.filter_value if region and region.requires_filter else None + filter_strategy = region.filter_strategy if region and region.requires_filter else None # Get model version model_version = _get_model_version(request.tax_benefit_model_name, session) @@ -1213,6 +1264,7 @@ def economic_impact( dataset_id=dataset.id, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, region_id=region.id if region else None, year=dataset.year, ) @@ -1226,6 +1278,7 @@ def economic_impact( dataset_id=dataset.id, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, region_id=region.id if region else None, year=dataset.year, ) @@ -1392,6 +1445,9 @@ def economy_custom( filter_value = ( region_obj.filter_value if region_obj and region_obj.requires_filter else None ) + filter_strategy = ( + region_obj.filter_strategy if region_obj and region_obj.requires_filter else None + ) model_version = _get_model_version(request.tax_benefit_model_name, session) @@ -1404,6 +1460,7 @@ def economy_custom( dataset_id=dataset.id, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, region_id=region_obj.id if region_obj else None, year=dataset.year, ) @@ -1417,6 +1474,7 @@ def economy_custom( dataset_id=dataset.id, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, region_id=region_obj.id if region_obj else None, year=dataset.year, ) diff --git a/src/policyengine_api/models/region.py b/src/policyengine_api/models/region.py index 29c2785..0458284 100644 --- a/src/policyengine_api/models/region.py +++ b/src/policyengine_api/models/region.py @@ -37,6 +37,7 @@ class RegionBase(SQLModel): requires_filter: bool = False filter_field: str | None = None # e.g., "state_code", "place_fips" filter_value: str | None = None # e.g., "CA", "44000" + filter_strategy: str | None = None # "row_filter" or "weight_replacement" parent_code: str | None = None # e.g., "us", "state/ca" state_code: str | None = None # For US regions state_name: str | None = None # For US regions diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index 132a560..f95100c 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -60,6 +60,10 @@ class SimulationBase(SQLModel): default=None, description="Value to match when filtering (e.g., '44000', 'ENGLAND')", ) + filter_strategy: str | None = Field( + default=None, + description="Scoping strategy: 'row_filter' or 'weight_replacement'", + ) year: int | None = None @@ -118,6 +122,7 @@ class SimulationCreate(SQLModel): region_id: UUID | None = None filter_field: str | None = None filter_value: str | None = None + filter_strategy: str | None = None year: int | None = None @model_validator(mode="after") diff --git a/src/policyengine_api/utils/__init__.py b/src/policyengine_api/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/policyengine_api/utils/strategy_reconstruction.py b/src/policyengine_api/utils/strategy_reconstruction.py new file mode 100644 index 0000000..30d7711 --- /dev/null +++ b/src/policyengine_api/utils/strategy_reconstruction.py @@ -0,0 +1,78 @@ +"""Reconstruct policyengine.py scoping strategy objects from DB columns. + +Rather than storing serialized strategy objects in the database, we store +a simple filter_strategy string ('row_filter' or 'weight_replacement') +and reconstruct the full strategy object at runtime using the existing +filter_field, filter_value, and region_type columns plus a constant +config mapping for weight matrix locations. +""" + +from policyengine.core.scoping_strategy import ( + RowFilterStrategy, + ScopingStrategy, + WeightReplacementStrategy, +) + +# GCS locations for weight matrices, keyed by region type +WEIGHT_MATRIX_CONFIG: dict[str, dict[str, str]] = { + "constituency": { + "weight_matrix_bucket": "policyengine-uk-data-private", + "weight_matrix_key": "parliamentary_constituency_weights.h5", + "lookup_csv_bucket": "policyengine-uk-data-private", + "lookup_csv_key": "constituencies_2024.csv", + }, + "local_authority": { + "weight_matrix_bucket": "policyengine-uk-data-private", + "weight_matrix_key": "local_authority_weights.h5", + "lookup_csv_bucket": "policyengine-uk-data-private", + "lookup_csv_key": "local_authorities_2021.csv", + }, +} + + +def reconstruct_strategy( + filter_strategy: str | None, + filter_field: str | None, + filter_value: str | None, + region_type: str | None, +) -> ScopingStrategy | None: + """Reconstruct a ScopingStrategy from DB columns. + + Args: + filter_strategy: Strategy type ('row_filter' or 'weight_replacement'). + filter_field: The household variable name (for row_filter). + filter_value: The value to match or region code. + region_type: The region type (e.g., 'constituency', 'local_authority'). + + Returns: + A ScopingStrategy instance, or None if no strategy is needed. + """ + if filter_strategy is None: + return None + + if filter_strategy == "row_filter": + if not filter_field or not filter_value: + return None + return RowFilterStrategy( + variable_name=filter_field, + variable_value=filter_value, + ) + + if filter_strategy == "weight_replacement": + if not filter_value or not region_type: + return None + config = WEIGHT_MATRIX_CONFIG.get(region_type) + if not config: + raise ValueError( + f"No weight matrix config for region type '{region_type}'. " + f"Known types: {list(WEIGHT_MATRIX_CONFIG.keys())}" + ) + return WeightReplacementStrategy( + region_code=filter_value, + **config, + ) + + raise ValueError( + f"Unknown filter_strategy '{filter_strategy}'. " + f"Expected 'row_filter' or 'weight_replacement'." + ) diff --git a/test_fixtures/fixtures_regions.py b/test_fixtures/fixtures_regions.py index db56633..3d95dfb 100644 --- a/test_fixtures/fixtures_regions.py +++ b/test_fixtures/fixtures_regions.py @@ -115,6 +115,7 @@ def create_region( requires_filter: bool = False, filter_field: str | None = None, filter_value: str | None = None, + filter_strategy: str | None = None, ) -> Region: """Create and persist a Region with a dataset link.""" region = Region( @@ -124,6 +125,7 @@ def create_region( requires_filter=requires_filter, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, tax_benefit_model_id=model.id, ) session.add(region) @@ -144,6 +146,7 @@ def create_simulation( model_version: TaxBenefitModelVersion, filter_field: str | None = None, filter_value: str | None = None, + filter_strategy: str | None = None, status: SimulationStatus = SimulationStatus.PENDING, ) -> Simulation: """Create and persist a Simulation with optional filter parameters.""" @@ -153,6 +156,7 @@ def create_simulation( status=status, filter_field=filter_field, filter_value=filter_value, + filter_strategy=filter_strategy, ) session.add(simulation) session.commit() @@ -263,3 +267,57 @@ def us_region_california(session, us_model_and_version, us_dataset): filter_field="state_code", filter_value="CA", ) + + +@pytest.fixture +def uk_region_england_with_strategy(session, uk_model_and_version, uk_dataset): + """Create England region with row_filter strategy.""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + filter_strategy="row_filter", + ) + + +@pytest.fixture +def uk_region_constituency(session, uk_model_and_version, uk_dataset): + """Create a UK constituency region with weight_replacement strategy.""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="constituency/sheffield-central", + label="Sheffield Central", + region_type="constituency", + requires_filter=True, + filter_field=None, + filter_value="E14001551", + filter_strategy="weight_replacement", + ) + + +@pytest.fixture +def uk_region_local_authority(session, uk_model_and_version, uk_dataset): + """Create a UK local authority region with weight_replacement strategy.""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="local_authority/manchester", + label="Manchester", + region_type="local_authority", + requires_filter=True, + filter_field=None, + filter_value="E09000003", + filter_strategy="weight_replacement", + ) diff --git a/test_fixtures/fixtures_strategy_reconstruction.py b/test_fixtures/fixtures_strategy_reconstruction.py new file mode 100644 index 0000000..2dcca9c --- /dev/null +++ b/test_fixtures/fixtures_strategy_reconstruction.py @@ -0,0 +1,58 @@ +"""Fixtures and constants for strategy reconstruction tests.""" + +# ----------------------------------------------------------------------------- +# Filter Strategy Constants +# ----------------------------------------------------------------------------- + +FILTER_STRATEGIES = { + "ROW_FILTER": "row_filter", + "WEIGHT_REPLACEMENT": "weight_replacement", +} + +# ----------------------------------------------------------------------------- +# Region Type Constants +# ----------------------------------------------------------------------------- + +REGION_TYPES = { + "CONSTITUENCY": "constituency", + "LOCAL_AUTHORITY": "local_authority", + "COUNTRY": "country", + "STATE": "state", + "NATIONAL": "national", +} + +# ----------------------------------------------------------------------------- +# Filter Field / Value Constants +# ----------------------------------------------------------------------------- + +FILTER_FIELDS = { + "COUNTRY": "country", + "STATE_CODE": "state_code", + "PLACE_FIPS": "place_fips", +} + +FILTER_VALUES = { + "ENGLAND": "ENGLAND", + "SCOTLAND": "SCOTLAND", + "CALIFORNIA": "CA", + "SHEFFIELD_CENTRAL": "E14001551", + "MANCHESTER": "E09000003", +} + +# ----------------------------------------------------------------------------- +# GCS Config Constants (expected in WEIGHT_MATRIX_CONFIG) +# ----------------------------------------------------------------------------- + +EXPECTED_CONSTITUENCY_CONFIG = { + "weight_matrix_bucket": "policyengine-uk-data-private", + "weight_matrix_key": "parliamentary_constituency_weights.h5", + "lookup_csv_bucket": "policyengine-uk-data-private", + "lookup_csv_key": "constituencies_2024.csv", +} + +EXPECTED_LOCAL_AUTHORITY_CONFIG = { + "weight_matrix_bucket": "policyengine-uk-data-private", + "weight_matrix_key": "local_authority_weights.h5", + "lookup_csv_bucket": "policyengine-uk-data-private", + "lookup_csv_key": "local_authorities_2021.csv", +} diff --git a/tests/test_analysis.py b/tests/test_analysis.py index b13659b..c977996 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -32,6 +32,9 @@ create_tax_benefit_model, create_tax_benefit_model_version, ) +from test_fixtures.fixtures_strategy_reconstruction import ( + FILTER_STRATEGIES, +) client = TestClient(app) @@ -491,6 +494,126 @@ def test__given_null_optional_params__then_consistent_id(self): assert id1 == id2 + def test__given_filter_strategy_none__then_backward_compatible_id(self): + """Passing filter_strategy=None should produce the same ID as omitting it.""" + dataset_id = uuid4() + model_version_id = uuid4() + + # Given — ID computed without filter_strategy param + id_without = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + + # When — ID computed with filter_strategy=None + id_with_none = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=None, + ) + + # Then + assert id_without == id_with_none + + def test__given_different_filter_strategy__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + # Given + id_row_filter = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + + # When + id_weight_replacement = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + ) + + # Then + assert id_row_filter != id_weight_replacement + + def test__given_filter_strategy_set__then_different_from_none(self): + dataset_id = uuid4() + model_version_id = uuid4() + + # Given + id_no_strategy = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=None, + ) + + # When + id_with_strategy = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + + # Then + assert id_no_strategy != id_with_strategy + + def test__given_same_filter_strategy__then_same_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + + assert id1 == id2 + # --------------------------------------------------------------------------- # _get_or_create_simulation @@ -652,6 +775,241 @@ def test__given_no_filter_params__then_simulation_has_null_filter_fields( assert simulation.filter_field is None assert simulation.filter_value is None + def test__given_filter_strategy__then_simulation_has_filter_strategy( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + + # Then + assert simulation.filter_strategy == FILTER_STRATEGIES["ROW_FILTER"] + + def test__given_weight_replacement_strategy__then_simulation_stores_strategy( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_value="E14001551", + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + ) + + # Then + assert simulation.filter_strategy == FILTER_STRATEGIES["WEIGHT_REPLACEMENT"] + + def test__given_different_filter_strategy__then_creates_separate_simulations( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + row_filter_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + weight_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + ) + + # Then + assert row_filter_sim.id != weight_sim.id + + def test__given_no_filter_strategy__then_simulation_has_null_filter_strategy( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # Then + assert simulation.filter_strategy is None + + def test__given_same_filter_strategy__then_reuses_simulation( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + first = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + second = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + + # Then + assert first.id == second.id + + +# --------------------------------------------------------------------------- +# _resolve_dataset_and_region (filter_strategy) +# --------------------------------------------------------------------------- + + +class TestResolveDatasetAndRegionFilterStrategy: + """Tests for filter_strategy extraction in _resolve_dataset_and_region.""" + + def test__given_region_with_row_filter_strategy__then_region_has_filter_strategy( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + # When + _, resolved_region = _resolve_dataset_and_region(request, session) + + # Then + assert resolved_region is not None + assert resolved_region.filter_strategy == FILTER_STRATEGIES["ROW_FILTER"] + + def test__given_constituency_region__then_region_has_weight_replacement_strategy( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="constituency/sheffield-central", + label="Sheffield Central", + region_type="constituency", + requires_filter=True, + filter_value="E14001551", + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="constituency/sheffield-central", + ) + + # When + _, resolved_region = _resolve_dataset_and_region(request, session) + + # Then + assert resolved_region is not None + assert resolved_region.filter_strategy == FILTER_STRATEGIES["WEIGHT_REPLACEMENT"] + + def test__given_national_region__then_filter_strategy_is_none( + self, session: Session + ): + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + # When + _, resolved_region = _resolve_dataset_and_region(request, session) + + # Then + assert resolved_region is not None + assert resolved_region.filter_strategy is None + # --------------------------------------------------------------------------- # HTTP endpoint validation (no database required) diff --git a/tests/test_strategy_reconstruction.py b/tests/test_strategy_reconstruction.py new file mode 100644 index 0000000..2075443 --- /dev/null +++ b/tests/test_strategy_reconstruction.py @@ -0,0 +1,335 @@ +"""Tests for strategy reconstruction utility. + +Tests for policyengine_api.utils.strategy_reconstruction.reconstruct_strategy(), +which rebuilds policyengine.py ScopingStrategy objects from DB columns. +""" + +import pytest +from policyengine.core.scoping_strategy import ( + RowFilterStrategy, + WeightReplacementStrategy, +) + +from policyengine_api.utils.strategy_reconstruction import ( + WEIGHT_MATRIX_CONFIG, + reconstruct_strategy, +) +from test_fixtures.fixtures_strategy_reconstruction import ( + EXPECTED_CONSTITUENCY_CONFIG, + EXPECTED_LOCAL_AUTHORITY_CONFIG, + FILTER_FIELDS, + FILTER_STRATEGIES, + FILTER_VALUES, + REGION_TYPES, +) + + +# --------------------------------------------------------------------------- +# reconstruct_strategy — None / no-op cases +# --------------------------------------------------------------------------- + + +class TestReconstructStrategyNone: + """Tests for cases where reconstruct_strategy returns None.""" + + def test__given_none_filter_strategy__then_returns_none(self): + # Given + filter_strategy = None + + # When + result = reconstruct_strategy( + filter_strategy=filter_strategy, + filter_field=FILTER_FIELDS["COUNTRY"], + filter_value=FILTER_VALUES["ENGLAND"], + region_type=REGION_TYPES["COUNTRY"], + ) + + # Then + assert result is None + + def test__given_row_filter_without_filter_field__then_returns_none(self): + # Given + filter_field = None + + # When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=filter_field, + filter_value=FILTER_VALUES["ENGLAND"], + region_type=REGION_TYPES["COUNTRY"], + ) + + # Then + assert result is None + + def test__given_row_filter_without_filter_value__then_returns_none(self): + # Given + filter_value = None + + # When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=FILTER_FIELDS["COUNTRY"], + filter_value=filter_value, + region_type=REGION_TYPES["COUNTRY"], + ) + + # Then + assert result is None + + def test__given_weight_replacement_without_filter_value__then_returns_none(self): + # Given + filter_value = None + + # When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=filter_value, + region_type=REGION_TYPES["CONSTITUENCY"], + ) + + # Then + assert result is None + + def test__given_weight_replacement_without_region_type__then_returns_none(self): + # Given + region_type = None + + # When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["SHEFFIELD_CENTRAL"], + region_type=region_type, + ) + + # Then + assert result is None + + +# --------------------------------------------------------------------------- +# reconstruct_strategy — RowFilterStrategy +# --------------------------------------------------------------------------- + + +class TestReconstructStrategyRowFilter: + """Tests for RowFilterStrategy reconstruction.""" + + def test__given_row_filter_strategy__then_returns_row_filter_instance(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=FILTER_FIELDS["COUNTRY"], + filter_value=FILTER_VALUES["ENGLAND"], + region_type=REGION_TYPES["COUNTRY"], + ) + + # Then + assert isinstance(result, RowFilterStrategy) + + def test__given_row_filter_strategy__then_variable_name_matches(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=FILTER_FIELDS["COUNTRY"], + filter_value=FILTER_VALUES["ENGLAND"], + region_type=REGION_TYPES["COUNTRY"], + ) + + # Then + assert result.variable_name == FILTER_FIELDS["COUNTRY"] + + def test__given_row_filter_strategy__then_variable_value_matches(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=FILTER_FIELDS["COUNTRY"], + filter_value=FILTER_VALUES["ENGLAND"], + region_type=REGION_TYPES["COUNTRY"], + ) + + # Then + assert result.variable_value == FILTER_VALUES["ENGLAND"] + + def test__given_us_state_row_filter__then_returns_correct_strategy(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=FILTER_FIELDS["STATE_CODE"], + filter_value=FILTER_VALUES["CALIFORNIA"], + region_type=REGION_TYPES["STATE"], + ) + + # Then + assert isinstance(result, RowFilterStrategy) + assert result.variable_name == FILTER_FIELDS["STATE_CODE"] + assert result.variable_value == FILTER_VALUES["CALIFORNIA"] + + def test__given_place_fips_row_filter__then_returns_correct_strategy(self): + # Given + fips_value = "44000" + + # When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], + filter_field=FILTER_FIELDS["PLACE_FIPS"], + filter_value=fips_value, + region_type=REGION_TYPES["STATE"], + ) + + # Then + assert isinstance(result, RowFilterStrategy) + assert result.variable_name == FILTER_FIELDS["PLACE_FIPS"] + assert result.variable_value == fips_value + + +# --------------------------------------------------------------------------- +# reconstruct_strategy — WeightReplacementStrategy +# --------------------------------------------------------------------------- + + +class TestReconstructStrategyWeightReplacement: + """Tests for WeightReplacementStrategy reconstruction.""" + + def test__given_constituency_weight_replacement__then_returns_weight_replacement_instance( + self, + ): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["SHEFFIELD_CENTRAL"], + region_type=REGION_TYPES["CONSTITUENCY"], + ) + + # Then + assert isinstance(result, WeightReplacementStrategy) + + def test__given_constituency_weight_replacement__then_region_code_matches(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["SHEFFIELD_CENTRAL"], + region_type=REGION_TYPES["CONSTITUENCY"], + ) + + # Then + assert result.region_code == FILTER_VALUES["SHEFFIELD_CENTRAL"] + + def test__given_constituency_weight_replacement__then_gcs_config_matches(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["SHEFFIELD_CENTRAL"], + region_type=REGION_TYPES["CONSTITUENCY"], + ) + + # Then + assert result.weight_matrix_bucket == EXPECTED_CONSTITUENCY_CONFIG["weight_matrix_bucket"] + assert result.weight_matrix_key == EXPECTED_CONSTITUENCY_CONFIG["weight_matrix_key"] + assert result.lookup_csv_bucket == EXPECTED_CONSTITUENCY_CONFIG["lookup_csv_bucket"] + assert result.lookup_csv_key == EXPECTED_CONSTITUENCY_CONFIG["lookup_csv_key"] + + def test__given_local_authority_weight_replacement__then_returns_weight_replacement_instance( + self, + ): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["MANCHESTER"], + region_type=REGION_TYPES["LOCAL_AUTHORITY"], + ) + + # Then + assert isinstance(result, WeightReplacementStrategy) + + def test__given_local_authority_weight_replacement__then_gcs_config_matches(self): + # Given / When + result = reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["MANCHESTER"], + region_type=REGION_TYPES["LOCAL_AUTHORITY"], + ) + + # Then + assert result.weight_matrix_bucket == EXPECTED_LOCAL_AUTHORITY_CONFIG["weight_matrix_bucket"] + assert result.weight_matrix_key == EXPECTED_LOCAL_AUTHORITY_CONFIG["weight_matrix_key"] + assert result.lookup_csv_bucket == EXPECTED_LOCAL_AUTHORITY_CONFIG["lookup_csv_bucket"] + assert result.lookup_csv_key == EXPECTED_LOCAL_AUTHORITY_CONFIG["lookup_csv_key"] + + +# --------------------------------------------------------------------------- +# reconstruct_strategy — error cases +# --------------------------------------------------------------------------- + + +class TestReconstructStrategyErrors: + """Tests for error handling in reconstruct_strategy.""" + + def test__given_unknown_filter_strategy__then_raises_value_error(self): + # Given + unknown_strategy = "magic_strategy" + + # When / Then + with pytest.raises(ValueError, match="Unknown filter_strategy"): + reconstruct_strategy( + filter_strategy=unknown_strategy, + filter_field=FILTER_FIELDS["COUNTRY"], + filter_value=FILTER_VALUES["ENGLAND"], + region_type=REGION_TYPES["COUNTRY"], + ) + + def test__given_weight_replacement_unknown_region_type__then_raises_value_error( + self, + ): + # Given + unknown_region_type = "province" + + # When / Then + with pytest.raises(ValueError, match="No weight matrix config"): + reconstruct_strategy( + filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], + filter_field=None, + filter_value=FILTER_VALUES["SHEFFIELD_CENTRAL"], + region_type=unknown_region_type, + ) + + +# --------------------------------------------------------------------------- +# WEIGHT_MATRIX_CONFIG — verify expected keys exist +# --------------------------------------------------------------------------- + + +class TestWeightMatrixConfig: + """Tests for the WEIGHT_MATRIX_CONFIG constant.""" + + def test__given_config__then_constituency_key_exists(self): + assert REGION_TYPES["CONSTITUENCY"] in WEIGHT_MATRIX_CONFIG + + def test__given_config__then_local_authority_key_exists(self): + assert REGION_TYPES["LOCAL_AUTHORITY"] in WEIGHT_MATRIX_CONFIG + + def test__given_constituency_config__then_has_all_required_keys(self): + config = WEIGHT_MATRIX_CONFIG[REGION_TYPES["CONSTITUENCY"]] + expected_keys = { + "weight_matrix_bucket", + "weight_matrix_key", + "lookup_csv_bucket", + "lookup_csv_key", + } + assert set(config.keys()) == expected_keys + + def test__given_local_authority_config__then_has_all_required_keys(self): + config = WEIGHT_MATRIX_CONFIG[REGION_TYPES["LOCAL_AUTHORITY"]] + expected_keys = { + "weight_matrix_bucket", + "weight_matrix_key", + "lookup_csv_bucket", + "lookup_csv_key", + } + assert set(config.keys()) == expected_keys From 3c3fe4fc545fbb8c30038ab290a50a7b445e9bd4 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sun, 8 Mar 2026 19:54:25 +0100 Subject: [PATCH 2/2] fix: Defer policyengine.core.scoping_strategy imports for CI compat Move imports inside reconstruct_strategy() function body so the module can be loaded without the unpublished scoping_strategy module. Tests use mock strategy classes injected via monkeypatch when the real module is unavailable. Co-Authored-By: Claude Opus 4.6 --- .../utils/strategy_reconstruction.py | 17 +++-- tests/test_strategy_reconstruction.py | 67 ++++++++++++++++--- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/policyengine_api/utils/strategy_reconstruction.py b/src/policyengine_api/utils/strategy_reconstruction.py index 30d7711..8c17331 100644 --- a/src/policyengine_api/utils/strategy_reconstruction.py +++ b/src/policyengine_api/utils/strategy_reconstruction.py @@ -7,12 +7,6 @@ config mapping for weight matrix locations. """ -from policyengine.core.scoping_strategy import ( - RowFilterStrategy, - ScopingStrategy, - WeightReplacementStrategy, -) - # GCS locations for weight matrices, keyed by region type WEIGHT_MATRIX_CONFIG: dict[str, dict[str, str]] = { "constituency": { @@ -35,9 +29,13 @@ def reconstruct_strategy( filter_field: str | None, filter_value: str | None, region_type: str | None, -) -> ScopingStrategy | None: +) -> object | None: """Reconstruct a ScopingStrategy from DB columns. + Imports from policyengine.core.scoping_strategy are deferred to avoid + import errors when the published policyengine package does not yet + include the scoping_strategy module. + Args: filter_strategy: Strategy type ('row_filter' or 'weight_replacement'). filter_field: The household variable name (for row_filter). @@ -50,6 +48,11 @@ def reconstruct_strategy( if filter_strategy is None: return None + from policyengine.core.scoping_strategy import ( + RowFilterStrategy, + WeightReplacementStrategy, + ) + if filter_strategy == "row_filter": if not filter_field or not filter_value: return None diff --git a/tests/test_strategy_reconstruction.py b/tests/test_strategy_reconstruction.py index 2075443..44379f0 100644 --- a/tests/test_strategy_reconstruction.py +++ b/tests/test_strategy_reconstruction.py @@ -2,13 +2,15 @@ Tests for policyengine_api.utils.strategy_reconstruction.reconstruct_strategy(), which rebuilds policyengine.py ScopingStrategy objects from DB columns. + +The scoping_strategy module may not exist in the published policyengine package, +so we provide mock strategy classes and inject them into sys.modules when needed. """ +import sys +from types import ModuleType + import pytest -from policyengine.core.scoping_strategy import ( - RowFilterStrategy, - WeightReplacementStrategy, -) from policyengine_api.utils.strategy_reconstruction import ( WEIGHT_MATRIX_CONFIG, @@ -24,6 +26,53 @@ ) +# --------------------------------------------------------------------------- +# Mock strategy classes (match the real constructor signatures) +# --------------------------------------------------------------------------- + + +class _MockRowFilterStrategy: + strategy_type = "row_filter" + + def __init__(self, *, variable_name: str, variable_value: str): + self.variable_name = variable_name + self.variable_value = variable_value + + +class _MockWeightReplacementStrategy: + strategy_type = "weight_replacement" + + def __init__( + self, + *, + region_code: str, + weight_matrix_bucket: str, + weight_matrix_key: str, + lookup_csv_bucket: str, + lookup_csv_key: str, + ): + self.region_code = region_code + self.weight_matrix_bucket = weight_matrix_bucket + self.weight_matrix_key = weight_matrix_key + self.lookup_csv_bucket = lookup_csv_bucket + self.lookup_csv_key = lookup_csv_key + + +@pytest.fixture(autouse=True) +def _ensure_scoping_strategy_module(monkeypatch): + """Inject a mock scoping_strategy module if the real one is not installed.""" + try: + from policyengine.core.scoping_strategy import ( # noqa: F401 + RowFilterStrategy, + WeightReplacementStrategy, + ) + except (ImportError, ModuleNotFoundError): + mock_mod = ModuleType("policyengine.core.scoping_strategy") + mock_mod.RowFilterStrategy = _MockRowFilterStrategy # type: ignore[attr-defined] + mock_mod.WeightReplacementStrategy = _MockWeightReplacementStrategy # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "policyengine.core.scoping_strategy", mock_mod) + + # --------------------------------------------------------------------------- # reconstruct_strategy — None / no-op cases # --------------------------------------------------------------------------- @@ -126,7 +175,7 @@ def test__given_row_filter_strategy__then_returns_row_filter_instance(self): ) # Then - assert isinstance(result, RowFilterStrategy) + assert result.strategy_type == "row_filter" def test__given_row_filter_strategy__then_variable_name_matches(self): # Given / When @@ -162,7 +211,7 @@ def test__given_us_state_row_filter__then_returns_correct_strategy(self): ) # Then - assert isinstance(result, RowFilterStrategy) + assert result.strategy_type == "row_filter" assert result.variable_name == FILTER_FIELDS["STATE_CODE"] assert result.variable_value == FILTER_VALUES["CALIFORNIA"] @@ -179,7 +228,7 @@ def test__given_place_fips_row_filter__then_returns_correct_strategy(self): ) # Then - assert isinstance(result, RowFilterStrategy) + assert result.strategy_type == "row_filter" assert result.variable_name == FILTER_FIELDS["PLACE_FIPS"] assert result.variable_value == fips_value @@ -204,7 +253,7 @@ def test__given_constituency_weight_replacement__then_returns_weight_replacement ) # Then - assert isinstance(result, WeightReplacementStrategy) + assert result.strategy_type == "weight_replacement" def test__given_constituency_weight_replacement__then_region_code_matches(self): # Given / When @@ -245,7 +294,7 @@ def test__given_local_authority_weight_replacement__then_returns_weight_replacem ) # Then - assert isinstance(result, WeightReplacementStrategy) + assert result.strategy_type == "weight_replacement" def test__given_local_authority_weight_replacement__then_gcs_config_matches(self): # Given / When