diff --git a/.claude/skills/database-migrations.md b/.claude/skills/database-migrations.md new file mode 100644 index 0000000..fedbef8 --- /dev/null +++ b/.claude/skills/database-migrations.md @@ -0,0 +1,301 @@ +# Database Migration Guidelines + +## Overview + +This project uses **Alembic** for database migrations with **SQLModel** models. Alembic is the industry-standard migration tool for SQLAlchemy/SQLModel projects. + +**CRITICAL**: SQL migrations are the single source of truth for database schema. All table creation and schema changes MUST go through Alembic migrations. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ SQLModel Models (src/policyengine_api/models/) │ +│ - Define Python classes │ +│ - Used for ORM queries │ +│ - NOT the source of truth for schema │ +└─────────────────────────────────────────────────────────────┘ + │ + │ alembic revision --autogenerate + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Alembic Migrations (alembic/versions/) │ +│ - Create/alter tables │ +│ - Add indexes, constraints │ +│ - SOURCE OF TRUTH for schema │ +└─────────────────────────────────────────────────────────────┘ + │ + │ alembic upgrade head + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ PostgreSQL Database (Supabase) │ +│ - Actual schema │ +│ - Tracked by alembic_version table │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Essential Rules + +### 1. NEVER use SQLModel.metadata.create_all() for schema creation + +The old pattern of using `SQLModel.metadata.create_all()` is deprecated. All tables are created via Alembic migrations. + +### 2. Every schema change requires a migration + +When you modify a SQLModel model (add column, change type, add index), you MUST: +1. Update the model in `src/policyengine_api/models/` +2. Generate a migration: `uv run alembic revision --autogenerate -m "Description"` +3. **Read and verify the generated migration** (see below) +4. Apply it: `uv run alembic upgrade head` + +### 3. ALWAYS verify auto-generated migrations before applying + +**This is critical for AI agents.** After running `alembic revision --autogenerate`, you MUST: + +1. **Read the generated migration file** in `alembic/versions/` +2. **Verify the `upgrade()` function** contains the expected changes: + - Correct table/column names + - Correct column types (e.g., `sa.String()`, `sa.Uuid()`, `sa.Integer()`) + - Proper foreign key references + - Appropriate nullable settings +3. **Verify the `downgrade()` function** properly reverses the changes +4. **Check for Alembic autogenerate limitations:** + - It may miss renamed columns (shows as drop + add instead) + - It may not detect some index changes + - It doesn't handle data migrations +5. **Edit the migration if needed** before applying + +Example verification: +```python +# Generated migration - verify this looks correct: +def upgrade() -> None: + op.add_column('users', sa.Column('phone', sa.String(), nullable=True)) + +def downgrade() -> None: + op.drop_column('users', 'phone') +``` + +**Never blindly apply a migration without reading it first.** + +### 4. Migrations must be self-contained + +Each migration should: +- Create tables it needs (never assume they exist from Python) +- Include both `upgrade()` and `downgrade()` functions +- Be idempotent where possible (use `IF NOT EXISTS` patterns) + +### 5. Never use conditional logic based on table existence + +Migrations should NOT check if tables exist. Instead: +- Ensure migrations run in the correct order (use `down_revision`) +- The initial migration creates all base tables +- Subsequent migrations build on that foundation + +## Common Commands + +```bash +# Apply all pending migrations +uv run alembic upgrade head + +# Generate migration from model changes +uv run alembic revision --autogenerate -m "Add users email index" + +# Create empty migration (for manual SQL) +uv run alembic revision -m "Add custom index" + +# Check current migration state +uv run alembic current + +# Show migration history +uv run alembic history + +# Downgrade one revision +uv run alembic downgrade -1 + +# Downgrade to specific revision +uv run alembic downgrade +``` + +## Local Development Workflow + +```bash +# 1. Start Supabase +supabase start + +# 2. Initialize database (runs migrations + applies RLS policies) +uv run python scripts/init.py + +# 3. Seed data +uv run python scripts/seed.py +``` + +### Reset database (DESTRUCTIVE) + +```bash +uv run python scripts/init.py --reset +``` + +## Adding a New Model + +1. Create the model in `src/policyengine_api/models/` + +```python +# src/policyengine_api/models/my_model.py +from sqlmodel import SQLModel, Field +from uuid import UUID, uuid4 + +class MyModel(SQLModel, table=True): + __tablename__ = "my_models" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str +``` + +2. Export in `__init__.py`: + +```python +# src/policyengine_api/models/__init__.py +from .my_model import MyModel +``` + +3. Generate migration: + +```bash +uv run alembic revision --autogenerate -m "Add my_models table" +``` + +4. Review the generated migration in `alembic/versions/` + +5. Apply the migration: + +```bash +uv run alembic upgrade head +``` + +6. Update `scripts/init.py` to include the table in RLS policies if needed. + +## Adding an Index + +1. Generate a migration: + +```bash +uv run alembic revision -m "Add index on users.email" +``` + +2. Edit the migration: + +```python +def upgrade() -> None: + op.create_index("idx_users_email", "users", ["email"]) + +def downgrade() -> None: + op.drop_index("idx_users_email", "users") +``` + +3. Apply: + +```bash +uv run alembic upgrade head +``` + +## Production Considerations + +### Applying migrations to production + +1. Migrations are automatically applied when deploying +2. Always test migrations locally first +3. For data migrations, consider running during low-traffic periods + +### Transitioning production from old system to Alembic + +Production databases that were created before Alembic (using the old `SQLModel.metadata.create_all()` approach or raw Supabase migrations) need special handling. Running `alembic upgrade head` would fail because the tables already exist. + +**The solution: `alembic stamp`** + +The `alembic stamp` command marks a migration as "already applied" without actually running it. This tells Alembic "the database is already at this state, start tracking from here." + +**How it works:** + +1. `alembic stamp ` inserts a row into the `alembic_version` table with the specified revision ID +2. Alembic now thinks that migration (and all migrations before it) have been applied +3. Future migrations will run normally starting from that point + +**Step-by-step production transition:** + +```bash +# 1. Connect to production database +# (set SUPABASE_DB_URL or other connection env vars) + +# 2. Check if alembic_version table exists +# If not, Alembic will create it automatically + +# 3. Verify production schema matches the initial migration +# Compare tables/columns in production against alembic/versions/20260204_d6e30d3b834d_initial_schema.py + +# 4. Stamp the initial migration as applied +uv run alembic stamp d6e30d3b834d + +# 5. If production also has the indexes from the second migration, stamp that too +uv run alembic stamp a17ac554f4aa + +# 6. Verify the stamp worked +uv run alembic current +# Should show: a17ac554f4aa (head) + +# 7. From now on, new migrations will apply normally +uv run alembic upgrade head +``` + +**Handling partially applied migrations:** + +If production has some but not all changes from a migration: + +1. Manually apply the missing changes via SQL +2. Then stamp that migration as complete +3. Or: create a new migration that only adds the missing pieces + +**After stamping:** + +- All future schema changes go through Alembic migrations +- Developers generate migrations with `alembic revision --autogenerate` +- Deployments run `alembic upgrade head` to apply pending migrations +- The `alembic_version` table tracks what's been applied + +## File Structure + +``` +alembic/ +├── env.py # Alembic configuration (imports models, sets DB URL) +├── script.py.mako # Template for new migrations +├── versions/ # Migration files +│ ├── 20260204_d6e30d3b834d_initial_schema.py +│ └── 20260204_a17ac554f4aa_add_parameter_values_indexes.py +alembic.ini # Alembic settings + +supabase/ +├── migrations/ # Supabase-specific migrations (storage only) +│ ├── 20241119000000_storage_bucket.sql +│ └── 20241121000000_storage_policies.sql +└── migrations_archived/ # Old table migrations (now in Alembic) +``` + +## Troubleshooting + +### "Target database is not up to date" + +Run `alembic upgrade head` to apply pending migrations. + +### "Can't locate revision" + +The alembic_version table has a revision that doesn't exist in your migrations folder. This can happen if someone deleted a migration file. Fix by stamping to a known revision: + +```bash +alembic stamp head # If tables are current +alembic stamp d6e30d3b834d # If at initial schema +``` + +### "Table already exists" + +The migration is trying to create a table that already exists. Options: +1. If this is a fresh setup, drop and recreate: `uv run python scripts/init.py --reset` +2. If in production, stamp the migration as applied: `alembic stamp ` diff --git a/.env.example b/.env.example index f2d1db2..282ba16 100644 --- a/.env.example +++ b/.env.example @@ -61,11 +61,22 @@ AGENT_USE_MODAL=false POLICYENGINE_API_URL=http://localhost:8000 # ============================================================================= -# MODAL SECRETS (for production) +# MODAL SERVERLESS COMPUTE # ============================================================================= -# Modal secrets are NOT set via .env - they're managed via Modal CLI: -# -# 1. modal secret create policyengine-db \ +# Modal environment to use (main, staging, testing). +# Only relevant when AGENT_USE_MODAL=true. +# The Modal SDK authenticates via ~/.modal.toml (from `modal setup`). +# For production (Cloud Run), set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET instead. +MODAL_ENVIRONMENT=main + +# For production (Cloud Run) only: +# MODAL_TOKEN_ID=ak-... +# MODAL_TOKEN_SECRET=as-... + +# ============================================================================= +# MODAL SECRETS (managed via Modal CLI, not .env) +# ============================================================================= +# 1. modal secret create policyengine-db [--env testing] \ # DATABASE_URL='postgresql://...' \ # SUPABASE_URL='https://...' \ # SUPABASE_KEY='...' \ @@ -75,5 +86,5 @@ POLICYENGINE_API_URL=http://localhost:8000 # 2. modal secret create anthropic-api-key \ # ANTHROPIC_API_KEY='sk-ant-...' # -# 3. modal secret create logfire-token \ +# 3. modal secret create policyengine-logfire \ # LOGFIRE_TOKEN='...' diff --git a/.gitignore b/.gitignore index 2f8c3e6..bb80186 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,6 @@ docs/.env.local data/ *.h5 *.db + +# macOS +.DS_Store diff --git a/CLAUDE.md b/CLAUDE.md index 2df55fc..d6fb240 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,7 +75,21 @@ Use `gh` CLI for GitHub operations to ensure Actions run correctly. ## Database -`make init` resets tables and storage. `make seed` populates UK/US models with variables, parameters, and datasets. +This project uses **Alembic** for database migrations. See `.claude/skills/database-migrations.md` for detailed guidelines. + +**Key rules:** +- All schema changes go through Alembic migrations (never use `SQLModel.metadata.create_all()`) +- After modifying a model: `uv run alembic revision --autogenerate -m "Description"` +- Apply migrations: `uv run alembic upgrade head` + +**Local development:** +```bash +supabase start # Start local Supabase +uv run python scripts/init.py # Run migrations + apply RLS policies +uv run python scripts/seed.py # Seed data +``` + +`scripts/init.py --reset` drops and recreates everything (destructive). ## Modal sandbox + Claude Code CLI gotchas diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..ed54635 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,145 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names +# Prepend with date for easier chronological ordering +file_template = %%(year)d%%(month).2d%%(day).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL - This is overridden by env.py which reads from application settings. +# The placeholder below is only used if env.py doesn't set it. +sqlalchemy.url = postgresql://placeholder:placeholder@localhost/placeholder + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# NOTE: ruff is in dev dependencies, so this hook only works when dev deps are installed +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..f930498 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,87 @@ +"""Alembic environment configuration for SQLModel migrations. + +This module configures Alembic to: +1. Use the database URL from application settings +2. Import all SQLModel models for autogenerate support +3. Run migrations in both offline and online modes +""" + +import sys +from logging.config import fileConfig +from pathlib import Path + +from sqlalchemy import engine_from_config, pool +from sqlmodel import SQLModel + +from alembic import context + +# Add src to path so we can import policyengine_api +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Import all models to register them with SQLModel.metadata +# This is required for autogenerate to detect model changes +from policyengine_api import models # noqa: F401 +from policyengine_api.config.settings import settings + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Override sqlalchemy.url with the actual database URL from settings +config.set_main_option("sqlalchemy.url", settings.database_url) + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# SQLModel metadata for autogenerate support +# This allows Alembic to detect changes in your SQLModel models +target_metadata = SQLModel.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20260205_0002_add_user_policies.py b/alembic/versions/20260205_0002_add_user_policies.py new file mode 100644 index 0000000..5f061de --- /dev/null +++ b/alembic/versions/20260205_0002_add_user_policies.py @@ -0,0 +1,94 @@ +"""add_user_policies + +Revision ID: 0002_user_policies +Revises: 36f9d434e95b +Create Date: 2026-02-05 + +This migration adds: +1. tax_benefit_model_id foreign key to policies table +2. user_policies table for user-policy associations + +Note: user_id in user_policies is NOT a foreign key to users table. +It's a client-generated UUID stored in localStorage, allowing anonymous +users to save policies without authentication. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0002_user_policies" +down_revision: Union[str, Sequence[str], None] = "36f9d434e95b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add user_policies table and policy.tax_benefit_model_id.""" + # Add tax_benefit_model_id to policies table + op.add_column( + "policies", sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False) + ) + op.create_index( + op.f("ix_policies_tax_benefit_model_id"), + "policies", + ["tax_benefit_model_id"], + unique=False, + ) + op.create_foreign_key( + "fk_policies_tax_benefit_model_id", + "policies", + "tax_benefit_models", + ["tax_benefit_model_id"], + ["id"], + ) + + # Create user_policies table + # Note: user_id is NOT a foreign key - it's a client-generated UUID from localStorage + op.create_table( + "user_policies", + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("policy_id", sa.Uuid(), nullable=False), + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_policies_policy_id"), + "user_policies", + ["policy_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_policies_user_id"), "user_policies", ["user_id"], unique=False + ) + op.create_index( + op.f("ix_user_policies_country_id"), + "user_policies", + ["country_id"], + unique=False, + ) + + +def downgrade() -> None: + """Remove user_policies table and policy.tax_benefit_model_id.""" + # Drop user_policies table + op.drop_index(op.f("ix_user_policies_country_id"), table_name="user_policies") + op.drop_index(op.f("ix_user_policies_user_id"), table_name="user_policies") + op.drop_index(op.f("ix_user_policies_policy_id"), table_name="user_policies") + op.drop_table("user_policies") + + # Remove tax_benefit_model_id from policies + op.drop_constraint( + "fk_policies_tax_benefit_model_id", "policies", type_="foreignkey" + ) + op.drop_index(op.f("ix_policies_tax_benefit_model_id"), table_name="policies") + op.drop_column("policies", "tax_benefit_model_id") diff --git a/alembic/versions/20260207_36f9d434e95b_initial_schema.py b/alembic/versions/20260207_36f9d434e95b_initial_schema.py new file mode 100644 index 0000000..8e707ca --- /dev/null +++ b/alembic/versions/20260207_36f9d434e95b_initial_schema.py @@ -0,0 +1,498 @@ +"""initial schema + +Revision ID: 36f9d434e95b +Revises: +Create Date: 2026-02-07 01:52:16.497121 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "36f9d434e95b" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "dynamics", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "policies", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "tax_benefit_models", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "users", + sa.Column("first_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("last_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("email", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) + op.create_table( + "datasets", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("filepath", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("is_output_dataset", sa.Boolean(), nullable=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["tax_benefit_model_id"], + ["tax_benefit_models.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "household_jobs", + sa.Column( + "tax_benefit_model_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("request_data", sa.JSON(), nullable=True), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column( + "status", + sa.Enum( + "PENDING", "RUNNING", "COMPLETED", "FAILED", name="householdjobstatus" + ), + nullable=False, + ), + sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("result", sa.JSON(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("started_at", sa.DateTime(), nullable=True), + sa.Column("completed_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["dynamic_id"], + ["dynamics.id"], + ), + sa.ForeignKeyConstraint( + ["policy_id"], + ["policies.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "tax_benefit_model_versions", + sa.Column("model_id", sa.Uuid(), nullable=False), + sa.Column("version", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["model_id"], + ["tax_benefit_models.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "dataset_versions", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["tax_benefit_model_id"], + ["tax_benefit_models.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "parameters", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("data_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("unit", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], + ["tax_benefit_model_versions.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "simulations", + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("output_dataset_id", sa.Uuid(), nullable=True), + sa.Column( + "status", + sa.Enum( + "PENDING", "RUNNING", "COMPLETED", "FAILED", name="simulationstatus" + ), + nullable=False, + ), + sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("started_at", sa.DateTime(), nullable=True), + sa.Column("completed_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["dynamic_id"], + ["dynamics.id"], + ), + sa.ForeignKeyConstraint( + ["output_dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["policy_id"], + ["policies.id"], + ), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], + ["tax_benefit_model_versions.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "variables", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("data_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("possible_values", sa.JSON(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], + ["tax_benefit_model_versions.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "parameter_values", + sa.Column("parameter_id", sa.Uuid(), nullable=False), + sa.Column("value_json", sa.JSON(), nullable=True), + sa.Column("start_date", sa.DateTime(), nullable=False), + sa.Column("end_date", sa.DateTime(), nullable=True), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["dynamic_id"], + ["dynamics.id"], + ), + sa.ForeignKeyConstraint( + ["parameter_id"], + ["parameters.id"], + ), + sa.ForeignKeyConstraint( + ["policy_id"], + ["policies.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "reports", + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("markdown", sa.Text(), nullable=True), + sa.Column("parent_report_id", sa.Uuid(), nullable=True), + sa.Column( + "status", + sa.Enum("PENDING", "RUNNING", "COMPLETED", "FAILED", name="reportstatus"), + nullable=False, + ), + sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=True), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["parent_report_id"], + ["reports.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "aggregates", + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "aggregate_type", + sa.Enum("SUM", "MEAN", "COUNT", name="aggregatetype"), + nullable=False, + ), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("filter_config", sa.JSON(), nullable=True), + sa.Column( + "status", + sa.Enum( + "PENDING", "RUNNING", "COMPLETED", "FAILED", name="aggregatestatus" + ), + nullable=False, + ), + sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("result", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.ForeignKeyConstraint( + ["simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "change_aggregates", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "aggregate_type", + sa.Enum("SUM", "MEAN", "COUNT", name="changeaggregatetype"), + nullable=False, + ), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("filter_config", sa.JSON(), nullable=True), + sa.Column("change_geq", sa.Float(), nullable=True), + sa.Column("change_leq", sa.Float(), nullable=True), + sa.Column( + "status", + sa.Enum( + "PENDING", + "RUNNING", + "COMPLETED", + "FAILED", + name="changeaggregatestatus", + ), + nullable=False, + ), + sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("result", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "decile_impacts", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column( + "income_variable", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("decile", sa.Integer(), nullable=False), + sa.Column("quantiles", sa.Integer(), nullable=False), + sa.Column("baseline_mean", sa.Float(), nullable=True), + sa.Column("reform_mean", sa.Float(), nullable=True), + sa.Column("absolute_change", sa.Float(), nullable=True), + sa.Column("relative_change", sa.Float(), nullable=True), + sa.Column("count_better_off", sa.Float(), nullable=True), + sa.Column("count_worse_off", sa.Float(), nullable=True), + sa.Column("count_no_change", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "inequality", + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column( + "income_variable", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("gini", sa.Float(), nullable=True), + sa.Column("top_10_share", sa.Float(), nullable=True), + sa.Column("top_1_share", sa.Float(), nullable=True), + sa.Column("bottom_50_share", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.ForeignKeyConstraint( + ["simulation_id"], + ["simulations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "poverty", + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("poverty_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("filter_variable", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("headcount", sa.Float(), nullable=True), + sa.Column("total_population", sa.Float(), nullable=True), + sa.Column("rate", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.ForeignKeyConstraint( + ["simulation_id"], + ["simulations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "program_statistics", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("program_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_tax", sa.Boolean(), nullable=False), + sa.Column("baseline_total", sa.Float(), nullable=True), + sa.Column("reform_total", sa.Float(), nullable=True), + sa.Column("change", sa.Float(), nullable=True), + sa.Column("baseline_count", sa.Float(), nullable=True), + sa.Column("reform_count", sa.Float(), nullable=True), + sa.Column("winners", sa.Float(), nullable=True), + sa.Column("losers", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("program_statistics") + op.drop_table("poverty") + op.drop_table("inequality") + op.drop_table("decile_impacts") + op.drop_table("change_aggregates") + op.drop_table("aggregates") + op.drop_table("reports") + op.drop_table("parameter_values") + op.drop_table("variables") + op.drop_table("simulations") + op.drop_table("parameters") + op.drop_table("dataset_versions") + op.drop_table("tax_benefit_model_versions") + op.drop_table("household_jobs") + op.drop_table("datasets") + op.drop_index(op.f("ix_users_email"), table_name="users") + op.drop_table("users") + op.drop_table("tax_benefit_models") + op.drop_table("policies") + op.drop_table("dynamics") + # ### end Alembic commands ### diff --git a/alembic/versions/20260207_f419b5f4acba_add_household_support.py b/alembic/versions/20260207_f419b5f4acba_add_household_support.py new file mode 100644 index 0000000..f781194 --- /dev/null +++ b/alembic/versions/20260207_f419b5f4acba_add_household_support.py @@ -0,0 +1,123 @@ +"""add household support + +Revision ID: f419b5f4acba +Revises: 36f9d434e95b +Create Date: 2026-02-07 01:56:31.064511 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f419b5f4acba" +down_revision: Union[str, Sequence[str], None] = "36f9d434e95b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "households", + sa.Column( + "tax_benefit_model_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("household_data", sa.JSON(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_household_associations", + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("household_id", sa.Uuid(), nullable=False), + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["household_id"], + ["households.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_household_associations_household_id"), + "user_household_associations", + ["household_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_household_associations_user_id"), + "user_household_associations", + ["user_id"], + unique=False, + ) + op.add_column( + "reports", + sa.Column("report_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + # Create enum type first + simulationtype = postgresql.ENUM( + "HOUSEHOLD", "ECONOMY", name="simulationtype", create_type=False + ) + simulationtype.create(op.get_bind(), checkfirst=True) + op.add_column( + "simulations", + sa.Column( + "simulation_type", + sa.Enum("HOUSEHOLD", "ECONOMY", name="simulationtype", create_type=False), + nullable=False, + ), + ) + op.add_column("simulations", sa.Column("household_id", sa.Uuid(), nullable=True)) + op.add_column( + "simulations", + sa.Column( + "household_result", postgresql.JSON(astext_type=sa.Text()), nullable=True + ), + ) + op.alter_column("simulations", "dataset_id", existing_type=sa.UUID(), nullable=True) + op.create_foreign_key(None, "simulations", "households", ["household_id"], ["id"]) + op.add_column("variables", sa.Column("default_value", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("variables", "default_value") + op.drop_constraint(None, "simulations", type_="foreignkey") + op.alter_column( + "simulations", "dataset_id", existing_type=sa.UUID(), nullable=False + ) + op.drop_column("simulations", "household_result") + op.drop_column("simulations", "household_id") + op.drop_column("simulations", "simulation_type") + # Drop enum type + postgresql.ENUM("HOUSEHOLD", "ECONOMY", name="simulationtype").drop( + op.get_bind(), checkfirst=True + ) + op.drop_column("reports", "report_type") + op.drop_index( + op.f("ix_user_household_associations_user_id"), + table_name="user_household_associations", + ) + op.drop_index( + op.f("ix_user_household_associations_household_id"), + table_name="user_household_associations", + ) + op.drop_table("user_household_associations") + op.drop_table("households") + # ### end Alembic commands ### diff --git a/alembic/versions/20260210_add_regions_table.py b/alembic/versions/20260210_add_regions_table.py new file mode 100644 index 0000000..390d355 --- /dev/null +++ b/alembic/versions/20260210_add_regions_table.py @@ -0,0 +1,63 @@ +"""add regions table + +Revision ID: a1b2c3d4e5f6 +Revises: f419b5f4acba +Create Date: 2026-02-10 12:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = "f419b5f4acba" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create regions table.""" + op.create_table( + "regions", + sa.Column("code", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("region_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("requires_filter", sa.Boolean(), nullable=False, default=False), + sa.Column("filter_field", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("filter_value", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("parent_code", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("state_code", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("state_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["tax_benefit_model_id"], + ["tax_benefit_models.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # Create unique constraint on (code, tax_benefit_model_id) + op.create_index( + "ix_regions_code_model", + "regions", + ["code", "tax_benefit_model_id"], + unique=True, + ) + + +def downgrade() -> None: + """Drop regions table.""" + op.drop_index("ix_regions_code_model", table_name="regions") + op.drop_table("regions") diff --git a/alembic/versions/20260218_add_simulation_filter_columns.py b/alembic/versions/20260218_add_simulation_filter_columns.py new file mode 100644 index 0000000..d7121e8 --- /dev/null +++ b/alembic/versions/20260218_add_simulation_filter_columns.py @@ -0,0 +1,41 @@ +"""add filter_field and filter_value to simulations + +Revision ID: add_sim_filters +Revises: merge_001 +Create Date: 2026-02-18 + +The Simulation model already has filter_field and filter_value fields +(used for regional economy simulations), but no migration added them +to the database. This brings the schema in line with the model. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_sim_filters" +down_revision: Union[str, Sequence[str], None] = "merge_001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add filter_field and filter_value columns to simulations table.""" + op.add_column( + "simulations", + sa.Column("filter_field", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "simulations", + sa.Column("filter_value", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + +def downgrade() -> None: + """Remove filter_field and filter_value columns from simulations table.""" + op.drop_column("simulations", "filter_value") + op.drop_column("simulations", "filter_field") diff --git a/alembic/versions/20260218_drop_parent_report_id.py b/alembic/versions/20260218_drop_parent_report_id.py new file mode 100644 index 0000000..7a5b0f4 --- /dev/null +++ b/alembic/versions/20260218_drop_parent_report_id.py @@ -0,0 +1,42 @@ +"""drop parent_report_id from reports + +Revision ID: drop_parent_report +Revises: add_sim_filters +Create Date: 2026-02-18 + +Remove the unused self-referential parent_report_id foreign key from +the reports table. No code reads or writes this column. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "drop_parent_report" +down_revision: Union[str, Sequence[str], None] = "add_sim_filters" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Drop parent_report_id column and its FK constraint.""" + op.drop_constraint("reports_parent_report_id_fkey", "reports", type_="foreignkey") + op.drop_column("reports", "parent_report_id") + + +def downgrade() -> None: + """Re-add parent_report_id column and FK constraint.""" + op.add_column( + "reports", + sa.Column("parent_report_id", sa.Uuid(), nullable=True), + ) + op.create_foreign_key( + "reports_parent_report_id_fkey", + "reports", + "reports", + ["parent_report_id"], + ["id"], + ) diff --git a/alembic/versions/20260218_merge_user_policies_and_household_support.py b/alembic/versions/20260218_merge_user_policies_and_household_support.py new file mode 100644 index 0000000..3d7c3e5 --- /dev/null +++ b/alembic/versions/20260218_merge_user_policies_and_household_support.py @@ -0,0 +1,30 @@ +"""merge user_policies and household_support branches + +Revision ID: merge_001 +Revises: 0002_user_policies, a1b2c3d4e5f6 +Create Date: 2026-02-18 + +Merge the two migration branches that diverged from the initial schema: +- 0002_user_policies: added user_policies table + policy.tax_benefit_model_id +- f419b5f4acba → a1b2c3d4e5f6: added household support + regions table + +No schema changes — both branches modify independent tables. +""" + +from typing import Sequence, Union + +# revision identifiers, used by Alembic. +revision: str = "merge_001" +down_revision: tuple[str, str] = ("0002_user_policies", "a1b2c3d4e5f6") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """No schema changes — merge only.""" + pass + + +def downgrade() -> None: + """No schema changes — merge only.""" + pass diff --git a/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py b/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py new file mode 100644 index 0000000..3103120 --- /dev/null +++ b/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py @@ -0,0 +1,68 @@ +"""add user_simulation_associations table + +Revision ID: 621977f3b1aa +Revises: drop_parent_report +Create Date: 2026-02-19 00:37:43.378088 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "621977f3b1aa" +down_revision: Union[str, Sequence[str], None] = "drop_parent_report" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "user_simulation_associations", + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["simulation_id"], + ["simulations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_simulation_associations_simulation_id"), + "user_simulation_associations", + ["simulation_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_simulation_associations_user_id"), + "user_simulation_associations", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_user_simulation_associations_user_id"), + table_name="user_simulation_associations", + ) + op.drop_index( + op.f("ix_user_simulation_associations_simulation_id"), + table_name="user_simulation_associations", + ) + op.drop_table("user_simulation_associations") + # ### end Alembic commands ### diff --git a/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py b/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py new file mode 100644 index 0000000..7a56753 --- /dev/null +++ b/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py @@ -0,0 +1,69 @@ +"""add user_report_associations table + +Revision ID: 9daa015274dd +Revises: 621977f3b1aa +Create Date: 2026-02-19 16:58:03.157551 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "9daa015274dd" +down_revision: Union[str, Sequence[str], None] = "621977f3b1aa" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "user_report_associations", + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=False), + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("last_run_at", sa.DateTime(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_report_associations_report_id"), + "user_report_associations", + ["report_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_report_associations_user_id"), + "user_report_associations", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_user_report_associations_user_id"), + table_name="user_report_associations", + ) + op.drop_index( + op.f("ix_user_report_associations_report_id"), + table_name="user_report_associations", + ) + op.drop_table("user_report_associations") + # ### end Alembic commands ### diff --git a/alembic/versions/20260220_75e5dca14603_add_intra_decile_impacts_table.py b/alembic/versions/20260220_75e5dca14603_add_intra_decile_impacts_table.py new file mode 100644 index 0000000..c79bad4 --- /dev/null +++ b/alembic/versions/20260220_75e5dca14603_add_intra_decile_impacts_table.py @@ -0,0 +1,59 @@ +"""add intra_decile_impacts table + +Revision ID: 75e5dca14603 +Revises: e243279f952f +Create Date: 2026-02-20 16:39:57.387711 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "75e5dca14603" +down_revision: Union[str, Sequence[str], None] = "e243279f952f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "intra_decile_impacts", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("decile", sa.Integer(), nullable=False), + sa.Column("lose_more_than_5pct", sa.Float(), nullable=True), + sa.Column("lose_less_than_5pct", sa.Float(), nullable=True), + sa.Column("no_change", sa.Float(), nullable=True), + sa.Column("gain_less_than_5pct", sa.Float(), nullable=True), + sa.Column("gain_more_than_5pct", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("intra_decile_impacts") + # ### end Alembic commands ### diff --git a/alembic/versions/20260220_83ceeb90bd74_add_constituency_impacts_table.py b/alembic/versions/20260220_83ceeb90bd74_add_constituency_impacts_table.py new file mode 100644 index 0000000..3daf51d --- /dev/null +++ b/alembic/versions/20260220_83ceeb90bd74_add_constituency_impacts_table.py @@ -0,0 +1,65 @@ +"""add_constituency_impacts_table + +Revision ID: 83ceeb90bd74 +Revises: a5e4144467e5 +Create Date: 2026-02-20 21:44:14.780195 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "83ceeb90bd74" +down_revision: Union[str, Sequence[str], None] = "a5e4144467e5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "constituency_impacts", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column( + "constituency_code", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column( + "constituency_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("x", sa.Integer(), nullable=False), + sa.Column("y", sa.Integer(), nullable=False), + sa.Column("average_household_income_change", sa.Float(), nullable=False), + sa.Column("relative_household_income_change", sa.Float(), nullable=False), + sa.Column("population", sa.Float(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("constituency_impacts") + # ### end Alembic commands ### diff --git a/alembic/versions/20260220_8d54837f0fcd_add_decile_type_to_intra_decile_impacts.py b/alembic/versions/20260220_8d54837f0fcd_add_decile_type_to_intra_decile_impacts.py new file mode 100644 index 0000000..bdad920 --- /dev/null +++ b/alembic/versions/20260220_8d54837f0fcd_add_decile_type_to_intra_decile_impacts.py @@ -0,0 +1,42 @@ +"""add_decile_type_to_intra_decile_impacts + +Revision ID: 8d54837f0fcd +Revises: a4ee5758d272 +Create Date: 2026-02-20 22:50:00.125733 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "8d54837f0fcd" +down_revision: Union[str, Sequence[str], None] = "a4ee5758d272" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "intra_decile_impacts", + sa.Column( + "decile_type", + sqlmodel.sql.sqltypes.AutoString(), + server_default="income", + nullable=False, + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("intra_decile_impacts", "decile_type") + # ### end Alembic commands ### diff --git a/alembic/versions/20260220_a4ee5758d272_add_local_authority_impacts_table.py b/alembic/versions/20260220_a4ee5758d272_add_local_authority_impacts_table.py new file mode 100644 index 0000000..4a86ce5 --- /dev/null +++ b/alembic/versions/20260220_a4ee5758d272_add_local_authority_impacts_table.py @@ -0,0 +1,65 @@ +"""add_local_authority_impacts_table + +Revision ID: a4ee5758d272 +Revises: 83ceeb90bd74 +Create Date: 2026-02-20 22:20:00.037965 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a4ee5758d272" +down_revision: Union[str, Sequence[str], None] = "83ceeb90bd74" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "local_authority_impacts", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column( + "local_authority_code", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column( + "local_authority_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("x", sa.Integer(), nullable=False), + sa.Column("y", sa.Integer(), nullable=False), + sa.Column("average_household_income_change", sa.Float(), nullable=False), + sa.Column("relative_household_income_change", sa.Float(), nullable=False), + sa.Column("population", sa.Float(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("local_authority_impacts") + # ### end Alembic commands ### diff --git a/alembic/versions/20260220_a5e4144467e5_add_congressional_district_impacts_table.py b/alembic/versions/20260220_a5e4144467e5_add_congressional_district_impacts_table.py new file mode 100644 index 0000000..0c82389 --- /dev/null +++ b/alembic/versions/20260220_a5e4144467e5_add_congressional_district_impacts_table.py @@ -0,0 +1,59 @@ +"""add_congressional_district_impacts_table + +Revision ID: a5e4144467e5 +Revises: 75e5dca14603 +Create Date: 2026-02-20 20:52:53.243197 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a5e4144467e5" +down_revision: Union[str, Sequence[str], None] = "75e5dca14603" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "congressional_district_impacts", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("district_geoid", sa.Integer(), nullable=False), + sa.Column("state_fips", sa.Integer(), nullable=False), + sa.Column("district_number", sa.Integer(), nullable=False), + sa.Column("average_household_income_change", sa.Float(), nullable=False), + sa.Column("relative_household_income_change", sa.Float(), nullable=False), + sa.Column("population", sa.Float(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("congressional_district_impacts") + # ### end Alembic commands ### diff --git a/alembic/versions/20260220_e243279f952f_add_budget_summary_table.py b/alembic/versions/20260220_e243279f952f_add_budget_summary_table.py new file mode 100644 index 0000000..cf49616 --- /dev/null +++ b/alembic/versions/20260220_e243279f952f_add_budget_summary_table.py @@ -0,0 +1,59 @@ +"""add budget_summary table + +Revision ID: e243279f952f +Revises: 9daa015274dd +Create Date: 2026-02-20 01:50:46.010955 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "e243279f952f" +down_revision: Union[str, Sequence[str], None] = "9daa015274dd" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "budget_summary", + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("entity", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("baseline_total", sa.Float(), nullable=True), + sa.Column("reform_total", sa.Float(), nullable=True), + sa.Column("change", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["baseline_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["reform_simulation_id"], + ["simulations.id"], + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["reports.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("budget_summary") + # ### end Alembic commands ### diff --git a/alembic/versions/20260227_963e91da9298_add_region_id_to_simulations.py b/alembic/versions/20260227_963e91da9298_add_region_id_to_simulations.py new file mode 100644 index 0000000..96184d6 --- /dev/null +++ b/alembic/versions/20260227_963e91da9298_add_region_id_to_simulations.py @@ -0,0 +1,31 @@ +"""add region_id to simulations + +Revision ID: 963e91da9298 +Revises: 8d54837f0fcd +Create Date: 2026-02-27 22:47:47.740784 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "963e91da9298" +down_revision: Union[str, Sequence[str], None] = "8d54837f0fcd" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column("simulations", sa.Column("region_id", sa.Uuid(), nullable=True)) + op.create_foreign_key(None, "simulations", "regions", ["region_id"], ["id"]) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_constraint(None, "simulations", type_="foreignkey") + op.drop_column("simulations", "region_id") diff --git a/alembic/versions/20260303_886921687770_region_datasets_join_table_and_.py b/alembic/versions/20260303_886921687770_region_datasets_join_table_and_.py new file mode 100644 index 0000000..6db7ee4 --- /dev/null +++ b/alembic/versions/20260303_886921687770_region_datasets_join_table_and_.py @@ -0,0 +1,79 @@ +"""region_datasets_join_table_and_simulation_year + +Revision ID: 886921687770 +Revises: 963e91da9298 +Create Date: 2026-03-03 18:56:13.551288 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "886921687770" +down_revision: Union[str, Sequence[str], None] = "963e91da9298" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Create the region_datasets join table + op.create_table( + "region_datasets", + sa.Column("region_id", sa.Uuid(), nullable=False), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["region_id"], + ["regions.id"], + ), + sa.PrimaryKeyConstraint("region_id", "dataset_id"), + ) + + # Migrate existing region->dataset links into the join table + op.execute(""" + INSERT INTO region_datasets (region_id, dataset_id) + SELECT id, dataset_id FROM regions + WHERE dataset_id IS NOT NULL + """) + + # Drop the old FK and column from regions + op.drop_constraint(op.f("regions_dataset_id_fkey"), "regions", type_="foreignkey") + op.drop_column("regions", "dataset_id") + + # Add year column to simulations + op.add_column("simulations", sa.Column("year", sa.Integer(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column("simulations", "year") + op.add_column( + "regions", + sa.Column("dataset_id", sa.UUID(), autoincrement=False, nullable=True), + ) + + # Migrate join table data back to the FK column (pick one dataset per region) + op.execute(""" + UPDATE regions r + SET dataset_id = rd.dataset_id + FROM ( + SELECT DISTINCT ON (region_id) region_id, dataset_id + FROM region_datasets + ORDER BY region_id + ) rd + WHERE r.id = rd.region_id + """) + + op.alter_column("regions", "dataset_id", nullable=False) + op.create_foreign_key( + op.f("regions_dataset_id_fkey"), "regions", "datasets", ["dataset_id"], ["id"] + ) + op.drop_table("region_datasets") diff --git a/docker-compose.yml b/docker-compose.yml index 60e8645..60aa598 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,10 +5,10 @@ services: ports: - "${API_PORT:-8000}:${API_PORT:-8000}" environment: - SUPABASE_URL: http://supabase_kong_policyengine-api-v2:8000 + SUPABASE_URL: http://supabase_kong_policyengine-api-v2-alpha:8000 SUPABASE_KEY: ${SUPABASE_KEY} SUPABASE_SERVICE_KEY: ${SUPABASE_SERVICE_KEY} - SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2:5432/postgres + SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2-alpha:5432/postgres LOGFIRE_TOKEN: ${LOGFIRE_TOKEN} DEBUG: "false" API_PORT: ${API_PORT:-8000} @@ -19,7 +19,7 @@ services: - ./src:/app/src - ./docs/out:/app/docs/out networks: - - supabase_network_policyengine-api-v2 + - supabase_network_policyengine-api-v2-alpha healthcheck: test: ["CMD", "python", "-c", "import httpx; exit(0 if httpx.get('http://localhost:${API_PORT:-8000}/health', timeout=2).status_code == 200 else 1)"] interval: 5s @@ -31,16 +31,16 @@ services: build: . command: pytest tests/ -v environment: - SUPABASE_URL: http://supabase_kong_policyengine-api-v2:8000 + SUPABASE_URL: http://supabase_kong_policyengine-api-v2-alpha:8000 SUPABASE_KEY: ${SUPABASE_KEY} SUPABASE_SERVICE_KEY: ${SUPABASE_SERVICE_KEY} - SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2:5432/postgres + SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2-alpha:5432/postgres LOGFIRE_TOKEN: ${LOGFIRE_TOKEN} volumes: - ./src:/app/src - ./tests:/app/tests networks: - - supabase_network_policyengine-api-v2 + - supabase_network_policyengine-api-v2-alpha depends_on: api: condition: service_healthy @@ -48,5 +48,5 @@ services: - test networks: - supabase_network_policyengine-api-v2: + supabase_network_policyengine-api-v2-alpha: external: true diff --git a/import_state_datasets.py b/import_state_datasets.py new file mode 100644 index 0000000..3797b7b --- /dev/null +++ b/import_state_datasets.py @@ -0,0 +1,484 @@ +"""Download, convert, and upload state & congressional district datasets. + +One-off script to migrate state/district datasets from GCS (old format) +to Supabase (new yearly entity-level format). + +Downloads raw h5 files from GCS, converts them to yearly entity-level files +using policyengine's create_datasets(), uploads to Supabase, and creates +database records. + +Usage: + python import_state_datasets.py AL # State + all AL districts + python import_state_datasets.py CA NY TX # Multiple states + districts + python import_state_datasets.py --all # All 51 states + 436 districts + python import_state_datasets.py AL --state-only # State only, no districts + python import_state_datasets.py --years 2025,2026 + python import_state_datasets.py --skip-upload # Convert only + +Must be run from the policyengine-api-v2-alpha project root (where .env lives). +""" + +import argparse +import json +import logging +import subprocess +import sys +import time +import warnings +from datetime import datetime, timezone +from pathlib import Path + +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path for policyengine_api imports +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from policyengine.countries.us.data import DISTRICT_COUNTS +from rich.console import Console +from sqlmodel import Session, create_engine, select + +from policyengine_api.config.settings import settings +from policyengine_api.models import Dataset, TaxBenefitModel +from policyengine_api.services.storage import upload_dataset_for_seeding + +console = Console() + +GCS_BUCKET = "gs://policyengine-us-data" +TMP_DIR = Path("/tmp/pe_state_data") +DEFAULT_YEARS = list(range(2024, 2036)) + +ALL_STATES = list(DISTRICT_COUNTS.keys()) + + +def fmt_duration(seconds: float) -> str: + """Format seconds into a human-readable duration.""" + if seconds < 60: + return f"{seconds:.1f}s" + minutes = int(seconds // 60) + secs = seconds % 60 + if minutes < 60: + return f"{minutes}m {secs:.0f}s" + hours = int(minutes // 60) + mins = minutes % 60 + return f"{hours}h {mins}m {secs:.0f}s" + + +def get_session() -> Session: + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def download_from_gcs(gcs_path: str, local_path: Path) -> bool: + """Download a file from GCS using gsutil. Skips if already exists locally.""" + if local_path.exists() and local_path.stat().st_size > 0: + return True + local_path.parent.mkdir(parents=True, exist_ok=True) + result = subprocess.run( + ["gsutil", "cp", gcs_path, str(local_path)], + capture_output=True, + text=True, + ) + if result.returncode != 0: + console.print(f" [red]gsutil error: {result.stderr.strip()}[/red]") + return False + return True + + +def convert_dataset(raw_h5_path: str, output_folder: str, years: list[int]) -> dict: + """Convert a raw h5 file to yearly entity-level h5 files. + + Skips conversion if all yearly output files already exist. + Returns dict mapping dataset_key -> PolicyEngineUSDataset. + """ + from policyengine.tax_benefit_models.us.datasets import ( + create_datasets, + load_datasets, + ) + + stem = Path(raw_h5_path).stem + all_exist = all( + Path(f"{output_folder}/{stem}_year_{year}.h5").exists() for year in years + ) + if all_exist: + return load_datasets( + datasets=[raw_h5_path], + years=years, + data_folder=output_folder, + ) + + return create_datasets( + datasets=[raw_h5_path], + years=years, + data_folder=output_folder, + ) + + +def process_file( + file_info: dict, + years: list[int], + data_folder: Path, + skip_upload: bool, + session, + us_model, + file_index: int, + total_files: int, +) -> tuple[int, int, int, dict]: + """Process a single raw h5 file (state or district). + + Returns (datasets_created, datasets_skipped, errors, timing). + Region-to-dataset wiring is handled by seed_regions.py, not here. + """ + code = file_info["code"] + prefix = f" [{file_index}/{total_files}] {code}" + datasets_created = 0 + datasets_skipped = 0 + errors = 0 + timing = {"code": code, "type": file_info["type"]} + + # Step 1: Download + t0 = time.time() + console.print(f"{prefix}: downloading from GCS...") + if not download_from_gcs(file_info["gcs_path"], file_info["local_path"]): + console.print(f"{prefix}: [red]download failed, skipping[/red]") + timing["status"] = "download_failed" + return 0, 0, 1, timing + dl_time = time.time() - t0 + size_mb = file_info["local_path"].stat().st_size / (1024 * 1024) + timing["download_seconds"] = round(dl_time, 2) + timing["raw_size_mb"] = round(size_mb, 1) + console.print(f"{prefix}: downloaded ({size_mb:.1f} MB, {fmt_duration(dl_time)})") + + # Step 2: Convert + t0 = time.time() + console.print(f"{prefix}: converting to {len(years)} yearly datasets...") + output_folder = str(data_folder / file_info["output_subfolder"]) + try: + converted = convert_dataset(str(file_info["local_path"]), output_folder, years) + except Exception as e: + console.print(f"{prefix}: [red]conversion failed: {e}[/red]") + timing["status"] = "conversion_failed" + timing["error"] = str(e) + return 0, 0, 1, timing + conv_time = time.time() - t0 + timing["conversion_seconds"] = round(conv_time, 2) + timing["datasets_converted"] = len(converted) + console.print( + f"{prefix}: converted {len(converted)} datasets ({fmt_duration(conv_time)})" + ) + + # Step 3: Upload + create DB records + if skip_upload: + datasets_skipped += len(converted) + timing["upload_seconds"] = 0 + timing["status"] = "upload_skipped" + console.print(f"{prefix}: [yellow]upload skipped[/yellow]") + else: + t0 = time.time() + console.print(f"{prefix}: uploading to Supabase...") + for _, pe_dataset in converted.items(): + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + datasets_skipped += 1 + continue + + object_name = f"{file_info['supabase_prefix']}/{pe_dataset.name}.h5" + + try: + upload_dataset_for_seeding(pe_dataset.filepath, object_name=object_name) + except Exception as e: + console.print( + f"{prefix}: [red]upload failed for {pe_dataset.name}: {e}[/red]" + ) + errors += 1 + continue + + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + session.refresh(db_dataset) + datasets_created += 1 + + upload_time = time.time() - t0 + timing["upload_seconds"] = round(upload_time, 2) + console.print( + f"{prefix}: uploaded {datasets_created} datasets, " + f"{datasets_skipped} already existed ({fmt_duration(upload_time)})" + ) + + timing["datasets_created"] = datasets_created + timing["datasets_skipped"] = datasets_skipped + timing["errors"] = errors + timing["status"] = timing.get("status", "ok") + + return datasets_created, datasets_skipped, errors, timing + + +def process_state( + state_code: str, + years: list[int], + data_folder: Path, + skip_upload: bool, + state_only: bool, + session, + us_model, +) -> tuple[int, int, int, list[dict]]: + """Process one state: its state-level file and all district files. + + Returns (created, skipped, errors, file_timings). + Region-to-dataset wiring is handled by seed_regions.py, not here. + """ + district_count = DISTRICT_COUNTS.get(state_code, 0) + + files_to_process = [] + + # State file + files_to_process.append( + { + "type": "state", + "code": state_code, + "gcs_path": f"{GCS_BUCKET}/states/{state_code}.h5", + "local_path": TMP_DIR / "states" / f"{state_code}.h5", + "output_subfolder": "states", + "supabase_prefix": f"states/{state_code}", + } + ) + + # District files + if not state_only: + for i in range(1, district_count + 1): + district_code = f"{state_code}-{i:02d}" + files_to_process.append( + { + "type": "district", + "code": district_code, + "gcs_path": f"{GCS_BUCKET}/districts/{district_code}.h5", + "local_path": TMP_DIR / "districts" / f"{district_code}.h5", + "output_subfolder": "districts", + "supabase_prefix": f"districts/{district_code}", + } + ) + + total_files = len(files_to_process) + total_created = 0 + total_skipped = 0 + total_errors = 0 + file_timings = [] + + for i, file_info in enumerate(files_to_process, 1): + created, skipped, errs, timing = process_file( + file_info=file_info, + years=years, + data_folder=data_folder, + skip_upload=skip_upload, + session=session, + us_model=us_model, + file_index=i, + total_files=total_files, + ) + total_created += created + total_skipped += skipped + total_errors += errs + file_timings.append(timing) + + return total_created, total_skipped, total_errors, file_timings + + +def main(): + parser = argparse.ArgumentParser( + description="Import state & district datasets from GCS to Supabase" + ) + parser.add_argument( + "states", + nargs="*", + help="State codes (e.g., CA NY TX). Uppercase 2-letter codes.", + ) + parser.add_argument( + "--all", + action="store_true", + dest="all_states", + help="Process all 51 states + DC", + ) + parser.add_argument( + "--state-only", + action="store_true", + help="Skip district processing, only do state-level datasets", + ) + parser.add_argument( + "--years", + type=str, + default=None, + help="Comma-separated years (default: 2024,2025,2026,2027,2028)", + ) + parser.add_argument( + "--skip-upload", + action="store_true", + help="Convert locally without uploading to Supabase or creating DB records", + ) + parser.add_argument( + "--data-folder", + type=str, + default=None, + help="Local directory for converted files (default: ./data)", + ) + args = parser.parse_args() + + # Determine which states to process + if args.all_states: + states = ALL_STATES + elif args.states: + states = [s.upper() for s in args.states] + else: + parser.error("Provide state codes or use --all") + return + + # Validate state codes + invalid = [s for s in states if s not in DISTRICT_COUNTS] + if invalid: + console.print(f"[red]Invalid state codes: {', '.join(invalid)}[/red]") + sys.exit(1) + + years = DEFAULT_YEARS + if args.years: + years = [int(y.strip()) for y in args.years.split(",")] + + data_folder = ( + Path(args.data_folder) if args.data_folder else Path(__file__).parent / "data" + ) + + total_districts = ( + sum(DISTRICT_COUNTS[s] for s in states) if not args.state_only else 0 + ) + total_files = len(states) + total_districts + total_yearly = total_files * len(years) + + console.print() + console.print("[bold green]State & District Dataset Import[/bold green]") + console.print(f" States: {len(states)} ({', '.join(states)})") + if not args.state_only: + console.print(f" Districts: {total_districts}") + console.print(f" Years: {years}") + console.print(f" Raw files to process: {total_files}") + console.print(f" Yearly datasets to produce: {total_yearly}") + if args.skip_upload: + console.print(" [yellow]Upload skipped (--skip-upload)[/yellow]") + console.print() + + grand_created = 0 + grand_skipped = 0 + grand_errors = 0 + + session = None + us_model = None + + if not args.skip_upload: + session = get_session() + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + if not us_model: + console.print( + "[red]Error: US model not found. Run seed_models.py first.[/red]" + ) + sys.exit(1) + + script_start = time.time() + timing_report = { + "started_at": datetime.now(timezone.utc).isoformat(), + "args": { + "states": states, + "years": years, + "state_only": args.state_only, + "skip_upload": args.skip_upload, + "data_folder": str(data_folder), + }, + "states": [], + } + + for state_idx, state_code in enumerate(states, 1): + district_count = DISTRICT_COUNTS[state_code] + file_count = 1 + (district_count if not args.state_only else 0) + console.print( + f"[bold]({state_idx}/{len(states)}) Processing {state_code} " + f"({file_count} files)[/bold]" + ) + + state_start = time.time() + + created, skipped, errs, file_timings = process_state( + state_code=state_code, + years=years, + data_folder=data_folder, + skip_upload=args.skip_upload, + state_only=args.state_only, + session=session, + us_model=us_model, + ) + + state_time = time.time() - state_start + console.print( + f"[bold]({state_idx}/{len(states)}) {state_code} complete: " + f"{created} created, {skipped} skipped" + f"{f', {errs} errors' if errs else ''} " + f"({fmt_duration(state_time)})[/bold]" + ) + console.print() + + timing_report["states"].append( + { + "state": state_code, + "total_seconds": round(state_time, 2), + "datasets_created": created, + "datasets_skipped": skipped, + "errors": errs, + "files": file_timings, + } + ) + + # Write timing file after each state so partial results are preserved + timing_path = data_folder / "import_timing.json" + timing_path.parent.mkdir(parents=True, exist_ok=True) + timing_path.write_text(json.dumps(timing_report, indent=2)) + + grand_created += created + grand_skipped += skipped + grand_errors += errs + + if session: + session.close() + + total_time = time.time() - script_start + + # Write final timing report + timing_report["finished_at"] = datetime.now(timezone.utc).isoformat() + timing_report["total_seconds"] = round(total_time, 2) + timing_report["totals"] = { + "datasets_created": grand_created, + "datasets_skipped": grand_skipped, + "errors": grand_errors, + } + timing_path = data_folder / "import_timing.json" + timing_path.write_text(json.dumps(timing_report, indent=2)) + + console.print( + f"[bold green]Import complete ({fmt_duration(total_time)})[/bold green]" + ) + console.print(f" Datasets created: {grand_created}") + console.print(f" Datasets skipped (already exist): {grand_skipped}") + if grand_errors: + console.print(f" [red]Errors: {grand_errors}[/red]") + console.print(f" Timing report: {timing_path}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 27eb310..d624a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,11 @@ dependencies = [ "psycopg2-binary>=2.9.10", "supabase>=2.10.0", "storage3>=0.8.1", - "policyengine>=3.1.15", + # IMPORTANT: Before merging app-v2-migration into main, replace this git + # dependency with the production PyPI version of policyengine (e.g., "policyengine>=X.Y.Z"). + # The git ref is used here because the app-v2-migration branch contains fixes + # (US reform application, regions support) not yet released to PyPI. + "policyengine @ git+https://github.com/PolicyEngine/policyengine.py.git@app-v2-migration", "policyengine-uk>=2.0.0", "policyengine-us>=1.0.0", "pydantic>=2.9.2", @@ -24,6 +28,7 @@ dependencies = [ "fastapi-mcp>=0.4.0", "modal>=0.68.0", "anthropic>=0.40.0", + "alembic>=1.13.0", ] [project.optional-dependencies] @@ -38,6 +43,9 @@ dev = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["src/policyengine_api"] diff --git a/scripts/init.py b/scripts/init.py index cf7a04a..f7a64eb 100644 --- a/scripts/init.py +++ b/scripts/init.py @@ -1,12 +1,19 @@ -"""Initialise Supabase: reset database, recreate tables, buckets, and permissions. +"""Initialise Supabase database with tables, buckets, and permissions. -This script performs a complete reset of the Supabase instance: -1. Drops and recreates the public schema (all tables) -2. Deletes and recreates the storage bucket -3. Creates all tables from SQLModel definitions -4. Applies RLS policies and storage permissions +This script can run in two modes: +1. Init mode (default): Creates tables via Alembic, applies RLS policies +2. Reset mode (--reset): Drops everything and recreates from scratch (DESTRUCTIVE) + +Usage: + uv run python scripts/init.py # Safe init (creates if not exists) + uv run python scripts/init.py --reset # Destructive reset (drops everything) + +For local development after `supabase start`, use init mode. +For production, use init mode to ensure tables and policies exist. +Reset mode should only be used when you need a completely fresh database. """ +import subprocess import sys from pathlib import Path @@ -14,16 +21,14 @@ from rich.console import Console from rich.panel import Panel -from sqlmodel import SQLModel, create_engine +from sqlmodel import create_engine -# Import all models to register them with SQLModel.metadata -from policyengine_api import models # noqa: F401 from policyengine_api.config.settings import settings from policyengine_api.services.storage import get_service_role_client console = Console() -MIGRATIONS_DIR = Path(__file__).parent.parent / "supabase" / "migrations" +PROJECT_ROOT = Path(__file__).parent.parent def reset_storage_bucket(): @@ -57,30 +62,61 @@ def reset_storage_bucket(): console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") +def ensure_storage_bucket(): + """Ensure storage bucket exists (non-destructive).""" + console.print("[bold blue]Ensuring storage bucket exists...") + + try: + supabase = get_service_role_client() + bucket_name = settings.storage_bucket + + # Try to get bucket info + try: + supabase.storage.get_bucket(bucket_name) + console.print(f"[green]✓[/green] Bucket '{bucket_name}' exists") + except Exception: + # Bucket doesn't exist, create it + supabase.storage.create_bucket(bucket_name, options={"public": True}) + console.print(f"[green]✓[/green] Created bucket '{bucket_name}'") + + except Exception as e: + console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") + + def reset_database(): - """Drop and recreate all tables.""" - console.print("[bold blue]Resetting database...") + """Drop and recreate the public schema (DESTRUCTIVE).""" + console.print("[bold red]Dropping database schema...") engine = create_engine(settings.database_url, echo=False) - # Drop and recreate public schema - console.print(" Dropping public schema...") with engine.begin() as conn: conn.exec_driver_sql("DROP SCHEMA public CASCADE") conn.exec_driver_sql("CREATE SCHEMA public") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO postgres") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO public") - # Create all tables from SQLModel - console.print(" Creating tables...") - SQLModel.metadata.create_all(engine) + console.print("[green]✓[/green] Schema dropped and recreated") + return engine - tables = list(SQLModel.metadata.tables.keys()) - console.print(f"[green]✓[/green] Created {len(tables)} tables:") - for table in sorted(tables): - console.print(f" {table}") - return engine +def run_alembic_migrations(): + """Run Alembic migrations to create/update tables.""" + console.print("[bold blue]Running Alembic migrations...") + + result = subprocess.run( + ["uv", "run", "alembic", "upgrade", "head"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + console.print("[red]✗ Alembic migration failed:[/red]") + console.print(result.stderr) + raise RuntimeError("Alembic migration failed") + + console.print("[green]✓[/green] Alembic migrations complete") + console.print(result.stdout) def apply_storage_policies(engine): @@ -158,6 +194,10 @@ def apply_rls_policies(engine): "parameter_values", "users", "household_jobs", + "households", + "user_household_associations", + "poverty", + "inequality", ] # Read-only tables (public can read, only service role can write) @@ -178,6 +218,7 @@ def apply_rls_policies(engine): "dynamics", "reports", "household_jobs", + "households", ] # Read-only results tables @@ -186,6 +227,8 @@ def apply_rls_policies(engine): "change_aggregates", "decile_impacts", "program_statistics", + "poverty", + "inequality", ] sql_parts = [] @@ -230,6 +273,13 @@ def apply_rls_policies(engine): FOR SELECT TO anon, authenticated USING (true); """) + # User-household associations need special handling + sql_parts.append(""" + DROP POLICY IF EXISTS "Users can manage own associations" ON user_household_associations; + CREATE POLICY "Users can manage own associations" ON user_household_associations + FOR ALL TO anon, authenticated USING (true) WITH CHECK (true); + """) + sql = "\n".join(sql_parts) conn = engine.raw_connection() @@ -246,30 +296,53 @@ def apply_rls_policies(engine): def main(): - """Run full Supabase initialisation.""" - console.print( - Panel.fit( - "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" - "This script resets the entire Supabase instance.", - title="Supabase init", + """Run Supabase initialisation.""" + reset_mode = "--reset" in sys.argv + + if reset_mode: + console.print( + Panel.fit( + "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" + "This script will reset the entire Supabase instance.", + title="Supabase RESET", + ) ) - ) - # Confirm unless running non-interactively - if sys.stdin.isatty(): - response = console.input("\nType 'yes' to continue: ") - if response.lower() != "yes": - console.print("[yellow]Aborted[/yellow]") - return + # Confirm unless running non-interactively + if sys.stdin.isatty(): + response = console.input("\nType 'yes' to continue: ") + if response.lower() != "yes": + console.print("[yellow]Aborted[/yellow]") + return + + console.print() + + # Reset storage bucket + reset_storage_bucket() + console.print() + + # Drop database schema + engine = reset_database() + console.print() + else: + console.print( + Panel.fit( + "[bold blue]Initialising Supabase[/bold blue]\n" + "This will create tables if they don't exist (safe/idempotent).\n" + "Use [cyan]--reset[/cyan] flag to drop and recreate everything.", + title="Supabase init", + ) + ) + console.print() - console.print() + # Ensure storage bucket exists + ensure_storage_bucket() + console.print() - # Reset storage bucket - reset_storage_bucket() - console.print() + engine = create_engine(settings.database_url, echo=False) - # Reset database and create tables - engine = reset_database() + # Run Alembic migrations to create/update tables + run_alembic_migrations() console.print() # Apply storage policies diff --git a/scripts/seed.py b/scripts/seed.py index f3fbfa8..c34ee02 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -1,649 +1,415 @@ -"""Seed database with UK and US models, variables, parameters, datasets.""" +"""Seed PolicyEngine database with models, datasets, policies, and regions. + +This is the main orchestrator script that calls individual seed scripts +based on the selected preset. + +Presets: + full - Everything (default) + lite - Both countries, 2026 datasets only, skip state params, core regions + minimal - Both countries, 2026 datasets only, skip state params, no policies/regions + uk-lite - UK only, 2026 datasets, skip state params + uk-minimal - UK only, 2026 datasets, skip state params, no policies/regions + us-lite - US only, 2026 datasets, skip state params, core regions only + us-minimal - US only, 2026 datasets, skip state params, no policies/regions + testing - US only, ~100 curated variables/parameters, fast local testing + +Usage: + python scripts/seed.py # Full seed (default) + python scripts/seed.py --preset=lite # Lite mode for both countries + python scripts/seed.py --preset=us-lite # US only, lite mode + python scripts/seed.py --preset=minimal # Minimal seed (no policies/regions) + python scripts/seed.py --preset=testing # Fast testing preset (~100 vars/params) +""" import argparse -import json -import logging -import math -import sys -import warnings -from datetime import datetime, timezone -from pathlib import Path -from uuid import uuid4 - -import logfire - -# Disable all SQLAlchemy and database logging BEFORE any imports -logging.basicConfig(level=logging.ERROR) -logging.getLogger("sqlalchemy").setLevel(logging.ERROR) -warnings.filterwarnings("ignore") - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - -from policyengine.tax_benefit_models.uk import uk_latest # noqa: E402 -from policyengine.tax_benefit_models.uk.datasets import ( # noqa: E402 - ensure_datasets as ensure_uk_datasets, -) -from policyengine.tax_benefit_models.us import us_latest # noqa: E402 -from policyengine.tax_benefit_models.us.datasets import ( # noqa: E402 - ensure_datasets as ensure_us_datasets, -) -from rich.console import Console # noqa: E402 -from rich.progress import Progress, SpinnerColumn, TextColumn # noqa: E402 -from sqlmodel import Session, create_engine, select # noqa: E402 - -from policyengine_api.config.settings import settings # noqa: E402 -from policyengine_api.models import ( # noqa: E402 - Dataset, - Parameter, - ParameterValue, - Policy, - TaxBenefitModel, - TaxBenefitModelVersion, -) -from policyengine_api.services.storage import ( # noqa: E402 - upload_dataset_for_seeding, -) - -# Configure logfire -if settings.logfire_token: - logfire.configure( - token=settings.logfire_token, - environment=settings.logfire_environment, - console=False, - ) - -console = Console() - - -def get_quiet_session(): - """Get database session with logging disabled.""" - engine = create_engine(settings.database_url, echo=False) - with Session(engine) as session: - yield session - - -def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): - """Fast bulk insert using PostgreSQL COPY via StringIO.""" - if not rows: - return - - import io - - # Get raw psycopg2 connection - need to use the connection from session - # but not commit separately to avoid transaction issues - connection = session.connection() - raw_conn = connection.connection.dbapi_connection - cursor = raw_conn.cursor() - - # Build CSV-like data in memory - output = io.StringIO() - for row in rows: - values = [] - for col in columns: - val = row[col] - if val is None: - values.append("\\N") - elif isinstance(val, str): - # Escape special characters for COPY - val = ( - val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") - ) - values.append(val) - else: - values.append(str(val)) - output.write("\t".join(values) + "\n") - - output.seek(0) - - # COPY is the fastest way to bulk load PostgreSQL - cursor.copy_from(output, table, columns=columns, null="\\N") - # Let SQLAlchemy handle the commit via session - session.commit() - - -def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVersion: - """Seed a tax-benefit model with its variables and parameters.""" - - with logfire.span( - "seed_model", - model=model_version.model.id, - version=model_version.version, - ): - # Create or get the model - console.print(f"[bold blue]Seeding {model_version.model.id}...") - - existing_model = session.exec( - select(TaxBenefitModel).where( - TaxBenefitModel.name == model_version.model.id - ) - ).first() - - if existing_model: - db_model = existing_model - console.print(f" Using existing model: {db_model.id}") - else: - db_model = TaxBenefitModel( - name=model_version.model.id, - description=model_version.model.description, - ) - session.add(db_model) - session.commit() - session.refresh(db_model) - console.print(f" Created model: {db_model.id}") - - # Create model version - existing_version = session.exec( - select(TaxBenefitModelVersion).where( - TaxBenefitModelVersion.model_id == db_model.id, - TaxBenefitModelVersion.version == model_version.version, - ) - ).first() - - if existing_version: - console.print( - f" Model version {model_version.version} already exists, skipping" - ) - return existing_version - - db_version = TaxBenefitModelVersion( - model_id=db_model.id, - version=model_version.version, - description=f"Version {model_version.version}", - ) - session.add(db_version) - session.commit() - session.refresh(db_version) - console.print(f" Created version: {db_version.version}") - - # Add variables - with logfire.span("add_variables", count=len(model_version.variables)): - var_rows = [] - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task( - f"Preparing {len(model_version.variables)} variables", - total=len(model_version.variables), - ) - for var in model_version.variables: - var_rows.append( - { - "id": uuid4(), - "name": var.name, - "entity": var.entity, - "description": var.description or "", - "data_type": var.data_type.__name__ - if hasattr(var.data_type, "__name__") - else str(var.data_type), - "possible_values": None, - "tax_benefit_model_version_id": db_version.id, - "created_at": datetime.now(timezone.utc), - } - ) - progress.advance(task) - - console.print(f" Inserting {len(var_rows)} variables...") - bulk_insert( +import time +from dataclasses import dataclass + +# Import seed functions from subscripts +from seed_datasets import seed_uk_datasets, seed_us_datasets +from seed_models import seed_uk_model, seed_us_model +from seed_policies import seed_uk_policy, seed_us_policy +from seed_regions import seed_uk_regions, seed_us_regions +from seed_utils import console, get_session + + +@dataclass +class SeedConfig: + """Configuration for database seeding.""" + + # Countries + seed_uk: bool = True + seed_us: bool = True + + # Models + skip_state_params: bool = False + variable_whitelist: set[str] | None = None # None = all variables + parameter_prefixes: set[str] | None = None # None = all parameters + + # Datasets + dataset_year: int | None = None # None = all years + + # Policies + seed_policies: bool = True + + # Regions + seed_regions: bool = True + skip_places: bool = False + skip_districts: bool = False + + +# Curated variable names for the testing preset (~100 US variables) +TESTING_VARIABLES: set[str] = { + # Person inputs + "age", + "employment_income", + "self_employment_income", + "pension_income", + "social_security", + "unemployment_compensation", + "dividend_income", + "interest_income", + "capital_gains", + "rental_income", + "alimony_income", + "child_support_received", + "is_tax_unit_dependent", + "is_disabled", + "is_blind", + "is_pregnant", + "is_ssi_aged", + "is_ssi_disabled", + "marital_status", + "tax_unit_spouse", + "is_tax_unit_head", + "military_basic_pay", + "farm_income", + "partnership_s_corp_income", + "taxable_pension_income", + # Household/geography + "state_name", + "state_code", + "state_fips", + "household_size", + "county", + "in_nyc", + "is_on_tribal_land", + "snap_region", + "medicaid_rating_area", + "fips", + # Tax unit + "adjusted_gross_income", + "taxable_income", + "standard_deduction", + "itemized_deductions", + "filing_status", + "tax_unit_size", + "tax_unit_dependents", + "tax_unit_is_joint", + "income_tax", + "income_tax_before_credits", + "income_tax_refundable_credits", + "income_tax_non_refundable_credits", + "earned_income", + "agi", + "tax_unit_earned_income", + # Federal tax outputs + "federal_income_tax", + "federal_income_tax_before_credits", + "payroll_tax", + "employee_payroll_tax", + "self_employment_tax", + "earned_income_tax_credit", + "child_tax_credit", + "additional_child_tax_credit", + "child_and_dependent_care_credit", + "american_opportunity_credit", + "premium_tax_credit", + "recovery_rebate_credit", + "ctc_qualifying_children", + "eitc_eligible", + "amt_income", + # Benefits + "snap", + "ssi", + "tanf", + "wic", + "school_meal_subsidy", + "free_school_meals", + "reduced_price_school_meals", + "medicaid", + "chip", + "housing_subsidy", + "section_8_income", + "lifeline", + "acp", + "pell_grant", + "ssi_amount_if_eligible", + # Aggregate/summary + "household_net_income", + "household_income", + "household_benefits", + "household_tax", + "household_market_income", + "net_income", + "market_income", + "spm_unit_net_income", + "spm_unit_spm_threshold", + "in_poverty", + "in_deep_poverty", + "poverty_gap", + "deep_poverty_gap", + "disposable_income", + "marginal_tax_rate", +} + +# Parameter name prefixes for the testing preset (~100 US parameters) +TESTING_PARAMETER_PREFIXES: set[str] = { + "gov.irs.income.bracket", + "gov.irs.deductions.standard", + "gov.irs.credits.ctc", + "gov.irs.credits.eitc", + "gov.usda.snap", + "gov.ssa.ssi", + "gov.ssa.social_security", + "gov.irs.payroll", + "gov.irs.fica", + "gov.hhs.tanf", + "gov.irs.income.amt", + "gov.irs.capital_gains", + "gov.irs.credits.premium_tax_credit", + "gov.irs.income.exemption", + "gov.hhs.medicaid", + "gov.contrib.ubi_center.basic_income", +} + + +# Preset configurations +PRESETS: dict[str, SeedConfig] = { + "full": SeedConfig( + seed_uk=True, + seed_us=True, + skip_state_params=False, + dataset_year=None, + seed_policies=True, + seed_regions=True, + skip_places=False, + skip_districts=False, + ), + "lite": SeedConfig( + seed_uk=True, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + skip_places=True, + skip_districts=True, + ), + "minimal": SeedConfig( + seed_uk=True, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=False, + seed_regions=False, + ), + "uk-lite": SeedConfig( + seed_uk=True, + seed_us=False, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + ), + "uk-minimal": SeedConfig( + seed_uk=True, + seed_us=False, + skip_state_params=True, + dataset_year=2026, + seed_policies=False, + seed_regions=False, + ), + "us-lite": SeedConfig( + seed_uk=False, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + skip_places=True, + skip_districts=True, + ), + "us-minimal": SeedConfig( + seed_uk=False, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=False, + seed_regions=False, + ), + "testing": SeedConfig( + seed_uk=False, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + skip_places=True, + skip_districts=True, + variable_whitelist=TESTING_VARIABLES, + parameter_prefixes=TESTING_PARAMETER_PREFIXES, + ), +} + + +def run_seed(config: SeedConfig): + """Run database seeding with the given configuration.""" + start = time.time() + + with get_session() as session: + # Step 1: Seed models + console.print("[bold blue]Step 1: Seeding models...[/bold blue]\n") + + if config.seed_uk: + seed_uk_model( session, - "variables", - [ - "id", - "name", - "entity", - "description", - "data_type", - "possible_values", - "tax_benefit_model_version_id", - "created_at", - ], - var_rows, - ) - - console.print( - f" [green]✓[/green] Added {len(model_version.variables)} variables" + skip_state_params=config.skip_state_params, + variable_whitelist=config.variable_whitelist, + parameter_prefixes=config.parameter_prefixes, ) - # Add parameters (only user-facing ones: those with labels) - # Deduplicate by name - keep first occurrence - # In lite mode, exclude US state parameters (gov.states.*) - seen_names = set() - parameters_to_add = [] - skipped_state_params = 0 - for p in model_version.parameters: - if p.label is None or p.name in seen_names: - continue - # In lite mode, skip state-level parameters for faster seeding - if lite and p.name.startswith("gov.states."): - skipped_state_params += 1 - continue - parameters_to_add.append(p) - seen_names.add(p.name) - - filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters" - filter_msg += f" (from {len(model_version.parameters)} total, deduplicated by name)" - if lite and skipped_state_params > 0: - filter_msg += f", skipped {skipped_state_params} state params (lite mode)" - console.print(filter_msg) - - with logfire.span("add_parameters", count=len(parameters_to_add)): - # Build list of parameter dicts for bulk insert - param_rows = [] - param_names = [] # Track (pe_id, name, generated_uuid) - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task( - f"Preparing {len(parameters_to_add)} parameters", - total=len(parameters_to_add), - ) - for param in parameters_to_add: - param_uuid = uuid4() - param_rows.append( - { - "id": param_uuid, - "name": param.name, - "label": param.label if hasattr(param, "label") else None, - "description": param.description or "", - "data_type": param.data_type.__name__ - if hasattr(param.data_type, "__name__") - else str(param.data_type), - "unit": param.unit, - "tax_benefit_model_version_id": db_version.id, - "created_at": datetime.now(timezone.utc), - } - ) - param_names.append((param.id, param.name, param_uuid)) - progress.advance(task) - - console.print(f" Inserting {len(param_rows)} parameters...") - bulk_insert( + if config.seed_us: + seed_us_model( session, - "parameters", - [ - "id", - "name", - "label", - "description", - "data_type", - "unit", - "tax_benefit_model_version_id", - "created_at", - ], - param_rows, + skip_state_params=config.skip_state_params, + variable_whitelist=config.variable_whitelist, + parameter_prefixes=config.parameter_prefixes, ) - # Build param_id_map from pre-generated UUIDs - param_id_map = {pe_id: db_uuid for pe_id, name, db_uuid in param_names} + # Step 2: Seed datasets + console.print("[bold blue]Step 2: Seeding datasets...[/bold blue]\n") + if config.seed_uk: + console.print("[bold]UK Datasets[/bold]") + uk_created, uk_skipped = seed_uk_datasets(session, year=config.dataset_year) console.print( - f" [green]✓[/green] Added {len(parameters_to_add)} parameters" + f"[green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped\n" ) - # Add parameter values - # Filter to only include values for parameters we added - parameter_values_to_add = [ - pv - for pv in model_version.parameter_values - if pv.parameter.id in param_id_map - ] - console.print(f" Found {len(parameter_values_to_add)} parameter values to add") - - with logfire.span("add_parameter_values", count=len(parameter_values_to_add)): - pv_rows = [] - skipped = 0 - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task( - f"Preparing {len(parameter_values_to_add)} parameter values", - total=len(parameter_values_to_add), - ) - for pv in parameter_values_to_add: - # Handle Infinity values - skip them as they can't be stored in JSON - if isinstance(pv.value, float) and ( - math.isinf(pv.value) or math.isnan(pv.value) - ): - skipped += 1 - progress.advance(task) - continue - - # Source data has dates swapped (start > end), fix ordering - # Only swap if both dates are set, otherwise keep original - if pv.start_date and pv.end_date: - start = pv.end_date # Swap: source end is our start - end = pv.start_date # Swap: source start is our end - else: - start = pv.start_date - end = pv.end_date - pv_rows.append( - { - "id": uuid4(), - "parameter_id": param_id_map[pv.parameter.id], - "value_json": json.dumps(pv.value), - "start_date": start, - "end_date": end, - "policy_id": None, - "dynamic_id": None, - "created_at": datetime.now(timezone.utc), - } - ) - progress.advance(task) - - console.print(f" Inserting {len(pv_rows)} parameter values...") - bulk_insert( - session, - "parameter_values", - [ - "id", - "parameter_id", - "value_json", - "start_date", - "end_date", - "policy_id", - "dynamic_id", - "created_at", - ], - pv_rows, - ) - - console.print( - f" [green]✓[/green] Added {len(pv_rows)} parameter values" - + (f" (skipped {skipped} invalid)" if skipped else "") - ) - - return db_version - - -def seed_datasets(session, lite: bool = False): - """Seed datasets and upload to S3.""" - with logfire.span("seed_datasets"): - mode_str = " (lite mode - 2026 only)" if lite else "" - console.print(f"[bold blue]Seeding datasets{mode_str}...") - - # Get UK and US models - uk_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") - ).first() - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() - - if not uk_model or not us_model: + if config.seed_us: + console.print("[bold]US Datasets[/bold]") + us_created, us_skipped = seed_us_datasets(session, year=config.dataset_year) console.print( - "[red]Error: UK or US model not found. Run seed_model first.[/red]" + f"[green]✓[/green] US: {us_created} created, {us_skipped} skipped\n" ) - return - - # UK datasets - console.print(" Creating UK datasets...") - data_folder = str(Path(__file__).parent.parent / "data") - uk_datasets = ensure_uk_datasets(data_folder=data_folder) - - # In lite mode, only upload FRS 2026 - if lite: - uk_datasets = { - k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k - } - console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") - - uk_created = 0 - uk_skipped = 0 - - with logfire.span("seed_uk_datasets", count=len(uk_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("UK datasets", total=len(uk_datasets)) - for _, pe_dataset in uk_datasets.items(): - progress.update(task, description=f"UK: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - uk_skipped += 1 - progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=uk_model.id, - ) - session.add(db_dataset) - session.commit() - uk_created += 1 - progress.advance(task) - - console.print( - f" [green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped" - ) - # US datasets - console.print(" Creating US datasets...") - us_datasets = ensure_us_datasets(data_folder=data_folder) - - # In lite mode, only upload CPS 2026 - if lite: - us_datasets = { - k: v for k, v in us_datasets.items() if v.year == 2026 and "cps" in k - } - console.print(f" Lite mode: filtered to {len(us_datasets)} dataset(s)") - - us_created = 0 - us_skipped = 0 - - with logfire.span("seed_us_datasets", count=len(us_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("US datasets", total=len(us_datasets)) - for _, pe_dataset in us_datasets.items(): - progress.update(task, description=f"US: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - us_skipped += 1 - progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=us_model.id, - ) - session.add(db_dataset) - session.commit() - us_created += 1 - progress.advance(task) - - console.print( - f" [green]✓[/green] US: {us_created} created, {us_skipped} skipped" - ) - console.print( - f"[green]✓[/green] Seeded {uk_created + us_created} datasets total\n" - ) + # Step 3: Seed policies + if config.seed_policies: + console.print("[bold blue]Step 3: Seeding policies...[/bold blue]\n") + if config.seed_uk: + seed_uk_policy(session) -def seed_example_policies(session): - """Seed example policy reforms for UK and US.""" - with logfire.span("seed_example_policies"): - console.print("[bold blue]Seeding example policies...") + if config.seed_us: + seed_us_policy(session) - # Get model versions - uk_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") - ).first() - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() + console.print() - if not uk_model or not us_model: - console.print( - "[red]Error: UK or US model not found. Run seed_model first.[/red]" - ) - return - - uk_version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == uk_model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - - us_version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == us_model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - - # UK example policy: raise basic rate to 22p - uk_policy_name = "UK basic rate 22p" - existing_uk_policy = session.exec( - select(Policy).where(Policy.name == uk_policy_name) - ).first() - - if existing_uk_policy: - console.print(f" Policy '{uk_policy_name}' already exists, skipping") - else: - # Find the basic rate parameter - uk_basic_rate_param = session.exec( - select(Parameter).where( - Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", - Parameter.tax_benefit_model_version_id == uk_version.id, - ) - ).first() + # Step 4: Seed regions + if config.seed_regions: + console.print("[bold blue]Step 4: Seeding regions...[/bold blue]\n") - if uk_basic_rate_param: - uk_policy = Policy( - name=uk_policy_name, - description="Raise the UK income tax basic rate from 20p to 22p", - ) - session.add(uk_policy) - session.commit() - session.refresh(uk_policy) - - # Add parameter value (22% = 0.22) - uk_param_value = ParameterValue( - parameter_id=uk_basic_rate_param.id, - value_json={"value": 0.22}, - start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), - end_date=None, - policy_id=uk_policy.id, + if config.seed_us: + console.print("[bold]US Regions[/bold]") + us_created, us_skipped = seed_us_regions( + session, + skip_places=config.skip_places, + skip_districts=config.skip_districts, ) - session.add(uk_param_value) - session.commit() - console.print(f" [green]✓[/green] Created UK policy: {uk_policy_name}") - else: console.print( - " [yellow]Warning: UK basic rate parameter not found[/yellow]" + f"[green]✓[/green] US: {us_created} created, {us_skipped} skipped\n" ) - # US example policy: raise first bracket rate to 12% - us_policy_name = "US 12% lowest bracket" - existing_us_policy = session.exec( - select(Policy).where(Policy.name == us_policy_name) - ).first() - - if existing_us_policy: - console.print(f" Policy '{us_policy_name}' already exists, skipping") - else: - # Find the first bracket rate parameter - us_first_bracket_param = session.exec( - select(Parameter).where( - Parameter.name == "gov.irs.income.bracket.rates.1", - Parameter.tax_benefit_model_version_id == us_version.id, - ) - ).first() - - if us_first_bracket_param: - us_policy = Policy( - name=us_policy_name, - description="Raise US federal income tax lowest bracket to 12%", - ) - session.add(us_policy) - session.commit() - session.refresh(us_policy) - - # Add parameter value (12% = 0.12) - us_param_value = ParameterValue( - parameter_id=us_first_bracket_param.id, - value_json={"value": 0.12}, - start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), - end_date=None, - policy_id=us_policy.id, - ) - session.add(us_param_value) - session.commit() - console.print(f" [green]✓[/green] Created US policy: {us_policy_name}") - else: + if config.seed_uk: + console.print("[bold]UK Regions[/bold]") + uk_created, uk_skipped = seed_uk_regions(session) console.print( - " [yellow]Warning: US first bracket parameter not found[/yellow]" + f"[green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped\n" ) - console.print("[green]✓[/green] Example policies seeded\n") + elapsed = time.time() - start + console.print("[bold green]✓ Database seeding complete![/bold green]") + console.print(f"[bold]Total time: {elapsed:.1f}s[/bold]") def main(): - """Main seed function.""" - parser = argparse.ArgumentParser(description="Seed PolicyEngine database") + parser = argparse.ArgumentParser( + description="Seed PolicyEngine database", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Presets: + full Everything (default) + lite Both countries, 2026 datasets only, skip state params, core regions + minimal Both countries, 2026 datasets only, skip state params, no policies/regions + uk-lite UK only, 2026 datasets, skip state params + uk-minimal UK only, 2026 datasets, skip state params, no policies/regions + us-lite US only, 2026 datasets, skip state params, core regions only + us-minimal US only, 2026 datasets, skip state params, no policies/regions + testing US only, ~100 curated variables/parameters, fast local testing +""", + ) parser.add_argument( - "--lite", - action="store_true", - help="Lite mode: skip US state parameters, only seed FRS 2026 and CPS 2026 datasets", + "--preset", + choices=list(PRESETS.keys()), + default="full", + help="Seeding preset (default: full)", ) args = parser.parse_args() - with logfire.span("database_seeding"): - mode_str = " (lite mode)" if args.lite else "" - console.print(f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n") - - with next(get_quiet_session()) as session: - # Seed UK model - uk_version = seed_model(uk_latest, session, lite=args.lite) - console.print(f"[green]✓[/green] UK model seeded: {uk_version.id}\n") + config = PRESETS[args.preset] - # Seed US model - us_version = seed_model(us_latest, session, lite=args.lite) - console.print(f"[green]✓[/green] US model seeded: {us_version.id}\n") + # Build description of what we're doing + countries = [] + if config.seed_uk: + countries.append("UK") + if config.seed_us: + countries.append("US") + country_str = " + ".join(countries) - # Seed datasets - seed_datasets(session, lite=args.lite) + year_str = f", {config.dataset_year} only" if config.dataset_year else "" + state_str = ", skip state params" if config.skip_state_params else "" - # Seed example policies - seed_example_policies(session) + console.print( + f"[bold green]PolicyEngine database seeding[/bold green] " + f"[dim](preset: {args.preset})[/dim]\n" + ) + console.print(f" Countries: {country_str}") + console.print( + f" Datasets: {'all years' if not config.dataset_year else config.dataset_year}" + ) + if config.skip_state_params: + console.print(" State params: skipped") + if config.variable_whitelist is not None: + console.print(f" Variables: {len(config.variable_whitelist)} whitelisted") + if config.parameter_prefixes is not None: + console.print(f" Parameter prefixes: {len(config.parameter_prefixes)} active") + console.print(f" Policies: {'yes' if config.seed_policies else 'no'}") + if config.seed_regions: + region_details = [] + if config.skip_places: + region_details.append("no places") + if config.skip_districts: + region_details.append("no districts") + region_str = ( + f"yes ({', '.join(region_details)})" if region_details else "yes (all)" + ) + console.print(f" Regions: {region_str}") + else: + console.print(" Regions: no") + console.print() - console.print("\n[bold green]✓ Database seeding complete![/bold green]") + run_seed(config) if __name__ == "__main__": diff --git a/scripts/seed_datasets.py b/scripts/seed_datasets.py new file mode 100644 index 0000000..2684968 --- /dev/null +++ b/scripts/seed_datasets.py @@ -0,0 +1,226 @@ +"""Seed datasets and upload to S3. + +This script downloads datasets from policyengine.py, uploads them to S3, +and creates database records. + +Usage: + python scripts/seed_datasets.py # Seed UK and US datasets + python scripts/seed_datasets.py --us-only # Seed only US datasets + python scripts/seed_datasets.py --uk-only # Seed only UK datasets + python scripts/seed_datasets.py --year=2026 # Seed only 2026 datasets +""" + +import argparse +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from seed_utils import console, get_session +from sqlmodel import Session, select + +# Import after seed_utils sets up path +from policyengine_api.models import Dataset, TaxBenefitModel # noqa: E402 +from policyengine_api.services.storage import upload_dataset_for_seeding # noqa: E402 + + +def seed_uk_datasets(session: Session, year: int | None = None) -> tuple[int, int]: + """Seed UK datasets. + + Args: + session: Database session + year: If specified, only seed datasets for this year + + Returns: + Tuple of (created_count, skipped_count) + """ + from policyengine.tax_benefit_models.uk.datasets import ( + ensure_datasets as ensure_uk_datasets, + ) + + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_models.py first.[/red]") + return 0, 0 + + data_folder = str(Path(__file__).parent.parent / "data") + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + + # Filter by year if specified + if year: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == year and "frs" in k + } + console.print(f" Filtered to {len(uk_datasets)} dataset(s) for year {year}") + + created = 0 + skipped = 0 + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for _, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + return created, skipped + + +def seed_us_datasets(session: Session, year: int | None = None) -> tuple[int, int]: + """Seed US datasets. + + Args: + session: Database session + year: If specified, only seed datasets for this year + + Returns: + Tuple of (created_count, skipped_count) + """ + from policyengine.tax_benefit_models.us.datasets import ( + ensure_datasets as ensure_us_datasets, + ) + + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_models.py first.[/red]") + return 0, 0 + + data_folder = str(Path(__file__).parent.parent / "data") + us_datasets = ensure_us_datasets(data_folder=data_folder) + + # Filter by year if specified + if year: + us_datasets = { + k: v for k, v in us_datasets.items() if v.year == year and "cps" in k + } + console.print(f" Filtered to {len(us_datasets)} dataset(s) for year {year}") + + created = 0 + skipped = 0 + + with logfire.span("seed_us_datasets", count=len(us_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US datasets", total=len(us_datasets)) + for _, pe_dataset in us_datasets.items(): + progress.update(task, description=f"US: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + return created, skipped + + +def main(): + parser = argparse.ArgumentParser(description="Seed datasets") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US datasets", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK datasets", + ) + parser.add_argument( + "--year", + type=int, + default=None, + help="Only seed datasets for this year (e.g., 2026)", + ) + args = parser.parse_args() + + year_str = f" (year {args.year})" if args.year else "" + console.print(f"[bold green]Seeding datasets{year_str}...[/bold green]\n") + + total_created = 0 + total_skipped = 0 + + with get_session() as session: + if not args.us_only: + console.print("[bold]UK Datasets[/bold]") + uk_created, uk_skipped = seed_uk_datasets(session, year=args.year) + total_created += uk_created + total_skipped += uk_skipped + console.print( + f"[green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped\n" + ) + + if not args.uk_only: + console.print("[bold]US Datasets[/bold]") + us_created, us_skipped = seed_us_datasets(session, year=args.year) + total_created += us_created + total_skipped += us_skipped + console.print( + f"[green]✓[/green] US: {us_created} created, {us_skipped} skipped\n" + ) + + console.print( + f"[bold green]✓ Dataset seeding complete! " + f"{total_created} created, {total_skipped} skipped[/bold green]" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_models.py b/scripts/seed_models.py new file mode 100644 index 0000000..9313faa --- /dev/null +++ b/scripts/seed_models.py @@ -0,0 +1,414 @@ +"""Seed tax-benefit models with variables and parameters. + +This script seeds TaxBenefitModel, TaxBenefitModelVersion, Variables, +Parameters, and ParameterValues from policyengine.py. + +Usage: + python scripts/seed_models.py # Seed UK and US models + python scripts/seed_models.py --us-only # Seed only US model + python scripts/seed_models.py --uk-only # Seed only UK model + python scripts/seed_models.py --skip-state-params # Skip US state parameters +""" + +import argparse +import json +import math +from datetime import datetime, timezone +from uuid import uuid4 + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from seed_utils import bulk_insert, console, get_session +from sqlmodel import Session, select + +# Import after seed_utils sets up path +from policyengine_api.models import ( # noqa: E402 + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def _get_variable_type_info(var) -> tuple[str, str | None]: + """Extract data_type and possible_values from a policyengine variable. + + For enum variables (those with possible_values), returns ("Enum", json_values). + For other variables, returns (python_type_name, None). + + Returns: + Tuple of (data_type, possible_values_json) + """ + if var.possible_values: + return "Enum", json.dumps(var.possible_values) + + data_type = ( + var.data_type.__name__ + if hasattr(var.data_type, "__name__") + else str(var.data_type) + ) + return data_type, None + + +def seed_model( + model_version, + session: Session, + skip_state_params: bool = False, + variable_whitelist: set[str] | None = None, + parameter_prefixes: set[str] | None = None, +) -> TaxBenefitModelVersion: + """Seed a tax-benefit model with its variables and parameters. + + Args: + model_version: The policyengine.py model version object + session: Database session + skip_state_params: Skip US state-level parameters (gov.states.*) + variable_whitelist: If provided, only seed variables whose name is in this set + parameter_prefixes: If provided, only seed parameters whose name starts with + one of these prefixes + + Returns: + The created or existing TaxBenefitModelVersion + """ + with logfire.span( + "seed_model", + model=model_version.model.id, + version=model_version.version, + ): + console.print(f"[bold blue]Seeding {model_version.model.id}...") + + # Create or get the model + existing_model = session.exec( + select(TaxBenefitModel).where( + TaxBenefitModel.name == model_version.model.id + ) + ).first() + + if existing_model: + db_model = existing_model + console.print(f" Using existing model: {db_model.id}") + else: + db_model = TaxBenefitModel( + name=model_version.model.id, + description=model_version.model.description, + ) + session.add(db_model) + session.commit() + session.refresh(db_model) + console.print(f" Created model: {db_model.id}") + + # Create model version + existing_version = session.exec( + select(TaxBenefitModelVersion).where( + TaxBenefitModelVersion.model_id == db_model.id, + TaxBenefitModelVersion.version == model_version.version, + ) + ).first() + + if existing_version: + console.print( + f" Model version {model_version.version} already exists, skipping" + ) + return existing_version + + db_version = TaxBenefitModelVersion( + model_id=db_model.id, + version=model_version.version, + description=f"Version {model_version.version}", + ) + session.add(db_version) + session.commit() + session.refresh(db_version) + console.print(f" Created version: {db_version.version}") + + # Filter variables by whitelist if provided + variables = model_version.variables + if variable_whitelist is not None: + variables = [v for v in variables if v.name in variable_whitelist] + console.print( + f" Filtered to {len(variables)} variables " + f"(from {len(model_version.variables)} total, whitelist applied)" + ) + + # Add variables + with logfire.span("add_variables", count=len(variables)): + var_rows = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(variables)} variables", + total=len(variables), + ) + for var in variables: + data_type, possible_values = _get_variable_type_info(var) + var_rows.append( + { + "id": uuid4(), + "name": var.name, + "entity": var.entity, + "description": var.description or "", + "data_type": data_type, + "possible_values": possible_values, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(var_rows)} variables...") + bulk_insert( + session, + "variables", + [ + "id", + "name", + "entity", + "description", + "data_type", + "possible_values", + "tax_benefit_model_version_id", + "created_at", + ], + var_rows, + ) + + console.print(f" [green]✓[/green] Added {len(variables)} variables") + + # Add parameters (only user-facing ones: those with labels) + # Deduplicate by name - keep first occurrence + seen_names = set() + parameters_to_add = [] + skipped_state_params_count = 0 + skipped_prefix_count = 0 + for p in model_version.parameters: + if p.label is None or p.name in seen_names: + continue + # Skip state-level parameters if requested + if skip_state_params and p.name.startswith("gov.states."): + skipped_state_params_count += 1 + continue + # Skip parameters not matching prefix whitelist + if parameter_prefixes is not None and not any( + p.name.startswith(prefix) for prefix in parameter_prefixes + ): + skipped_prefix_count += 1 + continue + parameters_to_add.append(p) + seen_names.add(p.name) + + filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters" + filter_msg += ( + f" (from {len(model_version.parameters)} total, deduplicated by name)" + ) + if skip_state_params and skipped_state_params_count > 0: + filter_msg += f", skipped {skipped_state_params_count} state params" + if parameter_prefixes is not None and skipped_prefix_count > 0: + filter_msg += f", skipped {skipped_prefix_count} by prefix filter" + console.print(filter_msg) + + with logfire.span("add_parameters", count=len(parameters_to_add)): + param_rows = [] + param_names = [] # Track (pe_id, name, generated_uuid) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameters_to_add)} parameters", + total=len(parameters_to_add), + ) + for param in parameters_to_add: + param_uuid = uuid4() + param_rows.append( + { + "id": param_uuid, + "name": param.name, + "label": param.label if hasattr(param, "label") else None, + "description": param.description or "", + "data_type": param.data_type.__name__ + if hasattr(param.data_type, "__name__") + else str(param.data_type), + "unit": param.unit, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + param_names.append((param.id, param.name, param_uuid)) + progress.advance(task) + + console.print(f" Inserting {len(param_rows)} parameters...") + bulk_insert( + session, + "parameters", + [ + "id", + "name", + "label", + "description", + "data_type", + "unit", + "tax_benefit_model_version_id", + "created_at", + ], + param_rows, + ) + + # Build param_id_map from pre-generated UUIDs + param_id_map = {pe_id: db_uuid for pe_id, name, db_uuid in param_names} + + console.print( + f" [green]✓[/green] Added {len(parameters_to_add)} parameters" + ) + + # Add parameter values + parameter_values_to_add = [ + pv + for pv in model_version.parameter_values + if pv.parameter.id in param_id_map + ] + console.print(f" Found {len(parameter_values_to_add)} parameter values to add") + + with logfire.span("add_parameter_values", count=len(parameter_values_to_add)): + pv_rows = [] + skipped = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameter_values_to_add)} parameter values", + total=len(parameter_values_to_add), + ) + for pv in parameter_values_to_add: + # Handle Infinity values - skip them as they can't be stored in JSON + if isinstance(pv.value, float) and ( + math.isinf(pv.value) or math.isnan(pv.value) + ): + skipped += 1 + progress.advance(task) + continue + + # Source data has dates swapped (start > end), fix ordering + if pv.start_date and pv.end_date: + start = pv.end_date # Swap: source end is our start + end = pv.start_date # Swap: source start is our end + else: + start = pv.start_date + end = pv.end_date + pv_rows.append( + { + "id": uuid4(), + "parameter_id": param_id_map[pv.parameter.id], + "value_json": json.dumps(pv.value), + "start_date": start, + "end_date": end, + "policy_id": None, + "dynamic_id": None, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(pv_rows)} parameter values...") + bulk_insert( + session, + "parameter_values", + [ + "id", + "parameter_id", + "value_json", + "start_date", + "end_date", + "policy_id", + "dynamic_id", + "created_at", + ], + pv_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(pv_rows)} parameter values" + + (f" (skipped {skipped} invalid)" if skipped else "") + ) + + return db_version + + +def seed_uk_model( + session: Session, + skip_state_params: bool = False, + variable_whitelist: set[str] | None = None, + parameter_prefixes: set[str] | None = None, +): + """Seed UK model.""" + from policyengine.tax_benefit_models.uk import uk_latest + + version = seed_model( + uk_latest, + session, + skip_state_params=skip_state_params, + variable_whitelist=variable_whitelist, + parameter_prefixes=parameter_prefixes, + ) + console.print(f"[green]✓[/green] UK model seeded: {version.id}\n") + return version + + +def seed_us_model( + session: Session, + skip_state_params: bool = False, + variable_whitelist: set[str] | None = None, + parameter_prefixes: set[str] | None = None, +): + """Seed US model.""" + from policyengine.tax_benefit_models.us import us_latest + + version = seed_model( + us_latest, + session, + skip_state_params=skip_state_params, + variable_whitelist=variable_whitelist, + parameter_prefixes=parameter_prefixes, + ) + console.print(f"[green]✓[/green] US model seeded: {version.id}\n") + return version + + +def main(): + parser = argparse.ArgumentParser(description="Seed tax-benefit models") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US model", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK model", + ) + parser.add_argument( + "--skip-state-params", + action="store_true", + help="Skip US state-level parameters (gov.states.*)", + ) + args = parser.parse_args() + + console.print("[bold green]Seeding tax-benefit models...[/bold green]\n") + + with get_session() as session: + if not args.us_only: + seed_uk_model(session, skip_state_params=args.skip_state_params) + + if not args.uk_only: + seed_us_model(session, skip_state_params=args.skip_state_params) + + console.print("[bold green]✓ Model seeding complete![/bold green]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_policies.py b/scripts/seed_policies.py new file mode 100644 index 0000000..8372b98 --- /dev/null +++ b/scripts/seed_policies.py @@ -0,0 +1,192 @@ +"""Seed example policy reforms. + +This script creates example policy reforms for UK and US models. + +Usage: + python scripts/seed_policies.py # Seed UK and US example policies + python scripts/seed_policies.py --us-only # Seed only US example policy + python scripts/seed_policies.py --uk-only # Seed only UK example policy +""" + +import argparse +from datetime import datetime, timezone + +from seed_utils import console, get_session +from sqlmodel import Session, select + +# Import after seed_utils sets up path +from policyengine_api.models import ( # noqa: E402 + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def seed_uk_policy(session: Session) -> bool: + """Seed UK example policy: raise basic rate to 22p. + + Returns: + True if created, False if skipped + """ + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_models.py first.[/red]") + return False + + uk_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == uk_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + if not uk_version: + console.print( + "[red]Error: UK model version not found. Run seed_models.py first.[/red]" + ) + return False + + policy_name = "UK basic rate 22p" + existing = session.exec(select(Policy).where(Policy.name == policy_name)).first() + + if existing: + console.print(f" Policy '{policy_name}' already exists, skipping") + return False + + # Find the basic rate parameter + uk_basic_rate_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", + Parameter.tax_benefit_model_version_id == uk_version.id, + ) + ).first() + + if not uk_basic_rate_param: + console.print(" [yellow]Warning: UK basic rate parameter not found[/yellow]") + return False + + uk_policy = Policy( + name=policy_name, + description="Raise the UK income tax basic rate from 20p to 22p", + tax_benefit_model_id=uk_model.id, + ) + session.add(uk_policy) + session.commit() + session.refresh(uk_policy) + + # Add parameter value (22% = 0.22) + uk_param_value = ParameterValue( + parameter_id=uk_basic_rate_param.id, + value_json={"value": 0.22}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=uk_policy.id, + ) + session.add(uk_param_value) + session.commit() + console.print(f" [green]✓[/green] Created UK policy: {policy_name}") + return True + + +def seed_us_policy(session: Session) -> bool: + """Seed US example policy: raise first bracket to 12%. + + Returns: + True if created, False if skipped + """ + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_models.py first.[/red]") + return False + + us_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == us_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + if not us_version: + console.print( + "[red]Error: US model version not found. Run seed_models.py first.[/red]" + ) + return False + + policy_name = "US 12% lowest bracket" + existing = session.exec(select(Policy).where(Policy.name == policy_name)).first() + + if existing: + console.print(f" Policy '{policy_name}' already exists, skipping") + return False + + # Find the first bracket rate parameter + us_first_bracket_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.irs.income.bracket.rates.1", + Parameter.tax_benefit_model_version_id == us_version.id, + ) + ).first() + + if not us_first_bracket_param: + console.print( + " [yellow]Warning: US first bracket parameter not found[/yellow]" + ) + return False + + us_policy = Policy( + name=policy_name, + description="Raise US federal income tax lowest bracket to 12%", + tax_benefit_model_id=us_model.id, + ) + session.add(us_policy) + session.commit() + session.refresh(us_policy) + + # Add parameter value (12% = 0.12) + us_param_value = ParameterValue( + parameter_id=us_first_bracket_param.id, + value_json={"value": 0.12}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=us_policy.id, + ) + session.add(us_param_value) + session.commit() + console.print(f" [green]✓[/green] Created US policy: {policy_name}") + return True + + +def main(): + parser = argparse.ArgumentParser(description="Seed example policies") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US example policy", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK example policy", + ) + args = parser.parse_args() + + console.print("[bold green]Seeding example policies...[/bold green]\n") + + with get_session() as session: + if not args.us_only: + seed_uk_policy(session) + + if not args.uk_only: + seed_us_policy(session) + + console.print("\n[bold green]✓ Policy seeding complete![/bold green]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py new file mode 100644 index 0000000..e180e7a --- /dev/null +++ b/scripts/seed_regions.py @@ -0,0 +1,385 @@ +"""Seed regions for US and UK geographic analysis. + +This script populates the regions table with: +- US: National, 51 states (incl. DC), 436 congressional districts, 333 places/cities +- UK: National and 4 countries (England, Scotland, Wales, Northern Ireland) + +Regions are sourced from policyengine.py's region registries and linked +to the appropriate datasets via the region_datasets join table. + +This script is the SOLE source of truth for region-to-dataset wiring. +After importing datasets with import_state_datasets.py, re-run this script +to link regions to any newly available datasets. + +Usage: + python scripts/seed_regions.py # Seed all US and UK regions + python scripts/seed_regions.py --us-only # Seed only US regions + python scripts/seed_regions.py --uk-only # Seed only UK regions + python scripts/seed_regions.py --skip-places # Exclude US places (cities) + python scripts/seed_regions.py --skip-districts # Exclude US congressional districts +""" + +import argparse +import time + +from rich.progress import Progress, SpinnerColumn, TextColumn +from seed_utils import console, get_session +from sqlmodel import Session, select + +# Import after seed_utils sets up path +from policyengine_api.models import ( # noqa: E402 + Dataset, + Region, + RegionDatasetLink, + TaxBenefitModel, +) + + +def _group_us_datasets( + session: Session, + us_model_id, +) -> tuple[list[Dataset], dict[str, list[Dataset]], dict[str, list[Dataset]]]: + """Pre-fetch and group all US datasets by type. + + Returns: + (national_datasets, state_datasets_by_code, district_datasets_by_code) + """ + all_datasets = session.exec( + select(Dataset).where(Dataset.tax_benefit_model_id == us_model_id) + ).all() + + national = [] + by_state: dict[str, list[Dataset]] = {} + by_district: dict[str, list[Dataset]] = {} + + for d in all_datasets: + if d.filepath and d.filepath.startswith("states/"): + # filepath = "states/AL/AL-year-2024.h5" + parts = d.filepath.split("/") + if len(parts) >= 2: + by_state.setdefault(parts[1], []).append(d) + elif d.filepath and d.filepath.startswith("districts/"): + # filepath = "districts/AL-01/AL-01-year-2024.h5" + parts = d.filepath.split("/") + if len(parts) >= 2: + by_district.setdefault(parts[1], []).append(d) + elif "cps" in d.name.lower(): + national.append(d) + + return national, by_state, by_district + + +def _get_datasets_for_us_region( + pe_region, + national_datasets: list[Dataset], + state_datasets: dict[str, list[Dataset]], + district_datasets: dict[str, list[Dataset]], +) -> list[Dataset]: + """Determine which datasets a US region should link to.""" + if pe_region.region_type == "national": + return national_datasets + + elif pe_region.region_type == "state": + # "state/ca" -> "CA" + state_code = pe_region.code.split("/")[1].upper() + return state_datasets.get(state_code, national_datasets) + + elif pe_region.region_type == "congressional_district": + # "congressional_district/CA-12" -> "CA-12" + district_code = pe_region.code.split("/")[1].upper() + return district_datasets.get(district_code, national_datasets) + + elif pe_region.region_type == "place": + # Places use parent state's datasets (filter at runtime) + if pe_region.state_code: + return state_datasets.get(pe_region.state_code, national_datasets) + return national_datasets + + return national_datasets + + +def _link_datasets( + region_id, + datasets: list[Dataset], + existing_link_set: set[tuple], + session: Session, +) -> int: + """Create RegionDatasetLink entries for missing links. + + Returns the number of new links created. + """ + created = 0 + for dataset in datasets: + key = (region_id, dataset.id) + if key not in existing_link_set: + session.add(RegionDatasetLink(region_id=region_id, dataset_id=dataset.id)) + existing_link_set.add(key) + created += 1 + return created + + +def seed_us_regions( + session: Session, + skip_places: bool = False, + skip_districts: bool = False, +) -> tuple[int, int, int]: + """Seed US regions from policyengine.py registry. + + Args: + session: Database session + skip_places: Skip US places (cities over 100K population) + skip_districts: Skip congressional districts + + Returns: + Tuple of (created_count, skipped_count, links_created) + """ + from policyengine.countries.us.regions import us_region_registry + + # Get US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed.py first.[/red]") + return 0, 0, 0 + + # Pre-fetch and group all US datasets + national_datasets, state_datasets, district_datasets = _group_us_datasets( + session, us_model.id + ) + + if not national_datasets: + console.print("[red]Error: No US CPS datasets found. Run seed.py first.[/red]") + return 0, 0, 0 + + # Pre-fetch existing dataset links for efficiency + existing_links = session.exec(select(RegionDatasetLink)).all() + existing_link_set = {(l.region_id, l.dataset_id) for l in existing_links} + + created = 0 + skipped = 0 + links_created = 0 + + # Filter regions based on options + regions_to_seed = [] + for region in us_region_registry.regions: + if region.region_type == "national": + regions_to_seed.append(region) + elif region.region_type == "state": + regions_to_seed.append(region) + elif region.region_type == "congressional_district" and not skip_districts: + regions_to_seed.append(region) + elif region.region_type == "place" and not skip_places: + regions_to_seed.append(region) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US regions", total=len(regions_to_seed)) + + for pe_region in regions_to_seed: + progress.update(task, description=f"US: {pe_region.label}") + + # Find existing or create new region + existing = session.exec( + select(Region).where(Region.code == pe_region.code) + ).first() + + if existing: + db_region = existing + skipped += 1 + else: + db_region = Region( + code=pe_region.code, + label=pe_region.label, + region_type=pe_region.region_type, + requires_filter=pe_region.requires_filter, + filter_field=pe_region.filter_field, + filter_value=pe_region.filter_value, + parent_code=pe_region.parent_code, + state_code=pe_region.state_code, + state_name=pe_region.state_name, + tax_benefit_model_id=us_model.id, + ) + session.add(db_region) + session.flush() # Get the ID assigned + created += 1 + + # Link datasets for this region + datasets = _get_datasets_for_us_region( + pe_region, national_datasets, state_datasets, district_datasets + ) + links_created += _link_datasets( + db_region.id, datasets, existing_link_set, session + ) + + progress.advance(task) + + session.commit() + + return created, skipped, links_created + + +def seed_uk_regions(session: Session) -> tuple[int, int, int]: + """Seed UK regions from policyengine.py registry. + + Args: + session: Database session + + Returns: + Tuple of (created_count, skipped_count, links_created) + """ + from policyengine.countries.uk.regions import uk_region_registry + + # Get UK model + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print( + "[yellow]Warning: UK model not found. Skipping UK regions.[/yellow]" + ) + return 0, 0, 0 + + # Get all UK FRS datasets + uk_datasets = session.exec( + select(Dataset) + .where(Dataset.tax_benefit_model_id == uk_model.id) + .where(Dataset.name.contains("frs")) # type: ignore + ).all() + + if not uk_datasets: + console.print( + "[yellow]Warning: No UK FRS datasets found. Skipping UK regions.[/yellow]" + ) + return 0, 0, 0 + + # Pre-fetch existing dataset links + existing_links = session.exec(select(RegionDatasetLink)).all() + existing_link_set = {(l.region_id, l.dataset_id) for l in existing_links} + + created = 0 + skipped = 0 + links_created = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK regions", total=len(uk_region_registry.regions)) + + for pe_region in uk_region_registry.regions: + progress.update(task, description=f"UK: {pe_region.label}") + + # Find existing or create new region + existing = session.exec( + select(Region).where(Region.code == pe_region.code) + ).first() + + if existing: + db_region = existing + skipped += 1 + else: + db_region = Region( + code=pe_region.code, + label=pe_region.label, + region_type=pe_region.region_type, + requires_filter=pe_region.requires_filter, + filter_field=pe_region.filter_field, + filter_value=pe_region.filter_value, + parent_code=pe_region.parent_code, + state_code=None, + state_name=None, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_region) + session.flush() + created += 1 + + # All UK regions link to FRS datasets (they filter at runtime) + links_created += _link_datasets( + db_region.id, uk_datasets, existing_link_set, session + ) + + progress.advance(task) + + session.commit() + + return created, skipped, links_created + + +def main(): + parser = argparse.ArgumentParser(description="Seed US and UK regions") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US regions", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK regions", + ) + parser.add_argument( + "--skip-places", + action="store_true", + help="Skip US places (cities over 100K population)", + ) + parser.add_argument( + "--skip-districts", + action="store_true", + help="Skip US congressional districts", + ) + args = parser.parse_args() + + console.print("[bold green]Seeding regions...[/bold green]\n") + + start = time.time() + total_created = 0 + total_skipped = 0 + total_links = 0 + + with get_session() as session: + # Seed US regions + if not args.uk_only: + console.print("[bold]US Regions[/bold]") + us_created, us_skipped, us_links = seed_us_regions( + session, + skip_places=args.skip_places, + skip_districts=args.skip_districts, + ) + total_created += us_created + total_skipped += us_skipped + total_links += us_links + console.print( + f"[green]\u2713[/green] US regions: {us_created} created, " + f"{us_skipped} skipped, {us_links} dataset links added\n" + ) + + # Seed UK regions + if not args.us_only: + console.print("[bold]UK Regions[/bold]") + uk_created, uk_skipped, uk_links = seed_uk_regions(session) + total_created += uk_created + total_skipped += uk_skipped + total_links += uk_links + console.print( + f"[green]\u2713[/green] UK regions: {uk_created} created, " + f"{uk_skipped} skipped, {uk_links} dataset links added\n" + ) + + elapsed = time.time() - start + console.print( + f"[bold]Total: {total_created} created, {total_skipped} skipped, " + f"{total_links} dataset links added[/bold]" + ) + console.print(f"[bold]Time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_utils.py b/scripts/seed_utils.py new file mode 100644 index 0000000..624379f --- /dev/null +++ b/scripts/seed_utils.py @@ -0,0 +1,72 @@ +"""Shared utilities for seed scripts.""" + +import io +import logging +import sys +import warnings +from pathlib import Path + +import logfire +from rich.console import Console +from sqlmodel import Session, create_engine + +# Disable all SQLAlchemy and database logging +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from policyengine_api.config.settings import settings # noqa: E402 + +# Configure logfire +if settings.logfire_token: + logfire.configure( + token=settings.logfire_token, + environment=settings.logfire_environment, + console=False, + ) + +console = Console() + + +def get_session() -> Session: + """Get database session with logging disabled.""" + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def bulk_insert(session: Session, table: str, columns: list[str], rows: list[dict]): + """Fast bulk insert using PostgreSQL COPY via StringIO.""" + if not rows: + return + + # Get raw psycopg2 connection + connection = session.connection() + raw_conn = connection.connection.dbapi_connection + cursor = raw_conn.cursor() + + # Build CSV-like data in memory + output = io.StringIO() + for row in rows: + values = [] + for col in columns: + val = row[col] + if val is None: + values.append("\\N") + elif isinstance(val, str): + # Escape special characters for COPY + val = ( + val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") + ) + values.append(val) + else: + values.append(str(val)) + output.write("\t".join(values) + "\n") + + output.seek(0) + + # COPY is the fastest way to bulk load PostgreSQL + cursor.copy_from(output, table, columns=columns, null="\\N") + session.commit() diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index 9d0436c..bcf1ab0 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -235,8 +235,7 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]: prop = schema_to_json_schema(spec, param_schema) prop["description"] = ( - param.get("description", "") - + f" (in: {param_in})" + param.get("description", "") + f" (in: {param_in})" ) properties[param_name] = prop @@ -268,16 +267,18 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]: if required: input_schema["required"] = list(set(required)) - tools.append({ - "name": tool_name, - "description": full_desc[:1024], # Claude has limits - "input_schema": input_schema, - "_meta": { - "path": path, - "method": method, - "parameters": operation.get("parameters", []), - }, - }) + tools.append( + { + "name": tool_name, + "description": full_desc[:1024], # Claude has limits + "input_schema": input_schema, + "_meta": { + "path": path, + "method": method, + "parameters": operation.get("parameters", []), + }, + } + ) return tools @@ -347,7 +348,9 @@ def execute_api_tool( url, params=query_params, json=body_data, headers=headers, timeout=60 ) elif method == "delete": - resp = requests.delete(url, params=query_params, headers=headers, timeout=60) + resp = requests.delete( + url, params=query_params, headers=headers, timeout=60 + ) else: return f"Unsupported method: {method}" @@ -415,9 +418,7 @@ def log(msg: str) -> None: tool_lookup = {t["name"]: t for t in tools} # Strip _meta from tools before sending to Claude (it doesn't need it) - claude_tools = [ - {k: v for k, v in t.items() if k != "_meta"} for t in tools - ] + claude_tools = [{k: v for k, v in t.items() if k != "_meta"} for t in tools] # Add the sleep tool claude_tools.append(SLEEP_TOOL) @@ -477,11 +478,13 @@ def log(msg: str) -> None: log(f"[TOOL_RESULT] {result[:300]}") - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) messages.append({"role": "assistant", "content": assistant_content}) diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 881af99..3e1db4a 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -9,13 +9,20 @@ datasets, dynamics, household, + household_analysis, + households, outputs, parameter_values, parameters, policies, + regions, simulations, tax_benefit_model_versions, tax_benefit_models, + user_household_associations, + user_policies, + user_report_associations, + user_simulation_associations, variables, ) @@ -23,6 +30,7 @@ api_router.include_router(datasets.router) api_router.include_router(policies.router) +api_router.include_router(regions.router) api_router.include_router(simulations.router) api_router.include_router(outputs.router) api_router.include_router(variables.router) @@ -33,7 +41,14 @@ api_router.include_router(tax_benefit_model_versions.router) api_router.include_router(change_aggregates.router) api_router.include_router(household.router) +api_router.include_router(household_analysis.router) +api_router.include_router(households.router) api_router.include_router(analysis.router) api_router.include_router(agent.router) +api_router.include_router(user_household_associations.router) +api_router.include_router(user_policies.router) +api_router.include_router(user_simulation_associations.router) +api_router.include_router(user_report_associations.router) +api_router.include_router(user_report_associations.reports_router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/agent.py b/src/policyengine_api/api/agent.py index 7b7d108..6c26e80 100644 --- a/src/policyengine_api/api/agent.py +++ b/src/policyengine_api/api/agent.py @@ -24,6 +24,7 @@ def get_traceparent() -> str | None: TraceContextTextMapPropagator().inject(carrier) return carrier.get("traceparent") + router = APIRouter(prefix="/agent", tags=["agent"]) @@ -93,7 +94,9 @@ def _run_local_agent( from policyengine_api.agent_sandbox import _run_agent_impl try: - history_dicts = [{"role": m.role, "content": m.content} for m in (history or [])] + history_dicts = [ + {"role": m.role, "content": m.content} for m in (history or []) + ] result = _run_agent_impl(question, api_base_url, call_id, history_dicts) _calls[call_id]["status"] = result.get("status", "completed") _calls[call_id]["result"] = result @@ -136,9 +139,15 @@ async def run_agent(request: RunRequest) -> RunResponse: traceparent = get_traceparent() run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent") - history_dicts = [{"role": m.role, "content": m.content} for m in request.history] + history_dicts = [ + {"role": m.role, "content": m.content} for m in request.history + ] call = run_fn.spawn( - request.question, api_base_url, call_id, history_dicts, traceparent=traceparent + request.question, + api_base_url, + call_id, + history_dicts, + traceparent=traceparent, ) _calls[call_id] = { @@ -166,7 +175,12 @@ async def run_agent(request: RunRequest) -> RunResponse: # Run in background using asyncio loop = asyncio.get_event_loop() loop.run_in_executor( - None, _run_local_agent, call_id, request.question, api_base_url, request.history + None, + _run_local_agent, + call_id, + request.question, + api_base_url, + request.history, ) return RunResponse(call_id=call_id, status="running") diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c9aa86d..a46eaf2 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -22,19 +22,41 @@ import logfire from fastapi import APIRouter, Depends, HTTPException from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlmodel import Session, select +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + get_modules_for_country, + validate_modules, +) from policyengine_api.models import ( + BudgetSummary, + BudgetSummaryRead, + CongressionalDistrictImpact, + CongressionalDistrictImpactRead, + ConstituencyImpact, + ConstituencyImpactRead, Dataset, DecileImpact, DecileImpactRead, + Inequality, + InequalityRead, + IntraDecileImpact, + IntraDecileImpactRead, + LocalAuthorityImpact, + LocalAuthorityImpactRead, + Poverty, + PovertyRead, ProgramStatistics, ProgramStatisticsRead, + Region, + RegionDatasetLink, Report, ReportStatus, Simulation, SimulationStatus, + SimulationType, TaxBenefitModel, TaxBenefitModelVersion, ) @@ -64,22 +86,73 @@ def _safe_float(value: float | None) -> float | None: router = APIRouter(prefix="/analysis", tags=["analysis"]) +# --------------------------------------------------------------------------- +# GET /analysis/options — list available computation modules +# --------------------------------------------------------------------------- + + +class ModuleOption(BaseModel): + """A single computation module available for economy analysis.""" + + name: str + label: str + description: str + response_fields: list[str] + + +@router.get("/options", response_model=list[ModuleOption]) +def list_analysis_options( + country: str | None = None, +) -> list[ModuleOption]: + """List available economy analysis modules. + + Args: + country: Optional country code ('uk' or 'us') to filter modules. + """ + if country: + modules = get_modules_for_country(country) + else: + modules = list(MODULE_REGISTRY.values()) + + return [ + ModuleOption( + name=m.name, + label=m.label, + description=m.description, + response_fields=list(m.response_fields), + ) + for m in modules + ] + + class EconomicImpactRequest(BaseModel): """Request body for economic impact analysis. - Example: + Example with dataset_id: { "tax_benefit_model_name": "policyengine_uk", "dataset_id": "uuid-from-datasets-endpoint", "policy_id": "uuid-of-reform-policy" } + + Example with region: + { + "tax_benefit_model_name": "policyengine_us", + "region": "state/ca", + "policy_id": "uuid-of-reform-policy" + } """ tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( description="Which country model to use" ) - dataset_id: UUID = Field( - description="Dataset ID from /datasets endpoint containing population microdata" + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID from /datasets endpoint. Either dataset_id or region must be provided.", + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us'). Either dataset_id or region must be provided.", ) policy_id: UUID | None = Field( default=None, @@ -88,6 +161,16 @@ class EconomicImpactRequest(BaseModel): dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" ) + year: int | None = Field( + default=None, + description="Year for the analysis (e.g., 2026). Selects the dataset for that year. Uses latest available if omitted.", + ) + + @model_validator(mode="after") + def check_dataset_or_region(self) -> "EconomicImpactRequest": + if not self.dataset_id and not self.region: + raise ValueError("Either dataset_id or region must be provided") + return self class SimulationInfo(BaseModel): @@ -98,6 +181,17 @@ class SimulationInfo(BaseModel): error_message: str | None = None +class RegionInfo(BaseModel): + """Region information used in analysis.""" + + code: str + label: str + region_type: str + requires_filter: bool + filter_field: str | None = None + filter_value: str | None = None + + class EconomicImpactResponse(BaseModel): """Response from economic impact analysis.""" @@ -105,9 +199,20 @@ class EconomicImpactResponse(BaseModel): status: ReportStatus baseline_simulation: SimulationInfo reform_simulation: SimulationInfo + region: RegionInfo | None = None error_message: str | None = None decile_impacts: list[DecileImpactRead] | None = None program_statistics: list[ProgramStatisticsRead] | None = None + poverty: list[PovertyRead] | None = None + inequality: list[InequalityRead] | None = None + budget_summary: list[BudgetSummaryRead] | None = None + intra_decile: list[IntraDecileImpactRead] | None = None + detailed_budget: dict[str, dict[str, float | None]] | None = None + congressional_district_impact: list[CongressionalDistrictImpactRead] | None = None + constituency_impact: list[ConstituencyImpactRead] | None = None + local_authority_impact: list[LocalAuthorityImpactRead] | None = None + wealth_decile: list[DecileImpactRead] | None = None + intra_wealth_decile: list[IntraDecileImpactRead] | None = None def _get_model_version( @@ -138,19 +243,26 @@ def _get_model_version( def _get_deterministic_simulation_id( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, + dataset_id: UUID | None = None, + household_id: UUID | None = None, + filter_field: str | None = None, + filter_value: str | None = None, ) -> UUID: """Generate a deterministic UUID from simulation parameters.""" - key = f"{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + if simulation_type == SimulationType.ECONOMY: + key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}:{filter_field}:{filter_value}" + else: + key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}" return uuid5(SIMULATION_NAMESPACE, key) def _get_deterministic_report_id( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, ) -> UUID: """Generate a deterministic UUID from report parameters.""" key = f"{baseline_sim_id}:{reform_sim_id}" @@ -158,15 +270,28 @@ def _get_deterministic_report_id( def _get_or_create_simulation( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, session: Session, + dataset_id: UUID | None = None, + household_id: UUID | None = None, + filter_field: str | None = None, + filter_value: str | None = None, + region_id: UUID | None = None, + year: int | None = None, ) -> Simulation: """Get existing simulation or create a new one.""" sim_id = _get_deterministic_simulation_id( - dataset_id, model_version_id, policy_id, dynamic_id + simulation_type, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + household_id=household_id, + filter_field=filter_field, + filter_value=filter_value, ) existing = session.get(Simulation, sim_id) @@ -175,22 +300,38 @@ def _get_or_create_simulation( simulation = Simulation( id=sim_id, + simulation_type=simulation_type, dataset_id=dataset_id, + household_id=household_id, tax_benefit_model_version_id=model_version_id, policy_id=policy_id, dynamic_id=dynamic_id, status=SimulationStatus.PENDING, + filter_field=filter_field, + filter_value=filter_value, + region_id=region_id, + year=year, ) + from sqlalchemy.exc import IntegrityError + session.add(simulation) - session.commit() + try: + session.commit() + except IntegrityError: + session.rollback() + existing = session.get(Simulation, sim_id) + if existing: + return existing + raise session.refresh(simulation) return simulation def _get_or_create_report( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, label: str, + report_type: str, session: Session, ) -> Report: """Get existing report or create a new one.""" @@ -203,12 +344,22 @@ def _get_or_create_report( report = Report( id=report_id, label=label, + report_type=report_type, baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, status=ReportStatus.PENDING, ) + from sqlalchemy.exc import IntegrityError + session.add(report) - session.commit() + try: + session.commit() + except IntegrityError: + session.rollback() + existing = session.get(Report, report_id) + if existing: + return existing + raise session.refresh(report) return report @@ -218,10 +369,21 @@ def _build_response( baseline_sim: Simulation, reform_sim: Simulation, session: Session, + region: Region | None = None, ) -> EconomicImpactResponse: """Build response from report and simulations.""" decile_impacts = None program_statistics = None + poverty_records = None + inequality_records = None + budget_summary_records = None + intra_decile_records = None + detailed_budget = None + district_impact_records = None + constituency_impact_records = None + local_authority_impact_records = None + wealth_decile_records = None + intra_wealth_decile_records = None if report.status == ReportStatus.COMPLETED: # Fetch decile impacts for this report @@ -275,6 +437,249 @@ def _build_response( for s in stats ] + # Build detailed_budget: V1-compatible per-program breakdown + # keyed by program name with baseline/reform/difference values. + detailed_budget = { + s.program_name: { + "baseline": _safe_float(s.baseline_total), + "reform": _safe_float(s.reform_total), + "difference": _safe_float(s.change), + } + for s in stats + } + + # Fetch poverty records for this report + pov_rows = session.exec( + select(Poverty).where(Poverty.report_id == report.id) + ).all() + poverty_records = [ + PovertyRead( + id=p.id, + created_at=p.created_at, + simulation_id=p.simulation_id, + report_id=p.report_id, + poverty_type=p.poverty_type, + entity=p.entity, + filter_variable=p.filter_variable, + headcount=_safe_float(p.headcount), + total_population=_safe_float(p.total_population), + rate=_safe_float(p.rate), + ) + for p in pov_rows + ] + + # Fetch inequality records for this report + ineq_rows = session.exec( + select(Inequality).where(Inequality.report_id == report.id) + ).all() + inequality_records = [ + InequalityRead( + id=i.id, + created_at=i.created_at, + simulation_id=i.simulation_id, + report_id=i.report_id, + income_variable=i.income_variable, + entity=i.entity, + gini=_safe_float(i.gini), + top_10_share=_safe_float(i.top_10_share), + top_1_share=_safe_float(i.top_1_share), + bottom_50_share=_safe_float(i.bottom_50_share), + ) + for i in ineq_rows + ] + + # Fetch budget summary records for this report + budget_rows = session.exec( + select(BudgetSummary).where(BudgetSummary.report_id == report.id) + ).all() + budget_summary_records = [ + BudgetSummaryRead( + id=b.id, + created_at=b.created_at, + baseline_simulation_id=b.baseline_simulation_id, + reform_simulation_id=b.reform_simulation_id, + report_id=b.report_id, + variable_name=b.variable_name, + entity=b.entity, + baseline_total=_safe_float(b.baseline_total), + reform_total=_safe_float(b.reform_total), + change=_safe_float(b.change), + ) + for b in budget_rows + ] + + # Fetch intra-decile impact records for this report + intra_rows = session.exec( + select(IntraDecileImpact).where(IntraDecileImpact.report_id == report.id) + ).all() + intra_decile_records = [ + IntraDecileImpactRead( + id=r.id, + created_at=r.created_at, + baseline_simulation_id=r.baseline_simulation_id, + reform_simulation_id=r.reform_simulation_id, + report_id=r.report_id, + decile=r.decile, + lose_more_than_5pct=_safe_float(r.lose_more_than_5pct), + lose_less_than_5pct=_safe_float(r.lose_less_than_5pct), + no_change=_safe_float(r.no_change), + gain_less_than_5pct=_safe_float(r.gain_less_than_5pct), + gain_more_than_5pct=_safe_float(r.gain_more_than_5pct), + ) + for r in intra_rows + ] + + # Fetch congressional district impact records for this report + district_rows = session.exec( + select(CongressionalDistrictImpact).where( + CongressionalDistrictImpact.report_id == report.id + ) + ).all() + if district_rows: + district_impact_records = [ + CongressionalDistrictImpactRead( + id=d.id, + created_at=d.created_at, + baseline_simulation_id=d.baseline_simulation_id, + reform_simulation_id=d.reform_simulation_id, + report_id=d.report_id, + district_geoid=d.district_geoid, + state_fips=d.state_fips, + district_number=d.district_number, + average_household_income_change=_safe_float( + d.average_household_income_change + ), + relative_household_income_change=_safe_float( + d.relative_household_income_change + ), + population=_safe_float(d.population), + ) + for d in district_rows + ] + + # Fetch constituency impact records for this report + constituency_rows = session.exec( + select(ConstituencyImpact).where(ConstituencyImpact.report_id == report.id) + ).all() + if constituency_rows: + constituency_impact_records = [ + ConstituencyImpactRead( + id=c.id, + created_at=c.created_at, + baseline_simulation_id=c.baseline_simulation_id, + reform_simulation_id=c.reform_simulation_id, + report_id=c.report_id, + constituency_code=c.constituency_code, + constituency_name=c.constituency_name, + x=c.x, + y=c.y, + average_household_income_change=_safe_float( + c.average_household_income_change + ), + relative_household_income_change=_safe_float( + c.relative_household_income_change + ), + population=_safe_float(c.population), + ) + for c in constituency_rows + ] + + # Fetch local authority impact records for this report + la_rows = session.exec( + select(LocalAuthorityImpact).where( + LocalAuthorityImpact.report_id == report.id + ) + ).all() + if la_rows: + local_authority_impact_records = [ + LocalAuthorityImpactRead( + id=la.id, + created_at=la.created_at, + baseline_simulation_id=la.baseline_simulation_id, + reform_simulation_id=la.reform_simulation_id, + report_id=la.report_id, + local_authority_code=la.local_authority_code, + local_authority_name=la.local_authority_name, + x=la.x, + y=la.y, + average_household_income_change=_safe_float( + la.average_household_income_change + ), + relative_household_income_change=_safe_float( + la.relative_household_income_change + ), + population=_safe_float(la.population), + ) + for la in la_rows + ] + + # Fetch wealth decile impact records (UK only) + wealth_decile_rows = session.exec( + select(DecileImpact).where( + DecileImpact.report_id == report.id, + DecileImpact.income_variable == "household_wealth_decile", + ) + ).all() + if wealth_decile_rows: + wealth_decile_records = [ + DecileImpactRead( + id=d.id, + created_at=d.created_at, + baseline_simulation_id=d.baseline_simulation_id, + reform_simulation_id=d.reform_simulation_id, + report_id=d.report_id, + income_variable=d.income_variable, + entity=d.entity, + decile=d.decile, + quantiles=d.quantiles, + baseline_mean=_safe_float(d.baseline_mean), + reform_mean=_safe_float(d.reform_mean), + absolute_change=_safe_float(d.absolute_change), + relative_change=_safe_float(d.relative_change), + count_better_off=_safe_float(d.count_better_off), + count_worse_off=_safe_float(d.count_worse_off), + count_no_change=_safe_float(d.count_no_change), + ) + for d in wealth_decile_rows + ] + + # Fetch intra-wealth-decile records (UK only) + intra_wealth_rows = session.exec( + select(IntraDecileImpact).where( + IntraDecileImpact.report_id == report.id, + IntraDecileImpact.decile_type == "wealth", + ) + ).all() + if intra_wealth_rows: + intra_wealth_decile_records = [ + IntraDecileImpactRead( + id=r.id, + created_at=r.created_at, + baseline_simulation_id=r.baseline_simulation_id, + reform_simulation_id=r.reform_simulation_id, + report_id=r.report_id, + decile_type=r.decile_type, + decile=r.decile, + lose_more_than_5pct=_safe_float(r.lose_more_than_5pct), + lose_less_than_5pct=_safe_float(r.lose_less_than_5pct), + no_change=_safe_float(r.no_change), + gain_less_than_5pct=_safe_float(r.gain_less_than_5pct), + gain_more_than_5pct=_safe_float(r.gain_more_than_5pct), + ) + for r in intra_wealth_rows + ] + + region_info = None + if region: + region_info = RegionInfo( + code=region.code, + label=region.label, + region_type=region.region_type, + requires_filter=region.requires_filter, + filter_field=region.filter_field, + filter_value=region.filter_value, + ) + return EconomicImpactResponse( report_id=report.id, status=report.status, @@ -288,9 +693,20 @@ def _build_response( status=reform_sim.status, error_message=reform_sim.error_message, ), + region=region_info, error_message=report.error_message, decile_impacts=decile_impacts, program_statistics=program_statistics, + poverty=poverty_records, + inequality=inequality_records, + budget_summary=budget_summary_records, + intra_decile=intra_decile_records, + detailed_budget=detailed_budget, + congressional_district_impact=district_impact_records, + constituency_impact=constituency_impact_records, + local_authority_impact=local_authority_impact_records, + wealth_decile=wealth_decile_records, + intra_wealth_decile=intra_wealth_decile_records, ) @@ -318,7 +734,9 @@ def _download_dataset_local(filepath: str) -> str: return str(cache_path) -def _run_local_economy_comparison_uk(job_id: str, session: Session) -> None: +def _run_local_economy_comparison_uk( + job_id: str, session: Session, modules: list[str] | None = None +) -> None: """Run UK economy comparison analysis locally.""" from datetime import datetime, timezone from uuid import UUID @@ -327,12 +745,8 @@ def _run_local_economy_comparison_uk(job_id: str, session: Session) -> None: from policyengine.core.dynamic import Dynamic as PEDynamic from policyengine.core.policy import ParameterValue as PEParameterValue from policyengine.core.policy import Policy as PEPolicy - from policyengine.outputs import DecileImpact as PEDecileImpact from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) from policyengine_api.models import Policy as DBPolicy @@ -365,7 +779,7 @@ def build_policy(policy_id): return None db_policy = session.get(DBPolicy, policy_id) if not db_policy: - return None + raise ValueError(f"Policy {policy_id} not found in database") pe_param_values = [] for pv in db_policy.parameter_values: if not pv.parameter: @@ -432,12 +846,14 @@ def build_dynamic(dynamic_id): year=dataset.year, ) - # Run simulations + # Run simulations (with optional regional filtering) pe_baseline_sim = PESimulation( dataset=pe_dataset, tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, ) pe_baseline_sim.ensure() @@ -446,71 +862,185 @@ def build_dynamic(dynamic_id): tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, ) pe_reform_sim.ensure() - # Calculate decile impacts - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], + # Run computation modules + from policyengine_api.api.computation_modules import UK_MODULE_DISPATCH, run_modules + + run_modules( + dispatch=UK_MODULE_DISPATCH, + modules=modules, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, + report_id=report.id, + session=session, + country_id="uk", + ) + + # Mark completed + baseline_sim.status = SimulationStatus.COMPLETED + baseline_sim.completed_at = datetime.now(timezone.utc) + reform_sim.status = SimulationStatus.COMPLETED + reform_sim.completed_at = datetime.now(timezone.utc) + report.status = ReportStatus.COMPLETED + session.add(baseline_sim) + session.add(reform_sim) + session.add(report) + session.commit() + + +def _run_local_economy_comparison_us( + job_id: str, session: Session, modules: list[str] | None = None +) -> None: + """Run US economy comparison analysis locally.""" + from datetime import datetime, timezone + from uuid import UUID + + from policyengine.core import Simulation as PESimulation + from policyengine.core.dynamic import Dynamic as PEDynamic + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset + + from policyengine_api.models import Policy as DBPolicy + + # Load report and simulations + report = session.get(Report, UUID(job_id)) + if not report: + raise ValueError(f"Report {job_id} not found") + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) + + if not baseline_sim or not reform_sim: + raise ValueError("Simulations not found") + + # Update status to running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Get dataset + dataset = session.get(Dataset, baseline_sim.dataset_id) + if not dataset: + raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") + + pe_model_version = us_latest + param_lookup = {p.name: p for p in pe_model_version.parameters} + + def build_policy(policy_id): + if not policy_id: + return None + db_policy = session.get(DBPolicy, policy_id) + if not db_policy: + raise ValueError(f"Policy {policy_id} not found in database") + pe_param_values = [] + for pv in db_policy.parameter_values: + if not pv.parameter: + continue + pe_param = param_lookup.get(pv.parameter.name) + if not pe_param: + continue + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, + pe_param_values.append(pe_pv) + return PEPolicy( + name=db_policy.name, + description=db_policy.description, + parameter_values=pe_param_values, + ) + + def build_dynamic(dynamic_id): + if not dynamic_id: + return None + from policyengine_api.models import Dynamic as DBDynamic + + db_dynamic = session.get(DBDynamic, dynamic_id) + if not db_dynamic: + return None + pe_param_values = [] + for pv in db_dynamic.parameter_values: + if not pv.parameter: + continue + pe_param = param_lookup.get(pv.parameter.name) + if not pe_param: + continue + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, ) - session.add(program_stat) - except KeyError: - pass # Variable not found in model + pe_param_values.append(pe_pv) + return PEDynamic( + name=db_dynamic.name, + description=db_dynamic.description, + parameter_values=pe_param_values, + ) + + baseline_policy = build_policy(baseline_sim.policy_id) + reform_policy = build_policy(reform_sim.policy_id) + baseline_dynamic = build_dynamic(baseline_sim.dynamic_id) + reform_dynamic = build_dynamic(reform_sim.dynamic_id) + + # Download dataset + local_path = _download_dataset_local(dataset.filepath) + pe_dataset = PolicyEngineUSDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, + ) + + # Run simulations (with optional regional filtering) + pe_baseline_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=baseline_policy, + dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, + ) + pe_baseline_sim.ensure() + + pe_reform_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=reform_policy, + dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, + ) + pe_reform_sim.ensure() + + # Run computation modules + from policyengine_api.api.computation_modules import US_MODULE_DISPATCH, run_modules + + run_modules( + dispatch=US_MODULE_DISPATCH, + modules=modules, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, + report_id=report.id, + session=session, + country_id="us", + ) # Mark completed baseline_sim.status = SimulationStatus.COMPLETED @@ -525,9 +1055,16 @@ def build_dynamic(dynamic_id): def _trigger_economy_comparison( - job_id: str, tax_benefit_model_name: str, session: Session | None = None + job_id: str, + tax_benefit_model_name: str, + session: Session | None = None, + modules: list[str] | None = None, ) -> None: - """Trigger economy comparison analysis (local or Modal).""" + """Trigger economy comparison analysis (local or Modal). + + Args: + modules: Optional list of module names to run. If None, runs all. + """ from policyengine_api.config import settings traceparent = get_traceparent() @@ -535,23 +1072,107 @@ def _trigger_economy_comparison( if not settings.agent_use_modal and session is not None: # Run locally if tax_benefit_model_name == "policyengine_uk": - _run_local_economy_comparison_uk(job_id, session) + _run_local_economy_comparison_uk(job_id, session, modules=modules) else: - # US not implemented for local yet - fall back to Modal - import modal - - fn = modal.Function.from_name("policyengine", "economy_comparison_us") - fn.spawn(job_id=job_id, traceparent=traceparent) + _run_local_economy_comparison_us(job_id, session, modules=modules) else: - # Use Modal + # Use Modal (modules param passed for future selective computation) import modal if tax_benefit_model_name == "policyengine_uk": - fn = modal.Function.from_name("policyengine", "economy_comparison_uk") + fn = modal.Function.from_name( + "policyengine", + "economy_comparison_uk", + environment_name=settings.modal_environment, + ) + else: + fn = modal.Function.from_name( + "policyengine", + "economy_comparison_us", + environment_name=settings.modal_environment, + ) + + try: + fn.spawn(job_id=job_id, traceparent=traceparent) + except Exception as e: + # Mark report as FAILED so it doesn't stay PENDING forever + if session is not None: + from uuid import UUID + + report = session.get(Report, UUID(job_id)) + if report: + report.status = ReportStatus.FAILED + report.error_message = f"Failed to trigger computation: {e}" + session.add(report) + session.commit() + raise HTTPException( + status_code=502, + detail=f"Failed to trigger computation: {e}", + ) + + +def _resolve_dataset_and_region( + request: EconomicImpactRequest, + session: Session, +) -> tuple[Dataset, Region | None]: + """Resolve dataset from request, optionally via region lookup. + + When a region is provided, the dataset is resolved from the region_datasets + join table. If request.year is set, the dataset for that year is selected; + otherwise the latest available year is used. + + Returns: + Tuple of (dataset, region) where region is None if dataset_id was provided directly. + """ + if request.region: + # Look up region by code + model_name = request.tax_benefit_model_name.replace("_", "-") + region = session.exec( + select(Region) + .join(TaxBenefitModel) + .where(Region.code == request.region) + .where(TaxBenefitModel.name == model_name) + ).first() + + if not region: + raise HTTPException( + status_code=404, + detail=f"Region '{request.region}' not found for model {model_name}", + ) + + # Resolve dataset from join table, filtered by year if provided + query = ( + select(Dataset) + .join(RegionDatasetLink) + .where(RegionDatasetLink.region_id == region.id) + ) + if request.year: + query = query.where(Dataset.year == request.year) else: - fn = modal.Function.from_name("policyengine", "economy_comparison_us") + query = query.order_by(Dataset.year.desc()) # type: ignore + dataset = session.exec(query).first() + + if not dataset: + year_msg = f" for year {request.year}" if request.year else "" + raise HTTPException( + status_code=404, + detail=f"No dataset found for region '{request.region}'{year_msg}", + ) + return dataset, region + + elif request.dataset_id: + dataset = session.get(Dataset, request.dataset_id) + if not dataset: + raise HTTPException( + status_code=404, detail=f"Dataset {request.dataset_id} not found" + ) + return dataset, None - fn.spawn(job_id=job_id, traceparent=traceparent) + else: + raise HTTPException( + status_code=400, + detail="Either dataset_id or region must be provided", + ) @router.post("/economic-impact", response_model=EconomicImpactResponse) @@ -567,32 +1188,46 @@ def economic_impact( Results include decile impacts (income changes by income group) and program statistics (budgetary effects of tax/benefit programs). + + You can specify the geographic scope either by: + - dataset_id: Direct dataset reference + - region: Region code (e.g., "state/ca", "us") which resolves to a dataset """ - # Validate dataset exists - dataset = session.get(Dataset, request.dataset_id) - if not dataset: - raise HTTPException( - status_code=404, detail=f"Dataset {request.dataset_id} not found" - ) + # Resolve dataset (and optionally region) + dataset, region = _resolve_dataset_and_region(request, session) + + # Extract filter parameters from region (if present) + filter_field = region.filter_field if region and region.requires_filter else None + filter_value = region.filter_value if region and region.requires_filter else None # Get model version model_version = _get_model_version(request.tax_benefit_model_name, session) - # Get or create simulations + # Get or create simulations using the resolved dataset baseline_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=None, dynamic_id=request.dynamic_id, session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + region_id=region.id if region else None, + year=dataset.year, ) reform_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=request.policy_id, dynamic_id=request.dynamic_id, session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + region_id=region.id if region else None, + year=dataset.year, ) # Get or create report @@ -600,7 +1235,9 @@ def economic_impact( if request.policy_id: label += f" (policy {request.policy_id})" - report = _get_or_create_report(baseline_sim.id, reform_sim.id, label, session) + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) # Trigger computation if report is pending if report.status == ReportStatus.PENDING: @@ -609,7 +1246,7 @@ def economic_impact( str(report.id), request.tax_benefit_model_name, session ) - return _build_response(report, baseline_sim, reform_sim, session) + return _build_response(report, baseline_sim, reform_sim, session, region) @router.get("/economic-impact/{report_id}", response_model=EconomicImpactResponse) @@ -631,4 +1268,324 @@ def get_economic_impact_status( if not baseline_sim or not reform_sim: raise HTTPException(status_code=500, detail="Simulation data missing") - return _build_response(report, baseline_sim, reform_sim, session) + region = ( + session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None + ) + return _build_response(report, baseline_sim, reform_sim, session, region) + + +# --------------------------------------------------------------------------- +# POST /analysis/economy-custom — run selected economy modules +# --------------------------------------------------------------------------- + +_MODEL_TO_COUNTRY = { + "policyengine_uk": "uk", + "policyengine_us": "us", +} + + +class EconomyCustomRequest(BaseModel): + """Request body for custom economy analysis with selected modules.""" + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID. Either dataset_id or region must be provided.", + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us').", + ) + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID to compare against baseline (current law)", + ) + dynamic_id: UUID | None = Field( + default=None, description="Optional behavioural response specification ID" + ) + year: int | None = Field( + default=None, + description="Year for the analysis. Uses latest available if omitted.", + ) + modules: list[str] = Field( + description="List of module names to compute (see GET /analysis/options)" + ) + + @model_validator(mode="after") + def check_dataset_or_region(self) -> "EconomyCustomRequest": + if not self.dataset_id and not self.region: + raise ValueError("Either dataset_id or region must be provided") + return self + + +def _build_filtered_response( + full_response: EconomicImpactResponse, + modules: list[str], +) -> EconomicImpactResponse: + """Return a copy of the response with only the fields for requested modules.""" + allowed_fields: set[str] = set() + for name in modules: + module = MODULE_REGISTRY.get(name) + if module: + allowed_fields.update(module.response_fields) + + # Fields that are always included regardless of modules + always_included = { + "report_id", + "status", + "baseline_simulation", + "reform_simulation", + "region", + "error_message", + } + + filtered = {} + for field_name in EconomicImpactResponse.model_fields: + value = getattr(full_response, field_name) + if field_name in always_included: + filtered[field_name] = value + elif field_name in allowed_fields: + filtered[field_name] = value + else: + filtered[field_name] = None + + return EconomicImpactResponse.model_construct(**filtered) + + +@router.post("/economy-custom", response_model=EconomicImpactResponse) +def economy_custom( + request: EconomyCustomRequest, + session: Session = Depends(get_session), +) -> EconomicImpactResponse: + """Run economy-wide analysis with only the selected modules. + + Same async pattern as /analysis/economic-impact but accepts a list of + module names. Only the requested modules' response fields are populated; + the rest are null. + + See GET /analysis/options for available module names. + """ + country = _MODEL_TO_COUNTRY[request.tax_benefit_model_name] + + try: + validate_modules(request.modules, country) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) + + # Reuse the same request model for dataset/region resolution + impact_request = EconomicImpactRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + dataset_id=request.dataset_id, + region=request.region, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + year=request.year, + ) + + dataset, region_obj = _resolve_dataset_and_region(impact_request, session) + + filter_field = ( + region_obj.filter_field if region_obj and region_obj.requires_filter else None + ) + filter_value = ( + region_obj.filter_value if region_obj and region_obj.requires_filter else None + ) + + model_version = _get_model_version(request.tax_benefit_model_name, session) + + baseline_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + region_id=region_obj.id if region_obj else None, + year=dataset.year, + ) + + reform_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + region_id=region_obj.id if region_obj else None, + year=dataset.year, + ) + + label = f"Custom analysis: {request.tax_benefit_model_name}" + if request.policy_id: + label += f" (policy {request.policy_id})" + + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison( + str(report.id), + request.tax_benefit_model_name, + session, + modules=request.modules, + ) + + full_response = _build_response( + report, baseline_sim, reform_sim, session, region_obj + ) + return _build_filtered_response(full_response, request.modules) + + +@router.get("/economy-custom/{report_id}", response_model=EconomicImpactResponse) +def get_economy_custom_status( + report_id: UUID, + modules: str | None = None, + session: Session = Depends(get_session), +) -> EconomicImpactResponse: + """Poll for results of custom economy analysis. + + Args: + report_id: The report ID returned by POST /analysis/economy-custom. + modules: Optional comma-separated module names to filter the response. + If omitted, all computed fields are returned. + """ + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id or not report.reform_simulation_id: + raise HTTPException(status_code=500, detail="Report missing simulation IDs") + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) + + if not baseline_sim or not reform_sim: + raise HTTPException(status_code=500, detail="Simulation data missing") + + region = ( + session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None + ) + full_response = _build_response(report, baseline_sim, reform_sim, session, region) + + if modules: + module_list = [m.strip() for m in modules.split(",")] + return _build_filtered_response(full_response, module_list) + + return full_response + + +# --------------------------------------------------------------------------- +# POST /analysis/rerun/{report_id} — force-rerun a report +# --------------------------------------------------------------------------- + + +class RerunResponse(BaseModel): + """Response from the rerun endpoint.""" + + report_id: str + status: str + + +@router.post("/rerun/{report_id}", response_model=RerunResponse) +def rerun_report( + report_id: UUID, + session: Session = Depends(get_session), +) -> RerunResponse: + """Force-rerun a report from scratch. + + Resets the report and its simulations to PENDING, deletes all + previously computed result records, and re-triggers computation. + Works for both economy and household reports. + """ + from sqlmodel import delete + + from policyengine_api.api.household_analysis import _trigger_household_impact + + # 1. Load report + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + # 2. Load simulations + baseline_sim = ( + session.get(Simulation, report.baseline_simulation_id) + if report.baseline_simulation_id + else None + ) + reform_sim = ( + session.get(Simulation, report.reform_simulation_id) + if report.reform_simulation_id + else None + ) + + if not baseline_sim: + raise HTTPException(status_code=400, detail="Report has no baseline simulation") + + # 3. Derive tax_benefit_model_name from simulation → model version → model + model_version = session.get( + TaxBenefitModelVersion, baseline_sim.tax_benefit_model_version_id + ) + if not model_version: + raise HTTPException(status_code=500, detail="Model version not found") + + model = session.get(TaxBenefitModel, model_version.model_id) + if not model: + raise HTTPException(status_code=500, detail="Tax-benefit model not found") + + tax_benefit_model_name = model.name.replace("-", "_") + + # 4. Delete all result records for this report + result_tables = [ + DecileImpact, + ProgramStatistics, + Poverty, + Inequality, + BudgetSummary, + IntraDecileImpact, + CongressionalDistrictImpact, + ConstituencyImpact, + LocalAuthorityImpact, + ] + for table in result_tables: + session.exec(delete(table).where(table.report_id == report_id)) + + # 5. Reset report status + report.status = ReportStatus.PENDING + report.error_message = None + session.add(report) + + # 6. Reset simulation statuses + for sim in [baseline_sim, reform_sim]: + if sim: + sim.status = SimulationStatus.PENDING + sim.error_message = None + sim.completed_at = None + session.add(sim) + + session.commit() + + # 7. Trigger computation based on report type + is_economy = report.report_type and "economy" in report.report_type + is_household = report.report_type and "household" in report.report_type + + if is_economy: + with logfire.span("rerun_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison(str(report.id), tax_benefit_model_name, session) + elif is_household: + with logfire.span("rerun_household_impact", job_id=str(report.id)): + _trigger_household_impact(str(report.id), tax_benefit_model_name, session) + else: + raise HTTPException( + status_code=400, + detail=f"Unknown report type: {report.report_type}", + ) + + return RerunResponse(report_id=str(report_id), status="pending") diff --git a/src/policyengine_api/api/change_aggregates.py b/src/policyengine_api/api/change_aggregates.py index f706939..56fc62b 100644 --- a/src/policyengine_api/api/change_aggregates.py +++ b/src/policyengine_api/api/change_aggregates.py @@ -13,6 +13,7 @@ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from sqlmodel import Session, select +from policyengine_api.config import settings from policyengine_api.models import ( ChangeAggregate, ChangeAggregateCreate, @@ -71,9 +72,17 @@ def _trigger_change_aggregate_computation( traceparent = _get_traceparent() if "uk" in model.name.lower(): - fn = modal.Function.from_name("policyengine", "compute_change_aggregate_uk") + fn = modal.Function.from_name( + "policyengine", + "compute_change_aggregate_uk", + environment_name=settings.modal_environment, + ) else: - fn = modal.Function.from_name("policyengine", "compute_change_aggregate_us") + fn = modal.Function.from_name( + "policyengine", + "compute_change_aggregate_us", + environment_name=settings.modal_environment, + ) fn.spawn(change_aggregate_id=change_aggregate_id, traceparent=traceparent) logfire.info( diff --git a/src/policyengine_api/api/computation_modules.py b/src/policyengine_api/api/computation_modules.py new file mode 100644 index 0000000..4c6c52d --- /dev/null +++ b/src/policyengine_api/api/computation_modules.py @@ -0,0 +1,817 @@ +"""Composable computation module functions for economy analysis. + +Each function computes a single module's results and writes DB records. +They share a common signature pattern: + (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, + report_id, session, **kwargs) -> None + +run_modules() passes country_id as a kwarg. Modules that need it (e.g. +compute_decile_module) accept it explicitly; others accept **_kwargs. + +Used by _run_local_economy_comparison_uk/us to run modules selectively. +""" + +from __future__ import annotations + +from uuid import UUID + +from sqlmodel import Session + +from policyengine_api.models import ( + BudgetSummary, + CongressionalDistrictImpact, + ConstituencyImpact, + DecileImpact, + Inequality, + IntraDecileImpact, + LocalAuthorityImpact, + Poverty, + ProgramStatistics, +) + +# --------------------------------------------------------------------------- +# Shared modules (UK + US) +# --------------------------------------------------------------------------- + + +DECILE_INCOME_VARIABLE: dict[str, str] = { + "us": "household_net_income", + "uk": "equiv_household_net_income", +} + + +def compute_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + country_id: str = "", +) -> None: + """Compute income decile impacts (1-10).""" + from policyengine.outputs import DecileImpact as PEDecileImpact + + if country_id not in DECILE_INCOME_VARIABLE: + raise ValueError( + f"No decile income variable configured for country '{country_id}'" + ) + + income_variable = DECILE_INCOME_VARIABLE[country_id] + + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + income_variable=income_variable, + ) + di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(record) + + +def compute_intra_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute intra-decile income change distribution (5 bands).""" + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + entity="household", + ) + for r in results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + + +# --------------------------------------------------------------------------- +# UK-specific modules +# --------------------------------------------------------------------------- + + +def compute_program_statistics_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK programme statistics.""" + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.uk.outputs import ( + ProgrammeStatistics as PEProgrammeStats, + ) + + PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programmes = { + "income_tax": {"entity": "person", "is_tax": True}, + "national_insurance": {"entity": "person", "is_tax": True}, + "vat": {"entity": "household", "is_tax": True}, + "council_tax": {"entity": "household", "is_tax": True}, + "universal_credit": {"entity": "person", "is_tax": False}, + "child_benefit": {"entity": "person", "is_tax": False}, + "pension_credit": {"entity": "person", "is_tax": False}, + "income_support": {"entity": "person", "is_tax": False}, + "working_tax_credit": {"entity": "person", "is_tax": False}, + "child_tax_credit": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programmes.items(): + try: + ps = PEProgrammeStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + programme_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + except KeyError: + import logfire + + logfire.warning(f"Program variable not found, skipping: {prog_name}") + continue + record = ProgramStatistics( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(record) + + +def compute_poverty_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK poverty rates (overall, by age, by gender).""" + from policyengine.outputs.poverty import ( + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + calculate_uk_poverty_rates, + ) + + sim_pairs = [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ] + + for calculator in [ + calculate_uk_poverty_rates, + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + ]: + for pe_sim, db_sim_id in sim_pairs: + results = calculator(pe_sim) + for pov in results.outputs: + record = Poverty( + simulation_id=db_sim_id, + report_id=report_id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(record) + + +def compute_inequality_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK inequality metrics.""" + from policyengine.outputs.inequality import calculate_uk_inequality + + for pe_sim, db_sim_id in [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ]: + ineq = calculate_uk_inequality(pe_sim) + ineq.run() + record = Inequality( + simulation_id=db_sim_id, + report_id=report_id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(record) + + +def compute_budget_summary_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK budget summary aggregates.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.aggregate import Aggregate as PEAggregate + from policyengine.outputs.aggregate import AggregateType as PEAggregateType + + PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + uk_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + } + for var_name, entity in uk_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(record) + + # Household count: raw sum of weights (bypasses Aggregate weighting) + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() + ) + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(record) + + +def compute_constituency_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK parliamentary constituency impact.""" + from policyengine.outputs.constituency_impact import ( + compute_uk_constituency_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import download as gcs_download + + weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="parliamentary_constituency_weights.h5", + ) + constituency_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="constituencies_2024.csv", + ) + impact = compute_uk_constituency_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=weight_matrix_path, + constituency_csv_path=constituency_csv_path, + ) + if impact.constituency_results: + for cr in impact.constituency_results: + record = ConstituencyImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + constituency_code=cr["constituency_code"], + constituency_name=cr["constituency_name"], + x=cr["x"], + y=cr["y"], + average_household_income_change=cr[ + "average_household_income_change" + ], + relative_household_income_change=cr[ + "relative_household_income_change" + ], + population=cr["population"], + ) + session.add(record) + except FileNotFoundError: + import logfire + + logfire.warning("Weight matrix not available, skipping constituency impact") + + +def compute_local_authority_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK local authority impact.""" + from policyengine.outputs.local_authority_impact import ( + compute_uk_local_authority_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import download as gcs_download + + la_weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authority_weights.h5", + ) + la_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authorities_2021.csv", + ) + impact = compute_uk_local_authority_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=la_weight_matrix_path, + local_authority_csv_path=la_csv_path, + ) + if impact.local_authority_results: + for lr in impact.local_authority_results: + record = LocalAuthorityImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + local_authority_code=lr["local_authority_code"], + local_authority_name=lr["local_authority_name"], + x=lr["x"], + y=lr["y"], + average_household_income_change=lr[ + "average_household_income_change" + ], + relative_household_income_change=lr[ + "relative_household_income_change" + ], + population=lr["population"], + ) + session.add(record) + except FileNotFoundError: + import logfire + + logfire.warning("Weight matrix not available, skipping local authority impact") + + +def compute_wealth_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute UK wealth decile impact and intra-wealth-decile breakdown.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.decile_impact import DecileImpact as PEDecileImpact + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + try: + PEDecileImpact.model_rebuild(_types_namespace={"Simulation": PESimulation}) + for decile_num in range(1, 11): + wealth_di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + decile=decile_num, + ) + wealth_di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + income_variable="household_wealth_decile", + entity="household", + decile=decile_num, + quantiles=10, + baseline_mean=wealth_di.baseline_mean, + reform_mean=wealth_di.reform_mean, + absolute_change=wealth_di.absolute_change, + relative_change=wealth_di.relative_change, + ) + session.add(record) + + # Intra-wealth-decile + intra_wealth_results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + ) + for r in intra_wealth_results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + decile_type="wealth", + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + except KeyError: + import logfire + + logfire.warning( + "household_wealth_decile not available, skipping wealth decile impact" + ) + + +# --------------------------------------------------------------------------- +# US-specific modules +# --------------------------------------------------------------------------- + + +def compute_program_statistics_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute US program statistics.""" + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.us.outputs import ( + ProgramStatistics as PEProgramStats, + ) + + PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programs = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "person", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "spm_unit", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programs.items(): + try: + ps = PEProgramStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + except KeyError: + import logfire + + logfire.warning(f"Program variable not found, skipping: {prog_name}") + continue + record = ProgramStatistics( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(record) + + +def compute_poverty_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute US poverty rates (overall, by age, gender, race).""" + from policyengine.outputs.poverty import ( + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + calculate_us_poverty_rates, + ) + + sim_pairs = [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ] + + for calculator in [ + calculate_us_poverty_rates, + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + ]: + for pe_sim, db_sim_id in sim_pairs: + results = calculator(pe_sim) + for pov in results.outputs: + record = Poverty( + simulation_id=db_sim_id, + report_id=report_id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(record) + + +def compute_inequality_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute US inequality metrics.""" + from policyengine.outputs.inequality import calculate_us_inequality + + for pe_sim, db_sim_id in [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ]: + ineq = calculate_us_inequality(pe_sim) + ineq.run() + record = Inequality( + simulation_id=db_sim_id, + report_id=report_id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(record) + + +def compute_budget_summary_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute US budget summary aggregates.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.aggregate import Aggregate as PEAggregate + from policyengine.outputs.aggregate import AggregateType as PEAggregateType + + PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + us_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + "household_state_income_tax": "tax_unit", + } + for var_name, entity in us_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(record) + + # Household count: raw sum of weights + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() + ) + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(record) + + +def compute_congressional_district_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + **_kwargs, +) -> None: + """Compute US congressional district impact.""" + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + try: + impact = compute_us_congressional_district_impacts( + pe_baseline_sim, pe_reform_sim + ) + except KeyError: + import logfire + + logfire.warning( + "congressional_district_geoid not in dataset, skipping congressional district impact" + ) + return + if impact.district_results: + for dr in impact.district_results: + record = CongressionalDistrictImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + district_geoid=dr["district_geoid"], + state_fips=dr["state_fips"], + district_number=dr["district_number"], + average_household_income_change=dr["average_household_income_change"], + relative_household_income_change=dr["relative_household_income_change"], + population=dr["population"], + ) + session.add(record) + + +# --------------------------------------------------------------------------- +# Dispatch tables: module name -> computation function +# --------------------------------------------------------------------------- + +# Type alias for module computation functions +ModuleFunction = type(compute_decile_module) + +UK_MODULE_DISPATCH: dict[str, ModuleFunction] = { + "decile": compute_decile_module, + "program_statistics": compute_program_statistics_module_uk, + "poverty": compute_poverty_module_uk, + "inequality": compute_inequality_module_uk, + "budget_summary": compute_budget_summary_module_uk, + "intra_decile": compute_intra_decile_module, + "constituency": compute_constituency_module, + "local_authority": compute_local_authority_module, + "wealth_decile": compute_wealth_decile_module, +} + +US_MODULE_DISPATCH: dict[str, ModuleFunction] = { + "decile": compute_decile_module, + "program_statistics": compute_program_statistics_module_us, + "poverty": compute_poverty_module_us, + "inequality": compute_inequality_module_us, + "budget_summary": compute_budget_summary_module_us, + "intra_decile": compute_intra_decile_module, + "congressional_district": compute_congressional_district_module, +} + + +def run_modules( + dispatch: dict[str, ModuleFunction], + modules: list[str] | None, + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, + country_id: str = "", +) -> None: + """Run the requested modules (or all if modules is None).""" + to_run = modules if modules is not None else list(dispatch.keys()) + for mod_name in to_run: + fn = dispatch.get(mod_name) + if fn: + fn( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id, + reform_sim_id, + report_id, + session, + country_id=country_id, + ) diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 0e89b5e..5fda4b4 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -300,11 +300,13 @@ def _calculate_household_uk( from pathlib import Path import pandas as pd - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.datasets import UKYearData + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + UKYearData, + ) n_people = len(people) n_benunits = max(1, len(benunit)) @@ -466,7 +468,14 @@ def _run_local_household_us( try: result = _calculate_household_us( - people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data + people, + marital_unit, + family, + spm_unit, + tax_unit, + household, + year, + policy_data, ) # Update job with result @@ -512,11 +521,13 @@ def _calculate_household_us( from pathlib import Path import pandas as pd - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset - from policyengine.tax_benefit_models.us.datasets import USYearData + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) n_people = len(people) n_households = max(1, len(household)) @@ -672,7 +683,9 @@ def safe_convert(value): except (ValueError, TypeError): return str(value) - def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]: + def extract_entity_outputs( + entity_name: str, entity_data, n_rows: int + ) -> list[dict]: outputs = [] for i in range(n_rows): row_dict = {} @@ -743,7 +756,11 @@ def _trigger_modal_household( traceparent = get_traceparent() if request.tax_benefit_model_name == "policyengine_uk": - fn = modal.Function.from_name("policyengine", "simulate_household_uk") + fn = modal.Function.from_name( + "policyengine", + "simulate_household_uk", + environment_name=settings.modal_environment, + ) fn.spawn( job_id=job_id, people=request.people, @@ -755,7 +772,11 @@ def _trigger_modal_household( traceparent=traceparent, ) else: - fn = modal.Function.from_name("policyengine", "simulate_household_us") + fn = modal.Function.from_name( + "policyengine", + "simulate_household_us", + environment_name=settings.modal_environment, + ) fn.spawn( job_id=job_id, people=request.people, diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py new file mode 100644 index 0000000..9c22234 --- /dev/null +++ b/src/policyengine_api/api/household_analysis.py @@ -0,0 +1,755 @@ +"""Household impact analysis endpoints. + +Use these endpoints to analyze household-level effects of policy reforms. +Supports single runs (current law) and comparisons (baseline vs reform). + +WORKFLOW: +1. Create a stored household: POST /households +2. Optionally create a reform policy: POST /policies +3. Run analysis: POST /analysis/household-impact (returns report_id) +4. Poll GET /analysis/household-impact/{report_id} until status="completed" +5. Results include baseline_result, reform_result (if comparison), and impact diff +""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Protocol +from uuid import UUID + +import logfire +from fastapi import APIRouter, Depends, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from pydantic import BaseModel, Field +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Policy, + Report, + ReportStatus, + Simulation, + SimulationStatus, + SimulationType, +) +from policyengine_api.services.database import get_session + +from .analysis import ( + _get_model_version, + _get_or_create_report, + _get_or_create_simulation, +) + + +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + + +router = APIRouter(prefix="/analysis", tags=["analysis"]) + + +# ============================================================================= +# Country Strategy Pattern +# ============================================================================= + + +@dataclass(frozen=True) +class CountryConfig: + """Configuration for a country's household calculation.""" + + name: str + entity_types: tuple[str, ...] + + +UK_CONFIG = CountryConfig( + name="uk", + entity_types=("person", "benunit", "household"), +) + +US_CONFIG = CountryConfig( + name="us", + entity_types=( + "person", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + "household", + ), +) + + +def get_country_config(tax_benefit_model_name: str) -> CountryConfig: + """Get country configuration from model name.""" + if tax_benefit_model_name == "policyengine_uk": + return UK_CONFIG + return US_CONFIG + + +class HouseholdCalculator(Protocol): + """Protocol for country-specific household calculators.""" + + def __call__( + self, + household_data: dict[str, Any], + year: int, + policy_data: dict | None, + ) -> dict: ... + + +def calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate UK household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_uk + + return _calculate_household_uk( + people=household_data.get("people", []), + benunit=_ensure_list(household_data.get("benunit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate US household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_us + + return _calculate_household_us( + people=household_data.get("people", []), + marital_unit=_ensure_list(household_data.get("marital_unit")), + family=_ensure_list(household_data.get("family")), + spm_unit=_ensure_list(household_data.get("spm_unit")), + tax_unit=_ensure_list(household_data.get("tax_unit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def get_calculator(tax_benefit_model_name: str) -> HouseholdCalculator: + """Get the appropriate calculator for a country.""" + if tax_benefit_model_name == "policyengine_uk": + return calculate_uk_household + return calculate_us_household + + +# ============================================================================= +# Data Transformation Helpers +# ============================================================================= + + +def _ensure_list(value: Any) -> list: + """Ensure value is a list; wrap dict in list if needed.""" + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _extract_policy_data(policy: Policy | None) -> dict | None: + """Extract policy data from a Policy model into calculation format. + + Returns format expected by _calculate_household_us/_calculate_household_uk: + { + "name": "policy name", + "description": "policy description", + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc...", + "value": 0.16, + "start_date": "2024-01-01T00:00:00+00:00", + "end_date": null + } + ] + } + """ + if not policy or not policy.parameter_values: + return None + + parameter_values = [] + for pv in policy.parameter_values: + if not pv.parameter: + continue + + parameter_values.append( + { + "parameter_name": pv.parameter.name, + "value": _extract_value(pv.value_json), + "start_date": _format_date(pv.start_date), + "end_date": _format_date(pv.end_date), + } + ) + + if not parameter_values: + return None + + return { + "name": policy.name, + "description": policy.description or "", + "parameter_values": parameter_values, + } + + +def _extract_value(value_json: Any) -> Any: + """Extract the actual value from value_json.""" + if isinstance(value_json, dict): + return value_json.get("value") + return value_json + + +def _format_date(date: Any) -> str | None: + """Format a date for the policy data structure.""" + if date is None: + return None + if hasattr(date, "isoformat"): + return date.isoformat() + return str(date) + + +# ============================================================================= +# Impact Computation +# ============================================================================= + + +def compute_variable_diff(baseline_val: Any, reform_val: Any) -> dict | None: + """Compute diff for a single variable if both are numeric.""" + if not isinstance(baseline_val, (int, float)): + return None + if not isinstance(reform_val, (int, float)): + return None + + return { + "baseline": baseline_val, + "reform": reform_val, + "change": reform_val - baseline_val, + } + + +def compute_entity_diff(baseline_entity: dict, reform_entity: dict) -> dict: + """Compute per-variable diffs for a single entity instance.""" + entity_diff = {} + + for key, baseline_val in baseline_entity.items(): + reform_val = reform_entity.get(key) + if reform_val is None: + continue + + diff = compute_variable_diff(baseline_val, reform_val) + if diff is not None: + entity_diff[key] = diff + + return entity_diff + + +def compute_entity_list_diff( + baseline_list: list[dict], + reform_list: list[dict], +) -> list[dict]: + """Compute diffs for a list of entity instances.""" + return [ + compute_entity_diff(b_entity, r_entity) + for b_entity, r_entity in zip(baseline_list, reform_list) + ] + + +def compute_household_impact( + baseline_result: dict, + reform_result: dict, + config: CountryConfig, +) -> dict[str, Any]: + """Compute difference between baseline and reform for all entity types.""" + impact: dict[str, Any] = {} + + for entity in config.entity_types: + baseline_entities = baseline_result.get(entity) + reform_entities = reform_result.get(entity) + + if baseline_entities is None or reform_entities is None: + continue + + impact[entity] = compute_entity_list_diff(baseline_entities, reform_entities) + + return impact + + +# ============================================================================= +# Simulation Execution +# ============================================================================= + + +def mark_simulation_running(simulation: Simulation, session: Session) -> None: + """Mark a simulation as running.""" + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_completed( + simulation: Simulation, + result: dict, + session: Session, +) -> None: + """Mark a simulation as completed with result.""" + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_failed( + simulation: Simulation, + error: Exception, + session: Session, +) -> None: + """Mark a simulation as failed with error.""" + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(error) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def run_household_simulation(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation and store result.""" + simulation = _load_simulation(simulation_id, session) + household = _load_household(simulation.household_id, session) + policy_data = _load_policy_data(simulation.policy_id, session) + + mark_simulation_running(simulation, session) + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + mark_simulation_completed(simulation, result, session) + except Exception as e: + logfire.error( + "Household simulation failed", + simulation_id=str(simulation_id), + error=str(e), + ) + mark_simulation_failed(simulation, e, session) + + +def _load_simulation(simulation_id: UUID, session: Session) -> Simulation: + """Load simulation or raise error.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + return simulation + + +def _load_household(household_id: UUID | None, session: Session) -> Household: + """Load household or raise error.""" + if not household_id: + raise ValueError("Simulation has no household_id") + + household = session.get(Household, household_id) + if not household: + raise ValueError(f"Household {household_id} not found") + return household + + +def _load_policy_data(policy_id: UUID | None, session: Session) -> dict | None: + """Load and extract policy data if policy_id is set.""" + if not policy_id: + return None + + policy = session.get(Policy, policy_id) + return _extract_policy_data(policy) + + +# ============================================================================= +# Report Orchestration (Async) +# ============================================================================= + + +def _run_local_household_impact(report_id: str, session: Session) -> None: + """Run household impact analysis locally. + + NOTE: This runs synchronously and blocks the HTTP request when running + locally (agent_use_modal=False). This mirrors the economic impact behavior. + True async execution requires Modal. + """ + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found for household impact") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + try: + # Run baseline simulation + if report.baseline_simulation_id: + _run_simulation_in_session(report.baseline_simulation_id, session) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_simulation_in_session(report.reform_simulation_id, session) + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + except Exception as e: + report.status = ReportStatus.FAILED + report.error_message = str(e) + session.add(report) + session.commit() + + +def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation within an existing session.""" + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + policy_data = _load_policy_data(simulation.policy_id, session) + + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _trigger_household_impact( + report_id: str, tax_benefit_model_name: str, session: Session | None = None +) -> None: + """Trigger household impact calculation (local or Modal based on settings).""" + from policyengine_api.config import settings + + traceparent = get_traceparent() + + if not settings.agent_use_modal and session is not None: + # Run locally (blocking - see _run_local_household_impact docstring) + _run_local_household_impact(report_id, session) + else: + # Use Modal + import modal + + if tax_benefit_model_name == "policyengine_uk": + fn = modal.Function.from_name( + "policyengine", + "household_impact_uk", + environment_name=settings.modal_environment, + ) + else: + fn = modal.Function.from_name( + "policyengine", + "household_impact_us", + environment_name=settings.modal_environment, + ) + + try: + fn.spawn(report_id=report_id, traceparent=traceparent) + except Exception as e: + # Mark report as FAILED so it doesn't stay PENDING forever + if session is not None: + report = session.get(Report, UUID(report_id)) + if report: + report.status = ReportStatus.FAILED + report.error_message = f"Failed to trigger computation: {e}" + session.add(report) + session.commit() + raise HTTPException( + status_code=502, + detail=f"Failed to trigger computation: {e}", + ) + + +# Legacy functions kept for compatibility +def _load_report(report_id: UUID, session: Session) -> Report: + """Load report or raise error.""" + report = session.get(Report, report_id) + if not report: + raise ValueError(f"Report {report_id} not found") + return report + + +# ============================================================================= +# Request/Response Schemas +# ============================================================================= + + +class HouseholdImpactRequest(BaseModel): + """Request for household impact analysis.""" + + household_id: UUID = Field(description="ID of the household to analyze") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs single calculation under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationInfo(BaseModel): + """Info about a household simulation.""" + + id: UUID + status: SimulationStatus + error_message: str | None = None + + +class HouseholdImpactResponse(BaseModel): + """Response for household impact analysis.""" + + report_id: UUID + report_type: str + status: ReportStatus + baseline_simulation: HouseholdSimulationInfo | None = None + reform_simulation: HouseholdSimulationInfo | None = None + baseline_result: dict | None = None + reform_result: dict | None = None + impact: dict | None = None + error_message: str | None = None + + +# ============================================================================= +# Response Building +# ============================================================================= + + +def build_simulation_info( + simulation: Simulation | None, +) -> HouseholdSimulationInfo | None: + """Build simulation info from a simulation.""" + if not simulation: + return None + + return HouseholdSimulationInfo( + id=simulation.id, + status=simulation.status, + error_message=simulation.error_message, + ) + + +def build_household_response( + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation | None, + session: Session, +) -> HouseholdImpactResponse: + """Build response including computed impact for comparisons.""" + baseline_result = baseline_sim.household_result + reform_result = reform_sim.household_result if reform_sim else None + + impact = _compute_impact_if_comparison( + baseline_sim, reform_sim, baseline_result, reform_result, session + ) + + return HouseholdImpactResponse( + report_id=report.id, + report_type=report.report_type or "household_single", + status=report.status, + baseline_simulation=build_simulation_info(baseline_sim), + reform_simulation=build_simulation_info(reform_sim), + baseline_result=baseline_result, + reform_result=reform_result, + impact=impact, + error_message=report.error_message, + ) + + +def _compute_impact_if_comparison( + baseline_sim: Simulation, + reform_sim: Simulation | None, + baseline_result: dict | None, + reform_result: dict | None, + session: Session, +) -> dict | None: + """Compute impact only if this is a comparison with both results.""" + if not reform_sim: + return None + if not baseline_result or not reform_result: + return None + + household = session.get(Household, baseline_sim.household_id) + if not household: + return None + + config = get_country_config(household.tax_benefit_model_name) + return compute_household_impact(baseline_result, reform_result, config) + + +# ============================================================================= +# Validation Helpers +# ============================================================================= + + +def validate_household_exists(household_id: UUID, session: Session) -> Household: + """Validate household exists and return it.""" + household = session.get(Household, household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {household_id} not found", + ) + return household + + +def validate_policy_exists(policy_id: UUID | None, session: Session) -> None: + """Validate policy exists if provided.""" + if not policy_id: + return + + policy = session.get(Policy, policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {policy_id} not found", + ) + + +# ============================================================================= +# Endpoints +# ============================================================================= + + +@router.post("/household-impact", response_model=HouseholdImpactResponse) +def household_impact( + request: HouseholdImpactRequest, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Run household impact analysis. + + If policy_id is None: single run under current law. + If policy_id is set: comparison (baseline vs reform). + + This is an async operation. The endpoint returns immediately with a report_id + and status="pending". Poll GET /analysis/household-impact/{report_id} until + status="completed" to get results. + """ + household = validate_household_exists(request.household_id, session) + validate_policy_exists(request.policy_id, session) + + model_version = _get_model_version(household.tax_benefit_model_name, session) + + baseline_sim = _create_baseline_simulation( + household, model_version.id, request.dynamic_id, session + ) + reform_sim = _create_reform_simulation( + household, model_version.id, request.policy_id, request.dynamic_id, session + ) + + report_type = "household_comparison" if request.policy_id else "household_single" + report = _get_or_create_report( + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id if reform_sim else None, + label=f"Household impact: {household.tax_benefit_model_name}", + report_type=report_type, + session=session, + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact( + str(report.id), household.tax_benefit_model_name, session + ) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +@router.get("/household-impact/{report_id}", response_model=HouseholdImpactResponse) +def get_household_impact( + report_id: UUID, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Get household impact analysis status and results.""" + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id: + raise HTTPException( + status_code=500, + detail="Report missing baseline simulation ID", + ) + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + if not baseline_sim: + raise HTTPException(status_code=500, detail="Baseline simulation data missing") + + reform_sim = None + if report.reform_simulation_id: + reform_sim = session.get(Simulation, report.reform_simulation_id) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +# ============================================================================= +# Simulation Creation Helpers +# ============================================================================= + + +def _create_baseline_simulation( + household: Household, + model_version_id: UUID, + dynamic_id: UUID | None, + session: Session, +) -> Simulation: + """Create baseline simulation (current law, no policy).""" + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=None, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) + + +def _create_reform_simulation( + household: Household, + model_version_id: UUID, + policy_id: UUID | None, + dynamic_id: UUID | None, + session: Session, +) -> Simulation | None: + """Create reform simulation if policy_id is provided.""" + if not policy_id: + return None + + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=policy_id, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) diff --git a/src/policyengine_api/api/households.py b/src/policyengine_api/api/households.py new file mode 100644 index 0000000..fdee1f7 --- /dev/null +++ b/src/policyengine_api/api/households.py @@ -0,0 +1,119 @@ +"""Stored household CRUD endpoints. + +Households represent saved household definitions that can be reused across +calculations and impact analyses. Create a household once, then reference +it by ID for repeated simulations. + +These endpoints manage stored household *definitions* (people, entity groups, +model name, year). For running calculations on a household, use the +/household/calculate and /household/impact endpoints instead. +""" + +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import Household, HouseholdCreate, HouseholdRead +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/households", tags=["households"]) + +_ENTITY_GROUP_KEYS = ( + "tax_unit", + "family", + "spm_unit", + "marital_unit", + "household", + "benunit", +) + + +def _pack_household_data(body: HouseholdCreate) -> dict[str, Any]: + """Pack the flat request fields into a single JSON blob for storage.""" + data: dict[str, Any] = {"people": body.people} + for key in _ENTITY_GROUP_KEYS: + val = getattr(body, key) + if val is not None: + data[key] = val + return data + + +def _to_read(record: Household) -> HouseholdRead: + """Unpack the JSON blob back into the flat response shape.""" + data = record.household_data + return HouseholdRead( + id=record.id, + tax_benefit_model_name=record.tax_benefit_model_name, + year=record.year, + label=record.label, + people=data["people"], + tax_unit=data.get("tax_unit"), + family=data.get("family"), + spm_unit=data.get("spm_unit"), + marital_unit=data.get("marital_unit"), + household=data.get("household"), + benunit=data.get("benunit"), + created_at=record.created_at, + updated_at=record.updated_at, + ) + + +@router.post("/", response_model=HouseholdRead, status_code=201) +def create_household(body: HouseholdCreate, session: Session = Depends(get_session)): + """Create a stored household definition. + + The household data (people + entity groups) is persisted so it can be + retrieved later by ID. Use the returned ID with /household/calculate + or /household/impact to run simulations. + """ + record = Household( + tax_benefit_model_name=body.tax_benefit_model_name, + year=body.year, + label=body.label, + household_data=_pack_household_data(body), + ) + session.add(record) + session.commit() + session.refresh(record) + return _to_read(record) + + +@router.get("/", response_model=list[HouseholdRead]) +def list_households( + tax_benefit_model_name: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List stored households with optional filtering.""" + query = select(Household) + if tax_benefit_model_name is not None: + query = query.where(Household.tax_benefit_model_name == tax_benefit_model_name) + query = query.offset(offset).limit(limit) + records = session.exec(query).all() + return [_to_read(r) for r in records] + + +@router.get("/{household_id}", response_model=HouseholdRead) +def get_household(household_id: UUID, session: Session = Depends(get_session)): + """Get a stored household by ID.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + return _to_read(record) + + +@router.delete("/{household_id}", status_code=204) +def delete_household(household_id: UUID, session: Session = Depends(get_session)): + """Delete a stored household.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/api/intra_decile.py b/src/policyengine_api/api/intra_decile.py new file mode 100644 index 0000000..668597a --- /dev/null +++ b/src/policyengine_api/api/intra_decile.py @@ -0,0 +1,113 @@ +"""Intra-decile income change computation. + +Computes the distribution of income change categories within each income +decile, producing proportions for 5 categories per decile plus an overall +average row. +""" + +from typing import Callable + +import numpy as np + +# The 5-category thresholds and labels (matching V1 structure) +BOUNDS = [-np.inf, -0.05, -1e-3, 1e-3, 0.05, np.inf] +CATEGORY_COLUMNS = [ + "lose_more_than_5pct", + "lose_less_than_5pct", + "no_change", + "gain_less_than_5pct", + "gain_more_than_5pct", +] + + +# --- Income change formula variants --- + + +# NOTE: This formula replicates V1's API (policyengine-api, endpoints/economy/ +# compare.py lines 324-331). It appears to double-count the change because it +# adds absolute_change to the already-changed capped reform income: +# capped_reform = max(reform, 1) + (reform - baseline) +# For the common case (both incomes > 1), this yields: +# income_change = 2 * (reform - baseline) / baseline +# Kept here for reference while confirming with the team. +def _income_change_v1_original( + baseline_income: np.ndarray, + reform_income: np.ndarray, +) -> np.ndarray: + absolute_change = reform_income - baseline_income + capped_baseline = np.maximum(baseline_income, 1) + capped_reform = np.maximum(reform_income, 1) + absolute_change + return (capped_reform - capped_baseline) / capped_baseline + + +def _income_change_corrected( + baseline_income: np.ndarray, + reform_income: np.ndarray, +) -> np.ndarray: + capped_baseline = np.maximum(baseline_income, 1) + return (reform_income - baseline_income) / capped_baseline + + +# Strategy selector — change this to switch formulas +def get_income_change_formula() -> Callable[[np.ndarray, np.ndarray], np.ndarray]: + return _income_change_corrected + + +# --- Main computation --- + + +def compute_intra_decile( + baseline_household_data: dict[str, np.ndarray], + reform_household_data: dict[str, np.ndarray], +) -> list[dict]: + """Compute intra-decile impact proportions. + + Args: + baseline_household_data: Dict with keys "household_net_income", + "household_weight", "household_count_people", + "household_income_decile" — all as raw numpy arrays. + reform_household_data: Same keys for the reform simulation. + + Returns: + List of 11 dicts (deciles 1-10 + overall as decile=0), each with + keys: decile, lose_more_than_5pct, lose_less_than_5pct, no_change, + gain_less_than_5pct, gain_more_than_5pct. + """ + baseline_income = baseline_household_data["household_net_income"] + reform_income = reform_household_data["household_net_income"] + weights = baseline_household_data["household_weight"] + people_per_hh = baseline_household_data["household_count_people"] + decile = baseline_household_data["household_income_decile"] + + # People-weighted count per household + people = people_per_hh * weights + + # Compute percentage income change + formula = get_income_change_formula() + income_change = formula(baseline_income, reform_income) + + # For each decile, compute proportion of people in each category + rows = [] + for decile_num in range(1, 11): + in_decile = decile == decile_num + people_in_decile = people[in_decile].sum() + + proportions = {} + for col, lower, upper in zip(CATEGORY_COLUMNS, BOUNDS[:-1], BOUNDS[1:]): + in_category = (income_change > lower) & (income_change <= upper) + in_both = in_decile & in_category + + if people_in_decile == 0: + proportions[col] = 0.0 + else: + proportions[col] = float(people[in_both].sum() / people_in_decile) + + rows.append({"decile": decile_num, **proportions}) + + # Overall average: mean of the 10 decile proportions (matching V1) + overall = {"decile": 0} + for col in CATEGORY_COLUMNS: + overall[col] = sum(r[col] for r in rows) / 10 + rows.append(overall) + + return rows diff --git a/src/policyengine_api/api/module_registry.py b/src/policyengine_api/api/module_registry.py new file mode 100644 index 0000000..3cd5c6c --- /dev/null +++ b/src/policyengine_api/api/module_registry.py @@ -0,0 +1,126 @@ +"""Economy analysis module registry. + +Defines the available computation modules for economy-wide analysis. +Each module maps to a named computation (e.g., "poverty", "decile") with +metadata about which countries support it and which response fields it +populates. + +Used by: +- GET /analysis/options — lists available modules +- POST /analysis/economy-custom — runs selected modules +""" + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class ComputationModule: + """A named economy analysis computation module.""" + + name: str + label: str + description: str + countries: list[str] = field(default_factory=list) + response_fields: list[str] = field(default_factory=list) + + +MODULE_REGISTRY: dict[str, ComputationModule] = { + "decile": ComputationModule( + name="decile", + label="Income decile impacts", + description="Relative and average income change by income decile (1-10).", + countries=["uk", "us"], + response_fields=["decile_impacts"], + ), + "program_statistics": ComputationModule( + name="program_statistics", + label="Program statistics", + description="Per-program baseline/reform totals, changes, and winner/loser counts.", + countries=["uk", "us"], + response_fields=["program_statistics", "detailed_budget"], + ), + "poverty": ComputationModule( + name="poverty", + label="Poverty rates", + description="Poverty rates by type, overall and by demographic breakdowns (age, gender, race).", + countries=["uk", "us"], + response_fields=["poverty"], + ), + "inequality": ComputationModule( + name="inequality", + label="Inequality metrics", + description="Gini coefficient, top 10%/1% share, bottom 50% share.", + countries=["uk", "us"], + response_fields=["inequality"], + ), + "budget_summary": ComputationModule( + name="budget_summary", + label="Budget summary", + description="Aggregate tax revenue, benefit spending, net income, and household count.", + countries=["uk", "us"], + response_fields=["budget_summary"], + ), + "intra_decile": ComputationModule( + name="intra_decile", + label="Intra-decile breakdown", + description="Distribution of income change categories (5 bands) within each decile.", + countries=["uk", "us"], + response_fields=["intra_decile"], + ), + "congressional_district": ComputationModule( + name="congressional_district", + label="Congressional district impact", + description="Per-district average and relative household income change for US congressional districts.", + countries=["us"], + response_fields=["congressional_district_impact"], + ), + "constituency": ComputationModule( + name="constituency", + label="Parliamentary constituency impact", + description="Per-constituency average and relative household income change for UK parliamentary constituencies.", + countries=["uk"], + response_fields=["constituency_impact"], + ), + "local_authority": ComputationModule( + name="local_authority", + label="Local authority impact", + description="Per-local-authority average and relative household income change for UK local authorities.", + countries=["uk"], + response_fields=["local_authority_impact"], + ), + "wealth_decile": ComputationModule( + name="wealth_decile", + label="Wealth decile impacts", + description="Income change by wealth decile (1-10) and intra-wealth-decile breakdown.", + countries=["uk"], + response_fields=["wealth_decile", "intra_wealth_decile"], + ), +} + + +def get_modules_for_country(country: str) -> list[ComputationModule]: + """Return modules applicable to a country code ('uk' or 'us').""" + return [m for m in MODULE_REGISTRY.values() if country in m.countries] + + +def get_all_module_names() -> list[str]: + """Return all registered module names.""" + return list(MODULE_REGISTRY.keys()) + + +def validate_modules(names: list[str], country: str) -> list[str]: + """Validate module names against the registry for a given country. + + Returns the validated list. Raises ValueError for unknown or + inapplicable modules. + """ + available = {m.name for m in get_modules_for_country(country)} + errors = [] + for name in names: + if name not in MODULE_REGISTRY: + errors.append(f"Unknown module: {name!r}") + elif name not in available: + errors.append(f"Module {name!r} is not available for country {country!r}") + if errors: + raise ValueError("; ".join(errors)) + return names diff --git a/src/policyengine_api/api/outputs.py b/src/policyengine_api/api/outputs.py index f87cf62..96196d3 100644 --- a/src/policyengine_api/api/outputs.py +++ b/src/policyengine_api/api/outputs.py @@ -12,6 +12,7 @@ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from sqlmodel import Session, select +from policyengine_api.config import settings from policyengine_api.models import ( AggregateOutput, AggregateOutputCreate, @@ -70,9 +71,17 @@ def _trigger_aggregate_computation( traceparent = _get_traceparent() if "uk" in model.name.lower(): - fn = modal.Function.from_name("policyengine", "compute_aggregate_uk") + fn = modal.Function.from_name( + "policyengine", + "compute_aggregate_uk", + environment_name=settings.modal_environment, + ) else: - fn = modal.Function.from_name("policyengine", "compute_aggregate_us") + fn = modal.Function.from_name( + "policyengine", + "compute_aggregate_us", + environment_name=settings.modal_environment, + ) fn.spawn(aggregate_id=aggregate_id, traceparent=traceparent) logfire.info( diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index db029e5..72b64ef 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -5,12 +5,16 @@ Parameter names are used when creating policy reforms. """ -from typing import List +from __future__ import annotations + +from typing import List, Literal from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel from sqlmodel import Session, select +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( Parameter, ParameterRead, @@ -67,6 +71,151 @@ def list_parameters( return parameters +class ParameterByNameRequest(BaseModel): + """Request body for looking up parameters by name.""" + + names: list[str] + country_id: CountryId + + +@router.post("/by-name", response_model=List[ParameterRead]) +def get_parameters_by_name( + request: ParameterByNameRequest, + session: Session = Depends(get_session), +): + """Look up parameters by their exact names. + + Given a list of parameter paths (e.g. "gov.hmrc.income_tax.rates.uk[0].rate"), + returns the full metadata for each matching parameter. Names that don't match + any parameter are silently omitted from the response. + + Use this to fetch metadata for a known set of parameters (e.g. all parameters + referenced in a user's saved policy) without loading the entire parameter catalog. + """ + if not request.names: + return [] + + model_name = COUNTRY_MODEL_NAMES[request.country_id] + + query = ( + select(Parameter) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(Parameter.name.in_(request.names)) + .order_by(Parameter.name) + ) + + return session.exec(query).all() + + +class ParameterChild(BaseModel): + """A single child in the parameter tree.""" + + path: str + label: str + type: Literal["node", "parameter"] + child_count: int | None = None + parameter: ParameterRead | None = None + + +class ParameterChildrenResponse(BaseModel): + """Response for the parameter children endpoint.""" + + parent_path: str + children: list[ParameterChild] + + +@router.get("/children", response_model=ParameterChildrenResponse) +def get_parameter_children( + country_id: CountryId = Query(description='Country ID ("us" or "uk")'), + parent_path: str = Query( + default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')" + ), + session: Session = Depends(get_session), +) -> ParameterChildrenResponse: + """Get direct children of a parameter path for tree navigation. + + Returns both intermediate nodes (folders with child_count) and leaf + parameters (with full metadata). Use this to lazily load the parameter + tree one level at a time. + """ + model_name = COUNTRY_MODEL_NAMES[country_id] + prefix = f"{parent_path}." if parent_path else "" + + # Fetch all parameters under this path + query = ( + select(Parameter) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(Parameter.name.startswith(prefix)) + ) + descendants = session.exec(query).all() + + # Group by direct child path + children_map: dict[str, dict] = {} + prefix_len = len(prefix) + + for param in descendants: + remainder = param.name[prefix_len:] + dot_pos = remainder.find(".") + + if dot_pos == -1: + # Direct child (leaf at this level) + child_path = param.name + if child_path not in children_map: + children_map[child_path] = { + "direct_param": None, + "descendant_count": 0, + } + children_map[child_path]["direct_param"] = param + else: + # Deeper descendant — extract direct child segment + segment = remainder[:dot_pos] + child_path = prefix + segment + if child_path not in children_map: + children_map[child_path] = { + "direct_param": None, + "descendant_count": 0, + } + children_map[child_path]["descendant_count"] += 1 + + # Build response + children = [] + for path in sorted(children_map): + info = children_map[path] + if info["descendant_count"] > 0: + # Node: has children below it + direct_param = info["direct_param"] + label = ( + direct_param.label + if direct_param and direct_param.label + else path.rsplit(".", 1)[-1] + ) + children.append( + ParameterChild( + path=path, + label=label, + type="node", + child_count=info["descendant_count"], + ) + ) + elif info["direct_param"]: + # Leaf parameter + param = info["direct_param"] + children.append( + ParameterChild( + path=path, + label=param.label or path.rsplit(".", 1)[-1], + type="parameter", + parameter=ParameterRead.model_validate(param), + ) + ) + + return ParameterChildrenResponse(parent_path=parent_path, children=children) + + @router.get("/{parameter_id}", response_model=ParameterRead) def get_parameter(parameter_id: UUID, session: Session = Depends(get_session)): """Get a specific parameter.""" diff --git a/src/policyengine_api/api/policies.py b/src/policyengine_api/api/policies.py index d0e2ca5..ad8be5f 100644 --- a/src/policyengine_api/api/policies.py +++ b/src/policyengine_api/api/policies.py @@ -27,22 +27,53 @@ 6. Poll GET /analysis/economic-impact/{report_id} until status="completed" """ -from datetime import datetime from typing import List from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import selectinload from sqlmodel import Session, select from policyengine_api.models import ( Parameter, ParameterValue, + ParameterValueWithName, Policy, PolicyCreate, PolicyRead, + TaxBenefitModel, ) from policyengine_api.services.database import get_session + +def _policy_to_read(policy: Policy) -> PolicyRead: + """Convert a Policy ORM object to PolicyRead with parameter names.""" + pv_with_names = [] + for pv in policy.parameter_values: + pv_with_names.append( + ParameterValueWithName( + id=pv.id, + parameter_id=pv.parameter_id, + value_json=pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, + policy_id=pv.policy_id, + dynamic_id=pv.dynamic_id, + created_at=pv.created_at, + parameter_name=pv.parameter.name, + ) + ) + return PolicyRead( + id=policy.id, + name=policy.name, + description=policy.description, + tax_benefit_model_id=policy.tax_benefit_model_id, + created_at=policy.created_at, + updated_at=policy.updated_at, + parameter_values=pv_with_names, + ) + + router = APIRouter(prefix="/policies", tags=["policies"]) @@ -67,61 +98,82 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session)) ] } """ + # Validate tax_benefit_model exists + tax_model = session.get(TaxBenefitModel, policy.tax_benefit_model_id) + if not tax_model: + raise HTTPException(status_code=404, detail="Tax benefit model not found") + # Create the policy - db_policy = Policy(name=policy.name, description=policy.description) + db_policy = Policy( + name=policy.name, + description=policy.description, + tax_benefit_model_id=policy.tax_benefit_model_id, + ) session.add(db_policy) session.flush() # Get the policy ID before adding parameter values # Create associated parameter values for pv_data in policy.parameter_values: # Validate parameter exists - param = session.get(Parameter, pv_data["parameter_id"]) + param = session.get(Parameter, pv_data.parameter_id) if not param: raise HTTPException( status_code=404, - detail=f"Parameter {pv_data['parameter_id']} not found", + detail=f"Parameter {pv_data.parameter_id} not found", ) - # Parse dates - start_date = ( - datetime.fromisoformat(pv_data["start_date"].replace("Z", "+00:00")) - if isinstance(pv_data["start_date"], str) - else pv_data["start_date"] - ) - end_date = None - if pv_data.get("end_date"): - end_date = ( - datetime.fromisoformat(pv_data["end_date"].replace("Z", "+00:00")) - if isinstance(pv_data["end_date"], str) - else pv_data["end_date"] - ) - - # Create parameter value + # Create parameter value (dates already parsed by Pydantic) db_pv = ParameterValue( - parameter_id=pv_data["parameter_id"], - value_json=pv_data["value_json"], - start_date=start_date, - end_date=end_date, + parameter_id=pv_data.parameter_id, + value_json=pv_data.value_json, + start_date=pv_data.start_date, + end_date=pv_data.end_date, policy_id=db_policy.id, ) session.add(db_pv) session.commit() - session.refresh(db_policy) - return db_policy + + # Re-fetch with eager loading for the response + query = ( + select(Policy) + .where(Policy.id == db_policy.id) + .options( + selectinload(Policy.parameter_values).selectinload(ParameterValue.parameter) + ) + ) + db_policy = session.exec(query).one() + return _policy_to_read(db_policy) @router.get("/", response_model=List[PolicyRead]) -def list_policies(session: Session = Depends(get_session)): - """List all policies.""" - policies = session.exec(select(Policy)).all() - return policies +def list_policies( + tax_benefit_model_id: UUID | None = Query( + None, description="Filter by tax benefit model" + ), + session: Session = Depends(get_session), +): + """List all policies, optionally filtered by tax benefit model.""" + query = select(Policy).options( + selectinload(Policy.parameter_values).selectinload(ParameterValue.parameter) + ) + if tax_benefit_model_id: + query = query.where(Policy.tax_benefit_model_id == tax_benefit_model_id) + policies = session.exec(query).all() + return [_policy_to_read(p) for p in policies] @router.get("/{policy_id}", response_model=PolicyRead) def get_policy(policy_id: UUID, session: Session = Depends(get_session)): """Get a specific policy.""" - policy = session.get(Policy, policy_id) + query = ( + select(Policy) + .where(Policy.id == policy_id) + .options( + selectinload(Policy.parameter_values).selectinload(ParameterValue.parameter) + ) + ) + policy = session.exec(query).first() if not policy: raise HTTPException(status_code=404, detail="Policy not found") - return policy + return _policy_to_read(policy) diff --git a/src/policyengine_api/api/regions.py b/src/policyengine_api/api/regions.py new file mode 100644 index 0000000..1d0a34e --- /dev/null +++ b/src/policyengine_api/api/regions.py @@ -0,0 +1,102 @@ +"""Region endpoints for geographic areas used in analysis. + +Regions represent geographic areas from countries down to states, +congressional districts, cities, etc. Each region has an associated +dataset for running simulations. +""" + +from typing import List +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import Region, RegionRead, TaxBenefitModel +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/regions", tags=["regions"]) + + +@router.get("/", response_model=List[RegionRead]) +def list_regions( + tax_benefit_model_id: UUID | None = Query( + None, description="Filter by tax-benefit model ID" + ), + tax_benefit_model_name: str | None = Query( + None, description="Filter by tax-benefit model name (e.g., 'policyengine-us')" + ), + region_type: str | None = Query( + None, + description="Filter by region type (e.g., 'state', 'congressional_district')", + ), + session: Session = Depends(get_session), +): + """List available regions. + + Returns regions that can be used with the /analysis/economic-impact endpoint. + Each region represents a geographic area with an associated dataset. + + Args: + tax_benefit_model_id: Filter by tax-benefit model UUID. + tax_benefit_model_name: Filter by model name (e.g., "policyengine-us"). + region_type: Filter by region type (e.g., "state", "congressional_district"). + """ + query = select(Region) + + if tax_benefit_model_id: + query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) + elif tax_benefit_model_name: + query = query.join(TaxBenefitModel).where( + TaxBenefitModel.name == tax_benefit_model_name + ) + + if region_type: + query = query.where(Region.region_type == region_type) + + regions = session.exec(query).all() + return regions + + +@router.get("/{region_id}", response_model=RegionRead) +def get_region(region_id: UUID, session: Session = Depends(get_session)): + """Get a specific region by ID.""" + region = session.get(Region, region_id) + if not region: + raise HTTPException(status_code=404, detail="Region not found") + return region + + +@router.get("/by-code/{region_code:path}", response_model=RegionRead) +def get_region_by_code( + region_code: str, + tax_benefit_model_id: UUID | None = Query( + None, + description="Tax-benefit model ID (required if multiple models have same region code)", + ), + tax_benefit_model_name: str | None = Query( + None, description="Tax-benefit model name (e.g., 'policyengine-us')" + ), + session: Session = Depends(get_session), +): + """Get a specific region by code. + + Region codes use a prefix format like "state/ca" or "constituency/Sheffield Central". + + Args: + region_code: The region code (e.g., "state/ca", "us"). + tax_benefit_model_id: Filter by tax-benefit model UUID. + tax_benefit_model_name: Filter by model name. + """ + query = select(Region).where(Region.code == region_code) + + if tax_benefit_model_id: + query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) + elif tax_benefit_model_name: + query = query.join(TaxBenefitModel).where( + TaxBenefitModel.name == tax_benefit_model_name + ) + + region = session.exec(query).first() + if not region: + raise HTTPException(status_code=404, detail="Region not found") + return region diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index 633c57c..8477a68 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -1,36 +1,399 @@ -"""Simulation status endpoints. +"""Simulation endpoints. -Simulations are economy-wide tax-benefit calculations running on population datasets. -They are created automatically when you call /analysis/economic-impact. Use these -endpoints to check simulation status (pending, running, completed, failed). +Simulations are individual tax-benefit calculations. Use these endpoints to: +- Create and run household simulations (single household, single policy) +- Create and run economy simulations (population dataset, single policy) +- Check simulation status and retrieve results + +For baseline-vs-reform comparisons, use the /analysis/ endpoints instead. """ -from typing import List +from typing import Any, List, Literal from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field, model_validator from sqlmodel import Session, select -from policyengine_api.models import Simulation, SimulationRead +from policyengine_api.models import ( + Dataset, + Household, + Policy, + Region, + RegionDatasetLink, + Simulation, + SimulationRead, + SimulationStatus, + SimulationType, + TaxBenefitModel, +) from policyengine_api.services.database import get_session +from .analysis import ( + RegionInfo, + _get_model_version, + _get_or_create_simulation, +) + router = APIRouter(prefix="/simulations", tags=["simulations"]) -@router.get("/", response_model=List[SimulationRead]) -def list_simulations(session: Session = Depends(get_session)): - """List all simulations. +# --------------------------------------------------------------------------- +# Request / Response schemas +# --------------------------------------------------------------------------- + + +class HouseholdSimulationRequest(BaseModel): + """Request body for creating a household simulation.""" + + household_id: UUID = Field(description="ID of the stored household") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationResponse(BaseModel): + """Response for a household simulation.""" + + id: UUID + status: SimulationStatus + household_id: UUID | None = None + policy_id: UUID | None = None + household_result: dict[str, Any] | None = None + error_message: str | None = None + + +class EconomySimulationRequest(BaseModel): + """Request body for creating an economy simulation.""" + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us'). Either region or dataset_id must be provided.", + ) + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID. Either region or dataset_id must be provided.", + ) + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + year: int | None = Field( + default=None, + description="Year for the simulation. Uses latest available if omitted.", + ) + + @model_validator(mode="after") + def check_dataset_or_region(self) -> "EconomySimulationRequest": + if not self.dataset_id and not self.region: + raise ValueError("Either dataset_id or region must be provided") + return self + + +class EconomySimulationResponse(BaseModel): + """Response for an economy simulation.""" + + id: UUID + status: SimulationStatus + dataset_id: UUID | None = None + policy_id: UUID | None = None + output_dataset_id: UUID | None = None + filter_field: str | None = None + filter_value: str | None = None + region: RegionInfo | None = None + error_message: str | None = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- - Simulations are created automatically via /analysis/economic-impact. - Check status to see if computation is pending, running, completed, or failed. + +def _resolve_economy_dataset( + tax_benefit_model_name: str, + region_code: str | None, + dataset_id: UUID | None, + session: Session, + year: int | None = None, +) -> tuple[Dataset, Region | None]: + """Resolve dataset from region code or dataset_id for economy simulations. + + When a region is provided, the dataset is resolved from the region_datasets + join table. If year is set, the dataset for that year is selected; + otherwise the latest available year is used. """ - simulations = session.exec(select(Simulation)).all() + if region_code: + model_name = tax_benefit_model_name.replace("_", "-") + region = session.exec( + select(Region) + .join(TaxBenefitModel) + .where(Region.code == region_code) + .where(TaxBenefitModel.name == model_name) + ).first() + if not region: + raise HTTPException( + status_code=404, + detail=f"Region '{region_code}' not found for model {model_name}", + ) + + # Resolve dataset from join table + query = ( + select(Dataset) + .join(RegionDatasetLink) + .where(RegionDatasetLink.region_id == region.id) + ) + if year: + query = query.where(Dataset.year == year) + else: + query = query.order_by(Dataset.year.desc()) # type: ignore + dataset = session.exec(query).first() + + if not dataset: + year_msg = f" for year {year}" if year else "" + raise HTTPException( + status_code=404, + detail=f"No dataset found for region '{region_code}'{year_msg}", + ) + return dataset, region + + elif dataset_id: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset {dataset_id} not found", + ) + return dataset, None + + else: + raise HTTPException( + status_code=400, + detail="Either region or dataset_id must be provided", + ) + + +def _build_household_response(simulation: Simulation) -> HouseholdSimulationResponse: + """Build response from a household simulation.""" + return HouseholdSimulationResponse( + id=simulation.id, + status=simulation.status, + household_id=simulation.household_id, + policy_id=simulation.policy_id, + household_result=simulation.household_result, + error_message=simulation.error_message, + ) + + +def _build_economy_response( + simulation: Simulation, region: Region | None = None +) -> EconomySimulationResponse: + """Build response from an economy simulation.""" + region_info = None + if region: + region_info = RegionInfo( + code=region.code, + label=region.label, + region_type=region.region_type, + requires_filter=region.requires_filter, + filter_field=region.filter_field, + filter_value=region.filter_value, + ) + + return EconomySimulationResponse( + id=simulation.id, + status=simulation.status, + dataset_id=simulation.dataset_id, + policy_id=simulation.policy_id, + output_dataset_id=simulation.output_dataset_id, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, + region=region_info, + error_message=simulation.error_message, + ) + + +# --------------------------------------------------------------------------- +# List / generic get (existing endpoints) +# --------------------------------------------------------------------------- + + +@router.get("/", response_model=List[SimulationRead]) +def list_simulations( + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List simulations with pagination.""" + simulations = session.exec(select(Simulation).offset(offset).limit(limit)).all() return simulations +# --------------------------------------------------------------------------- +# Household simulation endpoints +# --------------------------------------------------------------------------- + + +@router.post("/household", response_model=HouseholdSimulationResponse) +def create_household_simulation( + request: HouseholdSimulationRequest, + session: Session = Depends(get_session), +): + """Create a household simulation job. + + Creates a Simulation record for the given household and policy. + Returns immediately with status "pending". + Poll GET /simulations/household/{id} until status is "completed". + """ + # Validate household exists + household = session.get(Household, request.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {request.household_id} not found", + ) + + # Validate policy exists (if provided) + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {request.policy_id} not found", + ) + + # Get model version + model_version = _get_model_version(household.tax_benefit_model_name, session) + + # Get or create simulation (deterministic UUID) + simulation = _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + household_id=request.household_id, + ) + + return _build_household_response(simulation) + + +@router.get("/household/{simulation_id}", response_model=HouseholdSimulationResponse) +def get_household_simulation( + simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get a household simulation's status and result.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + if simulation.simulation_type != SimulationType.HOUSEHOLD: + raise HTTPException( + status_code=400, + detail="Simulation is not a household simulation", + ) + + return _build_household_response(simulation) + + +# --------------------------------------------------------------------------- +# Economy simulation endpoints +# --------------------------------------------------------------------------- + + +@router.post("/economy", response_model=EconomySimulationResponse) +def create_economy_simulation( + request: EconomySimulationRequest, + session: Session = Depends(get_session), +): + """Create a single economy simulation. + + Creates a Simulation record for the given dataset/region and policy. + Poll GET /simulations/economy/{id} until status is "completed". + + Note: standalone economy simulation computation will be connected + in future tasks. For full baseline-vs-reform economy analysis, + use POST /analysis/economic-impact instead. + """ + # Resolve dataset and region + dataset, region = _resolve_economy_dataset( + request.tax_benefit_model_name, + request.region, + request.dataset_id, + session, + year=request.year, + ) + + # Validate policy exists (if provided) + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {request.policy_id} not found", + ) + + # Extract filter parameters from region + filter_field = region.filter_field if region and region.requires_filter else None + filter_value = region.filter_value if region and region.requires_filter else None + + # Get model version + model_version = _get_model_version(request.tax_benefit_model_name, session) + + # Get or create simulation (deterministic UUID) + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + region_id=region.id if region else None, + year=dataset.year, + ) + + return _build_economy_response(simulation, region) + + +@router.get("/economy/{simulation_id}", response_model=EconomySimulationResponse) +def get_economy_simulation( + simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get an economy simulation's status and result.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + if simulation.simulation_type != SimulationType.ECONOMY: + raise HTTPException( + status_code=400, + detail="Simulation is not an economy simulation", + ) + + return _build_economy_response(simulation) + + +# --------------------------------------------------------------------------- +# Generic get (keep after specific routes to avoid path conflicts) +# --------------------------------------------------------------------------- + + @router.get("/{simulation_id}", response_model=SimulationRead) def get_simulation(simulation_id: UUID, session: Session = Depends(get_session)): - """Get a specific simulation.""" + """Get a specific simulation (any type).""" simulation = session.get(Simulation, simulation_id) if not simulation: raise HTTPException(status_code=404, detail="Simulation not found") diff --git a/src/policyengine_api/api/tax_benefit_models.py b/src/policyengine_api/api/tax_benefit_models.py index 5dda3a1..b4d4921 100644 --- a/src/policyengine_api/api/tax_benefit_models.py +++ b/src/policyengine_api/api/tax_benefit_models.py @@ -9,9 +9,16 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.models import TaxBenefitModel, TaxBenefitModelRead +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelRead, + TaxBenefitModelVersion, + TaxBenefitModelVersionRead, +) from policyengine_api.services.database import get_session router = APIRouter(prefix="/tax-benefit-models", tags=["tax-benefit-models"]) @@ -28,6 +35,55 @@ def list_tax_benefit_models(session: Session = Depends(get_session)): return models +class ModelByCountryResponse(BaseModel): + """Response for the model-by-country endpoint.""" + + model: TaxBenefitModelRead + latest_version: TaxBenefitModelVersionRead + + +@router.get( + "/by-country/{country_id}", + response_model=ModelByCountryResponse, +) +def get_model_by_country( + country_id: CountryId, + session: Session = Depends(get_session), +): + """Get a tax-benefit model and its latest version by country ID. + + Returns the model metadata and the most recently created version in a + single response. Use this on page load to check the current model version + for cache invalidation. + """ + model_name = COUNTRY_MODEL_NAMES[country_id] + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, + detail=f"No model found for country '{country_id}'", + ) + + latest_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not latest_version: + raise HTTPException( + status_code=404, + detail=f"No versions found for model '{model_name}'", + ) + + return ModelByCountryResponse( + model=TaxBenefitModelRead.model_validate(model), + latest_version=TaxBenefitModelVersionRead.model_validate(latest_version), + ) + + @router.get("/{model_id}", response_model=TaxBenefitModelRead) def get_tax_benefit_model(model_id: UUID, session: Session = Depends(get_session)): """Get a specific tax-benefit model.""" diff --git a/src/policyengine_api/api/user_household_associations.py b/src/policyengine_api/api/user_household_associations.py new file mode 100644 index 0000000..23282d5 --- /dev/null +++ b/src/policyengine_api/api/user_household_associations.py @@ -0,0 +1,143 @@ +"""User-household association endpoints. + +Associations link a user to a stored household definition with metadata +(label, country). A user can have multiple associations to the same +household (e.g. different labels or configurations). +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import ( + Household, + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter( + prefix="/user-household-associations", + tags=["user-household-associations"], +) + + +@router.post("/", response_model=UserHouseholdAssociationRead, status_code=201) +def create_association( + body: UserHouseholdAssociationCreate, + session: Session = Depends(get_session), +): + """Create a user-household association.""" + household = session.get(Household, body.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {body.household_id} not found", + ) + + record = UserHouseholdAssociation( + user_id=body.user_id, + household_id=body.household_id, + country_id=body.country_id, + label=body.label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/user/{user_id}", response_model=list[UserHouseholdAssociationRead]) +def list_by_user( + user_id: UUID, + country_id: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List all associations for a user, optionally filtered by country.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id + ) + if country_id is not None: + query = query.where(UserHouseholdAssociation.country_id == country_id) + query = query.offset(offset).limit(limit) + return session.exec(query).all() + + +@router.get( + "/{user_id}/{household_id}", + response_model=list[UserHouseholdAssociationRead], +) +def list_by_user_and_household( + user_id: UUID, + household_id: UUID, + session: Session = Depends(get_session), +): + """List all associations for a specific user+household pair.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id, + UserHouseholdAssociation.household_id == household_id, + ) + return session.exec(query).all() + + +@router.put("/{association_id}", response_model=UserHouseholdAssociationRead) +def update_association( + association_id: UUID, + body: UserHouseholdAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-household association (label). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserHouseholdAssociation).where( + UserHouseholdAssociation.id == association_id, + UserHouseholdAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + update_data = body.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + record.updated_at = datetime.now(timezone.utc) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{association_id}", status_code=204) +def delete_association( + association_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-household association. + + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserHouseholdAssociation).where( + UserHouseholdAssociation.id == association_id, + UserHouseholdAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/api/user_policies.py b/src/policyengine_api/api/user_policies.py new file mode 100644 index 0000000..3cc2c08 --- /dev/null +++ b/src/policyengine_api/api/user_policies.py @@ -0,0 +1,147 @@ +"""User-policy association endpoints. + +Associates users with policies they've saved/created. This enables users to +maintain a list of their policies across sessions without duplicating the +underlying policy data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save policies without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Policy, + UserPolicy, + UserPolicyCreate, + UserPolicyRead, + UserPolicyUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-policies", tags=["user-policies"]) + + +@router.post("/", response_model=UserPolicyRead) +def create_user_policy( + user_policy: UserPolicyCreate, + session: Session = Depends(get_session), +): + """Create a new user-policy association. + + Associates a user with a policy, allowing them to save it to their list. + Duplicates are allowed - users can save the same policy multiple times + with different labels (matching FE localStorage behavior). + + Note: user_id is not validated - it's a client-generated UUID from localStorage. + Note: country_id is validated via Pydantic Literal type to "us" or "uk". + """ + # Validate policy exists + policy = session.get(Policy, user_policy.policy_id) + if not policy: + raise HTTPException(status_code=404, detail="Policy not found") + + # Create the association (duplicates allowed) + db_user_policy = UserPolicy.model_validate(user_policy) + session.add(db_user_policy) + session.commit() + session.refresh(db_user_policy) + return db_user_policy + + +@router.get("/", response_model=list[UserPolicyRead]) +def list_user_policies( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all policy associations for a user. + + Returns all policies saved by the specified user. Optionally filter by country. + Country ID is validated via Pydantic Literal type. + """ + query = select(UserPolicy).where(UserPolicy.user_id == user_id) + + if country_id: + query = query.where(UserPolicy.country_id == country_id) + + user_policies = session.exec(query).all() + return user_policies + + +@router.get("/{user_policy_id}", response_model=UserPolicyRead) +def get_user_policy( + user_policy_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-policy association by ID.""" + user_policy = session.get(UserPolicy, user_policy_id) + if not user_policy: + raise HTTPException(status_code=404, detail="User-policy association not found") + return user_policy + + +@router.patch("/{user_policy_id}", response_model=UserPolicyRead) +def update_user_policy( + user_policy_id: UUID, + updates: UserPolicyUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-policy association (e.g., rename label). + + Requires user_id to verify ownership - only the owner can update. + """ + user_policy = session.exec( + select(UserPolicy).where( + UserPolicy.id == user_policy_id, + UserPolicy.user_id == user_id, + ) + ).first() + if not user_policy: + raise HTTPException(status_code=404, detail="User-policy association not found") + + # Apply updates + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(user_policy, key, value) + + # Update timestamp + user_policy.updated_at = datetime.now(timezone.utc) + + session.add(user_policy) + session.commit() + session.refresh(user_policy) + return user_policy + + +@router.delete("/{user_policy_id}", status_code=204) +def delete_user_policy( + user_policy_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-policy association. + + This only removes the association, not the underlying policy. + Requires user_id to verify ownership - only the owner can delete. + """ + user_policy = session.exec( + select(UserPolicy).where( + UserPolicy.id == user_policy_id, + UserPolicy.user_id == user_id, + ) + ).first() + if not user_policy: + raise HTTPException(status_code=404, detail="User-policy association not found") + + session.delete(user_policy) + session.commit() diff --git a/src/policyengine_api/api/user_report_associations.py b/src/policyengine_api/api/user_report_associations.py new file mode 100644 index 0000000..5428cd1 --- /dev/null +++ b/src/policyengine_api/api/user_report_associations.py @@ -0,0 +1,500 @@ +"""User-report association endpoints. + +Associates users with reports they've created. This enables users to +maintain a list of their reports across sessions without duplicating +the underlying report data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save reports without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy.orm import selectinload +from sqlmodel import Session, select + +from policyengine_api.api.analysis import ( + EconomicImpactResponse, + RegionInfo, + SimulationInfo, +) +from policyengine_api.api.analysis import ( + _build_response as build_economic_response, +) +from policyengine_api.api.household_analysis import ( + HouseholdImpactResponse, + build_household_response, +) +from policyengine_api.api.households import _to_read as household_to_read +from policyengine_api.api.policies import _policy_to_read +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Household, + HouseholdRead, + ParameterValue, + Policy, + PolicyRead, + Region, + Report, + ReportRead, + Simulation, + UserReportAssociation, + UserReportAssociationCreate, + UserReportAssociationRead, + UserReportAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-reports", tags=["user-reports"]) + + +@router.post("/", response_model=UserReportAssociationRead) +def create_user_report( + body: UserReportAssociationCreate, + session: Session = Depends(get_session), +): + """Create a new user-report association. + + Associates a user with a report, allowing them to save it to their list. + Duplicates are allowed - users can save the same report multiple times + with different labels. + """ + report = session.get(Report, body.report_id) + if not report: + raise HTTPException(status_code=404, detail="Report not found") + + record = UserReportAssociation.model_validate(body) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/", response_model=list[UserReportAssociationRead]) +def list_user_reports( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all report associations for a user. + + Returns all reports saved by the specified user. Optionally filter by country. + """ + query = select(UserReportAssociation).where( + UserReportAssociation.user_id == user_id + ) + + if country_id: + query = query.where(UserReportAssociation.country_id == country_id) + + return session.exec(query).all() + + +@router.get("/{user_report_id}", response_model=UserReportAssociationRead) +def get_user_report( + user_report_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-report association by ID.""" + record = session.get(UserReportAssociation, user_report_id) + if not record: + raise HTTPException(status_code=404, detail="User-report association not found") + return record + + +@router.patch("/{user_report_id}", response_model=UserReportAssociationRead) +def update_user_report( + user_report_id: UUID, + updates: UserReportAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-report association (e.g., rename label or update last_run_at). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserReportAssociation).where( + UserReportAssociation.id == user_report_id, + UserReportAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException(status_code=404, detail="User-report association not found") + + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + + record.updated_at = datetime.now(timezone.utc) + + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{user_report_id}", status_code=204) +def delete_user_report( + user_report_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-report association. + + This only removes the association, not the underlying report. + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserReportAssociation).where( + UserReportAssociation.id == user_report_id, + UserReportAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException(status_code=404, detail="User-report association not found") + + session.delete(record) + session.commit() + + +# --------------------------------------------------------------------------- +# GET /user-reports/{user_report_id}/full — read-only composite endpoint +# --------------------------------------------------------------------------- + + +class UserReportFullResponse(BaseModel): + """Complete user-report data in a single response. + + Read-only: does NOT trigger computation. If the report hasn't been + run yet, status will be "pending" and result fields will be null. + """ + + # Association metadata + id: UUID + user_id: UUID + report_id: UUID + country_id: str + label: str | None + created_at: datetime + last_run_at: datetime | None + + # Report + report: ReportRead + + # Simulations (metadata only) + baseline_simulation: SimulationInfo | None = None + reform_simulation: SimulationInfo | None = None + + # Policies with parameter values (null = current law) + baseline_policy: PolicyRead | None = None + reform_policy: PolicyRead | None = None + + # Population (household-type reports only) + household: HouseholdRead | None = None + + # Region (economy-type reports only) + region: RegionInfo | None = None + + # Results — one of these is populated when status == completed + economic_impact: EconomicImpactResponse | None = None + household_impact: HouseholdImpactResponse | None = None + + +@router.get("/{user_report_id}/full", response_model=UserReportFullResponse) +def get_user_report_full( + user_report_id: UUID, + session: Session = Depends(get_session), +): + """Get complete user-report data in a single call. + + Assembles association metadata, report, simulations, policies (with + parameter values), household/region, and results into one response. + + Read-only: does NOT trigger computation. + """ + # 1. Load association + record = session.get(UserReportAssociation, user_report_id) + if not record: + raise HTTPException(status_code=404, detail="User-report association not found") + + # 2. Load report + report = session.get(Report, record.report_id) + if not report: + raise HTTPException(status_code=404, detail="Report not found") + + report_read = ReportRead.model_validate(report) + + # 3. Load simulations + baseline_sim = ( + session.get(Simulation, report.baseline_simulation_id) + if report.baseline_simulation_id + else None + ) + reform_sim = ( + session.get(Simulation, report.reform_simulation_id) + if report.reform_simulation_id + else None + ) + + baseline_sim_info = ( + SimulationInfo( + id=baseline_sim.id, + status=baseline_sim.status, + error_message=baseline_sim.error_message, + ) + if baseline_sim + else None + ) + reform_sim_info = ( + SimulationInfo( + id=reform_sim.id, + status=reform_sim.status, + error_message=reform_sim.error_message, + ) + if reform_sim + else None + ) + + # 4. Load policies with parameter values (eager-loaded) + baseline_policy_read = _load_policy_read( + baseline_sim.policy_id if baseline_sim else None, session + ) + reform_policy_read = _load_policy_read( + reform_sim.policy_id if reform_sim else None, session + ) + + # 5. Load household (for household-type reports) + household_read = None + if baseline_sim and baseline_sim.household_id: + household = session.get(Household, baseline_sim.household_id) + if household: + household_read = household_to_read(household) + + # 6. Build region info (for economy-type reports) + region_info = _build_region_info(baseline_sim, session) + + # 7. Build results + economic_impact = None + household_impact = None + + is_economy = report.report_type and "economy" in report.report_type + is_household = report.report_type and "household" in report.report_type + + if is_economy and baseline_sim and reform_sim: + # Look up region object for full response + region_obj = ( + session.get(Region, baseline_sim.region_id) + if baseline_sim.region_id + else None + ) + economic_impact = build_economic_response( + report, baseline_sim, reform_sim, session, region_obj + ) + elif is_household and baseline_sim: + household_impact = build_household_response( + report, baseline_sim, reform_sim, session + ) + + return UserReportFullResponse( + id=record.id, + user_id=record.user_id, + report_id=record.report_id, + country_id=record.country_id, + label=record.label, + created_at=record.created_at, + last_run_at=record.last_run_at, + report=report_read, + baseline_simulation=baseline_sim_info, + reform_simulation=reform_sim_info, + baseline_policy=baseline_policy_read, + reform_policy=reform_policy_read, + household=household_read, + region=region_info, + economic_impact=economic_impact, + household_impact=household_impact, + ) + + +def _load_policy_read(policy_id: UUID | None, session: Session) -> PolicyRead | None: + """Load a policy with eager-loaded parameter values, or return None.""" + if not policy_id: + return None + + query = ( + select(Policy) + .where(Policy.id == policy_id) + .options( + selectinload(Policy.parameter_values).selectinload(ParameterValue.parameter) + ) + ) + policy = session.exec(query).first() + if not policy: + return None + + return _policy_to_read(policy) + + +def _build_region_info( + simulation: Simulation | None, session: Session +) -> RegionInfo | None: + """Build RegionInfo from a simulation's region_id FK.""" + if not simulation or not simulation.region_id: + return None + + region = session.get(Region, simulation.region_id) + if not region: + return None + + return RegionInfo( + code=region.code, + label=region.label, + region_type=region.region_type, + requires_filter=region.requires_filter, + filter_field=region.filter_field, + filter_value=region.filter_value, + ) + + +# --------------------------------------------------------------------------- +# GET /reports/{report_id}/full — report-level composite endpoint +# --------------------------------------------------------------------------- + +reports_router = APIRouter(prefix="/reports", tags=["reports"]) + + +class ReportFullResponse(BaseModel): + """Complete report data in a single response. + + Read-only: does NOT trigger computation. + Like UserReportFullResponse but keyed by report_id instead of + user-report association ID. + """ + + # Report + report: ReportRead + + # Simulations (metadata only) + baseline_simulation: SimulationInfo | None = None + reform_simulation: SimulationInfo | None = None + + # Policies with parameter values (null = current law) + baseline_policy: PolicyRead | None = None + reform_policy: PolicyRead | None = None + + # Population (household-type reports only) + household: HouseholdRead | None = None + + # Region (economy-type reports only) + region: RegionInfo | None = None + + # Results — one of these is populated when status == completed + economic_impact: EconomicImpactResponse | None = None + household_impact: HouseholdImpactResponse | None = None + + +@reports_router.get("/{report_id}/full", response_model=ReportFullResponse) +def get_report_full( + report_id: UUID, + session: Session = Depends(get_session), +): + """Get complete report data in a single call. + + Assembles report, simulations, policies (with parameter values), + household/region, and results into one response. + + Read-only: does NOT trigger computation. + """ + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail="Report not found") + + report_read = ReportRead.model_validate(report) + + # Load simulations + baseline_sim = ( + session.get(Simulation, report.baseline_simulation_id) + if report.baseline_simulation_id + else None + ) + reform_sim = ( + session.get(Simulation, report.reform_simulation_id) + if report.reform_simulation_id + else None + ) + + baseline_sim_info = ( + SimulationInfo( + id=baseline_sim.id, + status=baseline_sim.status, + error_message=baseline_sim.error_message, + ) + if baseline_sim + else None + ) + reform_sim_info = ( + SimulationInfo( + id=reform_sim.id, + status=reform_sim.status, + error_message=reform_sim.error_message, + ) + if reform_sim + else None + ) + + # Load policies with parameter values + baseline_policy_read = _load_policy_read( + baseline_sim.policy_id if baseline_sim else None, session + ) + reform_policy_read = _load_policy_read( + reform_sim.policy_id if reform_sim else None, session + ) + + # Load household (for household-type reports) + household_read = None + if baseline_sim and baseline_sim.household_id: + household = session.get(Household, baseline_sim.household_id) + if household: + household_read = household_to_read(household) + + # Build region info (for economy-type reports) + region_info = _build_region_info(baseline_sim, session) + + # Build results + economic_impact = None + household_impact = None + + is_economy = report.report_type and "economy" in report.report_type + is_household = report.report_type and "household" in report.report_type + + if is_economy and baseline_sim and reform_sim: + region_obj = ( + session.get(Region, baseline_sim.region_id) + if baseline_sim.region_id + else None + ) + economic_impact = build_economic_response( + report, baseline_sim, reform_sim, session, region_obj + ) + elif is_household and baseline_sim: + household_impact = build_household_response( + report, baseline_sim, reform_sim, session + ) + + return ReportFullResponse( + report=report_read, + baseline_simulation=baseline_sim_info, + reform_simulation=reform_sim_info, + baseline_policy=baseline_policy_read, + reform_policy=reform_policy_read, + household=household_read, + region=region_info, + economic_impact=economic_impact, + household_impact=household_impact, + ) diff --git a/src/policyengine_api/api/user_simulation_associations.py b/src/policyengine_api/api/user_simulation_associations.py new file mode 100644 index 0000000..2341d91 --- /dev/null +++ b/src/policyengine_api/api/user_simulation_associations.py @@ -0,0 +1,146 @@ +"""User-simulation association endpoints. + +Associates users with simulations they've run. This enables users to +maintain a list of their simulations across sessions without duplicating +the underlying simulation data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save simulations without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Simulation, + UserSimulationAssociation, + UserSimulationAssociationCreate, + UserSimulationAssociationRead, + UserSimulationAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-simulations", tags=["user-simulations"]) + + +@router.post("/", response_model=UserSimulationAssociationRead) +def create_user_simulation( + body: UserSimulationAssociationCreate, + session: Session = Depends(get_session), +): + """Create a new user-simulation association. + + Associates a user with a simulation, allowing them to save it to their list. + Duplicates are allowed - users can save the same simulation multiple times + with different labels. + """ + simulation = session.get(Simulation, body.simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + + record = UserSimulationAssociation.model_validate(body) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/", response_model=list[UserSimulationAssociationRead]) +def list_user_simulations( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all simulation associations for a user. + + Returns all simulations saved by the specified user. Optionally filter by country. + """ + query = select(UserSimulationAssociation).where( + UserSimulationAssociation.user_id == user_id + ) + + if country_id: + query = query.where(UserSimulationAssociation.country_id == country_id) + + return session.exec(query).all() + + +@router.get("/{user_simulation_id}", response_model=UserSimulationAssociationRead) +def get_user_simulation( + user_simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-simulation association by ID.""" + record = session.get(UserSimulationAssociation, user_simulation_id) + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + return record + + +@router.patch("/{user_simulation_id}", response_model=UserSimulationAssociationRead) +def update_user_simulation( + user_simulation_id: UUID, + updates: UserSimulationAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-simulation association (e.g., rename label). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserSimulationAssociation).where( + UserSimulationAssociation.id == user_simulation_id, + UserSimulationAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + + record.updated_at = datetime.now(timezone.utc) + + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{user_simulation_id}", status_code=204) +def delete_user_simulation( + user_simulation_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-simulation association. + + This only removes the association, not the underlying simulation. + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserSimulationAssociation).where( + UserSimulationAssociation.id == user_simulation_id, + UserSimulationAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + + session.delete(record) + session.commit() diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index d660b1b..04aa512 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -9,8 +9,10 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlmodel import Session, select +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( TaxBenefitModel, TaxBenefitModelVersion, @@ -56,9 +58,9 @@ def list_variables( # Case-insensitive search using ILIKE # Note: Variables don't have a label field, only name and description search_pattern = f"%{search}%" - search_filter = Variable.name.ilike(search_pattern) | Variable.description.ilike( + search_filter = Variable.name.ilike( search_pattern - ) + ) | Variable.description.ilike(search_pattern) query = query.where(search_filter) variables = session.exec( @@ -67,6 +69,44 @@ def list_variables( return variables +class VariableByNameRequest(BaseModel): + """Request body for looking up variables by name.""" + + names: list[str] + country_id: CountryId + + +@router.post("/by-name", response_model=List[VariableRead]) +def get_variables_by_name( + request: VariableByNameRequest, + session: Session = Depends(get_session), +): + """Look up variables by their exact names. + + Given a list of variable names (e.g. "employment_income", "income_tax"), + returns the full metadata for each matching variable. Names that don't + match any variable are silently omitted from the response. + + Use this to fetch metadata for a known set of variables (e.g. variables + used in a household builder or report output) without loading the entire + variable catalog. + """ + if not request.names: + return [] + + model_name = COUNTRY_MODEL_NAMES[request.country_id] + query = ( + select(Variable) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(Variable.name.in_(request.names)) + .order_by(Variable.name) + ) + + return session.exec(query).all() + + @router.get("/{variable_id}", response_model=VariableRead) def get_variable(variable_id: UUID, session: Session = Depends(get_session)): """Get a specific variable.""" diff --git a/src/policyengine_api/config/constants.py b/src/policyengine_api/config/constants.py new file mode 100644 index 0000000..a39d827 --- /dev/null +++ b/src/policyengine_api/config/constants.py @@ -0,0 +1,12 @@ +"""Shared constants for the PolicyEngine API.""" + +from typing import Literal + +# Countries supported by the API +CountryId = Literal["us", "uk"] + +# Mapping from country ID to tax-benefit model name in the database +COUNTRY_MODEL_NAMES: dict[str, str] = { + "uk": "policyengine-uk", + "us": "policyengine-us", +} diff --git a/src/policyengine_api/config/settings.py b/src/policyengine_api/config/settings.py index 76a1ab1..83d6a89 100644 --- a/src/policyengine_api/config/settings.py +++ b/src/policyengine_api/config/settings.py @@ -38,12 +38,26 @@ class Settings(BaseSettings): agent_use_modal: bool = False policyengine_api_url: str = "https://v2.api.policyengine.org" + # Modal + modal_environment: str = "testing" + @property def database_url(self) -> str: - """Get database URL from Supabase.""" + """Get database URL from Supabase. + + For local development, the database runs on port 54322 (not 54321 which is the API). + Use supabase_db_url to override, or rely on the default local URL. + """ + if self.supabase_db_url: + return self.supabase_db_url + + # For local development, default to the standard Supabase local DB port + if "localhost" in self.supabase_url or "127.0.0.1" in self.supabase_url: + return "postgresql://postgres:postgres@127.0.0.1:54322/postgres" + + # For remote Supabase, construct URL from API URL (usually need supabase_db_url set) return ( - self.supabase_db_url - or self.supabase_url.replace( + self.supabase_url.replace( "http://", "postgresql://postgres:postgres@" ).replace("https://", "postgresql://postgres:postgres@") + "/postgres" diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 1aa8119..2442db3 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -18,16 +18,17 @@ # Cache bust: 2026-01-12-v7-service-role-key-fix base_image = ( modal.Image.debian_slim(python_version="3.13") - .apt_install("libhdf5-dev") + .apt_install("libhdf5-dev", "git") .pip_install("uv") .run_commands( "uv pip install --system --upgrade " - "policyengine>=3.1.15 " + "git+https://github.com/PolicyEngine/policyengine.py.git@app-v2-migration " "sqlmodel>=0.0.22 " "psycopg2-binary>=2.9.10 " "supabase>=2.10.0 " "rich>=13.9.4 " "logfire[httpx]>=3.0.0 " + "pydantic-settings>=2.0.0 " "tables>=3.10.0" # pytables - required for HDF5 dataset operations ) # Include the policyengine_api models package (copy=True allows subsequent build steps) @@ -242,13 +243,13 @@ def simulate_household_uk( engine = create_engine(database_url) try: - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, + UKYearData, ) - from policyengine.tax_benefit_models.uk.datasets import UKYearData n_people = len(people) n_benunits = max(1, len(benunit)) @@ -487,13 +488,13 @@ def simulate_household_us( engine = create_engine(database_url) try: - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, + USYearData, ) - from policyengine.tax_benefit_models.us.datasets import USYearData n_people = len(people) n_households = max(1, len(household)) @@ -771,7 +772,9 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N with logfire.span("simulate_economy_uk", simulation_id=simulation_id): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) @@ -841,25 +844,30 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N tax_benefit_model_version=pe_model_version, policy=policy, dynamic=dynamic, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, ) pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): + from policyengine_api.services.storage import ( + output_filepath, + ) from supabase import create_client - output_filename = f"output_{simulation_id}.h5" - output_path = f"/tmp/{output_filename}" + output_storage_path = output_filepath(simulation_id) + output_local_path = f"/tmp/output_{simulation_id}.h5" # Set filepath and save - pe_sim.output_dataset.filepath = output_path + pe_sim.output_dataset.filepath = output_local_path pe_sim.output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) - with open(output_path, "rb") as f: + with open(output_local_path, "rb") as f: supabase.storage.from_(storage_bucket).upload( - output_filename, + output_storage_path, f, { "content-type": "application/octet-stream", @@ -871,7 +879,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N output_dataset = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", - filepath=output_filename, + filepath=output_storage_path, year=dataset.year, is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, @@ -937,7 +945,9 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N with logfire.span("simulate_economy_us", simulation_id=simulation_id): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) @@ -1007,25 +1017,30 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N tax_benefit_model_version=pe_model_version, policy=policy, dynamic=dynamic, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, ) pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): + from policyengine_api.services.storage import ( + output_filepath, + ) from supabase import create_client - output_filename = f"output_{simulation_id}.h5" - output_path = f"/tmp/{output_filename}" + output_storage_path = output_filepath(simulation_id) + output_local_path = f"/tmp/output_{simulation_id}.h5" # Set filepath and save - pe_sim.output_dataset.filepath = output_path + pe_sim.output_dataset.filepath = output_local_path pe_sim.output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) - with open(output_path, "rb") as f: + with open(output_local_path, "rb") as f: supabase.storage.from_(storage_bucket).upload( - output_filename, + output_storage_path, f, { "content-type": "application/octet-stream", @@ -1037,7 +1052,7 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N output_dataset = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", - filepath=output_filename, + filepath=output_storage_path, year=dataset.year, is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, @@ -1073,7 +1088,7 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N ), {"sim_id": simulation_id, "error": str(e)[:1000]}, ) - session.commit() + session.commit() except Exception as db_error: logfire.error("Failed to update DB", error=str(db_error)) raise @@ -1106,27 +1121,24 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") - # Debug: log the key role - import base64 - import json - try: - payload = supabase_key.split('.')[1] - payload += '=' * (4 - len(payload) % 4) - decoded = json.loads(base64.urlsafe_b64decode(payload)) - logfire.info("Supabase key info", role=decoded.get('role', 'unknown')) - except Exception as e: - logfire.warn("Could not decode key", error=str(e)) - engine = create_engine(database_url) try: # Import models inline from policyengine_api.models import ( + BudgetSummary, + ConstituencyImpact, Dataset, DecileImpact, + Inequality, + IntraDecileImpact, + LocalAuthorityImpact, + Poverty, ProgramStatistics, Report, ReportStatus, @@ -1168,6 +1180,12 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.outputs.aggregate import ( + Aggregate as PEAggregate, + ) + from policyengine.outputs.aggregate import ( + AggregateType as PEAggregateType, + ) from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, @@ -1213,6 +1231,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, ) pe_baseline_sim.ensure() @@ -1222,6 +1242,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, ) pe_reform_sim.ensure() @@ -1250,20 +1272,21 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: pass # Variable may not exist in model # Save output datasets for both simulations + from policyengine_api.services.storage import output_filepath from supabase import create_client supabase = create_client(supabase_url, supabase_key) # Save baseline output with logfire.span("save_baseline_output"): - baseline_output_filename = f"output_{baseline_sim.id}.h5" - baseline_output_path = f"/tmp/{baseline_output_filename}" - pe_baseline_sim.output_dataset.filepath = baseline_output_path + baseline_storage_path = output_filepath(str(baseline_sim.id)) + baseline_local_path = f"/tmp/output_{baseline_sim.id}.h5" + pe_baseline_sim.output_dataset.filepath = baseline_local_path pe_baseline_sim.output_dataset.save() - with open(baseline_output_path, "rb") as f: + with open(baseline_local_path, "rb") as f: supabase.storage.from_(storage_bucket).upload( - baseline_output_filename, + baseline_storage_path, f, { "content-type": "application/octet-stream", @@ -1274,7 +1297,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: baseline_output_dataset = Dataset( name=f"Output: {dataset.name} (baseline)", description=f"Output from baseline simulation {baseline_sim.id}", - filepath=baseline_output_filename, + filepath=baseline_storage_path, year=dataset.year, is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, @@ -1286,14 +1309,14 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: # Save reform output with logfire.span("save_reform_output"): - reform_output_filename = f"output_{reform_sim.id}.h5" - reform_output_path = f"/tmp/{reform_output_filename}" - pe_reform_sim.output_dataset.filepath = reform_output_path + reform_storage_path = output_filepath(str(reform_sim.id)) + reform_local_path = f"/tmp/output_{reform_sim.id}.h5" + pe_reform_sim.output_dataset.filepath = reform_local_path pe_reform_sim.output_dataset.save() - with open(reform_output_path, "rb") as f: + with open(reform_local_path, "rb") as f: supabase.storage.from_(storage_bucket).upload( - reform_output_filename, + reform_storage_path, f, { "content-type": "application/octet-stream", @@ -1304,7 +1327,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: reform_output_dataset = Dataset( name=f"Output: {dataset.name} (reform)", description=f"Output from reform simulation {reform_sim.id}", - filepath=reform_output_filename, + filepath=reform_storage_path, year=dataset.year, is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, @@ -1390,6 +1413,339 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: except KeyError: pass # Variable not in model, skip silently + # Calculate poverty rates for baseline and reform + from policyengine.outputs.poverty import ( + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + calculate_uk_poverty_rates, + ) + + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + poverty_results = calculate_uk_poverty_rates(pe_sim) + for pov in poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate poverty rates by age group + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + age_poverty_results = calculate_uk_poverty_by_age(pe_sim) + for pov in age_poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate poverty rates by gender + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + gender_poverty_results = calculate_uk_poverty_by_gender(pe_sim) + for pov in gender_poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate inequality for baseline and reform + from policyengine.outputs.inequality import ( + calculate_uk_inequality, + ) + + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + ineq = calculate_uk_inequality(pe_sim) + ineq.run() + inequality_record = Inequality( + simulation_id=db_sim.id, + report_id=report.id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(inequality_record) + + # Calculate budget summary aggregates + # UK budget variables — household-level aggregates + uk_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + } + PEAggregate.model_rebuild( + _types_namespace={"Simulation": PESimulation} + ) + for var_name, entity in uk_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + budget_record = BudgetSummary( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(budget_record) + + # Household count: bypass Aggregate and compute directly + # from raw numpy values. Using Aggregate(SUM) on + # household_weight would compute sum(weight * weight) + # because MicroSeries.sum() applies weights automatically + # — it's unclear whether Aggregate can be used correctly + # for summing the weight column itself. + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + budget_record = BudgetSummary( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(budget_record) + + # Calculate intra-decile impact + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + intra_decile_results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + entity="household", + ) + for r in intra_decile_results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + + # Calculate constituency impact + from policyengine.outputs.constituency_impact import ( + compute_uk_constituency_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import ( + download as gcs_download, + ) + + weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="parliamentary_constituency_weights.h5", + ) + constituency_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="constituencies_2024.csv", + ) + constituency_impact = compute_uk_constituency_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=weight_matrix_path, + constituency_csv_path=constituency_csv_path, + ) + if constituency_impact.constituency_results: + for cr in constituency_impact.constituency_results: + record = ConstituencyImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + constituency_code=cr["constituency_code"], + constituency_name=cr["constituency_name"], + x=cr["x"], + y=cr["y"], + average_household_income_change=cr[ + "average_household_income_change" + ], + relative_household_income_change=cr[ + "relative_household_income_change" + ], + population=cr["population"], + ) + session.add(record) + except FileNotFoundError: + logfire.warning( + "Weight matrix not available, skipping constituency impact" + ) + + # Calculate local authority impact + from policyengine.outputs.local_authority_impact import ( + compute_uk_local_authority_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import ( + download as gcs_download_la, + ) + + la_weight_matrix_path = gcs_download_la( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authority_weights.h5", + ) + la_csv_path = gcs_download_la( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authorities_2021.csv", + ) + la_impact = compute_uk_local_authority_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=la_weight_matrix_path, + local_authority_csv_path=la_csv_path, + ) + if la_impact.local_authority_results: + for lr in la_impact.local_authority_results: + record = LocalAuthorityImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + local_authority_code=lr["local_authority_code"], + local_authority_name=lr["local_authority_name"], + x=lr["x"], + y=lr["y"], + average_household_income_change=lr[ + "average_household_income_change" + ], + relative_household_income_change=lr[ + "relative_household_income_change" + ], + population=lr["population"], + ) + session.add(record) + except FileNotFoundError: + logfire.warning( + "Weight matrix not available, skipping local authority impact" + ) + + # Calculate wealth decile impact (UK only) + try: + from policyengine.outputs.decile_impact import ( + DecileImpact as PEDecileImpact, + ) + + PEDecileImpact.model_rebuild( + _types_namespace={"Simulation": PESimulation} + ) + for decile_num in range(1, 11): + wealth_di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + decile=decile_num, + ) + wealth_di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + income_variable="household_wealth_decile", + entity="household", + decile=decile_num, + quantiles=10, + baseline_mean=wealth_di.baseline_mean, + reform_mean=wealth_di.reform_mean, + absolute_change=wealth_di.absolute_change, + relative_change=wealth_di.relative_change, + ) + session.add(record) + + # Calculate intra-wealth-decile impact + intra_wealth_results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + ) + for r in intra_wealth_results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + decile_type="wealth", + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + except KeyError: + logfire.warning( + "household_wealth_decile not available, skipping wealth decile impact" + ) + # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -1429,7 +1785,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: @app.function( image=us_image, secrets=[db_secrets, logfire_secrets], - memory=8192, + memory=24576, cpu=8, timeout=1800, ) @@ -1448,7 +1804,9 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: with logfire.span("economy_comparison_us", job_id=job_id): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) @@ -1456,8 +1814,13 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: try: # Import models inline from policyengine_api.models import ( + BudgetSummary, + CongressionalDistrictImpact, Dataset, DecileImpact, + Inequality, + IntraDecileImpact, + Poverty, ProgramStatistics, Report, ReportStatus, @@ -1492,6 +1855,12 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.outputs.aggregate import ( + Aggregate as PEAggregate, + ) + from policyengine.outputs.aggregate import ( + AggregateType as PEAggregateType, + ) from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, @@ -1535,6 +1904,8 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, ) pe_baseline_sim.ensure() @@ -1544,24 +1915,51 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, ) pe_reform_sim.ensure() + # Pre-calculate key US variables to include in output + # (PolicyEngine is lazy - variables only calculated when accessed) + us_key_variables = [ + "household_net_income", + "income_tax", + "employee_payroll_tax", + "snap", + "tanf", + "ssi", + "social_security", + "household_benefits", + "household_tax", + "household_weight", + "household_count_people", + "household_income_decile", + ] + with logfire.span("precalculate_variables"): + for var in us_key_variables: + try: + pe_baseline_sim[var] + pe_reform_sim[var] + except Exception: + pass # Variable may not exist in model + # Save output datasets for both simulations + from policyengine_api.services.storage import output_filepath from supabase import create_client supabase = create_client(supabase_url, supabase_key) # Save baseline output with logfire.span("save_baseline_output"): - baseline_output_filename = f"output_{baseline_sim.id}.h5" - baseline_output_path = f"/tmp/{baseline_output_filename}" - pe_baseline_sim.output_dataset.filepath = baseline_output_path + baseline_storage_path = output_filepath(str(baseline_sim.id)) + baseline_local_path = f"/tmp/output_{baseline_sim.id}.h5" + pe_baseline_sim.output_dataset.filepath = baseline_local_path pe_baseline_sim.output_dataset.save() - with open(baseline_output_path, "rb") as f: + with open(baseline_local_path, "rb") as f: supabase.storage.from_(storage_bucket).upload( - baseline_output_filename, + baseline_storage_path, f, { "content-type": "application/octet-stream", @@ -1572,7 +1970,7 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: baseline_output_dataset = Dataset( name=f"Output: {dataset.name} (baseline)", description=f"Output from baseline simulation {baseline_sim.id}", - filepath=baseline_output_filename, + filepath=baseline_storage_path, year=dataset.year, is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, @@ -1584,14 +1982,14 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: # Save reform output with logfire.span("save_reform_output"): - reform_output_filename = f"output_{reform_sim.id}.h5" - reform_output_path = f"/tmp/{reform_output_filename}" - pe_reform_sim.output_dataset.filepath = reform_output_path + reform_storage_path = output_filepath(str(reform_sim.id)) + reform_local_path = f"/tmp/output_{reform_sim.id}.h5" + pe_reform_sim.output_dataset.filepath = reform_local_path pe_reform_sim.output_dataset.save() - with open(reform_output_path, "rb") as f: + with open(reform_local_path, "rb") as f: supabase.storage.from_(storage_bucket).upload( - reform_output_filename, + reform_storage_path, f, { "content-type": "application/octet-stream", @@ -1602,7 +2000,7 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: reform_output_dataset = Dataset( name=f"Output: {dataset.name} (reform)", description=f"Output from reform simulation {reform_sim.id}", - filepath=reform_output_filename, + filepath=reform_storage_path, year=dataset.year, is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, @@ -1619,6 +2017,7 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: baseline_simulation=pe_baseline_sim, reform_simulation=pe_reform_sim, decile=decile_num, + income_variable="household_net_income", ) di.run() @@ -1687,6 +2086,233 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: except KeyError: pass # Variable not in model, skip silently + # Calculate poverty rates for baseline and reform + from policyengine.outputs.poverty import ( + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + calculate_us_poverty_rates, + ) + + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + poverty_results = calculate_us_poverty_rates(pe_sim) + for pov in poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate poverty rates by age group + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + age_poverty_results = calculate_us_poverty_by_age(pe_sim) + for pov in age_poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate poverty rates by gender + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + gender_poverty_results = calculate_us_poverty_by_gender(pe_sim) + for pov in gender_poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate poverty rates by race (US only) + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + race_poverty_results = calculate_us_poverty_by_race(pe_sim) + for pov in race_poverty_results.outputs: + poverty_record = Poverty( + simulation_id=db_sim.id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate inequality for baseline and reform + from policyengine.outputs.inequality import ( + calculate_us_inequality, + ) + + for pe_sim, db_sim in [ + (pe_baseline_sim, baseline_sim), + (pe_reform_sim, reform_sim), + ]: + ineq = calculate_us_inequality(pe_sim) + ineq.run() + inequality_record = Inequality( + simulation_id=db_sim.id, + report_id=report.id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(inequality_record) + + # Calculate budget summary aggregates + # US budget variables — household-level plus state tax + us_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + "household_state_income_tax": "tax_unit", + } + PEAggregate.model_rebuild( + _types_namespace={"Simulation": PESimulation} + ) + for var_name, entity in us_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + budget_record = BudgetSummary( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(budget_record) + + # Household count: bypass Aggregate and compute directly + # from raw numpy values. Using Aggregate(SUM) on + # household_weight would compute sum(weight * weight) + # because MicroSeries.sum() applies weights automatically + # — it's unclear whether Aggregate can be used correctly + # for summing the weight column itself. + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + budget_record = BudgetSummary( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(budget_record) + + # Calculate intra-decile impact + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile_us, + ) + + intra_decile_results_us = pe_compute_intra_decile_us( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + entity="household", + ) + for r in intra_decile_results_us.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + + # Calculate congressional district impact + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + try: + district_impact = compute_us_congressional_district_impacts( + pe_baseline_sim, pe_reform_sim + ) + if district_impact.district_results: + for dr in district_impact.district_results: + record = CongressionalDistrictImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + district_geoid=dr["district_geoid"], + state_fips=dr["state_fips"], + district_number=dr["district_number"], + average_household_income_change=dr[ + "average_household_income_change" + ], + relative_household_income_change=dr[ + "relative_household_income_change" + ], + population=dr["population"], + ) + session.add(record) + except KeyError: + pass # congressional_district_geoid not in dataset + # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -1723,6 +2349,300 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: logfire.force_flush() +# --------------------------------------------------------------------------- +# Household impact (report-based) — called by _trigger_household_impact() +# --------------------------------------------------------------------------- + + +@app.function( + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, +) +def household_impact_uk(report_id: str, traceparent: str | None = None) -> None: + """Run UK household impact analysis for a report. + + Loads the Report and its Simulations from the database, runs household + calculations for each simulation, stores results, and marks the report + as completed. Called via Modal.spawn() from _trigger_household_impact(). + """ + import logfire + + configure_logfire("policyengine-modal-uk", traceparent) + + try: + with logfire.span("household_impact_uk", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.api.household import _calculate_household_uk + from policyengine_api.api.household_analysis import ( + _ensure_list, + _extract_policy_data, + ) + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run each simulation (baseline, then reform if present) + for sim_id in [ + report.baseline_simulation_id, + report.reform_simulation_id, + ]: + if not sim_id: + continue + + simulation = session.get(Simulation, sim_id) + if ( + not simulation + or simulation.status != SimulationStatus.PENDING + ): + continue + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError( + f"Household {simulation.household_id} not found" + ) + + # Convert policy to calculation format + policy_data = None + if simulation.policy_id: + from policyengine_api.models import Policy + + policy = session.get(Policy, simulation.policy_id) + policy_data = _extract_policy_data(policy) + + # Mark simulation as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + hh_data = household.household_data + with logfire.span( + "run_household_calculation", + simulation_id=str(sim_id), + ): + result = _calculate_household_uk( + people=hh_data.get("people", []), + benunit=_ensure_list(hh_data.get("benunit")), + household=_ensure_list(hh_data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "UK household impact failed", + report_id=report_id, + error=str(e), + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', " + "error_message = :error WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +@app.function( + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, +) +def household_impact_us(report_id: str, traceparent: str | None = None) -> None: + """Run US household impact analysis for a report. + + Loads the Report and its Simulations from the database, runs household + calculations for each simulation, stores results, and marks the report + as completed. Called via Modal.spawn() from _trigger_household_impact(). + """ + import logfire + + configure_logfire("policyengine-modal-us", traceparent) + + try: + with logfire.span("household_impact_us", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.api.household import _calculate_household_us + from policyengine_api.api.household_analysis import ( + _ensure_list, + _extract_policy_data, + ) + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run each simulation (baseline, then reform if present) + for sim_id in [ + report.baseline_simulation_id, + report.reform_simulation_id, + ]: + if not sim_id: + continue + + simulation = session.get(Simulation, sim_id) + if ( + not simulation + or simulation.status != SimulationStatus.PENDING + ): + continue + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError( + f"Household {simulation.household_id} not found" + ) + + # Convert policy to calculation format + policy_data = None + if simulation.policy_id: + from policyengine_api.models import Policy + + policy = session.get(Policy, simulation.policy_id) + policy_data = _extract_policy_data(policy) + + # Mark simulation as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + hh_data = household.household_data + with logfire.span( + "run_household_calculation", + simulation_id=str(sim_id), + ): + result = _calculate_household_us( + people=hh_data.get("people", []), + marital_unit=_ensure_list( + hh_data.get("marital_unit") + ), + family=_ensure_list(hh_data.get("family")), + spm_unit=_ensure_list(hh_data.get("spm_unit")), + tax_unit=_ensure_list(hh_data.get("tax_unit")), + household=_ensure_list(hh_data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "US household impact failed", + report_id=report_id, + error=str(e), + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', " + "error_message = :error WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + def _get_pe_policy_uk(policy_id, model_version, session): """Convert database Policy to policyengine Policy for UK.""" if policy_id is None: @@ -1836,7 +2756,9 @@ def compute_aggregate_uk(aggregate_id: str, traceparent: str | None = None) -> N with logfire.span("compute_aggregate_uk", aggregate_id=aggregate_id): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) @@ -1930,7 +2852,9 @@ def compute_aggregate_uk(aggregate_id: str, traceparent: str | None = None) -> N pe_aggregate = PEAggregate( simulation=pe_sim, variable=aggregate.variable, - aggregate_type=PEAggregateType(aggregate.aggregate_type.value), + aggregate_type=PEAggregateType( + aggregate.aggregate_type.value + ), entity=aggregate.entity, ) pe_aggregate.run() @@ -1988,7 +2912,9 @@ def compute_aggregate_us(aggregate_id: str, traceparent: str | None = None) -> N with logfire.span("compute_aggregate_us", aggregate_id=aggregate_id): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) @@ -2074,7 +3000,9 @@ def compute_aggregate_us(aggregate_id: str, traceparent: str | None = None) -> N pe_aggregate = PEAggregate( simulation=pe_sim, variable=aggregate.variable, - aggregate_type=PEAggregateType(aggregate.aggregate_type.value), + aggregate_type=PEAggregateType( + aggregate.aggregate_type.value + ), entity=aggregate.entity, ) pe_aggregate.run() @@ -2135,7 +3063,9 @@ def compute_change_aggregate_uk( ): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) @@ -2338,7 +3268,9 @@ def compute_change_aggregate_us( ): database_url = get_database_url() supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] + supabase_key = os.environ.get( + "SUPABASE_SERVICE_KEY", os.environ["SUPABASE_KEY"] + ) storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") engine = create_engine(database_url) diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 4d64c02..838bfeb 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -1,5 +1,10 @@ """Database models for PolicyEngine API.""" +from .budget_summary import ( + BudgetSummary, + BudgetSummaryCreate, + BudgetSummaryRead, +) from .change_aggregate import ( ChangeAggregate, ChangeAggregateCreate, @@ -7,10 +12,22 @@ ChangeAggregateStatus, ChangeAggregateType, ) +from .congressional_district_impact import ( + CongressionalDistrictImpact, + CongressionalDistrictImpactCreate, + CongressionalDistrictImpactRead, +) +from .constituency_impact import ( + ConstituencyImpact, + ConstituencyImpactCreate, + ConstituencyImpactRead, +) from .dataset import Dataset, DatasetCreate, DatasetRead from .dataset_version import DatasetVersion, DatasetVersionCreate, DatasetVersionRead from .decile_impact import DecileImpact, DecileImpactCreate, DecileImpactRead from .dynamic import Dynamic, DynamicCreate, DynamicRead +from .geographic_impact_base import GeographicImpactBase +from .household import Household, HouseholdCreate, HouseholdRead from .household_job import ( HouseholdJob, HouseholdJobCreate, @@ -18,6 +35,17 @@ HouseholdJobStatus, ) from .inequality import Inequality, InequalityCreate, InequalityRead +from .intra_decile_impact import ( + DecileType, + IntraDecileImpact, + IntraDecileImpactCreate, + IntraDecileImpactRead, +) +from .local_authority_impact import ( + LocalAuthorityImpact, + LocalAuthorityImpactCreate, + LocalAuthorityImpactRead, +) from .output import ( AggregateOutput, AggregateOutputCreate, @@ -26,16 +54,29 @@ AggregateType, ) from .parameter import Parameter, ParameterCreate, ParameterRead -from .parameter_value import ParameterValue, ParameterValueCreate, ParameterValueRead -from .policy import Policy, PolicyCreate, PolicyRead +from .parameter_value import ( + ParameterValue, + ParameterValueCreate, + ParameterValueRead, + ParameterValueWithName, +) +from .policy import Policy, PolicyCreate, PolicyParameterValueInput, PolicyRead from .poverty import Poverty, PovertyCreate, PovertyRead from .program_statistics import ( ProgramStatistics, ProgramStatisticsCreate, ProgramStatisticsRead, ) -from .report import Report, ReportCreate, ReportRead, ReportStatus -from .simulation import Simulation, SimulationCreate, SimulationRead, SimulationStatus +from .region import Region, RegionCreate, RegionRead, RegionType +from .region_dataset_link import RegionDatasetLink +from .report import Report, ReportCreate, ReportRead, ReportStatus, ReportType +from .simulation import ( + Simulation, + SimulationCreate, + SimulationRead, + SimulationStatus, + SimulationType, +) from .tax_benefit_model import ( TaxBenefitModel, TaxBenefitModelCreate, @@ -47,14 +88,50 @@ TaxBenefitModelVersionRead, ) from .user import User, UserCreate, UserRead +from .user_household_association import ( + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) +from .user_policy import ( + UserPolicy, + UserPolicyCreate, + UserPolicyRead, + UserPolicyUpdate, +) +from .user_report_association import ( + UserReportAssociation, + UserReportAssociationCreate, + UserReportAssociationRead, + UserReportAssociationUpdate, +) +from .user_simulation_association import ( + UserSimulationAssociation, + UserSimulationAssociationCreate, + UserSimulationAssociationRead, + UserSimulationAssociationUpdate, +) from .variable import Variable, VariableCreate, VariableRead __all__ = [ + "BudgetSummary", + "BudgetSummaryCreate", + "BudgetSummaryRead", "AggregateOutput", "AggregateOutputCreate", "AggregateOutputRead", "AggregateStatus", "AggregateType", + "CongressionalDistrictImpact", + "CongressionalDistrictImpactCreate", + "CongressionalDistrictImpactRead", + "ConstituencyImpact", + "ConstituencyImpactCreate", + "ConstituencyImpactRead", + "LocalAuthorityImpact", + "LocalAuthorityImpactCreate", + "LocalAuthorityImpactRead", "ChangeAggregate", "ChangeAggregateCreate", "ChangeAggregateRead", @@ -72,6 +149,9 @@ "Dynamic", "DynamicCreate", "DynamicRead", + "Household", + "HouseholdCreate", + "HouseholdRead", "HouseholdJob", "HouseholdJobCreate", "HouseholdJobRead", @@ -79,18 +159,30 @@ "Inequality", "InequalityCreate", "InequalityRead", + "DecileType", + "GeographicImpactBase", + "IntraDecileImpact", + "IntraDecileImpactCreate", + "IntraDecileImpactRead", "Parameter", "ParameterCreate", "ParameterRead", "ParameterValue", "ParameterValueCreate", "ParameterValueRead", + "ParameterValueWithName", "Policy", "PolicyCreate", + "PolicyParameterValueInput", "PolicyRead", "Poverty", "PovertyCreate", "PovertyRead", + "Region", + "RegionCreate", + "RegionDatasetLink", + "RegionRead", + "RegionType", "ProgramStatistics", "ProgramStatisticsCreate", "ProgramStatisticsRead", @@ -98,10 +190,12 @@ "ReportCreate", "ReportRead", "ReportStatus", + "ReportType", "Simulation", "SimulationCreate", "SimulationRead", "SimulationStatus", + "SimulationType", "TaxBenefitModel", "TaxBenefitModelCreate", "TaxBenefitModelRead", @@ -110,7 +204,23 @@ "TaxBenefitModelVersionRead", "User", "UserCreate", + "UserHouseholdAssociation", + "UserHouseholdAssociationCreate", + "UserHouseholdAssociationRead", + "UserHouseholdAssociationUpdate", "UserRead", + "UserSimulationAssociation", + "UserSimulationAssociationCreate", + "UserSimulationAssociationRead", + "UserSimulationAssociationUpdate", + "UserReportAssociation", + "UserReportAssociationCreate", + "UserReportAssociationRead", + "UserReportAssociationUpdate", + "UserPolicy", + "UserPolicyCreate", + "UserPolicyRead", + "UserPolicyUpdate", "Variable", "VariableCreate", "VariableRead", diff --git a/src/policyengine_api/models/budget_summary.py b/src/policyengine_api/models/budget_summary.py new file mode 100644 index 0000000..0a399fd --- /dev/null +++ b/src/policyengine_api/models/budget_summary.py @@ -0,0 +1,55 @@ +"""Budget summary output model. + +Stores economy-wide fiscal aggregates for a report. Each row represents +a single aggregate variable (e.g. household_tax, household_benefits) +with baseline and reform totals. This is separate from ProgramStatistics, +which stores per-program breakdowns. + +The client can derive V1-compatible budget fields from these rows: + - tax_revenue_impact = household_tax row's change + - benefit_spending_impact = household_benefits row's change + - budgetary_impact = tax change - benefit change + - households = household_count_total row's baseline_total + - baseline_net_income = household_net_income row's baseline_total + - state_tax_revenue_impact = household_state_income_tax row's change (US only) +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class BudgetSummaryBase(SQLModel): + """Base budget summary fields.""" + + baseline_simulation_id: UUID = Field(foreign_key="simulations.id") + reform_simulation_id: UUID = Field(foreign_key="simulations.id") + report_id: UUID | None = Field(default=None, foreign_key="reports.id") + variable_name: str + entity: str + baseline_total: float | None = None + reform_total: float | None = None + change: float | None = None + + +class BudgetSummary(BudgetSummaryBase, table=True): + """Budget summary database model.""" + + __tablename__ = "budget_summary" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class BudgetSummaryCreate(BudgetSummaryBase): + """Schema for creating budget summary records.""" + + pass + + +class BudgetSummaryRead(BudgetSummaryBase): + """Schema for reading budget summary records.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/congressional_district_impact.py b/src/policyengine_api/models/congressional_district_impact.py new file mode 100644 index 0000000..235092a --- /dev/null +++ b/src/policyengine_api/models/congressional_district_impact.py @@ -0,0 +1,38 @@ +"""Congressional district impact output model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field + +from .geographic_impact_base import GeographicImpactBase + + +class CongressionalDistrictImpactBase(GeographicImpactBase): + """Base congressional district impact fields.""" + + district_geoid: int + state_fips: int + district_number: int + + +class CongressionalDistrictImpact(CongressionalDistrictImpactBase, table=True): + """Congressional district impact database model.""" + + __tablename__ = "congressional_district_impacts" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class CongressionalDistrictImpactCreate(CongressionalDistrictImpactBase): + """Schema for creating congressional district impacts.""" + + pass + + +class CongressionalDistrictImpactRead(CongressionalDistrictImpactBase): + """Schema for reading congressional district impacts.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/constituency_impact.py b/src/policyengine_api/models/constituency_impact.py new file mode 100644 index 0000000..663b674 --- /dev/null +++ b/src/policyengine_api/models/constituency_impact.py @@ -0,0 +1,39 @@ +"""UK parliamentary constituency impact output model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field + +from .geographic_impact_base import GeographicImpactBase + + +class ConstituencyImpactBase(GeographicImpactBase): + """Base constituency impact fields.""" + + constituency_code: str + constituency_name: str + x: int + y: int + + +class ConstituencyImpact(ConstituencyImpactBase, table=True): + """Constituency impact database model.""" + + __tablename__ = "constituency_impacts" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class ConstituencyImpactCreate(ConstituencyImpactBase): + """Schema for creating constituency impacts.""" + + pass + + +class ConstituencyImpactRead(ConstituencyImpactBase): + """Schema for reading constituency impacts.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/geographic_impact_base.py b/src/policyengine_api/models/geographic_impact_base.py new file mode 100644 index 0000000..3687a91 --- /dev/null +++ b/src/policyengine_api/models/geographic_impact_base.py @@ -0,0 +1,19 @@ +"""Shared base for geographic impact models.""" + +from uuid import UUID + +from sqlmodel import Field, SQLModel + + +class GeographicImpactBase(SQLModel): + """Shared fields for geographic impact models. + + Used by constituency, local authority, and congressional district impacts. + """ + + baseline_simulation_id: UUID = Field(foreign_key="simulations.id") + reform_simulation_id: UUID = Field(foreign_key="simulations.id") + report_id: UUID | None = Field(default=None, foreign_key="reports.id") + average_household_income_change: float + relative_household_income_change: float + population: float diff --git a/src/policyengine_api/models/household.py b/src/policyengine_api/models/household.py new file mode 100644 index 0000000..8a96850 --- /dev/null +++ b/src/policyengine_api/models/household.py @@ -0,0 +1,54 @@ +"""Stored household definition model.""" + +from datetime import datetime, timezone +from typing import Any, Literal +from uuid import UUID, uuid4 + +from sqlalchemy import JSON +from sqlmodel import Column, Field, SQLModel + + +class HouseholdBase(SQLModel): + """Base household fields.""" + + tax_benefit_model_name: str + year: int + label: str | None = None + household_data: dict[str, Any] = Field(sa_column=Column(JSON, nullable=False)) + + +class Household(HouseholdBase, table=True): + """Stored household database model.""" + + __tablename__ = "households" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class HouseholdCreate(SQLModel): + """Schema for creating a stored household. + + Accepts the flat structure matching the frontend Household interface: + people as an array, entity groups as optional dicts. + """ + + tax_benefit_model_name: Literal["policyengine_us", "policyengine_uk"] + year: int + label: str | None = None + people: list[dict[str, Any]] + tax_unit: dict[str, Any] | None = None + family: dict[str, Any] | None = None + spm_unit: dict[str, Any] | None = None + marital_unit: dict[str, Any] | None = None + household: dict[str, Any] | None = None + benunit: dict[str, Any] | None = None + + +class HouseholdRead(HouseholdCreate): + """Schema for reading a stored household.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/src/policyengine_api/models/intra_decile_impact.py b/src/policyengine_api/models/intra_decile_impact.py new file mode 100644 index 0000000..8771d55 --- /dev/null +++ b/src/policyengine_api/models/intra_decile_impact.py @@ -0,0 +1,66 @@ +"""Intra-decile impact output model. + +Stores the distribution of income change categories within each income +decile. Each row represents one decile (1-10) or the overall average +(decile=0), with five proportion columns summing to ~1.0. + +The five categories classify households by their percentage income change: + - lose_more_than_5pct: change <= -5% + - lose_less_than_5pct: -5% < change <= -0.1% + - no_change: -0.1% < change <= 0.1% + - gain_less_than_5pct: 0.1% < change <= 5% + - gain_more_than_5pct: change > 5% + +Proportions are people-weighted (using household_count_people * +household_weight) so they reflect the share of people, not households. +""" + +from datetime import datetime, timezone +from enum import Enum +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class DecileType(str, Enum): + """Type of decile grouping.""" + + INCOME = "income" + WEALTH = "wealth" + + +class IntraDecileImpactBase(SQLModel): + """Base intra-decile impact fields.""" + + baseline_simulation_id: UUID = Field(foreign_key="simulations.id") + reform_simulation_id: UUID = Field(foreign_key="simulations.id") + report_id: UUID | None = Field(default=None, foreign_key="reports.id") + decile_type: DecileType = Field(default=DecileType.INCOME) + decile: int = Field(ge=0, le=10) + lose_more_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) + lose_less_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) + no_change: float | None = Field(default=None, ge=0.0, le=1.0) + gain_less_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) + gain_more_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) + + +class IntraDecileImpact(IntraDecileImpactBase, table=True): + """Intra-decile impact database model.""" + + __tablename__ = "intra_decile_impacts" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class IntraDecileImpactCreate(IntraDecileImpactBase): + """Schema for creating intra-decile impact records.""" + + pass + + +class IntraDecileImpactRead(IntraDecileImpactBase): + """Schema for reading intra-decile impact records.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/local_authority_impact.py b/src/policyengine_api/models/local_authority_impact.py new file mode 100644 index 0000000..19effa1 --- /dev/null +++ b/src/policyengine_api/models/local_authority_impact.py @@ -0,0 +1,39 @@ +"""UK local authority impact output model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field + +from .geographic_impact_base import GeographicImpactBase + + +class LocalAuthorityImpactBase(GeographicImpactBase): + """Base local authority impact fields.""" + + local_authority_code: str + local_authority_name: str + x: int + y: int + + +class LocalAuthorityImpact(LocalAuthorityImpactBase, table=True): + """Local authority impact database model.""" + + __tablename__ = "local_authority_impacts" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class LocalAuthorityImpactCreate(LocalAuthorityImpactBase): + """Schema for creating local authority impacts.""" + + pass + + +class LocalAuthorityImpactRead(LocalAuthorityImpactBase): + """Schema for reading local authority impacts.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/parameter_value.py b/src/policyengine_api/models/parameter_value.py index 8898796..977bde1 100644 --- a/src/policyengine_api/models/parameter_value.py +++ b/src/policyengine_api/models/parameter_value.py @@ -48,3 +48,9 @@ class ParameterValueRead(ParameterValueBase): id: UUID created_at: datetime + + +class ParameterValueWithName(ParameterValueRead): + """Parameter value with the parameter's dotted name included.""" + + parameter_name: str diff --git a/src/policyengine_api/models/policy.py b/src/policyengine_api/models/policy.py index 570320b..4e76206 100644 --- a/src/policyengine_api/models/policy.py +++ b/src/policyengine_api/models/policy.py @@ -1,11 +1,14 @@ from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 from sqlmodel import Field, Relationship, SQLModel +from .parameter_value import ParameterValueWithName + if TYPE_CHECKING: from .parameter_value import ParameterValue + from .tax_benefit_model import TaxBenefitModel class PolicyBase(SQLModel): @@ -13,6 +16,7 @@ class PolicyBase(SQLModel): name: str description: str | None = None + tax_benefit_model_id: UUID = Field(foreign_key="tax_benefit_models.id", index=True) class Policy(PolicyBase, table=True): @@ -26,6 +30,16 @@ class Policy(PolicyBase, table=True): # Relationships parameter_values: list["ParameterValue"] = Relationship(back_populates="policy") + tax_benefit_model: "TaxBenefitModel" = Relationship() + + +class PolicyParameterValueInput(SQLModel): + """Input schema for a parameter value when creating a policy.""" + + parameter_id: UUID + value_json: Any + start_date: datetime + end_date: datetime | None = None class PolicyCreate(PolicyBase): @@ -51,7 +65,7 @@ class PolicyCreate(PolicyBase): } """ - parameter_values: list[dict] = [] + parameter_values: list[PolicyParameterValueInput] = [] class PolicyRead(PolicyBase): @@ -60,3 +74,4 @@ class PolicyRead(PolicyBase): id: UUID created_at: datetime updated_at: datetime + parameter_values: list[ParameterValueWithName] = [] diff --git a/src/policyengine_api/models/region.py b/src/policyengine_api/models/region.py new file mode 100644 index 0000000..29c2785 --- /dev/null +++ b/src/policyengine_api/models/region.py @@ -0,0 +1,84 @@ +"""Region model for geographic areas used in analysis.""" + +from datetime import datetime, timezone +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from pydantic import model_validator +from sqlmodel import Field, Relationship, SQLModel + +from .region_dataset_link import RegionDatasetLink + +if TYPE_CHECKING: + from .dataset import Dataset + from .tax_benefit_model import TaxBenefitModel + + +class RegionType(str, Enum): + """Type of geographic region.""" + + NATIONAL = "national" + COUNTRY = "country" + STATE = "state" + CONGRESSIONAL_DISTRICT = "congressional_district" + CONSTITUENCY = "constituency" + LOCAL_AUTHORITY = "local_authority" + CITY = "city" + PLACE = "place" + + +class RegionBase(SQLModel): + """Base region fields.""" + + code: str # e.g., "state/ca", "constituency/Sheffield Central" + label: str # e.g., "California", "Sheffield Central" + region_type: RegionType # e.g., RegionType.STATE, RegionType.CONSTITUENCY + requires_filter: bool = False + filter_field: str | None = None # e.g., "state_code", "place_fips" + filter_value: str | None = None # e.g., "CA", "44000" + parent_code: str | None = None # e.g., "us", "state/ca" + state_code: str | None = None # For US regions + state_name: str | None = None # For US regions + tax_benefit_model_id: UUID = Field(foreign_key="tax_benefit_models.id") + + +class Region(RegionBase, table=True): + """Region database model. + + Regions represent geographic areas for analysis, from countries + down to states, congressional districts, cities, etc. + Each region links to multiple datasets (one per year) via the + region_datasets join table. + """ + + __tablename__ = "regions" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # Relationships + datasets: list["Dataset"] = Relationship(link_model=RegionDatasetLink) + tax_benefit_model: "TaxBenefitModel" = Relationship() + + +class RegionCreate(RegionBase): + """Schema for creating regions.""" + + @model_validator(mode="after") + def check_filter_fields(self) -> "RegionCreate": + if self.requires_filter: + if not self.filter_field or not self.filter_value: + raise ValueError( + "requires_filter=True requires filter_field and filter_value" + ) + return self + + +class RegionRead(RegionBase): + """Schema for reading regions.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/src/policyengine_api/models/region_dataset_link.py b/src/policyengine_api/models/region_dataset_link.py new file mode 100644 index 0000000..9801cfc --- /dev/null +++ b/src/policyengine_api/models/region_dataset_link.py @@ -0,0 +1,19 @@ +"""Link table for many-to-many relationship between regions and datasets.""" + +from uuid import UUID + +from sqlmodel import Field, SQLModel + + +class RegionDatasetLink(SQLModel, table=True): + """Join table linking regions to their available datasets. + + Each region can have multiple datasets (one per year), and each + dataset can be shared across multiple regions (e.g., a state dataset + used by both the state region and its place/city regions). + """ + + __tablename__ = "region_datasets" + + region_id: UUID = Field(foreign_key="regions.id", primary_key=True) + dataset_id: UUID = Field(foreign_key="datasets.id", primary_key=True) diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index ee1b678..b034dcb 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -14,14 +14,22 @@ class ReportStatus(str, Enum): FAILED = "failed" +class ReportType(str, Enum): + """Type of analysis report.""" + + ECONOMY_COMPARISON = "economy_comparison" + HOUSEHOLD_COMPARISON = "household_comparison" + HOUSEHOLD_SINGLE = "household_single" + + class ReportBase(SQLModel): """Base report fields.""" label: str description: str | None = None + report_type: ReportType | None = None user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) - parent_report_id: UUID | None = Field(default=None, foreign_key="reports.id") status: ReportStatus = ReportStatus.PENDING error_message: str | None = None baseline_simulation_id: UUID | None = Field( @@ -41,10 +49,18 @@ class Report(ReportBase, table=True): created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) -class ReportCreate(ReportBase): - """Schema for creating reports.""" +class ReportCreate(SQLModel): + """Schema for creating reports — client-settable fields only. - pass + Excludes server-controlled fields: status, error_message, markdown. + """ + + label: str + description: str | None = None + report_type: ReportType | None = None + user_id: UUID | None = None + baseline_simulation_id: UUID | None = None + reform_simulation_id: UUID | None = None class ReportRead(ReportBase): diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index b23141e..132a560 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -1,14 +1,19 @@ from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 +from pydantic import model_validator +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSON from sqlmodel import Field, Relationship, SQLModel if TYPE_CHECKING: from .dataset import Dataset from .dynamic import Dynamic + from .household import Household from .policy import Policy + from .region import Region from .tax_benefit_model_version import TaxBenefitModelVersion @@ -21,10 +26,19 @@ class SimulationStatus(str, Enum): FAILED = "failed" +class SimulationType(str, Enum): + """Type of simulation.""" + + HOUSEHOLD = "household" + ECONOMY = "economy" + + class SimulationBase(SQLModel): """Base simulation fields.""" - dataset_id: UUID = Field(foreign_key="datasets.id") + simulation_type: SimulationType = SimulationType.ECONOMY + dataset_id: UUID | None = Field(default=None, foreign_key="datasets.id") + household_id: UUID | None = Field(default=None, foreign_key="households.id") policy_id: UUID | None = Field(default=None, foreign_key="policies.id") dynamic_id: UUID | None = Field(default=None, foreign_key="dynamics.id") tax_benefit_model_version_id: UUID = Field( @@ -34,6 +48,21 @@ class SimulationBase(SQLModel): status: SimulationStatus = SimulationStatus.PENDING error_message: str | None = None + # Region provenance (which region this simulation targets) + region_id: UUID | None = Field(default=None, foreign_key="regions.id") + + # Regional filtering parameters (passed to policyengine.py) + filter_field: str | None = Field( + default=None, + description="Household-level variable to filter dataset by (e.g., 'place_fips', 'country')", + ) + filter_value: str | None = Field( + default=None, + description="Value to match when filtering (e.g., '44000', 'ENGLAND')", + ) + + year: int | None = None + class Simulation(SimulationBase, table=True): """Simulation database model.""" @@ -45,6 +74,9 @@ class Simulation(SimulationBase, table=True): updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) started_at: datetime | None = None completed_at: datetime | None = None + household_result: dict[str, Any] | None = Field( + default=None, sa_column=Column(JSON) + ) # Relationships dataset: "Dataset" = Relationship( @@ -53,7 +85,14 @@ class Simulation(SimulationBase, table=True): "primaryjoin": "Simulation.dataset_id==Dataset.id", } ) + household: "Household" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[Simulation.household_id]", + "primaryjoin": "Simulation.household_id==Household.id", + } + ) policy: "Policy" = Relationship() + region: "Region" = Relationship() dynamic: "Dynamic" = Relationship() tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship() output_dataset: "Dataset" = Relationship( @@ -64,10 +103,40 @@ class Simulation(SimulationBase, table=True): ) -class SimulationCreate(SimulationBase): - """Schema for creating simulations.""" - - pass +class SimulationCreate(SQLModel): + """Schema for creating simulations — client-settable fields only. + + Excludes server-controlled fields: status, error_message, output_dataset_id. + """ + + simulation_type: SimulationType = SimulationType.ECONOMY + dataset_id: UUID | None = None + household_id: UUID | None = None + policy_id: UUID | None = None + dynamic_id: UUID | None = None + tax_benefit_model_version_id: UUID + region_id: UUID | None = None + filter_field: str | None = None + filter_value: str | None = None + year: int | None = None + + @model_validator(mode="after") + def check_type_consistency(self) -> "SimulationCreate": + if self.simulation_type == SimulationType.HOUSEHOLD: + if not self.household_id: + raise ValueError("HOUSEHOLD simulation requires household_id") + if self.dataset_id: + raise ValueError("HOUSEHOLD simulation cannot have dataset_id") + elif self.simulation_type == SimulationType.ECONOMY: + if not self.dataset_id: + raise ValueError("ECONOMY simulation requires dataset_id") + if self.household_id: + raise ValueError("ECONOMY simulation cannot have household_id") + if (self.filter_field is None) != (self.filter_value is None): + raise ValueError( + "filter_field and filter_value must both be set or both None" + ) + return self class SimulationRead(SimulationBase): @@ -78,3 +147,4 @@ class SimulationRead(SimulationBase): updated_at: datetime started_at: datetime | None completed_at: datetime | None + household_result: dict[str, Any] | None = None diff --git a/src/policyengine_api/models/user_household_association.py b/src/policyengine_api/models/user_household_association.py new file mode 100644 index 0000000..d2de63d --- /dev/null +++ b/src/policyengine_api/models/user_household_association.py @@ -0,0 +1,53 @@ +"""User-household association model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserHouseholdAssociationBase(SQLModel): + """Base association fields.""" + + # user_id is a client-generated UUID stored in localStorage, not a foreign key + user_id: UUID = Field(index=True) + household_id: UUID = Field(foreign_key="households.id", index=True) + country_id: str # Stored as string in DB, validated via Pydantic in Create schema + label: str | None = None + + +class UserHouseholdAssociation(UserHouseholdAssociationBase, table=True): + """User-household association database model.""" + + __tablename__ = "user_household_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserHouseholdAssociationCreate(SQLModel): + """Schema for creating a user-household association.""" + + user_id: UUID + household_id: UUID + country_id: CountryId + label: str | None = None + + +class UserHouseholdAssociationUpdate(SQLModel): + """Schema for updating a user-household association.""" + + model_config = {"extra": "forbid"} + + label: str | None = None + + +class UserHouseholdAssociationRead(UserHouseholdAssociationBase): + """Schema for reading a user-household association.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/src/policyengine_api/models/user_policy.py b/src/policyengine_api/models/user_policy.py new file mode 100644 index 0000000..a9a86b6 --- /dev/null +++ b/src/policyengine_api/models/user_policy.py @@ -0,0 +1,64 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import Field, Relationship, SQLModel + +from policyengine_api.config.constants import CountryId + +if TYPE_CHECKING: + from .policy import Policy + + +class UserPolicyBase(SQLModel): + """Base user-policy association fields.""" + + # user_id is a client-generated UUID stored in localStorage, not a foreign key. + # This allows anonymous users to save policies without requiring authentication. + # The UUID is generated once per browser via crypto.randomUUID() and persisted + # in localStorage for stable identity across sessions. + user_id: UUID = Field(index=True) + policy_id: UUID = Field(foreign_key="policies.id", index=True) + country_id: str # Stored as string in DB, validated via Pydantic in Create schema + label: str | None = None + + +class UserPolicy(UserPolicyBase, table=True): + """User-policy association database model.""" + + __tablename__ = "user_policies" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # Relationships + policy: "Policy" = Relationship() + + +class UserPolicyCreate(SQLModel): + """Schema for creating user-policy associations. + + Uses CountryId Literal type for validation of country_id. + """ + + user_id: UUID + policy_id: UUID + country_id: CountryId # Validated to "us" or "uk" + label: str | None = None + + +class UserPolicyRead(UserPolicyBase): + """Schema for reading user-policy associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserPolicyUpdate(SQLModel): + """Schema for updating user-policy associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None diff --git a/src/policyengine_api/models/user_report_association.py b/src/policyengine_api/models/user_report_association.py new file mode 100644 index 0000000..4f078cb --- /dev/null +++ b/src/policyengine_api/models/user_report_association.py @@ -0,0 +1,63 @@ +"""User-report association model. + +Associates users with reports they've created. This enables users to +maintain a list of their reports across sessions. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save reports without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserReportAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(index=True) + report_id: UUID = Field(foreign_key="reports.id", index=True) + country_id: str + label: str | None = None + last_run_at: datetime | None = None + + +class UserReportAssociation(UserReportAssociationBase, table=True): + """User-report association database model.""" + + __tablename__ = "user_report_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserReportAssociationCreate(SQLModel): + """Schema for creating user-report associations.""" + + user_id: UUID + report_id: UUID + country_id: CountryId + label: str | None = None + last_run_at: datetime | None = None + + +class UserReportAssociationRead(UserReportAssociationBase): + """Schema for reading user-report associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserReportAssociationUpdate(SQLModel): + """Schema for updating user-report associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None + last_run_at: datetime | None = None diff --git a/src/policyengine_api/models/user_simulation_association.py b/src/policyengine_api/models/user_simulation_association.py new file mode 100644 index 0000000..9b07d19 --- /dev/null +++ b/src/policyengine_api/models/user_simulation_association.py @@ -0,0 +1,60 @@ +"""User-simulation association model. + +Associates users with simulations they've run. This enables users to +maintain a list of their simulations across sessions. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save simulations without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserSimulationAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(index=True) + simulation_id: UUID = Field(foreign_key="simulations.id", index=True) + country_id: str + label: str | None = None + + +class UserSimulationAssociation(UserSimulationAssociationBase, table=True): + """User-simulation association database model.""" + + __tablename__ = "user_simulation_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserSimulationAssociationCreate(SQLModel): + """Schema for creating user-simulation associations.""" + + user_id: UUID + simulation_id: UUID + country_id: CountryId + label: str | None = None + + +class UserSimulationAssociationRead(UserSimulationAssociationBase): + """Schema for reading user-simulation associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserSimulationAssociationUpdate(SQLModel): + """Schema for updating user-simulation associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None diff --git a/src/policyengine_api/models/variable.py b/src/policyengine_api/models/variable.py index f163577..b147bcf 100644 --- a/src/policyengine_api/models/variable.py +++ b/src/policyengine_api/models/variable.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 from sqlmodel import JSON, Column, Field, Relationship, SQLModel @@ -15,9 +15,12 @@ class VariableBase(SQLModel): entity: str description: str | None = None data_type: str | None = None # Store as string representation - possible_values: str | None = Field( + possible_values: list[str] | None = Field( default=None, sa_column=Column(JSON) ) # Store as JSON list + default_value: Any = Field( + default=None, sa_column=Column(JSON) + ) # Store as JSON (handles int, float, bool, str, etc.) tax_benefit_model_version_id: UUID = Field( foreign_key="tax_benefit_model_versions.id" ) diff --git a/src/policyengine_api/services/storage.py b/src/policyengine_api/services/storage.py index 233e11e..4a690d0 100644 --- a/src/policyengine_api/services/storage.py +++ b/src/policyengine_api/services/storage.py @@ -9,6 +9,11 @@ CACHE_DIR = Path("/tmp/policyengine_dataset_cache") +def output_filepath(simulation_id: str) -> str: + """Build the storage path for a simulation output dataset.""" + return f"outputs/output_{simulation_id}.h5" + + def get_supabase_client() -> Client: """Get Supabase client.""" return create_client(settings.supabase_url, settings.supabase_key) diff --git a/supabase/.temp/cli-latest b/supabase/.temp/cli-latest index 8c68db7..1dd6178 100644 --- a/supabase/.temp/cli-latest +++ b/supabase/.temp/cli-latest @@ -1 +1 @@ -v2.67.1 \ No newline at end of file +v2.75.0 \ No newline at end of file diff --git a/supabase/migrations/20251229000000_add_parameter_values_indexes.sql b/supabase/migrations/20251229000000_add_parameter_values_indexes.sql deleted file mode 100644 index c1713d5..0000000 --- a/supabase/migrations/20251229000000_add_parameter_values_indexes.sql +++ /dev/null @@ -1,16 +0,0 @@ --- Add indexes to parameter_values table for query optimization --- This migration improves query performance for filtering by parameter_id and policy_id - --- Composite index for the most common query pattern (filtering by both) -CREATE INDEX IF NOT EXISTS idx_parameter_values_parameter_policy -ON parameter_values(parameter_id, policy_id); - --- Single index on policy_id for filtering by policy alone -CREATE INDEX IF NOT EXISTS idx_parameter_values_policy -ON parameter_values(policy_id); - --- Partial index for baseline values (policy_id IS NULL) --- This optimizes the common "get current law values" query -CREATE INDEX IF NOT EXISTS idx_parameter_values_baseline -ON parameter_values(parameter_id) -WHERE policy_id IS NULL; diff --git a/supabase/migrations/20260103000000_add_poverty_inequality.sql b/supabase/migrations/20260103000000_add_poverty_inequality.sql deleted file mode 100644 index f315d93..0000000 --- a/supabase/migrations/20260103000000_add_poverty_inequality.sql +++ /dev/null @@ -1,33 +0,0 @@ --- Add poverty and inequality tables for economic analysis - -CREATE TABLE IF NOT EXISTS poverty ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, - report_id UUID REFERENCES reports(id) ON DELETE CASCADE, - poverty_type VARCHAR NOT NULL, - entity VARCHAR NOT NULL DEFAULT 'person', - filter_variable VARCHAR, - headcount FLOAT, - total_population FLOAT, - rate FLOAT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE IF NOT EXISTS inequality ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, - report_id UUID REFERENCES reports(id) ON DELETE CASCADE, - income_variable VARCHAR NOT NULL, - entity VARCHAR NOT NULL DEFAULT 'household', - gini FLOAT, - top_10_share FLOAT, - top_1_share FLOAT, - bottom_50_share FLOAT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- Indexes for efficient querying -CREATE INDEX IF NOT EXISTS idx_poverty_simulation_id ON poverty(simulation_id); -CREATE INDEX IF NOT EXISTS idx_poverty_report_id ON poverty(report_id); -CREATE INDEX IF NOT EXISTS idx_inequality_simulation_id ON inequality(simulation_id); -CREATE INDEX IF NOT EXISTS idx_inequality_report_id ON inequality(report_id); diff --git a/supabase/migrations/20260111000000_add_aggregate_status.sql b/supabase/migrations/20260111000000_add_aggregate_status.sql deleted file mode 100644 index b190620..0000000 --- a/supabase/migrations/20260111000000_add_aggregate_status.sql +++ /dev/null @@ -1,13 +0,0 @@ --- Add status and error_message columns to aggregates table -ALTER TABLE aggregates -ADD COLUMN IF NOT EXISTS status VARCHAR(20) DEFAULT 'pending', -ADD COLUMN IF NOT EXISTS error_message TEXT; - --- Add status and error_message columns to change_aggregates table -ALTER TABLE change_aggregates -ADD COLUMN IF NOT EXISTS status VARCHAR(20) DEFAULT 'pending', -ADD COLUMN IF NOT EXISTS error_message TEXT; - --- Create indexes for status filtering -CREATE INDEX IF NOT EXISTS idx_aggregates_status ON aggregates(status); -CREATE INDEX IF NOT EXISTS idx_change_aggregates_status ON change_aggregates(status); diff --git a/test_fixtures/fixtures_economic_impact_response.py b/test_fixtures/fixtures_economic_impact_response.py new file mode 100644 index 0000000..51da689 --- /dev/null +++ b/test_fixtures/fixtures_economic_impact_response.py @@ -0,0 +1,490 @@ +"""Fixtures for economic impact response tests. + +Provides factory functions to create completed reports with all output +table records (poverty, inequality, budget_summary, intra_decile, +program_statistics, decile_impacts) for testing _build_response(). +""" + + +from sqlmodel import Session + +from policyengine_api.models import ( + BudgetSummary, + CongressionalDistrictImpact, + ConstituencyImpact, + Dataset, + DecileImpact, + Inequality, + IntraDecileImpact, + LocalAuthorityImpact, + Poverty, + ProgramStatistics, + Report, + ReportStatus, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +UK_PROGRAMS = { + "income_tax": {"entity": "person", "is_tax": True}, + "national_insurance": {"entity": "person", "is_tax": True}, + "vat": {"entity": "household", "is_tax": True}, + "council_tax": {"entity": "household", "is_tax": True}, + "universal_credit": {"entity": "person", "is_tax": False}, + "child_benefit": {"entity": "person", "is_tax": False}, + "pension_credit": {"entity": "person", "is_tax": False}, + "income_support": {"entity": "person", "is_tax": False}, + "working_tax_credit": {"entity": "person", "is_tax": False}, + "child_tax_credit": {"entity": "person", "is_tax": False}, +} + +UK_PROGRAM_COUNT = len(UK_PROGRAMS) + +BUDGET_VARIABLES_UK = [ + ("household_tax", "household"), + ("household_benefits", "household"), + ("household_net_income", "household"), + ("household_count_total", "household"), +] + +SAMPLE_POVERTY_TYPES = ["absolute_bhc", "absolute_ahc"] +SAMPLE_INEQUALITY_INCOME_VAR = "household_net_income" +SAMPLE_GINI = 0.35 +SAMPLE_TOP_10_SHARE = 0.28 +SAMPLE_TOP_1_SHARE = 0.10 +SAMPLE_BOTTOM_50_SHARE = 0.22 + +INTRA_DECILE_DECILE_COUNT = 11 # 10 deciles + overall + + +# --------------------------------------------------------------------------- +# Core factory: report with simulations +# --------------------------------------------------------------------------- + + +def create_report_with_simulations( + session: Session, + status: ReportStatus = ReportStatus.COMPLETED, +) -> tuple[Report, Simulation, Simulation]: + """Create a model, version, dataset, two simulations, and a report.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0.0", description="Test" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="uk_test", + description="Test dataset", + filepath="test.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + baseline_sim = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + reform_sim = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(baseline_sim) + session.add(reform_sim) + session.commit() + session.refresh(baseline_sim) + session.refresh(reform_sim) + + report = Report( + label="Test economic impact report", + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + status=status, + report_type="economy_comparison", + ) + session.add(report) + session.commit() + session.refresh(report) + + return report, baseline_sim, reform_sim + + +# --------------------------------------------------------------------------- +# Output record factories +# --------------------------------------------------------------------------- + + +def add_poverty_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, + count: int = 4, +) -> list[Poverty]: + """Add poverty records to a report (for baseline and reform).""" + records = [] + for sim in [baseline_sim, reform_sim]: + for i, ptype in enumerate(SAMPLE_POVERTY_TYPES): + rec = Poverty( + simulation_id=sim.id, + report_id=report.id, + poverty_type=ptype, + entity="person", + filter_variable=None, + headcount=float(1000 + i * 100), + total_population=10000.0, + rate=float(1000 + i * 100) / 10000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +def add_poverty_by_age_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[Poverty]: + """Add poverty-by-age records with filter_variable set.""" + records = [] + age_groups = [ + ("is_child", True), + ("is_adult", True), + ("is_SP_age", True), + ] + for sim in [baseline_sim, reform_sim]: + for filter_var, _ in age_groups: + for ptype in SAMPLE_POVERTY_TYPES: + rec = Poverty( + simulation_id=sim.id, + report_id=report.id, + poverty_type=ptype, + entity="person", + filter_variable=filter_var, + headcount=500.0, + total_population=3000.0, + rate=500.0 / 3000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +def add_inequality_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[Inequality]: + """Add inequality records for baseline and reform.""" + records = [] + for sim in [baseline_sim, reform_sim]: + rec = Inequality( + simulation_id=sim.id, + report_id=report.id, + income_variable=SAMPLE_INEQUALITY_INCOME_VAR, + entity="household", + gini=SAMPLE_GINI, + top_10_share=SAMPLE_TOP_10_SHARE, + top_1_share=SAMPLE_TOP_1_SHARE, + bottom_50_share=SAMPLE_BOTTOM_50_SHARE, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +def add_budget_summary_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[BudgetSummary]: + """Add budget summary records for UK variables.""" + records = [] + for var_name, entity in BUDGET_VARIABLES_UK: + rec = BudgetSummary( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + variable_name=var_name, + entity=entity, + baseline_total=1_000_000.0, + reform_total=1_050_000.0, + change=50_000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +def add_intra_decile_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[IntraDecileImpact]: + """Add 11 intra-decile impact records (deciles 1-10 + overall).""" + records = [] + for decile_num in list(range(1, 11)) + [0]: + rec = IntraDecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + decile=decile_num, + lose_more_than_5pct=0.0, + lose_less_than_5pct=0.0, + no_change=0.0, + gain_less_than_5pct=1.0, + gain_more_than_5pct=0.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +def add_program_statistics_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, + programs: dict | None = None, +) -> list[ProgramStatistics]: + """Add program statistics records. Defaults to full UK program list.""" + if programs is None: + programs = UK_PROGRAMS + records = [] + for prog_name, prog_info in programs.items(): + rec = ProgramStatistics( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=500_000.0, + reform_total=520_000.0, + change=20_000.0, + baseline_count=10_000.0, + reform_count=10_000.0, + winners=3_000.0, + losers=2_000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +def add_congressional_district_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[CongressionalDistrictImpact]: + """Add congressional district impact records.""" + records = [] + districts = [ + {"district_geoid": 101, "state_fips": 1, "district_number": 1}, + {"district_geoid": 602, "state_fips": 6, "district_number": 2}, + ] + for d in districts: + rec = CongressionalDistrictImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + district_geoid=d["district_geoid"], + state_fips=d["state_fips"], + district_number=d["district_number"], + average_household_income_change=500.0, + relative_household_income_change=0.01, + population=100000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +SAMPLE_DISTRICT_COUNT = 2 + + +def add_constituency_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[ConstituencyImpact]: + """Add UK constituency impact records.""" + records = [] + constituencies = [ + {"code": "E14000530", "name": "Birmingham, Ladywood", "x": 410, "y": 290}, + { + "code": "E14000639", + "name": "Cities of London and Westminster", + "x": 530, + "y": 180, + }, + ] + for c in constituencies: + rec = ConstituencyImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + constituency_code=c["code"], + constituency_name=c["name"], + x=c["x"], + y=c["y"], + average_household_income_change=300.0, + relative_household_income_change=0.008, + population=80000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +SAMPLE_CONSTITUENCY_COUNT = 2 + + +def add_local_authority_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[LocalAuthorityImpact]: + """Add UK local authority impact records.""" + records = [] + las = [ + {"code": "E09000001", "name": "City of London", "x": 532, "y": 181}, + {"code": "E09000002", "name": "Barking and Dagenham", "x": 549, "y": 186}, + ] + for la in las: + rec = LocalAuthorityImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + local_authority_code=la["code"], + local_authority_name=la["name"], + x=la["x"], + y=la["y"], + average_household_income_change=400.0, + relative_household_income_change=0.012, + population=50000.0, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +SAMPLE_LA_COUNT = 2 + + +def add_wealth_decile_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[DecileImpact]: + """Add 10 wealth decile impact records (income_variable=household_wealth_decile).""" + records = [] + for decile_num in range(1, 11): + rec = DecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + income_variable="household_wealth_decile", + entity="household", + decile=decile_num, + quantiles=10, + baseline_mean=float(10000 * decile_num), + reform_mean=float(10000 * decile_num + 500), + absolute_change=500.0, + relative_change=500.0 / (10000 * decile_num), + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +SAMPLE_WEALTH_DECILE_COUNT = 10 + + +def add_intra_wealth_decile_records( + session: Session, + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation, +) -> list[IntraDecileImpact]: + """Add 11 intra-wealth-decile records (decile_type='wealth').""" + records = [] + for decile_num in list(range(1, 11)) + [0]: + rec = IntraDecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + decile_type="wealth", + decile=decile_num, + lose_more_than_5pct=0.0, + lose_less_than_5pct=0.1, + no_change=0.5, + gain_less_than_5pct=0.3, + gain_more_than_5pct=0.1, + ) + session.add(rec) + records.append(rec) + session.commit() + return records + + +SAMPLE_INTRA_WEALTH_DECILE_COUNT = 11 + + +# --------------------------------------------------------------------------- +# Composite: fully populated report +# --------------------------------------------------------------------------- + + +def create_fully_populated_report( + session: Session, +) -> tuple[Report, Simulation, Simulation]: + """Create a completed report with records in ALL output tables.""" + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_poverty_records(session, report, baseline_sim, reform_sim) + add_poverty_by_age_records(session, report, baseline_sim, reform_sim) + add_inequality_records(session, report, baseline_sim, reform_sim) + add_budget_summary_records(session, report, baseline_sim, reform_sim) + add_intra_decile_records(session, report, baseline_sim, reform_sim) + add_program_statistics_records(session, report, baseline_sim, reform_sim) + add_congressional_district_records(session, report, baseline_sim, reform_sim) + add_constituency_records(session, report, baseline_sim, reform_sim) + add_local_authority_records(session, report, baseline_sim, reform_sim) + add_wealth_decile_records(session, report, baseline_sim, reform_sim) + add_intra_wealth_decile_records(session, report, baseline_sim, reform_sim) + return report, baseline_sim, reform_sim diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py new file mode 100644 index 0000000..85401af --- /dev/null +++ b/test_fixtures/fixtures_household_analysis.py @@ -0,0 +1,365 @@ +"""Fixtures and helpers for household analysis endpoint tests.""" + +from typing import Any +from unittest.mock import patch +from uuid import UUID + +import pytest +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + +# ============================================================================= +# Sample Calculation Results +# ============================================================================= + + +SAMPLE_UK_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4500.0, + "national_insurance": 2800.0, + "net_income": 27700.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_UK_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4000.0, + "national_insurance": 2800.0, + "net_income": 28200.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_US_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 6000.0, + "fica": 3825.0, + "net_income": 40175.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +SAMPLE_US_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 5500.0, + "fica": 3825.0, + "net_income": 40675.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +# ============================================================================= +# Mock Calculator Functions +# ============================================================================= + + +def mock_calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock UK calculator that returns sample results.""" + if policy_data: + return SAMPLE_UK_REFORM_RESULT + return SAMPLE_UK_BASELINE_RESULT + + +def mock_calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock US calculator that returns sample results.""" + if policy_data: + return SAMPLE_US_REFORM_RESULT + return SAMPLE_US_BASELINE_RESULT + + +def mock_calculate_household_failing( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock calculator that raises an exception.""" + raise RuntimeError("Calculation failed") + + +# ============================================================================= +# Pytest Fixtures for Mocking +# ============================================================================= + + +@pytest.fixture +def mock_uk_calculator(): + """Fixture that patches UK calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_us_calculator(): + """Fixture that patches US calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_calculators(): + """Fixture that patches both UK and US calculators.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as uk_mock, + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as us_mock, + ): + yield {"uk": uk_mock, "us": us_mock} + + +@pytest.fixture +def mock_failing_calculator(): + """Fixture that patches calculators to fail.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_household_failing, + ), + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_household_failing, + ), + ): + yield + + +# ============================================================================= +# Database Factory Functions +# ============================================================================= + + +def create_tax_benefit_model( + session: Session, + name: str = "policyengine-uk", + description: str = "UK tax benefit model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + model = TaxBenefitModel( + name=name, + description=description, + ) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_model_version( + session: Session, + model_id: UUID, + version: str = "1.0.0", + description: str = "Test version", +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion record.""" + model_version = TaxBenefitModelVersion( + model_id=model_id, + version=version, + description=description, + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_parameter( + session: Session, + model_version_id: UUID, + name: str = "test_parameter", + label: str = "Test Parameter", + description: str = "A test parameter", +) -> Parameter: + """Create and persist a Parameter record.""" + param = Parameter( + tax_benefit_model_version_id=model_version_id, + name=name, + label=label, + description=description, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_policy( + session: Session, + model_id: UUID, + name: str = "Test Policy", + description: str = "A test policy", +) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + tax_benefit_model_id=model_id, + name=name, + description=description, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_policy_with_parameter_value( + session: Session, + model_id: UUID, + parameter_id: UUID, + value: float, + name: str = "Test Policy", +) -> Policy: + """Create a Policy with an associated ParameterValue.""" + policy = create_policy(session, model_id, name=name) + + param_value = ParameterValue( + policy_id=policy.id, + parameter_id=parameter_id, + value_json={"value": value}, + ) + session.add(param_value) + session.commit() + session.refresh(policy) + return policy + + +def create_household_for_analysis( + session: Session, + tax_benefit_model_name: str = "policyengine_uk", + year: int = 2024, + label: str = "Test household for analysis", +) -> Household: + """Create a household suitable for analysis testing.""" + if tax_benefit_model_name == "policyengine_uk": + household_data = { + "people": [{"age": 30, "employment_income": 35000}], + "benunit": {}, + "household": {"region": "LONDON"}, + } + else: + household_data = { + "people": [{"age": 30, "employment_income": 50000}], + "tax_unit": {"state_code": "CA"}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_fips": 6}, + } + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def setup_uk_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create UK model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_model_version(session, model.id) + return model, version + + +def setup_us_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create US model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_model_version(session, model.id) + return model, version diff --git a/test_fixtures/fixtures_households.py b/test_fixtures/fixtures_households.py new file mode 100644 index 0000000..4e676f4 --- /dev/null +++ b/test_fixtures/fixtures_households.py @@ -0,0 +1,66 @@ +"""Fixtures and helpers for household CRUD tests.""" + +from policyengine_api.models import Household + +# ----------------------------------------------------------------------------- +# Request payloads (match HouseholdCreate schema) +# ----------------------------------------------------------------------------- + +MOCK_US_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "US test household", + "people": [ + {"age": 30, "employment_income": 50000}, + {"age": 28, "employment_income": 30000}, + ], + "tax_unit": {}, + "family": {}, + "household": {"state_name": "CA"}, +} + +MOCK_UK_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_uk", + "year": 2024, + "label": "UK test household", + "people": [ + {"age": 40, "employment_income": 35000}, + ], + "benunit": {"is_married": False}, + "household": {"region": "LONDON"}, +} + +MOCK_HOUSEHOLD_MINIMAL = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "people": [{"age": 25}], +} + + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", + people: list | None = None, + **entity_groups, +) -> Household: + """Create and persist a Household record.""" + household_data = {"people": people or [{"age": 30}]} + household_data.update(entity_groups) + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_intra_decile.py b/test_fixtures/fixtures_intra_decile.py new file mode 100644 index 0000000..fd83c7d --- /dev/null +++ b/test_fixtures/fixtures_intra_decile.py @@ -0,0 +1,77 @@ +"""Fixtures for intra-decile impact tests.""" + +import numpy as np + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +NUM_HOUSEHOLDS = 100 +HOUSEHOLDS_PER_DECILE = NUM_HOUSEHOLDS // 10 + +# Each decile has 10 households; deciles 1-10 +DECILES = np.repeat(np.arange(1, 11), HOUSEHOLDS_PER_DECILE).astype(float) + +UNIFORM_WEIGHTS = np.ones(NUM_HOUSEHOLDS) * 100.0 +UNIFORM_PEOPLE = np.full(NUM_HOUSEHOLDS, 2.0) + +# Income change thresholds (matching intra_decile.py BOUNDS) +THRESHOLD_5PCT = 0.05 +THRESHOLD_0_1PCT = 1e-3 + +CATEGORY_NAMES = [ + "lose_more_than_5pct", + "lose_less_than_5pct", + "no_change", + "gain_less_than_5pct", + "gain_more_than_5pct", +] + +EXPECTED_ROW_COUNT = 11 # 10 deciles + 1 overall (decile=0) +EXPECTED_DECILE_NUMBERS = list(range(1, 11)) + [0] + + +# --------------------------------------------------------------------------- +# Factory functions +# --------------------------------------------------------------------------- + + +def make_baseline_income() -> np.ndarray: + """Baseline incomes: decile N earns N * 10,000.""" + return DECILES * 10_000.0 + + +def make_household_data( + baseline_income: np.ndarray, + reform_income: np.ndarray | None = None, + weights: np.ndarray | None = None, + people: np.ndarray | None = None, +) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """Build baseline and reform household data dicts for compute_intra_decile.""" + if reform_income is None: + reform_income = baseline_income.copy() + if weights is None: + weights = UNIFORM_WEIGHTS.copy() + if people is None: + people = UNIFORM_PEOPLE.copy() + + baseline = { + "household_net_income": baseline_income, + "household_weight": weights, + "household_count_people": people, + "household_income_decile": DECILES.copy(), + } + reform = { + "household_net_income": reform_income, + "household_weight": weights, + "household_count_people": people, + "household_income_decile": DECILES.copy(), + } + return baseline, reform + + +def make_single_household_arrays( + baseline_val: float, reform_val: float +) -> tuple[np.ndarray, np.ndarray]: + """Create single-element arrays for formula unit tests.""" + return np.array([baseline_val]), np.array([reform_val]) diff --git a/test_fixtures/fixtures_parameters.py b/test_fixtures/fixtures_parameters.py index ff69b0e..0df134c 100644 --- a/test_fixtures/fixtures_parameters.py +++ b/test_fixtures/fixtures_parameters.py @@ -54,9 +54,15 @@ def create_parameter(session, model_version, name: str, label: str) -> Parameter return param -def create_policy(session, name: str, description: str = "A test policy") -> Policy: +def create_policy( + session, name: str, model_version, description: str = "A test policy" +) -> Policy: """Create and persist a Policy.""" - policy = Policy(name=name, description=description) + policy = Policy( + name=name, + description=description, + tax_benefit_model_id=model_version.model_id, + ) session.add(policy) session.commit() session.refresh(policy) diff --git a/test_fixtures/fixtures_regions.py b/test_fixtures/fixtures_regions.py new file mode 100644 index 0000000..db56633 --- /dev/null +++ b/test_fixtures/fixtures_regions.py @@ -0,0 +1,265 @@ +"""Fixtures and helpers for region-related tests.""" + +from uuid import uuid4 + +import pytest + +from policyengine_api.models import ( + Dataset, + Region, + RegionDatasetLink, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, +) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +TEST_UUIDS = { + "DATASET": uuid4(), + "DATASET_UK": uuid4(), + "DATASET_US": uuid4(), + "MODEL_UK": uuid4(), + "MODEL_US": uuid4(), + "MODEL_VERSION_UK": uuid4(), + "MODEL_VERSION_US": uuid4(), + "REGION_UK": uuid4(), + "REGION_US_STATE": uuid4(), + "REGION_US_NATIONAL": uuid4(), + "POLICY": uuid4(), + "DYNAMIC": uuid4(), +} + +REGION_CODES = { + "UK_ENGLAND": "country/england", + "US_CALIFORNIA": "state/ca", + "US_NATIONAL": "us", + "UK_NATIONAL": "uk", +} + +FILTER_FIELDS = { + "UK_COUNTRY": "country", + "US_STATE": "state_code", + "US_FIPS": "place_fips", +} + +FILTER_VALUES = { + "ENGLAND": "ENGLAND", + "CALIFORNIA": "CA", + "CA_FIPS": "06000", +} + + +# ----------------------------------------------------------------------------- +# Factory Functions +# ----------------------------------------------------------------------------- + + +def create_tax_benefit_model( + session, name: str = "policyengine-uk", description: str = "UK model" +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel.""" + model = TaxBenefitModel(name=name, description=description) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_tax_benefit_model_version( + session, model: TaxBenefitModel, version: str = "1.0.0" +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion.""" + model_version = TaxBenefitModelVersion( + model_id=model.id, + version=version, + description=f"Version {version}", + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_dataset( + session, + model: TaxBenefitModel, + name: str = "test_dataset", + filepath: str = "test/path/dataset.h5", + year: int = 2024, +) -> Dataset: + """Create and persist a Dataset.""" + dataset = Dataset( + name=name, + description=f"Test dataset: {name}", + filepath=filepath, + year=year, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + return dataset + + +def create_region( + session, + model: TaxBenefitModel, + dataset: Dataset, + code: str, + label: str, + region_type: str, + requires_filter: bool = False, + filter_field: str | None = None, + filter_value: str | None = None, +) -> Region: + """Create and persist a Region with a dataset link.""" + region = Region( + code=code, + label=label, + region_type=region_type, + requires_filter=requires_filter, + filter_field=filter_field, + filter_value=filter_value, + tax_benefit_model_id=model.id, + ) + session.add(region) + session.commit() + session.refresh(region) + + # Create the join table link + link = RegionDatasetLink(region_id=region.id, dataset_id=dataset.id) + session.add(link) + session.commit() + + return region + + +def create_simulation( + session, + dataset: Dataset, + model_version: TaxBenefitModelVersion, + filter_field: str | None = None, + filter_value: str | None = None, + status: SimulationStatus = SimulationStatus.PENDING, +) -> Simulation: + """Create and persist a Simulation with optional filter parameters.""" + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=model_version.id, + status=status, + filter_field=filter_field, + filter_value=filter_value, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +# ----------------------------------------------------------------------------- +# Composite Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def uk_model_and_version(session): + """Create UK model with version.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_tax_benefit_model_version(session, model) + return model, version + + +@pytest.fixture +def us_model_and_version(session): + """Create US model with version.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_tax_benefit_model_version(session, model) + return model, version + + +@pytest.fixture +def uk_dataset(session, uk_model_and_version): + """Create a UK dataset.""" + model, _ = uk_model_and_version + return create_dataset( + session, model, name="uk_enhanced_frs", filepath="uk/enhanced_frs_2024.h5" + ) + + +@pytest.fixture +def us_dataset(session, us_model_and_version): + """Create a US dataset.""" + model, _ = us_model_and_version + return create_dataset(session, model, name="us_cps", filepath="us/cps_2024.h5") + + +@pytest.fixture +def uk_region_national(session, uk_model_and_version, uk_dataset): + """Create UK national region (no filter required).""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + + +@pytest.fixture +def uk_region_england(session, uk_model_and_version, uk_dataset): + """Create England region (filter required).""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + + +@pytest.fixture +def us_region_national(session, us_model_and_version, us_dataset): + """Create US national region (no filter required).""" + model, _ = us_model_and_version + return create_region( + session, + model=model, + dataset=us_dataset, + code="us", + label="United States", + region_type="national", + requires_filter=False, + ) + + +@pytest.fixture +def us_region_california(session, us_model_and_version, us_dataset): + """Create California state region (filter required).""" + model, _ = us_model_and_version + return create_region( + session, + model=model, + dataset=us_dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py new file mode 100644 index 0000000..c2e397f --- /dev/null +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -0,0 +1,173 @@ +"""Fixtures and helpers for standalone simulation endpoint tests.""" + + +from policyengine_api.models import ( + Dataset, + Household, + Policy, + Region, + RegionDatasetLink, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def create_us_model_and_version( + session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a US tax-benefit model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + return model, version + + +def create_uk_model_and_version( + session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a UK tax-benefit model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + return model, version + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str = "Test household", +) -> Household: + """Create and persist a Household record.""" + household = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={ + "people": [{"age": {"2024": 30}, "employment_income": {"2024": 50000}}], + "household": [{"state_code": {"2024": "CA"}}], + }, + ) + session.add(household) + session.commit() + session.refresh(household) + return household + + +def create_policy(session, model: TaxBenefitModel) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + name="Test reform", + description="A test reform policy", + tax_benefit_model_id=model.id, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_dataset(session, model: TaxBenefitModel) -> Dataset: + """Create and persist a Dataset record.""" + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + return dataset + + +def create_region( + session, + model: TaxBenefitModel, + dataset: Dataset, + code: str = "us", + label: str = "United States", + region_type: str = "country", + requires_filter: bool = False, + filter_field: str | None = None, + filter_value: str | None = None, +) -> Region: + """Create and persist a Region record with a dataset link.""" + region = Region( + code=code, + label=label, + region_type=region_type, + requires_filter=requires_filter, + filter_field=filter_field, + filter_value=filter_value, + tax_benefit_model_id=model.id, + ) + session.add(region) + session.commit() + session.refresh(region) + + # Create the join table link + link = RegionDatasetLink(region_id=region.id, dataset_id=dataset.id) + session.add(link) + session.commit() + + return region + + +def create_economy_simulation( + session, + version: TaxBenefitModelVersion, + dataset: Dataset, +) -> Simulation: + """Create and persist an economy Simulation record.""" + simulation = Simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +def create_household_simulation( + session, + version: TaxBenefitModelVersion, + household: Household, +) -> Simulation: + """Create and persist a household Simulation record.""" + simulation = Simulation( + simulation_type=SimulationType.HOUSEHOLD, + household_id=household.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + household_result={"person": [{"income_tax": {"2024": 5000}}]}, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation diff --git a/test_fixtures/fixtures_user_household_associations.py b/test_fixtures/fixtures_user_household_associations.py new file mode 100644 index 0000000..66b0835 --- /dev/null +++ b/test_fixtures/fixtures_user_household_associations.py @@ -0,0 +1,62 @@ +"""Fixtures and helpers for user-household association tests.""" + +from uuid import UUID + +from policyengine_api.models import Household, User, UserHouseholdAssociation + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_user( + session, + first_name: str = "Test", + last_name: str = "User", + email: str = "test@example.com", +) -> User: + """Create and persist a User record.""" + record = User(first_name=first_name, last_name=last_name, email=email) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", +) -> Household: + """Create and persist a Household record.""" + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={"people": [{"age": 30}]}, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_association( + session, + user_id: UUID, + household_id: UUID, + country_id: str = "us", + label: str | None = "My household", +) -> UserHouseholdAssociation: + """Create and persist a UserHouseholdAssociation record.""" + record = UserHouseholdAssociation( + user_id=user_id, + household_id=household_id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_user_policies.py b/test_fixtures/fixtures_user_policies.py new file mode 100644 index 0000000..1572ca7 --- /dev/null +++ b/test_fixtures/fixtures_user_policies.py @@ -0,0 +1,70 @@ +"""Fixtures and helpers for user-policy association tests.""" + +from uuid import UUID + +from policyengine_api.models import Policy, TaxBenefitModel, UserPolicy + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +US_COUNTRY_ID = "us" +UK_COUNTRY_ID = "uk" + +DEFAULT_POLICY_NAME = "Test policy" +DEFAULT_POLICY_DESCRIPTION = "A test policy" + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_policy( + session, + tax_benefit_model: TaxBenefitModel, + name: str = DEFAULT_POLICY_NAME, + description: str = DEFAULT_POLICY_DESCRIPTION, +) -> Policy: + """Create and persist a Policy record.""" + record = Policy( + name=name, + description=description, + tax_benefit_model_id=tax_benefit_model.id, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_user_policy( + session, + user_id: UUID, + policy: Policy, + country_id: str = US_COUNTRY_ID, + label: str | None = None, +) -> UserPolicy: + """Create and persist a UserPolicy association record.""" + record = UserPolicy( + user_id=user_id, + policy_id=policy.id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_user_report_associations.py b/test_fixtures/fixtures_user_report_associations.py new file mode 100644 index 0000000..4ef07df --- /dev/null +++ b/test_fixtures/fixtures_user_report_associations.py @@ -0,0 +1,101 @@ +"""Fixtures and helpers for user-report association tests.""" + +from datetime import datetime +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Report, + ReportStatus, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, + UserReportAssociation, +) + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_report(session, model: TaxBenefitModel | None = None) -> Report: + """Create and persist a Report with required simulation dependencies.""" + if model is None: + model = create_tax_benefit_model(session) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + baseline = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + reform = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(baseline) + session.add(reform) + session.commit() + session.refresh(baseline) + session.refresh(reform) + + report = Report( + label="Test report", + status=ReportStatus.COMPLETED, + baseline_simulation_id=baseline.id, + reform_simulation_id=reform.id, + ) + session.add(report) + session.commit() + session.refresh(report) + return report + + +def create_user_report_association( + session, + user_id: UUID, + report: Report, + country_id: str = "us", + label: str | None = None, + last_run_at: datetime | None = None, +) -> UserReportAssociation: + """Create and persist a UserReportAssociation record.""" + record = UserReportAssociation( + user_id=user_id, + report_id=report.id, + country_id=country_id, + label=label, + last_run_at=last_run_at, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_user_simulation_associations.py b/test_fixtures/fixtures_user_simulation_associations.py new file mode 100644 index 0000000..c2cbd74 --- /dev/null +++ b/test_fixtures/fixtures_user_simulation_associations.py @@ -0,0 +1,79 @@ +"""Fixtures and helpers for user-simulation association tests.""" + +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, + UserSimulationAssociation, +) + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_simulation(session, model: TaxBenefitModel | None = None) -> Simulation: + """Create and persist a Simulation with required dependencies.""" + if model is None: + model = create_tax_benefit_model(session) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +def create_user_simulation_association( + session, + user_id: UUID, + simulation: Simulation, + country_id: str = "us", + label: str | None = None, +) -> UserSimulationAssociation: + """Create and persist a UserSimulationAssociation record.""" + record = UserSimulationAssociation( + user_id=user_id, + simulation_id=simulation.id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/conftest.py b/tests/conftest.py index 8be9b3f..77c29ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,5 @@ """Pytest fixtures for tests.""" -from uuid import uuid4 - import pytest from fastapi.testclient import TestClient from fastapi_cache import FastAPICache @@ -48,6 +46,26 @@ def get_session_override(): app.dependency_overrides.clear() +@pytest.fixture(name="tax_benefit_model") +def tax_benefit_model_fixture(session: Session): + """Create a TaxBenefitModel for tests.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + return model + + +@pytest.fixture(name="uk_tax_benefit_model") +def uk_tax_benefit_model_fixture(session: Session): + """Create a UK TaxBenefitModel for tests.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + return model + + @pytest.fixture(name="simulation_id") def simulation_fixture(session: Session): """Create a test simulation with required dependencies.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 2c591f5..55bb2c2 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -9,6 +9,7 @@ import json from unittest.mock import AsyncMock, MagicMock, patch + from fastapi.testclient import TestClient from policyengine_api.main import app diff --git a/tests/test_agent_policy_questions.py b/tests/test_agent_policy_questions.py index 1550f89..289d73c 100644 --- a/tests/test_agent_policy_questions.py +++ b/tests/test_agent_policy_questions.py @@ -11,10 +11,10 @@ pytestmark = pytest.mark.integration -from policyengine_api.agent_sandbox import _run_agent_impl - import os +from policyengine_api.agent_sandbox import _run_agent_impl + # Use local API by default, override with POLICYENGINE_API_URL env var API_BASE = os.environ.get("POLICYENGINE_API_URL", "http://localhost:8000") @@ -218,4 +218,6 @@ def test_turn_efficiency(self, question, max_expected_turns): print(f"Result: {result['result'][:300]}") if result["turns"] > max_expected_turns: - print(f"WARNING: Took {result['turns']} turns, expected <= {max_expected_turns}") + print( + f"WARNING: Took {result['turns']} turns, expected <= {max_expected_turns}" + ) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 90dbe7c..b13659b 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -1,26 +1,667 @@ -"""Tests for economic impact analysis endpoint. +"""Tests for economic impact analysis (analysis.py). -These tests require a running database with seeded data. -Run with: make integration-test +Unit tests for internal functions (_resolve_dataset_and_region, +_get_deterministic_simulation_id, _get_or_create_simulation) and +integration tests for the /analysis/economic-impact endpoint. """ -import pytest +from uuid import uuid4 -pytestmark = pytest.mark.integration +import pytest +from fastapi import HTTPException from fastapi.testclient import TestClient from sqlmodel import Session, select +from policyengine_api.api.analysis import ( + EconomicImpactRequest, + _get_deterministic_simulation_id, + _get_or_create_simulation, + _resolve_dataset_and_region, +) from policyengine_api.main import app -from policyengine_api.models import Dataset, Simulation, TaxBenefitModel +from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, +) +from test_fixtures.fixtures_regions import ( + create_dataset, + create_region, + create_tax_benefit_model, + create_tax_benefit_model_version, +) client = TestClient(app) +# --------------------------------------------------------------------------- +# _resolve_dataset_and_region +# --------------------------------------------------------------------------- + + +class TestResolveDatasetAndRegion: + """Tests for _resolve_dataset_and_region.""" + + # -- dataset_id path -- + + def test__given_dataset_id__then_region_is_none(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset.id, + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is None + + def test__given_dataset_id__then_dataset_is_returned(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset.id, + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + def test__given_dataset_id_and_region__then_region_takes_precedence( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset1 = create_dataset(session, model, name="dataset_from_id") + dataset2 = create_dataset(session, model, name="dataset_from_region") + create_region( + session, + model=model, + dataset=dataset2, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset1.id, + region="uk", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset2.id + assert resolved_region is not None + assert resolved_region.code == "uk" + + # -- region with filter -- + + def test__given_region_requires_filter__then_returns_filter_fields( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.filter_field == "country" + assert resolved_region.filter_value == "ENGLAND" + assert resolved_region.requires_filter is True + + def test__given_us_state_region__then_returns_state_filter(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-us") + dataset = create_dataset(session, model, name="us_cps") + create_region( + session, + model=model, + dataset=dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="state/ca", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.filter_field == "state_code" + assert resolved_region.filter_value == "CA" + + def test__given_region_with_filter__then_dataset_is_resolved( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + # -- region without filter -- + + def test__given_national_uk_region__then_filter_params_none(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.requires_filter is False + assert resolved_region.filter_field is None + assert resolved_region.filter_value is None + + def test__given_national_us_region__then_filter_params_none(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-us") + dataset = create_dataset(session, model, name="us_cps") + create_region( + session, + model=model, + dataset=dataset, + code="us", + label="United States", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="us", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.requires_filter is False + assert resolved_region.filter_field is None + assert resolved_region.filter_value is None + + def test__given_national_region__then_dataset_still_resolved( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + # -- error cases -- + + def test__given_nonexistent_region_code__then_raises_404(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="nonexistent/region", + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + def test__given_region_for_wrong_model__then_raises_404(self, session: Session): + uk_model = create_tax_benefit_model(session, name="policyengine-uk") + uk_dataset = create_dataset(session, uk_model, name="uk_enhanced_frs") + create_region( + session, + model=uk_model, + dataset=uk_dataset, + code="uk", + label="United Kingdom", + region_type="national", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="uk", + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + + def test__given_neither_dataset_nor_region__then_raises_validation_error( + self, session: Session + ): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="dataset_id or region"): + EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + ) + + def test__given_nonexistent_dataset_id__then_raises_404(self, session: Session): + nonexistent_id = uuid4() + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=nonexistent_id, + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +# --------------------------------------------------------------------------- +# _get_deterministic_simulation_id +# --------------------------------------------------------------------------- + + +class TestGetDeterministicSimulationId: + """Tests for _get_deterministic_simulation_id.""" + + def test__given_same_params__then_same_id_returned(self): + dataset_id = uuid4() + model_version_id = uuid4() + policy_id = uuid4() + dynamic_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + + assert id1 == id2 + + def test__given_different_filter_field__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="state_code", + filter_value="ENGLAND", + ) + + assert id1 != id2 + + def test__given_different_filter_value__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="SCOTLAND", + ) + + assert id1 != id2 + + def test__given_filter_none_vs_filter_set__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id_no_filter = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, + ) + id_with_filter = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + + assert id_no_filter != id_with_filter + + def test__given_different_dataset__then_different_id(self): + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=uuid4(), + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=uuid4(), + filter_field="country", + filter_value="ENGLAND", + ) + + assert id1 != id2 + + def test__given_null_optional_params__then_consistent_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, + ) + + assert id1 == id2 + + +# --------------------------------------------------------------------------- +# _get_or_create_simulation +# --------------------------------------------------------------------------- + + +class TestGetOrCreateSimulation: + """Tests for _get_or_create_simulation.""" + + def test__given_existing_simulation_with_filter__then_reuses( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + first_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + second_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert first_sim.id == second_sim.id + + def test__given_different_filter__then_creates_new_simulation( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + england_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + scotland_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="SCOTLAND", + ) + + assert england_sim.id != scotland_sim.id + assert england_sim.filter_value == "ENGLAND" + assert scotland_sim.filter_value == "SCOTLAND" + + def test__given_no_filter_vs_filter__then_creates_separate_simulations( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + national_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field=None, + filter_value=None, + ) + filtered_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert national_sim.id != filtered_sim.id + assert national_sim.filter_field is None + assert filtered_sim.filter_field == "country" + + def test__given_new_simulation__then_status_is_pending(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert simulation.status == SimulationStatus.PENDING + + def test__given_filter_params__then_simulation_has_filter_fields( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert simulation.filter_field == "country" + assert simulation.filter_value == "ENGLAND" + + def test__given_no_filter_params__then_simulation_has_null_filter_fields( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + ) + + assert simulation.filter_field is None + assert simulation.filter_value is None + + +# --------------------------------------------------------------------------- +# HTTP endpoint validation (no database required) +# --------------------------------------------------------------------------- + + class TestEconomicImpactValidation: """Tests for request validation (no database required).""" def test_invalid_model_name(self): - """Test that invalid model name returns 422.""" response = client.post( "/analysis/economic-impact", json={ @@ -31,7 +672,6 @@ def test_invalid_model_name(self): assert response.status_code == 422 def test_missing_dataset_id(self): - """Test that missing dataset_id returns 422.""" response = client.post( "/analysis/economic-impact", json={ @@ -41,7 +681,6 @@ def test_missing_dataset_id(self): assert response.status_code == 422 def test_invalid_uuid(self): - """Test that invalid UUID returns 422.""" response = client.post( "/analysis/economic-impact", json={ @@ -52,11 +691,11 @@ def test_invalid_uuid(self): assert response.status_code == 422 +@pytest.mark.integration class TestEconomicImpactNotFound: """Tests for 404 responses.""" def test_dataset_not_found(self): - """Test that non-existent dataset returns 404.""" response = client.post( "/analysis/economic-impact", json={ @@ -68,8 +707,11 @@ def test_dataset_not_found(self): assert "not found" in response.json()["detail"].lower() -# Integration tests that require a running database with seeded data -# These are marked with pytest.mark.integration and skipped by default +# --------------------------------------------------------------------------- +# Integration tests (require running database with seeded data) +# --------------------------------------------------------------------------- + + @pytest.mark.integration class TestEconomicImpactIntegration: """Integration tests for economic impact analysis. @@ -97,7 +739,6 @@ def uk_dataset_id(self, session: Session): return dataset.id def test_uk_economic_impact_baseline_only(self, uk_dataset_id): - """Test UK economic impact with no reform policy.""" response = client.post( "/analysis/economic-impact", json={ @@ -113,10 +754,8 @@ def test_uk_economic_impact_baseline_only(self, uk_dataset_id): assert "decile_impacts" in data assert "programme_statistics" in data - # Should have 10 deciles assert len(data["decile_impacts"]) == 10 - # Check decile structure for di in data["decile_impacts"]: assert "decile" in di assert "baseline_mean" in di @@ -124,7 +763,6 @@ def test_uk_economic_impact_baseline_only(self, uk_dataset_id): assert "absolute_change" in di def test_simulations_created(self, uk_dataset_id, session: Session): - """Test that simulations are created in the database.""" response = client.post( "/analysis/economic-impact", json={ @@ -135,7 +773,6 @@ def test_simulations_created(self, uk_dataset_id, session: Session): assert response.status_code == 200 data = response.json() - # Check simulations exist in database baseline_sim = session.get(Simulation, data["baseline_simulation_id"]) assert baseline_sim is not None assert baseline_sim.status == "completed" diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py new file mode 100644 index 0000000..802562d --- /dev/null +++ b/tests/test_analysis_household_impact.py @@ -0,0 +1,519 @@ +"""Tests for household impact analysis endpoints.""" + +from datetime import date +from uuid import UUID, uuid4 + +from policyengine_api.api.household_analysis import ( + UK_CONFIG, + US_CONFIG, + _ensure_list, + _extract_value, + _format_date, + compute_entity_diff, + compute_entity_list_diff, + compute_household_impact, + compute_variable_diff, + get_calculator, + get_country_config, +) +from policyengine_api.models import Report, Simulation, SimulationType +from test_fixtures.fixtures_household_analysis import ( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + create_household_for_analysis, + create_policy, + setup_uk_model_and_version, + setup_us_model_and_version, +) + +# --------------------------------------------------------------------------- +# Unit tests for helper functions +# --------------------------------------------------------------------------- + + +class TestEnsureList: + """Tests for _ensure_list helper.""" + + def test_none_returns_empty_list(self): + assert _ensure_list(None) == [] + + def test_list_returns_same_list(self): + input_list = [1, 2, 3] + assert _ensure_list(input_list) == input_list + + def test_dict_wrapped_in_list(self): + input_dict = {"key": "value"} + result = _ensure_list(input_dict) + assert result == [input_dict] + + def test_empty_list_returns_empty_list(self): + assert _ensure_list([]) == [] + + +class TestExtractValue: + """Tests for _extract_value helper.""" + + def test_dict_with_value_key(self): + assert _extract_value({"value": 100}) == 100 + + def test_dict_without_value_key(self): + assert _extract_value({"other": 100}) is None + + def test_non_dict_returns_as_is(self): + assert _extract_value(100) == 100 + assert _extract_value("string") == "string" + assert _extract_value([1, 2]) == [1, 2] + + +class TestFormatDate: + """Tests for _format_date helper.""" + + def test_none_returns_none(self): + assert _format_date(None) is None + + def test_date_object_formatted(self): + d = date(2024, 1, 15) + assert _format_date(d) == "2024-01-15" + + def test_string_returns_string(self): + assert _format_date("2024-01-15") == "2024-01-15" + + +class TestComputeVariableDiff: + """Tests for compute_variable_diff helper.""" + + def test_numeric_values_return_diff(self): + result = compute_variable_diff(100, 150) + assert result == {"baseline": 100, "reform": 150, "change": 50} + + def test_negative_change(self): + result = compute_variable_diff(150, 100) + assert result == {"baseline": 150, "reform": 100, "change": -50} + + def test_float_values(self): + result = compute_variable_diff(100.5, 200.5) + assert result == {"baseline": 100.5, "reform": 200.5, "change": 100.0} + + def test_non_numeric_baseline_returns_none(self): + assert compute_variable_diff("string", 100) is None + + def test_non_numeric_reform_returns_none(self): + assert compute_variable_diff(100, "string") is None + + def test_both_non_numeric_returns_none(self): + assert compute_variable_diff("a", "b") is None + + +class TestComputeEntityDiff: + """Tests for compute_entity_diff helper.""" + + def test_computes_diff_for_numeric_keys(self): + baseline = {"income": 1000, "tax": 200, "name": "John"} + reform = {"income": 1000, "tax": 150, "name": "John"} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert result["income"]["change"] == 0 + assert "tax" in result + assert result["tax"]["change"] == -50 + assert "name" not in result + + def test_missing_key_in_reform_skipped(self): + baseline = {"income": 1000, "tax": 200} + reform = {"income": 1000} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert "tax" not in result + + def test_empty_entities(self): + assert compute_entity_diff({}, {}) == {} + + +class TestComputeEntityListDiff: + """Tests for compute_entity_list_diff helper.""" + + def test_computes_diff_for_each_pair(self): + baseline_list = [{"income": 100}, {"income": 200}] + reform_list = [{"income": 120}, {"income": 180}] + result = compute_entity_list_diff(baseline_list, reform_list) + + assert len(result) == 2 + assert result[0]["income"]["change"] == 20 + assert result[1]["income"]["change"] == -20 + + def test_empty_lists(self): + assert compute_entity_list_diff([], []) == [] + + +class TestComputeHouseholdImpact: + """Tests for compute_household_impact helper.""" + + def test_uk_household_impact(self): + result = compute_household_impact( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + UK_CONFIG, + ) + + assert "person" in result + assert "benunit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert "income_tax" in person_diff + assert person_diff["income_tax"]["baseline"] == 4500.0 + assert person_diff["income_tax"]["reform"] == 4000.0 + assert person_diff["income_tax"]["change"] == -500.0 + + def test_us_household_impact(self): + result = compute_household_impact( + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + US_CONFIG, + ) + + assert "person" in result + assert "tax_unit" in result + assert "spm_unit" in result + assert "family" in result + assert "marital_unit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert person_diff["income_tax"]["change"] == -500.0 + + def test_missing_entity_skipped(self): + baseline = {"person": [{"income": 100}]} + reform = {"person": [{"income": 120}]} + result = compute_household_impact(baseline, reform, UK_CONFIG) + + assert "person" in result + assert "benunit" not in result + assert "household" not in result + + +class TestGetCountryConfig: + """Tests for get_country_config helper.""" + + def test_uk_model_returns_uk_config(self): + config = get_country_config("policyengine_uk") + assert config == UK_CONFIG + assert config.name == "uk" + assert "benunit" in config.entity_types + + def test_us_model_returns_us_config(self): + config = get_country_config("policyengine_us") + assert config == US_CONFIG + assert config.name == "us" + assert "tax_unit" in config.entity_types + + def test_unknown_model_defaults_to_us(self): + config = get_country_config("unknown_model") + assert config == US_CONFIG + + +class TestGetCalculator: + """Tests for get_calculator helper.""" + + def test_uk_model_returns_uk_calculator(self): + from policyengine_api.api.household_analysis import calculate_uk_household + + calc = get_calculator("policyengine_uk") + assert calc == calculate_uk_household + + def test_us_model_returns_us_calculator(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("policyengine_us") + assert calc == calculate_us_household + + def test_unknown_model_defaults_to_us(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("unknown_model") + assert calc == calculate_us_household + + +# --------------------------------------------------------------------------- +# Validation tests (no database required beyond session fixture) +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactValidation: + """Tests for request validation.""" + + def test_missing_household_id(self, client): + """Test that missing household_id returns 422.""" + response = client.post( + "/analysis/household-impact", + json={}, + ) + assert response.status_code == 422 + + def test_invalid_uuid(self, client): + """Test that invalid UUID returns 422.""" + response = client.post( + "/analysis/household-impact", + json={ + "household_id": "not-a-uuid", + }, + ) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# 404 tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactNotFound: + """Tests for 404 responses.""" + + def test_household_not_found(self, client, session): + """Test that non-existent household returns 404.""" + # Need model for the model version lookup + setup_uk_model_and_version(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_policy_not_found(self, client, session): + """Test that non-existent policy returns 404.""" + setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_get_report_not_found(self, client): + """Test that GET with non-existent report_id returns 404.""" + response = client.get(f"/analysis/household-impact/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# Record creation tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactRecordCreation: + """Tests for correct record creation.""" + + def test_single_run_creates_one_simulation(self, client, session): + """Single run (no policy_id) creates one simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + # May fail during calculation since policyengine not available, + # but should create records + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is None + + def test_comparison_creates_two_simulations(self, client, session): + """Comparison (with policy_id) creates two simulations.""" + model, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, model.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_comparison" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is not None + + def test_simulation_type_is_household(self, client, session): + """Created simulations have simulation_type=HOUSEHOLD.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + data = response.json() + + # Check simulation in database (convert string to UUID for query) + sim_id = UUID(data["baseline_simulation"]["id"]) + sim = session.get(Simulation, sim_id) + assert sim is not None + assert sim.simulation_type == SimulationType.HOUSEHOLD + assert sim.household_id == household.id + assert sim.dataset_id is None + + def test_report_links_simulations(self, client, session): + """Report correctly links baseline and reform simulations.""" + model, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, model.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + + # Check report in database (convert string to UUID for query) + report = session.get(Report, UUID(data["report_id"])) + assert report is not None + assert report.baseline_simulation_id == UUID(data["baseline_simulation"]["id"]) + assert report.reform_simulation_id == UUID(data["reform_simulation"]["id"]) + assert report.report_type == "household_comparison" + + +# --------------------------------------------------------------------------- +# Deduplication tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactDeduplication: + """Tests for simulation/report deduplication.""" + + def test_same_request_returns_same_simulation(self, client, session): + """Same household + same parameters returns same simulation ID.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # First request + response1 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data1 = response1.json() + + # Second request with same parameters + response2 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data2 = response2.json() + + # Should return same IDs + assert data1["report_id"] == data2["report_id"] + assert data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + + def test_different_policy_creates_different_simulation(self, client, session): + """Different policy creates different simulation.""" + model, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy1 = create_policy(session, model.id, name="Policy 1") + policy2 = create_policy(session, model.id, name="Policy 2") + + # Request with policy1 + response1 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy1.id), + }, + ) + data1 = response1.json() + + # Request with policy2 + response2 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy2.id), + }, + ) + data2 = response2.json() + + # Reports should be different + assert data1["report_id"] != data2["report_id"] + # Reform simulations should be different + assert data1["reform_simulation"]["id"] != data2["reform_simulation"]["id"] + # Baseline simulations should be the same (same household, no policy) + assert data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + + +# --------------------------------------------------------------------------- +# GET endpoint tests +# --------------------------------------------------------------------------- + + +class TestGetHouseholdImpact: + """Tests for GET /analysis/household-impact/{report_id}.""" + + def test_get_returns_report_data(self, client, session): + """GET returns report with simulation info.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # Create report via POST + post_response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + report_id = post_response.json()["report_id"] + + # GET the report + get_response = client.get(f"/analysis/household-impact/{report_id}") + assert get_response.status_code == 200 + + data = get_response.json() + assert data["report_id"] == report_id + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + + +# --------------------------------------------------------------------------- +# US household tests +# --------------------------------------------------------------------------- + + +class TestUSHouseholdImpact: + """Tests specific to US households.""" + + def test_us_household_creates_simulation(self, client, session): + """US household creates simulation with correct model.""" + _, version = setup_us_model_and_version(session) + household = create_household_for_analysis( + session, tax_benefit_model_name="policyengine_us" + ) + + response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data = response.json() + assert "report_id" in data + assert data["baseline_simulation"] is not None diff --git a/tests/test_analysis_options.py b/tests/test_analysis_options.py new file mode 100644 index 0000000..d1b8b1b --- /dev/null +++ b/tests/test_analysis_options.py @@ -0,0 +1,132 @@ +"""Tests for GET /analysis/options endpoint.""" + +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + get_modules_for_country, +) + + +class TestAnalysisOptions: + """Tests for the /analysis/options endpoint.""" + + def test_returns_all_modules(self, client): + response = client.get("/analysis/options") + assert response.status_code == 200 + data = response.json() + assert len(data) == len(MODULE_REGISTRY) + + def test_response_shape(self, client): + response = client.get("/analysis/options") + data = response.json() + for item in data: + assert "name" in item + assert "label" in item + assert "description" in item + assert "response_fields" in item + assert isinstance(item["response_fields"], list) + + def test_all_names_are_strings(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["name"], str) + assert len(item["name"]) > 0 + + def test_all_labels_are_non_empty(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["label"], str) + assert len(item["label"]) > 0 + + def test_all_descriptions_are_non_empty(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["description"], str) + assert len(item["description"]) > 0 + + def test_all_response_fields_are_non_empty_lists(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert len(item["response_fields"]) > 0 + for field in item["response_fields"]: + assert isinstance(field, str) + + def test_filter_by_uk(self, client): + response = client.get("/analysis/options?country=uk") + assert response.status_code == 200 + data = response.json() + names = [m["name"] for m in data] + assert "constituency" in names + assert "local_authority" in names + assert "wealth_decile" in names + assert "congressional_district" not in names + + def test_filter_by_us(self, client): + response = client.get("/analysis/options?country=us") + assert response.status_code == 200 + data = response.json() + names = [m["name"] for m in data] + assert "congressional_district" in names + assert "constituency" not in names + assert "local_authority" not in names + assert "wealth_decile" not in names + + def test_uk_count_matches_registry(self, client): + response = client.get("/analysis/options?country=uk") + data = response.json() + expected = len(get_modules_for_country("uk")) + assert len(data) == expected + + def test_us_count_matches_registry(self, client): + response = client.get("/analysis/options?country=us") + data = response.json() + expected = len(get_modules_for_country("us")) + assert len(data) == expected + + def test_shared_modules_in_both_countries(self, client): + uk_resp = client.get("/analysis/options?country=uk") + us_resp = client.get("/analysis/options?country=us") + uk_names = {m["name"] for m in uk_resp.json()} + us_names = {m["name"] for m in us_resp.json()} + for shared in [ + "decile", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "program_statistics", + ]: + assert shared in uk_names + assert shared in us_names + + def test_unknown_country_returns_empty(self, client): + response = client.get("/analysis/options?country=fr") + assert response.status_code == 200 + assert response.json() == [] + + def test_program_statistics_has_two_response_fields(self, client): + response = client.get("/analysis/options") + ps_module = next( + m for m in response.json() if m["name"] == "program_statistics" + ) + assert "program_statistics" in ps_module["response_fields"] + assert "detailed_budget" in ps_module["response_fields"] + + def test_wealth_decile_has_two_response_fields(self, client): + response = client.get("/analysis/options?country=uk") + wd_module = next(m for m in response.json() if m["name"] == "wealth_decile") + assert "wealth_decile" in wd_module["response_fields"] + assert "intra_wealth_decile" in wd_module["response_fields"] + + def test_no_country_param_returns_all(self, client): + all_resp = client.get("/analysis/options") + data = all_resp.json() + returned_names = {m["name"] for m in data} + assert returned_names == set(MODULE_REGISTRY.keys()) + + def test_response_matches_registry_data(self, client): + response = client.get("/analysis/options") + for item in response.json(): + registry_mod = MODULE_REGISTRY[item["name"]] + assert item["label"] == registry_mod.label + assert item["description"] == registry_mod.description + assert item["response_fields"] == list(registry_mod.response_fields) diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py new file mode 100644 index 0000000..b316296 --- /dev/null +++ b/tests/test_computation_modules.py @@ -0,0 +1,295 @@ +"""Tests for the composable computation module dispatch system.""" + +import inspect +from unittest.mock import MagicMock +from uuid import uuid4 + +from policyengine_api.api import computation_modules as cm +from policyengine_api.api.computation_modules import ( + UK_MODULE_DISPATCH, + US_MODULE_DISPATCH, + run_modules, +) +from policyengine_api.api.module_registry import MODULE_REGISTRY + + +class TestDispatchTables: + """Tests for UK_MODULE_DISPATCH and US_MODULE_DISPATCH.""" + + def test_uk_dispatch_keys_match_registry(self): + """Every UK dispatch key should be a valid module in the registry.""" + for key in UK_MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"UK dispatch key {key!r} not in registry" + + def test_us_dispatch_keys_match_registry(self): + """Every US dispatch key should be a valid module in the registry.""" + for key in US_MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"US dispatch key {key!r} not in registry" + + def test_uk_dispatch_covers_uk_modules(self): + """UK dispatch should have an entry for every UK-applicable module.""" + uk_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries + } + assert set(UK_MODULE_DISPATCH.keys()) == uk_module_names + + def test_us_dispatch_covers_us_modules(self): + """US dispatch should have an entry for every US-applicable module.""" + us_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries + } + assert set(US_MODULE_DISPATCH.keys()) == us_module_names + + def test_all_dispatch_values_are_callable(self): + for fn in UK_MODULE_DISPATCH.values(): + assert callable(fn) + for fn in US_MODULE_DISPATCH.values(): + assert callable(fn) + + def test_uk_dispatch_has_9_entries(self): + assert len(UK_MODULE_DISPATCH) == 9 + + def test_us_dispatch_has_7_entries(self): + assert len(US_MODULE_DISPATCH) == 7 + + +class TestSharedModuleFunctions: + """Tests that shared modules reference the same function objects.""" + + def test_decile_function_shared_between_uk_and_us(self): + assert UK_MODULE_DISPATCH["decile"] is US_MODULE_DISPATCH["decile"] + assert UK_MODULE_DISPATCH["decile"] is cm.compute_decile_module + + def test_intra_decile_function_shared_between_uk_and_us(self): + assert UK_MODULE_DISPATCH["intra_decile"] is US_MODULE_DISPATCH["intra_decile"] + assert UK_MODULE_DISPATCH["intra_decile"] is cm.compute_intra_decile_module + + +class TestCountrySpecificFunctions: + """Tests that UK/US specific modules use the correct country-specific functions.""" + + def test_uk_program_statistics(self): + assert ( + UK_MODULE_DISPATCH["program_statistics"] + is cm.compute_program_statistics_module_uk + ) + + def test_us_program_statistics(self): + assert ( + US_MODULE_DISPATCH["program_statistics"] + is cm.compute_program_statistics_module_us + ) + + def test_uk_poverty(self): + assert UK_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_uk + + def test_us_poverty(self): + assert US_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_us + + def test_uk_inequality(self): + assert UK_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_uk + + def test_us_inequality(self): + assert US_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_us + + def test_uk_budget_summary(self): + assert ( + UK_MODULE_DISPATCH["budget_summary"] is cm.compute_budget_summary_module_uk + ) + + def test_us_budget_summary(self): + assert ( + US_MODULE_DISPATCH["budget_summary"] is cm.compute_budget_summary_module_us + ) + + def test_constituency_is_uk_only(self): + assert UK_MODULE_DISPATCH["constituency"] is cm.compute_constituency_module + assert "constituency" not in US_MODULE_DISPATCH + + def test_local_authority_is_uk_only(self): + assert ( + UK_MODULE_DISPATCH["local_authority"] is cm.compute_local_authority_module + ) + assert "local_authority" not in US_MODULE_DISPATCH + + def test_wealth_decile_is_uk_only(self): + assert UK_MODULE_DISPATCH["wealth_decile"] is cm.compute_wealth_decile_module + assert "wealth_decile" not in US_MODULE_DISPATCH + + def test_congressional_district_is_us_only(self): + assert ( + US_MODULE_DISPATCH["congressional_district"] + is cm.compute_congressional_district_module + ) + assert "congressional_district" not in UK_MODULE_DISPATCH + + +class TestModuleFunctionSignatures: + """Tests that all module functions share the expected signature pattern. + + Modules use a common 7-param signature pattern: + (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, + report_id, session, **kwargs) -> None + + run_modules() passes country_id as a kwarg. Modules that need it (e.g. + compute_decile_module) accept it explicitly; others accept **_kwargs. + """ + + _BASE_PARAMS = [ + "pe_baseline_sim", + "pe_reform_sim", + "baseline_sim_id", + "reform_sim_id", + "report_id", + "session", + ] + # 7th param can be either explicit country_id or **_kwargs + _VALID_7TH_PARAMS = {"country_id", "_kwargs"} + + def _get_all_unique_functions(self): + """Collect all unique module functions from both dispatch tables.""" + seen = set() + fns = [] + for fn in list(UK_MODULE_DISPATCH.values()) + list(US_MODULE_DISPATCH.values()): + if id(fn) not in seen: + seen.add(id(fn)) + fns.append(fn) + return fns + + def test_all_functions_have_7_parameters(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + assert len(sig.parameters) == 7, ( + f"{fn.__name__} has {len(sig.parameters)} params, expected 7" + ) + + def test_all_functions_have_expected_param_names(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + param_names = list(sig.parameters.keys()) + # First 6 params must match exactly + assert param_names[:6] == self._BASE_PARAMS, ( + f"{fn.__name__} first 6 params {param_names[:6]} != {self._BASE_PARAMS}" + ) + # 7th param can be country_id or _kwargs + assert param_names[6] in self._VALID_7TH_PARAMS, ( + f"{fn.__name__} 7th param '{param_names[6]}' not in {self._VALID_7TH_PARAMS}" + ) + + def test_all_functions_return_none(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + # `from __future__ import annotations` makes annotations strings + assert sig.return_annotation in (None, "None", inspect.Parameter.empty), ( + f"{fn.__name__} return annotation is {sig.return_annotation!r}, expected None" + ) + + +class TestRunModules: + """Tests for the run_modules dispatch helper.""" + + def _make_mock_dispatch(self, names): + """Create a dispatch dict with mock functions.""" + return {name: MagicMock(name=f"compute_{name}") for name in names} + + def test_runs_all_when_modules_is_none(self): + dispatch = self._make_mock_dispatch(["a", "b", "c"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) + + for fn in dispatch.values(): + fn.assert_called_once_with( + "bl", "rf", ids[0], ids[1], ids[2], session, country_id="" + ) + + def test_runs_only_requested_modules(self): + dispatch = self._make_mock_dispatch(["a", "b", "c"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) + + dispatch["a"].assert_not_called() + dispatch["b"].assert_called_once() + dispatch["c"].assert_not_called() + + def test_ignores_unknown_module_names(self): + dispatch = self._make_mock_dispatch(["a"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + # Should not raise + run_modules( + dispatch, ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session + ) + + dispatch["a"].assert_called_once() + + def test_empty_modules_list_runs_nothing(self): + dispatch = self._make_mock_dispatch(["a", "b"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, [], "bl", "rf", ids[0], ids[1], ids[2], session) + + for fn in dispatch.values(): + fn.assert_not_called() + + def test_preserves_call_order(self): + """Modules should be called in the order they appear in the modules list.""" + call_order = [] + + def make_tracker(name): + def fn(*args, **kwargs): + call_order.append(name) + + return fn + + dispatch = {name: make_tracker(name) for name in ["a", "b", "c"]} + ids = [uuid4() for _ in range(3)] + + run_modules( + dispatch, ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() + ) + + assert call_order == ["c", "a", "b"] + + def test_none_modules_runs_all_in_dispatch_key_order(self): + """When modules is None, all dispatch entries run in dict-iteration order.""" + call_order = [] + + def make_tracker(name): + def fn(*args, **kwargs): + call_order.append(name) + + return fn + + dispatch = {name: make_tracker(name) for name in ["x", "y", "z"]} + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) + + assert call_order == ["x", "y", "z"] + + def test_passes_all_args_correctly(self): + mock_fn = MagicMock() + dispatch = {"test_mod": mock_fn} + session = MagicMock() + bl, rf, b_id, r_id, rep_id = "baseline", "reform", uuid4(), uuid4(), uuid4() + + run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session) + + mock_fn.assert_called_once_with( + bl, rf, b_id, r_id, rep_id, session, country_id="" + ) + + def test_duplicate_module_name_runs_twice(self): + dispatch = self._make_mock_dispatch(["a"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session) + + assert dispatch["a"].call_count == 2 diff --git a/tests/test_economic_impact_response.py b/tests/test_economic_impact_response.py new file mode 100644 index 0000000..18036b8 --- /dev/null +++ b/tests/test_economic_impact_response.py @@ -0,0 +1,685 @@ +"""Tests for _build_response() and _safe_float() in analysis.py. + +Covers all Phase 2 output fields: poverty, inequality, budget_summary, +intra_decile, program_statistics, detailed_budget, and decile_impacts. +""" + + +from policyengine_api.api.analysis import _build_response, _safe_float +from policyengine_api.models import ReportStatus +from test_fixtures.fixtures_economic_impact_response import ( + BUDGET_VARIABLES_UK, + INTRA_DECILE_DECILE_COUNT, + SAMPLE_BOTTOM_50_SHARE, + SAMPLE_CONSTITUENCY_COUNT, + SAMPLE_DISTRICT_COUNT, + SAMPLE_GINI, + SAMPLE_INEQUALITY_INCOME_VAR, + SAMPLE_INTRA_WEALTH_DECILE_COUNT, + SAMPLE_LA_COUNT, + SAMPLE_POVERTY_TYPES, + SAMPLE_TOP_1_SHARE, + SAMPLE_TOP_10_SHARE, + SAMPLE_WEALTH_DECILE_COUNT, + UK_PROGRAM_COUNT, + UK_PROGRAMS, + add_budget_summary_records, + add_congressional_district_records, + add_constituency_records, + add_inequality_records, + add_intra_decile_records, + add_intra_wealth_decile_records, + add_local_authority_records, + add_poverty_by_age_records, + add_poverty_records, + add_program_statistics_records, + add_wealth_decile_records, + create_fully_populated_report, + create_report_with_simulations, +) + +# --------------------------------------------------------------------------- +# _safe_float +# --------------------------------------------------------------------------- + + +class TestSafeFloat: + """Tests for the _safe_float helper that sanitizes floats for JSON.""" + + def test__given_normal_float__then_returns_same_value(self): + assert _safe_float(42.5) == 42.5 + + def test__given_none__then_returns_none(self): + assert _safe_float(None) is None + + def test__given_nan__then_returns_none(self): + assert _safe_float(float("nan")) is None + + def test__given_positive_inf__then_returns_none(self): + assert _safe_float(float("inf")) is None + + def test__given_negative_inf__then_returns_none(self): + assert _safe_float(float("-inf")) is None + + def test__given_zero__then_returns_zero(self): + assert _safe_float(0.0) == 0.0 + + def test__given_negative_float__then_returns_same_value(self): + assert _safe_float(-123.456) == -123.456 + + +# --------------------------------------------------------------------------- +# _build_response — pending report +# --------------------------------------------------------------------------- + + +class TestBuildResponsePending: + """Tests for _build_response when the report is not yet completed.""" + + def test__given_pending_report__then_all_output_fields_are_none(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations( + session, status=ReportStatus.PENDING + ) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.status == ReportStatus.PENDING + assert response.decile_impacts is None + assert response.program_statistics is None + assert response.poverty is None + assert response.inequality is None + assert response.budget_summary is None + assert response.intra_decile is None + assert response.detailed_budget is None + + def test__given_running_report__then_all_output_fields_are_none(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations( + session, status=ReportStatus.RUNNING + ) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.status == ReportStatus.RUNNING + assert response.poverty is None + assert response.inequality is None + + +# --------------------------------------------------------------------------- +# _build_response — poverty +# --------------------------------------------------------------------------- + + +class TestBuildResponsePoverty: + """Tests for poverty records in _build_response output.""" + + def test__given_completed_report_with_poverty__then_poverty_list_not_empty( + self, session + ): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_poverty_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.poverty is not None + assert len(response.poverty) > 0 + + def test__given_poverty_records__then_each_has_poverty_type(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_poverty_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for p in response.poverty: + assert p.poverty_type in SAMPLE_POVERTY_TYPES + + def test__given_poverty_by_age_records__then_filter_variable_is_set(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_poverty_by_age_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.poverty is not None + filter_vars = {p.filter_variable for p in response.poverty} + assert "is_child" in filter_vars + assert "is_adult" in filter_vars + assert "is_SP_age" in filter_vars + + def test__given_poverty_records__then_rate_is_headcount_over_population( + self, session + ): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_poverty_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for p in response.poverty: + expected_rate = p.headcount / p.total_population + assert abs(p.rate - expected_rate) < 1e-9 + + +# --------------------------------------------------------------------------- +# _build_response — inequality +# --------------------------------------------------------------------------- + + +class TestBuildResponseInequality: + """Tests for inequality records in _build_response output.""" + + def test__given_completed_report_with_inequality__then_two_records(self, session): + # Given — one for baseline, one for reform + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_inequality_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.inequality is not None + assert len(response.inequality) == 2 + + def test__given_inequality_records__then_gini_matches_input(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_inequality_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for ineq in response.inequality: + assert ineq.gini == SAMPLE_GINI + assert ineq.top_10_share == SAMPLE_TOP_10_SHARE + assert ineq.top_1_share == SAMPLE_TOP_1_SHARE + assert ineq.bottom_50_share == SAMPLE_BOTTOM_50_SHARE + + def test__given_inequality_records__then_income_variable_set(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_inequality_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for ineq in response.inequality: + assert ineq.income_variable == SAMPLE_INEQUALITY_INCOME_VAR + + +# --------------------------------------------------------------------------- +# _build_response — budget_summary +# --------------------------------------------------------------------------- + + +class TestBuildResponseBudgetSummary: + """Tests for budget_summary records in _build_response output.""" + + def test__given_completed_report_with_budget__then_correct_count(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_budget_summary_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.budget_summary is not None + assert len(response.budget_summary) == len(BUDGET_VARIABLES_UK) + + def test__given_budget_records__then_change_equals_reform_minus_baseline( + self, session + ): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_budget_summary_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for b in response.budget_summary: + expected_change = b.reform_total - b.baseline_total + assert abs(b.change - expected_change) < 1e-9 + + def test__given_budget_records__then_variable_names_match_uk_set(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_budget_summary_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + var_names = {b.variable_name for b in response.budget_summary} + expected_names = {name for name, _ in BUDGET_VARIABLES_UK} + assert var_names == expected_names + + +# --------------------------------------------------------------------------- +# _build_response — intra_decile +# --------------------------------------------------------------------------- + + +class TestBuildResponseIntraDecile: + """Tests for intra_decile records in _build_response output.""" + + def test__given_completed_report_with_intra_decile__then_11_records(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_intra_decile_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.intra_decile is not None + assert len(response.intra_decile) == INTRA_DECILE_DECILE_COUNT + + def test__given_intra_decile_records__then_decile_0_present_for_overall( + self, session + ): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_intra_decile_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + decile_numbers = {r.decile for r in response.intra_decile} + assert 0 in decile_numbers # overall row + + def test__given_intra_decile_records__then_proportions_sum_to_one(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_intra_decile_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for r in response.intra_decile: + total = ( + r.lose_more_than_5pct + + r.lose_less_than_5pct + + r.no_change + + r.gain_less_than_5pct + + r.gain_more_than_5pct + ) + assert abs(total - 1.0) < 1e-9 + + +# --------------------------------------------------------------------------- +# _build_response — program_statistics & detailed_budget +# --------------------------------------------------------------------------- + + +class TestBuildResponseProgramStatistics: + """Tests for program_statistics and detailed_budget in _build_response.""" + + def test__given_completed_report_with_programs__then_correct_count(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_program_statistics_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.program_statistics is not None + assert len(response.program_statistics) == UK_PROGRAM_COUNT + + def test__given_uk_programs__then_all_10_programs_present(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_program_statistics_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + program_names = {s.program_name for s in response.program_statistics} + assert program_names == set(UK_PROGRAMS.keys()) + + def test__given_program_records__then_detailed_budget_has_same_keys(self, session): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_program_statistics_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.detailed_budget is not None + assert set(response.detailed_budget.keys()) == set(UK_PROGRAMS.keys()) + + def test__given_program_records__then_detailed_budget_has_baseline_reform_difference( + self, session + ): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_program_statistics_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + for prog_name, entry in response.detailed_budget.items(): + assert "baseline" in entry + assert "reform" in entry + assert "difference" in entry + + def test__given_program_records__then_detailed_budget_difference_matches_change( + self, session + ): + # Given + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_program_statistics_records(session, report, baseline_sim, reform_sim) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then — difference should equal reform - baseline (from ProgramStatistics.change) + for prog_name, entry in response.detailed_budget.items(): + expected_diff = entry["reform"] - entry["baseline"] + assert abs(entry["difference"] - expected_diff) < 1e-9 + + def test__given_no_program_records__then_detailed_budget_is_empty_dict( + self, session + ): + # Given — completed report with no program statistics + report, baseline_sim, reform_sim = create_report_with_simulations(session) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.detailed_budget == {} + + def test__given_program_with_nan_values__then_detailed_budget_has_none( + self, session + ): + # Given + from policyengine_api.models import ProgramStatistics + + report, baseline_sim, reform_sim = create_report_with_simulations(session) + rec = ProgramStatistics( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + program_name="test_program", + entity="person", + is_tax=True, + baseline_total=float("nan"), + reform_total=float("nan"), + change=float("nan"), + baseline_count=0.0, + reform_count=0.0, + winners=0.0, + losers=0.0, + ) + session.add(rec) + session.commit() + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.detailed_budget["test_program"]["baseline"] is None + assert response.detailed_budget["test_program"]["reform"] is None + assert response.detailed_budget["test_program"]["difference"] is None + + +# --------------------------------------------------------------------------- +# _build_response — fully populated report +# --------------------------------------------------------------------------- + + +class TestBuildResponseFullyPopulated: + """Tests for _build_response with all output tables populated.""" + + def test__given_fully_populated_report__then_all_fields_present(self, session): + # Given + report, baseline_sim, reform_sim = create_fully_populated_report(session) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.status == ReportStatus.COMPLETED + assert response.poverty is not None + assert response.inequality is not None + assert response.budget_summary is not None + assert response.intra_decile is not None + assert response.program_statistics is not None + assert response.detailed_budget is not None + assert response.congressional_district_impact is not None + assert response.constituency_impact is not None + assert response.local_authority_impact is not None + assert response.wealth_decile is not None + assert response.intra_wealth_decile is not None + + def test__given_fully_populated_report__then_report_id_matches(self, session): + # Given + report, baseline_sim, reform_sim = create_fully_populated_report(session) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.report_id == report.id + + def test__given_fully_populated_report__then_simulation_ids_match(self, session): + # Given + report, baseline_sim, reform_sim = create_fully_populated_report(session) + + # When + response = _build_response(report, baseline_sim, reform_sim, session) + + # Then + assert response.baseline_simulation.id == baseline_sim.id + assert response.reform_simulation.id == reform_sim.id + + +# --------------------------------------------------------------------------- +# _build_response — congressional_district_impact +# --------------------------------------------------------------------------- + + +class TestBuildResponseCongressionalDistrict: + """Tests for congressional_district_impact in _build_response output.""" + + def test__given_district_records__then_correct_count(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_congressional_district_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.congressional_district_impact is not None + assert len(response.congressional_district_impact) == SAMPLE_DISTRICT_COUNT + + def test__given_no_district_records__then_field_is_none(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.congressional_district_impact is None + + def test__given_district_records__then_geoid_fields_populated(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_congressional_district_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + for d in response.congressional_district_impact: + assert d.district_geoid > 0 + assert d.state_fips >= 0 + assert d.population > 0 + + +# --------------------------------------------------------------------------- +# _build_response — constituency_impact +# --------------------------------------------------------------------------- + + +class TestBuildResponseConstituency: + """Tests for constituency_impact in _build_response output.""" + + def test__given_constituency_records__then_correct_count(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_constituency_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.constituency_impact is not None + assert len(response.constituency_impact) == SAMPLE_CONSTITUENCY_COUNT + + def test__given_no_constituency_records__then_field_is_none(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.constituency_impact is None + + def test__given_constituency_records__then_code_and_name_populated(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_constituency_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + for c in response.constituency_impact: + assert c.constituency_code is not None + assert c.constituency_name is not None + assert c.population > 0 + + +# --------------------------------------------------------------------------- +# _build_response — local_authority_impact +# --------------------------------------------------------------------------- + + +class TestBuildResponseLocalAuthority: + """Tests for local_authority_impact in _build_response output.""" + + def test__given_la_records__then_correct_count(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_local_authority_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.local_authority_impact is not None + assert len(response.local_authority_impact) == SAMPLE_LA_COUNT + + def test__given_no_la_records__then_field_is_none(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.local_authority_impact is None + + def test__given_la_records__then_code_and_name_populated(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_local_authority_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + for la in response.local_authority_impact: + assert la.local_authority_code is not None + assert la.local_authority_name is not None + assert la.population > 0 + + +# --------------------------------------------------------------------------- +# _build_response — wealth_decile +# --------------------------------------------------------------------------- + + +class TestBuildResponseWealthDecile: + """Tests for wealth_decile in _build_response output.""" + + def test__given_wealth_decile_records__then_correct_count(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_wealth_decile_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.wealth_decile is not None + assert len(response.wealth_decile) == SAMPLE_WEALTH_DECILE_COUNT + + def test__given_no_wealth_decile_records__then_field_is_none(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.wealth_decile is None + + def test__given_wealth_decile_records__then_income_variable_is_wealth( + self, session + ): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_wealth_decile_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + for d in response.wealth_decile: + assert d.income_variable == "household_wealth_decile" + + +# --------------------------------------------------------------------------- +# _build_response — intra_wealth_decile +# --------------------------------------------------------------------------- + + +class TestBuildResponseIntraWealthDecile: + """Tests for intra_wealth_decile in _build_response output.""" + + def test__given_intra_wealth_records__then_correct_count(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_intra_wealth_decile_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.intra_wealth_decile is not None + assert len(response.intra_wealth_decile) == SAMPLE_INTRA_WEALTH_DECILE_COUNT + + def test__given_no_intra_wealth_records__then_field_is_none(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + + response = _build_response(report, baseline_sim, reform_sim, session) + + assert response.intra_wealth_decile is None + + def test__given_intra_wealth_records__then_decile_type_is_wealth(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_intra_wealth_decile_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + for r in response.intra_wealth_decile: + assert r.decile_type == "wealth" + + def test__given_intra_wealth_records__then_overall_row_present(self, session): + report, baseline_sim, reform_sim = create_report_with_simulations(session) + add_intra_wealth_decile_records(session, report, baseline_sim, reform_sim) + + response = _build_response(report, baseline_sim, reform_sim, session) + + decile_numbers = {r.decile for r in response.intra_wealth_decile} + assert 0 in decile_numbers diff --git a/tests/test_economy_custom.py b/tests/test_economy_custom.py new file mode 100644 index 0000000..fda65fa --- /dev/null +++ b/tests/test_economy_custom.py @@ -0,0 +1,315 @@ +"""Tests for POST /analysis/economy-custom endpoint.""" + +from uuid import uuid4 + +from policyengine_api.api.analysis import ( + EconomicImpactResponse, + SimulationInfo, + _build_filtered_response, +) +from policyengine_api.models import ReportStatus, SimulationStatus + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_stub_response(**overrides) -> EconomicImpactResponse: + """Build a minimal EconomicImpactResponse for testing.""" + defaults = dict( + report_id=uuid4(), + status=ReportStatus.COMPLETED, + baseline_simulation=SimulationInfo( + id=uuid4(), status=SimulationStatus.COMPLETED + ), + reform_simulation=SimulationInfo(id=uuid4(), status=SimulationStatus.COMPLETED), + region=None, + error_message=None, + decile_impacts=[{"fake": "decile"}], + program_statistics=[{"fake": "program"}], + poverty=[{"fake": "poverty"}], + inequality=[{"fake": "inequality"}], + budget_summary=[{"fake": "budget"}], + intra_decile=[{"fake": "intra"}], + detailed_budget={"prog": {"baseline": 1.0}}, + congressional_district_impact=[{"fake": "district"}], + constituency_impact=[{"fake": "constituency"}], + local_authority_impact=[{"fake": "la"}], + wealth_decile=[{"fake": "wealth"}], + intra_wealth_decile=[{"fake": "intra_wealth"}], + ) + defaults.update(overrides) + return EconomicImpactResponse.model_construct(**defaults) + + +# All data fields that can be nullified by module filtering +_DATA_FIELDS = { + "decile_impacts", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "detailed_budget", + "congressional_district_impact", + "constituency_impact", + "local_authority_impact", + "wealth_decile", + "intra_wealth_decile", +} + +_ALWAYS_INCLUDED = { + "report_id", + "status", + "baseline_simulation", + "reform_simulation", + "region", + "error_message", +} + + +# --------------------------------------------------------------------------- +# Unit tests for _build_filtered_response +# --------------------------------------------------------------------------- + + +class TestBuildFilteredResponse: + """Tests for response filtering by module list.""" + + def test_single_module_keeps_only_its_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty"]) + assert filtered.poverty is not None + assert filtered.decile_impacts is None + assert filtered.program_statistics is None + assert filtered.inequality is None + + def test_multiple_modules(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["decile", "inequality"]) + assert filtered.decile_impacts is not None + assert filtered.inequality is not None + assert filtered.poverty is None + assert filtered.congressional_district_impact is None + + def test_program_statistics_includes_detailed_budget(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["program_statistics"]) + assert filtered.program_statistics is not None + assert filtered.detailed_budget is not None + assert filtered.decile_impacts is None + + def test_always_included_fields_preserved(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty"]) + assert filtered.report_id == resp.report_id + assert filtered.status == resp.status + assert filtered.baseline_simulation is not None + assert filtered.reform_simulation is not None + + def test_region_always_included(self): + from policyengine_api.api.analysis import RegionInfo + + region = RegionInfo( + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + resp = _make_stub_response(region=region) + filtered = _build_filtered_response(resp, ["decile"]) + assert filtered.region is not None + assert filtered.region.code == "uk" + + def test_error_message_always_included(self): + resp = _make_stub_response(error_message="something went wrong") + filtered = _build_filtered_response(resp, ["decile"]) + assert filtered.error_message == "something went wrong" + + def test_empty_modules_nullifies_all_data_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, []) + for field in _DATA_FIELDS: + assert getattr(filtered, field) is None, f"{field} should be None" + assert filtered.report_id == resp.report_id + + def test_empty_modules_preserves_always_included(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, []) + for field in _ALWAYS_INCLUDED: + original = getattr(resp, field) + assert getattr(filtered, field) == original, f"{field} should be preserved" + + def test_wealth_decile_keeps_both_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["wealth_decile"]) + assert filtered.wealth_decile is not None + assert filtered.intra_wealth_decile is not None + assert filtered.decile_impacts is None + assert filtered.intra_decile is None + + def test_intra_decile_keeps_only_intra_decile(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["intra_decile"]) + assert filtered.intra_decile is not None + assert filtered.decile_impacts is None + assert filtered.intra_wealth_decile is None + + def test_congressional_district_keeps_only_district_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["congressional_district"]) + assert filtered.congressional_district_impact is not None + assert filtered.constituency_impact is None + assert filtered.local_authority_impact is None + + def test_constituency_keeps_only_constituency_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["constituency"]) + assert filtered.constituency_impact is not None + assert filtered.congressional_district_impact is None + assert filtered.local_authority_impact is None + + def test_local_authority_keeps_only_la_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["local_authority"]) + assert filtered.local_authority_impact is not None + assert filtered.constituency_impact is None + assert filtered.congressional_district_impact is None + + def test_budget_summary_keeps_only_budget(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["budget_summary"]) + assert filtered.budget_summary is not None + assert filtered.decile_impacts is None + assert filtered.program_statistics is None + + def test_unknown_module_in_list_is_gracefully_ignored(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty", "nonexistent_module"]) + assert filtered.poverty is not None + assert filtered.decile_impacts is None + + def test_all_modules_keeps_all_data_fields(self): + from policyengine_api.api.module_registry import MODULE_REGISTRY + + resp = _make_stub_response() + all_names = list(MODULE_REGISTRY.keys()) + filtered = _build_filtered_response(resp, all_names) + for field in _DATA_FIELDS: + assert getattr(filtered, field) is not None, ( + f"{field} should be preserved when all modules selected" + ) + + def test_returns_economic_impact_response_instance(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["decile"]) + assert isinstance(filtered, EconomicImpactResponse) + + +# --------------------------------------------------------------------------- +# Integration tests for the endpoint itself +# --------------------------------------------------------------------------- + + +class TestEconomyCustomEndpoint: + """Tests for POST /analysis/economy-custom validation.""" + + def test_unknown_module_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["nonexistent_module"], + }, + ) + assert response.status_code == 422 + assert "Unknown module" in response.json()["detail"] + + def test_wrong_country_module_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["constituency"], + }, + ) + assert response.status_code == 422 + assert "not available for country" in response.json()["detail"] + + def test_multiple_errors_in_module_validation(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["nonexistent", "constituency"], + }, + ) + assert response.status_code == 422 + detail = response.json()["detail"] + assert "Unknown module" in detail + assert "not available for country" in detail + + def test_empty_modules_list_passes_validation(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": [], + }, + ) + # Empty list passes module validation, so the error should be about + # dataset/region resolution, not about modules + assert ( + response.status_code != 422 + or "module" not in response.json().get("detail", "").lower() + ) + + def test_valid_modules_but_missing_region_returns_404(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["decile", "poverty"], + }, + ) + # Passes validation but region "us" is not in the DB -> 404 + assert response.status_code == 404 + + def test_missing_modules_field_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + }, + ) + assert response.status_code == 422 + + def test_invalid_model_name_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "invalid_model", + "region": "us", + "modules": ["decile"], + }, + ) + assert response.status_code == 422 + + +class TestEconomyCustomPolling: + """Tests for GET /analysis/economy-custom/{report_id}.""" + + def test_not_found(self, client): + fake_id = uuid4() + response = client.get(f"/analysis/economy-custom/{fake_id}") + assert response.status_code == 404 + + def test_invalid_uuid_returns_422(self, client): + response = client.get("/analysis/economy-custom/not-a-uuid") + assert response.status_code == 422 diff --git a/tests/test_household.py b/tests/test_household.py index a7248b3..0cf0288 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -289,5 +289,204 @@ def test_missing_people(self): assert response.status_code == 422 +class TestUSPolicyReform: + """Tests for US household calculations with policy reforms.""" + + def _get_us_model_id(self) -> str: + """Get the US tax benefit model ID.""" + response = client.get("/tax-benefit-models/") + assert response.status_code == 200 + models = response.json() + for model in models: + if "us" in model["name"].lower(): + return model["id"] + raise AssertionError("US model not found") + + def _get_parameter_id(self, model_id: str, param_name: str) -> str: + """Get a parameter ID by name.""" + response = client.get( + f"/parameters/?tax_benefit_model_id={model_id}&limit=10000" + ) + assert response.status_code == 200 + params = response.json() + for param in params: + if param["name"] == param_name: + return param["id"] + raise AssertionError(f"Parameter {param_name} not found") + + def _create_policy(self, param_id: str, value: float) -> str: + """Create a policy with a parameter value.""" + response = client.post( + "/policies/", + json={ + "name": "Test Reform", + "description": "Test reform for household calculation", + "parameter_values": [ + { + "parameter_id": param_id, + "value_json": value, + "start_date": "2024-01-01T00:00:00Z", + } + ], + }, + ) + assert response.status_code == 200 + return response.json()["id"] + + def test_us_reform_changes_household_net_income(self): + """Test that a US policy reform changes household net income. + + This test verifies the fix for the US reform application bug where + reforms were not being applied correctly due to the shared singleton + TaxBenefitSystem in policyengine-us. + """ + # Get the US model and a UBI parameter + model_id = self._get_us_model_id() + param_name = ( + "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount" + ) + param_id = self._get_parameter_id(model_id, param_name) + + # Create a policy with $1000 UBI for older adults + policy_id = self._create_policy(param_id, 1000) + + # Run baseline calculation (no policy) + baseline_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 40, "employment_income": 70000}], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + }, + ) + assert baseline_response.status_code == 200 + baseline_data = _poll_job(baseline_response.json()["job_id"]) + baseline_net_income = baseline_data["result"]["household"][0][ + "household_net_income" + ] + + # Run reform calculation (with UBI policy) + reform_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 40, "employment_income": 70000}], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + "policy_id": policy_id, + }, + ) + assert reform_response.status_code == 200 + reform_data = _poll_job(reform_response.json()["job_id"]) + reform_net_income = reform_data["result"]["household"][0][ + "household_net_income" + ] + + # Verify the reform increased net income by approximately $1000 + difference = reform_net_income - baseline_net_income + assert abs(difference - 1000) < 1, ( + f"Expected ~$1000 difference, got ${difference:.2f}. " + f"Baseline: ${baseline_net_income:.2f}, Reform: ${reform_net_income:.2f}" + ) + + +class TestUKPolicyReform: + """Tests for UK household calculations with policy reforms.""" + + def _get_uk_model_id(self) -> str | None: + """Get the UK tax benefit model ID, or None if not seeded.""" + response = client.get("/tax-benefit-models/") + assert response.status_code == 200 + models = response.json() + for model in models: + if "uk" in model["name"].lower(): + return model["id"] + return None + + def _get_parameter_id(self, model_id: str, param_name: str) -> str: + """Get a parameter ID by name.""" + response = client.get( + f"/parameters/?tax_benefit_model_id={model_id}&limit=10000" + ) + assert response.status_code == 200 + params = response.json() + for param in params: + if param["name"] == param_name: + return param["id"] + raise AssertionError(f"Parameter {param_name} not found") + + def _create_policy(self, param_id: str, value: float) -> str: + """Create a policy with a parameter value.""" + response = client.post( + "/policies/", + json={ + "name": "Test UK Reform", + "description": "Test reform for UK household calculation", + "parameter_values": [ + { + "parameter_id": param_id, + "value_json": value, + "start_date": "2026-01-01T00:00:00Z", + } + ], + }, + ) + assert response.status_code == 200 + return response.json()["id"] + + def test_uk_reform_changes_household_net_income(self): + """Test that a UK policy reform changes household net income.""" + # Get the UK model and a UBI parameter + model_id = self._get_uk_model_id() + if model_id is None: + pytest.skip("UK model not seeded in database") + param_name = "gov.contrib.ubi_center.basic_income.adult" + param_id = self._get_parameter_id(model_id, param_name) + + # Create a policy with £1000 UBI for adults + policy_id = self._create_policy(param_id, 1000) + + # Run baseline calculation (no policy) + baseline_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + }, + ) + assert baseline_response.status_code == 200 + baseline_data = _poll_job(baseline_response.json()["job_id"]) + baseline_net_income = baseline_data["result"]["household"][0][ + "household_net_income" + ] + + # Run reform calculation (with UBI policy) + reform_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + "policy_id": policy_id, + }, + ) + assert reform_response.status_code == 200 + reform_data = _poll_job(reform_response.json()["job_id"]) + reform_net_income = reform_data["result"]["household"][0][ + "household_net_income" + ] + + # Verify the reform increased net income + difference = reform_net_income - baseline_net_income + assert difference > 0, ( + f"Expected positive difference, got £{difference:.2f}. " + f"Baseline: £{baseline_net_income:.2f}, Reform: £{reform_net_income:.2f}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_household_calculation.py b/tests/test_household_calculation.py new file mode 100644 index 0000000..e4fc2a5 --- /dev/null +++ b/tests/test_household_calculation.py @@ -0,0 +1,128 @@ +"""Unit tests for household calculation functions. + +These tests verify that the calculation functions work correctly with policy reforms, +without requiring database setup or API calls. +""" + +import pytest + +from policyengine_api.api.household import _calculate_household_us + + +class TestUSHouseholdCalculation: + """Unit tests for US household calculation with policy reforms.""" + + @pytest.mark.slow + def test_baseline_calculation(self): + """Test basic US household calculation without policy.""" + result = _calculate_household_us( + people=[{"employment_income": 70000, "age": 40}], + marital_unit=[], + family=[], + spm_unit=[], + tax_unit=[{"state_code": "CA"}], + household=[{"state_fips": 6}], + year=2024, + policy_data=None, + ) + + assert "person" in result + assert "household" in result + assert "tax_unit" in result + assert len(result["person"]) == 1 + assert result["tax_unit"][0]["income_tax"] > 0 + + @pytest.mark.slow + def test_reform_changes_net_income(self): + """Test that a US policy reform changes household net income. + + This test verifies the fix for the US reform application bug where + reforms were not being applied correctly due to the shared singleton + TaxBenefitSystem in policyengine-us. + """ + household_args = { + "people": [{"employment_income": 70000, "age": 40}], + "marital_unit": [], + "family": [], + "spm_unit": [], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + } + + # Calculate baseline (no policy) + baseline = _calculate_household_us(**household_args, policy_data=None) + baseline_net_income = baseline["household"][0]["household_net_income"] + + # Calculate with $1000 UBI reform + policy_data = { + "name": "Test UBI", + "description": "Test UBI reform", + "parameter_values": [ + { + "parameter_name": "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount", + "value": 1000, + "start_date": "2024-01-01T00:00:00", + "end_date": None, + } + ], + } + reform = _calculate_household_us(**household_args, policy_data=policy_data) + reform_net_income = reform["household"][0]["household_net_income"] + + # Verify the reform increased net income by exactly $1000 + difference = reform_net_income - baseline_net_income + assert abs(difference - 1000) < 1, ( + f"Expected ~$1000 difference, got ${difference:.2f}. " + f"Baseline: ${baseline_net_income:.2f}, Reform: ${reform_net_income:.2f}" + ) + + @pytest.mark.slow + def test_reform_does_not_affect_baseline(self): + """Test that running reform doesn't pollute baseline calculations. + + This is a regression test for the singleton pollution bug where running + a reform calculation would affect subsequent baseline calculations. + """ + household_args = { + "people": [{"employment_income": 70000, "age": 40}], + "marital_unit": [], + "family": [], + "spm_unit": [], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + } + + # First baseline + baseline1 = _calculate_household_us(**household_args, policy_data=None) + baseline1_net_income = baseline1["household"][0]["household_net_income"] + + # Run reform + policy_data = { + "name": "Test UBI", + "description": "Test UBI reform", + "parameter_values": [ + { + "parameter_name": "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount", + "value": 5000, + "start_date": "2024-01-01T00:00:00", + "end_date": None, + } + ], + } + _calculate_household_us(**household_args, policy_data=policy_data) + + # Second baseline - should be same as first + baseline2 = _calculate_household_us(**household_args, policy_data=None) + baseline2_net_income = baseline2["household"][0]["household_net_income"] + + # Verify baselines are identical + assert abs(baseline1_net_income - baseline2_net_income) < 0.01, ( + f"Baseline changed after reform calculation! " + f"Before: ${baseline1_net_income:.2f}, After: ${baseline2_net_income:.2f}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_households.py b/tests/test_households.py new file mode 100644 index 0000000..4c60062 --- /dev/null +++ b/tests/test_households.py @@ -0,0 +1,155 @@ +"""Tests for stored household CRUD endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_households import ( + MOCK_HOUSEHOLD_MINIMAL, + MOCK_UK_HOUSEHOLD_CREATE, + MOCK_US_HOUSEHOLD_CREATE, + create_household, +) + +# --------------------------------------------------------------------------- +# POST /households +# --------------------------------------------------------------------------- + + +def test_create_us_household(client): + """Create a US household returns 201 with id and timestamps.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["year"] == 2024 + assert data["label"] == "US test household" + + +def test_create_household_returns_people_and_entities(client): + """Created household response includes people and entity groups.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + data = response.json() + assert len(data["people"]) == 2 + assert data["people"][0]["age"] == 30 + assert data["people"][0]["employment_income"] == 50000 + assert data["household"] == {"state_name": "CA"} + assert data["tax_unit"] == {} + assert data["family"] == {} + + +def test_create_uk_household(client): + """Create a UK household with benunit.""" + response = client.post("/households", json=MOCK_UK_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert data["tax_benefit_model_name"] == "policyengine_uk" + assert data["benunit"] == {"is_married": False} + assert data["household"] == {"region": "LONDON"} + + +def test_create_household_minimal(client): + """Create a household with minimal fields.""" + response = client.post("/households", json=MOCK_HOUSEHOLD_MINIMAL) + assert response.status_code == 201 + data = response.json() + assert data["label"] is None + assert data["tax_unit"] is None + assert data["benunit"] is None + + +def test_create_household_invalid_model_name(client): + """Reject invalid tax_benefit_model_name.""" + payload = {**MOCK_HOUSEHOLD_MINIMAL, "tax_benefit_model_name": "invalid"} + response = client.post("/households", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /households/{id} +# --------------------------------------------------------------------------- + + +def test_get_household(client, session): + """Get a stored household by ID.""" + record = create_household(session) + response = client.get(f"/households/{record.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(record.id) + assert data["tax_benefit_model_name"] == "policyengine_us" + + +def test_get_household_not_found(client): + """Get a non-existent household returns 404.""" + fake_id = uuid4() + response = client.get(f"/households/{fake_id}") + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /households +# --------------------------------------------------------------------------- + + +def test_list_households_empty(client): + """List households returns empty list when none exist.""" + response = client.get("/households") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_households_with_data(client, session): + """List households returns all stored households.""" + create_household(session, label="first") + create_household(session, label="second") + response = client.get("/households") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_households_filter_by_model_name(client, session): + """Filter households by tax_benefit_model_name.""" + create_household(session, tax_benefit_model_name="policyengine_us") + create_household(session, tax_benefit_model_name="policyengine_uk") + response = client.get( + "/households", params={"tax_benefit_model_name": "policyengine_uk"} + ) + data = response.json() + assert len(data) == 1 + assert data[0]["tax_benefit_model_name"] == "policyengine_uk" + + +def test_list_households_limit_and_offset(client, session): + """Respect limit and offset pagination.""" + for i in range(5): + create_household(session, label=f"household-{i}") + response = client.get("/households", params={"limit": 2, "offset": 1}) + data = response.json() + assert len(data) == 2 + + +# --------------------------------------------------------------------------- +# DELETE /households/{id} +# --------------------------------------------------------------------------- + + +def test_delete_household(client, session): + """Delete a household returns 204.""" + record = create_household(session) + response = client.delete(f"/households/{record.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/households/{record.id}") + assert response.status_code == 404 + + +def test_delete_household_not_found(client): + """Delete a non-existent household returns 404.""" + fake_id = uuid4() + response = client.delete(f"/households/{fake_id}") + assert response.status_code == 404 diff --git a/tests/test_integration.py b/tests/test_integration.py index e044cab..e055423 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -9,6 +9,7 @@ pytestmark = pytest.mark.integration from datetime import datetime, timezone + from rich.console import Console from sqlmodel import Session, create_engine, select diff --git a/tests/test_intra_decile.py b/tests/test_intra_decile.py new file mode 100644 index 0000000..854843d --- /dev/null +++ b/tests/test_intra_decile.py @@ -0,0 +1,309 @@ +"""Tests for intra-decile income change computation.""" + +import numpy as np + +from policyengine_api.api.intra_decile import ( + _income_change_corrected, + _income_change_v1_original, + compute_intra_decile, + get_income_change_formula, +) +from test_fixtures.fixtures_intra_decile import ( + CATEGORY_NAMES, + EXPECTED_DECILE_NUMBERS, + EXPECTED_ROW_COUNT, + make_baseline_income, + make_household_data, + make_single_household_arrays, +) + +# --------------------------------------------------------------------------- +# Income change formula variants +# --------------------------------------------------------------------------- + + +class TestIncomeChangeFormulas: + """Tests for the two income change formula variants.""" + + def test__given_both_incomes_above_1__when_v1_formula__then_doubles_percentage( + self, + ): + # Given + baseline, reform = make_single_household_arrays(100.0, 103.0) + + # When + result = _income_change_v1_original(baseline, reform) + + # Then — V1 produces ~6% instead of 3% + assert abs(result[0] - 0.06) < 1e-9 + + def test__given_both_incomes_above_1__when_corrected_formula__then_correct_percentage( + self, + ): + # Given + baseline, reform = make_single_household_arrays(100.0, 103.0) + + # When + result = _income_change_corrected(baseline, reform) + + # Then + assert abs(result[0] - 0.03) < 1e-9 + + def test__given_zero_baseline__when_corrected_formula__then_caps_denominator_at_1( + self, + ): + # Given + baseline, reform = make_single_household_arrays(0.0, 10.0) + + # When + result = _income_change_corrected(baseline, reform) + + # Then — denominator capped at 1, so change = (10 - 0) / 1 = 10.0 + assert abs(result[0] - 10.0) < 1e-9 + + def test__given_negative_baseline__when_corrected_formula__then_caps_denominator_at_1( + self, + ): + # Given + baseline, reform = make_single_household_arrays(-5.0, 5.0) + + # When + result = _income_change_corrected(baseline, reform) + + # Then — denominator capped at 1, change = (5 - (-5)) / 1 = 10.0 + assert abs(result[0] - 10.0) < 1e-9 + + def test__given_identical_incomes__when_v1_formula__then_zero_change(self): + # Given + baseline, reform = make_single_household_arrays(50_000.0, 50_000.0) + + # When + result = _income_change_v1_original(baseline, reform) + + # Then + assert result[0] == 0.0 + + def test__given_identical_incomes__when_corrected_formula__then_zero_change(self): + # Given + baseline, reform = make_single_household_arrays(50_000.0, 50_000.0) + + # When + result = _income_change_corrected(baseline, reform) + + # Then + assert result[0] == 0.0 + + def test__given_strategy_selector__then_returns_corrected_formula(self): + # When + formula = get_income_change_formula() + + # Then + assert formula is _income_change_corrected + + +# --------------------------------------------------------------------------- +# compute_intra_decile — structure +# --------------------------------------------------------------------------- + + +class TestComputeIntraDecileStructure: + """Tests for the shape and structure of compute_intra_decile output.""" + + def test__given_any_input__then_returns_11_rows(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + assert len(rows) == EXPECTED_ROW_COUNT + + def test__given_any_input__then_decile_numbers_are_1_through_10_plus_0(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income) + + # When + rows = compute_intra_decile(baseline, reform) + decile_numbers = [r["decile"] for r in rows] + + # Then + assert decile_numbers == EXPECTED_DECILE_NUMBERS + + def test__given_any_input__then_each_row_has_all_category_columns(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + for col in CATEGORY_NAMES: + assert col in row, ( + f"Missing column {col} in row for decile {row['decile']}" + ) + + def test__given_any_input__then_proportions_sum_to_approximately_one_per_decile( + self, + ): + # Given — a mix of changes so multiple categories are populated + income = make_baseline_income() + reform_income = income * np.where(np.arange(len(income)) % 3 == 0, 1.03, 0.97) + baseline, reform = make_household_data(income, reform_income) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + total = sum(row[col] for col in CATEGORY_NAMES) + assert abs(total - 1.0) < 1e-9, ( + f"Decile {row['decile']} proportions sum to {total}, expected 1.0" + ) + + def test__given_overall_row__then_is_mean_of_decile_proportions(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 1.03) + + # When + rows = compute_intra_decile(baseline, reform) + decile_rows = [r for r in rows if r["decile"] != 0] + overall_row = [r for r in rows if r["decile"] == 0][0] + + # Then + for col in CATEGORY_NAMES: + expected_mean = sum(r[col] for r in decile_rows) / 10 + assert abs(overall_row[col] - expected_mean) < 1e-9 + + +# --------------------------------------------------------------------------- +# compute_intra_decile — classification +# --------------------------------------------------------------------------- + + +class TestComputeIntraDecileClassification: + """Tests for correct classification of income changes into categories.""" + + def test__given_no_income_change__then_all_in_no_change_category(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income, income) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["no_change"] == 1.0 + assert row["gain_less_than_5pct"] == 0.0 + assert row["gain_more_than_5pct"] == 0.0 + assert row["lose_less_than_5pct"] == 0.0 + assert row["lose_more_than_5pct"] == 0.0 + + def test__given_uniform_3pct_raise__then_all_in_gain_less_than_5pct(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 1.03) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["gain_less_than_5pct"] == 1.0 + + def test__given_uniform_10pct_raise__then_all_in_gain_more_than_5pct(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 1.10) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["gain_more_than_5pct"] == 1.0 + + def test__given_uniform_3pct_loss__then_all_in_lose_less_than_5pct(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 0.97) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["lose_less_than_5pct"] == 1.0 + + def test__given_uniform_10pct_loss__then_all_in_lose_more_than_5pct(self): + # Given + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 0.90) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["lose_more_than_5pct"] == 1.0 + + def test__given_boundary_at_exactly_5pct_gain__then_in_gain_less_than_5pct(self): + # Given — BOUNDS uses (lower, upper], so exactly 0.05 falls in gain_less_than_5pct + # because the gain_less_than_5pct interval is (1e-3, 0.05] + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 1.05) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["gain_less_than_5pct"] == 1.0 + + def test__given_boundary_at_exactly_0_1pct_gain__then_in_no_change(self): + # Given — exactly 0.001 falls in no_change because the no_change + # interval is (-1e-3, 1e-3] and 0.001 == 1e-3 which is the upper bound + income = make_baseline_income() + baseline, reform = make_household_data(income, income * 1.001) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then + for row in rows: + assert row["no_change"] == 1.0 + + +# --------------------------------------------------------------------------- +# compute_intra_decile — edge cases +# --------------------------------------------------------------------------- + + +class TestComputeIntraDecileEdgeCases: + """Tests for edge cases in compute_intra_decile.""" + + def test__given_zero_people_in_decile__then_proportions_are_zero(self): + # Given — remove all households from decile 5 by setting their weight to 0 + income = make_baseline_income() + weights = np.ones(len(income)) * 100.0 + people = np.full(len(income), 2.0) + # Decile 5 is indices 40-49 + people[40:50] = 0.0 + + baseline, reform = make_household_data( + income, income * 1.03, weights=weights, people=people + ) + + # When + rows = compute_intra_decile(baseline, reform) + + # Then — decile 5 should have all-zero proportions + decile_5 = [r for r in rows if r["decile"] == 5][0] + for col in CATEGORY_NAMES: + assert decile_5[col] == 0.0 diff --git a/tests/test_models.py b/tests/test_models.py index 0f84140..e3a83d9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,9 +6,11 @@ AggregateOutput, AggregateType, Dataset, + Household, Policy, Simulation, SimulationStatus, + Variable, ) @@ -66,3 +68,90 @@ def test_aggregate_output_creation(): assert output.simulation_id == simulation_id assert output.aggregate_type == AggregateType.SUM assert output.result is None + + +def test_variable_creation_with_default_value(): + """Test variable model creation with default_value field.""" + model_version_id = uuid4() + variable = Variable( + name="age", + entity="person", + description="Age of the person", + data_type="int", + default_value=40, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.name == "age" + assert variable.entity == "person" + assert variable.data_type == "int" + assert variable.default_value == 40 + assert variable.id is not None + + +def test_variable_with_float_default_value(): + """Test variable model with float default value.""" + model_version_id = uuid4() + variable = Variable( + name="employment_income", + entity="person", + data_type="float", + default_value=0.0, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value == 0.0 + + +def test_variable_with_bool_default_value(): + """Test variable model with boolean default value.""" + model_version_id = uuid4() + variable = Variable( + name="is_disabled", + entity="person", + data_type="bool", + default_value=False, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value is False + + +def test_variable_with_string_default_value(): + """Test variable model with string default value (enum).""" + model_version_id = uuid4() + variable = Variable( + name="state_name", + entity="household", + data_type="Enum", + default_value="CA", + possible_values=["CA", "NY", "TX"], + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value == "CA" + assert variable.possible_values == ["CA", "NY", "TX"] + + +def test_variable_with_null_default_value(): + """Test variable model with null default value.""" + model_version_id = uuid4() + variable = Variable( + name="optional_field", + entity="person", + data_type="str", + default_value=None, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value is None + + +def test_household_creation(): + """Test household model creation.""" + household = Household( + tax_benefit_model_name="policyengine_us", + year=2024, + label="Test household", + household_data={"people": [{"age": 30}], "household": {}}, + ) + assert household.household_data == {"people": [{"age": 30}], "household": {}} + assert household.label == "Test household" + assert household.tax_benefit_model_name == "policyengine_us" + assert household.year == 2024 + assert household.id is not None diff --git a/tests/test_models_by_country.py b/tests/test_models_by_country.py new file mode 100644 index 0000000..6df2971 --- /dev/null +++ b/tests/test_models_by_country.py @@ -0,0 +1,146 @@ +"""Tests for GET /tax-benefit-models/by-country/{country_id} endpoint.""" + +from datetime import datetime, timezone + +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _create_model_and_version(session, name, description, version_str, **version_kw): + """Create a model and a single version, return (model, version).""" + model = TaxBenefitModel(name=name, description=description) + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, + version=version_str, + description=f"{name} {version_str}", + **version_kw, + ) + session.add(version) + session.commit() + session.refresh(version) + return model, version + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestModelByCountry: + """Tests for the by-country lookup.""" + + def test_uk_returns_model_and_version(self, client, session): + """country_id=uk returns the UK model and its latest version.""" + _create_model_and_version(session, "policyengine-uk", "UK model", "2.51.0") + + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 200 + data = response.json() + assert data["model"]["name"] == "policyengine-uk" + assert data["latest_version"]["version"] == "2.51.0" + + def test_us_returns_model_and_version(self, client, session): + """country_id=us returns the US model and its latest version.""" + _create_model_and_version(session, "policyengine-us", "US model", "1.20.0") + + response = client.get("/tax-benefit-models/by-country/us") + + assert response.status_code == 200 + data = response.json() + assert data["model"]["name"] == "policyengine-us" + assert data["latest_version"]["version"] == "1.20.0" + + def test_multiple_versions_returns_latest(self, client, session): + """When multiple versions exist, returns the most recently created.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + session.refresh(model) + + old = TaxBenefitModelVersion( + model_id=model.id, + version="2.50.0", + description="Old", + created_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + new = TaxBenefitModelVersion( + model_id=model.id, + version="2.51.0", + description="New", + created_at=datetime(2026, 2, 1, tzinfo=timezone.utc), + ) + session.add(old) + session.add(new) + session.commit() + + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 200 + assert response.json()["latest_version"]["version"] == "2.51.0" + + def test_no_model_returns_404(self, client): + """When the model doesn't exist in the DB, returns 404.""" + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 404 + assert "No model found" in response.json()["detail"] + + def test_model_without_versions_returns_404(self, client, session): + """When the model exists but has no versions, returns 404.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 404 + assert "No versions found" in response.json()["detail"] + + def test_invalid_country_id_returns_422(self, client): + """An invalid country_id is rejected by Literal validation.""" + response = client.get("/tax-benefit-models/by-country/fr") + + assert response.status_code == 422 + + def test_response_shape(self, client, session): + """Response contains the expected fields for both model and version.""" + _create_model_and_version(session, "policyengine-uk", "UK model", "2.51.0") + + response = client.get("/tax-benefit-models/by-country/uk") + data = response.json() + + # Model fields + model = data["model"] + assert "id" in model + assert "name" in model + assert "description" in model + assert "created_at" in model + + # Version fields + version = data["latest_version"] + assert "id" in version + assert "version" in version + assert "model_id" in version + assert "description" in version + assert "created_at" in version + + def test_country_isolation(self, client, session): + """UK endpoint doesn't return US model data and vice versa.""" + _create_model_and_version(session, "policyengine-uk", "UK", "2.51.0") + _create_model_and_version(session, "policyengine-us", "US", "1.20.0") + + uk_resp = client.get("/tax-benefit-models/by-country/uk") + us_resp = client.get("/tax-benefit-models/by-country/us") + + assert uk_resp.json()["model"]["name"] == "policyengine-uk" + assert uk_resp.json()["latest_version"]["version"] == "2.51.0" + assert us_resp.json()["model"]["name"] == "policyengine-us" + assert us_resp.json()["latest_version"]["version"] == "1.20.0" diff --git a/tests/test_module_registry.py b/tests/test_module_registry.py new file mode 100644 index 0000000..b3f8559 --- /dev/null +++ b/tests/test_module_registry.py @@ -0,0 +1,280 @@ +"""Tests for the economy analysis module registry.""" + +from dataclasses import FrozenInstanceError + +import pytest + +from policyengine_api.api.analysis import EconomicImpactResponse +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + ComputationModule, + get_all_module_names, + get_modules_for_country, + validate_modules, +) + + +class TestModuleRegistry: + """Tests for MODULE_REGISTRY contents.""" + + def test_registry_is_not_empty(self): + assert len(MODULE_REGISTRY) > 0 + + def test_registry_has_exactly_10_modules(self): + assert len(MODULE_REGISTRY) == 10 + + def test_all_entries_are_computation_modules(self): + for name, module in MODULE_REGISTRY.items(): + assert isinstance(module, ComputationModule) + assert module.name == name + + def test_all_modules_have_countries(self): + for module in MODULE_REGISTRY.values(): + assert len(module.countries) > 0 + for country in module.countries: + assert country in ("uk", "us") + + def test_all_modules_have_response_fields(self): + for module in MODULE_REGISTRY.values(): + assert len(module.response_fields) > 0 + + def test_all_modules_have_non_empty_label(self): + for name, module in MODULE_REGISTRY.items(): + assert module.label, f"Module {name!r} has empty label" + assert len(module.label) > 0 + + def test_all_modules_have_non_empty_description(self): + for name, module in MODULE_REGISTRY.items(): + assert module.description, f"Module {name!r} has empty description" + assert len(module.description) > 0 + + def test_expected_modules_exist(self): + expected = [ + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "congressional_district", + "constituency", + "local_authority", + "wealth_decile", + ] + for name in expected: + assert name in MODULE_REGISTRY, f"Missing module: {name}" + + def test_no_unexpected_modules(self): + expected = { + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "congressional_district", + "constituency", + "local_authority", + "wealth_decile", + } + assert set(MODULE_REGISTRY.keys()) == expected + + +class TestComputationModuleFrozen: + """Tests that ComputationModule instances are immutable.""" + + def test_cannot_mutate_name(self): + module = MODULE_REGISTRY["decile"] + with pytest.raises(FrozenInstanceError): + module.name = "changed" + + def test_cannot_mutate_countries(self): + module = MODULE_REGISTRY["decile"] + with pytest.raises(FrozenInstanceError): + module.countries = ["fr"] + + def test_cannot_mutate_response_fields(self): + module = MODULE_REGISTRY["poverty"] + with pytest.raises(FrozenInstanceError): + module.response_fields = ["something_else"] + + +class TestResponseFieldsMapping: + """Tests that each module's response_fields reference valid EconomicImpactResponse fields.""" + + def test_all_response_fields_exist_on_response_model(self): + valid_fields = set(EconomicImpactResponse.model_fields.keys()) + for name, module in MODULE_REGISTRY.items(): + for field in module.response_fields: + assert field in valid_fields, ( + f"Module {name!r} references response field {field!r} " + f"which does not exist on EconomicImpactResponse" + ) + + def test_decile_response_fields(self): + assert MODULE_REGISTRY["decile"].response_fields == ["decile_impacts"] + + def test_program_statistics_includes_detailed_budget(self): + fields = MODULE_REGISTRY["program_statistics"].response_fields + assert "program_statistics" in fields + assert "detailed_budget" in fields + + def test_poverty_response_fields(self): + assert MODULE_REGISTRY["poverty"].response_fields == ["poverty"] + + def test_inequality_response_fields(self): + assert MODULE_REGISTRY["inequality"].response_fields == ["inequality"] + + def test_budget_summary_response_fields(self): + assert MODULE_REGISTRY["budget_summary"].response_fields == ["budget_summary"] + + def test_intra_decile_response_fields(self): + assert MODULE_REGISTRY["intra_decile"].response_fields == ["intra_decile"] + + def test_congressional_district_response_fields(self): + assert MODULE_REGISTRY["congressional_district"].response_fields == [ + "congressional_district_impact" + ] + + def test_constituency_response_fields(self): + assert MODULE_REGISTRY["constituency"].response_fields == [ + "constituency_impact" + ] + + def test_local_authority_response_fields(self): + assert MODULE_REGISTRY["local_authority"].response_fields == [ + "local_authority_impact" + ] + + def test_wealth_decile_includes_both_fields(self): + fields = MODULE_REGISTRY["wealth_decile"].response_fields + assert "wealth_decile" in fields + assert "intra_wealth_decile" in fields + + +class TestCountryApplicability: + """Tests for country-specific module availability.""" + + def test_us_only_modules(self): + assert "us" in MODULE_REGISTRY["congressional_district"].countries + assert "uk" not in MODULE_REGISTRY["congressional_district"].countries + + def test_uk_only_modules(self): + for name in ("constituency", "local_authority", "wealth_decile"): + module = MODULE_REGISTRY[name] + assert "uk" in module.countries + assert "us" not in module.countries + + def test_shared_modules(self): + shared = [ + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + ] + for name in shared: + module = MODULE_REGISTRY[name] + assert "uk" in module.countries + assert "us" in module.countries + + +class TestGetModulesForCountry: + """Tests for get_modules_for_country().""" + + def test_uk_includes_constituency(self): + uk_modules = get_modules_for_country("uk") + names = [m.name for m in uk_modules] + assert "constituency" in names + assert "local_authority" in names + assert "wealth_decile" in names + + def test_uk_excludes_congressional_district(self): + uk_modules = get_modules_for_country("uk") + names = [m.name for m in uk_modules] + assert "congressional_district" not in names + + def test_uk_has_9_modules(self): + uk_modules = get_modules_for_country("uk") + assert len(uk_modules) == 9 + + def test_us_includes_congressional_district(self): + us_modules = get_modules_for_country("us") + names = [m.name for m in us_modules] + assert "congressional_district" in names + + def test_us_excludes_uk_only(self): + us_modules = get_modules_for_country("us") + names = [m.name for m in us_modules] + assert "constituency" not in names + assert "local_authority" not in names + assert "wealth_decile" not in names + + def test_us_has_7_modules(self): + us_modules = get_modules_for_country("us") + assert len(us_modules) == 7 + + def test_unknown_country_returns_empty(self): + assert get_modules_for_country("fr") == [] + + def test_returns_computation_module_instances(self): + for m in get_modules_for_country("uk"): + assert isinstance(m, ComputationModule) + + +class TestGetAllModuleNames: + """Tests for get_all_module_names().""" + + def test_returns_all_names(self): + names = get_all_module_names() + assert set(names) == set(MODULE_REGISTRY.keys()) + + def test_returns_list_of_strings(self): + names = get_all_module_names() + assert isinstance(names, list) + for name in names: + assert isinstance(name, str) + + +class TestValidateModules: + """Tests for validate_modules().""" + + def test_valid_us_modules(self): + result = validate_modules(["decile", "poverty"], "us") + assert result == ["decile", "poverty"] + + def test_valid_uk_modules(self): + result = validate_modules(["constituency", "wealth_decile"], "uk") + assert result == ["constituency", "wealth_decile"] + + def test_empty_list_passes_validation(self): + result = validate_modules([], "us") + assert result == [] + + def test_all_us_modules_pass_validation(self): + us_names = [m.name for m in get_modules_for_country("us")] + result = validate_modules(us_names, "us") + assert result == us_names + + def test_all_uk_modules_pass_validation(self): + uk_names = [m.name for m in get_modules_for_country("uk")] + result = validate_modules(uk_names, "uk") + assert result == uk_names + + def test_unknown_module_raises(self): + with pytest.raises(ValueError, match="Unknown module"): + validate_modules(["nonexistent"], "us") + + def test_wrong_country_raises(self): + with pytest.raises(ValueError, match="not available for country"): + validate_modules(["congressional_district"], "uk") + + def test_multiple_errors_combined(self): + with pytest.raises(ValueError, match="Unknown module.*not available"): + validate_modules(["nonexistent", "constituency"], "us") + + def test_returns_original_list_on_success(self): + names = ["poverty", "decile", "inequality"] + result = validate_modules(names, "us") + assert result is names diff --git a/tests/test_parameters.py b/tests/test_parameters.py index f95016b..50bb213 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -107,7 +107,7 @@ def test__given_policy_id_filter__then_returns_only_matching_values( """GET /parameter-values?policy_id=X returns only values for that policy.""" # Given param = create_parameter(session, model_version, "test.param", "Test Param") - policy = create_policy(session, "Test Policy") + policy = create_policy(session, "Test Policy", model_version) create_parameter_value(session, param.id, 100, policy_id=None) # baseline create_parameter_value(session, param.id, 150, policy_id=policy.id) # reform @@ -135,7 +135,7 @@ def test__given_both_parameter_and_policy_filters__then_returns_matching_interse param2 = create_parameter( session, model_version, "test.both.param2", "Test Both Param 2" ) - policy = create_policy(session, "Test Both Policy") + policy = create_policy(session, "Test Both Policy", model_version) create_parameter_value(session, param1.id, 100, policy_id=None) # baseline create_parameter_value(session, param1.id, 150, policy_id=policy.id) # target diff --git a/tests/test_parameters_by_name.py b/tests/test_parameters_by_name.py new file mode 100644 index 0000000..81bd360 --- /dev/null +++ b/tests/test_parameters_by_name.py @@ -0,0 +1,238 @@ +"""Tests for POST /parameters/by-name endpoint.""" + +import pytest + +from policyengine_api.models import ( + Parameter, + TaxBenefitModel, + TaxBenefitModelVersion, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def us_version(session): + """Create a policyengine-us model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="US v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +def create_parameter(session, model_version, name: str, label: str) -> Parameter: + """Create and persist a Parameter.""" + param = Parameter( + name=name, + label=label, + tax_benefit_model_version_id=model_version.id, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +class TestParametersByName: + """Tests for looking up parameters by their exact names.""" + + def test_returns_matching_parameters(self, client, session, us_version): + """Given known parameter names, returns their full metadata.""" + create_parameter(session, us_version, "gov.tax.rate", "Tax rate") + create_parameter(session, us_version, "gov.tax.threshold", "Threshold") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.tax.rate", "gov.tax.threshold"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + returned_names = {p["name"] for p in data} + assert returned_names == {"gov.tax.rate", "gov.tax.threshold"} + + def test_returns_empty_list_for_empty_names(self, client): + """Given an empty names list, returns an empty list.""" + response = client.post( + "/parameters/by-name", + json={ + "names": [], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_empty_list_for_unknown_names(self, client, session, us_version): + """Given names that don't match any parameter, returns an empty list.""" + create_parameter(session, us_version, "gov.exists", "Exists") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.does_not_exist", "gov.also_missing"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_only_matching_when_mix_of_known_and_unknown( + self, client, session, us_version + ): + """Given a mix of known and unknown names, returns only the known ones.""" + create_parameter(session, us_version, "gov.real", "Real param") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.real", "gov.fake"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.real" + + def test_filters_by_country(self, client, session): + """Parameters from a different country are excluded.""" + # Create two models + model_uk = TaxBenefitModel(name="policyengine-uk", description="UK") + model_us = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model_uk) + session.add(model_us) + session.commit() + session.refresh(model_uk) + session.refresh(model_us) + + ver_uk = TaxBenefitModelVersion( + model_id=model_uk.id, version="1.0", description="UK v1" + ) + ver_us = TaxBenefitModelVersion( + model_id=model_us.id, version="1.0", description="US v1" + ) + session.add(ver_uk) + session.add(ver_us) + session.commit() + session.refresh(ver_uk) + session.refresh(ver_us) + + # Same parameter name in both models + create_parameter(session, ver_uk, "gov.shared_name", "UK version") + create_parameter(session, ver_us, "gov.shared_name", "US version") + + # Request only UK + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.shared_name"], + "country_id": "uk", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["label"] == "UK version" + + def test_response_shape_matches_parameter_read(self, client, session, us_version): + """Returned objects have the same shape as ParameterRead.""" + create_parameter(session, us_version, "gov.shape_test", "Shape test") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.shape_test"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + param = data[0] + assert "id" in param + assert "name" in param + assert "label" in param + assert "created_at" in param + assert "tax_benefit_model_version_id" in param + + def test_results_ordered_by_name(self, client, session, us_version): + """Returned parameters are sorted alphabetically by name.""" + create_parameter(session, us_version, "gov.zzz", "Last") + create_parameter(session, us_version, "gov.aaa", "First") + create_parameter(session, us_version, "gov.mmm", "Middle") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.zzz", "gov.aaa", "gov.mmm"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + names = [p["name"] for p in response.json()] + assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] + + def test_missing_country_id_returns_422(self, client): + """Request without country_id is rejected.""" + response = client.post( + "/parameters/by-name", + json={"names": ["gov.something"]}, + ) + + assert response.status_code == 422 + + def test_invalid_country_id_returns_422(self, client): + """Request with invalid country_id is rejected.""" + response = client.post( + "/parameters/by-name", + json={"names": ["gov.something"], "country_id": "invalid"}, + ) + + assert response.status_code == 422 + + def test_missing_names_field_returns_422(self, client): + """Request without names field is rejected.""" + response = client.post( + "/parameters/by-name", + json={"country_id": "us"}, + ) + + assert response.status_code == 422 + + def test_single_name_lookup(self, client, session, us_version): + """Looking up a single parameter name works.""" + create_parameter(session, us_version, "gov.single", "Single param") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.single"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.single" diff --git a/tests/test_parameters_children.py b/tests/test_parameters_children.py new file mode 100644 index 0000000..d788179 --- /dev/null +++ b/tests/test_parameters_children.py @@ -0,0 +1,399 @@ +"""Tests for GET /parameters/children endpoint.""" + +import pytest + +from policyengine_api.models import ( + Parameter, + TaxBenefitModel, + TaxBenefitModelVersion, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def uk_version(session): + """Create a policyengine-uk model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="UK v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +@pytest.fixture +def us_version(session): + """Create a policyengine-us model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="US v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +def _add_params(session, version, names_and_labels): + """Bulk-add parameters. names_and_labels is [(name, label), ...].""" + for name, label in names_and_labels: + session.add( + Parameter( + name=name, + label=label, + tax_benefit_model_version_id=version.id, + ) + ) + session.commit() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestParameterChildrenBasic: + """Basic tree structure tests.""" + + def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version): + """Parameters at gov.hmrc.x and gov.dwp.x produce nodes for hmrc and dwp.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.income_tax.rate", "Basic rate"), + ("gov.hmrc.income_tax.threshold", "Threshold"), + ("gov.dwp.uc.amount", "UC amount"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["parent_path"] == "gov" + children = data["children"] + assert len(children) == 2 + paths = [c["path"] for c in children] + assert paths == ["gov.dwp", "gov.hmrc"] + for child in children: + assert child["type"] == "node" + assert child["child_count"] > 0 + + def test_returns_leaf_parameters(self, client, session, uk_version): + """Direct child parameters are returned with type='parameter'.""" + _add_params( + session, + uk_version, + [ + ("gov.benefit_uprating_cpi", "Benefit uprating CPI"), + ("gov.hmrc.income_tax.rate", "Basic rate"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + assert response.status_code == 200 + children = response.json()["children"] + assert len(children) == 2 + + leaf = next(c for c in children if c["type"] == "parameter") + assert leaf["path"] == "gov.benefit_uprating_cpi" + assert leaf["label"] == "Benefit uprating CPI" + assert leaf["parameter"] is not None + assert leaf["parameter"]["name"] == "gov.benefit_uprating_cpi" + + node = next(c for c in children if c["type"] == "node") + assert node["path"] == "gov.hmrc" + + def test_mixed_nodes_and_leaves(self, client, session, uk_version): + """Both nodes and leaf parameters can appear at the same level.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.tax.rate", "Rate"), + ("gov.flat_rate", "Flat rate"), + ("gov.threshold", "Threshold"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + types = {c["path"]: c["type"] for c in children} + assert types["gov.hmrc"] == "node" + assert types["gov.flat_rate"] == "parameter" + assert types["gov.threshold"] == "parameter" + + +class TestChildCount: + """Tests for child_count accuracy.""" + + def test_child_count_reflects_total_descendants(self, client, session, uk_version): + """child_count counts all leaf parameters under the node.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.income_tax.rate", "Rate"), + ("gov.hmrc.income_tax.threshold", "Threshold"), + ("gov.hmrc.ni.rate", "NI rate"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + hmrc = children[0] + assert hmrc["path"] == "gov.hmrc" + assert hmrc["child_count"] == 3 + + def test_nested_child_count(self, client, session, uk_version): + """Querying a deeper level gives accurate child counts.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.income_tax.rate", "Rate"), + ("gov.hmrc.income_tax.threshold", "Threshold"), + ("gov.hmrc.ni.rate", "NI rate"), + ], + ) + + response = client.get( + "/parameters/children", + params={"country_id": "uk", "parent_path": "gov.hmrc"}, + ) + + children = response.json()["children"] + assert len(children) == 2 + income_tax = next(c for c in children if c["path"] == "gov.hmrc.income_tax") + ni = next(c for c in children if c["path"] == "gov.hmrc.ni") + assert income_tax["child_count"] == 2 + assert ni["child_count"] == 1 + + def test_leaf_has_no_child_count(self, client, session, uk_version): + """Leaf parameters have child_count=None.""" + _add_params(session, uk_version, [("gov.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + assert len(children) == 1 + assert children[0]["child_count"] is None + + +class TestCountryFiltering: + """Tests for country_id filtering.""" + + def test_uk_country_id(self, client, session, uk_version): + """country_id=uk returns UK parameters.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + assert response.status_code == 200 + assert len(response.json()["children"]) == 1 + + def test_us_country_id(self, client, session, us_version): + """country_id=us returns US parameters.""" + _add_params(session, us_version, [("gov.irs.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + ) + + assert response.status_code == 200 + assert len(response.json()["children"]) == 1 + + def test_country_isolation(self, client, session, uk_version, us_version): + """Parameters from a different country are excluded.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "UK rate")]) + _add_params(session, us_version, [("gov.irs.rate", "US rate")]) + + uk_response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + us_response = client.get( + "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + ) + + uk_paths = [c["path"] for c in uk_response.json()["children"]] + us_paths = [c["path"] for c in us_response.json()["children"]] + assert uk_paths == ["gov.hmrc"] + assert us_paths == ["gov.irs"] + + def test_invalid_country_id_returns_422(self, client): + """An invalid country_id is rejected by Literal validation.""" + response = client.get( + "/parameters/children", + params={"country_id": "fr", "parent_path": "gov"}, + ) + + assert response.status_code == 422 + + +class TestEdgeCases: + """Tests for edge cases and special inputs.""" + + def test_empty_parent_path(self, client, session, uk_version): + """Empty parent_path returns top-level children.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": ""} + ) + + assert response.status_code == 200 + children = response.json()["children"] + assert len(children) == 1 + assert children[0]["path"] == "gov" + assert children[0]["type"] == "node" + + def test_nonexistent_parent_returns_empty(self, client, session, uk_version): + """A parent path with no descendants returns empty children list.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", + params={"country_id": "uk", "parent_path": "gov.dwp"}, + ) + + assert response.status_code == 200 + assert response.json()["children"] == [] + + def test_children_sorted_by_path(self, client, session, uk_version): + """Children are returned sorted alphabetically by path.""" + _add_params( + session, + uk_version, + [ + ("gov.zzz.param", "Z param"), + ("gov.aaa.param", "A param"), + ("gov.mmm.param", "M param"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + paths = [c["path"] for c in response.json()["children"]] + assert paths == ["gov.aaa", "gov.mmm", "gov.zzz"] + + def test_node_label_from_path_segment(self, client, session, uk_version): + """Node labels default to the last path segment when no parameter exists.""" + _add_params(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + assert children[0]["label"] == "hmrc" + + def test_missing_country_id_returns_422(self, client): + """Request without country_id returns 422.""" + response = client.get("/parameters/children", params={"parent_path": "gov"}) + + assert response.status_code == 422 + + def test_default_parent_path_is_empty(self, client, session, uk_version): + """Omitting parent_path defaults to empty string (root level).""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get("/parameters/children", params={"country_id": "uk"}) + + assert response.status_code == 200 + assert response.json()["parent_path"] == "" + assert len(response.json()["children"]) == 1 + + def test_leaf_parameter_includes_full_metadata(self, client, session, uk_version): + """Leaf parameters include the full ParameterRead shape.""" + _add_params(session, uk_version, [("gov.rate", "The rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + param = response.json()["children"][0]["parameter"] + assert param["name"] == "gov.rate" + assert param["label"] == "The rate" + assert "id" in param + assert "created_at" in param + assert "tax_benefit_model_version_id" in param + + def test_node_has_no_parameter_field(self, client, session, uk_version): + """Nodes do not include the parameter field.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + node = response.json()["children"][0] + assert node["type"] == "node" + assert node["parameter"] is None + + def test_deep_nesting(self, client, session, uk_version): + """Works correctly with deeply nested parameter paths.""" + _add_params( + session, + uk_version, + [("gov.hmrc.income_tax.rates.uk[0].rate", "Basic rate")], + ) + + # Each level should show the correct child + for parent, expected_child in [ + ("gov", "gov.hmrc"), + ("gov.hmrc", "gov.hmrc.income_tax"), + ("gov.hmrc.income_tax", "gov.hmrc.income_tax.rates"), + ("gov.hmrc.income_tax.rates", "gov.hmrc.income_tax.rates.uk[0]"), + ]: + resp = client.get( + "/parameters/children", + params={"country_id": "uk", "parent_path": parent}, + ) + children = resp.json()["children"] + assert len(children) == 1 + assert children[0]["path"] == expected_child + assert children[0]["type"] == "node" + + # Final level should be a leaf + resp = client.get( + "/parameters/children", + params={ + "country_id": "uk", + "parent_path": "gov.hmrc.income_tax.rates.uk[0]", + }, + ) + children = resp.json()["children"] + assert len(children) == 1 + assert children[0]["type"] == "parameter" + assert children[0]["path"] == "gov.hmrc.income_tax.rates.uk[0].rate" diff --git a/tests/test_policies.py b/tests/test_policies.py index f48730b..b4ac25f 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -12,25 +12,46 @@ def test_list_policies_empty(client): assert response.json() == [] -def test_create_policy(client): +def test_create_policy(client, tax_benefit_model): """Create a new policy.""" response = client.post( "/policies", json={ "name": "Test policy", "description": "A test policy", + "tax_benefit_model_id": str(tax_benefit_model.id), }, ) assert response.status_code == 200 data = response.json() assert data["name"] == "Test policy" assert data["description"] == "A test policy" + assert data["tax_benefit_model_id"] == str(tax_benefit_model.id) assert "id" in data -def test_list_policies_with_data(client, session): +def test_create_policy_invalid_tax_benefit_model(client): + """Create policy with non-existent tax_benefit_model returns 404.""" + fake_id = uuid4() + response = client.post( + "/policies", + json={ + "name": "Test policy", + "description": "A test policy", + "tax_benefit_model_id": str(fake_id), + }, + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Tax benefit model not found" + + +def test_list_policies_with_data(client, session, tax_benefit_model): """List policies returns all policies.""" - policy = Policy(name="test-policy", description="Test") + policy = Policy( + name="test-policy", + description="Test", + tax_benefit_model_id=tax_benefit_model.id, + ) session.add(policy) session.commit() @@ -41,9 +62,39 @@ def test_list_policies_with_data(client, session): assert data[0]["name"] == "test-policy" -def test_get_policy(client, session): +def test_list_policies_filter_by_tax_benefit_model( + client, session, tax_benefit_model, uk_tax_benefit_model +): + """List policies with tax_benefit_model_id filter.""" + policy1 = Policy( + name="US policy", + description="US", + tax_benefit_model_id=tax_benefit_model.id, + ) + policy2 = Policy( + name="UK policy", + description="UK", + tax_benefit_model_id=uk_tax_benefit_model.id, + ) + session.add(policy1) + session.add(policy2) + session.commit() + + # Filter by US model + response = client.get(f"/policies?tax_benefit_model_id={tax_benefit_model.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "US policy" + + +def test_get_policy(client, session, tax_benefit_model): """Get a specific policy by ID.""" - policy = Policy(name="test-policy", description="Test") + policy = Policy( + name="test-policy", + description="Test", + tax_benefit_model_id=tax_benefit_model.id, + ) session.add(policy) session.commit() session.refresh(policy) diff --git a/tests/test_simulations_standalone.py b/tests/test_simulations_standalone.py new file mode 100644 index 0000000..c900430 --- /dev/null +++ b/tests/test_simulations_standalone.py @@ -0,0 +1,309 @@ +"""Tests for standalone simulation endpoints (/simulations/household, /simulations/economy).""" + +from uuid import uuid4 + +from test_fixtures.fixtures_simulations_standalone import ( + create_dataset, + create_economy_simulation, + create_household, + create_household_simulation, + create_policy, + create_region, + create_us_model_and_version, +) + +# =========================================================================== +# POST /simulations/household +# =========================================================================== + + +def test_create_household_simulation(client, session): + """Create a household simulation returns 200 with pending status.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = {"household_id": str(household.id)} + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["household_id"] == str(household.id) + assert data["household_result"] is None + assert data["policy_id"] is None + + +def test_create_household_simulation_with_policy(client, session): + """Create a household simulation with a reform policy.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + policy = create_policy(session, model) + + payload = { + "household_id": str(household.id), + "policy_id": str(policy.id), + } + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["policy_id"] == str(policy.id) + + +def test_create_household_simulation_not_found(client, session): + """Creating with a non-existent household returns 404.""" + model, version = create_us_model_and_version(session) + payload = {"household_id": str(uuid4())} + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_household_simulation_policy_not_found(client, session): + """Creating with a non-existent policy returns 404.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = { + "household_id": str(household.id), + "policy_id": str(uuid4()), + } + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 404 + assert "Policy" in response.json()["detail"] + + +def test_household_simulation_deduplication(client, session): + """Same inputs produce the same simulation (deterministic UUID).""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = {"household_id": str(household.id)} + response1 = client.post("/simulations/household", json=payload) + response2 = client.post("/simulations/household", json=payload) + + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response1.json()["id"] == response2.json()["id"] + + +# =========================================================================== +# GET /simulations/household/{id} +# =========================================================================== + + +def test_get_household_simulation(client, session): + """Get a household simulation by ID.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + simulation = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/household/{simulation.id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(simulation.id) + assert data["status"] == "completed" + assert data["household_result"] is not None + + +def test_get_household_simulation_not_found(client, session): + """Get a non-existent household simulation returns 404.""" + response = client.get(f"/simulations/household/{uuid4()}") + assert response.status_code == 404 + + +def test_get_household_simulation_wrong_type(client, session): + """Get an economy simulation via the household endpoint returns 400.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + economy_sim = create_economy_simulation(session, version, dataset) + + response = client.get(f"/simulations/household/{economy_sim.id}") + assert response.status_code == 400 + assert "not a household simulation" in response.json()["detail"] + + +# =========================================================================== +# POST /simulations/economy +# =========================================================================== + + +def test_create_economy_simulation_with_region(client, session): + """Create an economy simulation using a region code.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + region = create_region(session, model, dataset, code="us", label="United States") + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "us", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["dataset_id"] == str(dataset.id) + assert data["region"]["code"] == "us" + assert data["region"]["label"] == "United States" + + +def test_create_economy_simulation_with_dataset(client, session): + """Create an economy simulation using a dataset_id directly.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["dataset_id"] == str(dataset.id) + assert data["region"] is None + + +def test_create_economy_simulation_with_region_filter(client, session): + """Create an economy simulation with a region that requires filtering.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + region = create_region( + session, + model, + dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "state/ca", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["filter_field"] == "state_code" + assert data["filter_value"] == "CA" + assert data["region"]["requires_filter"] is True + + +def test_create_economy_simulation_invalid_region(client, session): + """Creating with a non-existent region returns 404.""" + model, version = create_us_model_and_version(session) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "nonexistent/region", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_economy_simulation_no_region_or_dataset(client, session): + """Creating without region or dataset_id returns 422 (Pydantic validation).""" + model, version = create_us_model_and_version(session) + + payload = {"tax_benefit_model_name": "policyengine_us"} + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 422 + + +def test_create_economy_simulation_policy_not_found(client, session): + """Creating with a non-existent policy returns 404.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + "policy_id": str(uuid4()), + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 404 + assert "Policy" in response.json()["detail"] + + +def test_economy_simulation_deduplication(client, session): + """Same inputs produce the same simulation (deterministic UUID).""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + } + response1 = client.post("/simulations/economy", json=payload) + response2 = client.post("/simulations/economy", json=payload) + + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response1.json()["id"] == response2.json()["id"] + + +# =========================================================================== +# GET /simulations/economy/{id} +# =========================================================================== + + +def test_get_economy_simulation(client, session): + """Get an economy simulation by ID.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + simulation = create_economy_simulation(session, version, dataset) + + response = client.get(f"/simulations/economy/{simulation.id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(simulation.id) + assert data["status"] == "completed" + + +def test_get_economy_simulation_not_found(client, session): + """Get a non-existent economy simulation returns 404.""" + response = client.get(f"/simulations/economy/{uuid4()}") + assert response.status_code == 404 + + +def test_get_economy_simulation_wrong_type(client, session): + """Get a household simulation via the economy endpoint returns 400.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + household_sim = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/economy/{household_sim.id}") + assert response.status_code == 400 + assert "not an economy simulation" in response.json()["detail"] + + +# =========================================================================== +# Generic GET /simulations/{id} still works +# =========================================================================== + + +def test_get_simulation_generic(client, session): + """The generic GET /simulations/{id} endpoint still works for any type.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + simulation = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/{simulation.id}") + + assert response.status_code == 200 + assert response.json()["id"] == str(simulation.id) diff --git a/tests/test_user_household_associations.py b/tests/test_user_household_associations.py new file mode 100644 index 0000000..9da0a3b --- /dev/null +++ b/tests/test_user_household_associations.py @@ -0,0 +1,228 @@ +"""Tests for user-household association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_household_associations import ( + create_association, + create_household, + create_user, +) + +# --------------------------------------------------------------------------- +# POST /user-household-associations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 201 with id and timestamps.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "My US household", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user.id) + assert data["household_id"] == str(household.id) + assert data["country_id"] == "us" + assert data["label"] == "My US household" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same household are allowed.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-household-associations", json=payload) + assert r1.status_code == 201 + + payload["label"] = "Second label" + r2 = client.post("/user-household-associations", json=payload) + assert r2.status_code == 201 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_household_not_found(client, session): + """Creating with a non-existent household returns 404.""" + user = create_user(session) + payload = { + "user_id": str(user.id), + "household_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/user/{user_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get(f"/user-household-associations/user/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user = create_user(session) + h1 = create_household(session, label="H1") + h2 = create_household(session, label="H2") + create_association(session, user.id, h1.id, label="First") + create_association(session, user.id, h2.id, label="Second") + + response = client.get(f"/user-household-associations/user/{user.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, country_id="us") + create_association(session, user.id, household.id, country_id="uk") + + response = client.get( + f"/user-household-associations/user/{user.id}", + params={"country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/{user_id}/{household_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_and_household(client, session): + """List associations for a specific user+household pair.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, label="Label A") + create_association(session, user.id, household.id, label="Label B") + + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_and_household_empty(client): + """Returns empty list when no associations exist for the pair.""" + response = client.get(f"/user-household-associations/{uuid4()}/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +# --------------------------------------------------------------------------- +# PUT /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_update_association_label(client, session): + """Update label and verify updated_at changes.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id, label="Old") + + response = client.put( + f"/user-household-associations/{assoc.id}", + params={"user_id": str(user.id)}, + json={"label": "New label"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] == "New label" + + +def test_update_association_wrong_user(client, session): + """Update with wrong user_id returns 404 (ownership check).""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id, label="Old") + + response = client.put( + f"/user-household-associations/{assoc.id}", + params={"user_id": str(uuid4())}, + json={"label": "Hacked"}, + ) + assert response.status_code == 404 + + +def test_update_association_not_found(client): + """Update a non-existent association returns 404.""" + response = client.put( + f"/user-household-associations/{uuid4()}", + params={"user_id": str(uuid4())}, + json={"label": "Something"}, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# DELETE /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id) + + response = client.delete( + f"/user-household-associations/{assoc.id}", + params={"user_id": str(user.id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.json() == [] + + +def test_delete_association_wrong_user(client, session): + """Delete with wrong user_id returns 404 (ownership check).""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id) + + response = client.delete( + f"/user-household-associations/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + # Confirm it's still there + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert len(response.json()) == 1 + + +def test_delete_association_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-household-associations/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 diff --git a/tests/test_user_policies.py b/tests/test_user_policies.py new file mode 100644 index 0000000..3bf7312 --- /dev/null +++ b/tests/test_user_policies.py @@ -0,0 +1,282 @@ +"""Tests for user-policy association endpoints. + +Note: user_id is a client-generated UUID (not validated against users table), +so tests use uuid4() directly rather than creating User records. +""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_policies import ( + UK_COUNTRY_ID, + US_COUNTRY_ID, + create_policy, + create_user_policy, +) + + +def test_list_user_policies_empty(client): + """List user policies returns empty list when user has no associations.""" + user_id = uuid4() + response = client.get(f"/user-policies?user_id={user_id}") + assert response.status_code == 200 + assert response.json() == [] + + +def test_create_user_policy(client, session, tax_benefit_model): + """Create a new user-policy association.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(policy.id), + "country_id": US_COUNTRY_ID, + "label": "My test policy", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["user_id"] == str(user_id) + assert data["policy_id"] == str(policy.id) + assert data["country_id"] == US_COUNTRY_ID + assert data["label"] == "My test policy" + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + + +def test_create_user_policy_without_label(client, session, tax_benefit_model): + """Create a user-policy association without a label.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(policy.id), + "country_id": US_COUNTRY_ID, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] is None + assert data["country_id"] == US_COUNTRY_ID + + +def test_create_user_policy_policy_not_found(client): + """Create user-policy association with non-existent policy returns 404.""" + user_id = uuid4() + fake_policy_id = uuid4() + + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(fake_policy_id), + "country_id": US_COUNTRY_ID, + }, + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Policy not found" + + +def test_create_user_policy_duplicate_allowed(client, session, tax_benefit_model): + """Creating duplicate user-policy association is allowed (matches FE localStorage behavior).""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID) + + # Create duplicate - should succeed with a new ID + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(policy.id), + "country_id": US_COUNTRY_ID, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] != str(user_policy.id) # New association created + assert data["user_id"] == str(user_id) + assert data["policy_id"] == str(policy.id) + + +def test_list_user_policies_with_data( + client, session, tax_benefit_model, uk_tax_benefit_model +): + """List user policies returns all associations for a user.""" + user_id = uuid4() + policy1 = create_policy( + session, tax_benefit_model, name="Policy 1", description="First policy" + ) + policy2 = create_policy( + session, uk_tax_benefit_model, name="Policy 2", description="Second policy" + ) + create_user_policy( + session, user_id, policy1, country_id=US_COUNTRY_ID, label="US policy" + ) + create_user_policy( + session, user_id, policy2, country_id=UK_COUNTRY_ID, label="UK policy" + ) + + response = client.get(f"/user-policies?user_id={user_id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_user_policies_filter_by_country( + client, session, tax_benefit_model, uk_tax_benefit_model +): + """List user policies filtered by country_id.""" + user_id = uuid4() + policy1 = create_policy( + session, tax_benefit_model, name="Policy 1", description="First policy" + ) + policy2 = create_policy( + session, uk_tax_benefit_model, name="Policy 2", description="Second policy" + ) + create_user_policy(session, user_id, policy1, country_id=US_COUNTRY_ID) + create_user_policy(session, user_id, policy2, country_id=UK_COUNTRY_ID) + + response = client.get( + f"/user-policies?user_id={user_id}&country_id={US_COUNTRY_ID}" + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["policy_id"] == str(policy1.id) + assert data[0]["country_id"] == US_COUNTRY_ID + + +def test_get_user_policy(client, session, tax_benefit_model): + """Get a specific user-policy association by ID.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy( + session, user_id, policy, country_id=US_COUNTRY_ID, label="My policy" + ) + + response = client.get(f"/user-policies/{user_policy.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(user_policy.id) + assert data["label"] == "My policy" + assert data["country_id"] == US_COUNTRY_ID + + +def test_get_user_policy_not_found(client): + """Get a non-existent user-policy association returns 404.""" + fake_id = uuid4() + response = client.get(f"/user-policies/{fake_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "User-policy association not found" + + +def test_update_user_policy(client, session, tax_benefit_model): + """Update a user-policy association label.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy( + session, user_id, policy, country_id=US_COUNTRY_ID, label="Old label" + ) + + response = client.patch( + f"/user-policies/{user_policy.id}?user_id={user_id}", + json={"label": "New label"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] == "New label" + assert data["country_id"] == US_COUNTRY_ID + + +def test_update_user_policy_not_found(client): + """Update a non-existent user-policy association returns 404.""" + fake_id = uuid4() + fake_user_id = uuid4() + response = client.patch( + f"/user-policies/{fake_id}?user_id={fake_user_id}", + json={"label": "New label"}, + ) + assert response.status_code == 404 + assert response.json()["detail"] == "User-policy association not found" + + +def test_update_user_policy_wrong_user(client, session, tax_benefit_model): + """Update with wrong user_id returns 404 (ownership check).""" + user_id = uuid4() + wrong_user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy( + session, user_id, policy, country_id=US_COUNTRY_ID, label="Original label" + ) + + # Try to update with wrong user_id + response = client.patch( + f"/user-policies/{user_policy.id}?user_id={wrong_user_id}", + json={"label": "Hacked label"}, + ) + assert response.status_code == 404 + + # Verify original label unchanged + response = client.get(f"/user-policies/{user_policy.id}") + assert response.json()["label"] == "Original label" + + +def test_update_user_policy_rejects_extra_fields(client, session, tax_benefit_model): + """Update with extra fields returns 422 (extra='forbid').""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy( + session, user_id, policy, country_id=US_COUNTRY_ID, label="Original" + ) + + response = client.patch( + f"/user-policies/{user_policy.id}?user_id={user_id}", + json={"label": "New", "user_id": str(uuid4())}, + ) + assert response.status_code == 422 + + +def test_delete_user_policy(client, session, tax_benefit_model): + """Delete a user-policy association.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID) + + response = client.delete(f"/user-policies/{user_policy.id}?user_id={user_id}") + assert response.status_code == 204 + + # Verify it's deleted + response = client.get(f"/user-policies/{user_policy.id}") + assert response.status_code == 404 + + +def test_delete_user_policy_not_found(client): + """Delete a non-existent user-policy association returns 404.""" + fake_id = uuid4() + fake_user_id = uuid4() + response = client.delete(f"/user-policies/{fake_id}?user_id={fake_user_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "User-policy association not found" + + +def test_delete_user_policy_wrong_user(client, session, tax_benefit_model): + """Delete with wrong user_id returns 404 (ownership check).""" + user_id = uuid4() + wrong_user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID) + + # Try to delete with wrong user_id + response = client.delete(f"/user-policies/{user_policy.id}?user_id={wrong_user_id}") + assert response.status_code == 404 + + # Verify it still exists + response = client.get(f"/user-policies/{user_policy.id}") + assert response.status_code == 200 diff --git a/tests/test_user_report_associations.py b/tests/test_user_report_associations.py new file mode 100644 index 0000000..6bfb8a4 --- /dev/null +++ b/tests/test_user_report_associations.py @@ -0,0 +1,244 @@ +"""Tests for user-report association endpoints.""" + +from datetime import datetime, timezone +from uuid import uuid4 + +from test_fixtures.fixtures_user_report_associations import ( + create_report, + create_user_report_association, +) + +# --------------------------------------------------------------------------- +# POST /user-reports +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 200 with id and timestamps.""" + user_id = uuid4() + report = create_report(session) + payload = { + "user_id": str(user_id), + "report_id": str(report.id), + "country_id": "us", + "label": "My US report", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user_id) + assert data["report_id"] == str(report.id) + assert data["country_id"] == "us" + assert data["label"] == "My US report" + assert data["last_run_at"] is None + + +def test_create_association_with_last_run_at(client, session): + """Create an association with last_run_at set.""" + user_id = uuid4() + report = create_report(session) + now = datetime.now(timezone.utc).isoformat() + payload = { + "user_id": str(user_id), + "report_id": str(report.id), + "country_id": "us", + "last_run_at": now, + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 200 + assert response.json()["last_run_at"] is not None + + +def test_create_association_report_not_found(client): + """Creating with a non-existent report returns 404.""" + payload = { + "user_id": str(uuid4()), + "report_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_association_invalid_country(client, session): + """Creating with an invalid country_id returns 422.""" + report = create_report(session) + payload = { + "user_id": str(uuid4()), + "report_id": str(report.id), + "country_id": "invalid", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /user-reports/?user_id=... +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get("/user-reports/", params={"user_id": str(uuid4())}) + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user_id = uuid4() + r1 = create_report(session) + r2 = create_report(session) + create_user_report_association(session, user_id, r1, label="First") + create_user_report_association(session, user_id, r2, label="Second") + + response = client.get("/user-reports/", params={"user_id": str(user_id)}) + assert response.status_code == 200 + assert len(response.json()) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user_id = uuid4() + report = create_report(session) + create_user_report_association(session, user_id, report, country_id="us") + create_user_report_association(session, user_id, report, country_id="uk") + + response = client.get( + "/user-reports/", + params={"user_id": str(user_id), "country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-reports/{id} +# --------------------------------------------------------------------------- + + +def test_get_by_id(client, session): + """Get a specific association by ID.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report, label="Test") + + response = client.get(f"/user-reports/{assoc.id}") + assert response.status_code == 200 + assert response.json()["id"] == str(assoc.id) + assert response.json()["label"] == "Test" + + +def test_get_by_id_not_found(client): + """Get a non-existent association returns 404.""" + response = client.get(f"/user-reports/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# PATCH /user-reports/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_update_label(client, session): + """Update label via PATCH.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report, label="Old") + + response = client.patch( + f"/user-reports/{assoc.id}", + json={"label": "New label"}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["label"] == "New label" + + +def test_update_last_run_at(client, session): + """Update last_run_at via PATCH.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + now = datetime.now(timezone.utc).isoformat() + response = client.patch( + f"/user-reports/{assoc.id}", + json={"last_run_at": now}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["last_run_at"] is not None + + +def test_update_wrong_user(client, session): + """Update with wrong user_id returns 404.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report, label="Mine") + + response = client.patch( + f"/user-reports/{assoc.id}", + json={"label": "Stolen"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_update_not_found(client): + """Update a non-existent association returns 404.""" + response = client.patch( + f"/user-reports/{uuid4()}", + json={"label": "Something"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /user-reports/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + response = client.delete( + f"/user-reports/{assoc.id}", + params={"user_id": str(user_id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-reports/{assoc.id}") + assert response.status_code == 404 + + +def test_delete_wrong_user(client, session): + """Delete with wrong user_id returns 404.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + response = client.delete( + f"/user-reports/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_delete_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-reports/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 diff --git a/tests/test_user_simulation_associations.py b/tests/test_user_simulation_associations.py new file mode 100644 index 0000000..60567a2 --- /dev/null +++ b/tests/test_user_simulation_associations.py @@ -0,0 +1,236 @@ +"""Tests for user-simulation association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_simulation_associations import ( + create_simulation, + create_user_simulation_association, +) + +# --------------------------------------------------------------------------- +# POST /user-simulations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 200 with id and timestamps.""" + user_id = uuid4() + simulation = create_simulation(session) + payload = { + "user_id": str(user_id), + "simulation_id": str(simulation.id), + "country_id": "us", + "label": "My US simulation", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user_id) + assert data["simulation_id"] == str(simulation.id) + assert data["country_id"] == "us" + assert data["label"] == "My US simulation" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same simulation are allowed.""" + user_id = uuid4() + simulation = create_simulation(session) + payload = { + "user_id": str(user_id), + "simulation_id": str(simulation.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-simulations/", json=payload) + assert r1.status_code == 200 + + payload["label"] = "Second label" + r2 = client.post("/user-simulations/", json=payload) + assert r2.status_code == 200 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_simulation_not_found(client): + """Creating with a non-existent simulation returns 404.""" + payload = { + "user_id": str(uuid4()), + "simulation_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_association_invalid_country(client, session): + """Creating with an invalid country_id returns 422.""" + simulation = create_simulation(session) + payload = { + "user_id": str(uuid4()), + "simulation_id": str(simulation.id), + "country_id": "invalid", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /user-simulations/?user_id=... +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get("/user-simulations/", params={"user_id": str(uuid4())}) + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user_id = uuid4() + sim1 = create_simulation(session) + sim2 = create_simulation(session) + create_user_simulation_association(session, user_id, sim1, label="First") + create_user_simulation_association(session, user_id, sim2, label="Second") + + response = client.get("/user-simulations/", params={"user_id": str(user_id)}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user_id = uuid4() + simulation = create_simulation(session) + create_user_simulation_association(session, user_id, simulation, country_id="us") + create_user_simulation_association(session, user_id, simulation, country_id="uk") + + response = client.get( + "/user-simulations/", + params={"user_id": str(user_id), "country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-simulations/{id} +# --------------------------------------------------------------------------- + + +def test_get_by_id(client, session): + """Get a specific association by ID.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Test" + ) + + response = client.get(f"/user-simulations/{assoc.id}") + assert response.status_code == 200 + assert response.json()["id"] == str(assoc.id) + assert response.json()["label"] == "Test" + + +def test_get_by_id_not_found(client): + """Get a non-existent association returns 404.""" + response = client.get(f"/user-simulations/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# PATCH /user-simulations/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_update_label(client, session): + """Update label via PATCH.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Old" + ) + + response = client.patch( + f"/user-simulations/{assoc.id}", + json={"label": "New label"}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["label"] == "New label" + + +def test_update_wrong_user(client, session): + """Update with wrong user_id returns 404.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Mine" + ) + + response = client.patch( + f"/user-simulations/{assoc.id}", + json={"label": "Stolen"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_update_not_found(client): + """Update a non-existent association returns 404.""" + response = client.patch( + f"/user-simulations/{uuid4()}", + json={"label": "Something"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /user-simulations/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association(session, user_id, simulation) + + response = client.delete( + f"/user-simulations/{assoc.id}", + params={"user_id": str(user_id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-simulations/{assoc.id}") + assert response.status_code == 404 + + +def test_delete_wrong_user(client, session): + """Delete with wrong user_id returns 404.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association(session, user_id, simulation) + + response = client.delete( + f"/user-simulations/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_delete_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-simulations/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 diff --git a/tests/test_variables_by_name.py b/tests/test_variables_by_name.py new file mode 100644 index 0000000..3639fea --- /dev/null +++ b/tests/test_variables_by_name.py @@ -0,0 +1,228 @@ +"""Tests for POST /variables/by-name endpoint.""" + +import pytest + +from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def uk_version(session): + """Create a policyengine-uk model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="UK v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +@pytest.fixture +def us_version(session): + """Create a policyengine-us model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="US v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +def _add_var(session, version, name, entity="person", description=None): + """Create and persist a Variable.""" + var = Variable( + name=name, + entity=entity, + description=description, + tax_benefit_model_version_id=version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestVariablesByName: + """Tests for looking up variables by their exact names.""" + + def test_returns_matching_variables(self, client, session, uk_version): + """Given known variable names, returns their full metadata.""" + _add_var(session, uk_version, "employment_income") + _add_var(session, uk_version, "income_tax") + + response = client.post( + "/variables/by-name", + json={"names": ["employment_income", "income_tax"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + returned_names = {v["name"] for v in data} + assert returned_names == {"employment_income", "income_tax"} + + def test_returns_empty_list_for_empty_names(self, client): + """Given an empty names list, returns an empty list.""" + response = client.post( + "/variables/by-name", + json={"names": [], "country_id": "uk"}, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_empty_list_for_unknown_names(self, client, session, uk_version): + """Given names that don't match any variable, returns an empty list.""" + _add_var(session, uk_version, "employment_income") + + response = client.post( + "/variables/by-name", + json={"names": ["nonexistent_var", "also_missing"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_only_matching_when_mix_of_known_and_unknown( + self, client, session, uk_version + ): + """Given a mix of known and unknown names, returns only the known ones.""" + _add_var(session, uk_version, "income_tax") + + response = client.post( + "/variables/by-name", + json={"names": ["income_tax", "fake_var"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "income_tax" + + def test_single_name_lookup(self, client, session, uk_version): + """Looking up a single variable name works.""" + _add_var(session, uk_version, "age") + + response = client.post( + "/variables/by-name", + json={"names": ["age"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "age" + + def test_results_ordered_by_name(self, client, session, uk_version): + """Returned variables are sorted alphabetically by name.""" + _add_var(session, uk_version, "zzz_var") + _add_var(session, uk_version, "aaa_var") + _add_var(session, uk_version, "mmm_var") + + response = client.post( + "/variables/by-name", + json={ + "names": ["zzz_var", "aaa_var", "mmm_var"], + "country_id": "uk", + }, + ) + + assert response.status_code == 200 + names = [v["name"] for v in response.json()] + assert names == ["aaa_var", "mmm_var", "zzz_var"] + + def test_response_shape_matches_variable_read(self, client, session, uk_version): + """Returned objects have the same shape as VariableRead.""" + _add_var(session, uk_version, "income_tax", entity="person", description="Tax") + + response = client.post( + "/variables/by-name", + json={"names": ["income_tax"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + var = response.json()[0] + assert "id" in var + assert "name" in var + assert "entity" in var + assert "description" in var + assert "created_at" in var + assert "tax_benefit_model_version_id" in var + + +class TestVariablesByNameCountryFiltering: + """Tests for country_id filtering.""" + + def test_country_isolation(self, client, session, uk_version, us_version): + """Variables from a different country are excluded.""" + _add_var(session, uk_version, "council_tax") + _add_var(session, us_version, "state_income_tax") + + uk_response = client.post( + "/variables/by-name", + json={"names": ["council_tax", "state_income_tax"], "country_id": "uk"}, + ) + us_response = client.post( + "/variables/by-name", + json={"names": ["council_tax", "state_income_tax"], "country_id": "us"}, + ) + + assert len(uk_response.json()) == 1 + assert uk_response.json()[0]["name"] == "council_tax" + assert len(us_response.json()) == 1 + assert us_response.json()[0]["name"] == "state_income_tax" + + def test_invalid_country_id_returns_422(self, client): + """An invalid country_id is rejected.""" + response = client.post( + "/variables/by-name", + json={"names": ["income_tax"], "country_id": "fr"}, + ) + + assert response.status_code == 422 + + +class TestVariablesByNameValidation: + """Tests for request validation.""" + + def test_missing_country_id_returns_422(self, client): + """Request without country_id is rejected.""" + response = client.post( + "/variables/by-name", + json={"names": ["income_tax"]}, + ) + + assert response.status_code == 422 + + def test_missing_names_field_returns_422(self, client): + """Request without names field is rejected.""" + response = client.post( + "/variables/by-name", + json={"country_id": "uk"}, + ) + + assert response.status_code == 422 diff --git a/uv.lock b/uv.lock index 094ebf8..d43e7e4 100644 --- a/uv.lock +++ b/uv.lock @@ -91,6 +91,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alembic" +version = "1.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/41/ab8f624929847b49f84955c594b165855efd829b0c271e1a8cac694138e5/alembic-1.18.3.tar.gz", hash = "sha256:1212aa3778626f2b0f0aa6dd4e99a5f99b94bd25a0c1ac0bba3be65e081e50b0", size = 2052564, upload-time = "2026-01-29T20:24:15.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/8e/d79281f323e7469b060f15bd229e48d7cdd219559e67e71c013720a88340/alembic-1.18.3-py3-none-any.whl", hash = "sha256:12a0359bfc068a4ecbb9b3b02cf77856033abfdb59e4a5aca08b7eacd7b74ddd", size = 262282, upload-time = "2026-01-29T20:24:17.488Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1057,6 +1071,18 @@ sqlalchemy = [ { name = "opentelemetry-instrumentation-sqlalchemy" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -1169,15 +1195,15 @@ wheels = [ [[package]] name = "microdf-python" -version = "1.0.2" +version = "1.2.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "pandas" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/25/55c2b0495ae4c3142d61f1283d675494aac4c254e40ecf1ea4b337a051c7/microdf_python-1.0.2.tar.gz", hash = "sha256:5c845974d485598a7002c151f58ec7438e94c04954fc8fdea9238265e7bf02f5", size = 14826, upload-time = "2025-07-24T12:21:08.17Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/f0/9689f33e2524b0c0d1cdf0d556ad196bfbb2ec0292f4545f467a37b27773/microdf_python-1.2.2.tar.gz", hash = "sha256:7e5f6adc10b0469de0e6549789ede0a2e6c600d0f5c83eafffc009d1495a7933", size = 20395, upload-time = "2026-02-24T10:47:16.438Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9c/1a/aac40a7e58de4133a9cc7630913a8b8e6c76326288b168cbb47f7714c4fd/microdf_python-1.0.2-py3-none-any.whl", hash = "sha256:f7883785e4557d1c8822dbf0d69d7eeab9399f8e67a9bdb716f74554c7580ae7", size = 15823, upload-time = "2025-07-24T12:21:07.356Z" }, + { url = "https://files.pythonhosted.org/packages/cc/cc/89cd28dd8ef566a49d353870a88359f5d239d7c254e1bde5e755e067028a/microdf_python-1.2.2-py3-none-any.whl", hash = "sha256:94f8b11b6416be9d04ff86cc311ae9083614bd6d569e7a589d250e89ded3343c", size = 21476, upload-time = "2026-02-24T10:47:15.65Z" }, ] [[package]] @@ -1737,8 +1763,8 @@ wheels = [ [[package]] name = "policyengine" -version = "3.1.15" -source = { registry = "https://pypi.org/simple" } +version = "3.1.16" +source = { git = "https://github.com/PolicyEngine/policyengine.py.git?rev=app-v2-migration#6ccb46a8129e54fddc99dd41ae57b913dee46f5b" } dependencies = [ { name = "microdf-python" }, { name = "pandas" }, @@ -1747,16 +1773,13 @@ dependencies = [ { name = "pydantic" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1e/8e/e54302e342fe8995b58a575989d5742ea47ad543d9c069ef7357a90e7415/policyengine-3.1.15.tar.gz", hash = "sha256:44d6b2b74fe58bc1e6438f2809f33c0926e59c7544ec5780e38eb05db7e159b7", size = 181921, upload-time = "2025-12-14T23:52:00.289Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/58/ea/4963ab1c79923f0569d48023236d7453a0e994f7f471b1ea51a002d22f6a/policyengine-3.1.15-py3-none-any.whl", hash = "sha256:063c1eda8355021cfba8adfc98af1f01f9fdb6b21204e8d8ab3f3a41ba06b749", size = 68409, upload-time = "2025-12-14T23:51:58.977Z" }, -] [[package]] name = "policyengine-api-v2" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "alembic" }, { name = "anthropic" }, { name = "boto3" }, { name = "fastapi" }, @@ -1793,6 +1816,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.13.0" }, { name = "anthropic", specifier = ">=0.40.0" }, { name = "boto3", specifier = ">=1.41.1" }, { name = "fastapi", specifier = ">=0.115.0" }, @@ -1801,7 +1825,7 @@ requires-dist = [ { name = "httpx", marker = "extra == 'dev'", specifier = ">=0.27.2" }, { name = "logfire", extras = ["fastapi", "httpx", "sqlalchemy"], specifier = ">=0.60.0" }, { name = "modal", specifier = ">=0.68.0" }, - { name = "policyengine", specifier = ">=3.1.15" }, + { name = "policyengine", git = "https://github.com/PolicyEngine/policyengine.py.git?rev=app-v2-migration" }, { name = "policyengine-uk", specifier = ">=2.0.0" }, { name = "policyengine-us", specifier = ">=1.0.0" }, { name = "psycopg2-binary", specifier = ">=2.9.10" },