Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions alembic/versions/20260305_add_variable_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add label column to variables table

Revision ID: add_variable_label
Revises: 886921687770
Create Date: 2026-03-05

Variables now carry a human-readable label sourced from OpenFisca's
Variable.label class attribute (e.g. "Employment income"). Previously
labels were auto-generated on the frontend from the snake_case name.
"""

from typing import Sequence, Union

import sqlalchemy as sa
import sqlmodel.sql.sqltypes

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "add_variable_label"
down_revision: Union[str, Sequence[str], None] = "886921687770"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Add label column to variables table."""
op.add_column(
"variables",
sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
)


def downgrade() -> None:
"""Remove label column from variables table."""
op.drop_column("variables", "label")
2 changes: 2 additions & 0 deletions scripts/seed_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def seed_model(
{
"id": uuid4(),
"name": var.name,
"label": getattr(var, "label", None) or "",
"entity": var.entity,
"description": var.description or "",
"data_type": var.data_type.__name__
Expand All @@ -144,6 +145,7 @@ def seed_model(
[
"id",
"name",
"label",
"entity",
"description",
"data_type",
Expand Down
9 changes: 5 additions & 4 deletions src/policyengine_api/api/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ def list_variables(

if search:
# Case-insensitive search using ILIKE
# Note: Variables don't have a label field, only name and description
search_pattern = f"%{search}%"
search_filter = Variable.name.ilike(
search_pattern
) | Variable.description.ilike(search_pattern)
search_filter = (
Variable.name.ilike(search_pattern)
| Variable.label.ilike(search_pattern)
| Variable.description.ilike(search_pattern)
)
query = query.where(search_filter)

variables = session.exec(
Expand Down
1 change: 1 addition & 0 deletions src/policyengine_api/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class VariableBase(SQLModel):
"""Base variable fields."""

name: str
label: str | None = None
entity: str
description: str | None = None
data_type: str | None = None # Store as string representation
Expand Down
80 changes: 80 additions & 0 deletions test_fixtures/fixtures_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Fixtures and helpers for variable-related tests."""

import pytest

from policyengine_api.models import (
TaxBenefitModel,
TaxBenefitModelVersion,
Variable,
)

# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------


@pytest.fixture
def us_model_version(session):
"""Create a policyengine-us model and version for testing."""
model = TaxBenefitModel(name="policyengine-us", description="US model")
session.add(model)
session.commit()
session.refresh(model)

version = TaxBenefitModelVersion(
model_id=model.id,
version="1.0.0",
description="Test US version",
)
session.add(version)
session.commit()
session.refresh(version)
return version


@pytest.fixture
def uk_model_version(session):
"""Create a policyengine-uk model and version for testing."""
model = TaxBenefitModel(name="policyengine-uk", description="UK model")
session.add(model)
session.commit()
session.refresh(model)

version = TaxBenefitModelVersion(
model_id=model.id,
version="1.0.0",
description="Test UK version",
)
session.add(version)
session.commit()
session.refresh(version)
return version


# ---------------------------------------------------------------------------
# Factory Functions
# ---------------------------------------------------------------------------


def create_variable(
session,
model_version,
name: str,
label: str | None = None,
entity: str = "person",
description: str | None = None,
data_type: str | None = "float",
) -> Variable:
"""Create and persist a Variable."""
var = Variable(
name=name,
label=label,
entity=entity,
description=description,
data_type=data_type,
tax_benefit_model_version_id=model_version.id,
)
session.add(var)
session.commit()
session.refresh(var)
return var
Loading