Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/microplex_us/pipelines/us.py
Original file line number Diff line number Diff line change
Expand Up @@ -5309,7 +5309,11 @@ def _build_donor_imputer(
{
variable
for variable, support_family in support_families.items()
if support_family is VariableSupportFamily.ZERO_INFLATED_POSITIVE
if support_family
in {
VariableSupportFamily.ZERO_INFLATED_POSITIVE,
VariableSupportFamily.ZERO_INFLATED_SIGNED,
}
}
if backend == "zi_qrf"
else set()
Expand Down
45 changes: 24 additions & 21 deletions src/microplex_us/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class VariableSupportFamily(Enum):

CONTINUOUS = "continuous"
ZERO_INFLATED_POSITIVE = "zero_inflated_positive"
ZERO_INFLATED_SIGNED = "zero_inflated_signed"
BOUNDED_SHARE = "bounded_share"


Expand Down Expand Up @@ -155,7 +156,10 @@ def is_redundant_given(self, variable_names: Iterable[str]) -> bool:

@property
def condition_score_mode(self) -> ConditionScoreMode:
if self.support_family is VariableSupportFamily.ZERO_INFLATED_POSITIVE:
if self.support_family in {
VariableSupportFamily.ZERO_INFLATED_POSITIVE,
VariableSupportFamily.ZERO_INFLATED_SIGNED,
}:
return ConditionScoreMode.VALUE_AND_SUPPORT
return ConditionScoreMode.VALUE_ONLY

Expand All @@ -165,9 +169,7 @@ def allowed_condition_entities(self) -> tuple[EntityType, ...]:
return self.condition_entities
if self.native_entity is EntityType.PERSON:
record_entity = getattr(EntityType, "RECORD", None)
return tuple(
entity for entity in EntityType if entity is not record_entity
)
return tuple(entity for entity in EntityType if entity is not record_entity)
return (EntityType.HOUSEHOLD, self.native_entity)


Expand Down Expand Up @@ -457,7 +459,7 @@ def minor_positive_employment_income_mask(frame: pd.DataFrame) -> pd.Series:
EntityType.HOUSEHOLD,
EntityType.TAX_UNIT,
),
support_family=VariableSupportFamily.ZERO_INFLATED_POSITIVE,
support_family=VariableSupportFamily.ZERO_INFLATED_SIGNED,
donor_match_strategy=DonorMatchStrategy.ZERO_INFLATED_POSITIVE,
preferred_condition_vars=PUF_IRS_TAX_PREFERRED_CONDITION_VARS,
supplemental_shared_condition_vars=PUF_IRS_TAX_SUPPLEMENTAL_SHARED_CONDITION_VARS,
Expand Down Expand Up @@ -541,7 +543,7 @@ def minor_positive_employment_income_mask(frame: pd.DataFrame) -> pd.Series:
EntityType.HOUSEHOLD,
EntityType.TAX_UNIT,
),
support_family=VariableSupportFamily.ZERO_INFLATED_POSITIVE,
support_family=VariableSupportFamily.ZERO_INFLATED_SIGNED,
donor_match_strategy=DonorMatchStrategy.ZERO_INFLATED_POSITIVE,
preferred_condition_vars=PUF_IRS_TAX_PREFERRED_CONDITION_VARS,
supplemental_shared_condition_vars=PUF_IRS_TAX_SUPPLEMENTAL_SHARED_CONDITION_VARS,
Expand Down Expand Up @@ -683,7 +685,9 @@ def normalize_dividend_columns(frame: pd.DataFrame) -> pd.DataFrame:
component_total = qualified + non_qualified
normalized_total = component_total.where(component_total.ne(0.0), total)
elif has_qualified:
normalized_total = np.maximum(total.to_numpy(dtype=float), qualified.to_numpy(dtype=float))
normalized_total = np.maximum(
total.to_numpy(dtype=float), qualified.to_numpy(dtype=float)
)
non_qualified = pd.Series(
normalized_total - qualified.to_numpy(dtype=float),
index=result.index,
Expand Down Expand Up @@ -724,8 +728,12 @@ def normalize_social_security_columns(frame: pd.DataFrame) -> pd.DataFrame:
column: _nonnegative_series(result, column)
for column in SOCIAL_SECURITY_COMPONENT_COLUMNS
}
component_sum = sum(component_series.values(), start=pd.Series(0.0, index=result.index))
existing_unclassified = _nonnegative_series(result, SOCIAL_SECURITY_UNCLASSIFIED_COLUMN)
component_sum = sum(
component_series.values(), start=pd.Series(0.0, index=result.index)
)
existing_unclassified = _nonnegative_series(
result, SOCIAL_SECURITY_UNCLASSIFIED_COLUMN
)

if "social_security" in result.columns:
observed_total = _nonnegative_series(result, "social_security")
Expand All @@ -741,7 +749,8 @@ def normalize_social_security_columns(frame: pd.DataFrame) -> pd.DataFrame:
)
unclassified = pd.Series(
np.maximum(
normalized_total.to_numpy(dtype=float) - component_sum.to_numpy(dtype=float),
normalized_total.to_numpy(dtype=float)
- component_sum.to_numpy(dtype=float),
0.0,
),
index=result.index,
Expand Down Expand Up @@ -832,6 +841,7 @@ def restore_dividend_components_from_composition(frame: pd.DataFrame) -> pd.Data
restore_frame=restore_dividend_components_from_composition,
)


def variable_semantic_spec_for(variable_name: str) -> VariableSemanticSpec:
"""Return semantic metadata for one variable."""
return VARIABLE_SEMANTIC_SPECS.get(variable_name, VariableSemanticSpec())
Expand Down Expand Up @@ -915,9 +925,7 @@ def resolve_condition_entities_for_targets(
shared &= set(allowed_entities)
if not shared:
return (EntityType.HOUSEHOLD,)
return tuple(
entity for entity in allowed_by_target[0] if entity in shared
)
return tuple(entity for entity in allowed_by_target[0] if entity in shared)


def is_condition_var_compatible_with_targets(
Expand All @@ -942,15 +950,12 @@ def is_projected_condition_var_compatible(
condition_entity = variable_semantic_spec_for(condition_variable).native_entity
record_entity = getattr(EntityType, "RECORD", None)
allowed_entities = {
entity
for entity in allowed_condition_entities
if entity is not record_entity
entity for entity in allowed_condition_entities if entity is not record_entity
}
if condition_entity in allowed_entities:
return True
return (
condition_entity is EntityType.PERSON
and projected_entity in allowed_entities
condition_entity is EntityType.PERSON and projected_entity in allowed_entities
)


Expand All @@ -971,9 +976,7 @@ def donor_imputation_block_specs(
condition_entities=resolve_condition_entities_for_targets((variable,)),
model_variables=(variable,),
restored_variables=(variable,),
match_strategies={
variable: spec.donor_match_strategy
},
match_strategies={variable: spec.donor_match_strategy},
)
)
return tuple(block_specs)
Expand Down
51 changes: 51 additions & 0 deletions tests/pipelines/test_us.py
Original file line number Diff line number Diff line change
Expand Up @@ -3386,6 +3386,57 @@ def generate(self, frame, seed=None):
200.0,
}

def test_signed_zero_inflated_donor_vars_are_not_clamped(self, monkeypatch):
captured: dict[str, dict[str, object]] = {}

class FakeRegimeAwareDonorImputer:
def __init__(self, **kwargs):
captured["regime_aware"] = kwargs

class FakeQRFImputer:
def __init__(self, **kwargs):
captured["zi_qrf"] = kwargs

monkeypatch.setattr(
"microplex_us.pipelines.us.RegimeAwareDonorImputer",
FakeRegimeAwareDonorImputer,
)
monkeypatch.setattr(
"microplex_us.pipelines.us.ColumnwiseQRFDonorImputer",
FakeQRFImputer,
)

target_vars = ("partnership_s_corp_income", "public_assistance")

regime_pipeline = USMicroplexPipeline(
USMicroplexBuildConfig(
n_synthetic=4,
donor_imputer_backend="regime_aware",
)
)
regime_pipeline._build_donor_imputer(
condition_vars=["age"],
target_vars=target_vars,
)

qrf_pipeline = USMicroplexPipeline(
USMicroplexBuildConfig(
n_synthetic=4,
donor_imputer_backend="zi_qrf",
)
)
qrf_pipeline._build_donor_imputer(
condition_vars=["age"],
target_vars=target_vars,
)

assert captured["regime_aware"]["nonnegative_vars"] == {"public_assistance"}
assert captured["zi_qrf"]["nonnegative_vars"] == {"public_assistance"}
assert captured["zi_qrf"]["zero_inflated_vars"] == {
"partnership_s_corp_income",
"public_assistance",
}

def test_integrate_donor_sources_preserves_informative_scaffold_values(
self, monkeypatch
):
Expand Down
41 changes: 28 additions & 13 deletions tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ def test_resolve_variable_semantic_capabilities_marks_redundant_dividend_totals(
def test_variable_semantics_define_projection_aggregation_for_person_controls():
from microplex_us.variables import variable_semantic_spec_for

assert EntityType.RECORD not in variable_semantic_spec_for("age").allowed_condition_entities
assert (
EntityType.RECORD
not in variable_semantic_spec_for("age").allowed_condition_entities
)
assert (
variable_semantic_spec_for("age").projection_aggregation
is ProjectionAggregation.MAX
Expand All @@ -310,7 +313,9 @@ def test_state_program_proxy_semantics_are_registered():

has_medicaid = variable_semantic_spec_for("has_medicaid")
assert has_medicaid.support_family is VariableSupportFamily.ZERO_INFLATED_POSITIVE
assert has_medicaid.donor_match_strategy is DonorMatchStrategy.ZERO_INFLATED_POSITIVE
assert (
has_medicaid.donor_match_strategy is DonorMatchStrategy.ZERO_INFLATED_POSITIVE
)
assert has_medicaid.condition_score_mode is ConditionScoreMode.VALUE_AND_SUPPORT
assert has_medicaid.projection_aggregation is ProjectionAggregation.MAX

Expand Down Expand Up @@ -340,7 +345,7 @@ def test_partnership_income_semantics_remain_person_native():

spec = variable_semantic_spec_for("partnership_s_corp_income")
assert spec.native_entity is EntityType.PERSON
assert spec.support_family is VariableSupportFamily.ZERO_INFLATED_POSITIVE
assert spec.support_family is VariableSupportFamily.ZERO_INFLATED_SIGNED
assert spec.donor_match_strategy is DonorMatchStrategy.ZERO_INFLATED_POSITIVE
assert spec.condition_score_mode is ConditionScoreMode.VALUE_AND_SUPPORT

Expand Down Expand Up @@ -381,23 +386,27 @@ def test_sparse_irs_tax_variables_use_puf_irs_predictors():

assert PUF_IRS_TAX_SUPPLEMENTAL_SHARED_CONDITION_VARS == ()
assert (
variable_semantic_spec_for("taxable_interest_income")
.challenger_shared_condition_vars
variable_semantic_spec_for(
"taxable_interest_income"
).challenger_shared_condition_vars
== PUF_DIVIDEND_INTEREST_CHALLENGER_SHARED_CONDITION_VARS
)
assert (
variable_semantic_spec_for("qualified_dividend_income")
.challenger_shared_condition_vars
variable_semantic_spec_for(
"qualified_dividend_income"
).challenger_shared_condition_vars
== PUF_DIVIDEND_INTEREST_CHALLENGER_SHARED_CONDITION_VARS
)
assert (
variable_semantic_spec_for("taxable_pension_income")
.challenger_shared_condition_vars
variable_semantic_spec_for(
"taxable_pension_income"
).challenger_shared_condition_vars
== PUF_PENSION_CHALLENGER_SHARED_CONDITION_VARS
)
assert (
variable_semantic_spec_for("partnership_s_corp_income")
.challenger_shared_condition_vars
variable_semantic_spec_for(
"partnership_s_corp_income"
).challenger_shared_condition_vars
== PUF_PARTNERSHIP_CHALLENGER_SHARED_CONDITION_VARS
)

Expand Down Expand Up @@ -434,7 +443,9 @@ def test_person_native_irs_semantics_match_current_policyengine_entities():
"student_loan_interest",
"self_employment_income",
):
assert variable_semantic_spec_for(variable_name).native_entity is EntityType.PERSON
assert (
variable_semantic_spec_for(variable_name).native_entity is EntityType.PERSON
)


def test_self_employment_income_semantics_preserve_signed_support():
Expand Down Expand Up @@ -487,7 +498,11 @@ def test_employment_income_donor_semantics_uses_unclassified_social_security_com
adjusted = apply_donor_variable_semantics(frame, ("employment_income",))

assert adjusted["social_security_retirement"].tolist() == [0.0, 0.0, 0.0]
assert adjusted["social_security_unclassified"].tolist() == [18_000.0, 18_000.0, 0.0]
assert adjusted["social_security_unclassified"].tolist() == [
18_000.0,
18_000.0,
0.0,
]
assert adjusted["employment_income"].tolist() == [0.0, 80_000.0, 80_000.0]


Expand Down
Loading