Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions alembic/versions/20260308_add_filter_strategy_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""add filter_strategy to regions and simulations

Revision ID: add_filter_strategy
Revises: 886921687770
Create Date: 2026-03-08

Adds filter_strategy column to regions and simulations tables.
Values are 'row_filter' or 'weight_replacement', indicating which
scoping strategy to use when running simulations for that region.

Data migration:
- Existing regions with filter_field != 'household_weight' -> 'row_filter'
- Existing regions with filter_field = 'household_weight' -> 'weight_replacement'
- Simulations inherit from their region's strategy
"""

from typing import Sequence, Union

import sqlalchemy as sa
import sqlmodel.sql.sqltypes

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "add_filter_strategy"
down_revision: Union[str, Sequence[str], None] = "886921687770"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Add filter_strategy column and backfill existing data."""
# Add column to regions
op.add_column(
"regions",
sa.Column("filter_strategy", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
)

# Add column to simulations
op.add_column(
"simulations",
sa.Column("filter_strategy", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
)

# Backfill regions: set strategy based on existing filter_field
conn = op.get_bind()

# Regions with filter_field = 'household_weight' use weight replacement
conn.execute(
sa.text(
"UPDATE regions SET filter_strategy = 'weight_replacement' "
"WHERE filter_field = 'household_weight'"
)
)

# Regions with other non-null filter_field use row filtering
conn.execute(
sa.text(
"UPDATE regions SET filter_strategy = 'row_filter' "
"WHERE filter_field IS NOT NULL AND filter_field != 'household_weight'"
)
)

# Backfill simulations based on their region's strategy
conn.execute(
sa.text(
"UPDATE simulations SET filter_strategy = regions.filter_strategy "
"FROM regions "
"WHERE simulations.region_id = regions.id "
"AND regions.filter_strategy IS NOT NULL"
)
)


def downgrade() -> None:
"""Remove filter_strategy columns."""
op.drop_column("simulations", "filter_strategy")
op.drop_column("regions", "filter_strategy")
10 changes: 10 additions & 0 deletions scripts/seed_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def seed_us_regions(
requires_filter=pe_region.requires_filter,
filter_field=pe_region.filter_field,
filter_value=pe_region.filter_value,
filter_strategy=(
pe_region.scoping_strategy.strategy_type
if pe_region.scoping_strategy
else None
),
parent_code=pe_region.parent_code,
state_code=pe_region.state_code,
state_name=pe_region.state_name,
Expand Down Expand Up @@ -292,6 +297,11 @@ def seed_uk_regions(session: Session) -> tuple[int, int, int]:
requires_filter=pe_region.requires_filter,
filter_field=pe_region.filter_field,
filter_value=pe_region.filter_value,
filter_strategy=(
pe_region.scoping_strategy.strategy_type
if pe_region.scoping_strategy
else None
),
parent_code=pe_region.parent_code,
state_code=None,
state_name=None,
Expand Down
62 changes: 60 additions & 2 deletions src/policyengine_api/api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,15 @@ def _get_deterministic_simulation_id(
household_id: UUID | None = None,
filter_field: str | None = None,
filter_value: str | None = None,
filter_strategy: str | None = None,
) -> UUID:
"""Generate a deterministic UUID from simulation parameters."""
if simulation_type == SimulationType.ECONOMY:
key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}:{filter_field}:{filter_value}"
# Only append filter_strategy when non-null to preserve backward
# compatibility with existing simulation IDs
if filter_strategy is not None:
key += f":{filter_strategy}"
else:
key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}"
return uuid5(SIMULATION_NAMESPACE, key)
Expand All @@ -279,6 +284,7 @@ def _get_or_create_simulation(
household_id: UUID | None = None,
filter_field: str | None = None,
filter_value: str | None = None,
filter_strategy: str | None = None,
region_id: UUID | None = None,
year: int | None = None,
) -> Simulation:
Expand All @@ -292,6 +298,7 @@ def _get_or_create_simulation(
household_id=household_id,
filter_field=filter_field,
filter_value=filter_value,
filter_strategy=filter_strategy,
)

existing = session.get(Simulation, sim_id)
Expand All @@ -309,6 +316,7 @@ def _get_or_create_simulation(
status=SimulationStatus.PENDING,
filter_field=filter_field,
filter_value=filter_value,
filter_strategy=filter_strategy,
region_id=region_id,
year=year,
)
Expand Down Expand Up @@ -846,12 +854,32 @@ def build_dynamic(dynamic_id):
year=dataset.year,
)

# Run simulations (with optional regional filtering)
# Reconstruct scoping strategy from DB columns (if applicable)
from policyengine_api.utils.strategy_reconstruction import reconstruct_strategy

baseline_region = session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None
baseline_strategy = reconstruct_strategy(
filter_strategy=baseline_sim.filter_strategy,
filter_field=baseline_sim.filter_field,
filter_value=baseline_sim.filter_value,
region_type=baseline_region.region_type.value if baseline_region else None,
)

reform_region = session.get(Region, reform_sim.region_id) if reform_sim.region_id else None
reform_strategy = reconstruct_strategy(
filter_strategy=reform_sim.filter_strategy,
filter_field=reform_sim.filter_field,
filter_value=reform_sim.filter_value,
region_type=reform_region.region_type.value if reform_region else None,
)

# Run simulations (with optional regional scoping)
pe_baseline_sim = PESimulation(
dataset=pe_dataset,
tax_benefit_model_version=pe_model_version,
policy=baseline_policy,
dynamic=baseline_dynamic,
scoping_strategy=baseline_strategy,
filter_field=baseline_sim.filter_field,
filter_value=baseline_sim.filter_value,
)
Expand All @@ -862,6 +890,7 @@ def build_dynamic(dynamic_id):
tax_benefit_model_version=pe_model_version,
policy=reform_policy,
dynamic=reform_dynamic,
scoping_strategy=reform_strategy,
filter_field=reform_sim.filter_field,
filter_value=reform_sim.filter_value,
)
Expand Down Expand Up @@ -1006,12 +1035,32 @@ def build_dynamic(dynamic_id):
year=dataset.year,
)

# Run simulations (with optional regional filtering)
# Reconstruct scoping strategy from DB columns (if applicable)
from policyengine_api.utils.strategy_reconstruction import reconstruct_strategy

baseline_region = session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None
baseline_strategy = reconstruct_strategy(
filter_strategy=baseline_sim.filter_strategy,
filter_field=baseline_sim.filter_field,
filter_value=baseline_sim.filter_value,
region_type=baseline_region.region_type.value if baseline_region else None,
)

reform_region = session.get(Region, reform_sim.region_id) if reform_sim.region_id else None
reform_strategy = reconstruct_strategy(
filter_strategy=reform_sim.filter_strategy,
filter_field=reform_sim.filter_field,
filter_value=reform_sim.filter_value,
region_type=reform_region.region_type.value if reform_region else None,
)

# Run simulations (with optional regional scoping)
pe_baseline_sim = PESimulation(
dataset=pe_dataset,
tax_benefit_model_version=pe_model_version,
policy=baseline_policy,
dynamic=baseline_dynamic,
scoping_strategy=baseline_strategy,
filter_field=baseline_sim.filter_field,
filter_value=baseline_sim.filter_value,
)
Expand All @@ -1022,6 +1071,7 @@ def build_dynamic(dynamic_id):
tax_benefit_model_version=pe_model_version,
policy=reform_policy,
dynamic=reform_dynamic,
scoping_strategy=reform_strategy,
filter_field=reform_sim.filter_field,
filter_value=reform_sim.filter_value,
)
Expand Down Expand Up @@ -1199,6 +1249,7 @@ def economic_impact(
# Extract filter parameters from region (if present)
filter_field = region.filter_field if region and region.requires_filter else None
filter_value = region.filter_value if region and region.requires_filter else None
filter_strategy = region.filter_strategy if region and region.requires_filter else None

# Get model version
model_version = _get_model_version(request.tax_benefit_model_name, session)
Expand All @@ -1213,6 +1264,7 @@ def economic_impact(
dataset_id=dataset.id,
filter_field=filter_field,
filter_value=filter_value,
filter_strategy=filter_strategy,
region_id=region.id if region else None,
year=dataset.year,
)
Expand All @@ -1226,6 +1278,7 @@ def economic_impact(
dataset_id=dataset.id,
filter_field=filter_field,
filter_value=filter_value,
filter_strategy=filter_strategy,
region_id=region.id if region else None,
year=dataset.year,
)
Expand Down Expand Up @@ -1392,6 +1445,9 @@ def economy_custom(
filter_value = (
region_obj.filter_value if region_obj and region_obj.requires_filter else None
)
filter_strategy = (
region_obj.filter_strategy if region_obj and region_obj.requires_filter else None
)

model_version = _get_model_version(request.tax_benefit_model_name, session)

Expand All @@ -1404,6 +1460,7 @@ def economy_custom(
dataset_id=dataset.id,
filter_field=filter_field,
filter_value=filter_value,
filter_strategy=filter_strategy,
region_id=region_obj.id if region_obj else None,
year=dataset.year,
)
Expand All @@ -1417,6 +1474,7 @@ def economy_custom(
dataset_id=dataset.id,
filter_field=filter_field,
filter_value=filter_value,
filter_strategy=filter_strategy,
region_id=region_obj.id if region_obj else None,
year=dataset.year,
)
Expand Down
1 change: 1 addition & 0 deletions src/policyengine_api/models/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RegionBase(SQLModel):
requires_filter: bool = False
filter_field: str | None = None # e.g., "state_code", "place_fips"
filter_value: str | None = None # e.g., "CA", "44000"
filter_strategy: str | None = None # "row_filter" or "weight_replacement"
parent_code: str | None = None # e.g., "us", "state/ca"
state_code: str | None = None # For US regions
state_name: str | None = None # For US regions
Expand Down
5 changes: 5 additions & 0 deletions src/policyengine_api/models/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class SimulationBase(SQLModel):
default=None,
description="Value to match when filtering (e.g., '44000', 'ENGLAND')",
)
filter_strategy: str | None = Field(
default=None,
description="Scoping strategy: 'row_filter' or 'weight_replacement'",
)

year: int | None = None

Expand Down Expand Up @@ -118,6 +122,7 @@ class SimulationCreate(SQLModel):
region_id: UUID | None = None
filter_field: str | None = None
filter_value: str | None = None
filter_strategy: str | None = None
year: int | None = None

@model_validator(mode="after")
Expand Down
Empty file.
81 changes: 81 additions & 0 deletions src/policyengine_api/utils/strategy_reconstruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Reconstruct policyengine.py scoping strategy objects from DB columns.

Rather than storing serialized strategy objects in the database, we store
a simple filter_strategy string ('row_filter' or 'weight_replacement')
and reconstruct the full strategy object at runtime using the existing
filter_field, filter_value, and region_type columns plus a constant
config mapping for weight matrix locations.
"""

# GCS locations for weight matrices, keyed by region type
WEIGHT_MATRIX_CONFIG: dict[str, dict[str, str]] = {
"constituency": {
"weight_matrix_bucket": "policyengine-uk-data-private",
"weight_matrix_key": "parliamentary_constituency_weights.h5",
"lookup_csv_bucket": "policyengine-uk-data-private",
"lookup_csv_key": "constituencies_2024.csv",
},
"local_authority": {
"weight_matrix_bucket": "policyengine-uk-data-private",
"weight_matrix_key": "local_authority_weights.h5",
"lookup_csv_bucket": "policyengine-uk-data-private",
"lookup_csv_key": "local_authorities_2021.csv",
},
}


def reconstruct_strategy(
filter_strategy: str | None,
filter_field: str | None,
filter_value: str | None,
region_type: str | None,
) -> object | None:
"""Reconstruct a ScopingStrategy from DB columns.

Imports from policyengine.core.scoping_strategy are deferred to avoid
import errors when the published policyengine package does not yet
include the scoping_strategy module.

Args:
filter_strategy: Strategy type ('row_filter' or 'weight_replacement').
filter_field: The household variable name (for row_filter).
filter_value: The value to match or region code.
region_type: The region type (e.g., 'constituency', 'local_authority').

Returns:
A ScopingStrategy instance, or None if no strategy is needed.
"""
if filter_strategy is None:
return None

from policyengine.core.scoping_strategy import (
RowFilterStrategy,
WeightReplacementStrategy,
)

if filter_strategy == "row_filter":
if not filter_field or not filter_value:
return None
return RowFilterStrategy(
variable_name=filter_field,
variable_value=filter_value,
)

if filter_strategy == "weight_replacement":
if not filter_value or not region_type:
return None
config = WEIGHT_MATRIX_CONFIG.get(region_type)
if not config:
raise ValueError(
f"No weight matrix config for region type '{region_type}'. "
f"Known types: {list(WEIGHT_MATRIX_CONFIG.keys())}"
)
return WeightReplacementStrategy(
region_code=filter_value,
**config,
)

raise ValueError(
f"Unknown filter_strategy '{filter_strategy}'. "
f"Expected 'row_filter' or 'weight_replacement'."
)
Loading