Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/policyengine_api/api/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, or_, select

from policyengine_api.models import ParameterValue, ParameterValueRead
from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import resolve_model_version_id

router = APIRouter(prefix="/parameter-values", tags=["parameter-values"])

Expand All @@ -23,6 +24,8 @@ def list_parameter_values(
parameter_id: UUID | None = None,
policy_id: UUID | None = None,
current: bool = False,
tax_benefit_model_name: str | None = None,
tax_benefit_model_version_id: UUID | None = None,
skip: int = 0,
limit: int = 100,
session: Session = Depends(get_session),
Expand All @@ -37,6 +40,10 @@ def list_parameter_values(
policy_id: Filter by a specific policy reform.
current: If true, only return values that are currently in effect
(start_date <= now and (end_date is null or end_date > now)).
tax_benefit_model_name: Filter to values belonging to parameters from
this model. Defaults to the latest version.
tax_benefit_model_version_id: Filter to values belonging to parameters
from this specific model version. Takes precedence over model name.
"""
query = select(ParameterValue)

Expand All @@ -46,6 +53,14 @@ def list_parameter_values(
if policy_id:
query = query.where(ParameterValue.policy_id == policy_id)

version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)
if version_id:
query = query.join(Parameter).where(
Parameter.tax_benefit_model_version_id == version_id
)

if current:
now = datetime.now(timezone.utc)
query = query.where(
Expand Down
67 changes: 35 additions & 32 deletions src/policyengine_api/api/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
from pydantic import BaseModel
from sqlmodel import Session, select

from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
from policyengine_api.models import (
Parameter,
ParameterRead,
TaxBenefitModel,
TaxBenefitModelVersion,
)
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import resolve_model_version_id

router = APIRouter(prefix="/parameters", tags=["parameters"])

Expand All @@ -32,6 +30,7 @@ def list_parameters(
limit: int = 100,
search: str | None = None,
tax_benefit_model_name: str | None = None,
tax_benefit_model_version_id: UUID | None = None,
session: Session = Depends(get_session),
):
"""List available parameters with pagination and search.
Expand All @@ -44,19 +43,19 @@ def list_parameters(
tax_benefit_model_name: Filter by country model.
Use "policyengine-uk" for UK parameters.
Use "policyengine-us" for US parameters.
Defaults to the latest model version when no version ID is given.
tax_benefit_model_version_id: Filter by a specific model version.
Takes precedence over tax_benefit_model_name.
"""
query = select(Parameter)

# Filter by tax benefit model name (country)
if tax_benefit_model_name:
query = (
query.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == tax_benefit_model_name)
)
version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)
if version_id:
query = query.where(Parameter.tax_benefit_model_version_id == version_id)

if search:
# Case-insensitive search using ILIKE
search_pattern = f"%{search}%"
search_filter = (
Parameter.name.ilike(search_pattern)
Expand All @@ -75,7 +74,8 @@ class ParameterByNameRequest(BaseModel):
"""Request body for looking up parameters by name."""

names: list[str]
country_id: CountryId
tax_benefit_model_name: str
tax_benefit_model_version_id: UUID | None = None


@router.post("/by-name", response_model=List[ParameterRead])
Expand All @@ -95,18 +95,17 @@ def get_parameters_by_name(
if not request.names:
return []

model_name = COUNTRY_MODEL_NAMES[request.country_id]

query = (
select(Parameter)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Parameter.name.in_(request.names))
.order_by(Parameter.name)
version_id = resolve_model_version_id(
request.tax_benefit_model_name,
request.tax_benefit_model_version_id,
session,
)

return session.exec(query).all()
query = select(Parameter).where(Parameter.name.in_(request.names))
if version_id:
query = query.where(Parameter.tax_benefit_model_version_id == version_id)

return session.exec(query.order_by(Parameter.name)).all()


class ParameterChild(BaseModel):
Expand All @@ -128,10 +127,15 @@ class ParameterChildrenResponse(BaseModel):

@router.get("/children", response_model=ParameterChildrenResponse)
def get_parameter_children(
country_id: CountryId = Query(description='Country ID ("us" or "uk")'),
tax_benefit_model_name: str = Query(
description='Model name (e.g. "policyengine-us" or "policyengine-uk")'
),
parent_path: str = Query(
default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')"
),
tax_benefit_model_version_id: UUID | None = Query(
default=None, description="Optional specific model version ID"
),
session: Session = Depends(get_session),
) -> ParameterChildrenResponse:
"""Get direct children of a parameter path for tree navigation.
Expand All @@ -140,17 +144,16 @@ def get_parameter_children(
parameters (with full metadata). Use this to lazily load the parameter
tree one level at a time.
"""
model_name = COUNTRY_MODEL_NAMES[country_id]
version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)

prefix = f"{parent_path}." if parent_path else ""

# Fetch all parameters under this path
query = (
select(Parameter)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Parameter.name.startswith(prefix))
)
query = select(Parameter).where(Parameter.name.startswith(prefix))
if version_id:
query = query.where(Parameter.tax_benefit_model_version_id == version_id)

descendants = session.exec(query).all()

# Group by direct child path
Expand Down
43 changes: 21 additions & 22 deletions src/policyengine_api/api/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
from pydantic import BaseModel
from sqlmodel import Session, select

from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
from policyengine_api.models import (
TaxBenefitModel,
TaxBenefitModelVersion,
Variable,
VariableRead,
)
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import resolve_model_version_id

router = APIRouter(prefix="/variables", tags=["variables"])

Expand All @@ -30,6 +28,7 @@ def list_variables(
limit: int = 100,
search: str | None = None,
tax_benefit_model_name: str | None = None,
tax_benefit_model_version_id: UUID | None = None,
session: Session = Depends(get_session),
):
"""List available variables with pagination and search.
Expand All @@ -43,20 +42,19 @@ def list_variables(
tax_benefit_model_name: Filter by country model.
Use "policyengine-uk" for UK variables.
Use "policyengine-us" for US variables.
Defaults to the latest model version when no version ID is given.
tax_benefit_model_version_id: Filter by a specific model version.
Takes precedence over tax_benefit_model_name.
"""
query = select(Variable)

# Filter by tax benefit model name (country)
if tax_benefit_model_name:
query = (
query.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == tax_benefit_model_name)
)
version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)
if version_id:
query = query.where(Variable.tax_benefit_model_version_id == version_id)

if search:
# Case-insensitive search using ILIKE
# Note: Variables don't have a label field, only name and description
search_pattern = f"%{search}%"
search_filter = Variable.name.ilike(
search_pattern
Expand All @@ -73,7 +71,8 @@ class VariableByNameRequest(BaseModel):
"""Request body for looking up variables by name."""

names: list[str]
country_id: CountryId
tax_benefit_model_name: str
tax_benefit_model_version_id: UUID | None = None


@router.post("/by-name", response_model=List[VariableRead])
Expand All @@ -94,17 +93,17 @@ def get_variables_by_name(
if not request.names:
return []

model_name = COUNTRY_MODEL_NAMES[request.country_id]
query = (
select(Variable)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Variable.name.in_(request.names))
.order_by(Variable.name)
version_id = resolve_model_version_id(
request.tax_benefit_model_name,
request.tax_benefit_model_version_id,
session,
)

return session.exec(query).all()
query = select(Variable).where(Variable.name.in_(request.names))
if version_id:
query = query.where(Variable.tax_benefit_model_version_id == version_id)

return session.exec(query.order_by(Variable.name)).all()


@router.get("/{variable_id}", response_model=VariableRead)
Expand Down
13 changes: 12 additions & 1 deletion src/policyengine_api/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
"""Services for database and external integrations."""

from .database import get_session, init_db
from .tax_benefit_models import (
get_latest_model_version,
get_model_version_by_id,
resolve_model_version_id,
)

__all__ = ["get_session", "init_db"]
__all__ = [
"get_session",
"init_db",
"get_latest_model_version",
"get_model_version_by_id",
"resolve_model_version_id",
]
109 changes: 109 additions & 0 deletions src/policyengine_api/services/tax_benefit_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tax benefit model utilities.

Shared utilities for resolving tax benefit model versions.
"""

from uuid import UUID

from fastapi import HTTPException
from sqlmodel import Session, select

from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion


def get_latest_model_version(
tax_benefit_model_name: str, session: Session
) -> TaxBenefitModelVersion:
"""Get the latest tax benefit model version for a given model name.

Args:
tax_benefit_model_name: The model name (e.g., "policyengine-us").
Underscores are normalized to hyphens.
session: Database session.

Returns:
The latest TaxBenefitModelVersion for the model.

Raises:
HTTPException: If model or version not found.
"""
model_name = tax_benefit_model_name.replace("_", "-")

model = session.exec(
select(TaxBenefitModel).where(TaxBenefitModel.name == model_name)
).first()
if not model:
raise HTTPException(
status_code=404,
detail=f"Tax benefit model '{model_name}' not found",
)

version = session.exec(
select(TaxBenefitModelVersion)
.where(TaxBenefitModelVersion.model_id == model.id)
.order_by(TaxBenefitModelVersion.created_at.desc())
).first()
if not version:
raise HTTPException(
status_code=404,
detail=f"No version found for model '{model_name}'",
)

return version


def get_model_version_by_id(
version_id: UUID, session: Session
) -> TaxBenefitModelVersion:
"""Get a specific tax benefit model version by ID.

Args:
version_id: The UUID of the model version.
session: Database session.

Returns:
The TaxBenefitModelVersion with the given ID.

Raises:
HTTPException: If version not found.
"""
version = session.get(TaxBenefitModelVersion, version_id)
if not version:
raise HTTPException(
status_code=404,
detail=f"Tax benefit model version '{version_id}' not found",
)
return version


def resolve_model_version_id(
tax_benefit_model_name: str | None,
tax_benefit_model_version_id: UUID | None,
session: Session,
) -> UUID | None:
"""Resolve the model version ID from either explicit ID or model name.

Priority:
1. If tax_benefit_model_version_id provided, validate and return it.
2. If tax_benefit_model_name provided, return the latest version's ID.
3. If neither provided, return None (no filtering).

Args:
tax_benefit_model_name: Optional model name to resolve latest version for.
tax_benefit_model_version_id: Optional explicit version ID.
session: Database session.

Returns:
The resolved version ID, or None if no filtering requested.

Raises:
HTTPException: If specified version or model not found.
"""
if tax_benefit_model_version_id:
version = get_model_version_by_id(tax_benefit_model_version_id, session)
return version.id
elif tax_benefit_model_name:
version = get_latest_model_version(tax_benefit_model_name, session)
return version.id
else:
return None
Loading