diff --git a/alembic/versions/20260305_add_variable_label.py b/alembic/versions/20260305_add_variable_label.py new file mode 100644 index 0000000..c6ea9b6 --- /dev/null +++ b/alembic/versions/20260305_add_variable_label.py @@ -0,0 +1,36 @@ +"""Add label column to variables table + +Revision ID: add_variable_label +Revises: 886921687770 +Create Date: 2026-03-05 + +Variables now carry a human-readable label sourced from OpenFisca's +Variable.label class attribute (e.g. "Employment income"). Previously +labels were auto-generated on the frontend from the snake_case name. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_variable_label" +down_revision: Union[str, Sequence[str], None] = "886921687770" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add label column to variables table.""" + op.add_column( + "variables", + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + +def downgrade() -> None: + """Remove label column from variables table.""" + op.drop_column("variables", "label") diff --git a/scripts/seed_models.py b/scripts/seed_models.py index 049558e..74f78f2 100644 --- a/scripts/seed_models.py +++ b/scripts/seed_models.py @@ -125,6 +125,7 @@ def seed_model( { "id": uuid4(), "name": var.name, + "label": getattr(var, "label", None) or "", "entity": var.entity, "description": var.description or "", "data_type": var.data_type.__name__ @@ -144,6 +145,7 @@ def seed_model( [ "id", "name", + "label", "entity", "description", "data_type", diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index 04aa512..e592820 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -56,11 +56,12 @@ def list_variables( if search: # Case-insensitive search using ILIKE - # Note: Variables don't have a label field, only name and description search_pattern = f"%{search}%" - search_filter = Variable.name.ilike( - search_pattern - ) | Variable.description.ilike(search_pattern) + search_filter = ( + Variable.name.ilike(search_pattern) + | Variable.label.ilike(search_pattern) + | Variable.description.ilike(search_pattern) + ) query = query.where(search_filter) variables = session.exec( diff --git a/src/policyengine_api/models/variable.py b/src/policyengine_api/models/variable.py index eeebddc..8ee6b94 100644 --- a/src/policyengine_api/models/variable.py +++ b/src/policyengine_api/models/variable.py @@ -12,6 +12,7 @@ class VariableBase(SQLModel): """Base variable fields.""" name: str + label: str | None = None entity: str description: str | None = None data_type: str | None = None # Store as string representation diff --git a/test_fixtures/fixtures_variables.py b/test_fixtures/fixtures_variables.py new file mode 100644 index 0000000..9a0b052 --- /dev/null +++ b/test_fixtures/fixtures_variables.py @@ -0,0 +1,80 @@ +"""Fixtures and helpers for variable-related tests.""" + +import pytest + +from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def us_model_version(session): + """Create a policyengine-us model and version for testing.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, + version="1.0.0", + description="Test US version", + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +@pytest.fixture +def uk_model_version(session): + """Create a policyengine-uk model and version for testing.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, + version="1.0.0", + description="Test UK version", + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +# --------------------------------------------------------------------------- +# Factory Functions +# --------------------------------------------------------------------------- + + +def create_variable( + session, + model_version, + name: str, + label: str | None = None, + entity: str = "person", + description: str | None = None, + data_type: str | None = "float", +) -> Variable: + """Create and persist a Variable.""" + var = Variable( + name=name, + label=label, + entity=entity, + description=description, + data_type=data_type, + tax_benefit_model_version_id=model_version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var diff --git a/tests/test_variable_labels.py b/tests/test_variable_labels.py new file mode 100644 index 0000000..a99382b --- /dev/null +++ b/tests/test_variable_labels.py @@ -0,0 +1,297 @@ +"""Tests for variable label field across all variable endpoints.""" + +import pytest + +from test_fixtures.fixtures_variables import ( # noqa: F811 + create_variable, + uk_model_version, + us_model_version, +) + + +# --------------------------------------------------------------------------- +# GET /variables - label in list responses +# --------------------------------------------------------------------------- + + +class TestListVariablesLabel: + """Tests that label is returned when listing variables.""" + + def test_label_returned_in_response( + self, client, session, us_model_version # noqa: F811 + ): + """Variable with a label should include it in the list response.""" + create_variable( + session, + us_model_version, + name="employment_income", + label="Employment income", + ) + + response = client.get("/variables") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["label"] == "Employment income" + + def test_null_label_returned_when_absent( + self, client, session, us_model_version # noqa: F811 + ): + """Variable without a label should return null.""" + create_variable( + session, + us_model_version, + name="age", + label=None, + ) + + response = client.get("/variables") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["label"] is None + + def test_empty_label_returned( + self, client, session, us_model_version # noqa: F811 + ): + """Variable with an empty string label should return it as-is.""" + create_variable( + session, + us_model_version, + name="household_weight", + label="", + ) + + response = client.get("/variables") + assert response.status_code == 200 + assert response.json()[0]["label"] == "" + + +# --------------------------------------------------------------------------- +# GET /variables?search= - search by label +# --------------------------------------------------------------------------- + + +class TestSearchVariablesByLabel: + """Tests that the search parameter matches against labels.""" + + def test_search_matches_label( + self, client, session, us_model_version # noqa: F811 + ): + """Searching for a term in the label should return the variable.""" + create_variable( + session, + us_model_version, + name="employment_income", + label="Employment income", + ) + create_variable( + session, + us_model_version, + name="age", + label="Age of person", + ) + + response = client.get( + "/variables", + params={ + "search": "Employment", + "tax_benefit_model_name": "policyengine-us", + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "employment_income" + + def test_search_label_case_insensitive( + self, client, session, us_model_version # noqa: F811 + ): + """Label search should be case-insensitive.""" + create_variable( + session, + us_model_version, + name="income_tax", + label="Income tax", + ) + + response = client.get( + "/variables", + params={ + "search": "INCOME TAX", + "tax_benefit_model_name": "policyengine-us", + }, + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_search_partial_label_match( + self, client, session, us_model_version # noqa: F811 + ): + """Partial label matches should be returned.""" + create_variable( + session, + us_model_version, + name="state_income_tax", + label="State income tax", + ) + + response = client.get( + "/variables", + params={ + "search": "income", + "tax_benefit_model_name": "policyengine-us", + }, + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + + +# --------------------------------------------------------------------------- +# GET /variables/{id} - label in single variable response +# --------------------------------------------------------------------------- + + +class TestGetVariableLabel: + """Tests that label is returned when fetching a single variable.""" + + def test_label_in_get_response( + self, client, session, us_model_version # noqa: F811 + ): + """GET /variables/{id} should include the label field.""" + var = create_variable( + session, + us_model_version, + name="employment_income", + label="Employment income", + ) + + response = client.get(f"/variables/{var.id}") + assert response.status_code == 200 + assert response.json()["label"] == "Employment income" + + def test_null_label_in_get_response( + self, client, session, us_model_version # noqa: F811 + ): + """GET /variables/{id} should return null for missing label.""" + var = create_variable( + session, + us_model_version, + name="age", + label=None, + ) + + response = client.get(f"/variables/{var.id}") + assert response.status_code == 200 + assert response.json()["label"] is None + + +# --------------------------------------------------------------------------- +# POST /variables/by-name - label in batch lookup +# --------------------------------------------------------------------------- + + +class TestVariablesByNameLabel: + """Tests that label is included in by-name lookup responses.""" + + def test_label_in_by_name_response( + self, client, session, us_model_version # noqa: F811 + ): + """POST /variables/by-name should include the label field.""" + create_variable( + session, + us_model_version, + name="employment_income", + label="Employment income", + ) + + response = client.post( + "/variables/by-name", + json={"names": ["employment_income"], "country_id": "us"}, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["label"] == "Employment income" + + def test_mixed_labels_in_by_name_response( + self, client, session, us_model_version # noqa: F811 + ): + """Variables with and without labels should both be returned correctly.""" + create_variable( + session, + us_model_version, + name="employment_income", + label="Employment income", + ) + create_variable( + session, + us_model_version, + name="age", + label=None, + ) + + response = client.post( + "/variables/by-name", + json={"names": ["employment_income", "age"], "country_id": "us"}, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + by_name = {v["name"]: v for v in data} + assert by_name["employment_income"]["label"] == "Employment income" + assert by_name["age"]["label"] is None + + +# --------------------------------------------------------------------------- +# Country isolation for labels +# --------------------------------------------------------------------------- + + +class TestVariableLabelCountryIsolation: + """Tests that label search respects country boundaries.""" + + def test_search_by_label_isolated_by_country( + self, + client, + session, + us_model_version, # noqa: F811 + uk_model_version, # noqa: F811 + ): + """Searching by label should only return variables from the specified country.""" + create_variable( + session, + us_model_version, + name="state_income_tax", + label="State income tax", + ) + create_variable( + session, + uk_model_version, + name="council_tax", + label="Council tax", + ) + + us_response = client.get( + "/variables", + params={ + "search": "tax", + "tax_benefit_model_name": "policyengine-us", + }, + ) + uk_response = client.get( + "/variables", + params={ + "search": "tax", + "tax_benefit_model_name": "policyengine-uk", + }, + ) + + us_names = {v["name"] for v in us_response.json()} + uk_names = {v["name"] for v in uk_response.json()} + + assert "state_income_tax" in us_names + assert "council_tax" not in us_names + assert "council_tax" in uk_names + assert "state_income_tax" not in uk_names