diff --git a/src/policyengine_api/api/parameter_values.py b/src/policyengine_api/api/parameter_values.py index 4668ab8..9dd6819 100644 --- a/src/policyengine_api/api/parameter_values.py +++ b/src/policyengine_api/api/parameter_values.py @@ -12,8 +12,9 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, or_, select -from policyengine_api.models import ParameterValue, ParameterValueRead +from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/parameter-values", tags=["parameter-values"]) @@ -23,6 +24,8 @@ def list_parameter_values( parameter_id: UUID | None = None, policy_id: UUID | None = None, current: bool = False, + tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, skip: int = 0, limit: int = 100, session: Session = Depends(get_session), @@ -37,6 +40,10 @@ def list_parameter_values( policy_id: Filter by a specific policy reform. current: If true, only return values that are currently in effect (start_date <= now and (end_date is null or end_date > now)). + tax_benefit_model_name: Filter to values belonging to parameters from + this model. Defaults to the latest version. + tax_benefit_model_version_id: Filter to values belonging to parameters + from this specific model version. Takes precedence over model name. """ query = select(ParameterValue) @@ -46,6 +53,14 @@ def list_parameter_values( if policy_id: query = query.where(ParameterValue.policy_id == policy_id) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + if version_id: + query = query.join(Parameter).where( + Parameter.tax_benefit_model_version_id == version_id + ) + if current: now = datetime.now(timezone.utc) query = query.where( diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 72b64ef..43cf163 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -14,14 +14,12 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( Parameter, ParameterRead, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/parameters", tags=["parameters"]) @@ -32,6 +30,7 @@ def list_parameters( limit: int = 100, search: str | None = None, tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available parameters with pagination and search. @@ -44,19 +43,19 @@ def list_parameters( tax_benefit_model_name: Filter by country model. Use "policyengine-uk" for UK parameters. Use "policyengine-us" for US parameters. + Defaults to the latest model version when no version ID is given. + tax_benefit_model_version_id: Filter by a specific model version. + Takes precedence over tax_benefit_model_name. """ query = select(Parameter) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE search_pattern = f"%{search}%" search_filter = ( Parameter.name.ilike(search_pattern) @@ -75,7 +74,8 @@ class ParameterByNameRequest(BaseModel): """Request body for looking up parameters by name.""" names: list[str] - country_id: CountryId + tax_benefit_model_name: str + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[ParameterRead]) @@ -95,18 +95,17 @@ def get_parameters_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - - query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.in_(request.names)) - .order_by(Parameter.name) + version_id = resolve_model_version_id( + request.tax_benefit_model_name, + request.tax_benefit_model_version_id, + session, ) - return session.exec(query).all() + query = select(Parameter).where(Parameter.name.in_(request.names)) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Parameter.name)).all() class ParameterChild(BaseModel): @@ -128,10 +127,15 @@ class ParameterChildrenResponse(BaseModel): @router.get("/children", response_model=ParameterChildrenResponse) def get_parameter_children( - country_id: CountryId = Query(description='Country ID ("us" or "uk")'), + tax_benefit_model_name: str = Query( + description='Model name (e.g. "policyengine-us" or "policyengine-uk")' + ), parent_path: str = Query( default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')" ), + tax_benefit_model_version_id: UUID | None = Query( + default=None, description="Optional specific model version ID" + ), session: Session = Depends(get_session), ) -> ParameterChildrenResponse: """Get direct children of a parameter path for tree navigation. @@ -140,17 +144,16 @@ def get_parameter_children( parameters (with full metadata). Use this to lazily load the parameter tree one level at a time. """ - model_name = COUNTRY_MODEL_NAMES[country_id] + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + prefix = f"{parent_path}." if parent_path else "" - # Fetch all parameters under this path - query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.startswith(prefix)) - ) + query = select(Parameter).where(Parameter.name.startswith(prefix)) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) + descendants = session.exec(query).all() # Group by direct child path diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index 04aa512..edf0eea 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -12,14 +12,12 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, Variable, VariableRead, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/variables", tags=["variables"]) @@ -30,6 +28,7 @@ def list_variables( limit: int = 100, search: str | None = None, tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available variables with pagination and search. @@ -43,20 +42,19 @@ def list_variables( tax_benefit_model_name: Filter by country model. Use "policyengine-uk" for UK variables. Use "policyengine-us" for US variables. + Defaults to the latest model version when no version ID is given. + tax_benefit_model_version_id: Filter by a specific model version. + Takes precedence over tax_benefit_model_name. """ query = select(Variable) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE - # Note: Variables don't have a label field, only name and description search_pattern = f"%{search}%" search_filter = Variable.name.ilike( search_pattern @@ -73,7 +71,8 @@ class VariableByNameRequest(BaseModel): """Request body for looking up variables by name.""" names: list[str] - country_id: CountryId + tax_benefit_model_name: str + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[VariableRead]) @@ -94,17 +93,17 @@ def get_variables_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - query = ( - select(Variable) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Variable.name.in_(request.names)) - .order_by(Variable.name) + version_id = resolve_model_version_id( + request.tax_benefit_model_name, + request.tax_benefit_model_version_id, + session, ) - return session.exec(query).all() + query = select(Variable).where(Variable.name.in_(request.names)) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Variable.name)).all() @router.get("/{variable_id}", response_model=VariableRead) diff --git a/src/policyengine_api/services/__init__.py b/src/policyengine_api/services/__init__.py index a314727..1fb2a46 100644 --- a/src/policyengine_api/services/__init__.py +++ b/src/policyengine_api/services/__init__.py @@ -1,5 +1,16 @@ """Services for database and external integrations.""" from .database import get_session, init_db +from .tax_benefit_models import ( + get_latest_model_version, + get_model_version_by_id, + resolve_model_version_id, +) -__all__ = ["get_session", "init_db"] +__all__ = [ + "get_session", + "init_db", + "get_latest_model_version", + "get_model_version_by_id", + "resolve_model_version_id", +] diff --git a/src/policyengine_api/services/tax_benefit_models.py b/src/policyengine_api/services/tax_benefit_models.py new file mode 100644 index 0000000..f88d056 --- /dev/null +++ b/src/policyengine_api/services/tax_benefit_models.py @@ -0,0 +1,109 @@ +"""Tax benefit model utilities. + +Shared utilities for resolving tax benefit model versions. +""" + +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, select + +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion + + +def get_latest_model_version( + tax_benefit_model_name: str, session: Session +) -> TaxBenefitModelVersion: + """Get the latest tax benefit model version for a given model name. + + Args: + tax_benefit_model_name: The model name (e.g., "policyengine-us"). + Underscores are normalized to hyphens. + session: Database session. + + Returns: + The latest TaxBenefitModelVersion for the model. + + Raises: + HTTPException: If model or version not found. + """ + model_name = tax_benefit_model_name.replace("_", "-") + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, + detail=f"Tax benefit model '{model_name}' not found", + ) + + version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not version: + raise HTTPException( + status_code=404, + detail=f"No version found for model '{model_name}'", + ) + + return version + + +def get_model_version_by_id( + version_id: UUID, session: Session +) -> TaxBenefitModelVersion: + """Get a specific tax benefit model version by ID. + + Args: + version_id: The UUID of the model version. + session: Database session. + + Returns: + The TaxBenefitModelVersion with the given ID. + + Raises: + HTTPException: If version not found. + """ + version = session.get(TaxBenefitModelVersion, version_id) + if not version: + raise HTTPException( + status_code=404, + detail=f"Tax benefit model version '{version_id}' not found", + ) + return version + + +def resolve_model_version_id( + tax_benefit_model_name: str | None, + tax_benefit_model_version_id: UUID | None, + session: Session, +) -> UUID | None: + """Resolve the model version ID from either explicit ID or model name. + + Priority: + 1. If tax_benefit_model_version_id provided, validate and return it. + 2. If tax_benefit_model_name provided, return the latest version's ID. + 3. If neither provided, return None (no filtering). + + Args: + tax_benefit_model_name: Optional model name to resolve latest version for. + tax_benefit_model_version_id: Optional explicit version ID. + session: Database session. + + Returns: + The resolved version ID, or None if no filtering requested. + + Raises: + HTTPException: If specified version or model not found. + """ + if tax_benefit_model_version_id: + version = get_model_version_by_id(tax_benefit_model_version_id, session) + return version.id + elif tax_benefit_model_name: + version = get_latest_model_version(tax_benefit_model_name, session) + return version.id + else: + return None diff --git a/test_fixtures/fixtures_economic_impact_response.py b/test_fixtures/fixtures_economic_impact_response.py index 51da689..5ae9b29 100644 --- a/test_fixtures/fixtures_economic_impact_response.py +++ b/test_fixtures/fixtures_economic_impact_response.py @@ -5,7 +5,6 @@ program_statistics, decile_impacts) for testing _build_response(). """ - from sqlmodel import Session from policyengine_api.models import ( diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py index c2e397f..2b7ce39 100644 --- a/test_fixtures/fixtures_simulations_standalone.py +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -1,6 +1,5 @@ """Fixtures and helpers for standalone simulation endpoint tests.""" - from policyengine_api.models import ( Dataset, Household, diff --git a/test_fixtures/fixtures_version_filter.py b/test_fixtures/fixtures_version_filter.py new file mode 100644 index 0000000..c7304ad --- /dev/null +++ b/test_fixtures/fixtures_version_filter.py @@ -0,0 +1,207 @@ +"""Fixtures and helpers for version-filtering tests. + +Provides reusable model/version/parameter/variable factories and composite +fixtures for testing version-filter behaviour across endpoints. +""" + +from datetime import datetime, timezone + +import pytest + +from policyengine_api.models import ( + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, +) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +MODEL_NAMES = { + "US": "policyengine-us", + "UK": "policyengine-uk", +} + +VERSION_TIMESTAMPS = { + "V1": datetime(2025, 1, 1, tzinfo=timezone.utc), + "V2": datetime(2025, 6, 1, tzinfo=timezone.utc), +} + + +# ----------------------------------------------------------------------------- +# Factory Functions +# ----------------------------------------------------------------------------- + + +def create_model( + session, name: str = MODEL_NAMES["US"], description: str = "Test model" +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel.""" + model = TaxBenefitModel(name=name, description=description) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_version( + session, + model: TaxBenefitModel, + version: str = "1.0.0", + created_at: datetime | None = None, +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion.""" + ver = TaxBenefitModelVersion( + model_id=model.id, + version=version, + description=f"Version {version}", + **({"created_at": created_at} if created_at else {}), + ) + session.add(ver) + session.commit() + session.refresh(ver) + return ver + + +def create_parameter( + session, + model_version: TaxBenefitModelVersion, + name: str, + label: str = "", + description: str | None = None, +) -> Parameter: + """Create and persist a Parameter.""" + param = Parameter( + name=name, + label=label or name.rsplit(".", 1)[-1], + description=description, + tax_benefit_model_version_id=model_version.id, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_variable( + session, + model_version: TaxBenefitModelVersion, + name: str, + entity: str = "person", + description: str | None = None, +) -> Variable: + """Create and persist a Variable.""" + var = Variable( + name=name, + entity=entity, + description=description, + tax_benefit_model_version_id=model_version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var + + +def create_parameter_value( + session, + parameter_id, + value: int | float | dict, + policy_id=None, + start_date: datetime | None = None, +) -> ParameterValue: + """Create and persist a ParameterValue.""" + pv = ParameterValue( + parameter_id=parameter_id, + value_json=value, + start_date=start_date or datetime.now(timezone.utc), + policy_id=policy_id, + ) + session.add(pv) + session.commit() + session.refresh(pv) + return pv + + +def create_policy( + session, + model: TaxBenefitModel, + name: str = "Test Policy", +) -> Policy: + """Create and persist a Policy.""" + policy = Policy( + name=name, + description=f"Policy: {name}", + tax_benefit_model_id=model.id, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def add_params_bulk(session, version, names_and_labels): + """Bulk-add parameters. names_and_labels is [(name, label), ...].""" + for name, label in names_and_labels: + session.add( + Parameter( + name=name, + label=label, + tax_benefit_model_version_id=version.id, + ) + ) + session.commit() + + +# ----------------------------------------------------------------------------- +# Composite Fixtures — single model + single version +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def us_model(session): + """Create a policyengine-us model.""" + return create_model(session, MODEL_NAMES["US"], "US model") + + +@pytest.fixture +def uk_model(session): + """Create a policyengine-uk model.""" + return create_model(session, MODEL_NAMES["UK"], "UK model") + + +@pytest.fixture +def us_version(session, us_model): + """Create a single US model version.""" + return create_version(session, us_model, "1.0") + + +@pytest.fixture +def uk_version(session, uk_model): + """Create a single UK model version.""" + return create_version(session, uk_model, "1.0") + + +# ----------------------------------------------------------------------------- +# Composite Fixtures — single model + TWO versions (for version-filter tests) +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def us_two_versions(session, us_model): + """Create two US versions: v1 (old) and v2 (latest).""" + v1 = create_version(session, us_model, "1.0", VERSION_TIMESTAMPS["V1"]) + v2 = create_version(session, us_model, "2.0", VERSION_TIMESTAMPS["V2"]) + return v1, v2 + + +@pytest.fixture +def uk_two_versions(session, uk_model): + """Create two UK versions: v1 (old) and v2 (latest).""" + v1 = create_version(session, uk_model, "1.0", VERSION_TIMESTAMPS["V1"]) + v2 = create_version(session, uk_model, "2.0", VERSION_TIMESTAMPS["V2"]) + return v1, v2 diff --git a/tests/test_economic_impact_response.py b/tests/test_economic_impact_response.py index 18036b8..79b1ef4 100644 --- a/tests/test_economic_impact_response.py +++ b/tests/test_economic_impact_response.py @@ -4,7 +4,6 @@ intra_decile, program_statistics, detailed_budget, and decile_impacts. """ - from policyengine_api.api.analysis import _build_response, _safe_float from policyengine_api.models import ReportStatus from test_fixtures.fixtures_economic_impact_response import ( diff --git a/tests/test_parameter_values.py b/tests/test_parameter_values.py new file mode 100644 index 0000000..5cc17c4 --- /dev/null +++ b/tests/test_parameter_values.py @@ -0,0 +1,303 @@ +"""Tests for GET /parameter-values/ and GET /parameter-values/{id} endpoints.""" + +from datetime import datetime, timezone +from uuid import uuid4 + +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_parameter, + create_parameter_value, + create_policy, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 +) + +# ----------------------------------------------------------------------------- +# GET /parameter-values/ — basic +# ----------------------------------------------------------------------------- + + +class TestListParameterValues: + def test_given_no_values_then_returns_empty_list(self, client): + """Empty database returns an empty list.""" + response = client.get("/parameter-values") + assert response.status_code == 200 + assert response.json() == [] + + def test_given_values_exist_then_returns_list( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns parameter values that exist.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + create_parameter_value(session, param.id, 0.2) + + data = client.get("/parameter-values").json() + assert len(data) == 1 + + def test_given_parameter_id_then_filters_by_parameter( + self, + client, + session, + us_version, # noqa: F811 + ): + """Filters values to a specific parameter.""" + p1 = create_parameter(session, us_version, "gov.rate", "Rate") + p2 = create_parameter(session, us_version, "gov.threshold", "Threshold") + create_parameter_value(session, p1.id, 0.2) + create_parameter_value(session, p2.id, 12570) + + data = client.get(f"/parameter-values?parameter_id={p1.id}").json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p1.id) + + def test_given_policy_id_then_filters_by_policy( + self, + client, + session, + us_version, # noqa: F811 + us_model, # noqa: F811 + ): + """Filters values to a specific policy.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + policy = create_policy(session, us_model, "Reform A") + create_parameter_value(session, param.id, 0.2) + create_parameter_value(session, param.id, 0.25, policy_id=policy.id) + + data = client.get(f"/parameter-values?policy_id={policy.id}").json() + assert len(data) == 1 + assert data[0]["policy_id"] == str(policy.id) + + def test_given_combined_parameter_and_policy_filters( + self, + client, + session, + us_version, # noqa: F811 + us_model, # noqa: F811 + ): + """parameter_id + policy_id work together.""" + p1 = create_parameter(session, us_version, "gov.rate", "Rate") + p2 = create_parameter(session, us_version, "gov.threshold", "Threshold") + policy = create_policy(session, us_model, "Reform A") + create_parameter_value(session, p1.id, 0.2, policy_id=policy.id) + create_parameter_value(session, p2.id, 12570, policy_id=policy.id) + create_parameter_value(session, p1.id, 0.15) + + data = client.get( + f"/parameter-values?parameter_id={p1.id}&policy_id={policy.id}" + ).json() + assert len(data) == 1 + + def test_given_limit_then_returns_at_most_n( + self, + client, + session, + us_version, # noqa: F811 + ): + """Limit caps the number of results.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + for i in range(5): + create_parameter_value( + session, + param.id, + i * 0.1, + start_date=datetime(2020 + i, 1, 1, tzinfo=timezone.utc), + ) + + assert len(client.get("/parameter-values?limit=2").json()) == 2 + + def test_given_skip_then_skips_first_n( + self, + client, + session, + us_version, # noqa: F811 + ): + """Skip omits the first N results.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + for i in range(5): + create_parameter_value( + session, + param.id, + i * 0.1, + start_date=datetime(2020 + i, 1, 1, tzinfo=timezone.utc), + ) + + assert len(client.get("/parameter-values?skip=3&limit=10").json()) == 2 + + def test_results_ordered_by_start_date_desc( + self, + client, + session, + us_version, # noqa: F811 + ): + """Parameter values come back sorted by start_date descending.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + create_parameter_value( + session, + param.id, + 0.1, + start_date=datetime(2020, 1, 1, tzinfo=timezone.utc), + ) + create_parameter_value( + session, + param.id, + 0.2, + start_date=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + + data = client.get("/parameter-values").json() + assert len(data) == 2 + # Most recent first + dates = [d["start_date"] for d in data] + assert dates[0] > dates[1] + + +# ----------------------------------------------------------------------------- +# GET /parameter-values/ — version filtering +# ----------------------------------------------------------------------------- + + +class TestListParameterValuesVersionFilter: + def test_given_model_name_then_returns_only_latest_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Model name resolves to latest version; old-version param values excluded.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get( + f"/parameter-values?tax_benefit_model_name={MODEL_NAMES['US']}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p_new.id) + + def test_given_explicit_version_id_then_returns_that_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Explicit version_id pins to a specific version.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get( + f"/parameter-values?tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p_old.id) + + def test_given_both_then_version_id_takes_precedence( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """version_id overrides model_name.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get( + f"/parameter-values?tax_benefit_model_name={MODEL_NAMES['US']}" + f"&tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p_old.id) + + def test_given_no_filters_then_returns_all_versions( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Without model/version filter, values from all versions are returned.""" + v1, v2 = us_two_versions + p_old = create_parameter(session, v1, "gov.old", "Old") + p_new = create_parameter(session, v2, "gov.new", "New") + create_parameter_value(session, p_old.id, 0.1) + create_parameter_value(session, p_new.id, 0.2) + + data = client.get("/parameter-values").json() + assert len(data) == 2 + + def test_given_nonexistent_model_name_then_returns_404(self, client): + """Unknown model name returns 404.""" + response = client.get( + "/parameter-values?tax_benefit_model_name=nonexistent-model" + ) + assert response.status_code == 404 + + def test_given_version_filter_combined_with_parameter_id( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Version filter + parameter_id work together.""" + v1, v2 = us_two_versions + p1 = create_parameter(session, v2, "gov.rate", "Rate") + p2 = create_parameter(session, v2, "gov.threshold", "Threshold") + create_parameter_value(session, p1.id, 0.2) + create_parameter_value(session, p2.id, 12570) + + data = client.get( + f"/parameter-values?tax_benefit_model_name={MODEL_NAMES['US']}" + f"¶meter_id={p1.id}" + ).json() + assert len(data) == 1 + assert data[0]["parameter_id"] == str(p1.id) + + +# ----------------------------------------------------------------------------- +# GET /parameter-values/{id} +# ----------------------------------------------------------------------------- + + +class TestGetParameterValue: + def test_given_valid_id_then_returns_value( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns the full parameter value data.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + pv = create_parameter_value(session, param.id, 0.2) + + response = client.get(f"/parameter-values/{pv.id}") + assert response.status_code == 200 + assert response.json()["parameter_id"] == str(param.id) + + def test_given_nonexistent_id_then_returns_404(self, client): + """Unknown UUID returns 404.""" + response = client.get(f"/parameter-values/{uuid4()}") + assert response.status_code == 404 + + def test_response_shape_matches_parameter_value_read( + self, + client, + session, + us_version, # noqa: F811 + ): + """Response contains all ParameterValueRead fields.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + pv = create_parameter_value(session, param.id, 0.2) + + data = client.get(f"/parameter-values/{pv.id}").json() + for field in ("id", "parameter_id", "value_json", "start_date", "created_at"): + assert field in data diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 50bb213..4ee41ba 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,204 +1,256 @@ -"""Tests for parameter and parameter-value endpoints.""" +"""Tests for GET /parameters/ and GET /parameters/{id} endpoints.""" from uuid import uuid4 -import pytest - -from test_fixtures.fixtures_parameters import ( +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, create_parameter, - create_parameter_value, - create_parameter_values_batch, - create_policy, - model_version, # noqa: F401 - pytest fixture + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 ) # ----------------------------------------------------------------------------- -# Parameter Endpoint Tests +# GET /parameters/ — basic # ----------------------------------------------------------------------------- -def test__given_parameters_endpoint_called__then_returns_list(client): - """GET /parameters returns a list.""" - # Given - endpoint = "/parameters" - - # When - response = client.get(endpoint) - - # Then - assert response.status_code == 200 - assert isinstance(response.json(), list) - - -def test__given_nonexistent_parameter_id__then_returns_404(client): - """GET /parameters/{id} returns 404 for non-existent parameter.""" - # Given - fake_id = uuid4() - - # When - response = client.get(f"/parameters/{fake_id}") - - # Then - assert response.status_code == 404 - - -# ----------------------------------------------------------------------------- -# Parameter Value Endpoint Tests -# ----------------------------------------------------------------------------- - - -def test__given_parameter_values_endpoint_called__then_returns_list(client): - """GET /parameter-values returns a list.""" - # Given - endpoint = "/parameter-values" - - # When - response = client.get(endpoint) - - # Then - assert response.status_code == 200 - assert isinstance(response.json(), list) - - -def test__given_nonexistent_parameter_value_id__then_returns_404(client): - """GET /parameter-values/{id} returns 404 for non-existent parameter value.""" - # Given - fake_id = uuid4() - - # When - response = client.get(f"/parameter-values/{fake_id}") - - # Then - assert response.status_code == 404 +class TestListParameters: + def test_given_no_params_then_returns_empty_list(self, client): + """Empty database returns an empty list.""" + response = client.get("/parameters") + assert response.status_code == 200 + assert response.json() == [] + + def test_given_parameters_exist_then_returns_list( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns parameters that exist.""" + create_parameter(session, us_version, "gov.rate", "Rate") + response = client.get("/parameters") + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_given_search_by_name_then_returns_matching( + self, + client, + session, + us_version, # noqa: F811 + ): + """Search filter matches parameter name.""" + create_parameter(session, us_version, "gov.tax.rate", "Rate") + create_parameter(session, us_version, "gov.benefit.amount", "Amount") + + response = client.get("/parameters?search=tax") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.tax.rate" + + def test_given_search_by_label_then_returns_matching( + self, + client, + session, + us_version, # noqa: F811 + ): + """Search filter matches parameter label (case-insensitive).""" + create_parameter(session, us_version, "gov.x", "Basic Rate") + create_parameter(session, us_version, "gov.y", "Amount") + + response = client.get("/parameters?search=basic") + data = response.json() + assert len(data) == 1 + assert data[0]["label"] == "Basic Rate" + + def test_given_search_by_description_then_returns_matching( + self, + client, + session, + us_version, # noqa: F811 + ): + """Search filter matches parameter description.""" + create_parameter( + session, us_version, "gov.x", "X", description="The income tax rate" + ) + create_parameter(session, us_version, "gov.y", "Y", description="A benefit") + + response = client.get("/parameters?search=income") + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.x" + + def test_given_limit_then_returns_at_most_n( + self, + client, + session, + us_version, # noqa: F811 + ): + """Limit caps the number of results.""" + for i in range(5): + create_parameter(session, us_version, f"gov.p{i}", f"P{i}") + + response = client.get("/parameters?limit=2") + assert len(response.json()) == 2 + + def test_given_skip_then_skips_first_n( + self, + client, + session, + us_version, # noqa: F811 + ): + """Skip omits the first N results.""" + for i in range(5): + create_parameter(session, us_version, f"gov.p{i}", f"P{i}") + + response = client.get("/parameters?skip=3&limit=10") + assert len(response.json()) == 2 + + def test_results_ordered_by_name( + self, + client, + session, + us_version, # noqa: F811 + ): + """Parameters come back sorted alphabetically by name.""" + create_parameter(session, us_version, "gov.zzz", "Z") + create_parameter(session, us_version, "gov.aaa", "A") + names = [p["name"] for p in client.get("/parameters").json()] + assert names == ["gov.aaa", "gov.zzz"] # ----------------------------------------------------------------------------- -# Parameter Value Filtering Tests +# GET /parameters/ — version filtering # ----------------------------------------------------------------------------- -def test__given_parameter_id_filter__then_returns_only_matching_values( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?parameter_id=X returns only values for that parameter.""" - # Given - param1 = create_parameter(session, model_version, "test.param1", "Test Param 1") - param2 = create_parameter(session, model_version, "test.param2", "Test Param 2") - create_parameter_value(session, param1.id, 100) - create_parameter_value(session, param2.id, 200) - - # When - response = client.get(f"/parameter-values?parameter_id={param1.id}") - - # Then - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["parameter_id"] == str(param1.id) - - -def test__given_policy_id_filter__then_returns_only_matching_values( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?policy_id=X returns only values for that policy.""" - # Given - param = create_parameter(session, model_version, "test.param", "Test Param") - policy = create_policy(session, "Test Policy", model_version) - create_parameter_value(session, param.id, 100, policy_id=None) # baseline - create_parameter_value(session, param.id, 150, policy_id=policy.id) # reform - - # When - response = client.get(f"/parameter-values?policy_id={policy.id}") - - # Then - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["policy_id"] == str(policy.id) - assert data[0]["value_json"] == 150 - - -def test__given_both_parameter_and_policy_filters__then_returns_matching_intersection( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?parameter_id=X&policy_id=Y returns intersection.""" - # Given - param1 = create_parameter( - session, model_version, "test.both.param1", "Test Both Param 1" - ) - param2 = create_parameter( - session, model_version, "test.both.param2", "Test Both Param 2" - ) - policy = create_policy(session, "Test Both Policy", model_version) - - create_parameter_value(session, param1.id, 100, policy_id=None) # baseline - create_parameter_value(session, param1.id, 150, policy_id=policy.id) # target - create_parameter_value(session, param2.id, 200, policy_id=policy.id) # other - - # When - response = client.get( - f"/parameter-values?parameter_id={param1.id}&policy_id={policy.id}" - ) - - # Then - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["parameter_id"] == str(param1.id) - assert data[0]["policy_id"] == str(policy.id) - assert data[0]["value_json"] == 150 +class TestListParametersVersionFilter: + def test_given_model_name_then_returns_only_latest_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Model name resolves to latest version; old-version params excluded.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get( + f"/parameters?tax_benefit_model_name={MODEL_NAMES['US']}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.new" + + def test_given_explicit_version_id_then_returns_that_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Explicit version_id pins to a specific version.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get(f"/parameters?tax_benefit_model_version_id={v1.id}").json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old" + + def test_given_both_then_version_id_takes_precedence( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """version_id overrides model_name.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get( + f"/parameters?tax_benefit_model_name={MODEL_NAMES['US']}" + f"&tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.old" + + def test_given_no_filters_then_returns_all_versions( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Without model/version filter, params from all versions are returned.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.get("/parameters").json() + assert len(data) == 2 + + def test_given_nonexistent_model_name_then_returns_404(self, client): + """Unknown model name → 404.""" + response = client.get("/parameters?tax_benefit_model_name=nonexistent-model") + assert response.status_code == 404 + + def test_given_search_combined_with_version_filter( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Search + version filter work together.""" + v1, v2 = us_two_versions + create_parameter(session, v2, "gov.tax.rate", "Rate") + create_parameter(session, v2, "gov.benefit.amount", "Amount") + + data = client.get( + f"/parameters?tax_benefit_model_name={MODEL_NAMES['US']}&search=tax" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.tax.rate" # ----------------------------------------------------------------------------- -# Parameter Value Pagination Tests +# GET /parameters/{id} # ----------------------------------------------------------------------------- -def test__given_limit_parameter__then_returns_limited_results( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?limit=N returns at most N results.""" - # Given - param = create_parameter( - session, model_version, "test.pagination.param", "Test Pagination Param" - ) - create_parameter_values_batch(session, param.id, count=5) - - # When - response = client.get(f"/parameter-values?parameter_id={param.id}&limit=2") - - # Then - assert response.status_code == 200 - assert len(response.json()) == 2 - - -def test__given_skip_parameter__then_skips_specified_results( - client, - session, - model_version, # noqa: F811 -): - """GET /parameter-values?skip=N skips first N results.""" - # Given - param = create_parameter( - session, model_version, "test.skip.param", "Test Skip Param" - ) - create_parameter_values_batch(session, param.id, count=5) - - # When - response = client.get(f"/parameter-values?parameter_id={param.id}&skip=3&limit=10") - - # Then - assert response.status_code == 200 - assert len(response.json()) == 2 # 5 total - 3 skipped = 2 remaining - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +class TestGetParameter: + def test_given_valid_id_then_returns_parameter( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns the full parameter data.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + response = client.get(f"/parameters/{param.id}") + assert response.status_code == 200 + assert response.json()["name"] == "gov.rate" + + def test_given_nonexistent_id_then_returns_404(self, client): + """Unknown UUID → 404.""" + response = client.get(f"/parameters/{uuid4()}") + assert response.status_code == 404 + + def test_response_shape_matches_parameter_read( + self, + client, + session, + us_version, # noqa: F811 + ): + """Response contains all ParameterRead fields.""" + param = create_parameter(session, us_version, "gov.rate", "Rate") + data = client.get(f"/parameters/{param.id}").json() + for field in ( + "id", + "name", + "label", + "created_at", + "tax_benefit_model_version_id", + ): + assert field in data diff --git a/tests/test_parameters_by_name.py b/tests/test_parameters_by_name.py index 81bd360..5d089e4 100644 --- a/tests/test_parameters_by_name.py +++ b/tests/test_parameters_by_name.py @@ -1,238 +1,257 @@ """Tests for POST /parameters/by-name endpoint.""" -import pytest -from policyengine_api.models import ( - Parameter, - TaxBenefitModel, - TaxBenefitModelVersion, +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_parameter, + uk_model, # noqa: F401 + uk_version, # noqa: F401 + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def us_version(session): - """Create a policyengine-us model and version.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="US v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -def create_parameter(session, model_version, name: str, label: str) -> Parameter: - """Create and persist a Parameter.""" - param = Parameter( - name=name, - label=label, - tax_benefit_model_version_id=model_version.id, - ) - session.add(param) - session.commit() - session.refresh(param) - return param +# ----------------------------------------------------------------------------- +# Happy-path lookups +# ----------------------------------------------------------------------------- class TestParametersByName: - """Tests for looking up parameters by their exact names.""" - - def test_returns_matching_parameters(self, client, session, us_version): - """Given known parameter names, returns their full metadata.""" - create_parameter(session, us_version, "gov.tax.rate", "Tax rate") + def test_given_known_names_then_returns_matching( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns full metadata for each matching name.""" + create_parameter(session, us_version, "gov.tax.rate", "Rate") create_parameter(session, us_version, "gov.tax.threshold", "Threshold") - response = client.post( + data = client.post( "/parameters/by-name", json={ "names": ["gov.tax.rate", "gov.tax.threshold"], - "country_id": "us", + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) + ).json() - assert response.status_code == 200 - data = response.json() assert len(data) == 2 - returned_names = {p["name"] for p in data} - assert returned_names == {"gov.tax.rate", "gov.tax.threshold"} + assert {p["name"] for p in data} == {"gov.tax.rate", "gov.tax.threshold"} - def test_returns_empty_list_for_empty_names(self, client): - """Given an empty names list, returns an empty list.""" - response = client.post( + def test_given_empty_names_then_returns_empty_list( + self, + client, + session, + us_version, # noqa: F811 + ): + """Empty names list → empty response (no DB query).""" + data = client.post( "/parameters/by-name", - json={ - "names": [], - "country_id": "us", - }, - ) - - assert response.status_code == 200 - assert response.json() == [] - - def test_returns_empty_list_for_unknown_names(self, client, session, us_version): - """Given names that don't match any parameter, returns an empty list.""" + json={"names": [], "tax_benefit_model_name": MODEL_NAMES["US"]}, + ).json() + assert data == [] + + def test_given_unknown_names_then_returns_empty_list( + self, + client, + session, + us_version, # noqa: F811 + ): + """Names that don't match anything → empty list.""" create_parameter(session, us_version, "gov.exists", "Exists") - response = client.post( + data = client.post( "/parameters/by-name", json={ - "names": ["gov.does_not_exist", "gov.also_missing"], - "country_id": "us", + "names": ["gov.nope", "gov.also_missing"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - assert response.json() == [] - - def test_returns_only_matching_when_mix_of_known_and_unknown( - self, client, session, us_version + ).json() + assert data == [] + + def test_given_mixed_names_then_returns_only_known( + self, + client, + session, + us_version, # noqa: F811 ): - """Given a mix of known and unknown names, returns only the known ones.""" - create_parameter(session, us_version, "gov.real", "Real param") + """Only matching names are returned; unknowns silently omitted.""" + create_parameter(session, us_version, "gov.real", "Real") - response = client.post( + data = client.post( "/parameters/by-name", json={ "names": ["gov.real", "gov.fake"], - "country_id": "us", + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 assert data[0]["name"] == "gov.real" - def test_filters_by_country(self, client, session): - """Parameters from a different country are excluded.""" - # Create two models - model_uk = TaxBenefitModel(name="policyengine-uk", description="UK") - model_us = TaxBenefitModel(name="policyengine-us", description="US") - session.add(model_uk) - session.add(model_us) - session.commit() - session.refresh(model_uk) - session.refresh(model_us) - - ver_uk = TaxBenefitModelVersion( - model_id=model_uk.id, version="1.0", description="UK v1" - ) - ver_us = TaxBenefitModelVersion( - model_id=model_us.id, version="1.0", description="US v1" - ) - session.add(ver_uk) - session.add(ver_us) - session.commit() - session.refresh(ver_uk) - session.refresh(ver_us) - - # Same parameter name in both models - create_parameter(session, ver_uk, "gov.shared_name", "UK version") - create_parameter(session, ver_us, "gov.shared_name", "US version") - - # Request only UK - response = client.post( + def test_given_single_name_then_returns_one( + self, + client, + session, + us_version, # noqa: F811 + ): + """Single-element lookup works.""" + create_parameter(session, us_version, "gov.single", "Single") + data = client.post( "/parameters/by-name", json={ - "names": ["gov.shared_name"], - "country_id": "uk", + "names": ["gov.single"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - data = response.json() + ).json() assert len(data) == 1 - assert data[0]["label"] == "UK version" - def test_response_shape_matches_parameter_read(self, client, session, us_version): - """Returned objects have the same shape as ParameterRead.""" - create_parameter(session, us_version, "gov.shape_test", "Shape test") + def test_results_ordered_by_name( + self, + client, + session, + us_version, # noqa: F811 + ): + """Response is sorted alphabetically by name.""" + create_parameter(session, us_version, "gov.zzz", "Z") + create_parameter(session, us_version, "gov.aaa", "A") + create_parameter(session, us_version, "gov.mmm", "M") + + names = [ + p["name"] + for p in client.post( + "/parameters/by-name", + json={ + "names": ["gov.zzz", "gov.aaa", "gov.mmm"], + "tax_benefit_model_name": MODEL_NAMES["US"], + }, + ).json() + ] + assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] - response = client.post( + def test_response_shape(self, client, session, us_version): # noqa: F811 + """Each returned object has the ParameterRead fields.""" + create_parameter(session, us_version, "gov.shape", "Shape") + param = client.post( "/parameters/by-name", json={ - "names": ["gov.shape_test"], - "country_id": "us", + "names": ["gov.shape"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) - - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - param = data[0] - assert "id" in param - assert "name" in param - assert "label" in param - assert "created_at" in param - assert "tax_benefit_model_version_id" in param - - def test_results_ordered_by_name(self, client, session, us_version): - """Returned parameters are sorted alphabetically by name.""" - create_parameter(session, us_version, "gov.zzz", "Last") - create_parameter(session, us_version, "gov.aaa", "First") - create_parameter(session, us_version, "gov.mmm", "Middle") + ).json()[0] + for field in ( + "id", + "name", + "label", + "created_at", + "tax_benefit_model_version_id", + ): + assert field in param + + +# ----------------------------------------------------------------------------- +# Model isolation +# ----------------------------------------------------------------------------- + + +class TestParametersByNameModelIsolation: + def test_given_two_models_then_returns_only_requested( + self, + client, + session, + us_version, # noqa: F811 + uk_version, # noqa: F811 + ): + """Parameters from the other model are excluded.""" + create_parameter(session, us_version, "gov.shared", "US") + create_parameter(session, uk_version, "gov.shared", "UK") - response = client.post( + data = client.post( "/parameters/by-name", json={ - "names": ["gov.zzz", "gov.aaa", "gov.mmm"], - "country_id": "us", + "names": ["gov.shared"], + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) + ).json() + assert len(data) == 1 + assert data[0]["label"] == "UK" - assert response.status_code == 200 - names = [p["name"] for p in response.json()] - assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] - def test_missing_country_id_returns_422(self, client): - """Request without country_id is rejected.""" - response = client.post( - "/parameters/by-name", - json={"names": ["gov.something"]}, - ) +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- + +class TestParametersByNameValidation: + def test_given_missing_model_name_then_422(self, client): + """Omitting tax_benefit_model_name → 422.""" + response = client.post("/parameters/by-name", json={"names": ["gov.x"]}) assert response.status_code == 422 - def test_invalid_country_id_returns_422(self, client): - """Request with invalid country_id is rejected.""" + def test_given_missing_names_then_422(self, client): + """Omitting names → 422.""" response = client.post( "/parameters/by-name", - json={"names": ["gov.something"], "country_id": "invalid"}, + json={"tax_benefit_model_name": MODEL_NAMES["US"]}, ) - assert response.status_code == 422 - def test_missing_names_field_returns_422(self, client): - """Request without names field is rejected.""" + def test_given_nonexistent_model_name_then_404(self, client, session): + """Model that doesn't exist → 404 from resolve_model_version_id.""" response = client.post( "/parameters/by-name", - json={"country_id": "us"}, + json={ + "names": ["gov.x"], + "tax_benefit_model_name": "nonexistent-model", + }, ) + assert response.status_code == 404 - assert response.status_code == 422 - def test_single_name_lookup(self, client, session, us_version): - """Looking up a single parameter name works.""" - create_parameter(session, us_version, "gov.single", "Single param") +# ----------------------------------------------------------------------------- +# Version filtering +# ----------------------------------------------------------------------------- - response = client.post( + +class TestParametersByNameVersionFilter: + def test_given_model_name_only_then_defaults_to_latest( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Model name resolves to latest version.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") + + data = client.post( "/parameters/by-name", json={ - "names": ["gov.single"], - "country_id": "us", + "names": ["gov.old", "gov.new"], + "tax_benefit_model_name": MODEL_NAMES["US"], }, - ) + ).json() + assert len(data) == 1 + assert data[0]["name"] == "gov.new" + + def test_given_explicit_version_id_then_returns_that_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Explicit version_id overrides latest-version default.""" + v1, v2 = us_two_versions + create_parameter(session, v1, "gov.old", "Old") + create_parameter(session, v2, "gov.new", "New") - assert response.status_code == 200 - data = response.json() + data = client.post( + "/parameters/by-name", + json={ + "names": ["gov.old", "gov.new"], + "tax_benefit_model_name": MODEL_NAMES["US"], + "tax_benefit_model_version_id": str(v1.id), + }, + ).json() assert len(data) == 1 - assert data[0]["name"] == "gov.single" + assert data[0]["name"] == "gov.old" diff --git a/tests/test_parameters_children.py b/tests/test_parameters_children.py index d788179..07a4af1 100644 --- a/tests/test_parameters_children.py +++ b/tests/test_parameters_children.py @@ -1,76 +1,30 @@ """Tests for GET /parameters/children endpoint.""" -import pytest -from policyengine_api.models import ( - Parameter, - TaxBenefitModel, - TaxBenefitModelVersion, +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + add_params_bulk, + uk_model, # noqa: F401 + uk_two_versions, # noqa: F401 + uk_version, # noqa: F401 + us_model, # noqa: F401 + us_version, # noqa: F401 ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def uk_version(session): - """Create a policyengine-uk model and version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="UK v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -@pytest.fixture -def us_version(session): - """Create a policyengine-us model and version.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="US v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -def _add_params(session, version, names_and_labels): - """Bulk-add parameters. names_and_labels is [(name, label), ...].""" - for name, label in names_and_labels: - session.add( - Parameter( - name=name, - label=label, - tax_benefit_model_version_id=version.id, - ) - ) - session.commit() - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# Tree structure +# ----------------------------------------------------------------------------- class TestParameterChildrenBasic: - """Basic tree structure tests.""" - - def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version): + def test_returns_nodes_for_intermediate_paths( + self, + client, + session, + uk_version, # noqa: F811 + ): """Parameters at gov.hmrc.x and gov.dwp.x produce nodes for hmrc and dwp.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -81,7 +35,11 @@ def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version) ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, ) assert response.status_code == 200 @@ -95,9 +53,14 @@ def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version) assert child["type"] == "node" assert child["child_count"] > 0 - def test_returns_leaf_parameters(self, client, session, uk_version): + def test_returns_leaf_parameters( + self, + client, + session, + uk_version, # noqa: F811 + ): """Direct child parameters are returned with type='parameter'.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -107,7 +70,11 @@ def test_returns_leaf_parameters(self, client, session, uk_version): ) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, ) assert response.status_code == 200 @@ -123,9 +90,14 @@ def test_returns_leaf_parameters(self, client, session, uk_version): node = next(c for c in children if c["type"] == "node") assert node["path"] == "gov.hmrc" - def test_mixed_nodes_and_leaves(self, client, session, uk_version): + def test_mixed_nodes_and_leaves( + self, + client, + session, + uk_version, # noqa: F811 + ): """Both nodes and leaf parameters can appear at the same level.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -135,23 +107,34 @@ def test_mixed_nodes_and_leaves(self, client, session, uk_version): ], ) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) + children = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] - children = response.json()["children"] types = {c["path"]: c["type"] for c in children} assert types["gov.hmrc"] == "node" assert types["gov.flat_rate"] == "parameter" assert types["gov.threshold"] == "parameter" -class TestChildCount: - """Tests for child_count accuracy.""" +# ----------------------------------------------------------------------------- +# Child counts +# ----------------------------------------------------------------------------- + - def test_child_count_reflects_total_descendants(self, client, session, uk_version): +class TestChildCount: + def test_child_count_reflects_total_descendants( + self, + client, + session, + uk_version, # noqa: F811 + ): """child_count counts all leaf parameters under the node.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -161,18 +144,26 @@ def test_child_count_reflects_total_descendants(self, client, session, uk_versio ], ) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) + children = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] - children = response.json()["children"] hmrc = children[0] assert hmrc["path"] == "gov.hmrc" assert hmrc["child_count"] == 3 - def test_nested_child_count(self, client, session, uk_version): + def test_nested_child_count( + self, + client, + session, + uk_version, # noqa: F811 + ): """Querying a deeper level gives accurate child counts.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -182,92 +173,156 @@ def test_nested_child_count(self, client, session, uk_version): ], ) - response = client.get( + children = client.get( "/parameters/children", - params={"country_id": "uk", "parent_path": "gov.hmrc"}, - ) + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov.hmrc", + }, + ).json()["children"] - children = response.json()["children"] assert len(children) == 2 income_tax = next(c for c in children if c["path"] == "gov.hmrc.income_tax") ni = next(c for c in children if c["path"] == "gov.hmrc.ni") assert income_tax["child_count"] == 2 assert ni["child_count"] == 1 - def test_leaf_has_no_child_count(self, client, session, uk_version): + def test_leaf_has_no_child_count( + self, + client, + session, + uk_version, # noqa: F811 + ): """Leaf parameters have child_count=None.""" - _add_params(session, uk_version, [("gov.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.rate", "Rate")]) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) + children = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] - children = response.json()["children"] assert len(children) == 1 assert children[0]["child_count"] is None -class TestCountryFiltering: - """Tests for country_id filtering.""" +# ----------------------------------------------------------------------------- +# Model isolation +# ----------------------------------------------------------------------------- + - def test_uk_country_id(self, client, session, uk_version): - """country_id=uk returns UK parameters.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) +class TestParameterChildrenModelIsolation: + def test_given_uk_model_then_returns_uk_params( + self, + client, + session, + uk_version, # noqa: F811 + ): + """policyengine-uk returns UK parameters.""" + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, ) assert response.status_code == 200 assert len(response.json()["children"]) == 1 - def test_us_country_id(self, client, session, us_version): - """country_id=us returns US parameters.""" - _add_params(session, us_version, [("gov.irs.rate", "Rate")]) + def test_given_us_model_then_returns_us_params( + self, + client, + session, + us_version, # noqa: F811 + ): + """policyengine-us returns US parameters.""" + add_params_bulk(session, us_version, [("gov.irs.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["US"], + "parent_path": "gov", + }, ) assert response.status_code == 200 assert len(response.json()["children"]) == 1 - def test_country_isolation(self, client, session, uk_version, us_version): - """Parameters from a different country are excluded.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "UK rate")]) - _add_params(session, us_version, [("gov.irs.rate", "US rate")]) - - uk_response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) - us_response = client.get( - "/parameters/children", params={"country_id": "us", "parent_path": "gov"} - ) + def test_given_two_models_then_returns_only_requested( + self, + client, + session, + uk_version, # noqa: F811 + us_version, # noqa: F811 + ): + """Parameters from a different model are excluded.""" + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "UK rate")]) + add_params_bulk(session, us_version, [("gov.irs.rate", "US rate")]) + + uk_paths = [ + c["path"] + for c in client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] + ] + us_paths = [ + c["path"] + for c in client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["US"], + "parent_path": "gov", + }, + ).json()["children"] + ] - uk_paths = [c["path"] for c in uk_response.json()["children"]] - us_paths = [c["path"] for c in us_response.json()["children"]] assert uk_paths == ["gov.hmrc"] assert us_paths == ["gov.irs"] - def test_invalid_country_id_returns_422(self, client): - """An invalid country_id is rejected by Literal validation.""" - response = client.get( - "/parameters/children", - params={"country_id": "fr", "parent_path": "gov"}, - ) +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- + + +class TestParameterChildrenValidation: + def test_given_missing_model_name_then_422(self, client): + """Request without tax_benefit_model_name returns 422.""" + response = client.get("/parameters/children", params={"parent_path": "gov"}) assert response.status_code == 422 -class TestEdgeCases: - """Tests for edge cases and special inputs.""" +# ----------------------------------------------------------------------------- +# Edge cases +# ----------------------------------------------------------------------------- + - def test_empty_parent_path(self, client, session, uk_version): +class TestParameterChildrenEdgeCases: + def test_empty_parent_path( + self, + client, + session, + uk_version, # noqa: F811 + ): """Empty parent_path returns top-level children.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": ""} + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "", + }, ) assert response.status_code == 200 @@ -276,21 +331,33 @@ def test_empty_parent_path(self, client, session, uk_version): assert children[0]["path"] == "gov" assert children[0]["type"] == "node" - def test_nonexistent_parent_returns_empty(self, client, session, uk_version): + def test_nonexistent_parent_returns_empty( + self, + client, + session, + uk_version, # noqa: F811 + ): """A parent path with no descendants returns empty children list.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) - response = client.get( + children = client.get( "/parameters/children", - params={"country_id": "uk", "parent_path": "gov.dwp"}, - ) + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov.dwp", + }, + ).json()["children"] - assert response.status_code == 200 - assert response.json()["children"] == [] + assert children == [] - def test_children_sorted_by_path(self, client, session, uk_version): + def test_children_sorted_by_path( + self, + client, + session, + uk_version, # noqa: F811 + ): """Children are returned sorted alphabetically by path.""" - _add_params( + add_params_bulk( session, uk_version, [ @@ -300,100 +367,186 @@ def test_children_sorted_by_path(self, client, session, uk_version): ], ) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) - - paths = [c["path"] for c in response.json()["children"]] + paths = [ + c["path"] + for c in client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] + ] assert paths == ["gov.aaa", "gov.mmm", "gov.zzz"] - def test_node_label_from_path_segment(self, client, session, uk_version): - """Node labels default to the last path segment when no parameter exists.""" - _add_params(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) + def test_node_label_from_path_segment( + self, + client, + session, + uk_version, # noqa: F811 + ): + """Node labels default to the last path segment.""" + add_params_bulk(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) + children = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] - children = response.json()["children"] assert children[0]["label"] == "hmrc" - def test_missing_country_id_returns_422(self, client): - """Request without country_id returns 422.""" - response = client.get("/parameters/children", params={"parent_path": "gov"}) - - assert response.status_code == 422 - - def test_default_parent_path_is_empty(self, client, session, uk_version): + def test_default_parent_path_is_empty( + self, + client, + session, + uk_version, # noqa: F811 + ): """Omitting parent_path defaults to empty string (root level).""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) - - response = client.get("/parameters/children", params={"country_id": "uk"}) - - assert response.status_code == 200 - assert response.json()["parent_path"] == "" - assert len(response.json()["children"]) == 1 + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) - def test_leaf_parameter_includes_full_metadata(self, client, session, uk_version): + data = client.get( + "/parameters/children", + params={"tax_benefit_model_name": MODEL_NAMES["UK"]}, + ).json() + + assert data["parent_path"] == "" + assert len(data["children"]) == 1 + + def test_leaf_parameter_includes_full_metadata( + self, + client, + session, + uk_version, # noqa: F811 + ): """Leaf parameters include the full ParameterRead shape.""" - _add_params(session, uk_version, [("gov.rate", "The rate")]) + add_params_bulk(session, uk_version, [("gov.rate", "The rate")]) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) + param = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"][0]["parameter"] - param = response.json()["children"][0]["parameter"] assert param["name"] == "gov.rate" assert param["label"] == "The rate" - assert "id" in param - assert "created_at" in param - assert "tax_benefit_model_version_id" in param - - def test_node_has_no_parameter_field(self, client, session, uk_version): + for field in ("id", "created_at", "tax_benefit_model_version_id"): + assert field in param + + def test_node_has_no_parameter_field( + self, + client, + session, + uk_version, # noqa: F811 + ): """Nodes do not include the parameter field.""" - _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + add_params_bulk(session, uk_version, [("gov.hmrc.rate", "Rate")]) - response = client.get( - "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} - ) + node = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"][0] - node = response.json()["children"][0] assert node["type"] == "node" assert node["parameter"] is None - def test_deep_nesting(self, client, session, uk_version): + def test_deep_nesting( + self, + client, + session, + uk_version, # noqa: F811 + ): """Works correctly with deeply nested parameter paths.""" - _add_params( + add_params_bulk( session, uk_version, [("gov.hmrc.income_tax.rates.uk[0].rate", "Basic rate")], ) - # Each level should show the correct child for parent, expected_child in [ ("gov", "gov.hmrc"), ("gov.hmrc", "gov.hmrc.income_tax"), ("gov.hmrc.income_tax", "gov.hmrc.income_tax.rates"), ("gov.hmrc.income_tax.rates", "gov.hmrc.income_tax.rates.uk[0]"), ]: - resp = client.get( + children = client.get( "/parameters/children", - params={"country_id": "uk", "parent_path": parent}, - ) - children = resp.json()["children"] + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": parent, + }, + ).json()["children"] assert len(children) == 1 assert children[0]["path"] == expected_child assert children[0]["type"] == "node" # Final level should be a leaf - resp = client.get( + children = client.get( "/parameters/children", params={ - "country_id": "uk", + "tax_benefit_model_name": MODEL_NAMES["UK"], "parent_path": "gov.hmrc.income_tax.rates.uk[0]", }, - ) - children = resp.json()["children"] + ).json()["children"] assert len(children) == 1 assert children[0]["type"] == "parameter" assert children[0]["path"] == "gov.hmrc.income_tax.rates.uk[0].rate" + + +# ----------------------------------------------------------------------------- +# Version filtering +# ----------------------------------------------------------------------------- + + +class TestParameterChildrenVersionFilter: + def test_given_model_name_only_then_defaults_to_latest( + self, + client, + session, + uk_two_versions, # noqa: F811 + ): + """When only model name is given, returns children from latest version.""" + v1, v2 = uk_two_versions + add_params_bulk(session, v1, [("gov.old_param", "Old")]) + add_params_bulk(session, v2, [("gov.new_param", "New")]) + + children = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + }, + ).json()["children"] + + assert len(children) == 1 + assert children[0]["path"] == "gov.new_param" + + def test_given_explicit_version_id_then_returns_that_version( + self, + client, + session, + uk_two_versions, # noqa: F811 + ): + """When version ID is given, returns children from that specific version.""" + v1, v2 = uk_two_versions + add_params_bulk(session, v1, [("gov.old_param", "Old")]) + add_params_bulk(session, v2, [("gov.new_param", "New")]) + + children = client.get( + "/parameters/children", + params={ + "tax_benefit_model_name": MODEL_NAMES["UK"], + "parent_path": "gov", + "tax_benefit_model_version_id": str(v1.id), + }, + ).json()["children"] + + assert len(children) == 1 + assert children[0]["path"] == "gov.old_param" diff --git a/tests/test_variables.py b/tests/test_variables.py index 76cb47a..11e1314 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -1,23 +1,236 @@ -"""Tests for variable endpoints.""" +"""Tests for GET /variables/ and GET /variables/{id} endpoints.""" from uuid import uuid4 -import pytest +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_variable, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 + us_version, # noqa: F401 +) +# ----------------------------------------------------------------------------- +# GET /variables/ — basic +# ----------------------------------------------------------------------------- -def test_list_variables(client): - """List variables returns a list.""" - response = client.get("/variables") - assert response.status_code == 200 - assert isinstance(response.json(), list) +class TestListVariables: + def test_given_no_variables_then_returns_empty_list(self, client): + """Empty database returns an empty list.""" + response = client.get("/variables") + assert response.status_code == 200 + assert response.json() == [] -def test_get_variable_not_found(client): - """Get non-existent variable returns 404.""" - fake_id = uuid4() - response = client.get(f"/variables/{fake_id}") - assert response.status_code == 404 + def test_given_variables_exist_then_returns_list( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns variables that exist.""" + create_variable(session, us_version, "employment_income") + response = client.get("/variables") + assert response.status_code == 200 + assert len(response.json()) == 1 + def test_given_search_by_name_then_returns_matching( + self, + client, + session, + us_version, # noqa: F811 + ): + """Search filter matches variable name.""" + create_variable(session, us_version, "employment_income") + create_variable(session, us_version, "income_tax") -if __name__ == "__main__": - pytest.main([__file__, "-v"]) + data = client.get("/variables?search=employment").json() + assert len(data) == 1 + assert data[0]["name"] == "employment_income" + + def test_given_search_by_description_then_returns_matching( + self, + client, + session, + us_version, # noqa: F811 + ): + """Search filter matches variable description.""" + create_variable( + session, us_version, "var_x", description="Total household income" + ) + create_variable(session, us_version, "var_y", description="Tax liability") + + data = client.get("/variables?search=household").json() + assert len(data) == 1 + assert data[0]["name"] == "var_x" + + def test_given_limit_then_returns_at_most_n( + self, + client, + session, + us_version, # noqa: F811 + ): + """Limit caps the number of results.""" + for i in range(5): + create_variable(session, us_version, f"var_{i}") + + assert len(client.get("/variables?limit=2").json()) == 2 + + def test_given_skip_then_skips_first_n( + self, + client, + session, + us_version, # noqa: F811 + ): + """Skip omits the first N results.""" + for i in range(5): + create_variable(session, us_version, f"var_{i}") + + assert len(client.get("/variables?skip=3&limit=10").json()) == 2 + + def test_results_ordered_by_name( + self, + client, + session, + us_version, # noqa: F811 + ): + """Variables come back sorted alphabetically by name.""" + create_variable(session, us_version, "zzz_var") + create_variable(session, us_version, "aaa_var") + names = [v["name"] for v in client.get("/variables").json()] + assert names == ["aaa_var", "zzz_var"] + + +# ----------------------------------------------------------------------------- +# GET /variables/ — version filtering +# ----------------------------------------------------------------------------- + + +class TestListVariablesVersionFilter: + def test_given_model_name_then_returns_only_latest_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Model name resolves to latest version; old-version vars excluded.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get( + f"/variables?tax_benefit_model_name={MODEL_NAMES['US']}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "new_variable" + + def test_given_explicit_version_id_then_returns_that_version( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Explicit version_id pins to a specific version.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get(f"/variables?tax_benefit_model_version_id={v1.id}").json() + assert len(data) == 1 + assert data[0]["name"] == "old_variable" + + def test_given_both_then_version_id_takes_precedence( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """version_id overrides model_name.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get( + f"/variables?tax_benefit_model_name={MODEL_NAMES['US']}" + f"&tax_benefit_model_version_id={v1.id}" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "old_variable" + + def test_given_no_filters_then_returns_all_versions( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Without model/version filter, vars from all versions are returned.""" + v1, v2 = us_two_versions + create_variable(session, v1, "old_variable") + create_variable(session, v2, "new_variable") + + data = client.get("/variables").json() + assert len(data) == 2 + + def test_given_nonexistent_model_name_then_returns_404(self, client): + """Unknown model name returns 404.""" + response = client.get("/variables?tax_benefit_model_name=nonexistent-model") + assert response.status_code == 404 + + def test_given_search_combined_with_version_filter( + self, + client, + session, + us_two_versions, # noqa: F811 + ): + """Search + version filter work together.""" + v1, v2 = us_two_versions + create_variable(session, v2, "employment_income") + create_variable(session, v2, "income_tax") + + data = client.get( + f"/variables?tax_benefit_model_name={MODEL_NAMES['US']}&search=employment" + ).json() + assert len(data) == 1 + assert data[0]["name"] == "employment_income" + + +# ----------------------------------------------------------------------------- +# GET /variables/{id} +# ----------------------------------------------------------------------------- + + +class TestGetVariable: + def test_given_valid_id_then_returns_variable( + self, + client, + session, + us_version, # noqa: F811 + ): + """Returns the full variable data.""" + var = create_variable(session, us_version, "employment_income") + response = client.get(f"/variables/{var.id}") + assert response.status_code == 200 + assert response.json()["name"] == "employment_income" + + def test_given_nonexistent_id_then_returns_404(self, client): + """Unknown UUID returns 404.""" + response = client.get(f"/variables/{uuid4()}") + assert response.status_code == 404 + + def test_response_shape_matches_variable_read( + self, + client, + session, + us_version, # noqa: F811 + ): + """Response contains all VariableRead fields.""" + var = create_variable(session, us_version, "employment_income") + data = client.get(f"/variables/{var.id}").json() + for field in ( + "id", + "name", + "entity", + "created_at", + "tax_benefit_model_version_id", + ): + assert field in data diff --git a/tests/test_variables_by_name.py b/tests/test_variables_by_name.py index 3639fea..0f713bc 100644 --- a/tests/test_variables_by_name.py +++ b/tests/test_variables_by_name.py @@ -1,228 +1,270 @@ """Tests for POST /variables/by-name endpoint.""" -import pytest -from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, - Variable, +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_variable, + uk_model, # noqa: F401 + uk_two_versions, # noqa: F401 + uk_version, # noqa: F401 + us_model, # noqa: F401 + us_version, # noqa: F401 ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def uk_version(session): - """Create a policyengine-uk model and version.""" - model = TaxBenefitModel(name="policyengine-uk", description="UK model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="UK v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -@pytest.fixture -def us_version(session): - """Create a policyengine-us model and version.""" - model = TaxBenefitModel(name="policyengine-us", description="US model") - session.add(model) - session.commit() - session.refresh(model) - - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="US v1" - ) - session.add(version) - session.commit() - session.refresh(version) - return version - - -def _add_var(session, version, name, entity="person", description=None): - """Create and persist a Variable.""" - var = Variable( - name=name, - entity=entity, - description=description, - tax_benefit_model_version_id=version.id, - ) - session.add(var) - session.commit() - session.refresh(var) - return var - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# Happy-path lookups +# ----------------------------------------------------------------------------- class TestVariablesByName: - """Tests for looking up variables by their exact names.""" - - def test_returns_matching_variables(self, client, session, uk_version): - """Given known variable names, returns their full metadata.""" - _add_var(session, uk_version, "employment_income") - _add_var(session, uk_version, "income_tax") + def test_given_known_names_then_returns_matching( + self, + client, + session, + uk_version, # noqa: F811 + ): + """Returns full metadata for each matching name.""" + create_variable(session, uk_version, "employment_income") + create_variable(session, uk_version, "income_tax") - response = client.post( + data = client.post( "/variables/by-name", - json={"names": ["employment_income", "income_tax"], "country_id": "uk"}, - ) + json={ + "names": ["employment_income", "income_tax"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() - assert response.status_code == 200 - data = response.json() assert len(data) == 2 - returned_names = {v["name"] for v in data} - assert returned_names == {"employment_income", "income_tax"} + assert {v["name"] for v in data} == {"employment_income", "income_tax"} - def test_returns_empty_list_for_empty_names(self, client): - """Given an empty names list, returns an empty list.""" - response = client.post( + def test_given_empty_names_then_returns_empty_list( + self, + client, + session, + uk_version, # noqa: F811 + ): + """Empty names list returns empty response (no DB query).""" + data = client.post( "/variables/by-name", - json={"names": [], "country_id": "uk"}, - ) - - assert response.status_code == 200 - assert response.json() == [] - - def test_returns_empty_list_for_unknown_names(self, client, session, uk_version): - """Given names that don't match any variable, returns an empty list.""" - _add_var(session, uk_version, "employment_income") + json={"names": [], "tax_benefit_model_name": MODEL_NAMES["UK"]}, + ).json() + assert data == [] + + def test_given_unknown_names_then_returns_empty_list( + self, + client, + session, + uk_version, # noqa: F811 + ): + """Names that don't match anything return empty list.""" + create_variable(session, uk_version, "employment_income") - response = client.post( + data = client.post( "/variables/by-name", - json={"names": ["nonexistent_var", "also_missing"], "country_id": "uk"}, - ) - - assert response.status_code == 200 - assert response.json() == [] - - def test_returns_only_matching_when_mix_of_known_and_unknown( - self, client, session, uk_version + json={ + "names": ["nonexistent_var", "also_missing"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() + assert data == [] + + def test_given_mixed_names_then_returns_only_known( + self, + client, + session, + uk_version, # noqa: F811 ): - """Given a mix of known and unknown names, returns only the known ones.""" - _add_var(session, uk_version, "income_tax") + """Only matching names are returned; unknowns silently omitted.""" + create_variable(session, uk_version, "income_tax") - response = client.post( + data = client.post( "/variables/by-name", - json={"names": ["income_tax", "fake_var"], "country_id": "uk"}, - ) - - assert response.status_code == 200 - data = response.json() + json={ + "names": ["income_tax", "fake_var"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() assert len(data) == 1 assert data[0]["name"] == "income_tax" - def test_single_name_lookup(self, client, session, uk_version): - """Looking up a single variable name works.""" - _add_var(session, uk_version, "age") - - response = client.post( + def test_given_single_name_then_returns_one( + self, + client, + session, + uk_version, # noqa: F811 + ): + """Single-element lookup works.""" + create_variable(session, uk_version, "age") + data = client.post( "/variables/by-name", - json={"names": ["age"], "country_id": "uk"}, - ) - - assert response.status_code == 200 - data = response.json() + json={ + "names": ["age"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() assert len(data) == 1 - assert data[0]["name"] == "age" - def test_results_ordered_by_name(self, client, session, uk_version): - """Returned variables are sorted alphabetically by name.""" - _add_var(session, uk_version, "zzz_var") - _add_var(session, uk_version, "aaa_var") - _add_var(session, uk_version, "mmm_var") + def test_results_ordered_by_name( + self, + client, + session, + uk_version, # noqa: F811 + ): + """Response is sorted alphabetically by name.""" + create_variable(session, uk_version, "zzz_var") + create_variable(session, uk_version, "aaa_var") + create_variable(session, uk_version, "mmm_var") + + names = [ + v["name"] + for v in client.post( + "/variables/by-name", + json={ + "names": ["zzz_var", "aaa_var", "mmm_var"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() + ] + assert names == ["aaa_var", "mmm_var", "zzz_var"] - response = client.post( + def test_response_shape(self, client, session, uk_version): # noqa: F811 + """Each returned object has the VariableRead fields.""" + create_variable( + session, uk_version, "income_tax", entity="person", description="Tax" + ) + var = client.post( "/variables/by-name", json={ - "names": ["zzz_var", "aaa_var", "mmm_var"], - "country_id": "uk", + "names": ["income_tax"], + "tax_benefit_model_name": MODEL_NAMES["UK"], }, - ) - - assert response.status_code == 200 - names = [v["name"] for v in response.json()] - assert names == ["aaa_var", "mmm_var", "zzz_var"] - - def test_response_shape_matches_variable_read(self, client, session, uk_version): - """Returned objects have the same shape as VariableRead.""" - _add_var(session, uk_version, "income_tax", entity="person", description="Tax") + ).json()[0] + for field in ( + "id", + "name", + "entity", + "description", + "created_at", + "tax_benefit_model_version_id", + ): + assert field in var + + +# ----------------------------------------------------------------------------- +# Model isolation +# ----------------------------------------------------------------------------- + + +class TestVariablesByNameModelIsolation: + def test_given_two_models_then_returns_only_requested( + self, + client, + session, + uk_version, # noqa: F811 + us_version, # noqa: F811 + ): + """Variables from the other model are excluded.""" + create_variable(session, uk_version, "council_tax") + create_variable(session, us_version, "state_income_tax") - response = client.post( + uk_data = client.post( "/variables/by-name", - json={"names": ["income_tax"], "country_id": "uk"}, - ) + json={ + "names": ["council_tax", "state_income_tax"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() + us_data = client.post( + "/variables/by-name", + json={ + "names": ["council_tax", "state_income_tax"], + "tax_benefit_model_name": MODEL_NAMES["US"], + }, + ).json() - assert response.status_code == 200 - var = response.json()[0] - assert "id" in var - assert "name" in var - assert "entity" in var - assert "description" in var - assert "created_at" in var - assert "tax_benefit_model_version_id" in var + assert len(uk_data) == 1 + assert uk_data[0]["name"] == "council_tax" + assert len(us_data) == 1 + assert us_data[0]["name"] == "state_income_tax" -class TestVariablesByNameCountryFiltering: - """Tests for country_id filtering.""" +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- - def test_country_isolation(self, client, session, uk_version, us_version): - """Variables from a different country are excluded.""" - _add_var(session, uk_version, "council_tax") - _add_var(session, us_version, "state_income_tax") - uk_response = client.post( - "/variables/by-name", - json={"names": ["council_tax", "state_income_tax"], "country_id": "uk"}, - ) - us_response = client.post( +class TestVariablesByNameValidation: + def test_given_missing_model_name_then_422(self, client): + """Omitting tax_benefit_model_name returns 422.""" + response = client.post("/variables/by-name", json={"names": ["income_tax"]}) + assert response.status_code == 422 + + def test_given_missing_names_then_422(self, client): + """Omitting names returns 422.""" + response = client.post( "/variables/by-name", - json={"names": ["council_tax", "state_income_tax"], "country_id": "us"}, + json={"tax_benefit_model_name": MODEL_NAMES["UK"]}, ) + assert response.status_code == 422 - assert len(uk_response.json()) == 1 - assert uk_response.json()[0]["name"] == "council_tax" - assert len(us_response.json()) == 1 - assert us_response.json()[0]["name"] == "state_income_tax" - - def test_invalid_country_id_returns_422(self, client): - """An invalid country_id is rejected.""" + def test_given_nonexistent_model_name_then_404(self, client, session): + """Model that doesn't exist returns 404.""" response = client.post( "/variables/by-name", - json={"names": ["income_tax"], "country_id": "fr"}, + json={ + "names": ["income_tax"], + "tax_benefit_model_name": "nonexistent-model", + }, ) + assert response.status_code == 404 - assert response.status_code == 422 +# ----------------------------------------------------------------------------- +# Version filtering +# ----------------------------------------------------------------------------- -class TestVariablesByNameValidation: - """Tests for request validation.""" - def test_missing_country_id_returns_422(self, client): - """Request without country_id is rejected.""" - response = client.post( +class TestVariablesByNameVersionFilter: + def test_given_model_name_only_then_defaults_to_latest( + self, + client, + session, + uk_two_versions, # noqa: F811 + ): + """Model name resolves to latest version.""" + v1, v2 = uk_two_versions + create_variable(session, v1, "old_var") + create_variable(session, v2, "new_var") + + data = client.post( "/variables/by-name", - json={"names": ["income_tax"]}, - ) + json={ + "names": ["old_var", "new_var"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + }, + ).json() + assert len(data) == 1 + assert data[0]["name"] == "new_var" - assert response.status_code == 422 + def test_given_explicit_version_id_then_returns_that_version( + self, + client, + session, + uk_two_versions, # noqa: F811 + ): + """Explicit version_id overrides latest-version default.""" + v1, v2 = uk_two_versions + create_variable(session, v1, "old_var") + create_variable(session, v2, "new_var") - def test_missing_names_field_returns_422(self, client): - """Request without names field is rejected.""" - response = client.post( + data = client.post( "/variables/by-name", - json={"country_id": "uk"}, - ) - - assert response.status_code == 422 + json={ + "names": ["old_var", "new_var"], + "tax_benefit_model_name": MODEL_NAMES["UK"], + "tax_benefit_model_version_id": str(v1.id), + }, + ).json() + assert len(data) == 1 + assert data[0]["name"] == "old_var" diff --git a/tests/test_version_filter_service.py b/tests/test_version_filter_service.py new file mode 100644 index 0000000..b734775 --- /dev/null +++ b/tests/test_version_filter_service.py @@ -0,0 +1,140 @@ +"""Tests for the tax benefit model version resolution service.""" + +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from policyengine_api.services.tax_benefit_models import ( + get_latest_model_version, + get_model_version_by_id, + resolve_model_version_id, +) +from test_fixtures.fixtures_version_filter import ( + MODEL_NAMES, + create_model, + create_version, + us_model, # noqa: F401 + us_two_versions, # noqa: F401 +) + +# --------------------------------------------------------------------------- +# get_latest_model_version +# --------------------------------------------------------------------------- + + +class TestGetLatestModelVersion: + def test_given_multiple_versions_then_returns_newest( + self, + session, + us_model, # noqa: F811 + us_two_versions, # noqa: F811 + ): + """Returns the version with the most recent created_at.""" + _v1, v2 = us_two_versions + result = get_latest_model_version(MODEL_NAMES["US"], session) + assert result.id == v2.id + assert result.version == "2.0" + + def test_given_underscore_name_then_normalizes_to_hyphens( + self, + session, + us_model, # noqa: F811 + us_two_versions, # noqa: F811 + ): + """'policyengine_us' is normalised to 'policyengine-us'.""" + _v1, v2 = us_two_versions + result = get_latest_model_version("policyengine_us", session) + assert result.id == v2.id + + def test_given_nonexistent_model_then_raises_404(self, session): + """Unknown model name raises HTTPException 404.""" + with pytest.raises(HTTPException) as exc_info: + get_latest_model_version("nonexistent-model", session) + assert exc_info.value.status_code == 404 + + def test_given_model_without_versions_then_raises_404( + self, + session, + us_model, # noqa: F811 + ): + """Model that exists but has zero versions raises 404.""" + with pytest.raises(HTTPException) as exc_info: + get_latest_model_version(MODEL_NAMES["US"], session) + assert exc_info.value.status_code == 404 + + def test_given_single_version_then_returns_it(self, session): + """With only one version, that version is returned.""" + model = create_model(session) + only = create_version(session, model, "0.1") + result = get_latest_model_version(MODEL_NAMES["US"], session) + assert result.id == only.id + + +# --------------------------------------------------------------------------- +# get_model_version_by_id +# --------------------------------------------------------------------------- + + +class TestGetModelVersionById: + def test_given_valid_id_then_returns_version( + self, + session, + us_two_versions, # noqa: F811 + ): + """Returns the matching version.""" + v1, _v2 = us_two_versions + result = get_model_version_by_id(v1.id, session) + assert result.id == v1.id + assert result.version == "1.0" + + def test_given_nonexistent_id_then_raises_404(self, session): + """Unknown UUID raises HTTPException 404.""" + with pytest.raises(HTTPException) as exc_info: + get_model_version_by_id(uuid4(), session) + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# resolve_model_version_id +# --------------------------------------------------------------------------- + + +class TestResolveModelVersionId: + def test_given_version_id_then_takes_precedence_over_model_name( + self, + session, + us_two_versions, # noqa: F811 + ): + """Explicit version_id wins over model_name.""" + v1, _v2 = us_two_versions + result = resolve_model_version_id(MODEL_NAMES["US"], v1.id, session) + assert result == v1.id + + def test_given_only_model_name_then_resolves_to_latest( + self, + session, + us_model, # noqa: F811 + us_two_versions, # noqa: F811 + ): + """Model name alone returns the latest version's ID.""" + _v1, v2 = us_two_versions + result = resolve_model_version_id(MODEL_NAMES["US"], None, session) + assert result == v2.id + + def test_given_neither_then_returns_none(self, session): + """No model name and no version ID → None (no filtering).""" + result = resolve_model_version_id(None, None, session) + assert result is None + + def test_given_invalid_version_id_then_raises_404(self, session): + """Non-existent explicit version_id raises 404.""" + with pytest.raises(HTTPException) as exc_info: + resolve_model_version_id(None, uuid4(), session) + assert exc_info.value.status_code == 404 + + def test_given_invalid_model_name_then_raises_404(self, session): + """Non-existent model name raises 404.""" + with pytest.raises(HTTPException) as exc_info: + resolve_model_version_id("does-not-exist", None, session) + assert exc_info.value.status_code == 404