Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/pr_code_changes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
name: Code changes
on:
pull_request:
branches:
- main

paths:
- src/**
- tests/**
Expand Down
2 changes: 2 additions & 0 deletions src/policyengine/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ class Variable(BaseModel):
description: str | None = None
data_type: type = None
possible_values: list[Any] | None = None
default_value: Any = None
value_type: type | None = None
4 changes: 2 additions & 2 deletions src/policyengine/outputs/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum
from typing import Any

from policyengine.core import Output, Simulation


class AggregateType(str, Enum):
class AggregateType(StrEnum):
SUM = "sum"
MEAN = "mean"
COUNT = "count"
Expand Down
4 changes: 2 additions & 2 deletions src/policyengine/outputs/change_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum
from typing import Any

from policyengine.core import Output, Simulation


class ChangeAggregateType(str, Enum):
class ChangeAggregateType(StrEnum):
COUNT = "count"
SUM = "sum"
MEAN = "mean"
Expand Down
6 changes: 3 additions & 3 deletions src/policyengine/outputs/poverty.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Poverty analysis output types."""

from enum import Enum
from enum import StrEnum
from typing import Any

import pandas as pd
Expand All @@ -9,7 +9,7 @@
from policyengine.core import Output, OutputCollection, Simulation


class UKPovertyType(str, Enum):
class UKPovertyType(StrEnum):
"""UK poverty measure types."""

ABSOLUTE_BHC = "absolute_bhc"
Expand All @@ -18,7 +18,7 @@ class UKPovertyType(str, Enum):
RELATIVE_AHC = "relative_ahc"


class USPovertyType(str, Enum):
class USPovertyType(StrEnum):
"""US poverty measure types."""

SPM = "spm"
Expand Down
9 changes: 9 additions & 0 deletions src/policyengine/tax_benefit_models/uk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ def __init__(self, **kwargs: dict):
self.id = f"{self.model.id}@{self.version}"

for var_obj in system.variables.values():
# Serialize default_value for JSON compatibility
default_val = var_obj.default_value
if var_obj.value_type is Enum:
default_val = default_val.name
elif var_obj.value_type is datetime.date:
default_val = default_val.isoformat()

variable = Variable(
id=self.id + "-" + var_obj.name,
name=var_obj.name,
Expand All @@ -135,6 +142,8 @@ def __init__(self, **kwargs: dict):
data_type=var_obj.value_type
if var_obj.value_type is not Enum
else str,
default_value=default_val,
value_type=var_obj.value_type,
)
if (
hasattr(var_obj, "possible_values")
Expand Down
9 changes: 9 additions & 0 deletions src/policyengine/tax_benefit_models/us/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def __init__(self, **kwargs: dict):
self.id = f"{self.model.id}@{self.version}"

for var_obj in system.variables.values():
# Serialize default_value for JSON compatibility
default_val = var_obj.default_value
if var_obj.value_type is Enum:
default_val = default_val.name
elif var_obj.value_type is datetime.date:
default_val = default_val.isoformat()

variable = Variable(
id=self.id + "-" + var_obj.name,
name=var_obj.name,
Expand All @@ -128,6 +135,8 @@ def __init__(self, **kwargs: dict):
data_type=var_obj.value_type
if var_obj.value_type is not Enum
else str,
default_value=default_val,
value_type=var_obj.value_type,
)
if (
hasattr(var_obj, "possible_values")
Expand Down
38 changes: 37 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for UK and US tax-benefit model versions."""
"""Tests for UK and US tax-benefit model versions and core models."""

import re

Expand Down Expand Up @@ -146,3 +146,39 @@ def test__given_bracket_label__then_follows_expected_format(self):
f"Label '{p.label}' doesn't match expected bracket format"
)
break


class TestVariableDefaultValue:
"""Tests for Variable default_value and value_type fields."""

def test_us_age_variable_has_default_value_40(self):
"""US age variable should have default_value of 40."""
age_var = next((v for v in us_latest.variables if v.name == "age"), None)
assert age_var is not None, "age variable not found in US model"
assert age_var.default_value == 40, (
f"Expected age default_value to be 40, got {age_var.default_value}"
)

def test_us_enum_variable_has_string_default_value(self):
"""Enum variables should have string default_value (not enum object)."""
# age_group is an enum with default WORKING_AGE
age_group_var = next(
(v for v in us_latest.variables if v.name == "age_group"), None
)
assert age_group_var is not None, "age_group variable not found in US model"
assert age_group_var.default_value == "WORKING_AGE", (
f"Expected age_group default_value to be 'WORKING_AGE', "
f"got {age_group_var.default_value}"
)

def test_us_variables_have_value_type(self):
"""US variables should have value_type set."""
age_var = next((v for v in us_latest.variables if v.name == "age"), None)
assert age_var is not None, "age variable not found in US model"
assert age_var.value_type is not None, "age variable should have value_type"

def test_uk_age_variable_has_default_value(self):
"""UK age variable should have default_value set."""
age_var = next((v for v in uk_latest.variables if v.name == "age"), None)
assert age_var is not None, "age variable not found in UK model"
assert age_var.default_value is not None, "UK age should have default_value"