diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..047421db 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,8 @@ +- bump: minor + changes: + added: + - Entity relationship approach for simulation filtering that preserves household integrity + - Reusable variable validation functions (`get_variable`, `validate_variable_entity`, `validate_household_variable`) + changed: + - Refactored `_filter_simulation_by_household_variable` to use explicit entity relationship mapping + - Place-level filtering now builds entity_rel DataFrame for cleaner filtering logic diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 864aa90a..8ad432ca 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -47,6 +47,79 @@ SubsampleType = Optional[int] +# ============================================================================= +# Variable Validation Functions +# ============================================================================= + + +def get_variable(tax_benefit_system: Any, variable_name: str) -> Any: + """Get a variable from the tax-benefit system, raising if not found. + + Args: + tax_benefit_system: The tax-benefit system to search. + variable_name: The name of the variable to find. + + Returns: + The variable object from the tax-benefit system. + + Raises: + ValueError: If the variable is not found. + """ + if variable_name not in tax_benefit_system.variables: + raise ValueError( + f"Variable '{variable_name}' not found in tax-benefit system" + ) + return tax_benefit_system.variables[variable_name] + + +def validate_variable_entity( + tax_benefit_system: Any, + variable_name: str, + expected_entity: str, +) -> None: + """Validate that a variable belongs to the expected entity type. + + Args: + tax_benefit_system: The tax-benefit system containing the variable. + variable_name: The name of the variable to validate. + expected_entity: The expected entity key (e.g., "household", "person"). + + Raises: + ValueError: If the variable is not found or belongs to a different entity. + """ + variable = get_variable(tax_benefit_system, variable_name) + actual_entity = variable.entity.key + + if actual_entity != expected_entity: + raise ValueError( + f"Variable '{variable_name}' is a {actual_entity}-level variable, " + f"not a {expected_entity}-level variable." + ) + + +def validate_household_variable( + tax_benefit_system: Any, + variable_name: str, +) -> None: + """Validate that a variable is a household-level variable. + + Args: + tax_benefit_system: The tax-benefit system containing the variable. + variable_name: The name of the variable to validate. + + Raises: + ValueError: If the variable is not found or is not household-level. + """ + variable = get_variable(tax_benefit_system, variable_name) + + if variable.entity.key != "household": + raise ValueError( + f"Variable '{variable_name}' is a {variable.entity.key}-level variable, " + f"not a household-level variable. Only household-level variables can be " + f"used for filtering to preserve household integrity." + ) + + class SimulationOptions(BaseModel): country: CountryType = Field(..., description="The country to simulate.") scope: ScopeType = Field(..., description="The scope of the simulation.") @@ -388,6 +461,108 @@ def _apply_us_region_to_simulation( ) return simulation + def _build_entity_relationships( + self, + simulation: CountryMicrosimulation, + ) -> pd.DataFrame: + """Build a DataFrame mapping each person to their containing entities. + + Creates an explicit relationship map between persons and all entity + types (household, tax_unit, etc.). This enables filtering at any + entity level while preserving the integrity of all related entities. + + Args: + simulation: The microsimulation to extract relationships from. + + Returns: + A DataFrame indexed by person with columns for each entity ID. + """ + entity_rel = pd.DataFrame( + {"person_id": simulation.calculate("person_id").values} + ) + + # Add household relationship (required for all countries) + entity_rel["household_id"] = simulation.calculate( + "household_id", map_to="person" + ).values + + # Add country-specific entity relationships + tbs = simulation.tax_benefit_system + optional_entities = [ + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ] + + for entity_id in optional_entities: + if entity_id in tbs.variables: + entity_rel[entity_id] = simulation.calculate( + entity_id, map_to="person" + ).values + + return entity_rel + + def _filter_simulation_by_household_variable( + self, + simulation: CountryMicrosimulation, + simulation_type: type, + variable_name: str, + variable_value: Any, + reform: ReformType | None, + ) -> CountrySimulation: + """Filter a simulation to only include households where a variable matches a value. + + Uses the entity relationship approach: builds an explicit map of all + entity relationships, filters at the household level, and keeps all + persons in matching households to preserve entity integrity. + + Args: + simulation: The microsimulation to filter. + simulation_type: The type of simulation to create (e.g., Microsimulation). + variable_name: The name of the variable to filter on. Must be a + household-level variable. + variable_value: The value to match. For string variables that may be + stored as bytes in HDF5, both str and bytes versions are checked. + reform: Optional reform to apply to the filtered simulation. + + Returns: + A new simulation containing only households where the variable matches. + + Raises: + ValueError: If the variable is not a household-level variable. + """ + validate_household_variable( + simulation.tax_benefit_system, variable_name + ) + + # Build entity relationships + entity_rel = self._build_entity_relationships(simulation) + + # Get household-level variable values + hh_values = simulation.calculate(variable_name).values + hh_ids = simulation.calculate("household_id").values + + # Create mask for matching households, handling bytes encoding + if isinstance(variable_value, str): + hh_mask = (hh_values == variable_value) | ( + hh_values == variable_value.encode() + ) + else: + hh_mask = hh_values == variable_value + + matching_hh_ids = set(hh_ids[hh_mask]) + + # Filter entity_rel to persons in matching households + person_mask = entity_rel["household_id"].isin(matching_hh_ids) + filtered_entity_rel = entity_rel[person_mask] + + # Filter the input DataFrame using the filtered person indices + df = simulation.to_input_dataframe() + subset_df = df.iloc[filtered_entity_rel.index] + + return simulation_type(dataset=subset_df, reform=reform) + def _filter_us_simulation_by_place( self, simulation: CountryMicrosimulation, @@ -409,16 +584,14 @@ def _filter_us_simulation_by_place( from policyengine.utils.data.datasets import parse_us_place_region _, place_fips_code = parse_us_place_region(region) - df = simulation.to_input_dataframe() - # Get place_fips at person level since to_input_dataframe() is person-level - person_place_fips = simulation.calculate( - "place_fips", map_to="person" - ).values - # place_fips may be stored as bytes in HDF5; handle both str and bytes - mask = (person_place_fips == place_fips_code) | ( - person_place_fips == place_fips_code.encode() + + return self._filter_simulation_by_household_variable( + simulation=simulation, + simulation_type=simulation_type, + variable_name="place_fips", + variable_value=place_fips_code, + reform=reform, ) - return simulation_type(dataset=df[mask], reform=reform) def check_model_version(self) -> None: """ diff --git a/tests/country/test_us_places.py b/tests/country/test_us_places.py index bfdda841..bd6101a1 100644 --- a/tests/country/test_us_places.py +++ b/tests/country/test_us_places.py @@ -1,7 +1,14 @@ -"""Tests for US place-level (city) filtering functionality.""" +"""Tests for US place-level (city) filtering functionality. + +Tests the entity_rel filtering approach which: +1. Builds explicit entity relationships (person -> household, tax_unit, etc.) +2. Filters at household level to preserve entity integrity +3. Creates new simulations from filtered DataFrames +""" import pytest import pandas as pd +import numpy as np from unittest.mock import Mock, patch from tests.fixtures.country.us_places import ( @@ -42,11 +49,50 @@ create_mock_simulation_with_bytes_place_fips, create_mock_simulation_type, create_simulation_instance, + create_mock_tax_benefit_system, ) +class TestBuildEntityRelationships: + """Tests for the _build_entity_relationships method.""" + + def test__given__simulation__then__returns_dataframe_with_person_and_household_ids( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON, persons_per_household=2 + ) + + # When + sim_instance = create_simulation_instance() + entity_rel = sim_instance._build_entity_relationships(mock_sim) + + # Then + assert "person_id" in entity_rel.columns + assert "household_id" in entity_rel.columns + # 5 households * 2 persons each = 10 persons + assert len(entity_rel) == 10 + + def test__given__simulation__then__includes_optional_entity_ids_when_available( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON + ) + + # When + sim_instance = create_simulation_instance() + entity_rel = sim_instance._build_entity_relationships(mock_sim) + + # Then: Optional US entity IDs should be present + assert "tax_unit_id" in entity_rel.columns + assert "spm_unit_id" in entity_rel.columns + + class TestFilterUsSimulationByPlace: - """Tests for the _filter_us_simulation_by_place method.""" + """Tests for the _filter_us_simulation_by_place method using entity_rel approach.""" def test__given__households_in_target_place__then__filters_to_matching_households( self, @@ -70,8 +116,10 @@ def test__given__households_in_target_place__then__filters_to_matching_household call_args = mock_simulation_type.call_args filtered_df = call_args.kwargs["dataset"] + # With entity_rel, DataFrame is person-level assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_MIXED - assert all(filtered_df["place_fips"] == NJ_PATERSON_FIPS) + # Verify all records belong to Paterson households + assert all(filtered_df["place_fips__2024"] == NJ_PATERSON_FIPS) def test__given__no_households_in_target_place__then__returns_empty_dataset( self, @@ -187,7 +235,109 @@ def test__given__different_place_in_same_state__then__filters_correctly( filtered_df = call_args.kwargs["dataset"] assert len(filtered_df) == EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ - assert all(filtered_df["place_fips"] == NJ_NEWARK_FIPS) + assert all(filtered_df["place_fips__2024"] == NJ_NEWARK_FIPS) + + def test__given__multi_person_households__then__preserves_all_persons( + self, + ): + # Given: 3 households with 2 persons each, only first household in Paterson + mock_sim = create_mock_simulation_with_place_fips( + [NJ_PATERSON_FIPS, NJ_NEWARK_FIPS, NJ_JERSEY_CITY_FIPS], + persons_per_household=2, + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + # Then: Should have 2 persons (both from the Paterson household) + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == 2 # 1 household * 2 persons + # All persons should be from household 0 + assert all(filtered_df["household_id__2024"] == 0) + + +class TestFilterSimulationByHouseholdVariable: + """Tests for _filter_simulation_by_household_variable validation and behavior.""" + + def test__given__non_household_variable__then__raises_value_error(self): + # Given: A mock with a person-level variable + mock_sim = Mock() + mock_tbs = Mock() + mock_var = Mock() + mock_var.entity = Mock() + mock_var.entity.key = "person" # Not household-level + mock_tbs.variables = {"age": mock_var} + mock_sim.tax_benefit_system = mock_tbs + + mock_simulation_type = create_mock_simulation_type() + + # When / Then + sim_instance = create_simulation_instance() + with pytest.raises(ValueError) as exc_info: + sim_instance._filter_simulation_by_household_variable( + simulation=mock_sim, + simulation_type=mock_simulation_type, + variable_name="age", + variable_value=30, + reform=None, + ) + + assert "person-level variable" in str(exc_info.value) + assert "household-level variable" in str(exc_info.value) + + def test__given__nonexistent_variable__then__raises_value_error(self): + # Given + mock_sim = Mock() + mock_tbs = Mock() + mock_tbs.variables = {} # Empty - no variables + mock_sim.tax_benefit_system = mock_tbs + + mock_simulation_type = create_mock_simulation_type() + + # When / Then + sim_instance = create_simulation_instance() + with pytest.raises(ValueError) as exc_info: + sim_instance._filter_simulation_by_household_variable( + simulation=mock_sim, + simulation_type=mock_simulation_type, + variable_name="nonexistent_var", + variable_value="test", + reform=None, + ) + + assert "not found" in str(exc_info.value) + + def test__given__household_variable__then__filters_successfully(self): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_simulation_by_household_variable( + simulation=mock_sim, + simulation_type=mock_simulation_type, + variable_name="place_fips", + variable_value=NJ_PATERSON_FIPS, + reform=None, + ) + + # Then: Should create simulation with filtered data + assert mock_simulation_type.called + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_MIXED class TestApplyUsRegionToSimulationWithPlace: diff --git a/tests/fixtures/country/us_places.py b/tests/fixtures/country/us_places.py index ff65af60..3dc7e12a 100644 --- a/tests/fixtures/country/us_places.py +++ b/tests/fixtures/country/us_places.py @@ -1,9 +1,15 @@ -"""Test fixtures for US place-level filtering tests.""" +"""Test fixtures for US place-level filtering tests. + +These fixtures support testing the entity_rel filtering approach which: +1. Builds explicit entity relationships (person -> household, tax_unit, etc.) +2. Filters at household level to preserve entity integrity +3. Creates new simulations from filtered DataFrames +""" import pytest import numpy as np import pandas as pd -from unittest.mock import Mock +from unittest.mock import Mock, MagicMock # ============================================================================= # Place FIPS Constants @@ -123,34 +129,129 @@ # ============================================================================= +def create_mock_tax_benefit_system( + household_variables: list[str] | None = None, +) -> Mock: + """Create a mock tax benefit system with variable entity information. + + Args: + household_variables: List of variable names that are household-level. + Defaults to ["place_fips"]. + + Returns: + Mock TaxBenefitSystem with variables dict containing entity info. + """ + if household_variables is None: + household_variables = ["place_fips"] + + mock_tbs = Mock() + mock_tbs.variables = {} + + for var_name in household_variables: + mock_var = Mock() + mock_var.entity = Mock() + mock_var.entity.key = "household" + mock_tbs.variables[var_name] = mock_var + + # Add standard entity ID variables + for entity_id in [ + "person_id", + "household_id", + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ]: + mock_var = Mock() + mock_var.entity = Mock() + # Entity IDs belong to their respective entities + entity_name = entity_id.replace("_id", "") + mock_var.entity.key = ( + entity_name if entity_name != "person" else "person" + ) + mock_tbs.variables[entity_id] = mock_var + + return mock_tbs + + def create_mock_simulation_with_place_fips( place_fips_values: list[str], household_ids: list[int] | None = None, + persons_per_household: int = 1, ) -> Mock: - """Create a mock simulation with place_fips data. + """Create a mock simulation with place_fips data for entity_rel filtering. + + Supports the entity_rel approach by mocking: + - calculate() with variable-specific return values + - tax_benefit_system.variables for entity validation + - to_input_dataframe() returning person-level DataFrame Args: place_fips_values: List of place FIPS codes for each household. household_ids: Optional list of household IDs. + persons_per_household: Number of persons per household (default 1). Returns: - Mock simulation object with calculate() and to_input_dataframe() configured. + Mock simulation object configured for entity_rel filtering. """ if household_ids is None: household_ids = list(range(len(place_fips_values))) - mock_sim = Mock() + num_households = len(place_fips_values) + num_persons = num_households * persons_per_household + + # Create person-level data by repeating household data + person_ids = list(range(num_persons)) + person_household_ids = [] + person_place_fips = [] + for i, (hh_id, place) in enumerate(zip(household_ids, place_fips_values)): + for _ in range(persons_per_household): + person_household_ids.append(hh_id) + person_place_fips.append(place) - # Mock calculate to return place_fips values - mock_calculate_result = Mock() - mock_calculate_result.values = np.array(place_fips_values) - mock_sim.calculate.return_value = mock_calculate_result + mock_sim = Mock() - # Mock to_input_dataframe to return a DataFrame + # Mock tax_benefit_system + mock_sim.tax_benefit_system = create_mock_tax_benefit_system() + + # Mock calculate to return different values based on variable and map_to + def mock_calculate(variable_name, map_to=None, period=None): + result = Mock() + if variable_name == "place_fips": + if map_to == "person": + result.values = np.array(person_place_fips) + else: + result.values = np.array(place_fips_values) + elif variable_name == "person_id": + result.values = np.array(person_ids) + elif variable_name == "household_id": + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + elif variable_name in [ + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ]: + # For simplicity, use household_id as proxy for other entity IDs + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + else: + result.values = np.array([]) + return result + + mock_sim.calculate = mock_calculate + + # Mock to_input_dataframe to return person-level DataFrame df = pd.DataFrame( { - "household_id": household_ids, - "place_fips": place_fips_values, + "person_id__2024": person_ids, + "household_id__2024": person_household_ids, + "place_fips__2024": person_place_fips, } ) mock_sim.to_input_dataframe.return_value = df @@ -161,29 +262,70 @@ def create_mock_simulation_with_place_fips( def create_mock_simulation_with_bytes_place_fips( place_fips_values: list[bytes], household_ids: list[int] | None = None, + persons_per_household: int = 1, ) -> Mock: """Create a mock simulation with bytes place_fips data (as from HDF5). Args: place_fips_values: List of place FIPS codes as bytes. household_ids: Optional list of household IDs. + persons_per_household: Number of persons per household (default 1). Returns: - Mock simulation object with calculate() and to_input_dataframe() configured. + Mock simulation object configured for entity_rel filtering. """ if household_ids is None: household_ids = list(range(len(place_fips_values))) - mock_sim = Mock() + num_households = len(place_fips_values) + num_persons = num_households * persons_per_household - mock_calculate_result = Mock() - mock_calculate_result.values = np.array(place_fips_values) - mock_sim.calculate.return_value = mock_calculate_result + person_ids = list(range(num_persons)) + person_household_ids = [] + person_place_fips = [] + for i, (hh_id, place) in enumerate(zip(household_ids, place_fips_values)): + for _ in range(persons_per_household): + person_household_ids.append(hh_id) + person_place_fips.append(place) + + mock_sim = Mock() + mock_sim.tax_benefit_system = create_mock_tax_benefit_system() + + def mock_calculate(variable_name, map_to=None, period=None): + result = Mock() + if variable_name == "place_fips": + if map_to == "person": + result.values = np.array(person_place_fips) + else: + result.values = np.array(place_fips_values) + elif variable_name == "person_id": + result.values = np.array(person_ids) + elif variable_name == "household_id": + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + elif variable_name in [ + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ]: + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + else: + result.values = np.array([]) + return result + + mock_sim.calculate = mock_calculate df = pd.DataFrame( { - "household_id": household_ids, - "place_fips": place_fips_values, + "person_id__2024": person_ids, + "household_id__2024": person_household_ids, + "place_fips__2024": person_place_fips, } ) mock_sim.to_input_dataframe.return_value = df diff --git a/tests/test_simulation.py b/tests/test_simulation.py index f761588e..b4f9c752 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -8,7 +8,9 @@ mock_simulation_with_cliff_vars, ) import sys +import pytest from copy import deepcopy +from unittest.mock import Mock class TestSimulation: @@ -83,3 +85,119 @@ def test__calculates_correct_cliff_metrics( assert cliff_result.cliff_gap == 100.0 assert cliff_result.cliff_share == 0.5 + + +class TestVariableValidation: + """Tests for variable validation functions.""" + + @staticmethod + def _create_mock_tbs(variables: dict[str, str]) -> Mock: + """Create a mock tax-benefit system with specified variables. + + Args: + variables: Dict mapping variable names to entity keys. + + Returns: + Mock TBS with variables configured. + """ + mock_tbs = Mock() + mock_tbs.variables = {} + for var_name, entity_key in variables.items(): + mock_var = Mock() + mock_var.entity = Mock() + mock_var.entity.key = entity_key + mock_tbs.variables[var_name] = mock_var + return mock_tbs + + class TestGetVariable: + def test__given__existing_variable__then__returns_variable(self): + from policyengine.simulation import get_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"place_fips": "household"} + ) + + result = get_variable(mock_tbs, "place_fips") + + assert result is mock_tbs.variables["place_fips"] + + def test__given__nonexistent_variable__then__raises_value_error(self): + from policyengine.simulation import get_variable + + mock_tbs = TestVariableValidation._create_mock_tbs({}) + + with pytest.raises(ValueError) as exc_info: + get_variable(mock_tbs, "nonexistent") + + assert "not found" in str(exc_info.value) + assert "nonexistent" in str(exc_info.value) + + class TestValidateVariableEntity: + def test__given__matching_entity__then__passes(self): + from policyengine.simulation import validate_variable_entity + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"place_fips": "household"} + ) + + # Should not raise + validate_variable_entity(mock_tbs, "place_fips", "household") + + def test__given__mismatched_entity__then__raises_value_error(self): + from policyengine.simulation import validate_variable_entity + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"age": "person"} + ) + + with pytest.raises(ValueError) as exc_info: + validate_variable_entity(mock_tbs, "age", "household") + + assert "person-level" in str(exc_info.value) + assert "household-level" in str(exc_info.value) + + def test__given__nonexistent_variable__then__raises_value_error(self): + from policyengine.simulation import validate_variable_entity + + mock_tbs = TestVariableValidation._create_mock_tbs({}) + + with pytest.raises(ValueError) as exc_info: + validate_variable_entity(mock_tbs, "nonexistent", "household") + + assert "not found" in str(exc_info.value) + + class TestValidateHouseholdVariable: + def test__given__household_variable__then__passes(self): + from policyengine.simulation import validate_household_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"place_fips": "household"} + ) + + # Should not raise + validate_household_variable(mock_tbs, "place_fips") + + def test__given__person_variable__then__raises_value_error(self): + from policyengine.simulation import validate_household_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"age": "person"} + ) + + with pytest.raises(ValueError) as exc_info: + validate_household_variable(mock_tbs, "age") + + assert "person-level" in str(exc_info.value) + assert "household-level" in str(exc_info.value) + + def test__given__tax_unit_variable__then__raises_value_error(self): + from policyengine.simulation import validate_household_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"filing_status": "tax_unit"} + ) + + with pytest.raises(ValueError) as exc_info: + validate_household_variable(mock_tbs, "filing_status") + + assert "tax_unit-level" in str(exc_info.value)