Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 115 additions & 40 deletions src/microplex_us/pipelines/ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,12 @@ def build_sound_ecps_replacement_comparison(
and candidate_score_error <= score_consistency_tol
and baseline_score_error <= score_consistency_tol
)
ecps_refit_recovery_passed = (
baseline_refit["optimized_full_loss"]
ecps_refit_recovery_passed = baseline_refit[
"optimized_full_loss"
] <= baseline_refit["initial_full_loss"] + score_consistency_tol and (
baseline_score_loss is None
or baseline_score_loss
<= baseline_refit["initial_full_loss"] + score_consistency_tol
and (
baseline_score_loss is None
or baseline_score_loss
<= baseline_refit["initial_full_loss"] + score_consistency_tol
)
)

protected_family_losses = _protected_family_losses(
Expand Down Expand Up @@ -232,6 +230,10 @@ def build_sound_ecps_replacement_comparison(
support_audit_summary = (
_support_audit_summary(support_audit) if support_audit is not None else None
)
candidate_optimizer_summary = dict(candidate_refit["optimizer_summary"])
baseline_optimizer_summary = dict(baseline_refit["optimizer_summary"])
candidate_refit_progress = _refit_progress_summary(candidate_refit["loss_curve"])
baseline_refit_progress = _refit_progress_summary(baseline_refit["loss_curve"])

score_summary.update(
{
Expand All @@ -252,6 +254,30 @@ def build_sound_ecps_replacement_comparison(
"baseline_score_abs_error": baseline_score_error,
"candidate_refit_config": refit_config,
"baseline_refit_config": refit_config,
"candidate_refit_converged": bool(
candidate_optimizer_summary.get("converged", False)
),
"baseline_refit_converged": bool(
baseline_optimizer_summary.get("converged", False)
),
"candidate_refit_iterations": int(
candidate_optimizer_summary.get("iterations", 0)
),
"baseline_refit_iterations": int(
baseline_optimizer_summary.get("iterations", 0)
),
"candidate_refit_train_loss_improvement_last_step": (
candidate_refit_progress["train_loss_improvement_last_step"]
),
"baseline_refit_train_loss_improvement_last_step": (
baseline_refit_progress["train_loss_improvement_last_step"]
),
"candidate_refit_train_loss_improvement_last_20_steps": (
candidate_refit_progress["train_loss_improvement_last_20_steps"]
),
"baseline_refit_train_loss_improvement_last_20_steps": (
baseline_refit_progress["train_loss_improvement_last_20_steps"]
),
"symmetric_refit": True,
"score_candidate_only": False,
"refit_objective_matches_scoring": objective_identity_passed,
Expand Down Expand Up @@ -282,6 +308,12 @@ def build_sound_ecps_replacement_comparison(
"score_candidate_only": False,
"refit_objective_matches_scoring": objective_identity_passed,
"ecps_refit_recovery_passed": ecps_refit_recovery_passed,
"candidate_refit_converged": bool(
candidate_optimizer_summary.get("converged", False)
),
"baseline_refit_converged": bool(
baseline_optimizer_summary.get("converged", False)
),
"holdout_target_fraction": float(holdout_target_fraction),
"holdout_targets": int(holdout_mask.sum()),
"protected_family_losses": protected_family_losses,
Expand Down Expand Up @@ -324,7 +356,9 @@ def build_sound_ecps_replacement_comparison(
"train_targets": int((~holdout_mask).sum()),
"holdout_targets": int(holdout_mask.sum()),
"holdout_target_names": [
name for name, holdout in zip(target_names, holdout_mask, strict=True) if holdout
name
for name, holdout in zip(target_names, holdout_mask, strict=True)
if holdout
],
},
"refit_config": refit_config,
Expand Down Expand Up @@ -386,7 +420,9 @@ def _write_matched_dataset(
force: bool,
) -> None:
if output_path.exists() and not force:
raise FileExistsError(f"{output_path} already exists; pass --force to replace it")
raise FileExistsError(
f"{output_path} already exists; pass --force to replace it"
)
_write_matched_policyengine_us_baseline_dataset(
input_path,
output_path,
Expand Down Expand Up @@ -438,9 +474,7 @@ def _entity_structure_summary(
period_key,
)
if person_ids.shape[0] != person_household_ids.shape[0]:
raise ValueError(
f"{path} person_id and person_household_id lengths differ"
)
raise ValueError(f"{path} person_id and person_household_id lengths differ")

household_count = int(household_ids.shape[0])
summary: dict[str, Any] = {
Expand Down Expand Up @@ -500,8 +534,7 @@ def _entity_membership_summary(
)
if person_entity_ids.shape[0] != person_household_ids.shape[0]:
raise ValueError(
f"{dataset_path} person_{entity}_id and person_household_id "
"lengths differ"
f"{dataset_path} person_{entity}_id and person_household_id lengths differ"
)
unique_entity_ids = np.unique(entity_ids)
duplicate_unit_id_count = int(entity_ids.shape[0] - unique_entity_ids.shape[0])
Expand Down Expand Up @@ -602,8 +635,10 @@ def _extract_pe_native_loss_inputs(
check=False,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip() or str(
completed.returncode
detail = (
completed.stderr.strip()
or completed.stdout.strip()
or str(completed.returncode)
)
raise RuntimeError(f"PE-native loss-matrix extraction failed: {detail}")
return {
Expand Down Expand Up @@ -761,6 +796,24 @@ def _objective(matrix: np.ndarray, target: np.ndarray, weights: np.ndarray) -> f
return float(np.dot(residual, residual))


def _refit_progress_summary(
loss_curve: list[dict[str, Any]],
) -> dict[str, float | None]:
if len(loss_curve) < 2:
return {
"train_loss_improvement_last_step": None,
"train_loss_improvement_last_20_steps": None,
}
last_train_loss = float(loss_curve[-1]["train_loss"])
previous_train_loss = float(loss_curve[-2]["train_loss"])
lookback_index = max(0, len(loss_curve) - 21)
lookback_train_loss = float(loss_curve[lookback_index]["train_loss"])
return {
"train_loss_improvement_last_step": previous_train_loss - last_train_loss,
"train_loss_improvement_last_20_steps": lookback_train_loss - last_train_loss,
}


def _protected_family_losses(
*,
target_names: list[str],
Expand All @@ -771,7 +824,6 @@ def _protected_family_losses(
) -> dict[str, dict[str, float | int]]:
candidate_terms = _loss_terms(candidate_inputs, candidate_weights)
baseline_terms = _loss_terms(baseline_inputs, baseline_weights)
n_targets = float(len(target_names))
rows: dict[str, dict[str, float | int]] = {}
for family, patterns in _PROTECTED_TARGET_PATTERNS.items():
indices = [
Expand All @@ -781,8 +833,8 @@ def _protected_family_losses(
]
if not indices:
continue
candidate_loss = float(candidate_terms[indices].sum() / n_targets)
baseline_loss = float(baseline_terms[indices].sum() / n_targets)
candidate_loss = float(candidate_terms[indices].sum())
baseline_loss = float(baseline_terms[indices].sum())
rows[family] = {
"n_targets": int(len(indices)),
"candidate_loss": candidate_loss,
Expand Down Expand Up @@ -861,6 +913,12 @@ def _target_loss_diagnostics(
"baseline_relative_error": float(
baseline_values["relative_error"][index]
),
"candidate_shifted_residual_ratio": float(
candidate_values["relative_error"][index]
),
"baseline_shifted_residual_ratio": float(
baseline_values["relative_error"][index]
),
"candidate_loss_term": candidate_loss,
"baseline_loss_term": baseline_loss,
"loss_delta": float(loss_delta),
Expand Down Expand Up @@ -921,9 +979,9 @@ def _target_value_diagnostics(
target[native_mask] = np.asarray(unscaled_target, dtype=np.float64)[
native_mask
]
estimate[native_mask] = scaled_estimate[native_mask] / scaling_array[
native_mask
]
estimate[native_mask] = (
scaled_estimate[native_mask] / scaling_array[native_mask]
)
value_scale[native_mask] = "native"
if target.shape != estimate.shape:
raise ValueError("target and estimate shapes differ")
Expand All @@ -945,7 +1003,6 @@ def _target_family_breakdown(
families: dict[str, list[dict[str, Any]]] = {}
for row in target_rows:
families.setdefault(str(row["family"]), []).append(row)
denominator = float(total_targets) if total_targets else 1.0
breakdown = []
for family, rows in sorted(families.items()):
candidate_loss = sum(float(row["candidate_loss_term"]) for row in rows)
Expand All @@ -954,19 +1011,13 @@ def _target_family_breakdown(
{
"family": family,
"n_targets": int(len(rows)),
"train_targets": int(
sum(1 for row in rows if row["split"] == "train")
),
"train_targets": int(sum(1 for row in rows if row["split"] == "train")),
"holdout_targets": int(
sum(1 for row in rows if row["split"] == "holdout")
),
"candidate_loss_contribution": float(
candidate_loss / denominator
),
"baseline_loss_contribution": float(baseline_loss / denominator),
"loss_delta": float(
(candidate_loss - baseline_loss) / denominator
),
"candidate_loss_contribution": float(candidate_loss),
"baseline_loss_contribution": float(baseline_loss),
"loss_delta": float(candidate_loss - baseline_loss),
"candidate_wins": int(
sum(1 for row in rows if row["winner"] == "candidate")
),
Expand All @@ -976,7 +1027,9 @@ def _target_family_breakdown(
"ties": int(sum(1 for row in rows if row["winner"] == "tie")),
}
)
return sorted(breakdown, key=lambda row: abs(float(row["loss_delta"])), reverse=True)
return sorted(
breakdown, key=lambda row: abs(float(row["loss_delta"])), reverse=True
)


def _support_audit_summary(support_audit: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -1008,10 +1061,27 @@ def _support_audit_summary(support_audit: dict[str, Any]) -> dict[str, Any]:
"top_medicare_part_b_by_age_gaps": _sort_rows_by_abs_delta(
list(comparisons.get("medicare_part_b_premiums_by_age_delta") or ()),
"weighted_positive_delta",
drop_zero=True,
),
"top_aca_ptc_spending_gaps": _sort_rows_by_abs_delta(
list(comparisons.get("state_aca_ptc_spending_top_gaps") or ()),
"weighted_aca_ptc_delta",
drop_zero=True,
),
"top_state_marketplace_enrollment_gaps": _sort_rows_by_abs_delta(
list(comparisons.get("state_marketplace_enrollment_top_gaps") or ()),
"weighted_marketplace_enrollment_delta",
drop_zero=True,
),
"top_state_age_bucket_gaps": _sort_rows_by_abs_delta(
list(comparisons.get("state_age_bucket_top_gaps") or ()),
"weighted_count_delta",
drop_zero=True,
),
"top_mfs_high_agi_gaps": _sort_rows_by_abs_delta(
list(comparisons.get("mfs_high_agi_delta") or ()),
"weighted_count_delta",
drop_zero=True,
),
}

Expand All @@ -1021,12 +1091,20 @@ def _sort_rows_by_abs_delta(
delta_key: str,
*,
limit: int = 10,
drop_zero: bool = False,
) -> list[dict[str, Any]]:
return sorted(
sorted_rows = sorted(
rows,
key=lambda row: abs(float(row.get(delta_key, 0.0))),
reverse=True,
)[:limit]
)
if drop_zero:
sorted_rows = [
row
for row in sorted_rows
if not np.isclose(float(row.get(delta_key, 0.0)), 0.0)
]
return sorted_rows[:limit]


def _loss_terms(loss_inputs: dict[str, Any], weights: np.ndarray) -> np.ndarray:
Expand All @@ -1042,10 +1120,7 @@ def _target_matches_protected_family(
patterns: tuple[str, ...],
) -> bool:
normalized = (
target_name.lower()
.replace("-", "_")
.replace(" ", "_")
.replace("/", "_")
target_name.lower().replace("-", "_").replace(" ", "_").replace("/", "_")
)
if family == "wages" and (
"self_employment" in normalized or "business_income" in normalized
Expand Down
21 changes: 15 additions & 6 deletions src/microplex_us/pipelines/pe_native_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _project_to_simplex(values: np.ndarray, total: float) -> np.ndarray:
return values.copy()
clipped = np.maximum(values.astype(np.float64, copy=False), 0.0)
current_sum = float(clipped.sum())
if np.isclose(current_sum, total):
if np.isclose(current_sum, total, rtol=0.0, atol=1e-6):
return clipped
if total <= 0.0:
return np.zeros_like(clipped)
Expand Down Expand Up @@ -246,7 +246,9 @@ def optimize_pe_native_loss_weights(

initial_weight_sum = float(weights0.sum())
total_weight = (
float(target_total_weight) if target_total_weight is not None else initial_weight_sum
float(target_total_weight)
if target_total_weight is not None
else initial_weight_sum
)
weights = _project_to_budget_simplex(weights0, total_weight, budget)
initial_reference = weights.copy()
Expand Down Expand Up @@ -358,7 +360,9 @@ def rewrite_policyengine_us_dataset_weights(
with h5py.File(output, "r+") as handle:
household_ids = handle["household_id"][period_key][:]
if len(household_ids) != len(weights):
raise ValueError("household_weights length does not match household_id array")
raise ValueError(
"household_weights length does not match household_id array"
)
household_map = {
int(household_id): float(weight)
for household_id, weight in zip(household_ids, weights, strict=True)
Expand All @@ -368,7 +372,10 @@ def rewrite_policyengine_us_dataset_weights(
if "person_weight" in handle and "person_household_id" in handle:
person_households = handle["person_household_id"][period_key][:]
person_weights = np.array(
[household_map[int(household_id)] for household_id in person_households],
[
household_map[int(household_id)]
for household_id in person_households
],
dtype=np.float32,
)
handle["person_weight"][period_key][...] = person_weights
Expand Down Expand Up @@ -448,8 +455,10 @@ def optimize_policyengine_us_native_loss_dataset(
check=False,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip() or str(
completed.returncode
detail = (
completed.stderr.strip()
or completed.stdout.strip()
or str(completed.returncode)
)
raise RuntimeError(f"PE-native loss-matrix extraction failed: {detail}")

Expand Down
Loading
Loading