From 0191d23af1242c030c6cb074526ec765b0f639ad Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Mar 2026 20:18:25 +0100 Subject: [PATCH 1/3] feat: Add version filtering to metadata endpoints and use tax_benefit_model_name consistently All metadata endpoints now default to the latest model version when tax_benefit_model_name is provided, and accept an optional tax_benefit_model_version_id for pinning to a specific older version. The by-name and children endpoints now accept tax_benefit_model_name instead of country_id for consistency with the other metadata endpoints. Closes PolicyEngine/policyengine-api-v2-alpha#98 Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/parameter_values.py | 17 +- src/policyengine_api/api/parameters.py | 67 +++--- src/policyengine_api/api/variables.py | 43 ++-- src/policyengine_api/services/__init__.py | 13 +- .../services/tax_benefit_models.py | 109 +++++++++ tests/test_parameters.py | 116 +++++++++ tests/test_parameters_by_name.py | 122 ++++++++-- tests/test_parameters_children.py | 220 +++++++++++++++--- tests/test_variables.py | 129 ++++++++++ tests/test_variables_by_name.py | 152 +++++++++--- tests/test_version_filter_service.py | 134 +++++++++++ 11 files changed, 980 insertions(+), 142 deletions(-) create mode 100644 src/policyengine_api/services/tax_benefit_models.py create mode 100644 tests/test_version_filter_service.py diff --git a/src/policyengine_api/api/parameter_values.py b/src/policyengine_api/api/parameter_values.py index 4668ab8..9dd6819 100644 --- a/src/policyengine_api/api/parameter_values.py +++ b/src/policyengine_api/api/parameter_values.py @@ -12,8 +12,9 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, or_, select -from policyengine_api.models import ParameterValue, ParameterValueRead +from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/parameter-values", tags=["parameter-values"]) @@ -23,6 +24,8 @@ def list_parameter_values( parameter_id: UUID | None = None, policy_id: UUID | None = None, current: bool = False, + tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, skip: int = 0, limit: int = 100, session: Session = Depends(get_session), @@ -37,6 +40,10 @@ def list_parameter_values( policy_id: Filter by a specific policy reform. current: If true, only return values that are currently in effect (start_date <= now and (end_date is null or end_date > now)). + tax_benefit_model_name: Filter to values belonging to parameters from + this model. Defaults to the latest version. + tax_benefit_model_version_id: Filter to values belonging to parameters + from this specific model version. Takes precedence over model name. """ query = select(ParameterValue) @@ -46,6 +53,14 @@ def list_parameter_values( if policy_id: query = query.where(ParameterValue.policy_id == policy_id) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + if version_id: + query = query.join(Parameter).where( + Parameter.tax_benefit_model_version_id == version_id + ) + if current: now = datetime.now(timezone.utc) query = query.where( diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 72b64ef..43cf163 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -14,14 +14,12 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( Parameter, ParameterRead, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/parameters", tags=["parameters"]) @@ -32,6 +30,7 @@ def list_parameters( limit: int = 100, search: str | None = None, tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available parameters with pagination and search. @@ -44,19 +43,19 @@ def list_parameters( tax_benefit_model_name: Filter by country model. Use "policyengine-uk" for UK parameters. Use "policyengine-us" for US parameters. + Defaults to the latest model version when no version ID is given. + tax_benefit_model_version_id: Filter by a specific model version. + Takes precedence over tax_benefit_model_name. """ query = select(Parameter) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE search_pattern = f"%{search}%" search_filter = ( Parameter.name.ilike(search_pattern) @@ -75,7 +74,8 @@ class ParameterByNameRequest(BaseModel): """Request body for looking up parameters by name.""" names: list[str] - country_id: CountryId + tax_benefit_model_name: str + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[ParameterRead]) @@ -95,18 +95,17 @@ def get_parameters_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - - query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.in_(request.names)) - .order_by(Parameter.name) + version_id = resolve_model_version_id( + request.tax_benefit_model_name, + request.tax_benefit_model_version_id, + session, ) - return session.exec(query).all() + query = select(Parameter).where(Parameter.name.in_(request.names)) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Parameter.name)).all() class ParameterChild(BaseModel): @@ -128,10 +127,15 @@ class ParameterChildrenResponse(BaseModel): @router.get("/children", response_model=ParameterChildrenResponse) def get_parameter_children( - country_id: CountryId = Query(description='Country ID ("us" or "uk")'), + tax_benefit_model_name: str = Query( + description='Model name (e.g. "policyengine-us" or "policyengine-uk")' + ), parent_path: str = Query( default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')" ), + tax_benefit_model_version_id: UUID | None = Query( + default=None, description="Optional specific model version ID" + ), session: Session = Depends(get_session), ) -> ParameterChildrenResponse: """Get direct children of a parameter path for tree navigation. @@ -140,17 +144,16 @@ def get_parameter_children( parameters (with full metadata). Use this to lazily load the parameter tree one level at a time. """ - model_name = COUNTRY_MODEL_NAMES[country_id] + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + prefix = f"{parent_path}." if parent_path else "" - # Fetch all parameters under this path - query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.startswith(prefix)) - ) + query = select(Parameter).where(Parameter.name.startswith(prefix)) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) + descendants = session.exec(query).all() # Group by direct child path diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index 04aa512..edf0eea 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -12,14 +12,12 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, Variable, VariableRead, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/variables", tags=["variables"]) @@ -30,6 +28,7 @@ def list_variables( limit: int = 100, search: str | None = None, tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available variables with pagination and search. @@ -43,20 +42,19 @@ def list_variables( tax_benefit_model_name: Filter by country model. Use "policyengine-uk" for UK variables. Use "policyengine-us" for US variables. + Defaults to the latest model version when no version ID is given. + tax_benefit_model_version_id: Filter by a specific model version. + Takes precedence over tax_benefit_model_name. """ query = select(Variable) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE - # Note: Variables don't have a label field, only name and description search_pattern = f"%{search}%" search_filter = Variable.name.ilike( search_pattern @@ -73,7 +71,8 @@ class VariableByNameRequest(BaseModel): """Request body for looking up variables by name.""" names: list[str] - country_id: CountryId + tax_benefit_model_name: str + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[VariableRead]) @@ -94,17 +93,17 @@ def get_variables_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - query = ( - select(Variable) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Variable.name.in_(request.names)) - .order_by(Variable.name) + version_id = resolve_model_version_id( + request.tax_benefit_model_name, + request.tax_benefit_model_version_id, + session, ) - return session.exec(query).all() + query = select(Variable).where(Variable.name.in_(request.names)) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Variable.name)).all() @router.get("/{variable_id}", response_model=VariableRead) diff --git a/src/policyengine_api/services/__init__.py b/src/policyengine_api/services/__init__.py index a314727..1fb2a46 100644 --- a/src/policyengine_api/services/__init__.py +++ b/src/policyengine_api/services/__init__.py @@ -1,5 +1,16 @@ """Services for database and external integrations.""" from .database import get_session, init_db +from .tax_benefit_models import ( + get_latest_model_version, + get_model_version_by_id, + resolve_model_version_id, +) -__all__ = ["get_session", "init_db"] +__all__ = [ + "get_session", + "init_db", + "get_latest_model_version", + "get_model_version_by_id", + "resolve_model_version_id", +] diff --git a/src/policyengine_api/services/tax_benefit_models.py b/src/policyengine_api/services/tax_benefit_models.py new file mode 100644 index 0000000..f88d056 --- /dev/null +++ b/src/policyengine_api/services/tax_benefit_models.py @@ -0,0 +1,109 @@ +"""Tax benefit model utilities. + +Shared utilities for resolving tax benefit model versions. +""" + +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, select + +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion + + +def get_latest_model_version( + tax_benefit_model_name: str, session: Session +) -> TaxBenefitModelVersion: + """Get the latest tax benefit model version for a given model name. + + Args: + tax_benefit_model_name: The model name (e.g., "policyengine-us"). + Underscores are normalized to hyphens. + session: Database session. + + Returns: + The latest TaxBenefitModelVersion for the model. + + Raises: + HTTPException: If model or version not found. + """ + model_name = tax_benefit_model_name.replace("_", "-") + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, + detail=f"Tax benefit model '{model_name}' not found", + ) + + version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not version: + raise HTTPException( + status_code=404, + detail=f"No version found for model '{model_name}'", + ) + + return version + + +def get_model_version_by_id( + version_id: UUID, session: Session +) -> TaxBenefitModelVersion: + """Get a specific tax benefit model version by ID. + + Args: + version_id: The UUID of the model version. + session: Database session. + + Returns: + The TaxBenefitModelVersion with the given ID. + + Raises: + HTTPException: If version not found. + """ + version = session.get(TaxBenefitModelVersion, version_id) + if not version: + raise HTTPException( + status_code=404, + detail=f"Tax benefit model version '{version_id}' not found", + ) + return version + + +def resolve_model_version_id( + tax_benefit_model_name: str | None, + tax_benefit_model_version_id: UUID | None, + session: Session, +) -> UUID | None: + """Resolve the model version ID from either explicit ID or model name. + + Priority: + 1. If tax_benefit_model_version_id provided, validate and return it. + 2. If tax_benefit_model_name provided, return the latest version's ID. + 3. If neither provided, return None (no filtering). + + Args: + tax_benefit_model_name: Optional model name to resolve latest version for. + tax_benefit_model_version_id: Optional explicit version ID. + session: Database session. + + Returns: + The resolved version ID, or None if no filtering requested. + + Raises: + HTTPException: If specified version or model not found. + """ + if tax_benefit_model_version_id: + version = get_model_version_by_id(tax_benefit_model_version_id, session) + return version.id + elif tax_benefit_model_name: + version = get_latest_model_version(tax_benefit_model_name, session) + return version.id + else: + return None diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 50bb213..c8f39d0 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,9 +1,11 @@ """Tests for parameter and parameter-value endpoints.""" +from datetime import datetime, timezone from uuid import uuid4 import pytest +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion from test_fixtures.fixtures_parameters import ( create_parameter, create_parameter_value, @@ -200,5 +202,119 @@ def test__given_skip_parameter__then_skips_specified_results( assert len(response.json()) == 2 # 5 total - 3 skipped = 2 remaining +# ----------------------------------------------------------------------------- +# Version Filtering Tests +# ----------------------------------------------------------------------------- + + +def test__given_model_name__then_returns_only_latest_version_parameters( + client, session +): + """GET /parameters?tax_benefit_model_name=X returns only latest version's params.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + create_parameter(session, v1, "gov.old_param", "Old") + create_parameter(session, v2, "gov.new_param", "New") + + response = client.get("/parameters?tax_benefit_model_name=policyengine-us") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.new_param" + + +def test__given_explicit_version_id__then_returns_that_versions_parameters( + client, session +): + """GET /parameters?tax_benefit_model_version_id=X returns that version's params.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + create_parameter(session, v1, "gov.old_param", "Old") + create_parameter(session, v2, "gov.new_param", "New") + + response = client.get(f"/parameters?tax_benefit_model_version_id={v1.id}") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old_param" + + +def test__given_version_id_overrides_model_name(client, session): + """Version ID takes precedence over model name when both are provided.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + create_parameter(session, v1, "gov.old_param", "Old") + create_parameter(session, v2, "gov.new_param", "New") + + # Provide model name (would resolve to v2) but also provide explicit v1 ID + response = client.get( + f"/parameters?tax_benefit_model_name=policyengine-us&tax_benefit_model_version_id={v1.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old_param" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_parameters_by_name.py b/tests/test_parameters_by_name.py index 81bd360..686e881 100644 --- a/tests/test_parameters_by_name.py +++ b/tests/test_parameters_by_name.py @@ -1,5 +1,7 @@ """Tests for POST /parameters/by-name endpoint.""" +from datetime import datetime, timezone + import pytest from policyengine_api.models import ( @@ -55,7 +57,7 @@ def test_returns_matching_parameters(self, client, session, us_version): "/parameters/by-name", json={ "names": ["gov.tax.rate", "gov.tax.threshold"], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -65,13 +67,13 @@ def test_returns_matching_parameters(self, client, session, us_version): returned_names = {p["name"] for p in data} assert returned_names == {"gov.tax.rate", "gov.tax.threshold"} - def test_returns_empty_list_for_empty_names(self, client): + def test_returns_empty_list_for_empty_names(self, client, session, us_version): """Given an empty names list, returns an empty list.""" response = client.post( "/parameters/by-name", json={ "names": [], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -86,7 +88,7 @@ def test_returns_empty_list_for_unknown_names(self, client, session, us_version) "/parameters/by-name", json={ "names": ["gov.does_not_exist", "gov.also_missing"], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -103,7 +105,7 @@ def test_returns_only_matching_when_mix_of_known_and_unknown( "/parameters/by-name", json={ "names": ["gov.real", "gov.fake"], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -112,8 +114,8 @@ def test_returns_only_matching_when_mix_of_known_and_unknown( assert len(data) == 1 assert data[0]["name"] == "gov.real" - def test_filters_by_country(self, client, session): - """Parameters from a different country are excluded.""" + def test_filters_by_model_name(self, client, session): + """Parameters from a different model are excluded.""" # Create two models model_uk = TaxBenefitModel(name="policyengine-uk", description="UK") model_us = TaxBenefitModel(name="policyengine-us", description="US") @@ -144,7 +146,7 @@ def test_filters_by_country(self, client, session): "/parameters/by-name", json={ "names": ["gov.shared_name"], - "country_id": "uk", + "tax_benefit_model_name": "policyengine-uk", }, ) @@ -161,7 +163,7 @@ def test_response_shape_matches_parameter_read(self, client, session, us_version "/parameters/by-name", json={ "names": ["gov.shape_test"], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -185,7 +187,7 @@ def test_results_ordered_by_name(self, client, session, us_version): "/parameters/by-name", json={ "names": ["gov.zzz", "gov.aaa", "gov.mmm"], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -193,8 +195,8 @@ def test_results_ordered_by_name(self, client, session, us_version): names = [p["name"] for p in response.json()] assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] - def test_missing_country_id_returns_422(self, client): - """Request without country_id is rejected.""" + def test_missing_model_name_returns_422(self, client): + """Request without tax_benefit_model_name is rejected.""" response = client.post( "/parameters/by-name", json={"names": ["gov.something"]}, @@ -202,20 +204,11 @@ def test_missing_country_id_returns_422(self, client): assert response.status_code == 422 - def test_invalid_country_id_returns_422(self, client): - """Request with invalid country_id is rejected.""" - response = client.post( - "/parameters/by-name", - json={"names": ["gov.something"], "country_id": "invalid"}, - ) - - assert response.status_code == 422 - def test_missing_names_field_returns_422(self, client): """Request without names field is rejected.""" response = client.post( "/parameters/by-name", - json={"country_id": "us"}, + json={"tax_benefit_model_name": "policyengine-us"}, ) assert response.status_code == 422 @@ -228,7 +221,7 @@ def test_single_name_lookup(self, client, session, us_version): "/parameters/by-name", json={ "names": ["gov.single"], - "country_id": "us", + "tax_benefit_model_name": "policyengine-us", }, ) @@ -236,3 +229,86 @@ def test_single_name_lookup(self, client, session, us_version): data = response.json() assert len(data) == 1 assert data[0]["name"] == "gov.single" + + +class TestParametersByNameVersionFilter: + """Tests for version filtering on the by-name endpoint.""" + + def test_defaults_to_latest_version(self, client, session): + """When only model name is given, returns parameters from latest version.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + create_parameter(session, v1, "gov.old_param", "Old") + create_parameter(session, v2, "gov.new_param", "New") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.old_param", "gov.new_param"], + "tax_benefit_model_name": "policyengine-us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.new_param" + + def test_explicit_version_id_returns_that_version(self, client, session): + """When version ID is given, returns parameters from that specific version.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + create_parameter(session, v1, "gov.old_param", "Old") + create_parameter(session, v2, "gov.new_param", "New") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.old_param", "gov.new_param"], + "tax_benefit_model_name": "policyengine-us", + "tax_benefit_model_version_id": str(v1.id), + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old_param" diff --git a/tests/test_parameters_children.py b/tests/test_parameters_children.py index d788179..5fab96d 100644 --- a/tests/test_parameters_children.py +++ b/tests/test_parameters_children.py @@ -1,5 +1,7 @@ """Tests for GET /parameters/children endpoint.""" +from datetime import datetime, timezone + import pytest from policyengine_api.models import ( @@ -81,7 +83,11 @@ def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version) ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) assert response.status_code == 200 @@ -107,7 +113,11 @@ def test_returns_leaf_parameters(self, client, session, uk_version): ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) assert response.status_code == 200 @@ -136,7 +146,11 @@ def test_mixed_nodes_and_leaves(self, client, session, uk_version): ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) children = response.json()["children"] @@ -162,7 +176,11 @@ def test_child_count_reflects_total_descendants(self, client, session, uk_versio ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) children = response.json()["children"] @@ -184,7 +202,10 @@ def test_nested_child_count(self, client, session, uk_version): response = client.get( "/parameters/children", - params={"country_id": "uk", "parent_path": "gov.hmrc"}, + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov.hmrc", + }, ) children = response.json()["children"] @@ -199,7 +220,11 @@ def test_leaf_has_no_child_count(self, client, session, uk_version): _add_params(session, uk_version, [("gov.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) children = response.json()["children"] @@ -207,41 +232,57 @@ def test_leaf_has_no_child_count(self, client, session, uk_version): assert children[0]["child_count"] is None -class TestCountryFiltering: - """Tests for country_id filtering.""" +class TestModelNameFiltering: + """Tests for tax_benefit_model_name filtering.""" - def test_uk_country_id(self, client, session, uk_version): - """country_id=uk returns UK parameters.""" + def test_uk_model(self, client, session, uk_version): + """policyengine-uk returns UK parameters.""" _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) assert response.status_code == 200 assert len(response.json()["children"]) == 1 - def test_us_country_id(self, client, session, us_version): - """country_id=us returns US parameters.""" + def test_us_model(self, client, session, us_version): + """policyengine-us returns US parameters.""" _add_params(session, us_version, [("gov.irs.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-us", + "parent_path": "gov", + }, ) assert response.status_code == 200 assert len(response.json()["children"]) == 1 - def test_country_isolation(self, client, session, uk_version, us_version): - """Parameters from a different country are excluded.""" + def test_model_isolation(self, client, session, uk_version, us_version): + """Parameters from a different model are excluded.""" _add_params(session, uk_version, [("gov.hmrc.rate", "UK rate")]) _add_params(session, us_version, [("gov.irs.rate", "US rate")]) uk_response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) us_response = client.get( - "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-us", + "parent_path": "gov", + }, ) uk_paths = [c["path"] for c in uk_response.json()["children"]] @@ -249,11 +290,10 @@ def test_country_isolation(self, client, session, uk_version, us_version): assert uk_paths == ["gov.hmrc"] assert us_paths == ["gov.irs"] - def test_invalid_country_id_returns_422(self, client): - """An invalid country_id is rejected by Literal validation.""" + def test_missing_model_name_returns_422(self, client): + """Request without tax_benefit_model_name returns 422.""" response = client.get( - "/parameters/children", - params={"country_id": "fr", "parent_path": "gov"}, + "/parameters/children", params={"parent_path": "gov"} ) assert response.status_code == 422 @@ -267,7 +307,11 @@ def test_empty_parent_path(self, client, session, uk_version): _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": ""} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "", + }, ) assert response.status_code == 200 @@ -282,7 +326,10 @@ def test_nonexistent_parent_returns_empty(self, client, session, uk_version): response = client.get( "/parameters/children", - params={"country_id": "uk", "parent_path": "gov.dwp"}, + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov.dwp", + }, ) assert response.status_code == 200 @@ -301,7 +348,11 @@ def test_children_sorted_by_path(self, client, session, uk_version): ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) paths = [c["path"] for c in response.json()["children"]] @@ -312,23 +363,24 @@ def test_node_label_from_path_segment(self, client, session, uk_version): _add_params(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) children = response.json()["children"] assert children[0]["label"] == "hmrc" - def test_missing_country_id_returns_422(self, client): - """Request without country_id returns 422.""" - response = client.get("/parameters/children", params={"parent_path": "gov"}) - - assert response.status_code == 422 - def test_default_parent_path_is_empty(self, client, session, uk_version): """Omitting parent_path defaults to empty string (root level).""" _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) - response = client.get("/parameters/children", params={"country_id": "uk"}) + response = client.get( + "/parameters/children", + params={"tax_benefit_model_name": "policyengine-uk"}, + ) assert response.status_code == 200 assert response.json()["parent_path"] == "" @@ -339,7 +391,11 @@ def test_leaf_parameter_includes_full_metadata(self, client, session, uk_version _add_params(session, uk_version, [("gov.rate", "The rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) param = response.json()["children"][0]["parameter"] @@ -354,7 +410,11 @@ def test_node_has_no_parameter_field(self, client, session, uk_version): _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, ) node = response.json()["children"][0] @@ -378,7 +438,10 @@ def test_deep_nesting(self, client, session, uk_version): ]: resp = client.get( "/parameters/children", - params={"country_id": "uk", "parent_path": parent}, + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": parent, + }, ) children = resp.json()["children"] assert len(children) == 1 @@ -389,7 +452,7 @@ def test_deep_nesting(self, client, session, uk_version): resp = client.get( "/parameters/children", params={ - "country_id": "uk", + "tax_benefit_model_name": "policyengine-uk", "parent_path": "gov.hmrc.income_tax.rates.uk[0]", }, ) @@ -397,3 +460,86 @@ def test_deep_nesting(self, client, session, uk_version): assert len(children) == 1 assert children[0]["type"] == "parameter" assert children[0]["path"] == "gov.hmrc.income_tax.rates.uk[0].rate" + + +class TestVersionFiltering: + """Tests for version filtering on the children endpoint.""" + + def test_defaults_to_latest_version(self, client, session): + """When only model name is given, returns children from latest version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _add_params(session, v1, [("gov.old_param", "Old")]) + _add_params(session, v2, [("gov.new_param", "New")]) + + response = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + }, + ) + + assert response.status_code == 200 + children = response.json()["children"] + assert len(children) == 1 + assert children[0]["path"] == "gov.new_param" + + def test_explicit_version_id_returns_that_version(self, client, session): + """When version ID is given, returns children from that specific version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _add_params(session, v1, [("gov.old_param", "Old")]) + _add_params(session, v2, [("gov.new_param", "New")]) + + response = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": "policyengine-uk", + "parent_path": "gov", + "tax_benefit_model_version_id": str(v1.id), + }, + ) + + assert response.status_code == 200 + children = response.json()["children"] + assert len(children) == 1 + assert children[0]["path"] == "gov.old_param" diff --git a/tests/test_variables.py b/tests/test_variables.py index 76cb47a..d64e6aa 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -1,9 +1,12 @@ """Tests for variable endpoints.""" +from datetime import datetime, timezone from uuid import uuid4 import pytest +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion, Variable + def test_list_variables(client): """List variables returns a list.""" @@ -19,5 +22,131 @@ def test_get_variable_not_found(client): assert response.status_code == 404 +# ----------------------------------------------------------------------------- +# Version Filtering Tests +# ----------------------------------------------------------------------------- + + +def _create_var(session, version, name): + """Create and persist a Variable.""" + var = Variable( + name=name, + entity="person", + tax_benefit_model_version_id=version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var + + +def test__given_model_name__then_returns_only_latest_version_variables( + client, session +): + """GET /variables?tax_benefit_model_name=X returns only latest version's vars.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _create_var(session, v1, "old_variable") + _create_var(session, v2, "new_variable") + + response = client.get("/variables?tax_benefit_model_name=policyengine-us") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "new_variable" + + +def test__given_explicit_version_id__then_returns_that_versions_variables( + client, session +): + """GET /variables?tax_benefit_model_version_id=X returns that version's vars.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _create_var(session, v1, "old_variable") + _create_var(session, v2, "new_variable") + + response = client.get(f"/variables?tax_benefit_model_version_id={v1.id}") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "old_variable" + + +def test__given_version_id_overrides_model_name(client, session): + """Version ID takes precedence over model name when both are provided.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _create_var(session, v1, "old_variable") + _create_var(session, v2, "new_variable") + + response = client.get( + f"/variables?tax_benefit_model_name=policyengine-us&tax_benefit_model_version_id={v1.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "old_variable" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_variables_by_name.py b/tests/test_variables_by_name.py index 3639fea..7727d22 100644 --- a/tests/test_variables_by_name.py +++ b/tests/test_variables_by_name.py @@ -1,5 +1,7 @@ """Tests for POST /variables/by-name endpoint.""" +from datetime import datetime, timezone + import pytest from policyengine_api.models import ( @@ -76,7 +78,10 @@ def test_returns_matching_variables(self, client, session, uk_version): response = client.post( "/variables/by-name", - json={"names": ["employment_income", "income_tax"], "country_id": "uk"}, + json={ + "names": ["employment_income", "income_tax"], + "tax_benefit_model_name": "policyengine-uk", + }, ) assert response.status_code == 200 @@ -85,11 +90,14 @@ def test_returns_matching_variables(self, client, session, uk_version): returned_names = {v["name"] for v in data} assert returned_names == {"employment_income", "income_tax"} - def test_returns_empty_list_for_empty_names(self, client): + def test_returns_empty_list_for_empty_names(self, client, session, uk_version): """Given an empty names list, returns an empty list.""" response = client.post( "/variables/by-name", - json={"names": [], "country_id": "uk"}, + json={ + "names": [], + "tax_benefit_model_name": "policyengine-uk", + }, ) assert response.status_code == 200 @@ -101,7 +109,10 @@ def test_returns_empty_list_for_unknown_names(self, client, session, uk_version) response = client.post( "/variables/by-name", - json={"names": ["nonexistent_var", "also_missing"], "country_id": "uk"}, + json={ + "names": ["nonexistent_var", "also_missing"], + "tax_benefit_model_name": "policyengine-uk", + }, ) assert response.status_code == 200 @@ -115,7 +126,10 @@ def test_returns_only_matching_when_mix_of_known_and_unknown( response = client.post( "/variables/by-name", - json={"names": ["income_tax", "fake_var"], "country_id": "uk"}, + json={ + "names": ["income_tax", "fake_var"], + "tax_benefit_model_name": "policyengine-uk", + }, ) assert response.status_code == 200 @@ -129,7 +143,10 @@ def test_single_name_lookup(self, client, session, uk_version): response = client.post( "/variables/by-name", - json={"names": ["age"], "country_id": "uk"}, + json={ + "names": ["age"], + "tax_benefit_model_name": "policyengine-uk", + }, ) assert response.status_code == 200 @@ -147,7 +164,7 @@ def test_results_ordered_by_name(self, client, session, uk_version): "/variables/by-name", json={ "names": ["zzz_var", "aaa_var", "mmm_var"], - "country_id": "uk", + "tax_benefit_model_name": "policyengine-uk", }, ) @@ -161,7 +178,10 @@ def test_response_shape_matches_variable_read(self, client, session, uk_version) response = client.post( "/variables/by-name", - json={"names": ["income_tax"], "country_id": "uk"}, + json={ + "names": ["income_tax"], + "tax_benefit_model_name": "policyengine-uk", + }, ) assert response.status_code == 200 @@ -174,21 +194,27 @@ def test_response_shape_matches_variable_read(self, client, session, uk_version) assert "tax_benefit_model_version_id" in var -class TestVariablesByNameCountryFiltering: - """Tests for country_id filtering.""" +class TestVariablesByNameModelFiltering: + """Tests for tax_benefit_model_name filtering.""" - def test_country_isolation(self, client, session, uk_version, us_version): - """Variables from a different country are excluded.""" + def test_model_isolation(self, client, session, uk_version, us_version): + """Variables from a different model are excluded.""" _add_var(session, uk_version, "council_tax") _add_var(session, us_version, "state_income_tax") uk_response = client.post( "/variables/by-name", - json={"names": ["council_tax", "state_income_tax"], "country_id": "uk"}, + json={ + "names": ["council_tax", "state_income_tax"], + "tax_benefit_model_name": "policyengine-uk", + }, ) us_response = client.post( "/variables/by-name", - json={"names": ["council_tax", "state_income_tax"], "country_id": "us"}, + json={ + "names": ["council_tax", "state_income_tax"], + "tax_benefit_model_name": "policyengine-us", + }, ) assert len(uk_response.json()) == 1 @@ -196,21 +222,12 @@ def test_country_isolation(self, client, session, uk_version, us_version): assert len(us_response.json()) == 1 assert us_response.json()[0]["name"] == "state_income_tax" - def test_invalid_country_id_returns_422(self, client): - """An invalid country_id is rejected.""" - response = client.post( - "/variables/by-name", - json={"names": ["income_tax"], "country_id": "fr"}, - ) - - assert response.status_code == 422 - class TestVariablesByNameValidation: """Tests for request validation.""" - def test_missing_country_id_returns_422(self, client): - """Request without country_id is rejected.""" + def test_missing_model_name_returns_422(self, client): + """Request without tax_benefit_model_name is rejected.""" response = client.post( "/variables/by-name", json={"names": ["income_tax"]}, @@ -222,7 +239,90 @@ def test_missing_names_field_returns_422(self, client): """Request without names field is rejected.""" response = client.post( "/variables/by-name", - json={"country_id": "uk"}, + json={"tax_benefit_model_name": "policyengine-uk"}, ) assert response.status_code == 422 + + +class TestVariablesByNameVersionFilter: + """Tests for version filtering on the by-name endpoint.""" + + def test_defaults_to_latest_version(self, client, session): + """When only model name is given, returns variables from latest version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _add_var(session, v1, "old_var") + _add_var(session, v2, "new_var") + + response = client.post( + "/variables/by-name", + json={ + "names": ["old_var", "new_var"], + "tax_benefit_model_name": "policyengine-uk", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "new_var" + + def test_explicit_version_id_returns_that_version(self, client, session): + """When version ID is given, returns variables from that specific version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, + version="1.0", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=model.id, + version="2.0", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + + _add_var(session, v1, "old_var") + _add_var(session, v2, "new_var") + + response = client.post( + "/variables/by-name", + json={ + "names": ["old_var", "new_var"], + "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_version_id": str(v1.id), + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "old_var" diff --git a/tests/test_version_filter_service.py b/tests/test_version_filter_service.py new file mode 100644 index 0000000..232b905 --- /dev/null +++ b/tests/test_version_filter_service.py @@ -0,0 +1,134 @@ +"""Tests for the tax benefit model version resolution service.""" + +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion +from policyengine_api.services.tax_benefit_models import ( + get_latest_model_version, + get_model_version_by_id, + resolve_model_version_id, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def us_model(session): + """Create a policyengine-us model.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + return model + + +@pytest.fixture +def us_versions(session, us_model): + """Create two versions for the US model, v1 older than v2.""" + v1 = TaxBenefitModelVersion( + model_id=us_model.id, + version="1.0.0", + description="First version", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + v2 = TaxBenefitModelVersion( + model_id=us_model.id, + version="2.0.0", + description="Second version", + created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + session.add(v1) + session.add(v2) + session.commit() + session.refresh(v1) + session.refresh(v2) + return v1, v2 + + +# --------------------------------------------------------------------------- +# get_latest_model_version +# --------------------------------------------------------------------------- + + +class TestGetLatestModelVersion: + def test_returns_latest_version(self, session, us_model, us_versions): + """Given multiple versions, returns the one with the newest created_at.""" + v1, v2 = us_versions + result = get_latest_model_version("policyengine-us", session) + assert result.id == v2.id + assert result.version == "2.0.0" + + def test_normalizes_underscores_to_hyphens(self, session, us_model, us_versions): + """Underscore names like 'policyengine_us' are normalized.""" + v1, v2 = us_versions + result = get_latest_model_version("policyengine_us", session) + assert result.id == v2.id + + def test_nonexistent_model_raises_404(self, session): + """A model name that doesn't exist raises HTTPException 404.""" + with pytest.raises(HTTPException) as exc_info: + get_latest_model_version("nonexistent-model", session) + assert exc_info.value.status_code == 404 + + def test_model_with_no_versions_raises_404(self, session, us_model): + """A model that exists but has no versions raises HTTPException 404.""" + with pytest.raises(HTTPException) as exc_info: + get_latest_model_version("policyengine-us", session) + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# get_model_version_by_id +# --------------------------------------------------------------------------- + + +class TestGetModelVersionById: + def test_returns_version(self, session, us_versions): + """Given a valid version UUID, returns that version.""" + v1, v2 = us_versions + result = get_model_version_by_id(v1.id, session) + assert result.id == v1.id + assert result.version == "1.0.0" + + def test_nonexistent_id_raises_404(self, session): + """A UUID that doesn't match any version raises HTTPException 404.""" + with pytest.raises(HTTPException) as exc_info: + get_model_version_by_id(uuid4(), session) + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# resolve_model_version_id +# --------------------------------------------------------------------------- + + +class TestResolveModelVersionId: + def test_version_id_takes_precedence(self, session, us_versions): + """When both model name and version ID are given, version ID wins.""" + v1, v2 = us_versions + result = resolve_model_version_id("policyengine-us", v1.id, session) + assert result == v1.id + + def test_model_name_resolves_to_latest(self, session, us_model, us_versions): + """When only model name is given, resolves to the latest version.""" + v1, v2 = us_versions + result = resolve_model_version_id("policyengine-us", None, session) + assert result == v2.id + + def test_neither_returns_none(self, session): + """When neither model name nor version ID is given, returns None.""" + result = resolve_model_version_id(None, None, session) + assert result is None + + def test_invalid_version_id_raises_404(self, session): + """An explicit version ID that doesn't exist raises 404.""" + with pytest.raises(HTTPException) as exc_info: + resolve_model_version_id(None, uuid4(), session) + assert exc_info.value.status_code == 404 From fdc098a5adac5806c8e36cf668cabb5bcef864a8 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Mar 2026 21:19:55 +0100 Subject: [PATCH 2/3] test: Refactor version-filter tests to use shared fixtures Move all inline model/version/parameter/variable factories into test_fixtures/fixtures_version_filter.py and rewrite test files to import from there. Add new test_parameter_values.py covering CRUD, filtering, pagination, and version filtering for parameter values. Co-Authored-By: Claude Opus 4.6 --- test_fixtures/fixtures_version_filter.py | 207 ++++++++++ tests/test_parameter_values.py | 258 ++++++++++++ tests/test_parameters.py | 487 +++++++++-------------- tests/test_parameters_by_name.py | 381 +++++++----------- tests/test_parameters_children.py | 450 ++++++++++----------- tests/test_variables.py | 322 ++++++++------- tests/test_variables_by_name.py | 383 +++++++----------- tests/test_version_filter_service.py | 134 +++---- 8 files changed, 1402 insertions(+), 1220 deletions(-) create mode 100644 test_fixtures/fixtures_version_filter.py create mode 100644 tests/test_parameter_values.py diff --git a/test_fixtures/fixtures_version_filter.py b/test_fixtures/fixtures_version_filter.py new file mode 100644 index 0000000..c7304ad --- /dev/null +++ b/test_fixtures/fixtures_version_filter.py @@ -0,0 +1,207 @@ +"""Fixtures and helpers for version-filtering tests. + +Provides reusable model/version/parameter/variable factories and composite +fixtures for testing version-filter behaviour across endpoints. +""" + +from datetime import datetime, timezone + +import pytest + +from policyengine_api.models import ( + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, +) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +MODEL_NAMES = { + "US": "policyengine-us", + "UK": "policyengine-uk", +} + +VERSION_TIMESTAMPS = { + "V1": datetime(2025, 1, 1, tzinfo=timezone.utc), + "V2": datetime(2025, 6, 1, tzinfo=timezone.utc), +} + + +# ----------------------------------------------------------------------------- +# Factory Functions +# ----------------------------------------------------------------------------- + + +def create_model( + session, name: str = MODEL_NAMES["US"], description: str = "Test model" +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel.""" + model = TaxBenefitModel(name=name, description=description) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_version( + session, + model: TaxBenefitModel, + version: str = "1.0.0", + created_at: datetime | None = None, +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion.""" + ver = TaxBenefitModelVersion( + model_id=model.id, + version=version, + description=f"Version {version}", + **({"created_at": created_at} if created_at else {}), + ) + session.add(ver) + session.commit() + session.refresh(ver) + return ver + + +def create_parameter( + session, + model_version: TaxBenefitModelVersion, + name: str, + label: str = "", + description: str | None = None, +) -> Parameter: + """Create and persist a Parameter.""" + param = Parameter( + name=name, + label=label or name.rsplit(".", 1)[-1], + description=description, + tax_benefit_model_version_id=model_version.id, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_variable( + session, + model_version: TaxBenefitModelVersion, + name: str, + entity: str = "person", + description: str | None = None, +) -> Variable: + """Create and persist a Variable.""" + var = Variable( + name=name, + entity=entity, + description=description, + tax_benefit_model_version_id=model_version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var + + +def create_parameter_value( + session, + parameter_id, + value: int | float | dict, + policy_id=None, + start_date: datetime | None = None, +) -> ParameterValue: + """Create and persist a ParameterValue.""" + pv = ParameterValue( + parameter_id=parameter_id, + value_json=value, + start_date=start_date or datetime.now(timezone.utc), + policy_id=policy_id, + ) + session.add(pv) + session.commit() + session.refresh(pv) + return pv + + +def create_policy( + session, + model: TaxBenefitModel, + name: str = "Test Policy", +) -> Policy: + """Create and persist a Policy.""" + policy = Policy( + name=name, + description=f"Policy: {name}", + tax_benefit_model_id=model.id, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def add_params_bulk(session, version, names_and_labels): + """Bulk-add parameters. names_and_labels is [(name, label), ...].""" + for name, label in names_and_labels: + session.add( + Parameter( + name=name, + label=label, + tax_benefit_model_version_id=version.id, + ) + ) + session.commit() + + +# ----------------------------------------------------------------------------- +# Composite Fixtures — single model + single version +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def us_model(session): + """Create a policyengine-us model.""" + return create_model(session, MODEL_NAMES["US"], "US model") + + +@pytest.fixture +def uk_model(session): + """Create a policyengine-uk model.""" + return create_model(session, MODEL_NAMES["UK"], "UK model") + + +@pytest.fixture +def us_version(session, us_model): + """Create a single US model version.""" + return create_version(session, us_model, "1.0") + + +@pytest.fixture +def uk_version(session, uk_model): + """Create a single UK model version.""" + return create_version(session, uk_model, "1.0") + + +# ----------------------------------------------------------------------------- +# Composite Fixtures — single model + TWO versions (for version-filter tests) +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def us_two_versions(session, us_model): + """Create two US versions: v1 (old) and v2 (latest).""" + v1 = create_version(session, us_model, "1.0", VERSION_TIMESTAMPS["V1"]) + v2 = create_version(session, us_model, "2.0", VERSION_TIMESTAMPS["V2"]) + return v1, v2 + + +@pytest.fixture +def uk_two_versions(session, uk_model): + """Create two UK versions: v1 (old) and v2 (latest).""" + v1 = create_version(session, uk_model, "1.0", VERSION_TIMESTAMPS["V1"]) + v2 = create_version(session, uk_model, "2.0", VERSION_TIMESTAMPS["V2"]) + return v1, v2 diff --git a/tests/test_parameter_values.py b/tests/test_parameter_values.py new file mode 100644 index 0000000..1792f53 --- /dev/null +++ b/tests/test_parameter_values.py @@ -0,0 +1,258 @@ +"""Tests for GET /parameter-values/ and GET /parameter-values/{id} endpoints.""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest + +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_parameter, + create_parameter_value, + create_policy, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 +) + + +# ----------------------------------------------------------------------------- +# GET /parameter-values/ — basic +# ----------------------------------------------------------------------------- + + +class TestListParameterValues: + def test_given_no_values_then_returns_empty_list(self, client): + """Empty database returns an empty list.""" + response = client.get("/parameter-values") + assert response.status_code == 200 + assert response.json() == [] + + def test_given_values_exist_then_returns_list( + self, client, session, us_version # noqa: F811 + ): + """Returns parameter values that exist.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + create_parameter_value(session, param.id, 0.2) + + data = client.get("/parameter-values").json() + assert len(data) == 1 + + def test_given_parameter_id_then_filters_by_parameter( + self, client, session, us_version # noqa: F811 + ): + """Filters values to a specific parameter.""" + p1 = create_parameter(session, us_version, "gov.rate", "Rate") + p2 = create_parameter(session, us_version, "gov.threshold", "Threshold") + create_parameter_value(session, p1.id, 0.2) + create_parameter_value(session, p2.id, 12570) + + data = client.get(f"/parameter-values?parameter_id={p1.id}").json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p1.id) + + def test_given_policy_id_then_filters_by_policy( + self, client, session, us_version, us_model # noqa: F811 + ): + """Filters values to a specific policy.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + policy = create_policy(session, us_model, "Reform A") + create_parameter_value(session, param.id, 0.2) + create_parameter_value(session, param.id, 0.25, policy_id=policy.id) + + data = client.get(f"/parameter-values?policy_id={policy.id}").json() + assert len(data) == 1 + assert data[0]["policy_id"] == str(policy.id) + + def test_given_combined_parameter_and_policy_filters( + self, client, session, us_version, us_model # noqa: F811 + ): + """parameter_id + policy_id work together.""" + p1 = create_parameter(session, us_version, "gov.rate", "Rate") + p2 = create_parameter(session, us_version, "gov.threshold", "Threshold") + policy = create_policy(session, us_model, "Reform A") + create_parameter_value(session, p1.id, 0.2, policy_id=policy.id) + create_parameter_value(session, p2.id, 12570, policy_id=policy.id) + create_parameter_value(session, p1.id, 0.15) + + data = client.get( + f"/parameter-values?parameter_id={p1.id}&policy_id={policy.id}" + ).json() + assert len(data) == 1 + + def test_given_limit_then_returns_at_most_n( + self, client, session, us_version # noqa: F811 + ): + """Limit caps the number of results.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + for i in range(5): + create_parameter_value( + session, + param.id, + i * 0.1, + start_date=datetime(2020 + i, 1, 1, tzinfo=timezone.utc), + ) + + assert len(client.get("/parameter-values?limit=2").json()) == 2 + + def test_given_skip_then_skips_first_n( + self, client, session, us_version # noqa: F811 + ): + """Skip omits the first N results.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + for i in range(5): + create_parameter_value( + session, + param.id, + i * 0.1, + start_date=datetime(2020 + i, 1, 1, tzinfo=timezone.utc), + ) + + assert len(client.get("/parameter-values?skip=3&limit=10").json()) == 2 + + def test_results_ordered_by_start_date_desc( + self, client, session, us_version # noqa: F811 + ): + """Parameter values come back sorted by start_date descending.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + create_parameter_value( + session, param.id, 0.1, + start_date=datetime(2020, 1, 1, tzinfo=timezone.utc), + ) + create_parameter_value( + session, param.id, 0.2, + start_date=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + + data = client.get("/parameter-values").json() + assert len(data) == 2 + # Most recent first + dates = [d["start_date"] for d in data] + assert dates[0] > dates[1] + + +# ----------------------------------------------------------------------------- +# GET /parameter-values/ — version filtering +# ----------------------------------------------------------------------------- + + +class TestListParameterValuesVersionFilter: + def test_given_model_name_then_returns_only_latest_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Model name resolves to latest version; old-version param values excluded.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get( + f"/parameter-values?tax_benefit_model_name={MODEL_NAMES['US']}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p_new.id) + + def test_given_explicit_version_id_then_returns_that_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Explicit version_id pins to a specific version.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get( + f"/parameter-values?tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p_old.id) + + def test_given_both_then_version_id_takes_precedence( + self, client, session, us_two_versions # noqa: F811 + ): + """version_id overrides model_name.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get( + f"/parameter-values?tax_benefit_model_name={MODEL_NAMES['US']}" + f"&tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p_old.id) + + def test_given_no_filters_then_returns_all_versions( + self, client, session, us_two_versions # noqa: F811 + ): + """Without model/version filter, values from all versions are returned.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get("/parameter-values").json() + assert len(data) == 2 + + def test_given_nonexistent_model_name_then_returns_404(self, client): + """Unknown model name returns 404.""" + response = client.get( + "/parameter-values?tax_benefit_model_name=nonexistent-model" + ) + assert response.status_code == 404 + + def test_given_version_filter_combined_with_parameter_id( + self, client, session, us_two_versions # noqa: F811 + ): + """Version filter + parameter_id work together.""" + v1, v2 = us_two_versions + p1 = create_parameter(session, v2, "gov.rate", "Rate") + p2 = create_parameter(session, v2, "gov.threshold", "Threshold") + create_parameter_value(session, p1.id, 0.2) + create_parameter_value(session, p2.id, 12570) + + data = client.get( + f"/parameter-values?tax_benefit_model_name={MODEL_NAMES['US']}" + f"¶meter_id={p1.id}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p1.id) + + +# ----------------------------------------------------------------------------- +# GET /parameter-values/{id} +# ----------------------------------------------------------------------------- + + +class TestGetParameterValue: + def test_given_valid_id_then_returns_value( + self, client, session, us_version # noqa: F811 + ): + """Returns the full parameter value data.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + pv = create_parameter_value(session, param.id, 0.2) + + response = client.get(f"/parameter-values/{pv.id}") + assert response.status_code == 200 + assert response.json()["parameter_id"] == str(param.id) + + def test_given_nonexistent_id_then_returns_404(self, client): + """Unknown UUID returns 404.""" + response = client.get(f"/parameter-values/{uuid4()}") + assert response.status_code == 404 + + def test_response_shape_matches_parameter_value_read( + self, client, session, us_version # noqa: F811 + ): + """Response contains all ParameterValueRead fields.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + pv = create_parameter_value(session, param.id, 0.2) + + data = client.get(f"/parameter-values/{pv.id}").json() + for field in ("id", "parameter_id", "value_json", "start_date", "created_at"): + assert field in data diff --git a/tests/test_parameters.py b/tests/test_parameters.py index c8f39d0..eeb5673 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,320 +1,215 @@ -"""Tests for parameter and parameter-value endpoints.""" +"""Tests for GET /parameters/ and GET /parameters/{id} endpoints.""" -from datetime import datetime, timezone from uuid import uuid4 import pytest -from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion -from test_fixtures.fixtures_parameters import ( +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, create_parameter, - create_parameter_value, - create_parameter_values_batch, - create_policy, - model_version, # noqa: F401 - pytest fixture + create_version, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 ) # ----------------------------------------------------------------------------- -# Parameter Endpoint Tests +# GET /parameters/ — basic # ----------------------------------------------------------------------------- -def test__given_parameters_endpoint_called__then_returns_list(client): - """GET /parameters returns a list.""" - # Given - endpoint = "/parameters" +class TestListParameters: + def test_given_no_params_then_returns_empty_list(self, client): + """Empty database returns an empty list.""" + response = client.get("/parameters") + assert response.status_code == 200 + assert response.json() == [] + + def test_given_parameters_exist_then_returns_list( + self, client, session, us_version # noqa: F811 + ): + """Returns parameters that exist.""" + create_parameter(session, us_version, "gov.rate", "Rate") + response = client.get("/parameters") + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_given_search_by_name_then_returns_matching( + self, client, session, us_version # noqa: F811 + ): + """Search filter matches parameter name.""" + create_parameter(session, us_version, "gov.tax.rate", "Rate") + create_parameter(session, us_version, "gov.benefit.amount", "Amount") + + response = client.get("/parameters?search=tax") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.tax.rate" + + def test_given_search_by_label_then_returns_matching( + self, client, session, us_version # noqa: F811 + ): + """Search filter matches parameter label (case-insensitive).""" + create_parameter(session, us_version, "gov.x", "Basic Rate") + create_parameter(session, us_version, "gov.y", "Amount") + + response = client.get("/parameters?search=basic") + data = response.json() + assert len(data) == 1 + assert data[0]["label"] == "Basic Rate" + + def test_given_search_by_description_then_returns_matching( + self, client, session, us_version # noqa: F811 + ): + """Search filter matches parameter description.""" + create_parameter( + session, us_version, "gov.x", "X", description="The income tax rate" + ) + create_parameter(session, us_version, "gov.y", "Y", description="A benefit") + + response = client.get("/parameters?search=income") + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.x" + + def test_given_limit_then_returns_at_most_n( + self, client, session, us_version # noqa: F811 + ): + """Limit caps the number of results.""" + for i in range(5): + create_parameter(session, us_version, f"gov.p{i}", f"P{i}") + + response = client.get("/parameters?limit=2") + assert len(response.json()) == 2 + + def test_given_skip_then_skips_first_n( + self, client, session, us_version # noqa: F811 + ): + """Skip omits the first N results.""" + for i in range(5): + create_parameter(session, us_version, f"gov.p{i}", f"P{i}") + + response = client.get("/parameters?skip=3&limit=10") + assert len(response.json()) == 2 + + def test_results_ordered_by_name( + self, client, session, us_version # noqa: F811 + ): + """Parameters come back sorted alphabetically by name.""" + create_parameter(session, us_version, "gov.zzz", "Z") + create_parameter(session, us_version, "gov.aaa", "A") + names = [p["name"] for p in client.get("/parameters").json()] + assert names == ["gov.aaa", "gov.zzz"] - # When - response = client.get(endpoint) - # Then - assert response.status_code == 200 - assert isinstance(response.json(), list) - - -def test__given_nonexistent_parameter_id__then_returns_404(client): - """GET /parameters/{id} returns 404 for non-existent parameter.""" - # Given - fake_id = uuid4() - - # When - response = client.get(f"/parameters/{fake_id}") - - # Then - assert response.status_code == 404 - - -# ----------------------------------------------------------------------------- -# Parameter Value Endpoint Tests -# ----------------------------------------------------------------------------- - - -def test__given_parameter_values_endpoint_called__then_returns_list(client): - """GET /parameter-values returns a list.""" - # Given - endpoint = "/parameter-values" - - # When - response = client.get(endpoint) - - # Then - assert response.status_code == 200 - assert isinstance(response.json(), list) - - -def test__given_nonexistent_parameter_value_id__then_returns_404(client): - """GET /parameter-values/{id} returns 404 for non-existent parameter value.""" - # Given - fake_id = uuid4() - - # When - response = client.get(f"/parameter-values/{fake_id}") - - # Then - assert response.status_code == 404 - - -# ----------------------------------------------------------------------------- -# Parameter Value Filtering Tests # ----------------------------------------------------------------------------- - - -def test__given_parameter_id_filter__then_returns_only_matching_values( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?parameter_id=X returns only values for that parameter.""" - # Given - param1 = create_parameter(session, model_version, "test.param1", "Test Param 1") - param2 = create_parameter(session, model_version, "test.param2", "Test Param 2") - create_parameter_value(session, param1.id, 100) - create_parameter_value(session, param2.id, 200) - - # When - response = client.get(f"/parameter-values?parameter_id={param1.id}") - - # Then - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["parameter_id"] == str(param1.id) - - -def test__given_policy_id_filter__then_returns_only_matching_values( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?policy_id=X returns only values for that policy.""" - # Given - param = create_parameter(session, model_version, "test.param", "Test Param") - policy = create_policy(session, "Test Policy", model_version) - create_parameter_value(session, param.id, 100, policy_id=None) # baseline - create_parameter_value(session, param.id, 150, policy_id=policy.id) # reform - - # When - response = client.get(f"/parameter-values?policy_id={policy.id}") - - # Then - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["policy_id"] == str(policy.id) - assert data[0]["value_json"] == 150 - - -def test__given_both_parameter_and_policy_filters__then_returns_matching_intersection( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?parameter_id=X&policy_id=Y returns intersection.""" - # Given - param1 = create_parameter( - session, model_version, "test.both.param1", "Test Both Param 1" - ) - param2 = create_parameter( - session, model_version, "test.both.param2", "Test Both Param 2" - ) - policy = create_policy(session, "Test Both Policy", model_version) - - create_parameter_value(session, param1.id, 100, policy_id=None) # baseline - create_parameter_value(session, param1.id, 150, policy_id=policy.id) # target - create_parameter_value(session, param2.id, 200, policy_id=policy.id) # other - - # When - response = client.get( - f"/parameter-values?parameter_id={param1.id}&policy_id={policy.id}" - ) - - # Then - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["parameter_id"] == str(param1.id) - assert data[0]["policy_id"] == str(policy.id) - assert data[0]["value_json"] == 150 - - +# GET /parameters/ — version filtering # ----------------------------------------------------------------------------- -# Parameter Value Pagination Tests -# ----------------------------------------------------------------------------- - - -def test__given_limit_parameter__then_returns_limited_results( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?limit=N returns at most N results.""" - # Given - param = create_parameter( - session, model_version, "test.pagination.param", "Test Pagination Param" - ) - create_parameter_values_batch(session, param.id, count=5) - - # When - response = client.get(f"/parameter-values?parameter_id={param.id}&limit=2") - - # Then - assert response.status_code == 200 - assert len(response.json()) == 2 - - -def test__given_skip_parameter__then_skips_specified_results( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?skip=N skips first N results.""" - # Given - param = create_parameter( - session, model_version, "test.skip.param", "Test Skip Param" - ) - create_parameter_values_batch(session, param.id, count=5) - # When - response = client.get(f"/parameter-values?parameter_id={param.id}&skip=3&limit=10") - # Then - assert response.status_code == 200 - assert len(response.json()) == 2 # 5 total - 3 skipped = 2 remaining +class TestListParametersVersionFilter: + def test_given_model_name_then_returns_only_latest_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Model name resolves to latest version; old-version params excluded.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get( + f"/parameters?tax_benefit_model_name={MODEL_NAMES['US']}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.new" + + def test_given_explicit_version_id_then_returns_that_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Explicit version_id pins to a specific version.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get( + f"/parameters?tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old" + + def test_given_both_then_version_id_takes_precedence( + self, client, session, us_two_versions # noqa: F811 + ): + """version_id overrides model_name.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get( + f"/parameters?tax_benefit_model_name={MODEL_NAMES['US']}" + f"&tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old" + + def test_given_no_filters_then_returns_all_versions( + self, client, session, us_two_versions # noqa: F811 + ): + """Without model/version filter, params from all versions are returned.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get("/parameters").json() + assert len(data) == 2 + + def test_given_nonexistent_model_name_then_returns_404(self, client): + """Unknown model name → 404.""" + response = client.get( + "/parameters?tax_benefit_model_name=nonexistent-model" + ) + assert response.status_code == 404 + + def test_given_search_combined_with_version_filter( + self, client, session, us_two_versions # noqa: F811 + ): + """Search + version filter work together.""" + v1, v2 = us_two_versions + create_parameter(session, v2, "gov.tax.rate", "Rate") + create_parameter(session, v2, "gov.benefit.amount", "Amount") + + data = client.get( + f"/parameters?tax_benefit_model_name={MODEL_NAMES['US']}&search=tax" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.tax.rate" # ----------------------------------------------------------------------------- -# Version Filtering Tests +# GET /parameters/{id} # ----------------------------------------------------------------------------- -def test__given_model_name__then_returns_only_latest_version_parameters( - client, session -): - """GET /parameters?tax_benefit_model_name=X returns only latest version's params.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - create_parameter(session, v1, "gov.old_param", "Old") - create_parameter(session, v2, "gov.new_param", "New") - - response = client.get("/parameters?tax_benefit_model_name=policyengine-us") - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "gov.new_param" - - -def test__given_explicit_version_id__then_returns_that_versions_parameters( - client, session -): - """GET /parameters?tax_benefit_model_version_id=X returns that version's params.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - create_parameter(session, v1, "gov.old_param", "Old") - create_parameter(session, v2, "gov.new_param", "New") - - response = client.get(f"/parameters?tax_benefit_model_version_id={v1.id}") - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "gov.old_param" - - -def test__given_version_id_overrides_model_name(client, session): - """Version ID takes precedence over model name when both are provided.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - create_parameter(session, v1, "gov.old_param", "Old") - create_parameter(session, v2, "gov.new_param", "New") - - # Provide model name (would resolve to v2) but also provide explicit v1 ID - response = client.get( - f"/parameters?tax_benefit_model_name=policyengine-us&tax_benefit_model_version_id={v1.id}" - ) - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "gov.old_param" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +class TestGetParameter: + def test_given_valid_id_then_returns_parameter( + self, client, session, us_version # noqa: F811 + ): + """Returns the full parameter data.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + response = client.get(f"/parameters/{param.id}") + assert response.status_code == 200 + assert response.json()["name"] == "gov.rate" + + def test_given_nonexistent_id_then_returns_404(self, client): + """Unknown UUID → 404.""" + response = client.get(f"/parameters/{uuid4()}") + assert response.status_code == 404 + + def test_response_shape_matches_parameter_read( + self, client, session, us_version # noqa: F811 + ): + """Response contains all ParameterRead fields.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + data = client.get(f"/parameters/{param.id}").json() + for field in ("id", "name", "label", "created_at", "tax_benefit_model_version_id"): + assert field in data diff --git a/tests/test_parameters_by_name.py b/tests/test_parameters_by_name.py index 686e881..d08a8fb 100644 --- a/tests/test_parameters_by_name.py +++ b/tests/test_parameters_by_name.py @@ -1,314 +1,229 @@ """Tests for POST /parameters/by-name endpoint.""" -from datetime import datetime, timezone - import pytest -from policyengine_api.models import ( - Parameter, - TaxBenefitModel, - TaxBenefitModelVersion, +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_parameter, + uk_model, # noqa: F401 + uk_version, # noqa: F401 + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def us_version(session): - """Create a policyengine-us model and version.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="US v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -def create_parameter(session, model_version, name: str, label: str) -> Parameter: - """Create and persist a Parameter.""" - param = Parameter( - name=name, - label=label, - tax_benefit_model_version_id=model_version.id, - ) - session.add(param) - session.commit() - session.refresh(param) - return param +# ----------------------------------------------------------------------------- +# Happy-path lookups +# ----------------------------------------------------------------------------- -class TestParametersByName: - """Tests for looking up parameters by their exact names.""" - def test_returns_matching_parameters(self, client, session, us_version): - """Given known parameter names, returns their full metadata.""" - create_parameter(session, us_version, "gov.tax.rate", "Tax rate") +class TestParametersByName: + def test_given_known_names_then_returns_matching( + self, client, session, us_version # noqa: F811 + ): + """Returns full metadata for each matching name.""" + create_parameter(session, us_version, "gov.tax.rate", "Rate") create_parameter(session, us_version, "gov.tax.threshold", "Threshold") - response = client.post( + data = client.post( "/parameters/by-name", json={ "names": ["gov.tax.rate", "gov.tax.threshold"], - "tax_benefit_model_name": "policyengine-us", + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) + ).json() - assert response.status_code == 200 - data = response.json() assert len(data) == 2 - returned_names = {p["name"] for p in data} - assert returned_names == {"gov.tax.rate", "gov.tax.threshold"} + assert {p["name"] for p in data} == {"gov.tax.rate", "gov.tax.threshold"} - def test_returns_empty_list_for_empty_names(self, client, session, us_version): - """Given an empty names list, returns an empty list.""" - response = client.post( + def test_given_empty_names_then_returns_empty_list( + self, client, session, us_version # noqa: F811 + ): + """Empty names list → empty response (no DB query).""" + data = client.post( "/parameters/by-name", - json={ - "names": [], - "tax_benefit_model_name": "policyengine-us", - }, - ) - - assert response.status_code == 200 - assert response.json() == [] + json={"names": [], "tax_benefit_model_name": MODEL_NAMES["US"]}, + ).json() + assert data == [] - def test_returns_empty_list_for_unknown_names(self, client, session, us_version): - """Given names that don't match any parameter, returns an empty list.""" + def test_given_unknown_names_then_returns_empty_list( + self, client, session, us_version # noqa: F811 + ): + """Names that don't match anything → empty list.""" create_parameter(session, us_version, "gov.exists", "Exists") - response = client.post( + data = client.post( "/parameters/by-name", json={ - "names": ["gov.does_not_exist", "gov.also_missing"], - "tax_benefit_model_name": "policyengine-us", + "names": ["gov.nope", "gov.also_missing"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - assert response.json() == [] + ).json() + assert data == [] - def test_returns_only_matching_when_mix_of_known_and_unknown( - self, client, session, us_version + def test_given_mixed_names_then_returns_only_known( + self, client, session, us_version # noqa: F811 ): - """Given a mix of known and unknown names, returns only the known ones.""" - create_parameter(session, us_version, "gov.real", "Real param") + """Only matching names are returned; unknowns silently omitted.""" + create_parameter(session, us_version, "gov.real", "Real") - response = client.post( + data = client.post( "/parameters/by-name", json={ "names": ["gov.real", "gov.fake"], - "tax_benefit_model_name": "policyengine-us", + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 assert data[0]["name"] == "gov.real" - def test_filters_by_model_name(self, client, session): - """Parameters from a different model are excluded.""" - # Create two models - model_uk = TaxBenefitModel(name="policyengine-uk", description="UK") - model_us = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model_uk) - session.add(model_us) - session.commit() - session.refresh(model_uk) - session.refresh(model_us) - - ver_uk = TaxBenefitModelVersion( - model_id=model_uk.id, version="1.0", description="UK v1" - ) - ver_us = TaxBenefitModelVersion( - model_id=model_us.id, version="1.0", description="US v1" - ) - session.add(ver_uk) - session.add(ver_us) - session.commit() - session.refresh(ver_uk) - session.refresh(ver_us) - - # Same parameter name in both models - create_parameter(session, ver_uk, "gov.shared_name", "UK version") - create_parameter(session, ver_us, "gov.shared_name", "US version") - - # Request only UK - response = client.post( + def test_given_single_name_then_returns_one( + self, client, session, us_version # noqa: F811 + ): + """Single-element lookup works.""" + create_parameter(session, us_version, "gov.single", "Single") + data = client.post( "/parameters/by-name", json={ - "names": ["gov.shared_name"], - "tax_benefit_model_name": "policyengine-uk", + "names": ["gov.single"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 - assert data[0]["label"] == "UK version" - def test_response_shape_matches_parameter_read(self, client, session, us_version): - """Returned objects have the same shape as ParameterRead.""" - create_parameter(session, us_version, "gov.shape_test", "Shape test") + def test_results_ordered_by_name( + self, client, session, us_version # noqa: F811 + ): + """Response is sorted alphabetically by name.""" + create_parameter(session, us_version, "gov.zzz", "Z") + create_parameter(session, us_version, "gov.aaa", "A") + create_parameter(session, us_version, "gov.mmm", "M") + + names = [ + p["name"] + for p in client.post( + "/parameters/by-name", + json={ + "names": ["gov.zzz", "gov.aaa", "gov.mmm"], + "tax_benefit_model_name": MODEL_NAMES["US"], + }, + ).json() + ] + assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] - response = client.post( + def test_response_shape(self, client, session, us_version): # noqa: F811 + """Each returned object has the ParameterRead fields.""" + create_parameter(session, us_version, "gov.shape", "Shape") + param = client.post( "/parameters/by-name", json={ - "names": ["gov.shape_test"], - "tax_benefit_model_name": "policyengine-us", + "names": ["gov.shape"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) + ).json()[0] + for field in ("id", "name", "label", "created_at", "tax_benefit_model_version_id"): + assert field in param - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - param = data[0] - assert "id" in param - assert "name" in param - assert "label" in param - assert "created_at" in param - assert "tax_benefit_model_version_id" in param - - def test_results_ordered_by_name(self, client, session, us_version): - """Returned parameters are sorted alphabetically by name.""" - create_parameter(session, us_version, "gov.zzz", "Last") - create_parameter(session, us_version, "gov.aaa", "First") - create_parameter(session, us_version, "gov.mmm", "Middle") - response = client.post( +# ----------------------------------------------------------------------------- +# Model isolation +# ----------------------------------------------------------------------------- + + +class TestParametersByNameModelIsolation: + def test_given_two_models_then_returns_only_requested( + self, client, session, us_version, uk_version # noqa: F811 + ): + """Parameters from the other model are excluded.""" + create_parameter(session, us_version, "gov.shared", "US") + create_parameter(session, uk_version, "gov.shared", "UK") + + data = client.post( "/parameters/by-name", json={ - "names": ["gov.zzz", "gov.aaa", "gov.mmm"], - "tax_benefit_model_name": "policyengine-us", + "names": ["gov.shared"], + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) + ).json() + assert len(data) == 1 + assert data[0]["label"] == "UK" - assert response.status_code == 200 - names = [p["name"] for p in response.json()] - assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] - def test_missing_model_name_returns_422(self, client): - """Request without tax_benefit_model_name is rejected.""" +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- + + +class TestParametersByNameValidation: + def test_given_missing_model_name_then_422(self, client): + """Omitting tax_benefit_model_name → 422.""" response = client.post( - "/parameters/by-name", - json={"names": ["gov.something"]}, + "/parameters/by-name", json={"names": ["gov.x"]} ) - assert response.status_code == 422 - def test_missing_names_field_returns_422(self, client): - """Request without names field is rejected.""" + def test_given_missing_names_then_422(self, client): + """Omitting names → 422.""" response = client.post( "/parameters/by-name", - json={"tax_benefit_model_name": "policyengine-us"}, + json={"tax_benefit_model_name": MODEL_NAMES["US"]}, ) - assert response.status_code == 422 - def test_single_name_lookup(self, client, session, us_version): - """Looking up a single parameter name works.""" - create_parameter(session, us_version, "gov.single", "Single param") - + def test_given_nonexistent_model_name_then_404( + self, client, session + ): + """Model that doesn't exist → 404 from resolve_model_version_id.""" response = client.post( "/parameters/by-name", json={ - "names": ["gov.single"], - "tax_benefit_model_name": "policyengine-us", + "names": ["gov.x"], + "tax_benefit_model_name": "nonexistent-model", }, ) + assert response.status_code == 404 - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "gov.single" +# ----------------------------------------------------------------------------- +# Version filtering +# ----------------------------------------------------------------------------- -class TestParametersByNameVersionFilter: - """Tests for version filtering on the by-name endpoint.""" - - def test_defaults_to_latest_version(self, client, session): - """When only model name is given, returns parameters from latest version.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - create_parameter(session, v1, "gov.old_param", "Old") - create_parameter(session, v2, "gov.new_param", "New") +class TestParametersByNameVersionFilter: + def test_given_model_name_only_then_defaults_to_latest( + self, client, session, us_two_versions # noqa: F811 + ): + """Model name resolves to latest version.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") - response = client.post( + data = client.post( "/parameters/by-name", json={ - "names": ["gov.old_param", "gov.new_param"], - "tax_benefit_model_name": "policyengine-us", + "names": ["gov.old", "gov.new"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 - assert data[0]["name"] == "gov.new_param" - - def test_explicit_version_id_returns_that_version(self, client, session): - """When version ID is given, returns parameters from that specific version.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) + assert data[0]["name"] == "gov.new" - create_parameter(session, v1, "gov.old_param", "Old") - create_parameter(session, v2, "gov.new_param", "New") + def test_given_explicit_version_id_then_returns_that_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Explicit version_id overrides latest-version default.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") - response = client.post( + data = client.post( "/parameters/by-name", json={ - "names": ["gov.old_param", "gov.new_param"], - "tax_benefit_model_name": "policyengine-us", + "names": ["gov.old", "gov.new"], + "tax_benefit_model_name": MODEL_NAMES["US"], "tax_benefit_model_version_id": str(v1.id), }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 - assert data[0]["name"] == "gov.old_param" + assert data[0]["name"] == "gov.old" diff --git a/tests/test_parameters_children.py b/tests/test_parameters_children.py index 5fab96d..a3da03d 100644 --- a/tests/test_parameters_children.py +++ b/tests/test_parameters_children.py @@ -1,78 +1,30 @@ """Tests for GET /parameters/children endpoint.""" -from datetime import datetime, timezone - import pytest -from policyengine_api.models import ( - Parameter, - TaxBenefitModel, - TaxBenefitModelVersion, +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + add_params_bulk, + create_parameter, + uk_model, # noqa: F401 + uk_two_versions, # noqa: F401 + uk_version, # noqa: F401 + us_model, # noqa: F401 + us_version, # noqa: F401 ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def uk_version(session): - """Create a policyengine-uk model and version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="UK v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -@pytest.fixture -def us_version(session): - """Create a policyengine-us model and version.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="US v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -def _add_params(session, version, names_and_labels): - """Bulk-add parameters. names_and_labels is [(name, label), ...].""" - for name, label in names_and_labels: - session.add( - Parameter( - name=name, - label=label, - tax_benefit_model_version_id=version.id, - ) - ) - session.commit() - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# Tree structure +# ----------------------------------------------------------------------------- class TestParameterChildrenBasic: - """Basic tree structure tests.""" - - def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version): + def test_returns_nodes_for_intermediate_paths( + self, client, session, uk_version # noqa: F811 + ): """Parameters at gov.hmrc.x and gov.dwp.x produce nodes for hmrc and dwp.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -85,7 +37,7 @@ def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version) response = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, ) @@ -101,9 +53,11 @@ def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version) assert child["type"] == "node" assert child["child_count"] > 0 - def test_returns_leaf_parameters(self, client, session, uk_version): + def test_returns_leaf_parameters( + self, client, session, uk_version # noqa: F811 + ): """Direct child parameters are returned with type='parameter'.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -115,7 +69,7 @@ def test_returns_leaf_parameters(self, client, session, uk_version): response = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, ) @@ -133,9 +87,11 @@ def test_returns_leaf_parameters(self, client, session, uk_version): node = next(c for c in children if c["type"] == "node") assert node["path"] == "gov.hmrc" - def test_mixed_nodes_and_leaves(self, client, session, uk_version): + def test_mixed_nodes_and_leaves( + self, client, session, uk_version # noqa: F811 + ): """Both nodes and leaf parameters can appear at the same level.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -145,27 +101,31 @@ def test_mixed_nodes_and_leaves(self, client, session, uk_version): ], ) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"] - children = response.json()["children"] types = {c["path"]: c["type"] for c in children} assert types["gov.hmrc"] == "node" assert types["gov.flat_rate"] == "parameter" assert types["gov.threshold"] == "parameter" -class TestChildCount: - """Tests for child_count accuracy.""" +# ----------------------------------------------------------------------------- +# Child counts +# ----------------------------------------------------------------------------- + - def test_child_count_reflects_total_descendants(self, client, session, uk_version): +class TestChildCount: + def test_child_count_reflects_total_descendants( + self, client, session, uk_version # noqa: F811 + ): """child_count counts all leaf parameters under the node.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -175,22 +135,23 @@ def test_child_count_reflects_total_descendants(self, client, session, uk_versio ], ) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"] - children = response.json()["children"] hmrc = children[0] assert hmrc["path"] == "gov.hmrc" assert hmrc["child_count"] == 3 - def test_nested_child_count(self, client, session, uk_version): + def test_nested_child_count( + self, client, session, uk_version # noqa: F811 + ): """Querying a deeper level gives accurate child counts.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -200,49 +161,54 @@ def test_nested_child_count(self, client, session, uk_version): ], ) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov.hmrc", }, - ) + ).json()["children"] - children = response.json()["children"] assert len(children) == 2 income_tax = next(c for c in children if c["path"] == "gov.hmrc.income_tax") ni = next(c for c in children if c["path"] == "gov.hmrc.ni") assert income_tax["child_count"] == 2 assert ni["child_count"] == 1 - def test_leaf_has_no_child_count(self, client, session, uk_version): + def test_leaf_has_no_child_count( + self, client, session, uk_version # noqa: F811 + ): """Leaf parameters have child_count=None.""" - _add_params(session, uk_version, [("gov.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.rate", "Rate")]) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"] - children = response.json()["children"] assert len(children) == 1 assert children[0]["child_count"] is None -class TestModelNameFiltering: - """Tests for tax_benefit_model_name filtering.""" +# ----------------------------------------------------------------------------- +# Model isolation +# ----------------------------------------------------------------------------- - def test_uk_model(self, client, session, uk_version): + +class TestParameterChildrenModelIsolation: + def test_given_uk_model_then_returns_uk_params( + self, client, session, uk_version # noqa: F811 + ): """policyengine-uk returns UK parameters.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, ) @@ -250,14 +216,16 @@ def test_uk_model(self, client, session, uk_version): assert response.status_code == 200 assert len(response.json()["children"]) == 1 - def test_us_model(self, client, session, us_version): + def test_given_us_model_then_returns_us_params( + self, client, session, us_version # noqa: F811 + ): """policyengine-us returns US parameters.""" - _add_params(session, us_version, [("gov.irs.rate", "Rate")]) + add_params_bulk(session, us_version, [("gov.irs.rate", "Rate")]) response = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-us", + "tax_benefit_model_name": MODEL_NAMES["US"], "parent_path": "gov", }, ) @@ -265,51 +233,68 @@ def test_us_model(self, client, session, us_version): assert response.status_code == 200 assert len(response.json()["children"]) == 1 - def test_model_isolation(self, client, session, uk_version, us_version): + def test_given_two_models_then_returns_only_requested( + self, client, session, uk_version, us_version # noqa: F811 + ): """Parameters from a different model are excluded.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "UK rate")]) - _add_params(session, us_version, [("gov.irs.rate", "US rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "UK rate")]) + add_params_bulk(session, us_version, [("gov.irs.rate", "US rate")]) - uk_response = client.get( - "/parameters/children", - params={ - "tax_benefit_model_name": "policyengine-uk", - "parent_path": "gov", - }, - ) - us_response = client.get( - "/parameters/children", - params={ - "tax_benefit_model_name": "policyengine-us", - "parent_path": "gov", - }, - ) + uk_paths = [ + c["path"] + for c in client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] + ] + us_paths = [ + c["path"] + for c in client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["US"], + "parent_path": "gov", + }, + ).json()["children"] + ] - uk_paths = [c["path"] for c in uk_response.json()["children"]] - us_paths = [c["path"] for c in us_response.json()["children"]] assert uk_paths == ["gov.hmrc"] assert us_paths == ["gov.irs"] - def test_missing_model_name_returns_422(self, client): + +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- + + +class TestParameterChildrenValidation: + def test_given_missing_model_name_then_422(self, client): """Request without tax_benefit_model_name returns 422.""" response = client.get( "/parameters/children", params={"parent_path": "gov"} ) - assert response.status_code == 422 -class TestEdgeCases: - """Tests for edge cases and special inputs.""" +# ----------------------------------------------------------------------------- +# Edge cases +# ----------------------------------------------------------------------------- + - def test_empty_parent_path(self, client, session, uk_version): +class TestParameterChildrenEdgeCases: + def test_empty_parent_path( + self, client, session, uk_version # noqa: F811 + ): """Empty parent_path returns top-level children.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "", }, ) @@ -320,24 +305,27 @@ def test_empty_parent_path(self, client, session, uk_version): assert children[0]["path"] == "gov" assert children[0]["type"] == "node" - def test_nonexistent_parent_returns_empty(self, client, session, uk_version): + def test_nonexistent_parent_returns_empty( + self, client, session, uk_version # noqa: F811 + ): """A parent path with no descendants returns empty children list.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov.dwp", }, - ) + ).json()["children"] - assert response.status_code == 200 - assert response.json()["children"] == [] + assert children == [] - def test_children_sorted_by_path(self, client, session, uk_version): + def test_children_sorted_by_path( + self, client, session, uk_version # noqa: F811 + ): """Children are returned sorted alphabetically by path.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -347,199 +335,167 @@ def test_children_sorted_by_path(self, client, session, uk_version): ], ) - response = client.get( - "/parameters/children", - params={ - "tax_benefit_model_name": "policyengine-uk", - "parent_path": "gov", - }, - ) - - paths = [c["path"] for c in response.json()["children"]] + paths = [ + c["path"] + for c in client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] + ] assert paths == ["gov.aaa", "gov.mmm", "gov.zzz"] - def test_node_label_from_path_segment(self, client, session, uk_version): - """Node labels default to the last path segment when no parameter exists.""" - _add_params(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) + def test_node_label_from_path_segment( + self, client, session, uk_version # noqa: F811 + ): + """Node labels default to the last path segment.""" + add_params_bulk( + session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")] + ) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"] - children = response.json()["children"] assert children[0]["label"] == "hmrc" - def test_default_parent_path_is_empty(self, client, session, uk_version): + def test_default_parent_path_is_empty( + self, client, session, uk_version # noqa: F811 + ): """Omitting parent_path defaults to empty string (root level).""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) - response = client.get( + data = client.get( "/parameters/children", - params={"tax_benefit_model_name": "policyengine-uk"}, - ) + params={"tax_benefit_model_name": MODEL_NAMES["UK"]}, + ).json() - assert response.status_code == 200 - assert response.json()["parent_path"] == "" - assert len(response.json()["children"]) == 1 + assert data["parent_path"] == "" + assert len(data["children"]) == 1 - def test_leaf_parameter_includes_full_metadata(self, client, session, uk_version): + def test_leaf_parameter_includes_full_metadata( + self, client, session, uk_version # noqa: F811 + ): """Leaf parameters include the full ParameterRead shape.""" - _add_params(session, uk_version, [("gov.rate", "The rate")]) + add_params_bulk(session, uk_version, [("gov.rate", "The rate")]) - response = client.get( + param = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"][0]["parameter"] - param = response.json()["children"][0]["parameter"] assert param["name"] == "gov.rate" assert param["label"] == "The rate" - assert "id" in param - assert "created_at" in param - assert "tax_benefit_model_version_id" in param + for field in ("id", "created_at", "tax_benefit_model_version_id"): + assert field in param - def test_node_has_no_parameter_field(self, client, session, uk_version): + def test_node_has_no_parameter_field( + self, client, session, uk_version # noqa: F811 + ): """Nodes do not include the parameter field.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) - response = client.get( + node = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"][0] - node = response.json()["children"][0] assert node["type"] == "node" assert node["parameter"] is None - def test_deep_nesting(self, client, session, uk_version): + def test_deep_nesting( + self, client, session, uk_version # noqa: F811 + ): """Works correctly with deeply nested parameter paths.""" - _add_params( + add_params_bulk( session, uk_version, [("gov.hmrc.income_tax.rates.uk[0].rate", "Basic rate")], ) - # Each level should show the correct child for parent, expected_child in [ ("gov", "gov.hmrc"), ("gov.hmrc", "gov.hmrc.income_tax"), ("gov.hmrc.income_tax", "gov.hmrc.income_tax.rates"), ("gov.hmrc.income_tax.rates", "gov.hmrc.income_tax.rates.uk[0]"), ]: - resp = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": parent, }, - ) - children = resp.json()["children"] + ).json()["children"] assert len(children) == 1 assert children[0]["path"] == expected_child assert children[0]["type"] == "node" # Final level should be a leaf - resp = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov.hmrc.income_tax.rates.uk[0]", }, - ) - children = resp.json()["children"] + ).json()["children"] assert len(children) == 1 assert children[0]["type"] == "parameter" assert children[0]["path"] == "gov.hmrc.income_tax.rates.uk[0].rate" -class TestVersionFiltering: - """Tests for version filtering on the children endpoint.""" +# ----------------------------------------------------------------------------- +# Version filtering +# ----------------------------------------------------------------------------- - def test_defaults_to_latest_version(self, client, session): - """When only model name is given, returns children from latest version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - _add_params(session, v1, [("gov.old_param", "Old")]) - _add_params(session, v2, [("gov.new_param", "New")]) +class TestParameterChildrenVersionFilter: + def test_given_model_name_only_then_defaults_to_latest( + self, client, session, uk_two_versions # noqa: F811 + ): + """When only model name is given, returns children from latest version.""" + v1, v2 = uk_two_versions + add_params_bulk(session, v1, [("gov.old_param", "Old")]) + add_params_bulk(session, v2, [("gov.new_param", "New")]) - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", }, - ) + ).json()["children"] - assert response.status_code == 200 - children = response.json()["children"] assert len(children) == 1 assert children[0]["path"] == "gov.new_param" - def test_explicit_version_id_returns_that_version(self, client, session): + def test_given_explicit_version_id_then_returns_that_version( + self, client, session, uk_two_versions # noqa: F811 + ): """When version ID is given, returns children from that specific version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) + v1, v2 = uk_two_versions + add_params_bulk(session, v1, [("gov.old_param", "Old")]) + add_params_bulk(session, v2, [("gov.new_param", "New")]) - _add_params(session, v1, [("gov.old_param", "Old")]) - _add_params(session, v2, [("gov.new_param", "New")]) - - response = client.get( + children = client.get( "/parameters/children", params={ - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov", "tax_benefit_model_version_id": str(v1.id), }, - ) + ).json()["children"] - assert response.status_code == 200 - children = response.json()["children"] assert len(children) == 1 assert children[0]["path"] == "gov.old_param" diff --git a/tests/test_variables.py b/tests/test_variables.py index d64e6aa..c70d8bc 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -1,152 +1,200 @@ -"""Tests for variable endpoints.""" +"""Tests for GET /variables/ and GET /variables/{id} endpoints.""" -from datetime import datetime, timezone from uuid import uuid4 import pytest -from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion, Variable +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_variable, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 +) -def test_list_variables(client): - """List variables returns a list.""" - response = client.get("/variables") - assert response.status_code == 200 - assert isinstance(response.json(), list) +# ----------------------------------------------------------------------------- +# GET /variables/ — basic +# ----------------------------------------------------------------------------- + + +class TestListVariables: + def test_given_no_variables_then_returns_empty_list(self, client): + """Empty database returns an empty list.""" + response = client.get("/variables") + assert response.status_code == 200 + assert response.json() == [] + + def test_given_variables_exist_then_returns_list( + self, client, session, us_version # noqa: F811 + ): + """Returns variables that exist.""" + create_variable(session, us_version, "employment_income") + response = client.get("/variables") + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_given_search_by_name_then_returns_matching( + self, client, session, us_version # noqa: F811 + ): + """Search filter matches variable name.""" + create_variable(session, us_version, "employment_income") + create_variable(session, us_version, "income_tax") + + data = client.get("/variables?search=employment").json() + assert len(data) == 1 + assert data[0]["name"] == "employment_income" + + def test_given_search_by_description_then_returns_matching( + self, client, session, us_version # noqa: F811 + ): + """Search filter matches variable description.""" + create_variable( + session, us_version, "var_x", description="Total household income" + ) + create_variable( + session, us_version, "var_y", description="Tax liability" + ) + + data = client.get("/variables?search=household").json() + assert len(data) == 1 + assert data[0]["name"] == "var_x" + + def test_given_limit_then_returns_at_most_n( + self, client, session, us_version # noqa: F811 + ): + """Limit caps the number of results.""" + for i in range(5): + create_variable(session, us_version, f"var_{i}") + + assert len(client.get("/variables?limit=2").json()) == 2 + + def test_given_skip_then_skips_first_n( + self, client, session, us_version # noqa: F811 + ): + """Skip omits the first N results.""" + for i in range(5): + create_variable(session, us_version, f"var_{i}") + + assert len(client.get("/variables?skip=3&limit=10").json()) == 2 + + def test_results_ordered_by_name( + self, client, session, us_version # noqa: F811 + ): + """Variables come back sorted alphabetically by name.""" + create_variable(session, us_version, "zzz_var") + create_variable(session, us_version, "aaa_var") + names = [v["name"] for v in client.get("/variables").json()] + assert names == ["aaa_var", "zzz_var"] + + +# ----------------------------------------------------------------------------- +# GET /variables/ — version filtering +# ----------------------------------------------------------------------------- -def test_get_variable_not_found(client): - """Get non-existent variable returns 404.""" - fake_id = uuid4() - response = client.get(f"/variables/{fake_id}") - assert response.status_code == 404 +class TestListVariablesVersionFilter: + def test_given_model_name_then_returns_only_latest_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Model name resolves to latest version; old-version vars excluded.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get( + f"/variables?tax_benefit_model_name={MODEL_NAMES['US']}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "new_variable" + + def test_given_explicit_version_id_then_returns_that_version( + self, client, session, us_two_versions # noqa: F811 + ): + """Explicit version_id pins to a specific version.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get( + f"/variables?tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "old_variable" + + def test_given_both_then_version_id_takes_precedence( + self, client, session, us_two_versions # noqa: F811 + ): + """version_id overrides model_name.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get( + f"/variables?tax_benefit_model_name={MODEL_NAMES['US']}" + f"&tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "old_variable" + + def test_given_no_filters_then_returns_all_versions( + self, client, session, us_two_versions # noqa: F811 + ): + """Without model/version filter, vars from all versions are returned.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get("/variables").json() + assert len(data) == 2 + + def test_given_nonexistent_model_name_then_returns_404(self, client): + """Unknown model name returns 404.""" + response = client.get( + "/variables?tax_benefit_model_name=nonexistent-model" + ) + assert response.status_code == 404 + + def test_given_search_combined_with_version_filter( + self, client, session, us_two_versions # noqa: F811 + ): + """Search + version filter work together.""" + v1, v2 = us_two_versions + create_variable(session, v2, "employment_income") + create_variable(session, v2, "income_tax") + + data = client.get( + f"/variables?tax_benefit_model_name={MODEL_NAMES['US']}&search=employment" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "employment_income" # ----------------------------------------------------------------------------- -# Version Filtering Tests +# GET /variables/{id} # ----------------------------------------------------------------------------- -def _create_var(session, version, name): - """Create and persist a Variable.""" - var = Variable( - name=name, - entity="person", - tax_benefit_model_version_id=version.id, - ) - session.add(var) - session.commit() - session.refresh(var) - return var - - -def test__given_model_name__then_returns_only_latest_version_variables( - client, session -): - """GET /variables?tax_benefit_model_name=X returns only latest version's vars.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - _create_var(session, v1, "old_variable") - _create_var(session, v2, "new_variable") - - response = client.get("/variables?tax_benefit_model_name=policyengine-us") - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "new_variable" - - -def test__given_explicit_version_id__then_returns_that_versions_variables( - client, session -): - """GET /variables?tax_benefit_model_version_id=X returns that version's vars.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - _create_var(session, v1, "old_variable") - _create_var(session, v2, "new_variable") - - response = client.get(f"/variables?tax_benefit_model_version_id={v1.id}") - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "old_variable" - - -def test__given_version_id_overrides_model_name(client, session): - """Version ID takes precedence over model name when both are provided.""" - model = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - _create_var(session, v1, "old_variable") - _create_var(session, v2, "new_variable") - - response = client.get( - f"/variables?tax_benefit_model_name=policyengine-us&tax_benefit_model_version_id={v1.id}" - ) - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["name"] == "old_variable" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +class TestGetVariable: + def test_given_valid_id_then_returns_variable( + self, client, session, us_version # noqa: F811 + ): + """Returns the full variable data.""" + var = create_variable(session, us_version, "employment_income") + response = client.get(f"/variables/{var.id}") + assert response.status_code == 200 + assert response.json()["name"] == "employment_income" + + def test_given_nonexistent_id_then_returns_404(self, client): + """Unknown UUID returns 404.""" + response = client.get(f"/variables/{uuid4()}") + assert response.status_code == 404 + + def test_response_shape_matches_variable_read( + self, client, session, us_version # noqa: F811 + ): + """Response contains all VariableRead fields.""" + var = create_variable(session, us_version, "employment_income") + data = client.get(f"/variables/{var.id}").json() + for field in ("id", "name", "entity", "created_at", "tax_benefit_model_version_id"): + assert field in data diff --git a/tests/test_variables_by_name.py b/tests/test_variables_by_name.py index 7727d22..1aad4d2 100644 --- a/tests/test_variables_by_name.py +++ b/tests/test_variables_by_name.py @@ -1,328 +1,239 @@ """Tests for POST /variables/by-name endpoint.""" -from datetime import datetime, timezone - import pytest -from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, - Variable, +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_variable, + uk_model, # noqa: F401 + uk_two_versions, # noqa: F401 + uk_version, # noqa: F401 + us_model, # noqa: F401 + us_version, # noqa: F401 ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def uk_version(session): - """Create a policyengine-uk model and version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="UK v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -@pytest.fixture -def us_version(session): - """Create a policyengine-us model and version.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="US v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -def _add_var(session, version, name, entity="person", description=None): - """Create and persist a Variable.""" - var = Variable( - name=name, - entity=entity, - description=description, - tax_benefit_model_version_id=version.id, - ) - session.add(var) - session.commit() - session.refresh(var) - return var - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# Happy-path lookups +# ----------------------------------------------------------------------------- -class TestVariablesByName: - """Tests for looking up variables by their exact names.""" - def test_returns_matching_variables(self, client, session, uk_version): - """Given known variable names, returns their full metadata.""" - _add_var(session, uk_version, "employment_income") - _add_var(session, uk_version, "income_tax") +class TestVariablesByName: + def test_given_known_names_then_returns_matching( + self, client, session, uk_version # noqa: F811 + ): + """Returns full metadata for each matching name.""" + create_variable(session, uk_version, "employment_income") + create_variable(session, uk_version, "income_tax") - response = client.post( + data = client.post( "/variables/by-name", json={ "names": ["employment_income", "income_tax"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) + ).json() - assert response.status_code == 200 - data = response.json() assert len(data) == 2 - returned_names = {v["name"] for v in data} - assert returned_names == {"employment_income", "income_tax"} + assert {v["name"] for v in data} == {"employment_income", "income_tax"} - def test_returns_empty_list_for_empty_names(self, client, session, uk_version): - """Given an empty names list, returns an empty list.""" - response = client.post( + def test_given_empty_names_then_returns_empty_list( + self, client, session, uk_version # noqa: F811 + ): + """Empty names list returns empty response (no DB query).""" + data = client.post( "/variables/by-name", - json={ - "names": [], - "tax_benefit_model_name": "policyengine-uk", - }, - ) - - assert response.status_code == 200 - assert response.json() == [] + json={"names": [], "tax_benefit_model_name": MODEL_NAMES["UK"]}, + ).json() + assert data == [] - def test_returns_empty_list_for_unknown_names(self, client, session, uk_version): - """Given names that don't match any variable, returns an empty list.""" - _add_var(session, uk_version, "employment_income") + def test_given_unknown_names_then_returns_empty_list( + self, client, session, uk_version # noqa: F811 + ): + """Names that don't match anything return empty list.""" + create_variable(session, uk_version, "employment_income") - response = client.post( + data = client.post( "/variables/by-name", json={ "names": ["nonexistent_var", "also_missing"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) - - assert response.status_code == 200 - assert response.json() == [] + ).json() + assert data == [] - def test_returns_only_matching_when_mix_of_known_and_unknown( - self, client, session, uk_version + def test_given_mixed_names_then_returns_only_known( + self, client, session, uk_version # noqa: F811 ): - """Given a mix of known and unknown names, returns only the known ones.""" - _add_var(session, uk_version, "income_tax") + """Only matching names are returned; unknowns silently omitted.""" + create_variable(session, uk_version, "income_tax") - response = client.post( + data = client.post( "/variables/by-name", json={ "names": ["income_tax", "fake_var"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 assert data[0]["name"] == "income_tax" - def test_single_name_lookup(self, client, session, uk_version): - """Looking up a single variable name works.""" - _add_var(session, uk_version, "age") - - response = client.post( + def test_given_single_name_then_returns_one( + self, client, session, uk_version # noqa: F811 + ): + """Single-element lookup works.""" + create_variable(session, uk_version, "age") + data = client.post( "/variables/by-name", json={ "names": ["age"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 - assert data[0]["name"] == "age" - def test_results_ordered_by_name(self, client, session, uk_version): - """Returned variables are sorted alphabetically by name.""" - _add_var(session, uk_version, "zzz_var") - _add_var(session, uk_version, "aaa_var") - _add_var(session, uk_version, "mmm_var") - - response = client.post( - "/variables/by-name", - json={ - "names": ["zzz_var", "aaa_var", "mmm_var"], - "tax_benefit_model_name": "policyengine-uk", - }, - ) - - assert response.status_code == 200 - names = [v["name"] for v in response.json()] + def test_results_ordered_by_name( + self, client, session, uk_version # noqa: F811 + ): + """Response is sorted alphabetically by name.""" + create_variable(session, uk_version, "zzz_var") + create_variable(session, uk_version, "aaa_var") + create_variable(session, uk_version, "mmm_var") + + names = [ + v["name"] + for v in client.post( + "/variables/by-name", + json={ + "names": ["zzz_var", "aaa_var", "mmm_var"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() + ] assert names == ["aaa_var", "mmm_var", "zzz_var"] - def test_response_shape_matches_variable_read(self, client, session, uk_version): - """Returned objects have the same shape as VariableRead.""" - _add_var(session, uk_version, "income_tax", entity="person", description="Tax") - - response = client.post( + def test_response_shape(self, client, session, uk_version): # noqa: F811 + """Each returned object has the VariableRead fields.""" + create_variable( + session, uk_version, "income_tax", entity="person", description="Tax" + ) + var = client.post( "/variables/by-name", json={ "names": ["income_tax"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) + ).json()[0] + for field in ("id", "name", "entity", "description", "created_at", "tax_benefit_model_version_id"): + assert field in var - assert response.status_code == 200 - var = response.json()[0] - assert "id" in var - assert "name" in var - assert "entity" in var - assert "description" in var - assert "created_at" in var - assert "tax_benefit_model_version_id" in var +# ----------------------------------------------------------------------------- +# Model isolation +# ----------------------------------------------------------------------------- -class TestVariablesByNameModelFiltering: - """Tests for tax_benefit_model_name filtering.""" - def test_model_isolation(self, client, session, uk_version, us_version): - """Variables from a different model are excluded.""" - _add_var(session, uk_version, "council_tax") - _add_var(session, us_version, "state_income_tax") +class TestVariablesByNameModelIsolation: + def test_given_two_models_then_returns_only_requested( + self, client, session, uk_version, us_version # noqa: F811 + ): + """Variables from the other model are excluded.""" + create_variable(session, uk_version, "council_tax") + create_variable(session, us_version, "state_income_tax") - uk_response = client.post( + uk_data = client.post( "/variables/by-name", json={ "names": ["council_tax", "state_income_tax"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) - us_response = client.post( + ).json() + us_data = client.post( "/variables/by-name", json={ "names": ["council_tax", "state_income_tax"], - "tax_benefit_model_name": "policyengine-us", + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) + ).json() + + assert len(uk_data) == 1 + assert uk_data[0]["name"] == "council_tax" + assert len(us_data) == 1 + assert us_data[0]["name"] == "state_income_tax" - assert len(uk_response.json()) == 1 - assert uk_response.json()[0]["name"] == "council_tax" - assert len(us_response.json()) == 1 - assert us_response.json()[0]["name"] == "state_income_tax" + +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- class TestVariablesByNameValidation: - """Tests for request validation.""" + def test_given_missing_model_name_then_422(self, client): + """Omitting tax_benefit_model_name returns 422.""" + response = client.post( + "/variables/by-name", json={"names": ["income_tax"]} + ) + assert response.status_code == 422 - def test_missing_model_name_returns_422(self, client): - """Request without tax_benefit_model_name is rejected.""" + def test_given_missing_names_then_422(self, client): + """Omitting names returns 422.""" response = client.post( "/variables/by-name", - json={"names": ["income_tax"]}, + json={"tax_benefit_model_name": MODEL_NAMES["UK"]}, ) - assert response.status_code == 422 - def test_missing_names_field_returns_422(self, client): - """Request without names field is rejected.""" + def test_given_nonexistent_model_name_then_404(self, client, session): + """Model that doesn't exist returns 404.""" response = client.post( "/variables/by-name", - json={"tax_benefit_model_name": "policyengine-uk"}, + json={ + "names": ["income_tax"], + "tax_benefit_model_name": "nonexistent-model", + }, ) + assert response.status_code == 404 - assert response.status_code == 422 +# ----------------------------------------------------------------------------- +# Version filtering +# ----------------------------------------------------------------------------- -class TestVariablesByNameVersionFilter: - """Tests for version filtering on the by-name endpoint.""" - - def test_defaults_to_latest_version(self, client, session): - """When only model name is given, returns variables from latest version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - _add_var(session, v1, "old_var") - _add_var(session, v2, "new_var") +class TestVariablesByNameVersionFilter: + def test_given_model_name_only_then_defaults_to_latest( + self, client, session, uk_two_versions # noqa: F811 + ): + """Model name resolves to latest version.""" + v1, v2 = uk_two_versions + create_variable(session, v1, "old_var") + create_variable(session, v2, "new_var") - response = client.post( + data = client.post( "/variables/by-name", json={ "names": ["old_var", "new_var"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 assert data[0]["name"] == "new_var" - def test_explicit_version_id_returns_that_version(self, client, session): - """When version ID is given, returns variables from that specific version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK") - session.add(model) - session.commit() - session.refresh(model) - - v1 = TaxBenefitModelVersion( - model_id=model.id, - version="1.0", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=model.id, - version="2.0", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - - _add_var(session, v1, "old_var") - _add_var(session, v2, "new_var") + def test_given_explicit_version_id_then_returns_that_version( + self, client, session, uk_two_versions # noqa: F811 + ): + """Explicit version_id overrides latest-version default.""" + v1, v2 = uk_two_versions + create_variable(session, v1, "old_var") + create_variable(session, v2, "new_var") - response = client.post( + data = client.post( "/variables/by-name", json={ "names": ["old_var", "new_var"], - "tax_benefit_model_name": "policyengine-uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "tax_benefit_model_version_id": str(v1.id), }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 assert data[0]["name"] == "old_var" diff --git a/tests/test_version_filter_service.py b/tests/test_version_filter_service.py index 232b905..08c9820 100644 --- a/tests/test_version_filter_service.py +++ b/tests/test_version_filter_service.py @@ -1,55 +1,22 @@ """Tests for the tax benefit model version resolution service.""" -from datetime import datetime, timedelta, timezone from uuid import uuid4 import pytest from fastapi import HTTPException -from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion from policyengine_api.services.tax_benefit_models import ( get_latest_model_version, get_model_version_by_id, resolve_model_version_id, ) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def us_model(session): - """Create a policyengine-us model.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - return model - - -@pytest.fixture -def us_versions(session, us_model): - """Create two versions for the US model, v1 older than v2.""" - v1 = TaxBenefitModelVersion( - model_id=us_model.id, - version="1.0.0", - description="First version", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - v2 = TaxBenefitModelVersion( - model_id=us_model.id, - version="2.0.0", - description="Second version", - created_at=datetime(2025, 6, 1, tzinfo=timezone.utc), - ) - session.add(v1) - session.add(v2) - session.commit() - session.refresh(v1) - session.refresh(v2) - return v1, v2 +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_model, + create_version, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 +) # --------------------------------------------------------------------------- @@ -58,31 +25,44 @@ def us_versions(session, us_model): class TestGetLatestModelVersion: - def test_returns_latest_version(self, session, us_model, us_versions): - """Given multiple versions, returns the one with the newest created_at.""" - v1, v2 = us_versions - result = get_latest_model_version("policyengine-us", session) + def test_given_multiple_versions_then_returns_newest( + self, session, us_model, us_two_versions # noqa: F811 + ): + """Returns the version with the most recent created_at.""" + _v1, v2 = us_two_versions + result = get_latest_model_version(MODEL_NAMES["US"], session) assert result.id == v2.id - assert result.version == "2.0.0" + assert result.version == "2.0" - def test_normalizes_underscores_to_hyphens(self, session, us_model, us_versions): - """Underscore names like 'policyengine_us' are normalized.""" - v1, v2 = us_versions + def test_given_underscore_name_then_normalizes_to_hyphens( + self, session, us_model, us_two_versions # noqa: F811 + ): + """'policyengine_us' is normalised to 'policyengine-us'.""" + _v1, v2 = us_two_versions result = get_latest_model_version("policyengine_us", session) assert result.id == v2.id - def test_nonexistent_model_raises_404(self, session): - """A model name that doesn't exist raises HTTPException 404.""" + def test_given_nonexistent_model_then_raises_404(self, session): + """Unknown model name raises HTTPException 404.""" with pytest.raises(HTTPException) as exc_info: get_latest_model_version("nonexistent-model", session) assert exc_info.value.status_code == 404 - def test_model_with_no_versions_raises_404(self, session, us_model): - """A model that exists but has no versions raises HTTPException 404.""" + def test_given_model_without_versions_then_raises_404( + self, session, us_model # noqa: F811 + ): + """Model that exists but has zero versions raises 404.""" with pytest.raises(HTTPException) as exc_info: - get_latest_model_version("policyengine-us", session) + get_latest_model_version(MODEL_NAMES["US"], session) assert exc_info.value.status_code == 404 + def test_given_single_version_then_returns_it(self, session): + """With only one version, that version is returned.""" + model = create_model(session) + only = create_version(session, model, "0.1") + result = get_latest_model_version(MODEL_NAMES["US"], session) + assert result.id == only.id + # --------------------------------------------------------------------------- # get_model_version_by_id @@ -90,15 +70,17 @@ def test_model_with_no_versions_raises_404(self, session, us_model): class TestGetModelVersionById: - def test_returns_version(self, session, us_versions): - """Given a valid version UUID, returns that version.""" - v1, v2 = us_versions + def test_given_valid_id_then_returns_version( + self, session, us_two_versions # noqa: F811 + ): + """Returns the matching version.""" + v1, _v2 = us_two_versions result = get_model_version_by_id(v1.id, session) assert result.id == v1.id - assert result.version == "1.0.0" + assert result.version == "1.0" - def test_nonexistent_id_raises_404(self, session): - """A UUID that doesn't match any version raises HTTPException 404.""" + def test_given_nonexistent_id_then_raises_404(self, session): + """Unknown UUID raises HTTPException 404.""" with pytest.raises(HTTPException) as exc_info: get_model_version_by_id(uuid4(), session) assert exc_info.value.status_code == 404 @@ -110,25 +92,35 @@ def test_nonexistent_id_raises_404(self, session): class TestResolveModelVersionId: - def test_version_id_takes_precedence(self, session, us_versions): - """When both model name and version ID are given, version ID wins.""" - v1, v2 = us_versions - result = resolve_model_version_id("policyengine-us", v1.id, session) + def test_given_version_id_then_takes_precedence_over_model_name( + self, session, us_two_versions # noqa: F811 + ): + """Explicit version_id wins over model_name.""" + v1, _v2 = us_two_versions + result = resolve_model_version_id(MODEL_NAMES["US"], v1.id, session) assert result == v1.id - def test_model_name_resolves_to_latest(self, session, us_model, us_versions): - """When only model name is given, resolves to the latest version.""" - v1, v2 = us_versions - result = resolve_model_version_id("policyengine-us", None, session) + def test_given_only_model_name_then_resolves_to_latest( + self, session, us_model, us_two_versions # noqa: F811 + ): + """Model name alone returns the latest version's ID.""" + _v1, v2 = us_two_versions + result = resolve_model_version_id(MODEL_NAMES["US"], None, session) assert result == v2.id - def test_neither_returns_none(self, session): - """When neither model name nor version ID is given, returns None.""" + def test_given_neither_then_returns_none(self, session): + """No model name and no version ID → None (no filtering).""" result = resolve_model_version_id(None, None, session) assert result is None - def test_invalid_version_id_raises_404(self, session): - """An explicit version ID that doesn't exist raises 404.""" + def test_given_invalid_version_id_then_raises_404(self, session): + """Non-existent explicit version_id raises 404.""" with pytest.raises(HTTPException) as exc_info: resolve_model_version_id(None, uuid4(), session) assert exc_info.value.status_code == 404 + + def test_given_invalid_model_name_then_raises_404(self, session): + """Non-existent model name raises 404.""" + with pytest.raises(HTTPException) as exc_info: + resolve_model_version_id("does-not-exist", None, session) + assert exc_info.value.status_code == 404 From 52af05d6fc819678934640f6b61aa7d44bc4fce6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Mar 2026 21:50:44 +0100 Subject: [PATCH 3/3] style: Fix ruff formatting and add missing noqa comments Co-Authored-By: Claude Opus 4.6 --- .../fixtures_economic_impact_response.py | 1 - .../fixtures_simulations_standalone.py | 1 - tests/test_economic_impact_response.py | 1 - tests/test_parameter_values.py | 83 ++++++++++---- tests/test_parameters.py | 89 +++++++++++---- tests/test_parameters_by_name.py | 64 ++++++++--- tests/test_parameters_children.py | 107 +++++++++++++----- tests/test_variables.py | 88 +++++++++----- tests/test_variables_by_name.py | 61 +++++++--- tests/test_version_filter_service.py | 28 +++-- 10 files changed, 383 insertions(+), 140 deletions(-) diff --git a/test_fixtures/fixtures_economic_impact_response.py b/test_fixtures/fixtures_economic_impact_response.py index 51da689..5ae9b29 100644 --- a/test_fixtures/fixtures_economic_impact_response.py +++ b/test_fixtures/fixtures_economic_impact_response.py @@ -5,7 +5,6 @@ program_statistics, decile_impacts) for testing _build_response(). """ - from sqlmodel import Session from policyengine_api.models import ( diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py index c2e397f..2b7ce39 100644 --- a/test_fixtures/fixtures_simulations_standalone.py +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -1,6 +1,5 @@ """Fixtures and helpers for standalone simulation endpoint tests.""" - from policyengine_api.models import ( Dataset, Household, diff --git a/tests/test_economic_impact_response.py b/tests/test_economic_impact_response.py index 18036b8..79b1ef4 100644 --- a/tests/test_economic_impact_response.py +++ b/tests/test_economic_impact_response.py @@ -4,7 +4,6 @@ intra_decile, program_statistics, detailed_budget, and decile_impacts. """ - from policyengine_api.api.analysis import _build_response, _safe_float from policyengine_api.models import ReportStatus from test_fixtures.fixtures_economic_impact_response import ( diff --git a/tests/test_parameter_values.py b/tests/test_parameter_values.py index 1792f53..5cc17c4 100644 --- a/tests/test_parameter_values.py +++ b/tests/test_parameter_values.py @@ -3,8 +3,6 @@ from datetime import datetime, timezone from uuid import uuid4 -import pytest - from test_fixtures.fixtures_version_filter import ( MODEL_NAMES, create_parameter, @@ -15,7 +13,6 @@ us_version, # noqa: F401 ) - # ----------------------------------------------------------------------------- # GET /parameter-values/ — basic # ----------------------------------------------------------------------------- @@ -29,7 +26,10 @@ def test_given_no_values_then_returns_empty_list(self, client): assert response.json() == [] def test_given_values_exist_then_returns_list( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns parameter values that exist.""" param = create_parameter(session, us_version, "gov.rate", "Rate") @@ -39,7 +39,10 @@ def test_given_values_exist_then_returns_list( assert len(data) == 1 def test_given_parameter_id_then_filters_by_parameter( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Filters values to a specific parameter.""" p1 = create_parameter(session, us_version, "gov.rate", "Rate") @@ -52,7 +55,11 @@ def test_given_parameter_id_then_filters_by_parameter( assert data[0]["parameter_id"] == str(p1.id) def test_given_policy_id_then_filters_by_policy( - self, client, session, us_version, us_model # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 + us_model, # noqa: F811 ): """Filters values to a specific policy.""" param = create_parameter(session, us_version, "gov.rate", "Rate") @@ -65,7 +72,11 @@ def test_given_policy_id_then_filters_by_policy( assert data[0]["policy_id"] == str(policy.id) def test_given_combined_parameter_and_policy_filters( - self, client, session, us_version, us_model # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 + us_model, # noqa: F811 ): """parameter_id + policy_id work together.""" p1 = create_parameter(session, us_version, "gov.rate", "Rate") @@ -81,7 +92,10 @@ def test_given_combined_parameter_and_policy_filters( assert len(data) == 1 def test_given_limit_then_returns_at_most_n( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Limit caps the number of results.""" param = create_parameter(session, us_version, "gov.rate", "Rate") @@ -96,7 +110,10 @@ def test_given_limit_then_returns_at_most_n( assert len(client.get("/parameter-values?limit=2").json()) == 2 def test_given_skip_then_skips_first_n( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Skip omits the first N results.""" param = create_parameter(session, us_version, "gov.rate", "Rate") @@ -111,16 +128,23 @@ def test_given_skip_then_skips_first_n( assert len(client.get("/parameter-values?skip=3&limit=10").json()) == 2 def test_results_ordered_by_start_date_desc( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Parameter values come back sorted by start_date descending.""" param = create_parameter(session, us_version, "gov.rate", "Rate") create_parameter_value( - session, param.id, 0.1, + session, + param.id, + 0.1, start_date=datetime(2020, 1, 1, tzinfo=timezone.utc), ) create_parameter_value( - session, param.id, 0.2, + session, + param.id, + 0.2, start_date=datetime(2025, 1, 1, tzinfo=timezone.utc), ) @@ -138,7 +162,10 @@ def test_results_ordered_by_start_date_desc( class TestListParameterValuesVersionFilter: def test_given_model_name_then_returns_only_latest_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Model name resolves to latest version; old-version param values excluded.""" v1, v2 = us_two_versions @@ -154,7 +181,10 @@ def test_given_model_name_then_returns_only_latest_version( assert data[0]["parameter_id"] == str(p_new.id) def test_given_explicit_version_id_then_returns_that_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Explicit version_id pins to a specific version.""" v1, v2 = us_two_versions @@ -170,7 +200,10 @@ def test_given_explicit_version_id_then_returns_that_version( assert data[0]["parameter_id"] == str(p_old.id) def test_given_both_then_version_id_takes_precedence( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """version_id overrides model_name.""" v1, v2 = us_two_versions @@ -187,7 +220,10 @@ def test_given_both_then_version_id_takes_precedence( assert data[0]["parameter_id"] == str(p_old.id) def test_given_no_filters_then_returns_all_versions( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Without model/version filter, values from all versions are returned.""" v1, v2 = us_two_versions @@ -207,7 +243,10 @@ def test_given_nonexistent_model_name_then_returns_404(self, client): assert response.status_code == 404 def test_given_version_filter_combined_with_parameter_id( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Version filter + parameter_id work together.""" v1, v2 = us_two_versions @@ -231,7 +270,10 @@ def test_given_version_filter_combined_with_parameter_id( class TestGetParameterValue: def test_given_valid_id_then_returns_value( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns the full parameter value data.""" param = create_parameter(session, us_version, "gov.rate", "Rate") @@ -247,7 +289,10 @@ def test_given_nonexistent_id_then_returns_404(self, client): assert response.status_code == 404 def test_response_shape_matches_parameter_value_read( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Response contains all ParameterValueRead fields.""" param = create_parameter(session, us_version, "gov.rate", "Rate") diff --git a/tests/test_parameters.py b/tests/test_parameters.py index eeb5673..4ee41ba 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -2,12 +2,9 @@ from uuid import uuid4 -import pytest - from test_fixtures.fixtures_version_filter import ( MODEL_NAMES, create_parameter, - create_version, us_model, # noqa: F401 us_two_versions, # noqa: F401 us_version, # noqa: F401 @@ -26,7 +23,10 @@ def test_given_no_params_then_returns_empty_list(self, client): assert response.json() == [] def test_given_parameters_exist_then_returns_list( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns parameters that exist.""" create_parameter(session, us_version, "gov.rate", "Rate") @@ -35,7 +35,10 @@ def test_given_parameters_exist_then_returns_list( assert len(response.json()) == 1 def test_given_search_by_name_then_returns_matching( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Search filter matches parameter name.""" create_parameter(session, us_version, "gov.tax.rate", "Rate") @@ -48,7 +51,10 @@ def test_given_search_by_name_then_returns_matching( assert data[0]["name"] == "gov.tax.rate" def test_given_search_by_label_then_returns_matching( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Search filter matches parameter label (case-insensitive).""" create_parameter(session, us_version, "gov.x", "Basic Rate") @@ -60,7 +66,10 @@ def test_given_search_by_label_then_returns_matching( assert data[0]["label"] == "Basic Rate" def test_given_search_by_description_then_returns_matching( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Search filter matches parameter description.""" create_parameter( @@ -74,7 +83,10 @@ def test_given_search_by_description_then_returns_matching( assert data[0]["name"] == "gov.x" def test_given_limit_then_returns_at_most_n( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Limit caps the number of results.""" for i in range(5): @@ -84,7 +96,10 @@ def test_given_limit_then_returns_at_most_n( assert len(response.json()) == 2 def test_given_skip_then_skips_first_n( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Skip omits the first N results.""" for i in range(5): @@ -94,7 +109,10 @@ def test_given_skip_then_skips_first_n( assert len(response.json()) == 2 def test_results_ordered_by_name( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Parameters come back sorted alphabetically by name.""" create_parameter(session, us_version, "gov.zzz", "Z") @@ -110,7 +128,10 @@ def test_results_ordered_by_name( class TestListParametersVersionFilter: def test_given_model_name_then_returns_only_latest_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Model name resolves to latest version; old-version params excluded.""" v1, v2 = us_two_versions @@ -124,21 +145,25 @@ def test_given_model_name_then_returns_only_latest_version( assert data[0]["name"] == "gov.new" def test_given_explicit_version_id_then_returns_that_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Explicit version_id pins to a specific version.""" v1, v2 = us_two_versions create_parameter(session, v1, "gov.old", "Old") create_parameter(session, v2, "gov.new", "New") - data = client.get( - f"/parameters?tax_benefit_model_version_id={v1.id}" - ).json() + data = client.get(f"/parameters?tax_benefit_model_version_id={v1.id}").json() assert len(data) == 1 assert data[0]["name"] == "gov.old" def test_given_both_then_version_id_takes_precedence( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """version_id overrides model_name.""" v1, v2 = us_two_versions @@ -153,7 +178,10 @@ def test_given_both_then_version_id_takes_precedence( assert data[0]["name"] == "gov.old" def test_given_no_filters_then_returns_all_versions( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Without model/version filter, params from all versions are returned.""" v1, v2 = us_two_versions @@ -165,13 +193,14 @@ def test_given_no_filters_then_returns_all_versions( def test_given_nonexistent_model_name_then_returns_404(self, client): """Unknown model name → 404.""" - response = client.get( - "/parameters?tax_benefit_model_name=nonexistent-model" - ) + response = client.get("/parameters?tax_benefit_model_name=nonexistent-model") assert response.status_code == 404 def test_given_search_combined_with_version_filter( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Search + version filter work together.""" v1, v2 = us_two_versions @@ -192,7 +221,10 @@ def test_given_search_combined_with_version_filter( class TestGetParameter: def test_given_valid_id_then_returns_parameter( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns the full parameter data.""" param = create_parameter(session, us_version, "gov.rate", "Rate") @@ -206,10 +238,19 @@ def test_given_nonexistent_id_then_returns_404(self, client): assert response.status_code == 404 def test_response_shape_matches_parameter_read( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Response contains all ParameterRead fields.""" param = create_parameter(session, us_version, "gov.rate", "Rate") data = client.get(f"/parameters/{param.id}").json() - for field in ("id", "name", "label", "created_at", "tax_benefit_model_version_id"): + for field in ( + "id", + "name", + "label", + "created_at", + "tax_benefit_model_version_id", + ): assert field in data diff --git a/tests/test_parameters_by_name.py b/tests/test_parameters_by_name.py index d08a8fb..5d089e4 100644 --- a/tests/test_parameters_by_name.py +++ b/tests/test_parameters_by_name.py @@ -1,6 +1,5 @@ """Tests for POST /parameters/by-name endpoint.""" -import pytest from test_fixtures.fixtures_version_filter import ( MODEL_NAMES, @@ -12,7 +11,6 @@ us_version, # noqa: F401 ) - # ----------------------------------------------------------------------------- # Happy-path lookups # ----------------------------------------------------------------------------- @@ -20,7 +18,10 @@ class TestParametersByName: def test_given_known_names_then_returns_matching( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns full metadata for each matching name.""" create_parameter(session, us_version, "gov.tax.rate", "Rate") @@ -38,7 +39,10 @@ def test_given_known_names_then_returns_matching( assert {p["name"] for p in data} == {"gov.tax.rate", "gov.tax.threshold"} def test_given_empty_names_then_returns_empty_list( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Empty names list → empty response (no DB query).""" data = client.post( @@ -48,7 +52,10 @@ def test_given_empty_names_then_returns_empty_list( assert data == [] def test_given_unknown_names_then_returns_empty_list( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Names that don't match anything → empty list.""" create_parameter(session, us_version, "gov.exists", "Exists") @@ -63,7 +70,10 @@ def test_given_unknown_names_then_returns_empty_list( assert data == [] def test_given_mixed_names_then_returns_only_known( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Only matching names are returned; unknowns silently omitted.""" create_parameter(session, us_version, "gov.real", "Real") @@ -79,7 +89,10 @@ def test_given_mixed_names_then_returns_only_known( assert data[0]["name"] == "gov.real" def test_given_single_name_then_returns_one( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Single-element lookup works.""" create_parameter(session, us_version, "gov.single", "Single") @@ -93,7 +106,10 @@ def test_given_single_name_then_returns_one( assert len(data) == 1 def test_results_ordered_by_name( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Response is sorted alphabetically by name.""" create_parameter(session, us_version, "gov.zzz", "Z") @@ -122,7 +138,13 @@ def test_response_shape(self, client, session, us_version): # noqa: F811 "tax_benefit_model_name": MODEL_NAMES["US"], }, ).json()[0] - for field in ("id", "name", "label", "created_at", "tax_benefit_model_version_id"): + for field in ( + "id", + "name", + "label", + "created_at", + "tax_benefit_model_version_id", + ): assert field in param @@ -133,7 +155,11 @@ def test_response_shape(self, client, session, us_version): # noqa: F811 class TestParametersByNameModelIsolation: def test_given_two_models_then_returns_only_requested( - self, client, session, us_version, uk_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 + uk_version, # noqa: F811 ): """Parameters from the other model are excluded.""" create_parameter(session, us_version, "gov.shared", "US") @@ -158,9 +184,7 @@ def test_given_two_models_then_returns_only_requested( class TestParametersByNameValidation: def test_given_missing_model_name_then_422(self, client): """Omitting tax_benefit_model_name → 422.""" - response = client.post( - "/parameters/by-name", json={"names": ["gov.x"]} - ) + response = client.post("/parameters/by-name", json={"names": ["gov.x"]}) assert response.status_code == 422 def test_given_missing_names_then_422(self, client): @@ -171,9 +195,7 @@ def test_given_missing_names_then_422(self, client): ) assert response.status_code == 422 - def test_given_nonexistent_model_name_then_404( - self, client, session - ): + def test_given_nonexistent_model_name_then_404(self, client, session): """Model that doesn't exist → 404 from resolve_model_version_id.""" response = client.post( "/parameters/by-name", @@ -192,7 +214,10 @@ def test_given_nonexistent_model_name_then_404( class TestParametersByNameVersionFilter: def test_given_model_name_only_then_defaults_to_latest( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Model name resolves to latest version.""" v1, v2 = us_two_versions @@ -210,7 +235,10 @@ def test_given_model_name_only_then_defaults_to_latest( assert data[0]["name"] == "gov.new" def test_given_explicit_version_id_then_returns_that_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Explicit version_id overrides latest-version default.""" v1, v2 = us_two_versions diff --git a/tests/test_parameters_children.py b/tests/test_parameters_children.py index a3da03d..07a4af1 100644 --- a/tests/test_parameters_children.py +++ b/tests/test_parameters_children.py @@ -1,11 +1,9 @@ """Tests for GET /parameters/children endpoint.""" -import pytest from test_fixtures.fixtures_version_filter import ( MODEL_NAMES, add_params_bulk, - create_parameter, uk_model, # noqa: F401 uk_two_versions, # noqa: F401 uk_version, # noqa: F401 @@ -13,7 +11,6 @@ us_version, # noqa: F401 ) - # ----------------------------------------------------------------------------- # Tree structure # ----------------------------------------------------------------------------- @@ -21,7 +18,10 @@ class TestParameterChildrenBasic: def test_returns_nodes_for_intermediate_paths( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Parameters at gov.hmrc.x and gov.dwp.x produce nodes for hmrc and dwp.""" add_params_bulk( @@ -54,7 +54,10 @@ def test_returns_nodes_for_intermediate_paths( assert child["child_count"] > 0 def test_returns_leaf_parameters( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Direct child parameters are returned with type='parameter'.""" add_params_bulk( @@ -88,7 +91,10 @@ def test_returns_leaf_parameters( assert node["path"] == "gov.hmrc" def test_mixed_nodes_and_leaves( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Both nodes and leaf parameters can appear at the same level.""" add_params_bulk( @@ -122,7 +128,10 @@ def test_mixed_nodes_and_leaves( class TestChildCount: def test_child_count_reflects_total_descendants( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """child_count counts all leaf parameters under the node.""" add_params_bulk( @@ -148,7 +157,10 @@ def test_child_count_reflects_total_descendants( assert hmrc["child_count"] == 3 def test_nested_child_count( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Querying a deeper level gives accurate child counts.""" add_params_bulk( @@ -176,7 +188,10 @@ def test_nested_child_count( assert ni["child_count"] == 1 def test_leaf_has_no_child_count( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Leaf parameters have child_count=None.""" add_params_bulk(session, uk_version, [("gov.rate", "Rate")]) @@ -200,7 +215,10 @@ def test_leaf_has_no_child_count( class TestParameterChildrenModelIsolation: def test_given_uk_model_then_returns_uk_params( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """policyengine-uk returns UK parameters.""" add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) @@ -217,7 +235,10 @@ def test_given_uk_model_then_returns_uk_params( assert len(response.json()["children"]) == 1 def test_given_us_model_then_returns_us_params( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """policyengine-us returns US parameters.""" add_params_bulk(session, us_version, [("gov.irs.rate", "Rate")]) @@ -234,7 +255,11 @@ def test_given_us_model_then_returns_us_params( assert len(response.json()["children"]) == 1 def test_given_two_models_then_returns_only_requested( - self, client, session, uk_version, us_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 + us_version, # noqa: F811 ): """Parameters from a different model are excluded.""" add_params_bulk(session, uk_version, [("gov.hmrc.rate", "UK rate")]) @@ -273,9 +298,7 @@ def test_given_two_models_then_returns_only_requested( class TestParameterChildrenValidation: def test_given_missing_model_name_then_422(self, client): """Request without tax_benefit_model_name returns 422.""" - response = client.get( - "/parameters/children", params={"parent_path": "gov"} - ) + response = client.get("/parameters/children", params={"parent_path": "gov"}) assert response.status_code == 422 @@ -286,7 +309,10 @@ def test_given_missing_model_name_then_422(self, client): class TestParameterChildrenEdgeCases: def test_empty_parent_path( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Empty parent_path returns top-level children.""" add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) @@ -306,7 +332,10 @@ def test_empty_parent_path( assert children[0]["type"] == "node" def test_nonexistent_parent_returns_empty( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """A parent path with no descendants returns empty children list.""" add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) @@ -322,7 +351,10 @@ def test_nonexistent_parent_returns_empty( assert children == [] def test_children_sorted_by_path( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Children are returned sorted alphabetically by path.""" add_params_bulk( @@ -348,12 +380,13 @@ def test_children_sorted_by_path( assert paths == ["gov.aaa", "gov.mmm", "gov.zzz"] def test_node_label_from_path_segment( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Node labels default to the last path segment.""" - add_params_bulk( - session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")] - ) + add_params_bulk(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) children = client.get( "/parameters/children", @@ -366,7 +399,10 @@ def test_node_label_from_path_segment( assert children[0]["label"] == "hmrc" def test_default_parent_path_is_empty( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Omitting parent_path defaults to empty string (root level).""" add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) @@ -380,7 +416,10 @@ def test_default_parent_path_is_empty( assert len(data["children"]) == 1 def test_leaf_parameter_includes_full_metadata( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Leaf parameters include the full ParameterRead shape.""" add_params_bulk(session, uk_version, [("gov.rate", "The rate")]) @@ -399,7 +438,10 @@ def test_leaf_parameter_includes_full_metadata( assert field in param def test_node_has_no_parameter_field( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Nodes do not include the parameter field.""" add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) @@ -416,7 +458,10 @@ def test_node_has_no_parameter_field( assert node["parameter"] is None def test_deep_nesting( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Works correctly with deeply nested parameter paths.""" add_params_bulk( @@ -462,7 +507,10 @@ def test_deep_nesting( class TestParameterChildrenVersionFilter: def test_given_model_name_only_then_defaults_to_latest( - self, client, session, uk_two_versions # noqa: F811 + self, + client, + session, + uk_two_versions, # noqa: F811 ): """When only model name is given, returns children from latest version.""" v1, v2 = uk_two_versions @@ -481,7 +529,10 @@ def test_given_model_name_only_then_defaults_to_latest( assert children[0]["path"] == "gov.new_param" def test_given_explicit_version_id_then_returns_that_version( - self, client, session, uk_two_versions # noqa: F811 + self, + client, + session, + uk_two_versions, # noqa: F811 ): """When version ID is given, returns children from that specific version.""" v1, v2 = uk_two_versions diff --git a/tests/test_variables.py b/tests/test_variables.py index c70d8bc..11e1314 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -2,8 +2,6 @@ from uuid import uuid4 -import pytest - from test_fixtures.fixtures_version_filter import ( MODEL_NAMES, create_variable, @@ -12,7 +10,6 @@ us_version, # noqa: F401 ) - # ----------------------------------------------------------------------------- # GET /variables/ — basic # ----------------------------------------------------------------------------- @@ -26,7 +23,10 @@ def test_given_no_variables_then_returns_empty_list(self, client): assert response.json() == [] def test_given_variables_exist_then_returns_list( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns variables that exist.""" create_variable(session, us_version, "employment_income") @@ -35,7 +35,10 @@ def test_given_variables_exist_then_returns_list( assert len(response.json()) == 1 def test_given_search_by_name_then_returns_matching( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Search filter matches variable name.""" create_variable(session, us_version, "employment_income") @@ -46,22 +49,26 @@ def test_given_search_by_name_then_returns_matching( assert data[0]["name"] == "employment_income" def test_given_search_by_description_then_returns_matching( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Search filter matches variable description.""" create_variable( session, us_version, "var_x", description="Total household income" ) - create_variable( - session, us_version, "var_y", description="Tax liability" - ) + create_variable(session, us_version, "var_y", description="Tax liability") data = client.get("/variables?search=household").json() assert len(data) == 1 assert data[0]["name"] == "var_x" def test_given_limit_then_returns_at_most_n( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Limit caps the number of results.""" for i in range(5): @@ -70,7 +77,10 @@ def test_given_limit_then_returns_at_most_n( assert len(client.get("/variables?limit=2").json()) == 2 def test_given_skip_then_skips_first_n( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Skip omits the first N results.""" for i in range(5): @@ -79,7 +89,10 @@ def test_given_skip_then_skips_first_n( assert len(client.get("/variables?skip=3&limit=10").json()) == 2 def test_results_ordered_by_name( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Variables come back sorted alphabetically by name.""" create_variable(session, us_version, "zzz_var") @@ -95,7 +108,10 @@ def test_results_ordered_by_name( class TestListVariablesVersionFilter: def test_given_model_name_then_returns_only_latest_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Model name resolves to latest version; old-version vars excluded.""" v1, v2 = us_two_versions @@ -109,21 +125,25 @@ def test_given_model_name_then_returns_only_latest_version( assert data[0]["name"] == "new_variable" def test_given_explicit_version_id_then_returns_that_version( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Explicit version_id pins to a specific version.""" v1, v2 = us_two_versions create_variable(session, v1, "old_variable") create_variable(session, v2, "new_variable") - data = client.get( - f"/variables?tax_benefit_model_version_id={v1.id}" - ).json() + data = client.get(f"/variables?tax_benefit_model_version_id={v1.id}").json() assert len(data) == 1 assert data[0]["name"] == "old_variable" def test_given_both_then_version_id_takes_precedence( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """version_id overrides model_name.""" v1, v2 = us_two_versions @@ -138,7 +158,10 @@ def test_given_both_then_version_id_takes_precedence( assert data[0]["name"] == "old_variable" def test_given_no_filters_then_returns_all_versions( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Without model/version filter, vars from all versions are returned.""" v1, v2 = us_two_versions @@ -150,13 +173,14 @@ def test_given_no_filters_then_returns_all_versions( def test_given_nonexistent_model_name_then_returns_404(self, client): """Unknown model name returns 404.""" - response = client.get( - "/variables?tax_benefit_model_name=nonexistent-model" - ) + response = client.get("/variables?tax_benefit_model_name=nonexistent-model") assert response.status_code == 404 def test_given_search_combined_with_version_filter( - self, client, session, us_two_versions # noqa: F811 + self, + client, + session, + us_two_versions, # noqa: F811 ): """Search + version filter work together.""" v1, v2 = us_two_versions @@ -177,7 +201,10 @@ def test_given_search_combined_with_version_filter( class TestGetVariable: def test_given_valid_id_then_returns_variable( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Returns the full variable data.""" var = create_variable(session, us_version, "employment_income") @@ -191,10 +218,19 @@ def test_given_nonexistent_id_then_returns_404(self, client): assert response.status_code == 404 def test_response_shape_matches_variable_read( - self, client, session, us_version # noqa: F811 + self, + client, + session, + us_version, # noqa: F811 ): """Response contains all VariableRead fields.""" var = create_variable(session, us_version, "employment_income") data = client.get(f"/variables/{var.id}").json() - for field in ("id", "name", "entity", "created_at", "tax_benefit_model_version_id"): + for field in ( + "id", + "name", + "entity", + "created_at", + "tax_benefit_model_version_id", + ): assert field in data diff --git a/tests/test_variables_by_name.py b/tests/test_variables_by_name.py index 1aad4d2..0f713bc 100644 --- a/tests/test_variables_by_name.py +++ b/tests/test_variables_by_name.py @@ -1,6 +1,5 @@ """Tests for POST /variables/by-name endpoint.""" -import pytest from test_fixtures.fixtures_version_filter import ( MODEL_NAMES, @@ -12,7 +11,6 @@ us_version, # noqa: F401 ) - # ----------------------------------------------------------------------------- # Happy-path lookups # ----------------------------------------------------------------------------- @@ -20,7 +18,10 @@ class TestVariablesByName: def test_given_known_names_then_returns_matching( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Returns full metadata for each matching name.""" create_variable(session, uk_version, "employment_income") @@ -38,7 +39,10 @@ def test_given_known_names_then_returns_matching( assert {v["name"] for v in data} == {"employment_income", "income_tax"} def test_given_empty_names_then_returns_empty_list( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Empty names list returns empty response (no DB query).""" data = client.post( @@ -48,7 +52,10 @@ def test_given_empty_names_then_returns_empty_list( assert data == [] def test_given_unknown_names_then_returns_empty_list( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Names that don't match anything return empty list.""" create_variable(session, uk_version, "employment_income") @@ -63,7 +70,10 @@ def test_given_unknown_names_then_returns_empty_list( assert data == [] def test_given_mixed_names_then_returns_only_known( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Only matching names are returned; unknowns silently omitted.""" create_variable(session, uk_version, "income_tax") @@ -79,7 +89,10 @@ def test_given_mixed_names_then_returns_only_known( assert data[0]["name"] == "income_tax" def test_given_single_name_then_returns_one( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Single-element lookup works.""" create_variable(session, uk_version, "age") @@ -93,7 +106,10 @@ def test_given_single_name_then_returns_one( assert len(data) == 1 def test_results_ordered_by_name( - self, client, session, uk_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 ): """Response is sorted alphabetically by name.""" create_variable(session, uk_version, "zzz_var") @@ -124,7 +140,14 @@ def test_response_shape(self, client, session, uk_version): # noqa: F811 "tax_benefit_model_name": MODEL_NAMES["UK"], }, ).json()[0] - for field in ("id", "name", "entity", "description", "created_at", "tax_benefit_model_version_id"): + for field in ( + "id", + "name", + "entity", + "description", + "created_at", + "tax_benefit_model_version_id", + ): assert field in var @@ -135,7 +158,11 @@ def test_response_shape(self, client, session, uk_version): # noqa: F811 class TestVariablesByNameModelIsolation: def test_given_two_models_then_returns_only_requested( - self, client, session, uk_version, us_version # noqa: F811 + self, + client, + session, + uk_version, # noqa: F811 + us_version, # noqa: F811 ): """Variables from the other model are excluded.""" create_variable(session, uk_version, "council_tax") @@ -170,9 +197,7 @@ def test_given_two_models_then_returns_only_requested( class TestVariablesByNameValidation: def test_given_missing_model_name_then_422(self, client): """Omitting tax_benefit_model_name returns 422.""" - response = client.post( - "/variables/by-name", json={"names": ["income_tax"]} - ) + response = client.post("/variables/by-name", json={"names": ["income_tax"]}) assert response.status_code == 422 def test_given_missing_names_then_422(self, client): @@ -202,7 +227,10 @@ def test_given_nonexistent_model_name_then_404(self, client, session): class TestVariablesByNameVersionFilter: def test_given_model_name_only_then_defaults_to_latest( - self, client, session, uk_two_versions # noqa: F811 + self, + client, + session, + uk_two_versions, # noqa: F811 ): """Model name resolves to latest version.""" v1, v2 = uk_two_versions @@ -220,7 +248,10 @@ def test_given_model_name_only_then_defaults_to_latest( assert data[0]["name"] == "new_var" def test_given_explicit_version_id_then_returns_that_version( - self, client, session, uk_two_versions # noqa: F811 + self, + client, + session, + uk_two_versions, # noqa: F811 ): """Explicit version_id overrides latest-version default.""" v1, v2 = uk_two_versions diff --git a/tests/test_version_filter_service.py b/tests/test_version_filter_service.py index 08c9820..b734775 100644 --- a/tests/test_version_filter_service.py +++ b/tests/test_version_filter_service.py @@ -18,7 +18,6 @@ us_two_versions, # noqa: F401 ) - # --------------------------------------------------------------------------- # get_latest_model_version # --------------------------------------------------------------------------- @@ -26,7 +25,10 @@ class TestGetLatestModelVersion: def test_given_multiple_versions_then_returns_newest( - self, session, us_model, us_two_versions # noqa: F811 + self, + session, + us_model, # noqa: F811 + us_two_versions, # noqa: F811 ): """Returns the version with the most recent created_at.""" _v1, v2 = us_two_versions @@ -35,7 +37,10 @@ def test_given_multiple_versions_then_returns_newest( assert result.version == "2.0" def test_given_underscore_name_then_normalizes_to_hyphens( - self, session, us_model, us_two_versions # noqa: F811 + self, + session, + us_model, # noqa: F811 + us_two_versions, # noqa: F811 ): """'policyengine_us' is normalised to 'policyengine-us'.""" _v1, v2 = us_two_versions @@ -49,7 +54,9 @@ def test_given_nonexistent_model_then_raises_404(self, session): assert exc_info.value.status_code == 404 def test_given_model_without_versions_then_raises_404( - self, session, us_model # noqa: F811 + self, + session, + us_model, # noqa: F811 ): """Model that exists but has zero versions raises 404.""" with pytest.raises(HTTPException) as exc_info: @@ -71,7 +78,9 @@ def test_given_single_version_then_returns_it(self, session): class TestGetModelVersionById: def test_given_valid_id_then_returns_version( - self, session, us_two_versions # noqa: F811 + self, + session, + us_two_versions, # noqa: F811 ): """Returns the matching version.""" v1, _v2 = us_two_versions @@ -93,7 +102,9 @@ def test_given_nonexistent_id_then_raises_404(self, session): class TestResolveModelVersionId: def test_given_version_id_then_takes_precedence_over_model_name( - self, session, us_two_versions # noqa: F811 + self, + session, + us_two_versions, # noqa: F811 ): """Explicit version_id wins over model_name.""" v1, _v2 = us_two_versions @@ -101,7 +112,10 @@ def test_given_version_id_then_takes_precedence_over_model_name( assert result == v1.id def test_given_only_model_name_then_resolves_to_latest( - self, session, us_model, us_two_versions # noqa: F811 + self, + session, + us_model, # noqa: F811 + us_two_versions, # noqa: F811 ): """Model name alone returns the latest version's ID.""" _v1, v2 = us_two_versions