diff --git a/alembic/versions/20260303_886921687770_region_datasets_join_table_and_.py b/alembic/versions/20260303_886921687770_region_datasets_join_table_and_.py new file mode 100644 index 0000000..f95d0b6 --- /dev/null +++ b/alembic/versions/20260303_886921687770_region_datasets_join_table_and_.py @@ -0,0 +1,66 @@ +"""region_datasets_join_table_and_simulation_year + +Revision ID: 886921687770 +Revises: 963e91da9298 +Create Date: 2026-03-03 18:56:13.551288 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '886921687770' +down_revision: Union[str, Sequence[str], None] = '963e91da9298' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Create the region_datasets join table + op.create_table('region_datasets', + sa.Column('region_id', sa.Uuid(), nullable=False), + sa.Column('dataset_id', sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['region_id'], ['regions.id'], ), + sa.PrimaryKeyConstraint('region_id', 'dataset_id') + ) + + # Migrate existing region->dataset links into the join table + op.execute(""" + INSERT INTO region_datasets (region_id, dataset_id) + SELECT id, dataset_id FROM regions + WHERE dataset_id IS NOT NULL + """) + + # Drop the old FK and column from regions + op.drop_constraint(op.f('regions_dataset_id_fkey'), 'regions', type_='foreignkey') + op.drop_column('regions', 'dataset_id') + + # Add year column to simulations + op.add_column('simulations', sa.Column('year', sa.Integer(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('simulations', 'year') + op.add_column('regions', sa.Column('dataset_id', sa.UUID(), autoincrement=False, nullable=True)) + + # Migrate join table data back to the FK column (pick one dataset per region) + op.execute(""" + UPDATE regions r + SET dataset_id = rd.dataset_id + FROM ( + SELECT DISTINCT ON (region_id) region_id, dataset_id + FROM region_datasets + ORDER BY region_id + ) rd + WHERE r.id = rd.region_id + """) + + op.alter_column('regions', 'dataset_id', nullable=False) + op.create_foreign_key(op.f('regions_dataset_id_fkey'), 'regions', 'datasets', ['dataset_id'], ['id']) + op.drop_table('region_datasets') diff --git a/import_state_datasets.py b/import_state_datasets.py new file mode 100644 index 0000000..96fd7fa --- /dev/null +++ b/import_state_datasets.py @@ -0,0 +1,471 @@ +"""Download, convert, and upload state & congressional district datasets. + +One-off script to migrate state/district datasets from GCS (old format) +to Supabase (new yearly entity-level format). + +Downloads raw h5 files from GCS, converts them to yearly entity-level files +using policyengine's create_datasets(), uploads to Supabase, and creates +database records. + +Usage: + python import_state_datasets.py AL # State + all AL districts + python import_state_datasets.py CA NY TX # Multiple states + districts + python import_state_datasets.py --all # All 51 states + 436 districts + python import_state_datasets.py AL --state-only # State only, no districts + python import_state_datasets.py --years 2025,2026 + python import_state_datasets.py --skip-upload # Convert only + +Must be run from the policyengine-api-v2-alpha project root (where .env lives). +""" + +import argparse +import json +import logging +import subprocess +import sys +import time +import warnings +from datetime import datetime, timezone +from pathlib import Path + +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path for policyengine_api imports +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from rich.console import Console +from sqlmodel import Session, create_engine, select + +from policyengine_api.config.settings import settings +from policyengine_api.models import Dataset, TaxBenefitModel +from policyengine_api.services.storage import upload_dataset_for_seeding +from policyengine.countries.us.data import DISTRICT_COUNTS + +console = Console() + +GCS_BUCKET = "gs://policyengine-us-data" +TMP_DIR = Path("/tmp/pe_state_data") +DEFAULT_YEARS = list(range(2024, 2036)) + +ALL_STATES = list(DISTRICT_COUNTS.keys()) + + +def fmt_duration(seconds: float) -> str: + """Format seconds into a human-readable duration.""" + if seconds < 60: + return f"{seconds:.1f}s" + minutes = int(seconds // 60) + secs = seconds % 60 + if minutes < 60: + return f"{minutes}m {secs:.0f}s" + hours = int(minutes // 60) + mins = minutes % 60 + return f"{hours}h {mins}m {secs:.0f}s" + + +def get_session() -> Session: + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def download_from_gcs(gcs_path: str, local_path: Path) -> bool: + """Download a file from GCS using gsutil. Skips if already exists locally.""" + if local_path.exists() and local_path.stat().st_size > 0: + return True + local_path.parent.mkdir(parents=True, exist_ok=True) + result = subprocess.run( + ["gsutil", "cp", gcs_path, str(local_path)], + capture_output=True, + text=True, + ) + if result.returncode != 0: + console.print(f" [red]gsutil error: {result.stderr.strip()}[/red]") + return False + return True + + +def convert_dataset( + raw_h5_path: str, output_folder: str, years: list[int] +) -> dict: + """Convert a raw h5 file to yearly entity-level h5 files. + + Skips conversion if all yearly output files already exist. + Returns dict mapping dataset_key -> PolicyEngineUSDataset. + """ + from policyengine.tax_benefit_models.us.datasets import ( + create_datasets, + load_datasets, + ) + + stem = Path(raw_h5_path).stem + all_exist = all( + Path(f"{output_folder}/{stem}_year_{year}.h5").exists() + for year in years + ) + if all_exist: + return load_datasets( + datasets=[raw_h5_path], + years=years, + data_folder=output_folder, + ) + + return create_datasets( + datasets=[raw_h5_path], + years=years, + data_folder=output_folder, + ) + + +def process_file( + file_info: dict, + years: list[int], + data_folder: Path, + skip_upload: bool, + session, + us_model, + file_index: int, + total_files: int, +) -> tuple[int, int, int, dict]: + """Process a single raw h5 file (state or district). + + Returns (datasets_created, datasets_skipped, errors, timing). + Region-to-dataset wiring is handled by seed_regions.py, not here. + """ + code = file_info["code"] + prefix = f" [{file_index}/{total_files}] {code}" + datasets_created = 0 + datasets_skipped = 0 + errors = 0 + timing = {"code": code, "type": file_info["type"]} + + # Step 1: Download + t0 = time.time() + console.print(f"{prefix}: downloading from GCS...") + if not download_from_gcs(file_info["gcs_path"], file_info["local_path"]): + console.print(f"{prefix}: [red]download failed, skipping[/red]") + timing["status"] = "download_failed" + return 0, 0, 1, timing + dl_time = time.time() - t0 + size_mb = file_info["local_path"].stat().st_size / (1024 * 1024) + timing["download_seconds"] = round(dl_time, 2) + timing["raw_size_mb"] = round(size_mb, 1) + console.print(f"{prefix}: downloaded ({size_mb:.1f} MB, {fmt_duration(dl_time)})") + + # Step 2: Convert + t0 = time.time() + console.print(f"{prefix}: converting to {len(years)} yearly datasets...") + output_folder = str(data_folder / file_info["output_subfolder"]) + try: + converted = convert_dataset( + str(file_info["local_path"]), output_folder, years + ) + except Exception as e: + console.print(f"{prefix}: [red]conversion failed: {e}[/red]") + timing["status"] = "conversion_failed" + timing["error"] = str(e) + return 0, 0, 1, timing + conv_time = time.time() - t0 + timing["conversion_seconds"] = round(conv_time, 2) + timing["datasets_converted"] = len(converted) + console.print(f"{prefix}: converted {len(converted)} datasets ({fmt_duration(conv_time)})") + + # Step 3: Upload + create DB records + if skip_upload: + datasets_skipped += len(converted) + timing["upload_seconds"] = 0 + timing["status"] = "upload_skipped" + console.print(f"{prefix}: [yellow]upload skipped[/yellow]") + else: + t0 = time.time() + console.print(f"{prefix}: uploading to Supabase...") + for _, pe_dataset in converted.items(): + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + datasets_skipped += 1 + continue + + object_name = f"{file_info['supabase_prefix']}/{pe_dataset.name}.h5" + + try: + upload_dataset_for_seeding(pe_dataset.filepath, object_name=object_name) + except Exception as e: + console.print(f"{prefix}: [red]upload failed for {pe_dataset.name}: {e}[/red]") + errors += 1 + continue + + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + session.refresh(db_dataset) + datasets_created += 1 + + upload_time = time.time() - t0 + timing["upload_seconds"] = round(upload_time, 2) + console.print( + f"{prefix}: uploaded {datasets_created} datasets, " + f"{datasets_skipped} already existed ({fmt_duration(upload_time)})" + ) + + timing["datasets_created"] = datasets_created + timing["datasets_skipped"] = datasets_skipped + timing["errors"] = errors + timing["status"] = timing.get("status", "ok") + + return datasets_created, datasets_skipped, errors, timing + + +def process_state( + state_code: str, + years: list[int], + data_folder: Path, + skip_upload: bool, + state_only: bool, + session, + us_model, +) -> tuple[int, int, int, list[dict]]: + """Process one state: its state-level file and all district files. + + Returns (created, skipped, errors, file_timings). + Region-to-dataset wiring is handled by seed_regions.py, not here. + """ + district_count = DISTRICT_COUNTS.get(state_code, 0) + + files_to_process = [] + + # State file + files_to_process.append({ + "type": "state", + "code": state_code, + "gcs_path": f"{GCS_BUCKET}/states/{state_code}.h5", + "local_path": TMP_DIR / "states" / f"{state_code}.h5", + "output_subfolder": "states", + "supabase_prefix": f"states/{state_code}", + }) + + # District files + if not state_only: + for i in range(1, district_count + 1): + district_code = f"{state_code}-{i:02d}" + files_to_process.append({ + "type": "district", + "code": district_code, + "gcs_path": f"{GCS_BUCKET}/districts/{district_code}.h5", + "local_path": TMP_DIR / "districts" / f"{district_code}.h5", + "output_subfolder": "districts", + "supabase_prefix": f"districts/{district_code}", + }) + + total_files = len(files_to_process) + total_created = 0 + total_skipped = 0 + total_errors = 0 + file_timings = [] + + for i, file_info in enumerate(files_to_process, 1): + created, skipped, errs, timing = process_file( + file_info=file_info, + years=years, + data_folder=data_folder, + skip_upload=skip_upload, + session=session, + us_model=us_model, + file_index=i, + total_files=total_files, + ) + total_created += created + total_skipped += skipped + total_errors += errs + file_timings.append(timing) + + return total_created, total_skipped, total_errors, file_timings + + +def main(): + parser = argparse.ArgumentParser( + description="Import state & district datasets from GCS to Supabase" + ) + parser.add_argument( + "states", + nargs="*", + help="State codes (e.g., CA NY TX). Uppercase 2-letter codes.", + ) + parser.add_argument( + "--all", + action="store_true", + dest="all_states", + help="Process all 51 states + DC", + ) + parser.add_argument( + "--state-only", + action="store_true", + help="Skip district processing, only do state-level datasets", + ) + parser.add_argument( + "--years", + type=str, + default=None, + help="Comma-separated years (default: 2024,2025,2026,2027,2028)", + ) + parser.add_argument( + "--skip-upload", + action="store_true", + help="Convert locally without uploading to Supabase or creating DB records", + ) + parser.add_argument( + "--data-folder", + type=str, + default=None, + help="Local directory for converted files (default: ./data)", + ) + args = parser.parse_args() + + # Determine which states to process + if args.all_states: + states = ALL_STATES + elif args.states: + states = [s.upper() for s in args.states] + else: + parser.error("Provide state codes or use --all") + return + + # Validate state codes + invalid = [s for s in states if s not in DISTRICT_COUNTS] + if invalid: + console.print(f"[red]Invalid state codes: {', '.join(invalid)}[/red]") + sys.exit(1) + + years = DEFAULT_YEARS + if args.years: + years = [int(y.strip()) for y in args.years.split(",")] + + data_folder = Path(args.data_folder) if args.data_folder else Path(__file__).parent / "data" + + total_districts = sum(DISTRICT_COUNTS[s] for s in states) if not args.state_only else 0 + total_files = len(states) + total_districts + total_yearly = total_files * len(years) + + console.print() + console.print("[bold green]State & District Dataset Import[/bold green]") + console.print(f" States: {len(states)} ({', '.join(states)})") + if not args.state_only: + console.print(f" Districts: {total_districts}") + console.print(f" Years: {years}") + console.print(f" Raw files to process: {total_files}") + console.print(f" Yearly datasets to produce: {total_yearly}") + if args.skip_upload: + console.print(" [yellow]Upload skipped (--skip-upload)[/yellow]") + console.print() + + grand_created = 0 + grand_skipped = 0 + grand_errors = 0 + + session = None + us_model = None + + if not args.skip_upload: + session = get_session() + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + if not us_model: + console.print("[red]Error: US model not found. Run seed_models.py first.[/red]") + sys.exit(1) + + script_start = time.time() + timing_report = { + "started_at": datetime.now(timezone.utc).isoformat(), + "args": { + "states": states, + "years": years, + "state_only": args.state_only, + "skip_upload": args.skip_upload, + "data_folder": str(data_folder), + }, + "states": [], + } + + for state_idx, state_code in enumerate(states, 1): + district_count = DISTRICT_COUNTS[state_code] + file_count = 1 + (district_count if not args.state_only else 0) + console.print( + f"[bold]({state_idx}/{len(states)}) Processing {state_code} " + f"({file_count} files)[/bold]" + ) + + state_start = time.time() + + created, skipped, errs, file_timings = process_state( + state_code=state_code, + years=years, + data_folder=data_folder, + skip_upload=args.skip_upload, + state_only=args.state_only, + session=session, + us_model=us_model, + ) + + state_time = time.time() - state_start + console.print( + f"[bold]({state_idx}/{len(states)}) {state_code} complete: " + f"{created} created, {skipped} skipped" + f"{f', {errs} errors' if errs else ''} " + f"({fmt_duration(state_time)})[/bold]" + ) + console.print() + + timing_report["states"].append({ + "state": state_code, + "total_seconds": round(state_time, 2), + "datasets_created": created, + "datasets_skipped": skipped, + "errors": errs, + "files": file_timings, + }) + + # Write timing file after each state so partial results are preserved + timing_path = data_folder / "import_timing.json" + timing_path.parent.mkdir(parents=True, exist_ok=True) + timing_path.write_text(json.dumps(timing_report, indent=2)) + + grand_created += created + grand_skipped += skipped + grand_errors += errs + + if session: + session.close() + + total_time = time.time() - script_start + + # Write final timing report + timing_report["finished_at"] = datetime.now(timezone.utc).isoformat() + timing_report["total_seconds"] = round(total_time, 2) + timing_report["totals"] = { + "datasets_created": grand_created, + "datasets_skipped": grand_skipped, + "errors": grand_errors, + } + timing_path = data_folder / "import_timing.json" + timing_path.write_text(json.dumps(timing_report, indent=2)) + + console.print(f"[bold green]Import complete ({fmt_duration(total_time)})[/bold green]") + console.print(f" Datasets created: {grand_created}") + console.print(f" Datasets skipped (already exist): {grand_skipped}") + if grand_errors: + console.print(f" [red]Errors: {grand_errors}[/red]") + console.print(f" Timing report: {timing_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py index 060fb2f..f0e05ed 100644 --- a/scripts/seed_regions.py +++ b/scripts/seed_regions.py @@ -5,7 +5,11 @@ - UK: National and 4 countries (England, Scotland, Wales, Northern Ireland) Regions are sourced from policyengine.py's region registries and linked -to the appropriate datasets in the database. +to the appropriate datasets via the region_datasets join table. + +This script is the SOLE source of truth for region-to-dataset wiring. +After importing datasets with import_state_datasets.py, re-run this script +to link regions to any newly available datasets. Usage: python scripts/seed_regions.py # Seed all US and UK regions @@ -24,14 +28,101 @@ from seed_utils import console, get_session # Import after seed_utils sets up path -from policyengine_api.models import Dataset, Region, TaxBenefitModel # noqa: E402 +from policyengine_api.models import ( # noqa: E402 + Dataset, + Region, + RegionDatasetLink, + TaxBenefitModel, +) + + +def _group_us_datasets( + session: Session, us_model_id, +) -> tuple[list[Dataset], dict[str, list[Dataset]], dict[str, list[Dataset]]]: + """Pre-fetch and group all US datasets by type. + + Returns: + (national_datasets, state_datasets_by_code, district_datasets_by_code) + """ + all_datasets = session.exec( + select(Dataset).where(Dataset.tax_benefit_model_id == us_model_id) + ).all() + + national = [] + by_state: dict[str, list[Dataset]] = {} + by_district: dict[str, list[Dataset]] = {} + + for d in all_datasets: + if d.filepath and d.filepath.startswith("states/"): + # filepath = "states/AL/AL-year-2024.h5" + parts = d.filepath.split("/") + if len(parts) >= 2: + by_state.setdefault(parts[1], []).append(d) + elif d.filepath and d.filepath.startswith("districts/"): + # filepath = "districts/AL-01/AL-01-year-2024.h5" + parts = d.filepath.split("/") + if len(parts) >= 2: + by_district.setdefault(parts[1], []).append(d) + elif "cps" in d.name.lower(): + national.append(d) + + return national, by_state, by_district + + +def _get_datasets_for_us_region( + pe_region, + national_datasets: list[Dataset], + state_datasets: dict[str, list[Dataset]], + district_datasets: dict[str, list[Dataset]], +) -> list[Dataset]: + """Determine which datasets a US region should link to.""" + if pe_region.region_type == "national": + return national_datasets + + elif pe_region.region_type == "state": + # "state/ca" -> "CA" + state_code = pe_region.code.split("/")[1].upper() + return state_datasets.get(state_code, national_datasets) + + elif pe_region.region_type == "congressional_district": + # "congressional_district/CA-12" -> "CA-12" + district_code = pe_region.code.split("/")[1].upper() + return district_datasets.get(district_code, national_datasets) + + elif pe_region.region_type == "place": + # Places use parent state's datasets (filter at runtime) + if pe_region.state_code: + return state_datasets.get(pe_region.state_code, national_datasets) + return national_datasets + + return national_datasets + + +def _link_datasets( + region_id, + datasets: list[Dataset], + existing_link_set: set[tuple], + session: Session, +) -> int: + """Create RegionDatasetLink entries for missing links. + + Returns the number of new links created. + """ + created = 0 + for dataset in datasets: + key = (region_id, dataset.id) + if key not in existing_link_set: + session.add(RegionDatasetLink(region_id=region_id, dataset_id=dataset.id)) + existing_link_set.add(key) + created += 1 + return created def seed_us_regions( session: Session, skip_places: bool = False, skip_districts: bool = False, -) -> tuple[int, int]: +) -> tuple[int, int, int]: """Seed US regions from policyengine.py registry. Args: @@ -40,7 +131,7 @@ def seed_us_regions( skip_districts: Skip congressional districts Returns: - Tuple of (created_count, skipped_count) + Tuple of (created_count, skipped_count, links_created) """ from policyengine.countries.us.regions import us_region_registry @@ -51,22 +142,24 @@ def seed_us_regions( if not us_model: console.print("[red]Error: US model not found. Run seed.py first.[/red]") - return 0, 0 + return 0, 0, 0 - # Get US national dataset (CPS) - us_dataset = session.exec( - select(Dataset) - .where(Dataset.tax_benefit_model_id == us_model.id) - .where(Dataset.name.contains("cps")) # type: ignore - .order_by(Dataset.year.desc()) # type: ignore - ).first() + # Pre-fetch and group all US datasets + national_datasets, state_datasets, district_datasets = _group_us_datasets( + session, us_model.id + ) + + if not national_datasets: + console.print("[red]Error: No US CPS datasets found. Run seed.py first.[/red]") + return 0, 0, 0 - if not us_dataset: - console.print("[red]Error: US dataset not found. Run seed.py first.[/red]") - return 0, 0 + # Pre-fetch existing dataset links for efficiency + existing_links = session.exec(select(RegionDatasetLink)).all() + existing_link_set = {(l.region_id, l.dataset_id) for l in existing_links} created = 0 skipped = 0 + links_created = 0 # Filter regions based on options regions_to_seed = [] @@ -90,47 +183,54 @@ def seed_us_regions( for pe_region in regions_to_seed: progress.update(task, description=f"US: {pe_region.label}") - # Check if region already exists + # Find existing or create new region existing = session.exec( select(Region).where(Region.code == pe_region.code) ).first() if existing: + db_region = existing skipped += 1 - progress.advance(task) - continue - - # Create region record - db_region = Region( - code=pe_region.code, - label=pe_region.label, - region_type=pe_region.region_type, - requires_filter=pe_region.requires_filter, - filter_field=pe_region.filter_field, - filter_value=pe_region.filter_value, - parent_code=pe_region.parent_code, - state_code=pe_region.state_code, - state_name=pe_region.state_name, - dataset_id=us_dataset.id, # All US regions use the national dataset - tax_benefit_model_id=us_model.id, + else: + db_region = Region( + code=pe_region.code, + label=pe_region.label, + region_type=pe_region.region_type, + requires_filter=pe_region.requires_filter, + filter_field=pe_region.filter_field, + filter_value=pe_region.filter_value, + parent_code=pe_region.parent_code, + state_code=pe_region.state_code, + state_name=pe_region.state_name, + tax_benefit_model_id=us_model.id, + ) + session.add(db_region) + session.flush() # Get the ID assigned + created += 1 + + # Link datasets for this region + datasets = _get_datasets_for_us_region( + pe_region, national_datasets, state_datasets, district_datasets ) - session.add(db_region) - created += 1 + links_created += _link_datasets( + db_region.id, datasets, existing_link_set, session + ) + progress.advance(task) session.commit() - return created, skipped + return created, skipped, links_created -def seed_uk_regions(session: Session) -> tuple[int, int]: +def seed_uk_regions(session: Session) -> tuple[int, int, int]: """Seed UK regions from policyengine.py registry. Args: session: Database session Returns: - Tuple of (created_count, skipped_count) + Tuple of (created_count, skipped_count, links_created) """ from policyengine.countries.uk.regions import uk_region_registry @@ -143,24 +243,28 @@ def seed_uk_regions(session: Session) -> tuple[int, int]: console.print( "[yellow]Warning: UK model not found. Skipping UK regions.[/yellow]" ) - return 0, 0 + return 0, 0, 0 - # Get UK national dataset (FRS) - uk_dataset = session.exec( + # Get all UK FRS datasets + uk_datasets = session.exec( select(Dataset) .where(Dataset.tax_benefit_model_id == uk_model.id) .where(Dataset.name.contains("frs")) # type: ignore - .order_by(Dataset.year.desc()) # type: ignore - ).first() + ).all() - if not uk_dataset: + if not uk_datasets: console.print( - "[yellow]Warning: UK dataset not found. Skipping UK regions.[/yellow]" + "[yellow]Warning: No UK FRS datasets found. Skipping UK regions.[/yellow]" ) - return 0, 0 + return 0, 0, 0 + + # Pre-fetch existing dataset links + existing_links = session.exec(select(RegionDatasetLink)).all() + existing_link_set = {(l.region_id, l.dataset_id) for l in existing_links} created = 0 skipped = 0 + links_created = 0 with Progress( SpinnerColumn(), @@ -172,37 +276,41 @@ def seed_uk_regions(session: Session) -> tuple[int, int]: for pe_region in uk_region_registry.regions: progress.update(task, description=f"UK: {pe_region.label}") - # Check if region already exists + # Find existing or create new region existing = session.exec( select(Region).where(Region.code == pe_region.code) ).first() if existing: + db_region = existing skipped += 1 - progress.advance(task) - continue - - # Create region record - db_region = Region( - code=pe_region.code, - label=pe_region.label, - region_type=pe_region.region_type, - requires_filter=pe_region.requires_filter, - filter_field=pe_region.filter_field, - filter_value=pe_region.filter_value, - parent_code=pe_region.parent_code, - state_code=None, # UK regions don't have state_code - state_name=None, - dataset_id=uk_dataset.id, # All UK regions use the national dataset - tax_benefit_model_id=uk_model.id, + else: + db_region = Region( + code=pe_region.code, + label=pe_region.label, + region_type=pe_region.region_type, + requires_filter=pe_region.requires_filter, + filter_field=pe_region.filter_field, + filter_value=pe_region.filter_value, + parent_code=pe_region.parent_code, + state_code=None, + state_name=None, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_region) + session.flush() + created += 1 + + # All UK regions link to FRS datasets (they filter at runtime) + links_created += _link_datasets( + db_region.id, uk_datasets, existing_link_set, session ) - session.add(db_region) - created += 1 + progress.advance(task) session.commit() - return created, skipped + return created, skipped, links_created def main(): @@ -234,34 +342,42 @@ def main(): start = time.time() total_created = 0 total_skipped = 0 + total_links = 0 with get_session() as session: # Seed US regions if not args.uk_only: console.print("[bold]US Regions[/bold]") - us_created, us_skipped = seed_us_regions( + us_created, us_skipped, us_links = seed_us_regions( session, skip_places=args.skip_places, skip_districts=args.skip_districts, ) total_created += us_created total_skipped += us_skipped + total_links += us_links console.print( - f"[green]✓[/green] US regions: {us_created} created, {us_skipped} skipped\n" + f"[green]\u2713[/green] US regions: {us_created} created, " + f"{us_skipped} skipped, {us_links} dataset links added\n" ) # Seed UK regions if not args.us_only: console.print("[bold]UK Regions[/bold]") - uk_created, uk_skipped = seed_uk_regions(session) + uk_created, uk_skipped, uk_links = seed_uk_regions(session) total_created += uk_created total_skipped += uk_skipped + total_links += uk_links console.print( - f"[green]✓[/green] UK regions: {uk_created} created, {uk_skipped} skipped\n" + f"[green]\u2713[/green] UK regions: {uk_created} created, " + f"{uk_skipped} skipped, {uk_links} dataset links added\n" ) elapsed = time.time() - start - console.print(f"[bold]Total: {total_created} created, {total_skipped} skipped[/bold]") + console.print( + f"[bold]Total: {total_created} created, {total_skipped} skipped, " + f"{total_links} dataset links added[/bold]" + ) console.print(f"[bold]Time: {elapsed:.1f}s[/bold]") diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index b741750..91f1ca0 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -51,6 +51,7 @@ ProgramStatistics, ProgramStatisticsRead, Region, + RegionDatasetLink, Report, ReportStatus, Simulation, @@ -160,6 +161,10 @@ class EconomicImpactRequest(BaseModel): dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" ) + year: int | None = Field( + default=None, + description="Year for the analysis (e.g., 2026). Selects the dataset for that year. Uses latest available if omitted.", + ) class SimulationInfo(BaseModel): @@ -269,6 +274,7 @@ def _get_or_create_simulation( filter_field: str | None = None, filter_value: str | None = None, region_id: UUID | None = None, + year: int | None = None, ) -> Simulation: """Get existing simulation or create a new one.""" sim_id = _get_deterministic_simulation_id( @@ -298,6 +304,7 @@ def _get_or_create_simulation( filter_field=filter_field, filter_value=filter_value, region_id=region_id, + year=year, ) session.add(simulation) session.commit() @@ -1062,6 +1069,10 @@ def _resolve_dataset_and_region( ) -> tuple[Dataset, Region | None]: """Resolve dataset from request, optionally via region lookup. + When a region is provided, the dataset is resolved from the region_datasets + join table. If request.year is set, the dataset for that year is selected; + otherwise the latest available year is used. + Returns: Tuple of (dataset, region) where region is None if dataset_id was provided directly. """ @@ -1081,11 +1092,23 @@ def _resolve_dataset_and_region( detail=f"Region '{request.region}' not found for model {model_name}", ) - dataset = session.get(Dataset, region.dataset_id) + # Resolve dataset from join table, filtered by year if provided + query = ( + select(Dataset) + .join(RegionDatasetLink) + .where(RegionDatasetLink.region_id == region.id) + ) + if request.year: + query = query.where(Dataset.year == request.year) + else: + query = query.order_by(Dataset.year.desc()) # type: ignore + dataset = session.exec(query).first() + if not dataset: + year_msg = f" for year {request.year}" if request.year else "" raise HTTPException( status_code=404, - detail=f"Dataset for region '{request.region}' not found", + detail=f"No dataset found for region '{request.region}'{year_msg}", ) return dataset, region @@ -1143,6 +1166,7 @@ def economic_impact( filter_field=filter_field, filter_value=filter_value, region_id=region.id if region else None, + year=dataset.year, ) reform_sim = _get_or_create_simulation( @@ -1155,6 +1179,7 @@ def economic_impact( filter_field=filter_field, filter_value=filter_value, region_id=region.id if region else None, + year=dataset.year, ) # Get or create report @@ -1230,6 +1255,10 @@ class EconomyCustomRequest(BaseModel): dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" ) + year: int | None = Field( + default=None, + description="Year for the analysis. Uses latest available if omitted.", + ) modules: list[str] = Field( description="List of module names to compute (see GET /analysis/options)" ) @@ -1296,6 +1325,7 @@ def economy_custom( region=request.region, policy_id=request.policy_id, dynamic_id=request.dynamic_id, + year=request.year, ) dataset, region_obj = _resolve_dataset_and_region(impact_request, session) @@ -1319,6 +1349,7 @@ def economy_custom( filter_field=filter_field, filter_value=filter_value, region_id=region_obj.id if region_obj else None, + year=dataset.year, ) reform_sim = _get_or_create_simulation( @@ -1331,6 +1362,7 @@ def economy_custom( filter_field=filter_field, filter_value=filter_value, region_id=region_obj.id if region_obj else None, + year=dataset.year, ) label = f"Custom analysis: {request.tax_benefit_model_name}" diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index 9c63ae2..a44cc15 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -20,6 +20,7 @@ Household, Policy, Region, + RegionDatasetLink, Simulation, SimulationRead, SimulationStatus, @@ -89,6 +90,10 @@ class EconomySimulationRequest(BaseModel): default=None, description="Optional behavioural response specification ID", ) + year: int | None = Field( + default=None, + description="Year for the simulation. Uses latest available if omitted.", + ) class EconomySimulationResponse(BaseModel): @@ -115,8 +120,14 @@ def _resolve_economy_dataset( region_code: str | None, dataset_id: UUID | None, session: Session, + year: int | None = None, ) -> tuple[Dataset, Region | None]: - """Resolve dataset from region code or dataset_id for economy simulations.""" + """Resolve dataset from region code or dataset_id for economy simulations. + + When a region is provided, the dataset is resolved from the region_datasets + join table. If year is set, the dataset for that year is selected; + otherwise the latest available year is used. + """ if region_code: model_name = tax_benefit_model_name.replace("_", "-") region = session.exec( @@ -130,11 +141,24 @@ def _resolve_economy_dataset( status_code=404, detail=f"Region '{region_code}' not found for model {model_name}", ) - dataset = session.get(Dataset, region.dataset_id) + + # Resolve dataset from join table + query = ( + select(Dataset) + .join(RegionDatasetLink) + .where(RegionDatasetLink.region_id == region.id) + ) + if year: + query = query.where(Dataset.year == year) + else: + query = query.order_by(Dataset.year.desc()) # type: ignore + dataset = session.exec(query).first() + if not dataset: + year_msg = f" for year {year}" if year else "" raise HTTPException( status_code=404, - detail=f"Dataset for region '{region_code}' not found", + detail=f"No dataset found for region '{region_code}'{year_msg}", ) return dataset, region @@ -298,6 +322,7 @@ def create_economy_simulation( request.region, request.dataset_id, session, + year=request.year, ) # Validate policy exists (if provided) @@ -327,6 +352,7 @@ def create_economy_simulation( filter_field=filter_field, filter_value=filter_value, region_id=region.id if region else None, + year=dataset.year, ) return _build_economy_response(simulation, region) diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index ae660a3..c6ea679 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -61,6 +61,7 @@ from .policy import Policy, PolicyCreate, PolicyRead from .poverty import Poverty, PovertyCreate, PovertyRead from .region import Region, RegionCreate, RegionRead +from .region_dataset_link import RegionDatasetLink from .program_statistics import ( ProgramStatistics, ProgramStatisticsCreate, @@ -174,6 +175,7 @@ "PovertyRead", "Region", "RegionCreate", + "RegionDatasetLink", "RegionRead", "ProgramStatistics", "ProgramStatisticsCreate", diff --git a/src/policyengine_api/models/region.py b/src/policyengine_api/models/region.py index 7c87a00..6672309 100644 --- a/src/policyengine_api/models/region.py +++ b/src/policyengine_api/models/region.py @@ -6,6 +6,8 @@ from sqlmodel import Field, Relationship, SQLModel +from .region_dataset_link import RegionDatasetLink + if TYPE_CHECKING: from .dataset import Dataset from .tax_benefit_model import TaxBenefitModel @@ -23,7 +25,6 @@ class RegionBase(SQLModel): parent_code: str | None = None # e.g., "us", "state/ca" state_code: str | None = None # For US regions state_name: str | None = None # For US regions - dataset_id: UUID = Field(foreign_key="datasets.id") tax_benefit_model_id: UUID = Field(foreign_key="tax_benefit_models.id") @@ -32,7 +33,8 @@ class Region(RegionBase, table=True): Regions represent geographic areas for analysis, from countries down to states, congressional districts, cities, etc. - Each region has a dataset (either dedicated or filtered from parent). + Each region links to multiple datasets (one per year) via the + region_datasets join table. """ __tablename__ = "regions" @@ -42,7 +44,7 @@ class Region(RegionBase, table=True): updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # Relationships - dataset: "Dataset" = Relationship() + datasets: list["Dataset"] = Relationship(link_model=RegionDatasetLink) tax_benefit_model: "TaxBenefitModel" = Relationship() diff --git a/src/policyengine_api/models/region_dataset_link.py b/src/policyengine_api/models/region_dataset_link.py new file mode 100644 index 0000000..9801cfc --- /dev/null +++ b/src/policyengine_api/models/region_dataset_link.py @@ -0,0 +1,19 @@ +"""Link table for many-to-many relationship between regions and datasets.""" + +from uuid import UUID + +from sqlmodel import Field, SQLModel + + +class RegionDatasetLink(SQLModel, table=True): + """Join table linking regions to their available datasets. + + Each region can have multiple datasets (one per year), and each + dataset can be shared across multiple regions (e.g., a state dataset + used by both the state region and its place/city regions). + """ + + __tablename__ = "region_datasets" + + region_id: UUID = Field(foreign_key="regions.id", primary_key=True) + dataset_id: UUID = Field(foreign_key="datasets.id", primary_key=True) diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index f2711cc..e38ac18 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -60,6 +60,8 @@ class SimulationBase(SQLModel): description="Value to match when filtering (e.g., '44000', 'ENGLAND')", ) + year: int | None = None + class Simulation(SimulationBase, table=True): """Simulation database model.""" diff --git a/test_fixtures/fixtures_regions.py b/test_fixtures/fixtures_regions.py index e95e0d8..8cbf6b0 100644 --- a/test_fixtures/fixtures_regions.py +++ b/test_fixtures/fixtures_regions.py @@ -7,6 +7,7 @@ from policyengine_api.models import ( Dataset, Region, + RegionDatasetLink, Simulation, SimulationStatus, TaxBenefitModel, @@ -116,7 +117,7 @@ def create_region( filter_field: str | None = None, filter_value: str | None = None, ) -> Region: - """Create and persist a Region.""" + """Create and persist a Region with a dataset link.""" region = Region( code=code, label=label, @@ -124,12 +125,17 @@ def create_region( requires_filter=requires_filter, filter_field=filter_field, filter_value=filter_value, - dataset_id=dataset.id, tax_benefit_model_id=model.id, ) session.add(region) session.commit() session.refresh(region) + + # Create the join table link + link = RegionDatasetLink(region_id=region.id, dataset_id=dataset.id) + session.add(link) + session.commit() + return region diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py index 314afe5..21f6253 100644 --- a/test_fixtures/fixtures_simulations_standalone.py +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -7,6 +7,7 @@ Household, Policy, Region, + RegionDatasetLink, Simulation, SimulationStatus, SimulationType, @@ -110,7 +111,7 @@ def create_region( filter_field: str | None = None, filter_value: str | None = None, ) -> Region: - """Create and persist a Region record.""" + """Create and persist a Region record with a dataset link.""" region = Region( code=code, label=label, @@ -118,12 +119,17 @@ def create_region( requires_filter=requires_filter, filter_field=filter_field, filter_value=filter_value, - dataset_id=dataset.id, tax_benefit_model_id=model.id, ) session.add(region) session.commit() session.refresh(region) + + # Create the join table link + link = RegionDatasetLink(region_id=region.id, dataset_id=dataset.id) + session.add(link) + session.commit() + return region