diff --git a/.github/workflows/local_area_publish.yaml b/.github/workflows/local_area_publish.yaml index 44675e63e..9fa86174d 100644 --- a/.github/workflows/local_area_publish.yaml +++ b/.github/workflows/local_area_publish.yaml @@ -4,7 +4,7 @@ on: push: branches: [main] paths: - - 'policyengine_us_data/datasets/cps/local_area_calibration/**' + - 'policyengine_us_data/calibration/**' - '.github/workflows/local_area_publish.yaml' - 'modal_app/**' repository_dispatch: @@ -23,7 +23,7 @@ on: type: boolean # Trigger strategy: -# 1. Automatic: Code changes to local_area_calibration/ pushed to main +# 1. Automatic: Code changes to calibration/ pushed to main # 2. repository_dispatch: Calibration workflow triggers after uploading new weights # 3. workflow_dispatch: Manual trigger with optional parameters @@ -55,7 +55,7 @@ jobs: SKIP_UPLOAD="${{ github.event.inputs.skip_upload || 'false' }}" BRANCH="${{ github.head_ref || github.ref_name }}" - CMD="modal run modal_app/local_area.py --branch=${BRANCH} --num-workers=${NUM_WORKERS}" + CMD="modal run modal_app/local_area.py::main --branch=${BRANCH} --num-workers=${NUM_WORKERS}" if [ "$SKIP_UPLOAD" = "true" ]; then CMD="${CMD} --skip-upload" @@ -71,5 +71,60 @@ jobs: echo "" >> $GITHUB_STEP_SUMMARY echo "Files have been uploaded to GCS and staged on HuggingFace." >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY - echo "### Next step: Promote to production" >> $GITHUB_STEP_SUMMARY - echo "Trigger the **Promote Local Area H5 Files** workflow with the version from the build output." >> $GITHUB_STEP_SUMMARY + echo "### Next step: Validation runs automatically" >> $GITHUB_STEP_SUMMARY + echo "The validate-staging job will now check all staged H5s." >> $GITHUB_STEP_SUMMARY + + validate-staging: + needs: publish-local-area + runs-on: ubuntu-latest + env: + HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + steps: + - name: Checkout repo + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + + - name: Install dependencies + run: uv sync + + - name: Validate staged H5s + run: | + uv run python -m policyengine_us_data.calibration.validate_staging \ + --area-type states --output validation_results.csv + + - name: Upload validation results to HF + run: | + uv run python -c " + from policyengine_us_data.utils.huggingface import upload + upload('validation_results.csv', + 'policyengine/policyengine-us-data', + 'calibration/logs/validation_results.csv') + " + + - name: Post validation summary + if: always() + run: | + echo "## Validation Results" >> $GITHUB_STEP_SUMMARY + if [ -f validation_results.csv ]; then + TOTAL=$(tail -n +2 validation_results.csv | wc -l) + FAILS=$(grep -c ',FAIL,' validation_results.csv || true) + echo "- **${TOTAL}** targets validated" >> $GITHUB_STEP_SUMMARY + echo "- **${FAILS}** sanity failures" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Review in dashboard, then trigger **Promote** workflow." >> $GITHUB_STEP_SUMMARY + else + echo "Validation did not produce output." >> $GITHUB_STEP_SUMMARY + fi + + - name: Upload validation artifact + uses: actions/upload-artifact@v4 + with: + name: validation-results + path: validation_results.csv diff --git a/.gitignore b/.gitignore index a7ab98c9a..5418f2090 100644 --- a/.gitignore +++ b/.gitignore @@ -30,12 +30,12 @@ docs/.ipynb_checkpoints/ ## ACA PTC state-level uprating factors !policyengine_us_data/storage/aca_ptc_multipliers_2022_2024.csv -## Raw input cache for database pipeline -policyengine_us_data/storage/calibration/raw_inputs/ +## Calibration run outputs (weights, diagnostics, packages, config) +policyengine_us_data/storage/calibration/ ## Batch processing checkpoints completed_*.txt ## Test fixtures -!policyengine_us_data/tests/test_local_area_calibration/test_fixture_50hh.h5 +!policyengine_us_data/tests/test_calibration/test_fixture_50hh.h5 oregon_ctc_analysis.py diff --git a/Makefile b/Makefile index b34b8eb60..c4e7ba541 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,12 @@ -.PHONY: all format test install download upload docker documentation data validate-data calibrate publish-local-area clean build paper clean-paper presentations database database-refresh promote-database promote-dataset +.PHONY: all format test install download upload docker documentation data validate-data calibrate calibrate-build publish-local-area upload-calibration upload-dataset upload-database push-to-modal build-matrices calibrate-modal calibrate-modal-national calibrate-both stage-h5s stage-national-h5 stage-all-h5s pipeline validate-staging validate-staging-full upload-validation check-staging check-sanity clean build paper clean-paper presentations database database-refresh promote-database promote-dataset promote build-h5s validate-local + +GPU ?= A100-80GB +EPOCHS ?= 200 +NATIONAL_GPU ?= T4 +NATIONAL_EPOCHS ?= 200 +BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD) +NUM_WORKERS ?= 8 +VERSION ?= HF_CLONE_DIR ?= $(HOME)/huggingface/policyengine-us-data @@ -79,8 +87,8 @@ promote-database: @echo "Copied DB and raw_inputs to HF clone. Now cd to HF repo, commit, and push." promote-dataset: - cp policyengine_us_data/storage/stratified_extended_cps_2024.h5 \ - $(HF_CLONE_DIR)/calibration/stratified_extended_cps.h5 + cp policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \ + $(HF_CLONE_DIR)/calibration/source_imputed_stratified_extended_cps.h5 @echo "Copied dataset to HF clone. Now cd to HF repo, commit, and push." data: download @@ -90,20 +98,147 @@ data: download python policyengine_us_data/datasets/puf/irs_puf.py python policyengine_us_data/datasets/puf/puf.py python policyengine_us_data/datasets/cps/extended_cps.py + python policyengine_us_data/calibration/create_stratified_cps.py + python policyengine_us_data/calibration/create_source_imputed_cps.py + +data-legacy: data python policyengine_us_data/datasets/cps/enhanced_cps.py python policyengine_us_data/datasets/cps/small_enhanced_cps.py - python policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py calibrate: data python -m policyengine_us_data.calibration.unified_calibration \ - --puf-dataset policyengine_us_data/storage/puf_2024.h5 + --target-config policyengine_us_data/calibration/target_config.yaml + +calibrate-build: data + python -m policyengine_us_data.calibration.unified_calibration \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --build-only + +validate-package: + python -m policyengine_us_data.calibration.validate_package publish-local-area: - python policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py + python policyengine_us_data/calibration/publish_local_area.py --upload + +build-h5s: + python -m policyengine_us_data.calibration.publish_local_area \ + --weights-path policyengine_us_data/storage/calibration/calibration_weights.npy \ + --dataset-path policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \ + --db-path policyengine_us_data/storage/calibration/policy_data.db \ + --calibration-blocks policyengine_us_data/storage/calibration/stacked_blocks.npy \ + --stacked-takeup policyengine_us_data/storage/calibration/stacked_takeup.npz \ + --states-only + +validate-local: + python -m policyengine_us_data.calibration.validate_staging \ + --hf-prefix local_area_build \ + --area-type states --output validation_results.csv validate-data: python -c "from policyengine_us_data.storage.upload_completed_datasets import validate_all_datasets; validate_all_datasets()" +upload-calibration: + python -c "from policyengine_us_data.utils.huggingface import upload_calibration_artifacts; \ + upload_calibration_artifacts()" + +upload-dataset: + python -c "from policyengine_us_data.utils.huggingface import upload; \ + upload('policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5', \ + 'policyengine/policyengine-us-data', \ + 'calibration/source_imputed_stratified_extended_cps.h5')" + @echo "Dataset uploaded to HF." + +upload-database: + python -c "from policyengine_us_data.utils.huggingface import upload; \ + upload('policyengine_us_data/storage/calibration/policy_data.db', \ + 'policyengine/policyengine-us-data', \ + 'calibration/policy_data.db')" + @echo "Database uploaded to HF." + +push-to-modal: + modal volume put local-area-staging \ + policyengine_us_data/storage/calibration/calibration_weights.npy \ + calibration_inputs/calibration/calibration_weights.npy --force + modal volume put local-area-staging \ + policyengine_us_data/storage/calibration/stacked_blocks.npy \ + calibration_inputs/calibration/stacked_blocks.npy --force + modal volume put local-area-staging \ + policyengine_us_data/storage/calibration/stacked_takeup.npz \ + calibration_inputs/calibration/stacked_takeup.npz --force + modal volume put local-area-staging \ + policyengine_us_data/storage/calibration/policy_data.db \ + calibration_inputs/calibration/policy_data.db --force + modal volume put local-area-staging \ + policyengine_us_data/storage/calibration/geo_labels.json \ + calibration_inputs/calibration/geo_labels.json --force + modal volume put local-area-staging \ + policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \ + calibration_inputs/calibration/source_imputed_stratified_extended_cps.h5 --force + @echo "All calibration inputs pushed to Modal volume." + +build-matrices: + modal run modal_app/remote_calibration_runner.py::build_package \ + --branch $(BRANCH) + +calibrate-modal: + modal run modal_app/remote_calibration_runner.py::main \ + --branch $(BRANCH) --gpu $(GPU) --epochs $(EPOCHS) \ + --push-results + +calibrate-modal-national: + modal run modal_app/remote_calibration_runner.py::main \ + --branch $(BRANCH) --gpu $(NATIONAL_GPU) \ + --epochs $(NATIONAL_EPOCHS) \ + --push-results --national + +calibrate-both: + $(MAKE) calibrate-modal & $(MAKE) calibrate-modal-national & wait + +stage-h5s: + modal run modal_app/local_area.py::main \ + --branch $(BRANCH) --num-workers $(NUM_WORKERS) \ + $(if $(SKIP_DOWNLOAD),--skip-download) + +stage-national-h5: + modal run modal_app/local_area.py::main_national \ + --branch $(BRANCH) + +stage-all-h5s: + $(MAKE) stage-h5s & $(MAKE) stage-national-h5 & wait + +promote: + $(eval VERSION := $(or $(VERSION),$(shell python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])"))) + modal run modal_app/local_area.py::main_promote \ + --branch $(BRANCH) --version $(VERSION) + +validate-staging: + python -m policyengine_us_data.calibration.validate_staging \ + --area-type states --output validation_results.csv + +validate-staging-full: + python -m policyengine_us_data.calibration.validate_staging \ + --area-type states,districts --output validation_results.csv + +upload-validation: + python -c "from policyengine_us_data.utils.huggingface import upload; \ + upload('validation_results.csv', \ + 'policyengine/policyengine-us-data', \ + 'calibration/logs/validation_results.csv')" + +check-staging: + python -m policyengine_us_data.calibration.check_staging_sums + +check-sanity: + python -m policyengine_us_data.calibration.validate_staging \ + --sanity-only --area-type states --areas NC + +pipeline: data upload-dataset build-matrices calibrate-both stage-all-h5s + @echo "" + @echo "========================================" + @echo "Pipeline complete. H5s are in HF staging." + @echo "Run 'Promote Local Area H5 Files' workflow in GitHub to publish." + @echo "========================================" + clean: rm -f policyengine_us_data/storage/*.h5 rm -f policyengine_us_data/storage/*.db diff --git a/changelog.d/add-database-build-test.added.md b/changelog.d/add-database-build-test.added.md new file mode 100644 index 000000000..27661ea66 --- /dev/null +++ b/changelog.d/add-database-build-test.added.md @@ -0,0 +1 @@ +Add end-to-end test for calibration database build pipeline. diff --git a/changelog.d/calibration-pipeline-improvements.added.md b/changelog.d/calibration-pipeline-improvements.added.md new file mode 100644 index 000000000..6f6a34158 --- /dev/null +++ b/changelog.d/calibration-pipeline-improvements.added.md @@ -0,0 +1,8 @@ +Unified calibration pipeline with GPU-accelerated L1/L0 solver, target config YAML, and CLI package validator. +Per-state and per-county precomputation replacing per-clone Microsimulation (51 sims instead of 436). +Parallel state, county, and clone loop processing via ProcessPoolExecutor. +Block-level takeup re-randomization with deterministic seeded draws. +Hierarchical uprating with ACA PTC state-level CSV factors and CD reconciliation. +Modal remote runner with Volume support, CUDA OOM fixes, and checkpointing. +Stacked dataset builder with sparse CD subsets and calibration block propagation. +Staging validation script (validate_staging.py) with sim.calculate() comparison and sanity checks. diff --git a/changelog.d/calibration-pipeline-improvements.changed.md b/changelog.d/calibration-pipeline-improvements.changed.md new file mode 100644 index 000000000..492640977 --- /dev/null +++ b/changelog.d/calibration-pipeline-improvements.changed.md @@ -0,0 +1,3 @@ +Geography assignment now prevents clone-to-CD collisions. +County-dependent vars (aca_ptc) selectively precomputed per county; other vars use state-only path. +Target config switched to finest-grain include mode (~18K targets). diff --git a/changelog.d/calibration-pipeline-improvements.fixed.md b/changelog.d/calibration-pipeline-improvements.fixed.md new file mode 100644 index 000000000..c935ce0bc --- /dev/null +++ b/changelog.d/calibration-pipeline-improvements.fixed.md @@ -0,0 +1,3 @@ +Cross-state cache pollution in matrix builder precomputation. +Takeup draw ordering mismatch between matrix builder and stacked builder. +At-large district geoid mismatch (7 districts had 0 estimates). diff --git a/changelog.d/migrate-to-towncrier.changed.md b/changelog.d/migrate-to-towncrier.changed.md new file mode 100644 index 000000000..865484add --- /dev/null +++ b/changelog.d/migrate-to-towncrier.changed.md @@ -0,0 +1 @@ +Migrated from changelog_entry.yaml to towncrier fragments to eliminate merge conflicts. diff --git a/docs/build_h5.md b/docs/build_h5.md new file mode 100644 index 000000000..513680b0b --- /dev/null +++ b/docs/build_h5.md @@ -0,0 +1,123 @@ +# build_h5 — Unified H5 Builder + +`build_h5` is the single function that produces all local-area H5 datasets (national, state, district, city). It lives in `policyengine_us_data/calibration/publish_local_area.py`. + +## Signature + +```python +def build_h5( + weights: np.ndarray, + blocks: np.ndarray, + dataset_path: Path, + output_path: Path, + cds_to_calibrate: List[str], + cd_subset: List[str] = None, + county_filter: set = None, + rerandomize_takeup: bool = False, + takeup_filter: List[str] = None, +) -> Path: +``` + +## Parameter Semantics + +| Parameter | Type | Purpose | +|---|---|---| +| `weights` | `np.ndarray` | Stacked weight vector, shape `(n_geo * n_hh,)` | +| `blocks` | `np.ndarray` | Block GEOID per weight entry (same shape). If `None`, generated from CD assignments. | +| `dataset_path` | `Path` | Path to base dataset H5 file | +| `output_path` | `Path` | Where to write the output H5 file | +| `cds_to_calibrate` | `List[str]` | Ordered list of CD GEOIDs defining weight matrix row ordering | +| `cd_subset` | `List[str]` | If provided, only include rows for these CDs | +| `county_filter` | `set` | If provided, scale weights by P(target counties \| CD) for city datasets | +| `rerandomize_takeup` | `bool` | Re-draw takeup using block-level seeds | +| `takeup_filter` | `List[str]` | List of takeup variables to re-randomize | + +## How `cd_subset` Controls Output Level + +The `cd_subset` parameter determines what geographic level the output represents: + +- **National** (`cd_subset=None`): All CDs included — produces a full national dataset. +- **State** (`cd_subset=[CDs in state]`): Filter to CDs whose FIPS prefix matches the state — produces a state dataset. +- **District** (`cd_subset=[single_cd]`): Single CD — produces a district dataset. +- **City** (`cd_subset=[NYC CDs]` + `county_filter=NYC_COUNTIES`): Multiple CDs with county filtering — produces a city dataset. The `county_filter` scales weights by the probability that a household in each CD falls within the target counties. + +## Internal Pipeline + +1. **Load base simulation** — One `Microsimulation` loaded from `dataset_path`. Entity arrays and membership mappings extracted. + +2. **Reshape weights** — The flat weight vector is reshaped to `(n_geo, n_hh)`. + +3. **CD subset filtering** — Rows for CDs not in `cd_subset` are zeroed out. + +4. **County filtering** — If `county_filter` is set, each row is scaled by `P(target_counties | CD)` via `get_county_filter_probability()`. + +5. **Identify active clones** — `np.where(W > 0)` finds all nonzero entries. Each represents a distinct household clone. + +6. **Clone entity arrays** — Entity arrays (household, person, tax_unit, spm_unit, family, marital_unit) are cloned using fancy indexing on the base simulation arrays. + +7. **Reindex entity IDs** — All entity IDs are reassigned to be globally unique. Cross-reference arrays (e.g., `person_household_id`) are updated accordingly. + +8. **Derive geography** — Block GEOIDs are mapped to state FIPS, county, tract, CBSA, etc. via `derive_geography_from_blocks()`. Unique blocks are deduplicated for efficiency. + +9. **Recalculate SPM thresholds** — SPM thresholds are recomputed using `calculate_spm_thresholds_vectorized()` with the clone's CD-level geographic adjustment factor. + +10. **Rerandomize takeup** (optional) — If enabled, takeup booleans are redrawn per census block using `apply_block_takeup_to_arrays()`. + +11. **Write H5** — All variable arrays are written to the output file. + +## Usage Examples + +### National +```python +build_h5( + weights=w, + blocks=blocks, + dataset_path=Path("base.h5"), + output_path=Path("national/US.h5"), + cds_to_calibrate=cds, +) +``` + +### State +```python +state_fips = 6 # California +cd_subset = [cd for cd in cds if int(cd) // 100 == state_fips] +build_h5( + weights=w, + blocks=blocks, + dataset_path=Path("base.h5"), + output_path=Path("states/CA.h5"), + cds_to_calibrate=cds, + cd_subset=cd_subset, +) +``` + +### District +```python +build_h5( + weights=w, + blocks=blocks, + dataset_path=Path("base.h5"), + output_path=Path("districts/CA-12.h5"), + cds_to_calibrate=cds, + cd_subset=["0612"], +) +``` + +### City (NYC) +```python +from policyengine_us_data.calibration.publish_local_area import ( + NYC_COUNTIES, NYC_CDS, +) + +cd_subset = [cd for cd in cds if cd in NYC_CDS] +build_h5( + weights=w, + blocks=blocks, + dataset_path=Path("base.h5"), + output_path=Path("cities/NYC.h5"), + cds_to_calibrate=cds, + cd_subset=cd_subset, + county_filter=NYC_COUNTIES, +) +``` diff --git a/docs/calibration.md b/docs/calibration.md new file mode 100644 index 000000000..131c2a8e5 --- /dev/null +++ b/docs/calibration.md @@ -0,0 +1,497 @@ +# Calibration Pipeline User's Manual + +The unified calibration pipeline reweights cloned CPS records to match administrative targets using L0-regularized optimization. This guide covers the main workflows: lightweight build-then-fit, full pipeline with PUF, and fitting from a saved package. + +## Quick Start + +```bash +# Build matrix only from stratified CPS (no PUF, no re-imputation): +python -m policyengine_us_data.calibration.unified_calibration \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --skip-source-impute \ + --skip-takeup-rerandomize \ + --build-only + +# Fit weights from a saved package: +python -m policyengine_us_data.calibration.unified_calibration \ + --package-path storage/calibration/calibration_package.pkl \ + --epochs 500 --device cuda + +# Full pipeline with PUF (build + fit in one shot): +make calibrate +``` + +## Architecture Overview + +The pipeline has two phases: + +1. **Matrix build**: Clone CPS records, assign geography, compute all target variable values, assemble a sparse calibration matrix. Optionally includes PUF cloning (doubles record count) and source re-imputation. +2. **Weight fitting** (~5-20 min on GPU): L0-regularized optimization to find household weights that reproduce administrative targets. + +The calibration package checkpoint lets you run phase 1 once and iterate on phase 2 with different hyperparameters or target selections---without rebuilding. + +### Prerequisites + +The matrix build requires two inputs from the data pipeline: + +- **Stratified CPS** (`storage/stratified_extended_cps_2024.h5`): ~12K households, built by `make data`. This is the base dataset that gets cloned. +- **Target database** (`storage/calibration/policy_data.db`): Administrative targets, built by `make database`. + +Both must exist before running calibration. The stratified CPS already contains all CPS variables needed for calibration; PUF cloning and source re-imputation are optional enhancements that happen at calibration time. + +## Workflows + +### 1. Lightweight build-then-fit (recommended for iteration) + +Build the matrix from the stratified CPS without PUF cloning or re-imputation. This is the fastest way to get a calibration package for experimentation. + +**Step 1: Build the matrix (~12K base records x 436 clones = ~5.2M columns).** + +```bash +python -m policyengine_us_data.calibration.unified_calibration \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --skip-source-impute \ + --skip-takeup-rerandomize \ + --build-only +``` + +This saves `storage/calibration/calibration_package.pkl` (default location). Use `--package-output` to specify a different path. + +**Step 2: Fit weights from the package (fast, repeatable).** + +```bash +python -m policyengine_us_data.calibration.unified_calibration \ + --package-path storage/calibration/calibration_package.pkl \ + --epochs 1000 \ + --lambda-l0 1e-8 \ + --beta 0.65 \ + --lambda-l2 1e-8 \ + --device cuda +``` + +You can re-run Step 2 as many times as you want with different hyperparameters. The expensive matrix build only happens once. + +### 2. Full pipeline with PUF + +Adding `--puf-dataset` doubles the record count (~24K base records x 436 clones = ~10.4M columns) by creating PUF-imputed copies of every CPS record. This also triggers source re-imputation unless skipped. + +**Single-pass (build + fit):** + +```bash +python -m policyengine_us_data.calibration.unified_calibration \ + --puf-dataset policyengine_us_data/storage/puf_2024.h5 \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --epochs 200 \ + --device cuda +``` + +Or equivalently: `make calibrate` + +Output: +- `storage/calibration/calibration_weights.npy` --- calibrated weight vector +- `storage/calibration/unified_diagnostics.csv` --- per-target error report +- `storage/calibration/unified_run_config.json` --- full run configuration + +**Build-only (save package for later fitting):** + +```bash +python -m policyengine_us_data.calibration.unified_calibration \ + --puf-dataset policyengine_us_data/storage/puf_2024.h5 \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --build-only +``` + +Or equivalently: `make calibrate-build` + +This saves `storage/calibration/calibration_package.pkl` (default location). Use `--package-output` to specify a different path. + +Then fit from the package using the same Step 2 command from Workflow 1. + +### 3. Re-filtering a saved package + +A saved package contains **all** targets from the database (before target config filtering). You can apply a different target config at fit time: + +```bash +python -m policyengine_us_data.calibration.unified_calibration \ + --package-path storage/calibration/calibration_package.pkl \ + --target-config my_custom_config.yaml \ + --epochs 200 +``` + +This lets you experiment with which targets to include without rebuilding the matrix. + +### 4. Running on Modal (GPU cloud) + +**From a pre-built package** (recommended): + +Use `--package-path` to point at a local `.pkl` file. The runner automatically uploads it to the Modal Volume and then fits from it on the GPU, avoiding the function argument size limit. + +```bash +modal run modal_app/remote_calibration_runner.py \ + --package-path policyengine_us_data/storage/calibration/calibration_package.pkl \ + --branch calibration-pipeline-improvements \ + --gpu T4 \ + --epochs 1000 \ + --beta 0.65 \ + --lambda-l0 1e-8 \ + --lambda-l2 1e-8 +``` + +If a package already exists on the volume from a previous upload, you can also use `--prebuilt-matrices` to fit directly without re-uploading. + +**Full pipeline** (builds matrix from scratch on Modal): + +```bash +modal run modal_app/remote_calibration_runner.py \ + --branch calibration-pipeline-improvements \ + --gpu T4 \ + --epochs 1000 \ + --beta 0.65 \ + --lambda-l0 1e-8 \ + --lambda-l2 1e-8 \ + --target-config policyengine_us_data/calibration/target_config.yaml +``` + +The target config YAML is read from the cloned repo inside the container, so it must be committed to the branch you specify. + +### 5. Portable fitting (Kaggle, Colab, etc.) + +Transfer the package file to any environment with `scipy`, `numpy`, `pandas`, `torch`, and `l0-python` installed: + +```python +from policyengine_us_data.calibration.unified_calibration import ( + load_calibration_package, + apply_target_config, + fit_l0_weights, +) + +package = load_calibration_package("calibration_package.pkl") +targets_df = package["targets_df"] +X_sparse = package["X_sparse"] + +weights = fit_l0_weights( + X_sparse=X_sparse, + targets=targets_df["value"].values, + lambda_l0=1e-8, + epochs=500, + device="cuda", + beta=0.65, + lambda_l2=1e-8, +) +``` + +## Target Config + +The target config controls which targets reach the optimizer. It uses a YAML exclusion list: + +```yaml +exclude: + - variable: rent + geo_level: national + - variable: eitc + geo_level: district + - variable: snap + geo_level: state + domain_variable: snap # optional: further narrow the match +``` + +Each rule drops rows from the calibration matrix where **all** specified fields match. Unrecognized variables silently match nothing. + +### Fields + +| Field | Required | Values | Description | +|---|---|---|---| +| `variable` | Yes | Any variable name in `target_overview` | The calibration target variable | +| `geo_level` | Yes | `national`, `state`, `district` | Geographic aggregation level | +| `domain_variable` | No | Any domain variable in `target_overview` | Narrows match to a specific domain | + +### Default config + +The checked-in config at `policyengine_us_data/calibration/target_config.yaml` reproduces the junkyard notebook's 22 excluded target groups. It drops: + +- **13 national-level variables**: alimony, charitable deduction, child support, interest deduction, medical expense deduction, net worth, person count, real estate taxes, rent, social security dependents/survivors +- **9 district-level variables**: ACA PTC, EITC, income tax before credits, medical expense deduction, net capital gains, rental income, tax unit count, partnership/S-corp income, taxable social security + +Applying this config reduces targets from ~37K to ~21K, matching the junkyard's target selection. + +### Writing a custom config + +To experiment, copy the default and edit: + +```bash +cp policyengine_us_data/calibration/target_config.yaml my_config.yaml +# Edit my_config.yaml to add/remove exclusion rules +python -m policyengine_us_data.calibration.unified_calibration \ + --package-path storage/calibration/calibration_package.pkl \ + --target-config my_config.yaml \ + --epochs 200 +``` + +To see what variables and geo_levels are available in the database: + +```sql +SELECT DISTINCT variable, geo_level +FROM target_overview +ORDER BY variable, geo_level; +``` + +## CLI Reference + +### Core flags + +| Flag | Default | Description | +|---|---|---| +| `--dataset` | `storage/stratified_extended_cps_2024.h5` | Path to CPS h5 file | +| `--db-path` | `storage/calibration/policy_data.db` | Path to target database | +| `--output` | `storage/calibration/calibration_weights.npy` | Weight output path | +| `--puf-dataset` | None | Path to PUF h5 (enables PUF cloning) | +| `--preset` | `local` | L0 preset: `local` (1e-8) or `national` (1e-4) | +| `--lambda-l0` | None | Custom L0 penalty (overrides `--preset`) | +| `--epochs` | 100 | Training epochs | +| `--device` | `cpu` | `cpu` or `cuda` | +| `--n-clones` | 436 | Number of dataset clones | +| `--seed` | 42 | Random seed for geography assignment | + +### Target selection + +| Flag | Default | Description | +|---|---|---| +| `--target-config` | None | Path to YAML exclusion config | +| `--domain-variables` | None | Comma-separated domain filter (SQL-level) | +| `--hierarchical-domains` | None | Domains for hierarchical uprating | + +### Checkpoint flags + +| Flag | Default | Description | +|---|---|---| +| `--build-only` | False | Build matrix, save package, skip fitting | +| `--package-path` | None | Load pre-built package (uploads to Modal volume automatically when using Modal runner) | +| `--package-output` | Auto (when `--build-only`) | Where to save package | + +### Hyperparameter flags + +| Flag | Default | Junkyard value | Description | +|---|---|---|---| +| `--beta` | 0.35 | 0.65 | L0 gate temperature (higher = softer gates) | +| `--lambda-l2` | 1e-12 | 1e-8 | L2 regularization on weights | +| `--learning-rate` | 0.15 | 0.15 | Optimizer learning rate | + +### Skip flags + +| Flag | Description | +|---|---| +| `--skip-puf` | Skip PUF clone + QRF imputation | +| `--skip-source-impute` | Skip ACS/SIPP/SCF re-imputation | +| `--skip-takeup-rerandomize` | Skip takeup re-randomization | + +## Calibration Package Format + +The package is a pickled Python dict: + +```python +{ + "X_sparse": scipy.sparse.csr_matrix, # (n_targets, n_records) + "targets_df": pd.DataFrame, # target metadata + values + "target_names": list[str], # human-readable names + "metadata": { + "dataset_path": str, + "db_path": str, + "n_clones": int, + "n_records": int, + "seed": int, + "created_at": str, # ISO timestamp + "target_config": dict, # config used at build time + }, +} +``` + +The `targets_df` DataFrame has columns: `variable`, `geo_level`, `geographic_id`, `domain_variable`, `value`, and others from the database. + +## Validating a Package + +Before uploading a package to Modal, validate it: + +```bash +# Default package location +python -m policyengine_us_data.calibration.validate_package + +# Specific package +python -m policyengine_us_data.calibration.validate_package path/to/calibration_package.pkl + +# Strict mode: fail if any target has row_sum/target < 1% +python -m policyengine_us_data.calibration.validate_package --strict +``` + +Exit codes: **0** = pass, **1** = impossible targets, **2** = strict ratio failures. + +Validation also runs automatically after `--build-only`. + +## Hyperparameter Tuning Guide + +The three key hyperparameters control the tradeoff between target accuracy and sparsity: + +- **`beta`** (L0 gate temperature): Controls how sharply the L0 gates open/close. Higher values (0.5--0.8) give softer decisions and more exploration early in training. Lower values (0.2--0.4) give harder on/off decisions. + +- **`lambda_l0`** (via `--preset` or `--lambda-l0`): Controls how many records survive. `1e-8` (local preset) keeps millions of records for local-area analysis. `1e-4` (national preset) keeps ~50K for the web app. + +- **`lambda_l2`**: Regularizes weight magnitudes. Larger values (1e-8) prevent any single record from having extreme weight. Smaller values (1e-12) allow more weight concentration. + +### Suggested starting points + +For **local-area calibration** (millions of records): +```bash +--lambda-l0 1e-8 --beta 0.65 --lambda-l2 1e-8 --epochs 500 +``` + +For **national web app** (~50K records): +```bash +--lambda-l0 1e-4 --beta 0.35 --lambda-l2 1e-12 --epochs 200 +``` + +## Makefile Targets + +| Target | Description | +|---|---| +| `make calibrate` | Full pipeline with PUF and target config | +| `make calibrate-build` | Build-only mode (saves package, no fitting) | +| `make pipeline` | End-to-end: data, upload, calibrate, stage | +| `make validate-staging` | Validate staged H5s against targets (states only) | +| `make validate-staging-full` | Validate staged H5s (states + districts) | +| `make upload-validation` | Push validation_results.csv to HF | +| `make check-staging` | Smoke test: sum key variables across all state H5s | +| `make check-sanity` | Quick structural integrity check on one state | +| `make upload-calibration` | Upload weights, blocks, and logs to HF | + +## Takeup Rerandomization + +The calibration pipeline uses two independent code paths to compute the same target variables: + +1. **Matrix builder** (`UnifiedMatrixBuilder.build_matrix`): Computes a sparse calibration matrix $X$ where each row is a target and each column is a cloned household. The optimizer finds weights $w$ that minimize $\|Xw - t\|$ (target values). + +2. **Stacked builder** (`create_sparse_cd_stacked_dataset`): Produces the `.h5` files that users load in `Microsimulation`. It reconstructs each congressional district by combining base CPS records with calibrated weights and block-level geography. + +For the calibration to be meaningful, **both paths must produce identical values** for every target variable. If the matrix builder computes $X_{snap,NC} \cdot w = \$5.2B$ but the stacked NC.h5 file yields `sim.calculate("snap") * household_weight = $4.8B`, then the optimizer's solution does not actually match the target. + +### The problem with takeup variables + +Variables like `snap`, `aca_ptc`, `ssi`, and `medicaid` depend on **takeup draws** — random Bernoulli samples that determine whether an eligible household actually claims the benefit. By default, PolicyEngine draws these at simulation time using Python's built-in `hash()`, which is randomized per process. + +This means loading the same H5 file in two different processes can produce different SNAP totals, even with the same weights. Worse, the matrix builder runs in process A while the stacked builder runs in process B, so their draws can diverge. + +### The solution: block-level seeding + +Both paths call `seeded_rng(variable_name, salt=f"{block_geoid}:{household_id}")` to generate deterministic takeup draws. This ensures: + +- The same household at the same block always gets the same draw +- Draws are stable across processes (no dependency on `hash()`) +- Draws are stable when aggregating to any geography (state, CD, county) + +The affected variables are listed in `TAKEUP_AFFECTED_TARGETS` in `utils/takeup.py`: snap, aca_ptc, ssi, medicaid, tanf, head_start, early_head_start, and dc_property_tax_credit. + +The `--skip-takeup-rerandomize` flag disables this rerandomization for faster iteration when you only care about non-takeup variables. Do not use it for production calibrations. + +## Block-Level Seeding + +Each cloned household is assigned to a Census block (15-digit GEOID) during the `assign_random_geography` step. The first 2 digits are the state FIPS code, which determines the household's takeup rates (since benefit eligibility rules are state-specific). + +### Mechanism + +```python +rng = seeded_rng(variable_name, salt=f"{block_geoid}:{household_id}") +draw = rng.random() +takes_up = draw < takeup_rate[state_fips] +``` + +The `seeded_rng` function uses `_stable_string_hash` — a deterministic hash that does not depend on Python's `PYTHONHASHSEED`. This is critical because Python's built-in `hash()` is randomized per process by default (since Python 3.3). + +### Why block (not CD or state)? + +Blocks are the finest Census geography. A household's block assignment stays the same regardless of how blocks are aggregated — the same household-block-draw triple produces the same result whether you are building an H5 for a state, a congressional district, or a county. This means: + +- State H5s and district H5s are consistent (no draw drift) +- Future county-level H5s will also be consistent +- Re-running the pipeline with different area selections yields the same per-household values + +### Inactive records + +When converting to stacked format, households that are not assigned to a given CD get zero weight. These inactive records must receive an empty string `""` as their block GEOID, not a real block. If they received real blocks, they would inflate the entity count `n` passed to the RNG, shifting the draw positions for active entities and breaking the $X \cdot w$ consistency invariant. + +## The $X \cdot w$ Consistency Invariant + +### Formal statement + +For every target variable $v$ and geography $g$: + +$$X_{v,g} \cdot w = \sum_{i \in g} \text{sim.calculate}(v)_i \times w_i$$ + +where the left side comes from the matrix builder and the right side comes from loading the stacked H5 and running `Microsimulation.calculate()`. + +### Why it matters + +This invariant is what makes calibration meaningful. Without it, the optimizer's solution (which minimizes $\|Xw - t\|$) does not actually produce a dataset that matches the targets. The weights would be "correct" in the matrix builder's view but produce different totals in the H5 files that users actually load. + +### Known sources of drift + +1. **Mismatched takeup draws**: The matrix builder and stacked builder use different RNG states. Solved by block-level seeding (see above). + +2. **Different block assignments**: The stacked format uses first-clone-wins for multi-clone-same-CD records. With ~11M blocks and 3-10 clones, collision rate is ~0.7-10% of records. In practice, the residual mismatch is negligible. + +3. **Inactive records in RNG calls**: If inactive records (w=0) receive real block GEOIDs, they inflate the entity count for that block's RNG call, shifting draw positions. Solved by using `""` for inactive blocks. + +4. **Entity ordering**: Both paths must iterate over entities in the same order (`sim.calculate("{entity}_id", map_to=entity)`). NumPy boolean masking preserves order, so `draws[i]` maps to the same entity in both paths. + +### Testing + +The `test_xw_consistency.py` test (`pytest -m slow`) verifies this invariant end-to-end: + +1. Load base dataset, create geography with uniform weights +2. Build $X$ with the matrix builder (including takeup rerandomization) +3. Convert weights to stacked format +4. Build stacked H5 for selected CDs +5. Compare $X \cdot w$ vs `sim.calculate() * household_weight` — assert ratio within 1% + +## Post-Calibration Gating Workflow + +After the pipeline stages H5 files to HuggingFace, two manual review gates determine whether to promote to production. + +### Gate 1: Review calibration fit + +Load `calibration_log.csv` in the microcalibrate dashboard. This file contains the $X \cdot w$ values from the matrix builder for every target at every epoch. + +**What to check:** +- Loss curve converges (no divergence in later epochs) +- No individual target groups diverging while others improve +- Final loss is comparable to or better than the previous production run + +If fit is poor, re-calibrate with different hyperparameters (learning rate, lambda_l0, beta, epochs). + +### Gate 2: Review simulation quality + +```bash +make validate-staging # states only (~30 min) +make validate-staging-full # states + districts (~3 hrs) +make upload-validation # push CSV to HF +``` + +This produces `validation_results.csv` with `sim.calculate()` values for every target. Load it in the dashboard's Combined tab alongside `calibration_log.csv`. + +**What to check:** +- `CalibrationVsSimComparison` shows the gap between $X \cdot w$ and `sim.calculate()` values +- No large regressions vs the previous production run +- Sanity check column has no FAIL entries + +### Promote + +If both gates pass: +- Run the "Promote Local Area H5 Files" GitHub workflow, OR +- Manually copy staged files to the production paths in the HF repo + +### Structural pre-flight + +For a quick structural check without loading the full database: + +```bash +make check-sanity # one state, ~2 min +``` + +This runs weight non-negativity, entity ID uniqueness, NaN/Inf detection, person-household mapping, boolean takeup validation, and per-household range checks. diff --git a/docs/calibration_matrix.ipynb b/docs/calibration_matrix.ipynb index 41497b1e8..133f45910 100644 --- a/docs/calibration_matrix.ipynb +++ b/docs/calibration_matrix.ipynb @@ -24,10 +24,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], - "source": "import numpy as np\nimport pandas as pd\nfrom policyengine_us import Microsimulation\nfrom policyengine_us_data.storage import STORAGE_FOLDER\nfrom policyengine_us_data.calibration.unified_matrix_builder import (\n UnifiedMatrixBuilder,\n)\nfrom policyengine_us_data.calibration.clone_and_assign import (\n assign_random_geography,\n)\nfrom policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (\n create_target_groups,\n drop_target_groups,\n get_geo_level,\n STATE_CODES,\n)\n\ndb_path = STORAGE_FOLDER / \"calibration\" / \"policy_data.db\"\ndb_uri = f\"sqlite:///{db_path}\"\ndataset_path = STORAGE_FOLDER / \"stratified_extended_cps_2024.h5\"" + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/baogorek/envs/sep/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from policyengine_us import Microsimulation\n", + "from policyengine_us_data.storage import STORAGE_FOLDER\n", + "from policyengine_us_data.calibration.unified_matrix_builder import (\n", + " UnifiedMatrixBuilder,\n", + ")\n", + "from policyengine_us_data.calibration.clone_and_assign import (\n", + " assign_random_geography,\n", + ")\n", + "from policyengine_us_data.calibration.calibration_utils import (\n", + " create_target_groups,\n", + " drop_target_groups,\n", + " get_geo_level,\n", + " STATE_CODES,\n", + ")\n", + "\n", + "db_path = STORAGE_FOLDER / \"calibration\" / \"policy_data.db\"\n", + "db_uri = f\"sqlite:///{db_path}\"\n", + "dataset_path = STORAGE_FOLDER / \"stratified_extended_cps_2024.h5\"" + ] }, { "cell_type": "code", @@ -40,7 +70,7 @@ "text": [ "Records: 11,999, Clones: 3, Total columns: 35,997\n", "Matrix shape: (1411, 35997)\n", - "Non-zero entries: 14,946\n" + "Non-zero entries: 29,425\n" ] } ], @@ -79,10 +109,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], - "source": "print(f\"Targets: {X_sparse.shape[0]}\")\nprint(f\"Columns: {X_sparse.shape[1]:,} ({N_CLONES} clones x {n_records:,} records)\")\nprint(f\"Non-zeros: {X_sparse.nnz:,}\")\nprint(f\"Density: {X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.6f}\")\n\ngeo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\nlevel_names = {0: \"National\", 1: \"State\", 2: \"District\"}\nfor level in [0, 1, 2]:\n n = (geo_levels == level).sum()\n if n > 0:\n print(f\" {level_names[level]}: {n} targets\")" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Targets: 1411\n", + "Columns: 35,997 (3 clones x 11,999 records)\n", + "Non-zeros: 29,425\n", + "Density: 0.000579\n", + " National: 1 targets\n", + " State: 102 targets\n", + " District: 1308 targets\n" + ] + } + ], + "source": [ + "print(f\"Targets: {X_sparse.shape[0]}\")\n", + "print(f\"Columns: {X_sparse.shape[1]:,} ({N_CLONES} clones x {n_records:,} records)\")\n", + "print(f\"Non-zeros: {X_sparse.nnz:,}\")\n", + "print(f\"Density: {X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.6f}\")\n", + "\n", + "geo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\n", + "level_names = {0: \"National\", 1: \"State\", 2: \"District\"}\n", + "for level in [0, 1, 2]:\n", + " n = (geo_levels == level).sum()\n", + " if n > 0:\n", + " print(f\" {level_names[level]}: {n} targets\")" + ] }, { "cell_type": "markdown", @@ -131,13 +187,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Row 705 has 9 non-zero columns\n", + "Row 705 has 10 non-zero columns\n", " Spans 3 clone(s)\n", - " Spans 9 unique record(s)\n", + " Spans 10 unique record(s)\n", "\n", - "First non-zero column (8000):\n", + "First non-zero column (1212):\n", " clone_idx: 0\n", - " record_idx: 8000\n", + " record_idx: 1212\n", " state_fips: 34\n", " cd_geoid: 3402\n", " value: 1.00\n" @@ -189,7 +245,7 @@ " record_idx: 42\n", " state_fips: 45\n", " cd_geoid: 4507\n", - " block_geoid: 450510801013029\n", + " block_geoid: 450410002022009\n", "\n", "This column has non-zero values in 0 target rows\n" ] @@ -334,7 +390,7 @@ "\n", "--- Group 4: District ACA PTC Tax Unit Count (436 targets) ---\n", " variable geographic_id value\n", - "tax_unit_count 1001 25064.255490\n", + "tax_unit_count 1000 25064.255490\n", "tax_unit_count 101 9794.081624\n", "tax_unit_count 102 11597.544977\n", "tax_unit_count 103 9160.097959\n", @@ -373,13 +429,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Example SNAP-receiving household: record index 23\n", - "SNAP value: $70\n", + "Example SNAP-receiving household: record index 2\n", + "SNAP value: $679\n", "\n", "Column positions across 3 clones:\n", - " col 23: TX (state=48, CD=4829) — 0 non-zero rows\n", - " col 12022: IL (state=17, CD=1708) — 0 non-zero rows\n", - " col 24021: FL (state=12, CD=1220) — 3 non-zero rows\n" + " col 2: TX (state=48, CD=4814) — 4 non-zero rows\n", + " col 12001: IN (state=18, CD=1804) — 3 non-zero rows\n", + " col 24000: PA (state=42, CD=4212) — 3 non-zero rows\n" ] } ], @@ -413,10 +469,21 @@ "output_type": "stream", "text": [ "\n", - "Clone 2 (col 24021, CD 1220):\n", - " household_count (geo=12): 1.00\n", - " snap (geo=12): 70.08\n", - " household_count (geo=1220): 1.00\n" + "Clone 0 (col 2, CD 4814):\n", + " person_count (geo=US): 3.00\n", + " household_count (geo=48): 1.00\n", + " snap (geo=48): 678.60\n", + " household_count (geo=4814): 1.00\n", + "\n", + "Clone 1 (col 12001, CD 1804):\n", + " household_count (geo=18): 1.00\n", + " snap (geo=18): 678.60\n", + " household_count (geo=1804): 1.00\n", + "\n", + "Clone 2 (col 24000, CD 4212):\n", + " household_count (geo=42): 1.00\n", + " snap (geo=42): 678.60\n", + " household_count (geo=4212): 1.00\n" ] } ], @@ -455,9 +522,9 @@ "output_type": "stream", "text": [ "Total cells: 50,791,767\n", - "Non-zero entries: 14,946\n", - "Density: 0.000294\n", - "Sparsity: 99.9706%\n" + "Non-zero entries: 29,425\n", + "Density: 0.000579\n", + "Sparsity: 99.9421%\n" ] } ], @@ -472,10 +539,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], - "source": "nnz_per_row = np.diff(X_sparse.indptr)\nprint(f\"Non-zeros per row:\")\nprint(f\" min: {nnz_per_row.min():,}\")\nprint(f\" median: {int(np.median(nnz_per_row)):,}\")\nprint(f\" mean: {nnz_per_row.mean():,.0f}\")\nprint(f\" max: {nnz_per_row.max():,}\")\n\ngeo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\nlevel_names = {0: \"National\", 1: \"State\", 2: \"District\"}\nprint(\"\\nBy geographic level:\")\nfor level in [0, 1, 2]:\n mask = (geo_levels == level).values\n if mask.any():\n vals = nnz_per_row[mask]\n print(\n f\" {level_names[level]:10s}: \"\n f\"n={mask.sum():>4d}, \"\n f\"median nnz={int(np.median(vals)):>7,}, \"\n f\"range=[{vals.min():,}, {vals.max():,}]\"\n )" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Non-zeros per row:\n", + " min: 0\n", + " median: 10\n", + " mean: 21\n", + " max: 3,408\n", + "\n", + "By geographic level:\n", + " National : n= 1, median nnz= 3,408, range=[3,408, 3,408]\n", + " State : n= 102, median nnz= 80, range=[10, 694]\n", + " District : n=1308, median nnz= 9, range=[0, 27]\n" + ] + } + ], + "source": [ + "nnz_per_row = np.diff(X_sparse.indptr)\n", + "print(f\"Non-zeros per row:\")\n", + "print(f\" min: {nnz_per_row.min():,}\")\n", + "print(f\" median: {int(np.median(nnz_per_row)):,}\")\n", + "print(f\" mean: {nnz_per_row.mean():,.0f}\")\n", + "print(f\" max: {nnz_per_row.max():,}\")\n", + "\n", + "geo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\n", + "level_names = {0: \"National\", 1: \"State\", 2: \"District\"}\n", + "print(\"\\nBy geographic level:\")\n", + "for level in [0, 1, 2]:\n", + " mask = (geo_levels == level).values\n", + " if mask.any():\n", + " vals = nnz_per_row[mask]\n", + " print(\n", + " f\" {level_names[level]:10s}: \"\n", + " f\"n={mask.sum():>4d}, \"\n", + " f\"median nnz={int(np.median(vals)):>7,}, \"\n", + " f\"range=[{vals.min():,}, {vals.max():,}]\"\n", + " )" + ] }, { "cell_type": "code", @@ -488,9 +593,9 @@ "text": [ "Non-zeros per clone block:\n", " clone nnz unique_states\n", - " 0 4962 50\n", - " 1 4988 50\n", - " 2 4996 50\n" + " 0 9775 51\n", + " 1 9810 51\n", + " 2 9840 51\n" ] } ], @@ -613,15 +718,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Achievable targets: 479\n", - "Impossible targets: 881\n", + "Achievable targets: 1358\n", + "Impossible targets: 2\n", "\n", "Impossible targets by (domain, variable):\n", - " aca_ptc/aca_ptc: 436\n", - " aca_ptc/tax_unit_count: 436\n", - " snap/household_count: 7\n", - " aca_ptc/person_count: 1\n", - " snap/snap: 1\n" + " aca_ptc/aca_ptc: 1\n", + " aca_ptc/tax_unit_count: 1\n" ] } ], @@ -657,11 +759,11 @@ "output_type": "stream", "text": [ "Hardest targets (lowest row_sum / target_value ratio):\n", - " snap/household_count (geo=621): ratio=0.0000, row_sum=4, target=119,148\n", - " snap/household_count (geo=3615): ratio=0.0001, row_sum=9, target=173,591\n", - " snap/snap (geo=46): ratio=0.0001, row_sum=9,421, target=180,195,817\n", - " snap/household_count (geo=3625): ratio=0.0001, row_sum=4, target=67,315\n", - " snap/household_count (geo=1702): ratio=0.0001, row_sum=6, target=97,494\n" + " aca_ptc/aca_ptc (geo=3612): ratio=0.0000, row_sum=5,439, target=376,216,522\n", + " aca_ptc/aca_ptc (geo=2508): ratio=0.0000, row_sum=2,024, target=124,980,814\n", + " aca_ptc/tax_unit_count (geo=2508): ratio=0.0000, row_sum=1, target=51,937\n", + " aca_ptc/tax_unit_count (geo=3612): ratio=0.0000, row_sum=2, target=73,561\n", + " aca_ptc/tax_unit_count (geo=1198): ratio=0.0000, row_sum=1, target=30,419\n" ] } ], @@ -692,9 +794,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Final matrix shape: (479, 35997)\n", - "Final non-zero entries: 9,944\n", - "Final density: 0.000577\n", + "Final matrix shape: (1358, 35997)\n", + "Final non-zero entries: 23,018\n", + "Final density: 0.000471\n", "\n", "This is what the optimizer receives.\n" ] @@ -747,4 +849,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/docs/hierarchical_uprating.ipynb b/docs/hierarchical_uprating.ipynb index 4da30d82c..5839ccbbe 100644 --- a/docs/hierarchical_uprating.ipynb +++ b/docs/hierarchical_uprating.ipynb @@ -54,7 +54,7 @@ "from policyengine_us_data.calibration.unified_matrix_builder import (\n", " UnifiedMatrixBuilder,\n", ")\n", - "from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (\n", + "from policyengine_us_data.calibration.calibration_utils import (\n", " STATE_CODES,\n", ")\n", "\n", diff --git a/docs/local_area_calibration_setup.ipynb b/docs/local_area_calibration_setup.ipynb index 2e8614aa9..519e11a94 100644 --- a/docs/local_area_calibration_setup.ipynb +++ b/docs/local_area_calibration_setup.ipynb @@ -68,12 +68,12 @@ ")\n", "from policyengine_us_data.utils.randomness import seeded_rng\n", "from policyengine_us_data.parameters import load_take_up_rate\n", - "from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (\n", + "from policyengine_us_data.calibration.calibration_utils import (\n", " get_calculated_variables,\n", " STATE_CODES,\n", " get_all_cds_from_database,\n", ")\n", - "from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import (\n", + "from policyengine_us_data.calibration.stacked_dataset_builder import (\n", " create_sparse_cd_stacked_dataset,\n", ")\n", "\n", @@ -96,7 +96,7 @@ "output_type": "stream", "text": [ "Base dataset: 11,999 households\n", - "Example household: record_idx=8629, household_id=128694, SNAP=$18,396.00\n" + "Example household: record_idx=8629, household_id=130831, SNAP=$0.00\n" ] } ], @@ -137,9 +137,9 @@ "output_type": "stream", "text": [ "Total cloned records: 35,997\n", - "Unique states: 50\n", - "Unique CDs: 435\n", - "Unique blocks: 35508\n" + "Unique states: 51\n", + "Unique CDs: 436\n", + "Unique blocks: 35517\n" ] } ], @@ -203,8 +203,8 @@ " 8629\n", " 48\n", " TX\n", - " 4817\n", - " 481450004002026\n", + " 4816\n", + " 481410030003002\n", " \n", " \n", " 1\n", @@ -213,7 +213,7 @@ " 42\n", " PA\n", " 4201\n", - " 420171058013029\n", + " 420171018051005\n", " \n", " \n", " 2\n", @@ -222,7 +222,7 @@ " 36\n", " NY\n", " 3611\n", - " 360850208041023\n", + " 360470200002002\n", " \n", " \n", "\n", @@ -230,9 +230,9 @@ ], "text/plain": [ " clone col state_fips abbr cd_geoid block_geoid\n", - "0 0 8629 48 TX 4817 481450004002026\n", - "1 1 20628 42 PA 4201 420171058013029\n", - "2 2 32627 36 NY 3611 360850208041023" + "0 0 8629 48 TX 4816 481410030003002\n", + "1 1 20628 42 PA 4201 420171018051005\n", + "2 2 32627 36 NY 3611 360470200002002" ] }, "execution_count": 4, @@ -280,13 +280,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Global block distribution: 5,765,442 blocks\n", + "Global block distribution: 5,769,942 blocks\n", "Top 5 states by total probability:\n", - " CA (6): 11.954%\n", - " TX (48): 8.736%\n", - " FL (12): 6.437%\n", - " NY (36): 5.977%\n", - " PA (42): 3.908%\n" + " CA (6): 11.927%\n", + " TX (48): 8.716%\n", + " FL (12): 6.422%\n", + " NY (36): 5.963%\n", + " PA (42): 3.899%\n" ] } ], @@ -327,10 +327,10 @@ "output_type": "stream", "text": [ "Example household (record_idx=8629):\n", - " Original state: NC (37)\n", + " Original state: CA (6)\n", " Clone 0 state: TX (48)\n", - " Original SNAP: $18,396.00\n", - " Clone 0 SNAP: $18,396.00\n" + " Original SNAP: $0.00\n", + " Clone 0 SNAP: $0.00\n" ] } ], @@ -410,31 +410,31 @@ " 0\n", " TX\n", " 48\n", - " $18,396.00\n", + " $0.00\n", " \n", " \n", " 1\n", " 1\n", " PA\n", " 42\n", - " $18,396.00\n", + " $0.00\n", " \n", " \n", " 2\n", " 2\n", " NY\n", " 36\n", - " $18,396.00\n", + " $0.00\n", " \n", " \n", "\n", "" ], "text/plain": [ - " clone state state_fips SNAP\n", - "0 0 TX 48 $18,396.00\n", - "1 1 PA 42 $18,396.00\n", - "2 2 NY 36 $18,396.00" + " clone state state_fips SNAP\n", + "0 0 TX 48 $0.00\n", + "1 1 PA 42 $0.00\n", + "2 2 NY 36 $0.00" ] }, "execution_count": 7, @@ -499,10 +499,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Unique states mapped: 50\n", - "Unique CDs mapped: 435\n", + "Unique states mapped: 51\n", + "Unique CDs mapped: 436\n", "\n", - "Columns per state: min=62, median=494, max=4311\n" + "Columns per state: min=63, median=490, max=4299\n" ] } ], @@ -539,9 +539,9 @@ "text": [ "Example household clone visibility:\n", "\n", - "Clone 0 (TX, CD 4817):\n", + "Clone 0 (TX, CD 4816):\n", " Visible to TX state targets: col 8629 in state_to_cols[48]? True\n", - " Visible to CD 4817 targets: col 8629 in cd_to_cols['4817']? True\n", + " Visible to CD 4816 targets: col 8629 in cd_to_cols['4816']? True\n", " Visible to NC (37) targets: False\n", "\n", "Clone 1 (PA, CD 4201):\n", @@ -612,7 +612,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "8 takeup variables:\n", + "9 takeup variables:\n", "\n", " takes_up_snap_if_eligible entity=spm_unit rate=82.00%\n", " takes_up_aca_if_eligible entity=tax_unit rate=67.20%\n", @@ -621,7 +621,8 @@ " takes_up_early_head_start_if_eligible entity=person rate=9.00%\n", " takes_up_ssi_if_eligible entity=person rate=50.00%\n", " would_file_taxes_voluntarily entity=tax_unit rate=5.00%\n", - " takes_up_medicaid_if_eligible entity=person rate=dict (51 entries)\n" + " takes_up_medicaid_if_eligible entity=person rate=dict (51 entries)\n", + " takes_up_tanf_if_eligible entity=spm_unit rate=22.00%\n" ] } ], @@ -708,14 +709,15 @@ "text": [ "Takeup rates before/after re-randomization (clone 0):\n", "\n", - " takes_up_snap_if_eligible before=82.333% after=82.381%\n", - " takes_up_aca_if_eligible before=66.718% after=67.486%\n", - " takes_up_dc_ptc before=31.483% after=32.044%\n", - " takes_up_head_start_if_eligible before=29.963% after=29.689%\n", - " takes_up_early_head_start_if_eligible before=8.869% after=8.721%\n", - " takes_up_ssi_if_eligible before=100.000% after=49.776%\n", - " would_file_taxes_voluntarily before=0.000% after=4.905%\n", - " takes_up_medicaid_if_eligible before=84.496% after=80.051%\n" + " takes_up_snap_if_eligible before=82.116% after=82.364%\n", + " takes_up_aca_if_eligible before=67.115% after=67.278%\n", + " takes_up_dc_ptc before=31.673% after=31.534%\n", + " takes_up_head_start_if_eligible before=100.000% after=29.852%\n", + " takes_up_early_head_start_if_eligible before=100.000% after=8.904%\n", + " takes_up_ssi_if_eligible before=100.000% after=49.504%\n", + " would_file_taxes_voluntarily before=0.000% after=5.115%\n", + " takes_up_medicaid_if_eligible before=84.868% after=80.354%\n", + " takes_up_tanf_if_eligible before=100.000% after=21.991%\n" ] } ], @@ -801,11 +803,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "2026-02-13 17:11:22,384 - INFO - Processing clone 1/3 (cols 0-11998, 50 unique states)...\n", - "2026-02-13 17:11:23,509 - INFO - Processing clone 2/3 (cols 11999-23997, 50 unique states)...\n", - "2026-02-13 17:11:24,645 - INFO - Processing clone 3/3 (cols 23998-35996, 50 unique states)...\n", - "2026-02-13 17:11:25,769 - INFO - Assembling matrix from 3 clones...\n", - "2026-02-13 17:11:25,771 - INFO - Matrix: 538 targets x 35997 cols, 14946 nnz\n" + "2026-02-20 15:34:21,531 - INFO - Per-state precomputation: 51 states, 1 hh vars, 1 constraint vars\n", + "2026-02-20 15:34:22,137 - INFO - State 1/51 complete\n", + "2026-02-20 15:34:27,750 - INFO - State 10/51 complete\n", + "2026-02-20 15:34:34,205 - INFO - State 20/51 complete\n", + "2026-02-20 15:34:40,885 - INFO - State 30/51 complete\n", + "2026-02-20 15:34:47,174 - INFO - State 40/51 complete\n", + "2026-02-20 15:34:53,723 - INFO - State 50/51 complete\n", + "2026-02-20 15:34:54,415 - INFO - Per-state precomputation done: 51 states\n", + "2026-02-20 15:34:54,419 - INFO - Assembling clone 1/3 (cols 0-11998, 51 unique states)...\n", + "2026-02-20 15:34:54,516 - INFO - Assembling matrix from 3 clones...\n", + "2026-02-20 15:34:54,517 - INFO - Matrix: 538 targets x 35997 cols, 19140 nnz\n" ] }, { @@ -813,8 +821,8 @@ "output_type": "stream", "text": [ "Matrix shape: (538, 35997)\n", - "Non-zero entries: 14,946\n", - "Density: 0.000772\n" + "Non-zero entries: 19,140\n", + "Density: 0.000988\n" ] } ], @@ -848,18 +856,9 @@ "text": [ "Example household non-zero pattern across clones:\n", "\n", - "Clone 0 (TX, CD 4817): 3 non-zero rows\n", - " row 39: household_count (geo=48): 1.00\n", - " row 90: snap (geo=48): 18396.00\n", - " row 410: household_count (geo=4817): 1.00\n", - "Clone 1 (PA, CD 4201): 3 non-zero rows\n", - " row 34: household_count (geo=42): 1.00\n", - " row 85: snap (geo=42): 18396.00\n", - " row 358: household_count (geo=4201): 1.00\n", - "Clone 2 (NY, CD 3611): 3 non-zero rows\n", - " row 27: household_count (geo=36): 1.00\n", - " row 78: snap (geo=36): 18396.00\n", - " row 292: household_count (geo=3611): 1.00\n" + "Clone 0 (TX, CD 4816): 0 non-zero rows\n", + "Clone 1 (PA, CD 4201): 0 non-zero rows\n", + "Clone 2 (NY, CD 3611): 0 non-zero rows\n" ] } ], @@ -993,6 +992,7 @@ "Extracted weights for 2 CDs from full weight matrix\n", "Total active household-CD pairs: 277\n", "Total weight in W matrix: 281\n", + "Warning: No rent data for CD 201, using geoadj=1.0\n", "Processing CD 201 (2/2)...\n" ] }, @@ -1000,10 +1000,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2026-02-13 17:11:40,873 - INFO - HTTP Request: GET https://huggingface.co/api/models/policyengine/policyengine-us-data \"HTTP/1.1 200 OK\"\n", - "2026-02-13 17:11:40,899 - INFO - HTTP Request: HEAD https://huggingface.co/policyengine/policyengine-us-data/resolve/main/enhanced_cps_2024.h5 \"HTTP/1.1 302 Found\"\n", - "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n", - "2026-02-13 17:11:40,899 - WARNING - Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n" + "2026-02-20 15:35:04,090 - INFO - HTTP Request: GET https://huggingface.co/api/models/policyengine/policyengine-us-data \"HTTP/1.1 200 OK\"\n", + "2026-02-20 15:35:04,123 - INFO - HTTP Request: HEAD https://huggingface.co/policyengine/policyengine-us-data/resolve/main/enhanced_cps_2024.h5 \"HTTP/1.1 302 Found\"\n" ] }, { @@ -1013,7 +1011,7 @@ "\n", "Combining 2 CD DataFrames...\n", "Total households across all CDs: 277\n", - "Combined DataFrame shape: (726, 222)\n", + "Combined DataFrame shape: (716, 219)\n", "\n", "Reindexing all entity IDs using 25k ranges per CD...\n", " Created 277 unique households across 2 CDs\n", @@ -1022,12 +1020,12 @@ " Reindexing SPM units...\n", " Reindexing marital units...\n", " Reindexing families...\n", - " Final persons: 726\n", + " Final persons: 716\n", " Final households: 277\n", - " Final tax units: 373\n", - " Final SPM units: 291\n", - " Final marital units: 586\n", - " Final families: 309\n", + " Final tax units: 387\n", + " Final SPM units: 290\n", + " Final marital units: 587\n", + " Final families: 318\n", "\n", "Weights in combined_df AFTER reindexing:\n", " HH weight sum: 0.00M\n", @@ -1035,8 +1033,8 @@ " Ratio: 1.00\n", "\n", "Overflow check:\n", - " Max person ID after reindexing: 5,025,335\n", - " Max person ID × 100: 502,533,500\n", + " Max person ID after reindexing: 5,025,365\n", + " Max person ID × 100: 502,536,500\n", " int32 max: 2,147,483,647\n", " ✓ No overflow risk!\n", "\n", @@ -1044,15 +1042,15 @@ "Building simulation from Dataset...\n", "\n", "Saving to calibration_output/results.h5...\n", - "Found 175 input variables to save\n", - "Variables saved: 218\n", - "Variables skipped: 3763\n", + "Found 172 input variables to save\n", + "Variables saved: 215\n", + "Variables skipped: 3825\n", "Sparse CD-stacked dataset saved successfully!\n", "Household mapping saved to calibration_output/mappings/results_household_mapping.csv\n", "\n", "Verifying saved file...\n", " Final households: 277\n", - " Final persons: 726\n", + " Final persons: 716\n", " Total population (from household weights): 281\n" ] }, @@ -1089,17 +1087,17 @@ "text": [ "Stacked dataset: 277 households\n", "\n", - "Example household (original_id=128694) in mapping:\n", + "Example household (original_id=130831) in mapping:\n", "\n", " new_household_id original_household_id congressional_district state_fips\n", - " 108 128694 201 2\n", - " 25097 128694 3701 37\n", + " 108 130831 201 2\n", + " 25097 130831 3701 37\n", "\n", "In stacked dataset:\n", "\n", - " household_id congressional_district_geoid household_weight state_fips snap\n", - " 108 201 3.5 2 23640.0\n", - " 25097 3701 2.5 37 18396.0\n" + " household_id congressional_district_geoid household_weight state_fips snap\n", + " 108 201 3.5 2 0.0\n", + " 25097 3701 2.5 37 0.0\n" ] } ], diff --git a/docs/myst.yml b/docs/myst.yml index c38666af9..56d74ec45 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -26,6 +26,7 @@ project: - file: methodology.md - file: long_term_projections.md - file: local_area_calibration_setup.ipynb + - file: calibration.md - file: discussion.md - file: conclusion.md - file: appendix.md diff --git a/modal_app/README.md b/modal_app/README.md index 0b10cf726..49c976aab 100644 --- a/modal_app/README.md +++ b/modal_app/README.md @@ -22,41 +22,217 @@ modal run modal_app/remote_calibration_runner.py --branch --epochs | `--epochs` | `200` | Number of training epochs | | `--gpu` | `T4` | GPU type: `T4`, `A10`, `A100-40GB`, `A100-80GB`, `H100` | | `--output` | `calibration_weights.npy` | Local path for weights file | -| `--log-output` | `calibration_log.csv` | Local path for calibration log | +| `--log-output` | `unified_diagnostics.csv` | Local path for diagnostics log | +| `--log-freq` | (none) | Log every N epochs to `calibration_log.csv` | +| `--push-results` | `False` | Upload weights, blocks, and logs to HuggingFace | +| `--trigger-publish` | `False` | Fire `repository_dispatch` to trigger the Publish workflow | +| `--target-config` | (none) | Target configuration name | +| `--beta` | (none) | L0 relaxation parameter | +| `--lambda-l0` | (none) | L0 penalty weight | +| `--lambda-l2` | (none) | L2 penalty weight | +| `--learning-rate` | (none) | Optimizer learning rate | +| `--package-path` | (none) | Local path to a pre-built calibration package (uploads to Modal volume, then fits) | +| `--prebuilt-matrices` | `False` | Fit from pre-built package on Modal volume | +| `--full-pipeline` | `False` | Force full rebuild even if a package exists on the volume | +| `--county-level` | `False` | Include county-level targets | +| `--workers` | `1` | Number of parallel workers | + +### Examples + +**Two-step workflow (recommended):** + +Step 1 — Build the X matrix on CPU (no GPU cost, 10h timeout): +```bash +modal run modal_app/remote_calibration_runner.py::build_package \ + --branch main +``` -### Example +Step 2 — Fit weights from the pre-built package on GPU: +```bash +modal run modal_app/remote_calibration_runner.py::main \ + --branch main --epochs 200 --gpu A100-80GB \ + --prebuilt-matrices --push-results +``` + +**Full pipeline (single step, requires enough timeout for matrix build + fit):** +```bash +modal run modal_app/remote_calibration_runner.py::main \ + --branch main --epochs 200 --gpu A100-80GB \ + --full-pipeline --push-results +``` +Fit, push, and trigger the publish workflow: ```bash -modal run modal_app/remote_calibration_runner.py --branch health-insurance-premiums --epochs 100 --gpu T4 +modal run modal_app/remote_calibration_runner.py::main \ + --gpu A100-80GB --epochs 200 \ + --prebuilt-matrices --push-results --trigger-publish ``` ## Output Files -- **calibration_weights.npy** - Fitted household weights -- **calibration_log.csv** - Per-target performance metrics across epochs (target_name, estimate, target, epoch, error, rel_error, abs_error, rel_abs_error, loss) +Every run produces these local files (whichever the calibration script emits): -## Changing Hyperparameters +- **calibration_weights.npy** — Fitted household weights +- **unified_diagnostics.csv** — Final per-target diagnostics +- **calibration_log.csv** — Per-target metrics across epochs (requires `--log-freq`) +- **unified_run_config.json** — Run configuration and summary stats +- **stacked_blocks.npy** — Census block assignments for stacked records -Hyperparameters are in `policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py`: +## Artifact Upload to HuggingFace + +The `--push-results` flag uploads all artifacts to HuggingFace in a single +atomic commit after writing them locally: + +| Local file | HF path | +|------------|---------| +| `calibration_weights.npy` | `calibration/calibration_weights.npy` | +| `stacked_blocks.npy` | `calibration/stacked_blocks.npy` | +| `calibration_log.csv` | `calibration/logs/calibration_log.csv` | +| `unified_diagnostics.csv` | `calibration/logs/unified_diagnostics.csv` | +| `unified_run_config.json` | `calibration/logs/unified_run_config.json` | + +Each upload overwrites the previous files. HF git history provides implicit +versioning — browse past commits to see earlier runs. + +## Triggering the Publish Workflow + +The `--trigger-publish` flag fires a `repository_dispatch` event +(`calibration-updated`) on GitHub, which starts the "Publish Local Area H5 +Files" workflow. Requires `GITHUB_TOKEN` or +`POLICYENGINE_US_DATA_GITHUB_TOKEN` set locally. + +### Downloading logs ```python -BETA = 0.35 -GAMMA = -0.1 -ZETA = 1.1 -INIT_KEEP_PROB = 0.999 -LOG_WEIGHT_JITTER_SD = 0.05 -LOG_ALPHA_JITTER_SD = 0.01 -LAMBDA_L0 = 1e-8 -LAMBDA_L2 = 1e-8 -LEARNING_RATE = 0.15 -``` - -To change them: -1. Edit `fit_calibration_weights.py` -2. Commit and push to your branch -3. Re-run the Modal command with that branch +from policyengine_us_data.utils.huggingface import download_calibration_logs + +paths = download_calibration_logs("/tmp/cal_logs") +# {"calibration_log": Path(...), "diagnostics": Path(...), "config": Path(...)} +``` + +Pass `version=""` to download from a specific HF revision. + +### Viewing logs in the microcalibrate dashboard + +The [microcalibration dashboard](https://github.com/PolicyEngine/microcalibrate) +has a **Hugging Face** tab that loads `calibration_log.csv` directly from HF: + +1. Open the dashboard +2. Click the **Hugging Face** tab +3. Defaults are pre-filled — click **Load** +4. Change the **Revision** field to load from a specific HF commit or tag ## Important Notes -- **Keep your connection open** - Modal needs to stay connected to download results. Don't close your laptop or let it sleep until you see the local "Weights saved to:" and "Calibration log saved to:" messages. -- Modal clones from GitHub, so local changes must be pushed before they take effect. +- **Keep your connection open** — Modal needs to stay connected to download + results. Don't close your laptop or let it sleep until you see the local + "Weights saved to:" message. +- Modal clones from GitHub, so local changes must be pushed before they + take effect. +- `--push-results` requires the `HUGGING_FACE_TOKEN` environment variable + to be set locally (not just as a Modal secret). +- `--trigger-publish` requires `GITHUB_TOKEN` or + `POLICYENGINE_US_DATA_GITHUB_TOKEN` set locally. + +## Full Pipeline Reference + +The calibration pipeline has six stages. Each can be run locally, via Modal CLI, or via GitHub Actions. + +### Stage 1: Build data + +Produces `stratified_extended_cps_2024.h5` from raw CPS/PUF/ACS inputs. + +| Method | Command | +|--------|---------| +| **Local** | `make data` | +| **Modal (CI)** | `modal run modal_app/data_build.py --branch=` | +| **GitHub Actions** | Automatic on merge to `main` via `code_changes.yaml` → `reusable_test.yaml` (with `full_suite: true`). Also triggered by `pr_code_changes.yaml` on PRs. | + +Notes: +- `make data` stops at `create_stratified_cps.py`. Use `make data-legacy` to also build `enhanced_cps.py` and `small_enhanced_cps.py`. +- `data_build.py` (CI) always builds the full suite including enhanced_cps. + +### Stage 2: Upload inputs to HuggingFace + +Pushes the dataset and (optionally) database to HF so Modal can download them. + +| Artifact | Command | +|----------|---------| +| Dataset | `make upload-dataset` | +| Database | `make upload-database` | + +The database is relatively stable; only re-upload after `make database` or `make database-refresh`. + +### Stage 3: Build calibration matrices + +Downloads dataset + database from HF, builds the X matrix, saves to Modal volume. CPU-only, no GPU cost. + +| Method | Command | +|--------|---------| +| **Local** | `make calibrate-build` | +| **Modal CLI** | `make build-matrices BRANCH=` (aka `modal run modal_app/remote_calibration_runner.py::build_package --branch=`) | + +### Stage 4: Fit calibration weights + +Loads pre-built matrices from Modal volume, fits L0-regularized weights on GPU. + +| Method | Command | +|--------|---------| +| **Local (CPU)** | `make calibrate` | +| **Modal CLI** | `make calibrate-modal BRANCH= GPU= EPOCHS=` | + +`make calibrate-modal` passes `--prebuilt-matrices --push-results` automatically. + +Full example: +``` +modal run modal_app/remote_calibration_runner.py::main \ + --branch calibration-pipeline-improvements \ + --gpu T4 --epochs 1000 \ + --beta 0.65 --lambda-l0 1e-6 --lambda-l2 1e-8 \ + --log-freq 500 \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --prebuilt-matrices --push-results +``` + +**Safety check**: If a pre-built package exists on the volume and you don't pass `--prebuilt-matrices` or `--full-pipeline`, the runner refuses to proceed and tells you which flag to add. This prevents accidentally rebuilding from scratch. + +Artifacts uploaded to HF by `--push-results`: + +| Local file | HF path | +|------------|---------| +| `calibration_weights.npy` | `calibration/calibration_weights.npy` | +| `stacked_blocks.npy` | `calibration/stacked_blocks.npy` | +| `calibration_log.csv` | `calibration/logs/calibration_log.csv` | +| `unified_diagnostics.csv` | `calibration/logs/unified_diagnostics.csv` | +| `unified_run_config.json` | `calibration/logs/unified_run_config.json` | + +### Stage 5: Build and stage local area H5 files + +Downloads weights + dataset + database from HF, builds state/district/city H5 files. + +| Method | Command | +|--------|---------| +| **Local** | `python policyengine_us_data/calibration/publish_local_area.py --rerandomize-takeup` | +| **Modal CLI** | `make stage-h5s BRANCH=` (aka `modal run modal_app/local_area.py --branch= --num-workers=8`) | +| **GitHub Actions** | "Publish Local Area H5 Files" workflow — manual trigger via `workflow_dispatch`, or automatic via `repository_dispatch` (`--trigger-publish` flag), or on code push to `main` touching `calibration/` or `modal_app/`. | + +This stages H5s to HF `staging/` paths. It does NOT promote to production or GCS. + +### Stage 6: Promote (manual gate) + +Moves files from HF staging to production paths and uploads to GCS. + +| Method | Command | +|--------|---------| +| **Modal CLI** | `modal run modal_app/local_area.py::main_promote --version=` | +| **GitHub Actions** | "Promote Local Area H5 Files" workflow — manual `workflow_dispatch` only. Requires `version` input. | + +### One-command pipeline + +For the common case (local data build → Modal calibration → Modal staging): + +``` +make pipeline GPU=T4 EPOCHS=1000 BRANCH=calibration-pipeline-improvements +``` + +This chains: `data` → `upload-dataset` → `build-matrices` → `calibrate-modal` → `stage-h5s`. diff --git a/modal_app/data_build.py b/modal_app/data_build.py index 131e7f0bf..8c75187d2 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -55,10 +55,13 @@ "policyengine_us_data/storage/enhanced_cps_2024.h5", "calibration_log.csv", ], - "policyengine_us_data/datasets/cps/" - "local_area_calibration/create_stratified_cps.py": ( + "policyengine_us_data/calibration/create_stratified_cps.py": ( "policyengine_us_data/storage/stratified_extended_cps_2024.h5" ), + "policyengine_us_data/calibration/create_source_imputed_cps.py": ( + "policyengine_us_data/storage/" + "source_imputed_stratified_extended_cps_2024.h5" + ), "policyengine_us_data/datasets/cps/small_enhanced_cps.py": ( "policyengine_us_data/storage/small_enhanced_cps_2024.h5" ), @@ -70,7 +73,7 @@ "policyengine_us_data/tests/test_database.py", "policyengine_us_data/tests/test_pandas3_compatibility.py", "policyengine_us_data/tests/test_datasets/", - "policyengine_us_data/tests/test_local_area_calibration/", + "policyengine_us_data/tests/test_calibration/", ] @@ -408,11 +411,9 @@ def build_datasets( ), executor.submit( run_script_with_checkpoint, - "policyengine_us_data/datasets/cps/" - "local_area_calibration/create_stratified_cps.py", + "policyengine_us_data/calibration/create_stratified_cps.py", SCRIPT_OUTPUTS[ - "policyengine_us_data/datasets/cps/" - "local_area_calibration/create_stratified_cps.py" + "policyengine_us_data/calibration/create_stratified_cps.py" ], branch, checkpoint_volume, diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 92e068335..33f798cf9 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -128,6 +128,88 @@ def get_completed_from_volume(version_dir: Path) -> set: return completed +def run_phase( + phase_name: str, + states: List[str], + districts: List[str], + cities: List[str], + num_workers: int, + completed: set, + branch: str, + version: str, + calibration_inputs: Dict[str, str], + version_dir: Path, +) -> set: + """Run a single build phase, spawning workers and collecting results.""" + work_chunks = partition_work( + states, districts, cities, num_workers, completed + ) + total_remaining = sum(len(c) for c in work_chunks) + + print(f"\n--- Phase: {phase_name} ---") + print( + f"Remaining work: {total_remaining} items " + f"across {len(work_chunks)} workers" + ) + + if total_remaining == 0: + print(f"All {phase_name} items already built!") + return completed + + handles = [] + for i, chunk in enumerate(work_chunks): + print(f" Worker {i}: {len(chunk)} items") + handle = build_areas_worker.spawn( + branch=branch, + version=version, + work_items=chunk, + calibration_inputs=calibration_inputs, + ) + handles.append(handle) + + print(f"Waiting for {phase_name} workers to complete...") + all_results = [] + all_errors = [] + + for i, handle in enumerate(handles): + try: + result = handle.get() + all_results.append(result) + print( + f" Worker {i}: {len(result['completed'])} completed, " + f"{len(result['failed'])} failed" + ) + if result["errors"]: + all_errors.extend(result["errors"]) + except Exception as e: + all_errors.append({"worker": i, "error": str(e)}) + print(f" Worker {i}: CRASHED - {e}") + + total_completed = sum(len(r["completed"]) for r in all_results) + total_failed = sum(len(r["failed"]) for r in all_results) + + staging_volume.reload() + volume_completed = get_completed_from_volume(version_dir) + volume_new = volume_completed - completed + + print(f"\n{phase_name} summary (worker-reported):") + print(f" Completed: {total_completed}") + print(f" Failed: {total_failed}") + print(f"{phase_name} summary (volume verification):") + print(f" Files on volume: {len(volume_completed)}") + print(f" New files this run: {len(volume_new)}") + + if all_errors: + print(f"\nErrors ({len(all_errors)}):") + for err in all_errors[:5]: + err_msg = err.get("error", "Unknown")[:100] + print(f" - {err.get('item', err.get('worker'))}: " f"{err_msg}") + if len(all_errors) > 5: + print(f" ... and {len(all_errors) - 5} more") + + return volume_completed + + @app.function( image=image, secrets=[hf_secret, gcp_secret], @@ -154,23 +236,46 @@ def build_areas_worker( work_items_json = json.dumps(work_items) + worker_cmd = [ + "uv", + "run", + "python", + "modal_app/worker_script.py", + "--work-items", + work_items_json, + "--weights-path", + calibration_inputs["weights"], + "--dataset-path", + calibration_inputs["dataset"], + "--db-path", + calibration_inputs["database"], + "--output-dir", + str(output_dir), + ] + if "blocks" in calibration_inputs: + worker_cmd.extend( + [ + "--calibration-blocks", + calibration_inputs["blocks"], + ] + ) + if "geo_labels" in calibration_inputs: + worker_cmd.extend( + [ + "--geo-labels", + calibration_inputs["geo_labels"], + ] + ) + if "stacked_takeup" in calibration_inputs: + worker_cmd.extend( + [ + "--stacked-takeup", + calibration_inputs["stacked_takeup"], + ] + ) + result = subprocess.run( - [ - "uv", - "run", - "python", - "modal_app/worker_script.py", - "--work-items", - work_items_json, - "--weights-path", - calibration_inputs["weights"], - "--dataset-path", - calibration_inputs["dataset"], - "--db-path", - calibration_inputs["database"], - "--output-dir", - str(output_dir), - ], + worker_cmd, capture_output=True, text=True, env=os.environ.copy(), @@ -254,18 +359,18 @@ def validate_staging(branch: str, version: str) -> Dict: @app.function( image=image, - secrets=[hf_secret, gcp_secret], + secrets=[hf_secret], volumes={VOLUME_MOUNT: staging_volume}, memory=8192, timeout=14400, ) def upload_to_staging(branch: str, version: str, manifest: Dict) -> str: """ - Upload files to GCS (production) and HuggingFace (staging only). + Upload files to HuggingFace staging only. + GCS is updated during promote_publish, not here. Promote must be run separately via promote_publish. """ - setup_gcp_credentials() setup_repo(branch) manifest_json = json.dumps(manifest) @@ -280,10 +385,7 @@ def upload_to_staging(branch: str, version: str, manifest: Dict) -> str: import json from pathlib import Path from policyengine_us_data.utils.manifest import verify_manifest -from policyengine_us_data.utils.data_upload import ( - upload_local_area_file, - upload_to_staging_hf, -) +from policyengine_us_data.utils.data_upload import upload_to_staging_hf manifest = json.loads('''{manifest_json}''') version = "{version}" @@ -305,20 +407,6 @@ def upload_to_staging(branch: str, version: str, manifest: Dict) -> str: local_path = version_dir / rel_path files_with_paths.append((local_path, rel_path)) -# Upload to GCS (direct to production paths) -print(f"Uploading {{len(files_with_paths)}} files to GCS...") -gcs_count = 0 -for local_path, rel_path in files_with_paths: - subdirectory = str(Path(rel_path).parent) - upload_local_area_file( - str(local_path), - subdirectory, - version=version, - skip_hf=True, - ) - gcs_count += 1 -print(f"Uploaded {{gcs_count}} files to GCS") - # Upload to HuggingFace staging/ print(f"Uploading {{len(files_with_paths)}} files to HuggingFace staging/...") hf_count = upload_to_staging_hf(files_with_paths, version) @@ -336,24 +424,26 @@ def upload_to_staging(branch: str, version: str, manifest: Dict) -> str: return ( f"Staged version {version} with {len(manifest['files'])} files. " - f"Run promote workflow to publish to HuggingFace production." + f"Run promote workflow to publish to HuggingFace production and GCS." ) @app.function( image=image, - secrets=[hf_secret], + secrets=[hf_secret, gcp_secret], volumes={VOLUME_MOUNT: staging_volume}, memory=4096, timeout=3600, ) def promote_publish(branch: str = "main", version: str = "") -> str: """ - Promote staged files from HF staging/ to production paths, then cleanup. + Promote staged files from HF staging/ to production paths, + upload to GCS, then cleanup HF staging. Reads the manifest from the Modal staging volume to determine which files to promote. """ + setup_gcp_credentials() setup_repo(branch) staging_dir = Path(VOLUME_MOUNT) @@ -379,17 +469,34 @@ def promote_publish(branch: str = "main", version: str = "") -> str: "-c", f""" import json +from pathlib import Path from policyengine_us_data.utils.data_upload import ( promote_staging_to_production_hf, cleanup_staging_hf, + upload_local_area_file, ) rel_paths = json.loads('''{rel_paths_json}''') version = "{version}" +version_dir = Path("{VOLUME_MOUNT}") / version print(f"Promoting {{len(rel_paths)}} files from staging/ to production...") promoted = promote_staging_to_production_hf(rel_paths, version) -print(f"Promoted {{promoted}} files to production") +print(f"Promoted {{promoted}} files to HuggingFace production") + +print(f"Uploading {{len(rel_paths)}} files to GCS...") +gcs_count = 0 +for rel_path in rel_paths: + local_path = version_dir / rel_path + subdirectory = str(Path(rel_path).parent) + upload_local_area_file( + str(local_path), + subdirectory, + version=version, + skip_hf=True, + ) + gcs_count += 1 +print(f"Uploaded {{gcs_count}} files to GCS") print("Cleaning up staging/...") cleaned = cleanup_staging_hf(rel_paths, version) @@ -419,6 +526,7 @@ def coordinate_publish( branch: str = "main", num_workers: int = 8, skip_upload: bool = False, + skip_download: bool = False, ) -> str: """Coordinate the full publishing workflow.""" setup_gcp_credentials() @@ -428,24 +536,46 @@ def coordinate_publish( print(f"Publishing version {version} from branch {branch}") print(f"Using {num_workers} parallel workers") + import shutil + staging_dir = Path(VOLUME_MOUNT) version_dir = staging_dir / version + if version_dir.exists(): + print(f"Clearing stale build directory: {version_dir}") + shutil.rmtree(version_dir) version_dir.mkdir(parents=True, exist_ok=True) calibration_dir = staging_dir / "calibration_inputs" - calibration_dir.mkdir(parents=True, exist_ok=True) # hf_hub_download preserves directory structure, so files are in calibration/ subdir - weights_path = ( - calibration_dir / "calibration" / "w_district_calibration.npy" - ) - dataset_path = ( - calibration_dir / "calibration" / "stratified_extended_cps.h5" - ) + weights_path = calibration_dir / "calibration" / "calibration_weights.npy" db_path = calibration_dir / "calibration" / "policy_data.db" - if not all(p.exists() for p in [weights_path, dataset_path, db_path]): - print("Downloading calibration inputs...") + if skip_download: + print("Verifying pre-pushed calibration inputs...") + staging_volume.reload() + dataset_path = ( + calibration_dir + / "calibration" + / "source_imputed_stratified_extended_cps.h5" + ) + required = { + "weights": weights_path, + "dataset": dataset_path, + "database": db_path, + } + for label, p in required.items(): + if not p.exists(): + raise RuntimeError( + f"Missing required calibration input " f"({label}): {p}" + ) + print("All required calibration inputs found on volume.") + else: + if calibration_dir.exists(): + shutil.rmtree(calibration_dir) + calibration_dir.mkdir(parents=True, exist_ok=True) + + print("Downloading calibration inputs from HuggingFace...") result = subprocess.run( [ "uv", @@ -464,15 +594,31 @@ def coordinate_publish( if result.returncode != 0: raise RuntimeError(f"Download failed: {result.stderr}") staging_volume.commit() - print("Calibration inputs downloaded and cached on volume") - else: - print("Using cached calibration inputs from volume") + print("Calibration inputs downloaded") + dataset_path = ( + calibration_dir + / "calibration" + / "source_imputed_stratified_extended_cps.h5" + ) + + blocks_path = calibration_dir / "calibration" / "stacked_blocks.npy" + geo_labels_path = calibration_dir / "calibration" / "geo_labels.json" calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), "database": str(db_path), } + if blocks_path.exists(): + calibration_inputs["blocks"] = str(blocks_path) + print(f"Calibration blocks found: {blocks_path}") + if geo_labels_path.exists(): + calibration_inputs["geo_labels"] = str(geo_labels_path) + print(f"Geo labels found: {geo_labels_path}") + takeup_path = calibration_dir / "calibration" / "stacked_takeup.npz" + if takeup_path.exists(): + calibration_inputs["stacked_takeup"] = str(takeup_path) + print(f"Stacked takeup found: {takeup_path}") result = subprocess.run( [ @@ -482,11 +628,11 @@ def coordinate_publish( "-c", f""" import json -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( +from policyengine_us_data.calibration.calibration_utils import ( get_all_cds_from_database, STATE_CODES, ) -from policyengine_us_data.datasets.cps.local_area_calibration.publish_local_area import ( +from policyengine_us_data.calibration.publish_local_area import ( get_district_friendly_name, ) @@ -514,70 +660,49 @@ def coordinate_publish( completed = get_completed_from_volume(version_dir) print(f"Found {len(completed)} already-completed items on volume") - work_chunks = partition_work( - states, districts, cities, num_workers, completed + phase_args = dict( + num_workers=num_workers, + branch=branch, + version=version, + calibration_inputs=calibration_inputs, + version_dir=version_dir, ) - total_remaining = sum(len(c) for c in work_chunks) - print( - f"Remaining work: {total_remaining} items " - f"across {len(work_chunks)} workers" + completed = run_phase( + "States", + states=states, + districts=[], + cities=[], + completed=completed, + **phase_args, ) - if total_remaining == 0: - print("All items already built!") - else: - print("\nSpawning workers...") - handles = [] - for i, chunk in enumerate(work_chunks): - print(f" Worker {i}: {len(chunk)} items") - handle = build_areas_worker.spawn( - branch=branch, - version=version, - work_items=chunk, - calibration_inputs=calibration_inputs, - ) - handles.append(handle) - - print("\nWaiting for workers to complete...") - all_results = [] - all_errors = [] - - for i, handle in enumerate(handles): - try: - result = handle.get() - all_results.append(result) - print( - f" Worker {i}: {len(result['completed'])} completed, " - f"{len(result['failed'])} failed" - ) - if result["errors"]: - all_errors.extend(result["errors"]) - except Exception as e: - all_errors.append({"worker": i, "error": str(e)}) - print(f" Worker {i}: CRASHED - {e}") - - total_completed = sum(len(r["completed"]) for r in all_results) - total_failed = sum(len(r["failed"]) for r in all_results) - - print(f"\nBuild summary:") - print(f" Completed: {total_completed}") - print(f" Failed: {total_failed}") - print(f" Previously completed: {len(completed)}") - - if all_errors: - print(f"\nErrors ({len(all_errors)}):") - for err in all_errors[:5]: - err_msg = err.get("error", "Unknown")[:100] - print(f" - {err.get('item', err.get('worker'))}: {err_msg}") - if len(all_errors) > 5: - print(f" ... and {len(all_errors) - 5} more") - - if total_failed > 0: - raise RuntimeError( - f"Build incomplete: {total_failed} failures. " - f"Volume preserved for retry." - ) + completed = run_phase( + "Districts", + states=[], + districts=districts, + cities=[], + completed=completed, + **phase_args, + ) + + completed = run_phase( + "Cities", + states=[], + districts=[], + cities=cities, + completed=completed, + **phase_args, + ) + + expected_total = len(states) + len(districts) + len(cities) + if len(completed) < expected_total: + missing = expected_total - len(completed) + raise RuntimeError( + f"Build incomplete: {missing} files missing from " + f"volume ({len(completed)}/{expected_total}). " + f"Volume preserved for retry." + ) if skip_upload: print("\nSkipping upload (--skip-upload flag set)") @@ -625,13 +750,235 @@ def main( branch: str = "main", num_workers: int = 8, skip_upload: bool = False, + skip_download: bool = False, ): """Local entrypoint for Modal CLI.""" result = coordinate_publish.remote( branch=branch, num_workers=num_workers, skip_upload=skip_upload, + skip_download=skip_download, + ) + print(result) + + +@app.function( + image=image, + secrets=[hf_secret, gcp_secret], + volumes={VOLUME_MOUNT: staging_volume}, + memory=16384, + timeout=14400, +) +def coordinate_national_publish( + branch: str = "main", +) -> str: + """Build and upload a national US.h5 from national weights.""" + setup_gcp_credentials() + setup_repo(branch) + + version = get_version() + print( + f"Building national H5 for version {version} " f"from branch {branch}" + ) + + import shutil + + staging_dir = Path(VOLUME_MOUNT) + calibration_dir = staging_dir / "national_calibration_inputs" + if calibration_dir.exists(): + shutil.rmtree(calibration_dir) + calibration_dir.mkdir(parents=True, exist_ok=True) + + print("Downloading national calibration inputs from HF...") + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +from policyengine_us_data.utils.huggingface import ( + download_calibration_inputs, +) +download_calibration_inputs("{calibration_dir}", prefix="national_") +print("Done") +""", + ], + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(f"Download failed: {result.stderr}") + staging_volume.commit() + print("National calibration inputs downloaded") + + weights_path = ( + calibration_dir / "calibration" / "national_calibration_weights.npy" + ) + db_path = calibration_dir / "calibration" / "policy_data.db" + dataset_path = ( + calibration_dir + / "calibration" + / "source_imputed_stratified_extended_cps.h5" + ) + + blocks_path = ( + calibration_dir / "calibration" / "national_stacked_blocks.npy" ) + national_geo_labels_path = ( + calibration_dir / "calibration" / "national_geo_labels.json" + ) + calibration_inputs = { + "weights": str(weights_path), + "dataset": str(dataset_path), + "database": str(db_path), + } + if blocks_path.exists(): + calibration_inputs["blocks"] = str(blocks_path) + print(f"National calibration blocks found: {blocks_path}") + if national_geo_labels_path.exists(): + calibration_inputs["geo_labels"] = str(national_geo_labels_path) + print(f"National geo labels found: " f"{national_geo_labels_path}") + national_takeup_path = ( + calibration_dir / "calibration" / "national_stacked_takeup.npz" + ) + if national_takeup_path.exists(): + calibration_inputs["stacked_takeup"] = str(national_takeup_path) + print(f"National stacked takeup found: " f"{national_takeup_path}") + + version_dir = staging_dir / version + version_dir.mkdir(parents=True, exist_ok=True) + + work_items = [{"type": "national", "id": "US"}] + print("Spawning worker for national H5 build...") + worker_result = build_areas_worker.remote( + branch=branch, + version=version, + work_items=work_items, + calibration_inputs=calibration_inputs, + ) + + print( + f"Worker result: " + f"{len(worker_result['completed'])} completed, " + f"{len(worker_result['failed'])} failed" + ) + + if worker_result["failed"]: + raise RuntimeError(f"National build failed: {worker_result['errors']}") + + staging_volume.reload() + national_h5 = version_dir / "national" / "US.h5" + if not national_h5.exists(): + raise RuntimeError(f"Expected {national_h5} not found after build") + + print(f"Uploading {national_h5} to HF staging...") + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +from policyengine_us_data.utils.data_upload import ( + upload_to_staging_hf, +) +upload_to_staging_hf( + [("{national_h5}", "national/US.h5")], + "{version}", +) +print("Done") +""", + ], + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(f"Staging upload failed: {result.stderr}") + + print("National H5 staged. Run promote workflow to publish.") + return ( + f"National US.h5 built and staged for version {version}. " + f"Run main_national_promote to publish." + ) + + +@app.local_entrypoint() +def main_national(branch: str = "main"): + """Build and stage national US.h5.""" + result = coordinate_national_publish.remote(branch=branch) + print(result) + + +@app.function( + image=image, + secrets=[hf_secret, gcp_secret], + volumes={VOLUME_MOUNT: staging_volume}, + memory=4096, + timeout=3600, +) +def promote_national_publish( + branch: str = "main", +) -> str: + """Promote national US.h5 from HF staging to production + GCS.""" + setup_gcp_credentials() + setup_repo(branch) + + version = get_version() + rel_paths = ["national/US.h5"] + + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +import json +from pathlib import Path +from policyengine_us_data.utils.data_upload import ( + promote_staging_to_production_hf, + cleanup_staging_hf, + upload_local_area_file, +) + +version = "{version}" +rel_paths = {json.dumps(rel_paths)} +version_dir = Path("{VOLUME_MOUNT}") / version + +print(f"Promoting national H5 from staging to production...") +promoted = promote_staging_to_production_hf(rel_paths, version) +print(f"Promoted {{promoted}} files to HuggingFace production") + +national_h5 = version_dir / "national" / "US.h5" +if national_h5.exists(): + print("Uploading national H5 to GCS...") + upload_local_area_file( + str(national_h5), "national", version=version, skip_hf=True + ) + print("Uploaded national H5 to GCS") +else: + print(f"WARNING: {{national_h5}} not on volume, skipping GCS") + +print("Cleaning up staging...") +cleaned = cleanup_staging_hf(rel_paths, version) +print(f"Cleaned up {{cleaned}} files from staging") +print(f"Successfully promoted national H5 for version {{version}}") +""", + ], + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(f"National promote failed: {result.stderr}") + + return f"National US.h5 promoted for version {version}" + + +@app.local_entrypoint() +def main_national_promote(branch: str = "main"): + """Promote staged national US.h5 to production.""" + result = promote_national_publish.remote(branch=branch) print(result) diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 689d245dd..87b2fb833 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -5,6 +5,9 @@ app = modal.App("policyengine-us-data-fit-weights") hf_secret = modal.Secret.from_name("huggingface-token") +calibration_vol = modal.Volume.from_name( + "calibration-data", create_if_missing=True +) image = ( modal.Image.debian_slim(python_version="3.11") @@ -13,18 +16,211 @@ ) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" +VOLUME_MOUNT = "/calibration-data" + + +def _run_streaming(cmd, env=None, label=""): + """Run a subprocess, streaming output line-by-line. + + Returns (returncode, captured_stdout_lines). + """ + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=env, + ) + lines = [] + for line in proc.stdout: + line = line.rstrip("\n") + if label: + print(f"[{label}] {line}", flush=True) + else: + print(line, flush=True) + lines.append(line) + proc.wait() + return proc.returncode, lines -def _fit_weights_impl(branch: str, epochs: int) -> dict: - """Shared implementation for weight fitting.""" +def _clone_and_install(branch: str): + """Clone the repo and install dependencies.""" os.chdir("/root") subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True) os.chdir("policyengine-us-data") - subprocess.run(["uv", "sync", "--extra", "l0"], check=True) - print("Downloading calibration inputs from HuggingFace...") - download_result = subprocess.run( + +def _append_hyperparams( + cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None +): + """Append optional hyperparameter flags to a command list.""" + if beta is not None: + cmd.extend(["--beta", str(beta)]) + if lambda_l0 is not None: + cmd.extend(["--lambda-l0", str(lambda_l0)]) + if lambda_l2 is not None: + cmd.extend(["--lambda-l2", str(lambda_l2)]) + if learning_rate is not None: + cmd.extend(["--learning-rate", str(learning_rate)]) + if log_freq is not None: + cmd.extend(["--log-freq", str(log_freq)]) + + +def _collect_outputs(cal_lines): + """Extract weights and log bytes from calibration output lines.""" + output_path = None + log_path = None + cal_log_path = None + config_path = None + blocks_path = None + geo_labels_path = None + for line in cal_lines: + if "OUTPUT_PATH:" in line: + output_path = line.split("OUTPUT_PATH:")[1].strip() + elif "CONFIG_PATH:" in line: + config_path = line.split("CONFIG_PATH:")[1].strip() + elif "CAL_LOG_PATH:" in line: + cal_log_path = line.split("CAL_LOG_PATH:")[1].strip() + elif "GEO_LABELS_PATH:" in line: + geo_labels_path = line.split("GEO_LABELS_PATH:")[1].strip() + elif "BLOCKS_PATH:" in line: + blocks_path = line.split("BLOCKS_PATH:")[1].strip() + elif "LOG_PATH:" in line: + log_path = line.split("LOG_PATH:")[1].strip() + + with open(output_path, "rb") as f: + weights_bytes = f.read() + + log_bytes = None + if log_path: + with open(log_path, "rb") as f: + log_bytes = f.read() + + cal_log_bytes = None + if cal_log_path: + with open(cal_log_path, "rb") as f: + cal_log_bytes = f.read() + + config_bytes = None + if config_path: + with open(config_path, "rb") as f: + config_bytes = f.read() + + blocks_bytes = None + if blocks_path and os.path.exists(blocks_path): + with open(blocks_path, "rb") as f: + blocks_bytes = f.read() + + geo_labels_bytes = None + if geo_labels_path and os.path.exists(geo_labels_path): + with open(geo_labels_path, "rb") as f: + geo_labels_bytes = f.read() + + return { + "weights": weights_bytes, + "log": log_bytes, + "cal_log": cal_log_bytes, + "config": config_bytes, + "blocks": blocks_bytes, + "geo_labels": geo_labels_bytes, + } + + +def _trigger_repository_dispatch(event_type: str = "calibration-updated"): + """Fire a repository_dispatch event on GitHub.""" + import json + import urllib.request + + token = os.environ.get( + "GITHUB_TOKEN", + os.environ.get("POLICYENGINE_US_DATA_GITHUB_TOKEN"), + ) + if not token: + print( + "WARNING: No GITHUB_TOKEN or " + "POLICYENGINE_US_DATA_GITHUB_TOKEN found. " + "Skipping repository_dispatch.", + flush=True, + ) + return False + + url = ( + "https://api.github.com/repos/" + "PolicyEngine/policyengine-us-data/dispatches" + ) + payload = json.dumps({"event_type": event_type}).encode() + req = urllib.request.Request( + url, + data=payload, + headers={ + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + method="POST", + ) + resp = urllib.request.urlopen(req) + print( + f"Triggered repository_dispatch '{event_type}' " + f"(HTTP {resp.status})", + flush=True, + ) + return True + + +def _upload_source_imputed(lines): + """Parse SOURCE_IMPUTED_PATH from output and upload to HF.""" + source_path = None + for line in lines: + if "SOURCE_IMPUTED_PATH:" in line: + raw = line.split("SOURCE_IMPUTED_PATH:")[1].strip() + source_path = raw.split("]")[-1].strip() if "]" in raw else raw + if not source_path or not os.path.exists(source_path): + return + print(f"Uploading source-imputed dataset: {source_path}", flush=True) + rc, _ = _run_streaming( + [ + "uv", + "run", + "python", + "-c", + "from policyengine_us_data.utils.huggingface import upload; " + f"upload('{source_path}', " + "'policyengine/policyengine-us-data', " + "'calibration/" + "source_imputed_stratified_extended_cps.h5')", + ], + env=os.environ.copy(), + label="upload-source-imputed", + ) + if rc != 0: + print( + "WARNING: Failed to upload source-imputed dataset", + flush=True, + ) + else: + print("Source-imputed dataset uploaded to HF", flush=True) + + +def _fit_weights_impl( + branch: str, + epochs: int, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + skip_county: bool = True, + workers: int = 1, +) -> dict: + """Full pipeline: download data, build matrix, fit weights.""" + _clone_and_install(branch) + + print("Downloading calibration inputs from HuggingFace...", flush=True) + dl_rc, dl_lines = _run_streaming( [ "uv", "run", @@ -36,66 +232,324 @@ def _fit_weights_impl(branch: str, epochs: int) -> dict: "print(f\"DB: {paths['database']}\"); " "print(f\"DATASET: {paths['dataset']}\")", ], - capture_output=True, - text=True, env=os.environ.copy(), + label="download", ) - print(download_result.stdout) - if download_result.stderr: - print("Download STDERR:", download_result.stderr) - if download_result.returncode != 0: - raise RuntimeError(f"Download failed: {download_result.returncode}") + if dl_rc != 0: + raise RuntimeError(f"Download failed with code {dl_rc}") db_path = dataset_path = None - for line in download_result.stdout.split("\n"): - if line.startswith("DB:"): + for line in dl_lines: + if "DB:" in line: db_path = line.split("DB:")[1].strip() - elif line.startswith("DATASET:"): + elif "DATASET:" in line: dataset_path = line.split("DATASET:")[1].strip() script_path = "policyengine_us_data/calibration/unified_calibration.py" - result = subprocess.run( + cmd = [ + "uv", + "run", + "python", + script_path, + "--device", + "cuda", + "--epochs", + str(epochs), + "--db-path", + db_path, + "--dataset", + dataset_path, + ] + if target_config: + cmd.extend(["--target-config", target_config]) + if not skip_county: + cmd.append("--county-level") + if workers > 1: + cmd.extend(["--workers", str(workers)]) + _append_hyperparams( + cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq + ) + + cal_rc, cal_lines = _run_streaming( + cmd, + env=os.environ.copy(), + label="calibrate", + ) + if cal_rc != 0: + raise RuntimeError(f"Script failed with code {cal_rc}") + + _upload_source_imputed(cal_lines) + + return _collect_outputs(cal_lines) + + +def _fit_from_package_impl( + branch: str, + epochs: int, + volume_package_path: str = None, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, +) -> dict: + """Fit weights from a pre-built calibration package.""" + if not volume_package_path: + raise ValueError("volume_package_path is required") + + _clone_and_install(branch) + + pkg_path = "/root/calibration_package.pkl" + import shutil + + shutil.copy(volume_package_path, pkg_path) + size = os.path.getsize(pkg_path) + print( + f"Copied package from volume ({size:,} bytes) to {pkg_path}", + flush=True, + ) + + script_path = "policyengine_us_data/calibration/unified_calibration.py" + cmd = [ + "uv", + "run", + "python", + script_path, + "--device", + "cuda", + "--epochs", + str(epochs), + "--package-path", + pkg_path, + ] + if target_config: + cmd.extend(["--target-config", target_config]) + _append_hyperparams( + cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq + ) + + print(f"Running command: {' '.join(cmd)}", flush=True) + + cal_rc, cal_lines = _run_streaming( + cmd, + env=os.environ.copy(), + label="calibrate", + ) + if cal_rc != 0: + raise RuntimeError(f"Script failed with code {cal_rc}") + + return _collect_outputs(cal_lines) + + +def _print_provenance_from_meta( + meta: dict, current_branch: str = None +) -> None: + """Print provenance info and warn on branch mismatch.""" + built = meta.get("created_at", "unknown") + branch = meta.get("git_branch", "unknown") + commit = meta.get("git_commit") + commit_short = commit[:8] if commit else "unknown" + dirty = " (DIRTY)" if meta.get("git_dirty") else "" + version = meta.get("package_version", "unknown") + print("--- Package Provenance ---", flush=True) + print(f" Built: {built}", flush=True) + print( + f" Branch: {branch} @ {commit_short}{dirty}", + flush=True, + ) + print(f" Version: {version}", flush=True) + print("--------------------------", flush=True) + if current_branch and branch != "unknown" and branch != current_branch: + print( + f"WARNING: Package built on branch " + f"'{branch}', but fitting with " + f"--branch {current_branch}", + flush=True, + ) + + +def _write_package_sidecar(pkg_path: str) -> None: + """Extract metadata from a pickle package and write a JSON sidecar.""" + import json + import pickle + + sidecar_path = pkg_path.replace(".pkl", "_meta.json") + try: + with open(pkg_path, "rb") as f: + package = pickle.load(f) + meta = package.get("metadata", {}) + del package + with open(sidecar_path, "w") as f: + json.dump(meta, f, indent=2) + print( + f"Sidecar metadata written to {sidecar_path}", + flush=True, + ) + except Exception as e: + print( + f"WARNING: Failed to write sidecar: {e}", + flush=True, + ) + + +def _build_package_impl( + branch: str, + target_config: str = None, + skip_county: bool = True, + workers: int = 1, +) -> str: + """Download data, build X matrix, save package to volume.""" + _clone_and_install(branch) + + print( + "Downloading calibration inputs from HuggingFace...", + flush=True, + ) + dl_rc, dl_lines = _run_streaming( [ "uv", "run", "python", - script_path, - "--device", - "cuda", - "--epochs", - str(epochs), - "--db-path", - db_path, - "--dataset", - dataset_path, + "-c", + "from policyengine_us_data.utils.huggingface import " + "download_calibration_inputs; " + "paths = download_calibration_inputs(" + "'/root/calibration_data'); " + "print(f\"DB: {paths['database']}\"); " + "print(f\"DATASET: {paths['dataset']}\")", ], - capture_output=True, - text=True, env=os.environ.copy(), + label="download", ) - print(result.stdout) - if result.stderr: - print("STDERR:", result.stderr) - if result.returncode != 0: - raise RuntimeError(f"Script failed with code {result.returncode}") + if dl_rc != 0: + raise RuntimeError(f"Download failed with code {dl_rc}") - output_path = None - log_path = None - for line in result.stdout.split("\n"): - if "OUTPUT_PATH:" in line: - output_path = line.split("OUTPUT_PATH:")[1].strip() - elif "LOG_PATH:" in line: - log_path = line.split("LOG_PATH:")[1].strip() + db_path = dataset_path = None + for line in dl_lines: + if "DB:" in line: + db_path = line.split("DB:")[1].strip() + elif "DATASET:" in line: + dataset_path = line.split("DATASET:")[1].strip() - with open(output_path, "rb") as f: - weights_bytes = f.read() + pkg_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + script_path = "policyengine_us_data/calibration/unified_calibration.py" + cmd = [ + "uv", + "run", + "python", + script_path, + "--device", + "cpu", + "--epochs", + "0", + "--db-path", + db_path, + "--dataset", + dataset_path, + "--build-only", + "--package-output", + pkg_path, + ] + if target_config: + cmd.extend(["--target-config", target_config]) + if not skip_county: + cmd.append("--county-level") + if workers > 1: + cmd.extend(["--workers", str(workers)]) - log_bytes = None - if log_path: - with open(log_path, "rb") as f: - log_bytes = f.read() + build_rc, build_lines = _run_streaming( + cmd, + env=os.environ.copy(), + label="build", + ) + if build_rc != 0: + raise RuntimeError(f"Package build failed with code {build_rc}") - return {"weights": weights_bytes, "log": log_bytes} + _upload_source_imputed(build_lines) + + _write_package_sidecar(pkg_path) + + size = os.path.getsize(pkg_path) + print( + f"Package saved to volume at {pkg_path} " f"({size:,} bytes)", + flush=True, + ) + calibration_vol.commit() + return pkg_path + + +@app.function( + image=image, + secrets=[hf_secret], + memory=65536, + cpu=4.0, + timeout=36000, + volumes={VOLUME_MOUNT: calibration_vol}, +) +def build_package_remote( + branch: str = "main", + target_config: str = None, + skip_county: bool = True, + workers: int = 1, +) -> str: + return _build_package_impl( + branch, + target_config=target_config, + skip_county=skip_county, + workers=workers, + ) + + +@app.function( + image=image, + timeout=30, + volumes={VOLUME_MOUNT: calibration_vol}, +) +def check_volume_package() -> dict: + """Check if a calibration package exists on the volume. + + Reads the lightweight JSON sidecar for provenance fields. + Falls back to size/mtime if sidecar is missing. + """ + import datetime + import json + + pkg_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + sidecar_path = f"{VOLUME_MOUNT}/calibration_package_meta.json" + if not os.path.exists(pkg_path): + return {"exists": False} + + stat = os.stat(pkg_path) + mtime = datetime.datetime.fromtimestamp( + stat.st_mtime, tz=datetime.timezone.utc + ) + info = { + "exists": True, + "size": stat.st_size, + "modified": mtime.strftime("%Y-%m-%d %H:%M UTC"), + } + if os.path.exists(sidecar_path): + try: + with open(sidecar_path) as f: + meta = json.load(f) + for key in ( + "git_branch", + "git_commit", + "git_dirty", + "package_version", + "created_at", + "dataset_sha256", + "db_sha256", + ): + if key in meta: + info[key] = meta[key] + except Exception: + pass + return info + + +# --- Full pipeline GPU functions --- @app.function( @@ -106,8 +560,30 @@ def _fit_weights_impl(branch: str, epochs: int) -> dict: gpu="T4", timeout=14400, ) -def fit_weights_t4(branch: str = "main", epochs: int = 200) -> dict: - return _fit_weights_impl(branch, epochs) +def fit_weights_t4( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + skip_county: bool = True, + workers: int = 1, +) -> dict: + return _fit_weights_impl( + branch, + epochs, + target_config, + beta, + lambda_l0, + lambda_l2, + learning_rate, + log_freq, + skip_county=skip_county, + workers=workers, + ) @app.function( @@ -118,8 +594,30 @@ def fit_weights_t4(branch: str = "main", epochs: int = 200) -> dict: gpu="A10", timeout=14400, ) -def fit_weights_a10(branch: str = "main", epochs: int = 200) -> dict: - return _fit_weights_impl(branch, epochs) +def fit_weights_a10( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + skip_county: bool = True, + workers: int = 1, +) -> dict: + return _fit_weights_impl( + branch, + epochs, + target_config, + beta, + lambda_l0, + lambda_l2, + learning_rate, + log_freq, + skip_county=skip_county, + workers=workers, + ) @app.function( @@ -130,8 +628,30 @@ def fit_weights_a10(branch: str = "main", epochs: int = 200) -> dict: gpu="A100-40GB", timeout=14400, ) -def fit_weights_a100_40(branch: str = "main", epochs: int = 200) -> dict: - return _fit_weights_impl(branch, epochs) +def fit_weights_a100_40( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + skip_county: bool = True, + workers: int = 1, +) -> dict: + return _fit_weights_impl( + branch, + epochs, + target_config, + beta, + lambda_l0, + lambda_l2, + learning_rate, + log_freq, + skip_county=skip_county, + workers=workers, + ) @app.function( @@ -142,8 +662,30 @@ def fit_weights_a100_40(branch: str = "main", epochs: int = 200) -> dict: gpu="A100-80GB", timeout=14400, ) -def fit_weights_a100_80(branch: str = "main", epochs: int = 200) -> dict: - return _fit_weights_impl(branch, epochs) +def fit_weights_a100_80( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + skip_county: bool = True, + workers: int = 1, +) -> dict: + return _fit_weights_impl( + branch, + epochs, + target_config, + beta, + lambda_l0, + lambda_l2, + learning_rate, + log_freq, + skip_county=skip_county, + workers=workers, + ) @app.function( @@ -154,8 +696,30 @@ def fit_weights_a100_80(branch: str = "main", epochs: int = 200) -> dict: gpu="H100", timeout=14400, ) -def fit_weights_h100(branch: str = "main", epochs: int = 200) -> dict: - return _fit_weights_impl(branch, epochs) +def fit_weights_h100( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + skip_county: bool = True, + workers: int = 1, +) -> dict: + return _fit_weights_impl( + branch, + epochs, + target_config, + beta, + lambda_l0, + lambda_l2, + learning_rate, + log_freq, + skip_county=skip_county, + workers=workers, + ) GPU_FUNCTIONS = { @@ -167,22 +731,344 @@ def fit_weights_h100(branch: str = "main", epochs: int = 200) -> dict: } +# --- Package-path GPU functions --- + + +@app.function( + image=image, + memory=32768, + cpu=4.0, + gpu="T4", + timeout=14400, + volumes={"/calibration-data": calibration_vol}, +) +def fit_from_package_t4( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + volume_package_path: str = None, +) -> dict: + return _fit_from_package_impl( + branch, + epochs, + volume_package_path=volume_package_path, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + ) + + +@app.function( + image=image, + memory=32768, + cpu=4.0, + gpu="A10", + timeout=14400, + volumes={"/calibration-data": calibration_vol}, +) +def fit_from_package_a10( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + volume_package_path: str = None, +) -> dict: + return _fit_from_package_impl( + branch, + epochs, + volume_package_path=volume_package_path, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + ) + + +@app.function( + image=image, + memory=32768, + cpu=4.0, + gpu="A100-40GB", + timeout=14400, + volumes={"/calibration-data": calibration_vol}, +) +def fit_from_package_a100_40( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + volume_package_path: str = None, +) -> dict: + return _fit_from_package_impl( + branch, + epochs, + volume_package_path=volume_package_path, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + ) + + +@app.function( + image=image, + memory=32768, + cpu=4.0, + gpu="A100-80GB", + timeout=14400, + volumes={"/calibration-data": calibration_vol}, +) +def fit_from_package_a100_80( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + volume_package_path: str = None, +) -> dict: + return _fit_from_package_impl( + branch, + epochs, + volume_package_path=volume_package_path, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + ) + + +@app.function( + image=image, + memory=32768, + cpu=4.0, + gpu="H100", + timeout=14400, + volumes={"/calibration-data": calibration_vol}, +) +def fit_from_package_h100( + branch: str = "main", + epochs: int = 200, + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + volume_package_path: str = None, +) -> dict: + return _fit_from_package_impl( + branch, + epochs, + volume_package_path=volume_package_path, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + ) + + +PACKAGE_GPU_FUNCTIONS = { + "T4": fit_from_package_t4, + "A10": fit_from_package_a10, + "A100-40GB": fit_from_package_a100_40, + "A100-80GB": fit_from_package_a100_80, + "H100": fit_from_package_h100, +} + + @app.local_entrypoint() def main( branch: str = "main", epochs: int = 200, gpu: str = "T4", output: str = "calibration_weights.npy", - log_output: str = "calibration_log.csv", + log_output: str = "unified_diagnostics.csv", + target_config: str = None, + beta: float = None, + lambda_l0: float = None, + lambda_l2: float = None, + learning_rate: float = None, + log_freq: int = None, + package_path: str = None, + full_pipeline: bool = False, + county_level: bool = False, + workers: int = 1, + push_results: bool = False, + trigger_publish: bool = False, + national: bool = False, ): + prefix = "national_" if national else "" + if national: + if lambda_l0 is None: + lambda_l0 = 1e-4 + output = f"{prefix}{output}" + log_output = f"{prefix}{log_output}" + if gpu not in GPU_FUNCTIONS: raise ValueError( - f"Unknown GPU: {gpu}. Choose from: {list(GPU_FUNCTIONS.keys())}" + f"Unknown GPU: {gpu}. " + f"Choose from: {list(GPU_FUNCTIONS.keys())}" ) - print(f"Running with GPU: {gpu}, epochs: {epochs}, branch: {branch}") - func = GPU_FUNCTIONS[gpu] - result = func.remote(branch=branch, epochs=epochs) + if package_path: + vol_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + print(f"Reading package from {package_path}...", flush=True) + import json as _json + import pickle as _pkl + + with open(package_path, "rb") as f: + package_bytes = f.read() + size = len(package_bytes) + # Extract metadata for sidecar + pkg_meta = _pkl.loads(package_bytes).get("metadata", {}) + sidecar_bytes = _json.dumps(pkg_meta, indent=2).encode() + print( + f"Uploading package ({size:,} bytes) to Modal volume...", + flush=True, + ) + with calibration_vol.batch_upload(force=True) as batch: + from io import BytesIO + + batch.put( + BytesIO(package_bytes), + "calibration_package.pkl", + ) + batch.put( + BytesIO(sidecar_bytes), + "calibration_package_meta.json", + ) + calibration_vol.commit() + del package_bytes + print("Upload complete.", flush=True) + _print_provenance_from_meta(pkg_meta, branch) + func = PACKAGE_GPU_FUNCTIONS[gpu] + result = func.remote( + branch=branch, + epochs=epochs, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + volume_package_path=vol_path, + ) + elif full_pipeline: + print( + "========================================", + flush=True, + ) + print( + "Mode: full pipeline (download, build matrix, fit)", + flush=True, + ) + print( + f"GPU: {gpu} | Epochs: {epochs} | " f"Branch: {branch}", + flush=True, + ) + print( + "========================================", + flush=True, + ) + func = GPU_FUNCTIONS[gpu] + result = func.remote( + branch=branch, + epochs=epochs, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + skip_county=not county_level, + workers=workers, + ) + else: + vol_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + vol_info = check_volume_package.remote() + if not vol_info["exists"]: + raise SystemExit( + "\nNo calibration package found on Modal volume.\n" + "Run 'make build-matrices' first, or use " + "--full-pipeline to build from scratch.\n" + ) + if vol_info.get("created_at") or vol_info.get("git_branch"): + _print_provenance_from_meta(vol_info, branch) + mode_label = ( + "national calibration" + if national + else "fitting from pre-built package" + ) + print( + "========================================", + flush=True, + ) + print(f"Mode: {mode_label}", flush=True) + print( + f"GPU: {gpu} | Epochs: {epochs} | " f"Branch: {branch}", + flush=True, + ) + if push_results: + print( + "After fitting, will upload to HuggingFace:", + flush=True, + ) + print( + f" - calibration/{prefix}calibration_weights.npy", + flush=True, + ) + print( + f" - calibration/{prefix}stacked_blocks.npy", + flush=True, + ) + print( + f" - calibration/logs/{prefix}* (diagnostics, " + "config, calibration log)", + flush=True, + ) + print( + "========================================", + flush=True, + ) + func = PACKAGE_GPU_FUNCTIONS[gpu] + result = func.remote( + branch=branch, + epochs=epochs, + target_config=target_config, + beta=beta, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + volume_package_path=vol_path, + ) with open(output, "wb") as f: f.write(result["weights"]) @@ -191,4 +1077,96 @@ def main( if result["log"]: with open(log_output, "wb") as f: f.write(result["log"]) - print(f"Calibration log saved to: {log_output}") + print(f"Diagnostics log saved to: {log_output}") + + cal_log_output = f"{prefix}calibration_log.csv" + if result.get("cal_log"): + with open(cal_log_output, "wb") as f: + f.write(result["cal_log"]) + print(f"Calibration log saved to: {cal_log_output}") + + config_output = f"{prefix}unified_run_config.json" + if result.get("config"): + with open(config_output, "wb") as f: + f.write(result["config"]) + print(f"Run config saved to: {config_output}") + + blocks_output = f"{prefix}stacked_blocks.npy" + if result.get("blocks"): + with open(blocks_output, "wb") as f: + f.write(result["blocks"]) + print(f"Stacked blocks saved to: {blocks_output}") + + geo_labels_output = f"{prefix}geo_labels.json" + if result.get("geo_labels"): + with open(geo_labels_output, "wb") as f: + f.write(result["geo_labels"]) + print(f"Geo labels saved to: {geo_labels_output}") + + if push_results: + from policyengine_us_data.utils.huggingface import ( + upload_calibration_artifacts, + ) + + upload_calibration_artifacts( + weights_path=output, + blocks_path=(blocks_output if result.get("blocks") else None), + geo_labels_path=( + geo_labels_output if result.get("geo_labels") else None + ), + log_dir=".", + prefix=prefix, + ) + + if trigger_publish: + _trigger_repository_dispatch() + + +@app.local_entrypoint() +def build_package( + branch: str = "main", + target_config: str = None, + county_level: bool = False, + workers: int = 1, +): + """Build the calibration package (X matrix) on CPU and save + to Modal volume. Then run main() to fit.""" + print( + "========================================", + flush=True, + ) + print( + f"Mode: building calibration package (CPU only)", + flush=True, + ) + print(f"Branch: {branch}", flush=True) + print( + "This builds the X matrix and saves it to " "a Modal volume.", + flush=True, + ) + print( + "No GPU is used. Timeout: 10 hours.", + flush=True, + ) + print( + "========================================", + flush=True, + ) + vol_path = build_package_remote.remote( + branch=branch, + target_config=target_config, + skip_county=not county_level, + workers=workers, + ) + print( + f"Package built and saved to Modal volume at {vol_path}", + flush=True, + ) + print( + "\nTo fit weights, run:\n" + " modal run modal_app/remote_calibration_runner.py" + "::main \\\n" + f" --branch {branch} --gpu " + "--epochs --push-results", + flush=True, + ) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index b197260e8..e19f0c82b 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -20,6 +20,24 @@ def main(): parser.add_argument("--dataset-path", required=True) parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--calibration-blocks", + type=str, + default=None, + help="Path to stacked_blocks.npy from calibration", + ) + parser.add_argument( + "--geo-labels", + type=str, + default=None, + help="Path to geo_labels.json (overrides DB lookup)", + ) + parser.add_argument( + "--stacked-takeup", + type=str, + default=None, + help="Path to stacked_takeup.npz from calibration", + ) args = parser.parse_args() work_items = json.loads(args.work_items) @@ -28,18 +46,43 @@ def main(): db_path = Path(args.db_path) output_dir = Path(args.output_dir) - from policyengine_us_data.datasets.cps.local_area_calibration.publish_local_area import ( - build_state_h5, - build_district_h5, - build_city_h5, + calibration_blocks = None + if args.calibration_blocks: + calibration_blocks = np.load(args.calibration_blocks) + + stacked_takeup = None + if args.stacked_takeup: + stacked_takeup = dict(np.load(args.stacked_takeup)) + + from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + ) + + takeup_filter = [ + info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values() + ] + + original_stdout = sys.stdout + sys.stdout = sys.stderr + + from policyengine_us_data.calibration.publish_local_area import ( + build_h5, + NYC_COUNTIES, + NYC_CDS, + AT_LARGE_DISTRICTS, ) - from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + from policyengine_us_data.calibration.calibration_utils import ( get_all_cds_from_database, + load_geo_labels, STATE_CODES, ) - db_uri = f"sqlite:///{db_path}" - cds_to_calibrate = get_all_cds_from_database(db_uri) + if args.geo_labels and Path(args.geo_labels).exists(): + geo_labels = load_geo_labels(args.geo_labels) + else: + db_uri = f"sqlite:///{db_path}" + geo_labels = get_all_cds_from_database(db_uri) + cds_to_calibrate = geo_labels weights = np.load(weights_path) results = { @@ -54,44 +97,124 @@ def main(): try: if item_type == "state": - path = build_state_h5( - state_code=item_id, + state_fips = None + for fips, code in STATE_CODES.items(): + if code == item_id: + state_fips = fips + break + if state_fips is None: + raise ValueError(f"Unknown state code: {item_id}") + cd_subset = [ + cd + for cd in cds_to_calibrate + if int(cd) // 100 == state_fips + ] + if not cd_subset: + print( + f"No CDs for {item_id}, skipping", + file=sys.stderr, + ) + continue + states_dir = output_dir / "states" + states_dir.mkdir(parents=True, exist_ok=True) + path = build_h5( weights=weights, - cds_to_calibrate=cds_to_calibrate, + blocks=calibration_blocks, dataset_path=dataset_path, - output_dir=output_dir, + output_path=states_dir / f"{item_id}.h5", + cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, ) + elif item_type == "district": state_code, dist_num = item_id.split("-") - geoid = None + state_fips = None for fips, code in STATE_CODES.items(): if code == state_code: - geoid = f"{fips}{int(dist_num):02d}" + state_fips = fips break - if geoid is None: + if state_fips is None: raise ValueError(f"Unknown state in district: {item_id}") - path = build_district_h5( - cd_geoid=geoid, + candidate = f"{state_fips}{int(dist_num):02d}" + if candidate in geo_labels: + geoid = candidate + else: + state_cds = [ + cd for cd in geo_labels if int(cd) // 100 == state_fips + ] + if len(state_cds) == 1: + geoid = state_cds[0] + else: + raise ValueError( + f"CD {candidate} not found and " + f"state {state_code} has " + f"{len(state_cds)} CDs" + ) + + cd_int = int(geoid) + district_num = cd_int % 100 + if district_num in AT_LARGE_DISTRICTS: + district_num = 1 + friendly_name = f"{state_code}-{district_num:02d}" + + districts_dir = output_dir / "districts" + districts_dir.mkdir(parents=True, exist_ok=True) + path = build_h5( weights=weights, - cds_to_calibrate=cds_to_calibrate, + blocks=calibration_blocks, dataset_path=dataset_path, - output_dir=output_dir, + output_path=districts_dir / f"{friendly_name}.h5", + cds_to_calibrate=cds_to_calibrate, + cd_subset=[geoid], + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, ) + elif item_type == "city": - path = build_city_h5( - city_name=item_id, + cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] + if not cd_subset: + print( + "No NYC CDs found, skipping", + file=sys.stderr, + ) + continue + cities_dir = output_dir / "cities" + cities_dir.mkdir(parents=True, exist_ok=True) + path = build_h5( weights=weights, + blocks=calibration_blocks, + dataset_path=dataset_path, + output_path=cities_dir / "NYC.h5", cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + county_filter=NYC_COUNTIES, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + elif item_type == "national": + national_dir = output_dir / "national" + national_dir.mkdir(parents=True, exist_ok=True) + path = build_h5( + weights=weights, + blocks=calibration_blocks, dataset_path=dataset_path, - output_dir=output_dir, + output_path=national_dir / "US.h5", + cds_to_calibrate=cds_to_calibrate, + stacked_takeup=stacked_takeup, ) else: raise ValueError(f"Unknown item type: {item_type}") if path: results["completed"].append(f"{item_type}:{item_id}") - print(f"Completed {item_type}:{item_id}", file=sys.stderr) + print( + f"Completed {item_type}:{item_id}", + file=sys.stderr, + ) except Exception as e: results["failed"].append(f"{item_type}:{item_id}") @@ -102,8 +225,12 @@ def main(): "traceback": traceback.format_exc(), } ) - print(f"FAILED {item_type}:{item_id}: {e}", file=sys.stderr) + print( + f"FAILED {item_type}:{item_id}: {e}", + file=sys.stderr, + ) + sys.stdout = original_stdout print(json.dumps(results)) diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/block_assignment.py b/policyengine_us_data/calibration/block_assignment.py similarity index 86% rename from policyengine_us_data/datasets/cps/local_area_calibration/block_assignment.py rename to policyengine_us_data/calibration/block_assignment.py index 73b435f69..ddeafa378 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/block_assignment.py +++ b/policyengine_us_data/calibration/block_assignment.py @@ -100,22 +100,33 @@ def _build_county_fips_to_enum() -> Dict[str, str]: return fips_to_enum -def get_county_enum_index_from_block(block_geoid: str) -> int: - """ - Get County enum index from block GEOID. +def get_county_enum_index_from_fips(county_fips: str) -> int: + """Get County enum index from 5-digit county FIPS. Args: - block_geoid: 15-digit census block GEOID + county_fips: 5-digit county FIPS code (e.g. "37183") Returns: Integer index into County enum, or UNKNOWN index if not found """ - county_fips = get_county_fips_from_block(block_geoid) fips_to_enum = _build_county_fips_to_enum() enum_name = fips_to_enum.get(county_fips, "UNKNOWN") return County._member_names_.index(enum_name) +def get_county_enum_index_from_block(block_geoid: str) -> int: + """Get County enum index from block GEOID. + + Args: + block_geoid: 15-digit census block GEOID + + Returns: + Integer index into County enum, or UNKNOWN index if not found + """ + county_fips = get_county_fips_from_block(block_geoid) + return get_county_enum_index_from_fips(county_fips) + + # === CBSA Lookup === @@ -338,7 +349,7 @@ def _generate_fallback_blocks(cd_geoid: str, n_households: int) -> np.ndarray: Array of 15-character block GEOID strings """ # Import here to avoid circular dependency - from policyengine_us_data.datasets.cps.local_area_calibration.county_assignment import ( + from policyengine_us_data.calibration.county_assignment import ( assign_counties_for_cd, ) @@ -508,6 +519,82 @@ def assign_geography_for_cd( } +def derive_geography_from_blocks( + block_geoids: np.ndarray, +) -> Dict[str, np.ndarray]: + """Derive all geography from pre-assigned block GEOIDs. + + Given an array of block GEOIDs (already assigned by + calibration), derives county, tract, state, CBSA, SLDU, + SLDL, place, VTD, PUMA, ZCTA, and county enum index. + + Args: + block_geoids: Array of 15-char block GEOID strings. + + Returns: + Dict with same keys as assign_geography_for_cd. + """ + county_fips = np.array( + [get_county_fips_from_block(b) for b in block_geoids] + ) + tract_geoids = np.array( + [get_tract_geoid_from_block(b) for b in block_geoids] + ) + state_fips = np.array([get_state_fips_from_block(b) for b in block_geoids]) + cbsa_codes = np.array([get_cbsa_from_county(c) or "" for c in county_fips]) + county_indices = np.array( + [get_county_enum_index_from_block(b) for b in block_geoids], + dtype=np.int32, + ) + + crosswalk = _load_block_crosswalk() + has_zcta = "zcta" in crosswalk.columns + + sldu_list = [] + sldl_list = [] + place_fips_list = [] + vtd_list = [] + puma_list = [] + zcta_list = [] + + for b in block_geoids: + if not crosswalk.empty and b in crosswalk.index: + row = crosswalk.loc[b] + sldu_list.append(row["sldu"] if pd.notna(row["sldu"]) else "") + sldl_list.append(row["sldl"] if pd.notna(row["sldl"]) else "") + place_fips_list.append( + row["place_fips"] if pd.notna(row["place_fips"]) else "" + ) + vtd_list.append(row["vtd"] if pd.notna(row["vtd"]) else "") + puma_list.append(row["puma"] if pd.notna(row["puma"]) else "") + if has_zcta: + zcta_list.append(row["zcta"] if pd.notna(row["zcta"]) else "") + else: + zcta_list.append("") + else: + sldu_list.append("") + sldl_list.append("") + place_fips_list.append("") + vtd_list.append("") + puma_list.append("") + zcta_list.append("") + + return { + "block_geoid": block_geoids, + "county_fips": county_fips, + "tract_geoid": tract_geoids, + "state_fips": state_fips, + "cbsa_code": cbsa_codes, + "sldu": np.array(sldu_list), + "sldl": np.array(sldl_list), + "place_fips": np.array(place_fips_list), + "vtd": np.array(vtd_list), + "puma": np.array(puma_list), + "zcta": np.array(zcta_list), + "county_index": county_indices, + } + + # === County Filter Functions (for city-level datasets) === diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py b/policyengine_us_data/calibration/calibration_utils.py similarity index 86% rename from policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py rename to policyengine_us_data/calibration/calibration_utils.py index 97c82360d..6920955b9 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py +++ b/policyengine_us_data/calibration/calibration_utils.py @@ -3,6 +3,7 @@ """ from typing import Dict, List, Tuple +import json import numpy as np import pandas as pd @@ -521,6 +522,22 @@ def get_cd_index_mapping(db_uri: str = None): return cd_to_index, index_to_cd, cds_ordered +def save_geo_labels(labels: List[str], path) -> None: + """Save geo unit labels to JSON.""" + from pathlib import Path + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(labels, f) + + +def load_geo_labels(path) -> List[str]: + """Load geo unit labels from JSON.""" + with open(path) as f: + return json.load(f) + + def load_cd_geoadj_values( cds_to_calibrate: List[str], ) -> Dict[str, float]: @@ -548,93 +565,74 @@ def load_cd_geoadj_values( ) rent_lookup[row["cd_geoid"]] = geoadj - # Map each CD to calibrate to its geoadj value - # Handle at-large districts: database uses XX01, rent CSV uses XX00 geoadj_dict = {} for cd in cds_to_calibrate: if cd in rent_lookup: geoadj_dict[cd] = rent_lookup[cd] else: - # Try at-large mapping: XX01 -> XX00 - cd_int = int(cd) - state_fips = cd_int // 100 - district = cd_int % 100 - if district == 1: - at_large_cd = str(state_fips * 100) # XX00 - if at_large_cd in rent_lookup: - geoadj_dict[cd] = rent_lookup[at_large_cd] - continue - # Fallback to national average (geoadj = 1.0) print(f"Warning: No rent data for CD {cd}, using geoadj=1.0") geoadj_dict[cd] = 1.0 return geoadj_dict -def calculate_spm_thresholds_for_cd( - sim, - time_period: int, - geoadj: float, +def calculate_spm_thresholds_vectorized( + person_ages: np.ndarray, + person_spm_unit_ids: np.ndarray, + spm_unit_tenure_types: np.ndarray, + spm_unit_geoadj: np.ndarray, year: int, ) -> np.ndarray: + """Calculate SPM thresholds for cloned SPM units from raw arrays. + + Works without a Microsimulation instance. Counts adults/children + per SPM unit from person-level arrays, then computes + base_threshold * equivalence_scale * geoadj for each unit. + + Args: + person_ages: Age per cloned person. + person_spm_unit_ids: New SPM unit ID per cloned person + (0-based contiguous). + spm_unit_tenure_types: Tenure type string per cloned SPM + unit (e.g. b"RENTER", b"OWNER_WITH_MORTGAGE"). + spm_unit_geoadj: Geographic adjustment factor per cloned + SPM unit. + year: Tax year for base threshold lookup. + + Returns: + Float32 array of SPM thresholds, one per SPM unit. """ - Calculate SPM thresholds for all SPM units using CD-specific geo-adjustment. - """ - spm_unit_ids_person = sim.calculate("spm_unit_id", map_to="person").values - ages = sim.calculate("age", map_to="person").values - - df = pd.DataFrame( - { - "spm_unit_id": spm_unit_ids_person, - "is_adult": ages >= 18, - "is_child": ages < 18, - } - ) - - agg = ( - df.groupby("spm_unit_id") - .agg( - num_adults=("is_adult", "sum"), - num_children=("is_child", "sum"), + n_units = len(spm_unit_tenure_types) + + # Count adults and children per SPM unit + is_adult = person_ages >= 18 + num_adults = np.zeros(n_units, dtype=np.int32) + num_children = np.zeros(n_units, dtype=np.int32) + np.add.at(num_adults, person_spm_unit_ids, is_adult.astype(np.int32)) + np.add.at(num_children, person_spm_unit_ids, (~is_adult).astype(np.int32)) + + # Map tenure type strings to codes + tenure_codes = np.full(n_units, 3, dtype=np.int32) + for tenure_str, code in SPM_TENURE_STRING_TO_CODE.items(): + tenure_bytes = ( + tenure_str.encode() if isinstance(tenure_str, str) else tenure_str ) - .reset_index() - ) - - tenure_types = sim.calculate( - "spm_unit_tenure_type", map_to="spm_unit" - ).values - spm_unit_ids_unit = sim.calculate("spm_unit_id", map_to="spm_unit").values - - tenure_df = pd.DataFrame( - { - "spm_unit_id": spm_unit_ids_unit, - "tenure_type": tenure_types, - } - ) - - merged = agg.merge(tenure_df, on="spm_unit_id", how="left") - merged["tenure_code"] = ( - merged["tenure_type"] - .map(SPM_TENURE_STRING_TO_CODE) - .fillna(3) - .astype(int) - ) + mask = spm_unit_tenure_types == tenure_bytes + if not mask.any(): + mask = spm_unit_tenure_types == tenure_str + tenure_codes[mask] = code + # Look up base thresholds calc = SPMCalculator(year=year) base_thresholds = calc.get_base_thresholds() - n = len(merged) - thresholds = np.zeros(n, dtype=np.float32) - - for i in range(n): - tenure_str = TENURE_CODE_MAP.get( - int(merged.iloc[i]["tenure_code"]), "renter" - ) + thresholds = np.zeros(n_units, dtype=np.float32) + for i in range(n_units): + tenure_str = TENURE_CODE_MAP.get(int(tenure_codes[i]), "renter") base = base_thresholds[tenure_str] equiv_scale = spm_equivalence_scale( - int(merged.iloc[i]["num_adults"]), - int(merged.iloc[i]["num_children"]), + int(num_adults[i]), int(num_children[i]) ) - thresholds[i] = base * equiv_scale * geoadj + thresholds[i] = base * equiv_scale * spm_unit_geoadj[i] return thresholds diff --git a/policyengine_us_data/calibration/check_staging_sums.py b/policyengine_us_data/calibration/check_staging_sums.py new file mode 100644 index 000000000..9d2c4f879 --- /dev/null +++ b/policyengine_us_data/calibration/check_staging_sums.py @@ -0,0 +1,118 @@ +"""Sum key variables across all staging state H5 files. + +Quick smoke test: loads all 51 state H5s, sums key variables, +compares to national references. No database needed. ~10 min runtime. + +Usage: + python -m policyengine_us_data.calibration.check_staging_sums + python -m policyengine_us_data.calibration.check_staging_sums \ + --hf-prefix hf://policyengine/policyengine-us-data/staging +""" + +import argparse + +import pandas as pd + +from policyengine_us_data.calibration.calibration_utils import ( + STATE_CODES, +) + +STATE_ABBRS = sorted(STATE_CODES.values()) + +VARIABLES = [ + "adjusted_gross_income", + "employment_income", + "self_employment_income", + "tax_unit_partnership_s_corp_income", + "taxable_pension_income", + "dividend_income", + "net_capital_gains", + "rental_income", + "taxable_interest_income", + "social_security", + "snap", + "ssi", + "income_tax_before_credits", + "eitc", + "refundable_ctc", + "real_estate_taxes", + "rent", + "is_pregnant", + "person_count", + "household_count", +] + +DEFAULT_HF_PREFIX = "hf://policyengine/policyengine-us-data/staging/states" + + +def main(argv=None): + parser = argparse.ArgumentParser( + description="Sum key variables across staging state H5 files" + ) + parser.add_argument( + "--hf-prefix", + default=DEFAULT_HF_PREFIX, + help="HF path prefix for state H5 files " + f"(default: {DEFAULT_HF_PREFIX})", + ) + args = parser.parse_args(argv) + + from policyengine_us import Microsimulation + + results = {} + errors = [] + + for i, st in enumerate(STATE_ABBRS): + print( + f"[{i + 1}/{len(STATE_ABBRS)}] {st}...", + end=" ", + flush=True, + ) + try: + sim = Microsimulation(dataset=f"{args.hf_prefix}/{st}.h5") + row = {} + for var in VARIABLES: + try: + row[var] = float(sim.calculate(var).sum()) + except Exception: + row[var] = None + results[st] = row + print("OK") + except Exception as e: + errors.append((st, str(e))) + print(f"FAILED: {e}") + + df = pd.DataFrame(results).T + df.index.name = "state" + + print("\n" + "=" * 70) + print("NATIONAL TOTALS (sum across all states)") + print("=" * 70) + + totals = df.sum() + for var in VARIABLES: + val = totals[var] + if var in ("person_count", "household_count", "is_pregnant"): + print(f" {var:45s} {val:>15,.0f}") + else: + print(f" {var:45s} ${val:>15,.0f}") + + print("\n" + "=" * 70) + print("REFERENCE VALUES (approximate, for sanity checking)") + print("=" * 70) + print(" US GDP ~$29T, US population ~335M, ~130M households") + print(" Total AGI ~$15T, Employment income ~$10T") + print(" SNAP ~$110B, SSI ~$60B, Social Security ~$1.2T") + print(" EITC ~$60B, CTC ~$120B") + + if errors: + print(f"\n{len(errors)} states failed:") + for st, err in errors: + print(f" {st}: {err}") + + print("\nPer-state details saved to staging_sums.csv") + df.to_csv("staging_sums.csv") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py index 9aa64cbbc..e25af7d02 100644 --- a/policyengine_us_data/calibration/clone_and_assign.py +++ b/policyengine_us_data/calibration/clone_and_assign.py @@ -23,6 +23,7 @@ class GeographyAssignment: block_geoid: np.ndarray # str array, 15-char block GEOIDs cd_geoid: np.ndarray # str array of CD GEOIDs + county_fips: np.ndarray # str array of 5-char county FIPS state_fips: np.ndarray # int array of 2-digit state FIPS n_records: int n_clones: int @@ -52,7 +53,7 @@ def load_global_block_distribution(): df = pd.read_csv(csv_path, dtype={"block_geoid": str}) block_geoids = df["block_geoid"].values - cd_geoids = df["cd_geoid"].astype(str).values + cd_geoids = np.array(df["cd_geoid"].astype(str).tolist()) state_fips = np.array([int(b[:2]) for b in block_geoids]) probs = df["probability"].values.astype(np.float64) @@ -88,11 +89,44 @@ def assign_random_geography( n_total = n_records * n_clones rng = np.random.default_rng(seed) - indices = rng.choice(len(blocks), size=n_total, p=probs) + indices = np.empty(n_total, dtype=np.int64) + + # Clone 0: unrestricted draw + indices[:n_records] = rng.choice(len(blocks), size=n_records, p=probs) + + assigned_cds = np.empty((n_clones, n_records), dtype=object) + assigned_cds[0] = cds[indices[:n_records]] + + for clone_idx in range(1, n_clones): + start = clone_idx * n_records + clone_indices = rng.choice(len(blocks), size=n_records, p=probs) + clone_cds = cds[clone_indices] + + collisions = np.zeros(n_records, dtype=bool) + for prev in range(clone_idx): + collisions |= clone_cds == assigned_cds[prev] + + for _ in range(50): + n_bad = collisions.sum() + if n_bad == 0: + break + clone_indices[collisions] = rng.choice( + len(blocks), size=n_bad, p=probs + ) + clone_cds = cds[clone_indices] + collisions = np.zeros(n_records, dtype=bool) + for prev in range(clone_idx): + collisions |= clone_cds == assigned_cds[prev] + + indices[start : start + n_records] = clone_indices + assigned_cds[clone_idx] = clone_cds + + assigned_blocks = blocks[indices] return GeographyAssignment( - block_geoid=blocks[indices], + block_geoid=assigned_blocks, cd_geoid=cds[indices], + county_fips=np.array([b[:5] for b in assigned_blocks]), state_fips=states[indices], n_records=n_records, n_clones=n_clones, @@ -124,6 +158,7 @@ def double_geography_for_puf( new_blocks = [] new_cds = [] + new_counties = [] new_states = [] for c in range(n_clones): @@ -131,14 +166,17 @@ def double_geography_for_puf( end = start + n_old clone_blocks = geography.block_geoid[start:end] clone_cds = geography.cd_geoid[start:end] + clone_counties = geography.county_fips[start:end] clone_states = geography.state_fips[start:end] new_blocks.append(np.concatenate([clone_blocks, clone_blocks])) new_cds.append(np.concatenate([clone_cds, clone_cds])) + new_counties.append(np.concatenate([clone_counties, clone_counties])) new_states.append(np.concatenate([clone_states, clone_states])) return GeographyAssignment( block_geoid=np.concatenate(new_blocks), cd_geoid=np.concatenate(new_cds), + county_fips=np.concatenate(new_counties), state_fips=np.concatenate(new_states), n_records=n_new, n_clones=n_clones, diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/county_assignment.py b/policyengine_us_data/calibration/county_assignment.py similarity index 98% rename from policyengine_us_data/datasets/cps/local_area_calibration/county_assignment.py rename to policyengine_us_data/calibration/county_assignment.py index 780bc4c77..6d32d30bd 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/county_assignment.py +++ b/policyengine_us_data/calibration/county_assignment.py @@ -38,7 +38,7 @@ def _build_state_counties() -> Dict[str, List[str]]: def _generate_uniform_distribution(cd_geoid: str) -> Dict[str, float]: """Generate uniform distribution across counties in CD's state.""" - from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + from policyengine_us_data.calibration.calibration_utils import ( STATE_CODES, ) diff --git a/policyengine_us_data/calibration/create_source_imputed_cps.py b/policyengine_us_data/calibration/create_source_imputed_cps.py new file mode 100644 index 000000000..4381f72dd --- /dev/null +++ b/policyengine_us_data/calibration/create_source_imputed_cps.py @@ -0,0 +1,91 @@ +"""Create source-imputed stratified extended CPS. + +Standalone step that runs ACS/SIPP/SCF source imputations on the +stratified extended CPS, producing the dataset used by calibration +and H5 generation. + +Usage: + python policyengine_us_data/calibration/create_source_imputed_cps.py +""" + +import logging +import sys +from pathlib import Path + +import h5py + +from policyengine_us_data.storage import STORAGE_FOLDER + +logger = logging.getLogger(__name__) + +INPUT_PATH = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") +OUTPUT_PATH = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) + + +def create_source_imputed_cps( + input_path: str = INPUT_PATH, + output_path: str = OUTPUT_PATH, + seed: int = 42, +): + from policyengine_us import Microsimulation + from policyengine_us_data.calibration.clone_and_assign import ( + assign_random_geography, + ) + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + logger.info("Loading dataset from %s", input_path) + sim = Microsimulation(dataset=input_path) + n_records = len(sim.calculate("household_id", map_to="household").values) + + raw_keys = sim.dataset.load_dataset()["household_id"] + if isinstance(raw_keys, dict): + time_period = int(next(iter(raw_keys))) + else: + time_period = 2024 + + logger.info("Loaded %d households, time_period=%d", n_records, time_period) + + geography = assign_random_geography( + n_records=n_records, n_clones=1, seed=seed + ) + base_states = geography.state_fips[:n_records] + + raw_data = sim.dataset.load_dataset() + data_dict = {} + for var in raw_data: + val = raw_data[var] + if isinstance(val, dict): + data_dict[var] = { + int(k) if k.isdigit() else k: v for k, v in val.items() + } + else: + data_dict[var] = {time_period: val[...]} + + logger.info("Running source imputations...") + data_dict = impute_source_variables( + data=data_dict, + state_fips=base_states, + time_period=time_period, + dataset_path=input_path, + ) + + logger.info("Saving to %s", output_path) + with h5py.File(output_path, "w") as f: + for var, time_dict in data_dict.items(): + for tp, values in time_dict.items(): + f.create_dataset(f"{var}/{tp}", data=values) + + logger.info("Done.") + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + stream=sys.stderr, + ) + create_source_imputed_cps() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py b/policyengine_us_data/calibration/create_stratified_cps.py similarity index 100% rename from policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py rename to policyengine_us_data/calibration/create_stratified_cps.py diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py new file mode 100644 index 000000000..8d39bdf5a --- /dev/null +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -0,0 +1,1071 @@ +""" +Build local area H5 files, optionally uploading to GCP and Hugging Face. + +Downloads calibration inputs from HF, generates state/district H5s +with checkpointing. Uploads only occur when --upload is explicitly passed. + +Usage: + python publish_local_area.py [--skip-download] [--states-only] [--upload] +""" + +import numpy as np +from pathlib import Path +from typing import List + +from policyengine_us import Microsimulation +from policyengine_us_data.utils.huggingface import download_calibration_inputs +from policyengine_us_data.utils.data_upload import ( + upload_local_area_file, + upload_local_area_batch_to_hf, +) +from policyengine_us_data.calibration.calibration_utils import ( + get_all_cds_from_database, + STATE_CODES, + load_cd_geoadj_values, + calculate_spm_thresholds_vectorized, +) +from policyengine_us_data.calibration.block_assignment import ( + assign_geography_for_cd, + derive_geography_from_blocks, + get_county_filter_probability, +) +from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + apply_block_takeup_to_arrays, +) + +CHECKPOINT_FILE = Path("completed_states.txt") +CHECKPOINT_FILE_DISTRICTS = Path("completed_districts.txt") +CHECKPOINT_FILE_CITIES = Path("completed_cities.txt") +WORK_DIR = Path("local_area_build") + +NYC_COUNTIES = { + "QUEENS_COUNTY_NY", + "BRONX_COUNTY_NY", + "RICHMOND_COUNTY_NY", + "NEW_YORK_COUNTY_NY", + "KINGS_COUNTY_NY", +} + +NYC_CDS = [ + "3603", + "3605", + "3606", + "3607", + "3608", + "3609", + "3610", + "3611", + "3612", + "3613", + "3614", + "3615", + "3616", +] + + +def load_completed_states() -> set: + if CHECKPOINT_FILE.exists(): + content = CHECKPOINT_FILE.read_text().strip() + if content: + return set(content.split("\n")) + return set() + + +def record_completed_state(state_code: str): + with open(CHECKPOINT_FILE, "a") as f: + f.write(f"{state_code}\n") + + +def load_completed_districts() -> set: + if CHECKPOINT_FILE_DISTRICTS.exists(): + content = CHECKPOINT_FILE_DISTRICTS.read_text().strip() + if content: + return set(content.split("\n")) + return set() + + +def record_completed_district(district_name: str): + with open(CHECKPOINT_FILE_DISTRICTS, "a") as f: + f.write(f"{district_name}\n") + + +def load_completed_cities() -> set: + if CHECKPOINT_FILE_CITIES.exists(): + content = CHECKPOINT_FILE_CITIES.read_text().strip() + if content: + return set(content.split("\n")) + return set() + + +def record_completed_city(city_name: str): + with open(CHECKPOINT_FILE_CITIES, "a") as f: + f.write(f"{city_name}\n") + + +def build_h5( + weights: np.ndarray, + blocks: np.ndarray, + dataset_path: Path, + output_path: Path, + cds_to_calibrate: List[str], + cd_subset: List[str] = None, + county_filter: set = None, + takeup_filter: List[str] = None, + stacked_takeup: dict = None, +) -> Path: + """Build an H5 file by cloning records for each nonzero weight. + + Uses fancy indexing on a single loaded simulation instead of + looping over CDs. + + Each nonzero entry in the (n_geo, n_hh) weight matrix represents + a distinct household clone. This function clones entity arrays, + derives geography from blocks, reindexes entity IDs, recalculates + SPM thresholds, applies calibration takeup draws (when blocks are + provided), and writes the H5. + + Args: + weights: Stacked weight vector, shape (n_geo * n_hh,). + blocks: Block GEOID per weight entry, same shape. When + provided, calibration takeup draws are applied + automatically. + dataset_path: Path to base dataset H5 file. + output_path: Where to write the output H5 file. + cds_to_calibrate: Ordered list of CD GEOIDs defining + weight matrix row ordering. + cd_subset: If provided, only include rows for these CDs. + county_filter: If provided, scale weights by P(target|CD) + for city datasets. + takeup_filter: List of takeup vars to apply. + stacked_takeup: Pre-computed weight-averaged takeup per + (CD, entity). Dict of {var_name: (n_cds, n_ent)}. + When provided, overrides block-based takeup draws. + + Returns: + Path to the output H5 file. + """ + import h5py + from collections import defaultdict + from policyengine_core.enums import Enum + from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( + County, + ) + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + apply_takeup = blocks is not None + + # === Load base simulation === + sim = Microsimulation(dataset=str(dataset_path)) + time_period = int(sim.default_calculation_period) + household_ids = sim.calculate("household_id", map_to="household").values + n_hh = len(household_ids) + + if weights.shape[0] % n_hh != 0: + raise ValueError( + f"Weight vector length {weights.shape[0]} is not " + f"divisible by n_hh={n_hh}" + ) + n_geo = weights.shape[0] // n_hh + + # Generate blocks from assign_geography_for_cd if not provided + if blocks is None: + print("No blocks provided, generating from CD assignments...") + all_blocks = np.empty(n_geo * n_hh, dtype="U15") + for geo_idx, cd in enumerate(cds_to_calibrate): + geo = assign_geography_for_cd( + cd_geoid=cd, + n_households=n_hh, + seed=42 + int(cd), + ) + start = geo_idx * n_hh + all_blocks[start : start + n_hh] = geo["block_geoid"] + blocks = all_blocks + + if len(blocks) != len(weights): + raise ValueError( + f"Blocks length {len(blocks)} != " f"weights length {len(weights)}" + ) + + # === Reshape and filter weight matrix === + W = weights.reshape(n_geo, n_hh).copy() + + # CD subset filtering: zero out rows for CDs not in subset + if cd_subset is not None: + cd_index_set = set() + for cd in cd_subset: + if cd not in cds_to_calibrate: + raise ValueError(f"CD {cd} not in calibrated CDs list") + cd_index_set.add(cds_to_calibrate.index(cd)) + for i in range(n_geo): + if i not in cd_index_set: + W[i, :] = 0 + + # County filtering: scale weights by P(target_counties | CD) + if county_filter is not None: + for geo_idx in range(n_geo): + cd = cds_to_calibrate[geo_idx] + p = get_county_filter_probability(cd, county_filter) + W[geo_idx, :] *= p + + n_active_cds = len(cd_subset) if cd_subset is not None else n_geo + label = ( + f"{n_active_cds} CDs" + if cd_subset is not None + else f"{n_geo} geo units" + ) + print(f"\n{'='*60}") + print(f"Building {output_path.name} ({label}, {n_hh} households)") + print(f"{'='*60}") + + # === Identify active clones === + active_geo, active_hh = np.where(W > 0) + n_clones = len(active_geo) + clone_weights = W[active_geo, active_hh] + active_blocks = blocks.reshape(n_geo, n_hh)[active_geo, active_hh] + + empty_count = np.sum(active_blocks == "") + if empty_count > 0: + raise ValueError( + f"{empty_count} active clones have empty block GEOIDs" + ) + + print(f"Active clones: {n_clones:,}") + print(f"Total weight: {clone_weights.sum():,.0f}") + + # === Build entity membership maps === + hh_id_to_idx = {int(hid): i for i, hid in enumerate(household_ids)} + person_hh_ids = sim.calculate("household_id", map_to="person").values + + hh_to_persons = defaultdict(list) + for p_idx, p_hh_id in enumerate(person_hh_ids): + hh_to_persons[hh_id_to_idx[int(p_hh_id)]].append(p_idx) + + SUB_ENTITIES = [ + "tax_unit", + "spm_unit", + "family", + "marital_unit", + ] + hh_to_entity = {} + entity_id_arrays = {} + person_entity_id_arrays = {} + + for ek in SUB_ENTITIES: + eids = sim.calculate(f"{ek}_id", map_to=ek).values + peids = sim.calculate(f"person_{ek}_id", map_to="person").values + entity_id_arrays[ek] = eids + person_entity_id_arrays[ek] = peids + eid_to_idx = {int(eid): i for i, eid in enumerate(eids)} + + mapping = defaultdict(list) + seen = defaultdict(set) + for p_idx in range(len(person_hh_ids)): + hh_idx = hh_id_to_idx[int(person_hh_ids[p_idx])] + e_idx = eid_to_idx[int(peids[p_idx])] + if e_idx not in seen[hh_idx]: + seen[hh_idx].add(e_idx) + mapping[hh_idx].append(e_idx) + for hh_idx in mapping: + mapping[hh_idx].sort() + hh_to_entity[ek] = mapping + + # === Build clone index arrays === + hh_clone_idx = active_hh + + persons_per_clone = np.array( + [len(hh_to_persons.get(h, [])) for h in active_hh] + ) + person_parts = [ + np.array(hh_to_persons.get(h, []), dtype=np.int64) for h in active_hh + ] + person_clone_idx = ( + np.concatenate(person_parts) + if person_parts + else np.array([], dtype=np.int64) + ) + + entity_clone_idx = {} + entities_per_clone = {} + for ek in SUB_ENTITIES: + epc = np.array([len(hh_to_entity[ek].get(h, [])) for h in active_hh]) + entities_per_clone[ek] = epc + parts = [ + np.array(hh_to_entity[ek].get(h, []), dtype=np.int64) + for h in active_hh + ] + entity_clone_idx[ek] = ( + np.concatenate(parts) if parts else np.array([], dtype=np.int64) + ) + + n_persons = len(person_clone_idx) + print(f"Cloned persons: {n_persons:,}") + for ek in SUB_ENTITIES: + print(f"Cloned {ek}s: {len(entity_clone_idx[ek]):,}") + + # === Build new entity IDs and cross-references === + new_hh_ids = np.arange(n_clones, dtype=np.int32) + new_person_ids = np.arange(n_persons, dtype=np.int32) + new_person_hh_ids = np.repeat(new_hh_ids, persons_per_clone) + + new_entity_ids = {} + new_person_entity_ids = {} + clone_ids_for_persons = np.repeat( + np.arange(n_clones, dtype=np.int64), persons_per_clone + ) + + for ek in SUB_ENTITIES: + n_ents = len(entity_clone_idx[ek]) + new_entity_ids[ek] = np.arange(n_ents, dtype=np.int32) + + old_eids = entity_id_arrays[ek][entity_clone_idx[ek]].astype(np.int64) + clone_ids_e = np.repeat( + np.arange(n_clones, dtype=np.int64), + entities_per_clone[ek], + ) + + offset = int(old_eids.max()) + 1 if len(old_eids) > 0 else 1 + entity_keys = clone_ids_e * offset + old_eids + + sorted_order = np.argsort(entity_keys) + sorted_keys = entity_keys[sorted_order] + sorted_new = new_entity_ids[ek][sorted_order] + + p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype( + np.int64 + ) + person_keys = clone_ids_for_persons * offset + p_old_eids + + positions = np.searchsorted(sorted_keys, person_keys) + positions = np.clip(positions, 0, len(sorted_keys) - 1) + new_person_entity_ids[ek] = sorted_new[positions] + + # === Derive geography from blocks (dedup optimization) === + print("Deriving geography from blocks...") + unique_blocks, block_inv = np.unique(active_blocks, return_inverse=True) + print(f" {n_clones:,} blocks -> " f"{len(unique_blocks):,} unique") + unique_geo = derive_geography_from_blocks(unique_blocks) + geography = {k: v[block_inv] for k, v in unique_geo.items()} + + # === Calculate weights for all entity levels === + person_weights = np.repeat(clone_weights, persons_per_clone) + per_person_wt = clone_weights / np.maximum(persons_per_clone, 1) + + entity_weights = {} + for ek in SUB_ENTITIES: + n_ents = len(entity_clone_idx[ek]) + ent_person_counts = np.zeros(n_ents, dtype=np.int32) + np.add.at( + ent_person_counts, + new_person_entity_ids[ek], + 1, + ) + clone_ids_e = np.repeat(np.arange(n_clones), entities_per_clone[ek]) + entity_weights[ek] = per_person_wt[clone_ids_e] * ent_person_counts + + # === Determine variables to save === + vars_to_save = set(sim.input_variables) + vars_to_save.add("county") + vars_to_save.add("spm_unit_spm_threshold") + vars_to_save.add("congressional_district_geoid") + for gv in [ + "block_geoid", + "tract_geoid", + "cbsa_code", + "sldu", + "sldl", + "place_fips", + "vtd", + "puma", + "zcta", + ]: + vars_to_save.add(gv) + + # === Clone variable arrays === + clone_idx_map = { + "household": hh_clone_idx, + "person": person_clone_idx, + } + for ek in SUB_ENTITIES: + clone_idx_map[ek] = entity_clone_idx[ek] + + data = {} + variables_saved = 0 + + for variable in sim.tax_benefit_system.variables: + if variable not in vars_to_save: + continue + + holder = sim.get_holder(variable) + periods = holder.get_known_periods() + if not periods: + continue + + var_def = sim.tax_benefit_system.variables.get(variable) + entity_key = var_def.entity.key + if entity_key not in clone_idx_map: + continue + + cidx = clone_idx_map[entity_key] + var_data = {} + + for period in periods: + values = holder.get_array(period) + + if var_def.value_type in (Enum, str) and variable != "county_fips": + if hasattr(values, "decode_to_str"): + values = values.decode_to_str().astype("S") + else: + values = values.astype("S") + elif variable == "county_fips": + values = values.astype("int32") + else: + values = np.array(values) + + var_data[period] = values[cidx] + variables_saved += 1 + + if var_data: + data[variable] = var_data + + print(f"Variables cloned: {variables_saved}") + + # === Override entity IDs === + data["household_id"] = {time_period: new_hh_ids} + data["person_id"] = {time_period: new_person_ids} + data["person_household_id"] = { + time_period: new_person_hh_ids, + } + for ek in SUB_ENTITIES: + data[f"{ek}_id"] = { + time_period: new_entity_ids[ek], + } + data[f"person_{ek}_id"] = { + time_period: new_person_entity_ids[ek], + } + + # === Override weights === + data["household_weight"] = { + time_period: clone_weights.astype(np.float32), + } + data["person_weight"] = { + time_period: person_weights.astype(np.float32), + } + for ek in SUB_ENTITIES: + data[f"{ek}_weight"] = { + time_period: entity_weights[ek].astype(np.float32), + } + + # === Override geography === + data["state_fips"] = { + time_period: geography["state_fips"].astype(np.int32), + } + county_names = np.array( + [County._member_names_[i] for i in geography["county_index"]] + ).astype("S") + data["county"] = {time_period: county_names} + data["county_fips"] = { + time_period: geography["county_fips"].astype(np.int32), + } + for gv in [ + "block_geoid", + "tract_geoid", + "cbsa_code", + "sldu", + "sldl", + "place_fips", + "vtd", + "puma", + "zcta", + ]: + if gv in geography: + data[gv] = { + time_period: geography[gv].astype("S"), + } + + # === Gap 4: Congressional district GEOID === + clone_cd_geoids = np.array( + [int(cds_to_calibrate[g]) for g in active_geo], + dtype=np.int32, + ) + data["congressional_district_geoid"] = { + time_period: clone_cd_geoids, + } + + # === Gap 1: SPM threshold recalculation === + print("Recalculating SPM thresholds...") + cd_geoadj_values = load_cd_geoadj_values(cds_to_calibrate) + # Build per-SPM-unit geoadj from clone's CD + spm_clone_ids = np.repeat( + np.arange(n_clones, dtype=np.int64), + entities_per_clone["spm_unit"], + ) + spm_unit_geoadj = np.array( + [ + cd_geoadj_values[cds_to_calibrate[active_geo[c]]] + for c in spm_clone_ids + ], + dtype=np.float64, + ) + + # Get cloned person ages and SPM unit IDs + person_ages = sim.calculate("age", map_to="person").values[ + person_clone_idx + ] + + # Get cloned tenure types + spm_tenure_holder = sim.get_holder("spm_unit_tenure_type") + spm_tenure_periods = spm_tenure_holder.get_known_periods() + if spm_tenure_periods: + raw_tenure = spm_tenure_holder.get_array(spm_tenure_periods[0]) + if hasattr(raw_tenure, "decode_to_str"): + raw_tenure = raw_tenure.decode_to_str().astype("S") + else: + raw_tenure = np.array(raw_tenure).astype("S") + spm_tenure_cloned = raw_tenure[entity_clone_idx["spm_unit"]] + else: + spm_tenure_cloned = np.full( + len(entity_clone_idx["spm_unit"]), + b"RENTER", + dtype="S30", + ) + + new_spm_thresholds = calculate_spm_thresholds_vectorized( + person_ages=person_ages, + person_spm_unit_ids=new_person_entity_ids["spm_unit"], + spm_unit_tenure_types=spm_tenure_cloned, + spm_unit_geoadj=spm_unit_geoadj, + year=time_period, + ) + data["spm_unit_spm_threshold"] = { + time_period: new_spm_thresholds, + } + + # === Apply calibration takeup draws === + if stacked_takeup is not None: + print("Applying pre-computed stacked takeup values...") + from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS, + ) + + var_to_entity = { + spec["variable"]: spec["entity"] for spec in SIMPLE_TAKEUP_VARS + } + for spec in SIMPLE_TAKEUP_VARS: + var_name = spec["variable"] + entity_level = spec["entity"] + + if entity_level == "person": + n_ent = n_persons + clone_ids_e = np.repeat(np.arange(n_clones), persons_per_clone) + base_indices = person_clone_idx + else: + n_ent = len(entity_clone_idx[entity_level]) + clone_ids_e = np.repeat( + np.arange(n_clones), + entities_per_clone[entity_level], + ) + base_indices = entity_clone_idx[entity_level] + + if var_name in stacked_takeup: + geo_indices = active_geo[clone_ids_e] + values = stacked_takeup[var_name][geo_indices, base_indices] + data[var_name] = {time_period: values.astype(np.float32)} + else: + data[var_name] = {time_period: np.ones(n_ent, dtype=bool)} + elif apply_takeup: + print("Applying calibration takeup draws...") + entity_hh_indices = { + "person": np.repeat( + np.arange(n_clones, dtype=np.int64), + persons_per_clone, + ).astype(np.int64), + "tax_unit": np.repeat( + np.arange(n_clones, dtype=np.int64), + entities_per_clone["tax_unit"], + ).astype(np.int64), + "spm_unit": np.repeat( + np.arange(n_clones, dtype=np.int64), + entities_per_clone["spm_unit"], + ).astype(np.int64), + } + entity_counts = { + "person": n_persons, + "tax_unit": len(entity_clone_idx["tax_unit"]), + "spm_unit": len(entity_clone_idx["spm_unit"]), + } + hh_state_fips = geography["state_fips"].astype(np.int32) + original_hh_ids = household_ids[active_hh].astype(np.int64) + + takeup_results = apply_block_takeup_to_arrays( + hh_blocks=active_blocks, + hh_state_fips=hh_state_fips, + hh_ids=original_hh_ids, + entity_hh_indices=entity_hh_indices, + entity_counts=entity_counts, + time_period=time_period, + takeup_filter=takeup_filter, + ) + for var_name, bools in takeup_results.items(): + data[var_name] = {time_period: bools} + + # === Write H5 === + with h5py.File(str(output_path), "w") as f: + for variable, periods in data.items(): + grp = f.create_group(variable) + for period, values in periods.items(): + grp.create_dataset(str(period), data=values) + + print(f"\nH5 saved to {output_path}") + + with h5py.File(str(output_path), "r") as f: + tp = str(time_period) + if "household_id" in f and tp in f["household_id"]: + n = len(f["household_id"][tp][:]) + print(f"Verified: {n:,} households in output") + if "person_id" in f and tp in f["person_id"]: + n = len(f["person_id"][tp][:]) + print(f"Verified: {n:,} persons in output") + if "household_weight" in f and tp in f["household_weight"]: + hw = f["household_weight"][tp][:] + print(f"Total population (HH weights): " f"{hw.sum():,.0f}") + if "person_weight" in f and tp in f["person_weight"]: + pw = f["person_weight"][tp][:] + print(f"Total population (person weights): " f"{pw.sum():,.0f}") + + return output_path + + +AT_LARGE_DISTRICTS = {0, 98} + + +def get_district_friendly_name(cd_geoid: str) -> str: + """Convert GEOID to friendly name (e.g., '0101' -> 'AL-01').""" + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + if district_num in AT_LARGE_DISTRICTS: + district_num = 1 + state_code = STATE_CODES.get(state_fips, str(state_fips)) + return f"{state_code}-{district_num:02d}" + + +def build_states( + weights_path: Path, + dataset_path: Path, + db_path: Path, + output_dir: Path, + completed_states: set, + hf_batch_size: int = 10, + calibration_blocks: np.ndarray = None, + takeup_filter: List[str] = None, + upload: bool = False, + stacked_takeup: dict = None, +): + """Build state H5 files with checkpointing, optionally uploading.""" + db_uri = f"sqlite:///{db_path}" + cds_to_calibrate = get_all_cds_from_database(db_uri) + w = np.load(weights_path) + + states_dir = output_dir / "states" + states_dir.mkdir(parents=True, exist_ok=True) + + hf_queue = [] + + for state_fips, state_code in STATE_CODES.items(): + if state_code in completed_states: + print(f"Skipping {state_code} (already completed)") + continue + + cd_subset = [ + cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips + ] + if not cd_subset: + print(f"No CDs found for {state_code}, skipping") + continue + + output_path = states_dir / f"{state_code}.h5" + + try: + build_h5( + weights=w, + blocks=calibration_blocks, + dataset_path=dataset_path, + output_path=output_path, + cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + if upload: + print(f"Uploading {state_code}.h5 to GCP...") + upload_local_area_file( + str(output_path), "states", skip_hf=True + ) + hf_queue.append((str(output_path), "states")) + + record_completed_state(state_code) + print(f"Completed {state_code}") + + if upload and len(hf_queue) >= hf_batch_size: + print( + f"\nUploading batch of {len(hf_queue)} " + f"files to HuggingFace..." + ) + upload_local_area_batch_to_hf(hf_queue) + hf_queue = [] + + except Exception as e: + print(f"ERROR building {state_code}: {e}") + raise + + if upload and hf_queue: + print( + f"\nUploading final batch of {len(hf_queue)} " + f"files to HuggingFace..." + ) + upload_local_area_batch_to_hf(hf_queue) + + +def build_districts( + weights_path: Path, + dataset_path: Path, + db_path: Path, + output_dir: Path, + completed_districts: set, + hf_batch_size: int = 10, + calibration_blocks: np.ndarray = None, + takeup_filter: List[str] = None, + upload: bool = False, + stacked_takeup: dict = None, +): + """Build district H5 files with checkpointing, optionally uploading.""" + db_uri = f"sqlite:///{db_path}" + cds_to_calibrate = get_all_cds_from_database(db_uri) + w = np.load(weights_path) + + districts_dir = output_dir / "districts" + districts_dir.mkdir(parents=True, exist_ok=True) + + hf_queue = [] + + for i, cd_geoid in enumerate(cds_to_calibrate): + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + if district_num in AT_LARGE_DISTRICTS: + district_num = 1 + state_code = STATE_CODES.get(state_fips, str(state_fips)) + friendly_name = f"{state_code}-{district_num:02d}" + + if friendly_name in completed_districts: + print(f"Skipping {friendly_name} (already completed)") + continue + + output_path = districts_dir / f"{friendly_name}.h5" + print( + f"\n[{i+1}/{len(cds_to_calibrate)}] " f"Building {friendly_name}" + ) + + try: + build_h5( + weights=w, + blocks=calibration_blocks, + dataset_path=dataset_path, + output_path=output_path, + cds_to_calibrate=cds_to_calibrate, + cd_subset=[cd_geoid], + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + if upload: + print(f"Uploading {friendly_name}.h5 to GCP...") + upload_local_area_file( + str(output_path), "districts", skip_hf=True + ) + hf_queue.append((str(output_path), "districts")) + + record_completed_district(friendly_name) + print(f"Completed {friendly_name}") + + if upload and len(hf_queue) >= hf_batch_size: + print( + f"\nUploading batch of {len(hf_queue)} " + f"files to HuggingFace..." + ) + upload_local_area_batch_to_hf(hf_queue) + hf_queue = [] + + except Exception as e: + print(f"ERROR building {friendly_name}: {e}") + raise + + if upload and hf_queue: + print( + f"\nUploading final batch of {len(hf_queue)} " + f"files to HuggingFace..." + ) + upload_local_area_batch_to_hf(hf_queue) + + +def build_cities( + weights_path: Path, + dataset_path: Path, + db_path: Path, + output_dir: Path, + completed_cities: set, + hf_batch_size: int = 10, + calibration_blocks: np.ndarray = None, + takeup_filter: List[str] = None, + upload: bool = False, + stacked_takeup: dict = None, +): + """Build city H5 files with checkpointing, optionally uploading.""" + db_uri = f"sqlite:///{db_path}" + cds_to_calibrate = get_all_cds_from_database(db_uri) + w = np.load(weights_path) + + cities_dir = output_dir / "cities" + cities_dir.mkdir(parents=True, exist_ok=True) + + hf_queue = [] + + # NYC + if "NYC" in completed_cities: + print("Skipping NYC (already completed)") + else: + cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] + if not cd_subset: + print("No NYC-related CDs found, skipping") + else: + output_path = cities_dir / "NYC.h5" + + try: + build_h5( + weights=w, + blocks=calibration_blocks, + dataset_path=dataset_path, + output_path=output_path, + cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + county_filter=NYC_COUNTIES, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + if upload: + print("Uploading NYC.h5 to GCP...") + upload_local_area_file( + str(output_path), "cities", skip_hf=True + ) + hf_queue.append((str(output_path), "cities")) + + record_completed_city("NYC") + print("Completed NYC") + + except Exception as e: + print(f"ERROR building NYC: {e}") + raise + + if upload and hf_queue: + print( + f"\nUploading batch of {len(hf_queue)} " + f"city files to HuggingFace..." + ) + upload_local_area_batch_to_hf(hf_queue) + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Build and publish local area H5 files" + ) + parser.add_argument( + "--skip-download", + action="store_true", + help="Skip downloading inputs from HF (use existing files)", + ) + parser.add_argument( + "--states-only", + action="store_true", + help="Only build and upload state files", + ) + parser.add_argument( + "--districts-only", + action="store_true", + help="Only build and upload district files", + ) + parser.add_argument( + "--cities-only", + action="store_true", + help="Only build and upload city files (e.g., NYC)", + ) + parser.add_argument( + "--weights-path", + type=str, + help="Override path to weights file (for local testing)", + ) + parser.add_argument( + "--dataset-path", + type=str, + help="Override path to dataset file (for local testing)", + ) + parser.add_argument( + "--db-path", + type=str, + help="Override path to database file (for local testing)", + ) + parser.add_argument( + "--calibration-blocks", + type=str, + help="Path to stacked_blocks.npy from calibration", + ) + parser.add_argument( + "--stacked-takeup", + type=str, + help="Path to stacked_takeup.npz from calibration", + ) + parser.add_argument( + "--upload", + action="store_true", + help="Upload to GCP and HuggingFace (default: build locally only)", + ) + args = parser.parse_args() + + WORK_DIR.mkdir(parents=True, exist_ok=True) + + if args.weights_path and args.dataset_path and args.db_path: + inputs = { + "weights": Path(args.weights_path), + "dataset": Path(args.dataset_path), + "database": Path(args.db_path), + } + print("Using provided paths:") + for key, path in inputs.items(): + print(f" {key}: {path}") + elif args.skip_download: + inputs = { + "weights": WORK_DIR / "calibration_weights.npy", + "dataset": ( + WORK_DIR / "source_imputed_stratified_extended_cps.h5" + ), + "database": WORK_DIR / "policy_data.db", + } + print("Using existing files in work directory:") + for key, path in inputs.items(): + if not path.exists(): + raise FileNotFoundError(f"Expected file not found: {path}") + print(f" {key}: {path}") + else: + print("Downloading calibration inputs from Hugging Face...") + inputs = download_calibration_inputs(str(WORK_DIR)) + for key, path in inputs.items(): + inputs[key] = Path(path) + + print(f"Using dataset: {inputs['dataset']}") + + sim = Microsimulation(dataset=str(inputs["dataset"])) + n_hh = sim.calculate("household_id", map_to="household").shape[0] + print(f"\nBase dataset has {n_hh:,} households") + + calibration_blocks = None + takeup_filter = None + + if args.calibration_blocks: + calibration_blocks = np.load(args.calibration_blocks) + print(f"Loaded calibration blocks: {len(calibration_blocks):,}") + takeup_filter = [ + info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values() + ] + print(f"Takeup filter: {takeup_filter}") + + stacked_takeup = None + if getattr(args, "stacked_takeup", None): + stacked_takeup = dict(np.load(args.stacked_takeup)) + print(f"Loaded stacked takeup: " f"{list(stacked_takeup.keys())}") + + # Determine what to build based on flags + do_states = not args.districts_only and not args.cities_only + do_districts = not args.states_only and not args.cities_only + do_cities = not args.states_only and not args.districts_only + + # If a specific *-only flag is set, only build that type + if args.states_only: + do_states = True + do_districts = False + do_cities = False + elif args.districts_only: + do_states = False + do_districts = True + do_cities = False + elif args.cities_only: + do_states = False + do_districts = False + do_cities = True + + if do_states: + print("\n" + "=" * 60) + print("BUILDING STATE FILES") + print("=" * 60) + completed_states = load_completed_states() + print(f"Already completed: {len(completed_states)} states") + build_states( + inputs["weights"], + inputs["dataset"], + inputs["database"], + WORK_DIR, + completed_states, + calibration_blocks=calibration_blocks, + takeup_filter=takeup_filter, + upload=args.upload, + stacked_takeup=stacked_takeup, + ) + + if do_districts: + print("\n" + "=" * 60) + print("BUILDING DISTRICT FILES") + print("=" * 60) + completed_districts = load_completed_districts() + print(f"Already completed: {len(completed_districts)} districts") + build_districts( + inputs["weights"], + inputs["dataset"], + inputs["database"], + WORK_DIR, + completed_districts, + calibration_blocks=calibration_blocks, + takeup_filter=takeup_filter, + upload=args.upload, + stacked_takeup=stacked_takeup, + ) + + if do_cities: + print("\n" + "=" * 60) + print("BUILDING CITY FILES") + print("=" * 60) + completed_cities = load_completed_cities() + print(f"Already completed: {len(completed_cities)} cities") + build_cities( + inputs["weights"], + inputs["dataset"], + inputs["database"], + WORK_DIR, + completed_cities, + calibration_blocks=calibration_blocks, + takeup_filter=takeup_filter, + upload=args.upload, + stacked_takeup=stacked_takeup, + ) + + print("\n" + "=" * 60) + print("ALL DONE!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/sanity_checks.py b/policyengine_us_data/calibration/sanity_checks.py new file mode 100644 index 000000000..91beb3b1b --- /dev/null +++ b/policyengine_us_data/calibration/sanity_checks.py @@ -0,0 +1,291 @@ +"""Structural integrity checks for calibrated H5 files. + +Run standalone: + python -m policyengine_us_data.calibration.sanity_checks path/to/file.h5 + +Or integrated via validate_staging.py --sanity-only. +""" + +import logging +from typing import List + +import h5py +import numpy as np + +logger = logging.getLogger(__name__) + +KEY_MONETARY_VARS = [ + "employment_income", + "adjusted_gross_income", + "snap", + "ssi", + "eitc", + "social_security", + "income_tax_before_credits", +] + +TAKEUP_VARS = [ + "takes_up_snap_if_eligible", + "takes_up_ssi_if_eligible", + "takes_up_aca_ptc_if_eligible", + "takes_up_medicaid_if_eligible", + "takes_up_tanf_if_eligible", + "takes_up_head_start_if_eligible", + "takes_up_early_head_start_if_eligible", + "takes_up_dc_property_tax_credit_if_eligible", +] + + +def run_sanity_checks( + h5_path: str, + period: int = 2024, +) -> List[dict]: + """Run structural integrity checks on an H5 file. + + Args: + h5_path: Path to the H5 dataset file. + period: Tax year (used for variable keys). + + Returns: + List of {check, status, detail} dicts. + """ + results = [] + + def _get(f, path): + """Resolve a slash path like 'var/2024' in the H5.""" + try: + obj = f[path] + if isinstance(obj, h5py.Dataset): + return obj[:] + return None + except KeyError: + return None + + with h5py.File(h5_path, "r") as f: + # 1. Weight non-negativity + w_key = f"household_weight/{period}" + weights = _get(f, w_key) + if weights is not None: + n_neg = int((weights < 0).sum()) + if n_neg > 0: + results.append( + { + "check": "weight_non_negativity", + "status": "FAIL", + "detail": f"{n_neg} negative weights", + } + ) + else: + results.append( + { + "check": "weight_non_negativity", + "status": "PASS", + "detail": "", + } + ) + else: + results.append( + { + "check": "weight_non_negativity", + "status": "SKIP", + "detail": f"key {w_key} not found", + } + ) + + # 2. Entity ID uniqueness + for entity in [ + "person", + "household", + "tax_unit", + "spm_unit", + ]: + ids = _get(f, f"{entity}_id/{period}") + if ids is None: + ids = _get(f, f"{entity}_id") + if ids is not None: + n_dup = len(ids) - len(np.unique(ids)) + if n_dup > 0: + results.append( + { + "check": f"{entity}_id_uniqueness", + "status": "FAIL", + "detail": f"{n_dup} duplicate IDs", + } + ) + else: + results.append( + { + "check": f"{entity}_id_uniqueness", + "status": "PASS", + "detail": "", + } + ) + + # 3. No NaN/Inf in key monetary variables + for var in KEY_MONETARY_VARS: + vals = _get(f, f"{var}/{period}") + if vals is None: + continue + n_nan = int(np.isnan(vals).sum()) + n_inf = int(np.isinf(vals).sum()) + if n_nan > 0 or n_inf > 0: + results.append( + { + "check": f"no_nan_inf_{var}", + "status": "FAIL", + "detail": f"{n_nan} NaN, {n_inf} Inf", + } + ) + else: + results.append( + { + "check": f"no_nan_inf_{var}", + "status": "PASS", + "detail": "", + } + ) + + # 4. Person-to-household mapping + person_hh_arr = _get(f, f"person_household_id/{period}") + if person_hh_arr is None: + person_hh_arr = _get(f, "person_household_id") + hh_id_arr = _get(f, f"household_id/{period}") + if hh_id_arr is None: + hh_id_arr = _get(f, "household_id") + + if person_hh_arr is not None and hh_id_arr is not None: + person_hh = set(person_hh_arr.tolist()) + hh_ids = set(hh_id_arr.tolist()) + orphans = person_hh - hh_ids + if orphans: + results.append( + { + "check": "person_household_mapping", + "status": "FAIL", + "detail": ( + f"{len(orphans)} persons map to " + "non-existent households" + ), + } + ) + else: + results.append( + { + "check": "person_household_mapping", + "status": "PASS", + "detail": "", + } + ) + + # 5. Boolean takeup variables + for var in TAKEUP_VARS: + vals = _get(f, f"{var}/{period}") + if vals is None: + continue + unique = set(np.unique(vals).tolist()) + valid = {True, False, 0, 1, 0.0, 1.0} + bad = unique - valid + if bad: + results.append( + { + "check": f"boolean_takeup_{var}", + "status": "FAIL", + "detail": (f"unexpected values: {bad}"), + } + ) + else: + results.append( + { + "check": f"boolean_takeup_{var}", + "status": "PASS", + "detail": "", + } + ) + + # 6. Reasonable per-capita ranges + if weights is not None: + total_hh = weights.sum() + if total_hh > 0: + emp = _get(f, f"employment_income/{period}") + if emp is not None: + total_emp = (emp * weights).sum() + per_hh = total_emp / total_hh + if per_hh < 10_000 or per_hh > 200_000: + results.append( + { + "check": "per_hh_employment_income", + "status": "WARN", + "detail": ( + f"${per_hh:,.0f}/hh " + "(expected $10K-$200K)" + ), + } + ) + else: + results.append( + { + "check": "per_hh_employment_income", + "status": "PASS", + "detail": f"${per_hh:,.0f}/hh", + } + ) + + snap_arr = _get(f, f"snap/{period}") + if snap_arr is not None: + total_snap = (snap_arr * weights).sum() + per_hh_snap = total_snap / total_hh + if per_hh_snap < 0 or per_hh_snap > 10_000: + results.append( + { + "check": "per_hh_snap", + "status": "WARN", + "detail": ( + f"${per_hh_snap:,.0f}/hh " + "(expected $0-$10K)" + ), + } + ) + else: + results.append( + { + "check": "per_hh_snap", + "status": "PASS", + "detail": f"${per_hh_snap:,.0f}/hh", + } + ) + + return results + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Run structural sanity checks on an H5 file" + ) + parser.add_argument("h5_path", help="Path to the H5 file") + parser.add_argument( + "--period", + type=int, + default=2024, + help="Tax year (default: 2024)", + ) + args = parser.parse_args() + + results = run_sanity_checks(args.h5_path, args.period) + + n_fail = sum(1 for r in results if r["status"] == "FAIL") + n_warn = sum(1 for r in results if r["status"] == "WARN") + + for r in results: + icon = "PASS" if r["status"] == "PASS" else r["status"] + detail = f" — {r['detail']}" if r["detail"] else "" + print(f" [{icon}] {r['check']}{detail}") + + print(f"\n{len(results)} checks: " f"{n_fail} failures, {n_warn} warnings") + if n_fail > 0: + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/stacked_dataset_builder.py b/policyengine_us_data/calibration/stacked_dataset_builder.py new file mode 100644 index 000000000..ef14b1c5a --- /dev/null +++ b/policyengine_us_data/calibration/stacked_dataset_builder.py @@ -0,0 +1,256 @@ +""" +CLI for creating CD-stacked datasets via build_h5. + +All H5 building logic lives in build_h5() in publish_local_area.py. +This module provides a CLI for common build modes (national, states, +cds, single-cd, single-state, nyc). +""" + +import os +import numpy as np +from pathlib import Path + +from policyengine_us_data.calibration.calibration_utils import ( + get_all_cds_from_database, + STATE_CODES, +) + +if __name__ == "__main__": + import argparse + + from policyengine_us import Microsimulation + from policyengine_us_data.calibration.publish_local_area import ( + build_h5, + NYC_COUNTIES, + NYC_CDS, + TAKEUP_AFFECTED_TARGETS, + ) + + parser = argparse.ArgumentParser( + description="Create sparse CD-stacked datasets" + ) + parser.add_argument( + "--weights-path", + required=True, + help="Path to w_cd.npy file", + ) + parser.add_argument( + "--dataset-path", + required=True, + help="Path to stratified dataset .h5 file", + ) + parser.add_argument( + "--db-path", + required=True, + help="Path to policy_data.db", + ) + parser.add_argument( + "--output-dir", + default="./temp", + help="Output directory for files", + ) + parser.add_argument( + "--mode", + choices=[ + "national", + "states", + "cds", + "single-cd", + "single-state", + "nyc", + ], + default="national", + help="Output mode", + ) + parser.add_argument( + "--cd", + type=str, + help="Single CD GEOID (--mode single-cd)", + ) + parser.add_argument( + "--state", + type=str, + help="State code e.g. RI, CA (--mode single-state)", + ) + parser.add_argument( + "--calibration-blocks", + default=None, + help="Path to stacked_blocks.npy", + ) + parser.add_argument( + "--stacked-takeup", + default=None, + help="Path to stacked_takeup.npz from calibration", + ) + + args = parser.parse_args() + dataset_path_str = args.dataset_path + weights_path_str = args.weights_path + db_path = Path(args.db_path).resolve() + output_dir = args.output_dir + mode = args.mode + + os.makedirs(output_dir, exist_ok=True) + + w = np.load(weights_path_str) + db_uri = f"sqlite:///{db_path}" + + cds_to_calibrate = get_all_cds_from_database(db_uri) + print(f"Found {len(cds_to_calibrate)} congressional districts") + + assert_sim = Microsimulation(dataset=dataset_path_str) + n_hh = assert_sim.calculate("household_id", map_to="household").shape[0] + expected_length = len(cds_to_calibrate) * n_hh + + if len(w) != expected_length: + raise ValueError( + f"Weight vector length ({len(w):,}) doesn't match " + f"expected ({expected_length:,})" + ) + + cal_blocks = None + takeup_filter = None + if args.calibration_blocks: + cal_blocks = np.load(args.calibration_blocks) + print(f"Loaded calibration blocks: {len(cal_blocks):,}") + takeup_filter = [ + info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values() + ] + print(f"Takeup filter: {takeup_filter}") + + stacked_takeup = None + if args.stacked_takeup: + stacked_takeup = dict(np.load(args.stacked_takeup)) + print(f"Loaded stacked takeup: " f"{list(stacked_takeup.keys())}") + + if mode == "national": + output_path = f"{output_dir}/national.h5" + print(f"\nCreating national dataset: {output_path}") + build_h5( + weights=np.array(w), + blocks=cal_blocks, + dataset_path=Path(dataset_path_str), + output_path=Path(output_path), + cds_to_calibrate=cds_to_calibrate, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + elif mode == "states": + for state_fips, state_code in STATE_CODES.items(): + cd_subset = [ + cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips + ] + if not cd_subset: + continue + output_path = f"{output_dir}/{state_code}.h5" + print(f"\nCreating {state_code}: {output_path}") + build_h5( + weights=np.array(w), + blocks=cal_blocks, + dataset_path=Path(dataset_path_str), + output_path=Path(output_path), + cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + elif mode == "cds": + for i, cd_geoid in enumerate(cds_to_calibrate): + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + if district_num in (0, 98): + district_num = 1 + state_code = STATE_CODES.get(state_fips, str(state_fips)) + friendly_name = f"{state_code}-{district_num:02d}" + + output_path = f"{output_dir}/{friendly_name}.h5" + print( + f"\n[{i+1}/{len(cds_to_calibrate)}] " + f"Creating {friendly_name}.h5" + ) + build_h5( + weights=np.array(w), + blocks=cal_blocks, + dataset_path=Path(dataset_path_str), + output_path=Path(output_path), + cds_to_calibrate=cds_to_calibrate, + cd_subset=[cd_geoid], + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + elif mode == "single-cd": + if not args.cd: + raise ValueError("--cd required with --mode single-cd") + if args.cd not in cds_to_calibrate: + raise ValueError(f"CD {args.cd} not in calibrated CDs list") + output_path = f"{output_dir}/{args.cd}.h5" + print(f"\nCreating single CD dataset: {output_path}") + build_h5( + weights=np.array(w), + blocks=cal_blocks, + dataset_path=Path(dataset_path_str), + output_path=Path(output_path), + cds_to_calibrate=cds_to_calibrate, + cd_subset=[args.cd], + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + elif mode == "single-state": + if not args.state: + raise ValueError("--state required with --mode single-state") + state_code_upper = args.state.upper() + state_fips = None + for fips, code in STATE_CODES.items(): + if code == state_code_upper: + state_fips = fips + break + if state_fips is None: + raise ValueError(f"Unknown state code: {args.state}") + + cd_subset = [ + cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips + ] + if not cd_subset: + raise ValueError(f"No CDs found for state {state_code_upper}") + + output_path = f"{output_dir}/{state_code_upper}.h5" + print( + f"\nCreating {state_code_upper} with " + f"{len(cd_subset)} CDs: {output_path}" + ) + build_h5( + weights=np.array(w), + blocks=cal_blocks, + dataset_path=Path(dataset_path_str), + output_path=Path(output_path), + cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + elif mode == "nyc": + cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] + if not cd_subset: + raise ValueError("No NYC CDs found") + + output_path = f"{output_dir}/NYC.h5" + print(f"\nCreating NYC with {len(cd_subset)} CDs: " f"{output_path}") + build_h5( + weights=np.array(w), + blocks=cal_blocks, + dataset_path=Path(dataset_path_str), + output_path=Path(output_path), + cds_to_calibrate=cds_to_calibrate, + cd_subset=cd_subset, + county_filter=NYC_COUNTIES, + takeup_filter=takeup_filter, + stacked_takeup=stacked_takeup, + ) + + print("\nDone!") diff --git a/policyengine_us_data/calibration/target_config.yaml b/policyengine_us_data/calibration/target_config.yaml new file mode 100644 index 000000000..dc867ef73 --- /dev/null +++ b/policyengine_us_data/calibration/target_config.yaml @@ -0,0 +1,212 @@ + +include: + # === DISTRICT — age demographics === + - variable: person_count + geo_level: district + domain_variable: age + + # === DISTRICT — count targets === + - variable: person_count + geo_level: district + domain_variable: adjusted_gross_income + - variable: household_count + geo_level: district + + # === DISTRICT — dollar targets (needed_w 7-41, compatible) === + - variable: real_estate_taxes + geo_level: district + - variable: self_employment_income + geo_level: district + - variable: taxable_pension_income + geo_level: district + - variable: refundable_ctc + geo_level: district + - variable: unemployment_compensation + geo_level: district + + # === DISTRICT — ACA PTC === + - variable: aca_ptc + geo_level: district + - variable: tax_unit_count + geo_level: district + domain_variable: aca_ptc + + # === STATE === + - variable: person_count + geo_level: state + domain_variable: medicaid_enrolled + # TODO: re-enable once make data is re-run with is_pregnant preserved + # - variable: person_count + # geo_level: state + # domain_variable: is_pregnant + - variable: snap + geo_level: state + + # === NATIONAL — aggregate dollar targets === + - variable: adjusted_gross_income + geo_level: national + - variable: child_support_expense + geo_level: national + - variable: child_support_received + geo_level: national + - variable: eitc + geo_level: national + - variable: health_insurance_premiums_without_medicare_part_b + geo_level: national + - variable: medicaid + geo_level: national + - variable: medicare_part_b_premiums + geo_level: national + - variable: other_medical_expenses + geo_level: national + - variable: over_the_counter_health_expenses + geo_level: national + - variable: qualified_business_income_deduction + geo_level: national + - variable: rent + geo_level: national + - variable: salt_deduction + geo_level: national + - variable: snap + geo_level: national + - variable: social_security + geo_level: national + - variable: social_security_disability + geo_level: national + - variable: social_security_retirement + geo_level: national + - variable: spm_unit_capped_housing_subsidy + geo_level: national + - variable: spm_unit_capped_work_childcare_expenses + geo_level: national + - variable: ssi + geo_level: national + - variable: tanf + geo_level: national + - variable: tip_income + geo_level: national + - variable: unemployment_compensation + geo_level: national + + # === NATIONAL — IRS SOI domain-constrained dollar targets === + - variable: aca_ptc + geo_level: national + domain_variable: aca_ptc + - variable: dividend_income + geo_level: national + domain_variable: dividend_income + - variable: eitc + geo_level: national + domain_variable: eitc_child_count + - variable: income_tax_positive + geo_level: national + - variable: income_tax_before_credits + geo_level: national + domain_variable: income_tax_before_credits + - variable: net_capital_gains + geo_level: national + domain_variable: net_capital_gains + - variable: qualified_business_income_deduction + geo_level: national + domain_variable: qualified_business_income_deduction + - variable: qualified_dividend_income + geo_level: national + domain_variable: qualified_dividend_income + - variable: refundable_ctc + geo_level: national + domain_variable: refundable_ctc + - variable: rental_income + geo_level: national + domain_variable: rental_income + - variable: salt + geo_level: national + domain_variable: salt + - variable: self_employment_income + geo_level: national + domain_variable: self_employment_income + - variable: tax_exempt_interest_income + geo_level: national + domain_variable: tax_exempt_interest_income + - variable: tax_unit_partnership_s_corp_income + geo_level: national + domain_variable: tax_unit_partnership_s_corp_income + - variable: taxable_interest_income + geo_level: national + domain_variable: taxable_interest_income + - variable: taxable_ira_distributions + geo_level: national + domain_variable: taxable_ira_distributions + - variable: taxable_pension_income + geo_level: national + domain_variable: taxable_pension_income + - variable: taxable_social_security + geo_level: national + domain_variable: taxable_social_security + - variable: unemployment_compensation + geo_level: national + domain_variable: unemployment_compensation + + # === NATIONAL — IRS SOI filer count targets === + - variable: tax_unit_count + geo_level: national + domain_variable: aca_ptc + - variable: tax_unit_count + geo_level: national + domain_variable: dividend_income + - variable: tax_unit_count + geo_level: national + domain_variable: eitc_child_count + - variable: tax_unit_count + geo_level: national + domain_variable: income_tax + - variable: tax_unit_count + geo_level: national + domain_variable: income_tax_before_credits + - variable: tax_unit_count + geo_level: national + domain_variable: medical_expense_deduction + - variable: tax_unit_count + geo_level: national + domain_variable: net_capital_gains + - variable: tax_unit_count + geo_level: national + domain_variable: qualified_business_income_deduction + - variable: tax_unit_count + geo_level: national + domain_variable: qualified_dividend_income + - variable: tax_unit_count + geo_level: national + domain_variable: real_estate_taxes + - variable: tax_unit_count + geo_level: national + domain_variable: refundable_ctc + - variable: tax_unit_count + geo_level: national + domain_variable: rental_income + - variable: tax_unit_count + geo_level: national + domain_variable: salt + - variable: tax_unit_count + geo_level: national + domain_variable: self_employment_income + - variable: tax_unit_count + geo_level: national + domain_variable: tax_exempt_interest_income + - variable: tax_unit_count + geo_level: national + domain_variable: tax_unit_partnership_s_corp_income + - variable: tax_unit_count + geo_level: national + domain_variable: taxable_interest_income + - variable: tax_unit_count + geo_level: national + domain_variable: taxable_ira_distributions + - variable: tax_unit_count + geo_level: national + domain_variable: taxable_pension_income + - variable: tax_unit_count + geo_level: national + domain_variable: taxable_social_security + - variable: tax_unit_count + geo_level: national + domain_variable: unemployment_compensation diff --git a/policyengine_us_data/calibration/target_config_full.yaml b/policyengine_us_data/calibration/target_config_full.yaml new file mode 100644 index 000000000..1e1e287dd --- /dev/null +++ b/policyengine_us_data/calibration/target_config_full.yaml @@ -0,0 +1,51 @@ +# Target exclusion config for unified calibration. +# Each entry excludes targets matching (variable, geo_level). +# Derived from junkyard's 22 excluded target groups. + +exclude: + # National exclusions + - variable: alimony_expense + geo_level: national + - variable: alimony_income + geo_level: national + - variable: charitable_deduction + geo_level: national + - variable: child_support_expense + geo_level: national + - variable: child_support_received + geo_level: national + - variable: interest_deduction + geo_level: national + - variable: medical_expense_deduction + geo_level: national + - variable: net_worth + geo_level: national + - variable: person_count + geo_level: national + - variable: real_estate_taxes + geo_level: national + - variable: rent + geo_level: national + - variable: social_security_dependents + geo_level: national + - variable: social_security_survivors + geo_level: national + # District exclusions + - variable: aca_ptc + geo_level: district + - variable: eitc + geo_level: district + - variable: income_tax_before_credits + geo_level: district + - variable: medical_expense_deduction + geo_level: district + - variable: net_capital_gains + geo_level: district + - variable: rental_income + geo_level: district + - variable: tax_unit_count + geo_level: district + - variable: tax_unit_partnership_s_corp_income + geo_level: district + - variable: taxable_social_security + geo_level: district diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 1fb7a6b34..506bd6e14 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -5,11 +5,11 @@ 1. Load CPS dataset -> get n_records 2. Clone Nx, assign random geography (census block) 3. (Optional) Source impute ACS/SIPP/SCF vars with state - 4. (Optional) PUF clone (2x) + QRF impute with state - 5. Re-randomize simple takeup variables per block - 6. Build sparse calibration matrix (clone-by-clone) - 7. L0-regularized optimization -> calibrated weights - 8. Save weights, diagnostics, run config + 4. Build sparse calibration matrix (clone-by-clone) + 5. L0-regularized optimization -> calibrated weights + 6. Save weights, diagnostics, run config + +Note: PUF cloning happens upstream in `extended_cps.py`, not here. Two presets control output size via L0 regularization: - local: L0=1e-8, ~3-4M records (for local area dataset) @@ -22,12 +22,13 @@ --output path/to/weights.npy \\ --preset local \\ --epochs 100 \\ - --puf-dataset path/to/puf_2024.h5 + --skip-source-impute """ import argparse import builtins import logging +import os import sys from pathlib import Path from typing import Optional @@ -55,133 +56,119 @@ LAMBDA_L2 = 1e-12 LEARNING_RATE = 0.15 DEFAULT_EPOCHS = 100 -DEFAULT_N_CLONES = 10 - -SIMPLE_TAKEUP_VARS = [ - { - "variable": "takes_up_snap_if_eligible", - "entity": "spm_unit", - "rate_key": "snap", - }, - { - "variable": "takes_up_aca_if_eligible", - "entity": "tax_unit", - "rate_key": "aca", - }, - { - "variable": "takes_up_dc_ptc", - "entity": "tax_unit", - "rate_key": "dc_ptc", - }, - { - "variable": "takes_up_head_start_if_eligible", - "entity": "person", - "rate_key": "head_start", - }, - { - "variable": "takes_up_early_head_start_if_eligible", - "entity": "person", - "rate_key": "early_head_start", - }, - { - "variable": "takes_up_ssi_if_eligible", - "entity": "person", - "rate_key": "ssi", - }, - { - "variable": "would_file_taxes_voluntarily", - "entity": "tax_unit", - "rate_key": "voluntary_filing", - }, - { - "variable": "takes_up_medicaid_if_eligible", - "entity": "person", - "rate_key": "medicaid", - }, - { - "variable": "takes_up_tanf_if_eligible", - "entity": "spm_unit", - "rate_key": "tanf", - }, -] - - -def rerandomize_takeup( - sim, - clone_block_geoids: np.ndarray, - clone_state_fips: np.ndarray, - time_period: int, -) -> None: - """Re-randomize simple takeup variables per census block. - - Groups entities by their household's block GEOID and draws - new takeup booleans using seeded_rng(var_name, salt=block). - Overrides the simulation's stored inputs. - - Args: - sim: Microsimulation instance (already has state_fips). - clone_block_geoids: Block GEOIDs per household. - clone_state_fips: State FIPS per household. - time_period: Tax year. - """ - from policyengine_us_data.parameters import ( - load_take_up_rate, - ) - from policyengine_us_data.utils.randomness import ( - seeded_rng, - ) +DEFAULT_N_CLONES = 430 - hh_ids = sim.calculate("household_id", map_to="household").values - hh_to_block = dict(zip(hh_ids, clone_block_geoids)) - hh_to_state = dict(zip(hh_ids, clone_state_fips)) - for spec in SIMPLE_TAKEUP_VARS: - var_name = spec["variable"] - entity_level = spec["entity"] - rate_key = spec["rate_key"] +def get_git_provenance() -> dict: + """Capture git state and package version for provenance tracking.""" + import subprocess as _sp - rate_or_dict = load_take_up_rate(rate_key, time_period) - - is_state_specific = isinstance(rate_or_dict, dict) + info = { + "git_commit": None, + "git_branch": None, + "git_dirty": None, + "package_version": None, + } + try: + info["git_commit"] = ( + _sp.check_output( + ["git", "rev-parse", "HEAD"], + stderr=_sp.DEVNULL, + ) + .decode() + .strip() + ) + info["git_branch"] = ( + _sp.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + stderr=_sp.DEVNULL, + ) + .decode() + .strip() + ) + porcelain = ( + _sp.check_output( + ["git", "status", "--porcelain"], + stderr=_sp.DEVNULL, + ) + .decode() + .strip() + ) + info["git_dirty"] = len(porcelain) > 0 + except Exception: + pass + try: + from policyengine_us_data.__version__ import __version__ - entity_ids = sim.calculate( - f"{entity_level}_id", map_to=entity_level - ).values - entity_hh_ids = sim.calculate( - "household_id", map_to=entity_level - ).values - n_entities = len(entity_ids) - - draws = np.zeros(n_entities, dtype=np.float64) - rates = np.zeros(n_entities, dtype=np.float64) - - entity_blocks = np.array( - [hh_to_block.get(hid, "0") for hid in entity_hh_ids] - ) - - unique_blocks = np.unique(entity_blocks) - for block in unique_blocks: - mask = entity_blocks == block - n_in_block = mask.sum() - rng = seeded_rng(var_name, salt=str(block)) - draws[mask] = rng.random(n_in_block) - - if is_state_specific: - block_hh_ids = entity_hh_ids[mask] - for i, hid in enumerate(block_hh_ids): - state = int(hh_to_state.get(hid, 0)) - state_str = str(state) - r = rate_or_dict.get( - state_str, - rate_or_dict.get(state, 0.8), - ) - idx = np.where(mask)[0][i] - rates[idx] = r - else: - rates[mask] = rate_or_dict + info["package_version"] = __version__ + except Exception: + pass + return info + + +def print_package_provenance(metadata: dict) -> None: + """Print a provenance banner from package metadata.""" + built = metadata.get("created_at", "unknown") + branch = metadata.get("git_branch", "unknown") + commit = metadata.get("git_commit") + commit_short = commit[:8] if commit else "unknown" + dirty = " (DIRTY)" if metadata.get("git_dirty") else "" + version = metadata.get("package_version", "unknown") + ds_sha = metadata.get("dataset_sha256", "") + db_sha = metadata.get("db_sha256", "") + ds_label = ds_sha[:12] if ds_sha else "unknown" + db_label = db_sha[:12] if db_sha else "unknown" + print("--- Package Provenance ---") + print(f" Built: {built}") + print(f" Branch: {branch} @ {commit_short}{dirty}") + print(f" Version: {version}") + print(f" Dataset SHA: {ds_label} DB SHA: {db_label}") + print("--------------------------") + + +def check_package_staleness(metadata: dict) -> None: + """Warn if package is stale, dirty, or from a different branch.""" + import datetime + + created = metadata.get("created_at") + if created: + try: + built_dt = datetime.datetime.fromisoformat(created) + age = datetime.datetime.now() - built_dt + if age.days > 7: + print( + f"WARNING: Package is {age.days} days old " + f"(built {created})" + ) + except Exception: + pass + + if metadata.get("git_dirty"): + print("WARNING: Package was built from a dirty " "working tree") + + current = get_git_provenance() + pkg_branch = metadata.get("git_branch") + cur_branch = current.get("git_branch") + if pkg_branch and cur_branch and pkg_branch != cur_branch: + print( + f"WARNING: Package built on branch " + f"'{pkg_branch}', current branch is " + f"'{cur_branch}'" + ) - new_values = draws < rates - sim.set_input(var_name, time_period, new_values) + pkg_commit = metadata.get("git_commit") + cur_commit = current.get("git_commit") + if ( + pkg_commit + and cur_commit + and pkg_commit != cur_commit + and pkg_branch == cur_branch + ): + print( + f"WARNING: Package commit {pkg_commit[:8]} " + f"differs from current {cur_commit[:8]} " + f"on same branch '{cur_branch}'" + ) def parse_args(argv=None): @@ -257,23 +244,297 @@ def parse_args(argv=None): help="Skip takeup re-randomization", ) parser.add_argument( - "--puf-dataset", + "--skip-source-impute", + action="store_true", + default=True, + help="(default) Skip ACS/SIPP/SCF re-imputation with state", + ) + parser.add_argument( + "--no-skip-source-impute", + dest="skip_source_impute", + action="store_false", + help="Run ACS/SIPP/SCF source imputation inline", + ) + parser.add_argument( + "--target-config", default=None, - help="Path to PUF h5 file for QRF training", + help="Path to target exclusion YAML config", ) parser.add_argument( - "--skip-puf", + "--county-level", action="store_true", - help="Skip PUF clone + QRF imputation", + help="Iterate per-county (slow, ~3143 counties). " + "Default is state-only (~51 states), which is much " + "faster for county-invariant target variables.", ) parser.add_argument( - "--skip-source-impute", + "--build-only", action="store_true", - help="Skip ACS/SIPP/SCF re-imputation with state", + help="Build matrix + save package, skip fitting", + ) + parser.add_argument( + "--package-path", + default=None, + help="Load pre-built calibration package (skip matrix build)", + ) + parser.add_argument( + "--package-output", + default=None, + help="Where to save calibration package", + ) + parser.add_argument( + "--beta", + type=float, + default=BETA, + help=f"L0 gate temperature (default: {BETA})", + ) + parser.add_argument( + "--lambda-l2", + type=float, + default=LAMBDA_L2, + help=f"L2 regularization (default: {LAMBDA_L2})", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=LEARNING_RATE, + help=f"Learning rate (default: {LEARNING_RATE})", + ) + parser.add_argument( + "--log-freq", + type=int, + default=None, + help="Epochs between per-target CSV log entries. " + "Omit to disable epoch logging.", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel workers for state/county " + "precomputation (default: 1, sequential).", ) return parser.parse_args(argv) +def load_target_config(path: str) -> dict: + """Load target exclusion config from YAML. + + Args: + path: Path to YAML config file. + + Returns: + Parsed config dict with 'exclude' list. + """ + import yaml + + with open(path) as f: + config = yaml.safe_load(f) + if config is None: + config = {} + if "exclude" not in config: + config["exclude"] = [] + return config + + +def _match_rules(targets_df, rules): + """Build a boolean mask matching any of the given rules.""" + mask = np.zeros(len(targets_df), dtype=bool) + for rule in rules: + rule_mask = targets_df["variable"] == rule["variable"] + if "geo_level" in rule: + rule_mask = rule_mask & ( + targets_df["geo_level"] == rule["geo_level"] + ) + if "domain_variable" in rule: + rule_mask = rule_mask & ( + targets_df["domain_variable"] == rule["domain_variable"] + ) + mask |= rule_mask + return mask + + +def apply_target_config( + targets_df: "pd.DataFrame", + X_sparse, + target_names: list, + config: dict, +) -> tuple: + """Filter targets based on include/exclude config. + + Use ``include`` to keep only matching targets, or ``exclude`` + to drop matching targets. Both support ``variable``, + ``geo_level`` (optional), and ``domain_variable`` (optional). + If both are present, ``include`` is applied first, then + ``exclude`` removes from the included set. + + Args: + targets_df: DataFrame with target rows. + X_sparse: Sparse matrix (targets x records). + target_names: List of target name strings. + config: Config dict with 'include' and/or 'exclude' list. + + Returns: + (filtered_targets_df, filtered_X_sparse, filtered_names) + """ + include_rules = config.get("include", []) + exclude_rules = config.get("exclude", []) + + if not include_rules and not exclude_rules: + return targets_df, X_sparse, target_names + + n_before = len(targets_df) + + if include_rules: + keep_mask = _match_rules(targets_df, include_rules) + else: + keep_mask = np.ones(n_before, dtype=bool) + + if exclude_rules: + drop_mask = _match_rules(targets_df, exclude_rules) + keep_mask &= ~drop_mask + + n_dropped = n_before - keep_mask.sum() + logger.info( + "Target config: kept %d / %d targets (dropped %d)", + keep_mask.sum(), + n_before, + n_dropped, + ) + + idx = np.where(keep_mask)[0] + filtered_df = targets_df.iloc[idx].reset_index(drop=True) + filtered_X = X_sparse[idx, :] + filtered_names = [target_names[i] for i in idx] + + return filtered_df, filtered_X, filtered_names + + +def save_calibration_package( + path: str, + X_sparse, + targets_df: "pd.DataFrame", + target_names: list, + metadata: dict, + initial_weights: np.ndarray = None, + cd_geoid: np.ndarray = None, + block_geoid: np.ndarray = None, +) -> None: + """Save calibration package to pickle. + + Args: + path: Output file path. + X_sparse: Sparse matrix. + targets_df: Targets DataFrame. + target_names: Target name list. + metadata: Run metadata dict. + initial_weights: Pre-computed initial weight array. + cd_geoid: CD GEOID array from geography assignment. + block_geoid: Block GEOID array from geography assignment. + """ + import pickle + + package = { + "X_sparse": X_sparse, + "targets_df": targets_df, + "target_names": target_names, + "metadata": metadata, + "initial_weights": initial_weights, + "cd_geoid": cd_geoid, + "block_geoid": block_geoid, + } + Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as f: + pickle.dump(package, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info("Calibration package saved to %s", path) + + +def load_calibration_package(path: str) -> dict: + """Load calibration package from pickle. + + Args: + path: Path to package file. + + Returns: + Dict with X_sparse, targets_df, target_names, metadata. + """ + import pickle + + with open(path, "rb") as f: + package = pickle.load(f) + logger.info( + "Loaded package: %d targets, %d records", + package["X_sparse"].shape[0], + package["X_sparse"].shape[1], + ) + meta = package.get("metadata", {}) + print_package_provenance(meta) + check_package_staleness(meta) + return package + + +def compute_initial_weights( + X_sparse, + targets_df: "pd.DataFrame", +) -> np.ndarray: + """Compute population-based initial weights from age targets. + + For each congressional district, sums person_count targets where + domain_variable == "age" to get district population, then divides + by the number of columns (households) active in that district. + + Args: + X_sparse: Sparse matrix (targets x records). + targets_df: Targets DataFrame with columns: variable, + domain_variable, geo_level, geographic_id, value. + + Returns: + Weight array of shape (n_records,). + """ + n_total = X_sparse.shape[1] + + age_mask = ( + (targets_df["variable"] == "person_count") + & (targets_df["domain_variable"] == "age") + & (targets_df["geo_level"] == "district") + ) + age_rows = targets_df[age_mask] + + if len(age_rows) == 0: + logger.warning( + "No person_count/age/district targets found; " + "falling back to uniform weights=100" + ) + return np.ones(n_total) * 100 + + initial_weights = np.ones(n_total) * 100 + cd_groups = age_rows.groupby("geographic_id") + + for cd_id, group in cd_groups: + cd_pop = group["value"].sum() + row_indices = group.index.tolist() + col_set = set() + for ri in row_indices: + row = X_sparse[ri] + col_set.update(row.indices) + n_cols = len(col_set) + if n_cols == 0: + continue + w = cd_pop / n_cols + for c in col_set: + initial_weights[c] = w + + n_unique = len(np.unique(initial_weights)) + logger.info( + "Initial weights: min=%.1f, max=%.1f, mean=%.1f, " "%d unique values", + initial_weights.min(), + initial_weights.max(), + initial_weights.mean(), + n_unique, + ) + return initial_weights + + def fit_l0_weights( X_sparse, targets: np.ndarray, @@ -281,6 +542,15 @@ def fit_l0_weights( epochs: int = DEFAULT_EPOCHS, device: str = "cpu", verbose_freq: Optional[int] = None, + beta: float = BETA, + lambda_l2: float = LAMBDA_L2, + learning_rate: float = LEARNING_RATE, + log_freq: int = None, + log_path: str = None, + target_names: list = None, + initial_weights: np.ndarray = None, + targets_df: "pd.DataFrame" = None, + achievable: np.ndarray = None, ) -> np.ndarray: """Fit L0-regularized calibration weights. @@ -291,6 +561,17 @@ def fit_l0_weights( epochs: Training epochs. device: Torch device. verbose_freq: Print frequency. Defaults to 10%. + beta: L0 gate temperature. + lambda_l2: L2 regularization strength. + learning_rate: Optimizer learning rate. + log_freq: Epochs between per-target CSV logs. + None disables logging. + log_path: Path for the per-target calibration log CSV. + target_names: Human-readable target names for the log. + initial_weights: Pre-computed initial weights. If None, + computed from targets_df age targets. + targets_df: Targets DataFrame, used to compute + initial_weights when not provided. Returns: Weight array of shape (n_records,). @@ -304,21 +585,30 @@ def fit_l0_weights( import torch + os.environ.setdefault( + "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True" + ) + n_total = X_sparse.shape[1] - initial_weights = np.ones(n_total) * 100 + if initial_weights is None: + initial_weights = compute_initial_weights(X_sparse, targets_df) logger.info( "L0 calibration: %d targets, %d features, " - "lambda_l0=%.1e, epochs=%d", + "lambda_l0=%.1e, beta=%.2f, lambda_l2=%.1e, " + "lr=%.3f, epochs=%d", X_sparse.shape[0], n_total, lambda_l0, + beta, + lambda_l2, + learning_rate, epochs, ) model = SparseCalibrationWeights( n_features=n_total, - beta=BETA, + beta=beta, gamma=GAMMA, zeta=ZETA, init_keep_prob=INIT_KEEP_PROB, @@ -339,22 +629,136 @@ def _flushed_print(*args, **kwargs): builtins.print = _flushed_print - t0 = time.time() - try: - model.fit( - M=X_sparse, - y=targets, - target_groups=None, - lambda_l0=lambda_l0, - lambda_l2=LAMBDA_L2, - lr=LEARNING_RATE, - epochs=epochs, - loss_type="relative", - verbose=True, - verbose_freq=verbose_freq, + enable_logging = ( + log_freq is not None + and log_path is not None + and target_names is not None + ) + if enable_logging: + Path(log_path).parent.mkdir(parents=True, exist_ok=True) + with open(log_path, "w") as f: + f.write( + "target_name,estimate,target,epoch," + "error,rel_error,abs_error," + "rel_abs_error,loss,achievable\n" + ) + logger.info( + "Epoch logging enabled: freq=%d, path=%s", + log_freq, + log_path, ) - finally: - builtins.print = _builtin_print + + t0 = time.time() + if enable_logging: + epochs_done = 0 + while epochs_done < epochs: + chunk = min(log_freq, epochs - epochs_done) + model.fit( + M=X_sparse, + y=targets, + target_groups=None, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + lr=learning_rate, + epochs=chunk, + loss_type="relative", + verbose=False, + ) + + epochs_done += chunk + + with torch.no_grad(): + y_pred = model.predict(X_sparse).cpu().numpy() + weights_snap = ( + model.get_weights(deterministic=True).cpu().numpy() + ) + + active_w = weights_snap[weights_snap > 0] + nz = len(active_w) + sparsity = (1 - nz / n_total) * 100 + + rel_errs = np.where( + np.abs(targets) > 0, + (y_pred - targets) / np.abs(targets), + 0.0, + ) + mean_err = np.mean(np.abs(rel_errs)) + max_err = np.max(np.abs(rel_errs)) + total_loss = np.sum(rel_errs**2) + + if nz > 0: + w_tiny = (active_w < 0.01).sum() + w_small = ((active_w >= 0.01) & (active_w < 0.1)).sum() + w_med = ((active_w >= 0.1) & (active_w < 1.0)).sum() + w_normal = ((active_w >= 1.0) & (active_w < 10.0)).sum() + w_large = ((active_w >= 10.0) & (active_w < 1000.0)).sum() + w_huge = (active_w >= 1000.0).sum() + weight_dist = ( + f"[<0.01: {100*w_tiny/nz:.1f}%, " + f"0.01-0.1: {100*w_small/nz:.1f}%, " + f"0.1-1: {100*w_med/nz:.1f}%, " + f"1-10: {100*w_normal/nz:.1f}%, " + f"10-1000: {100*w_large/nz:.1f}%, " + f">1000: {100*w_huge/nz:.1f}%]" + ) + else: + weight_dist = "[no active weights]" + + print( + f"Epoch {epochs_done:4d}: " + f"mean_error={mean_err:.4%}, " + f"max_error={max_err:.1%}, " + f"total_loss={total_loss:.3f}, " + f"active={nz}/{n_total} " + f"({sparsity:.1f}% sparse)\n" + f" Weight dist: {weight_dist}", + flush=True, + ) + + ach_flags = ( + achievable if achievable is not None else [True] * len(targets) + ) + with open(log_path, "a") as f: + for i in range(len(targets)): + est = y_pred[i] + tgt = targets[i] + err = est - tgt + rel_err = err / tgt if tgt != 0 else 0 + abs_err = abs(err) + rel_abs = abs(rel_err) + loss = rel_err**2 + f.write( + f'"{target_names[i]}",' + f"{est},{tgt},{epochs_done}," + f"{err},{rel_err},{abs_err}," + f"{rel_abs},{loss}," + f"{ach_flags[i]}\n" + ) + + logger.info( + "Logged %d targets at epoch %d", + len(targets), + epochs_done, + ) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + try: + model.fit( + M=X_sparse, + y=targets, + target_groups=None, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + lr=learning_rate, + epochs=epochs, + loss_type="relative", + verbose=True, + verbose_freq=verbose_freq, + ) + finally: + builtins.print = _builtin_print elapsed = time.time() - t0 logger.info( @@ -376,115 +780,266 @@ def _flushed_print(*args, **kwargs): return weights -def compute_diagnostics( +def convert_weights_to_stacked_format( weights: np.ndarray, - X_sparse, - targets_df, - target_names: list, -) -> "pd.DataFrame": - import pandas as pd + cd_geoid: np.ndarray, + base_n_records: int, + cds_ordered: list, +) -> np.ndarray: + """Convert column-ordered weights to (n_cds, n_records) stacked format. - estimates = X_sparse.dot(weights) - true_values = targets_df["value"].values - row_sums = np.array(X_sparse.sum(axis=1)).flatten() + The L0 calibration produces one weight per column, where columns + are ordered by clone (column i -> clone i // n_records, record + i % n_records) with random CD assignments. This function + aggregates weights across clones into the (n_cds, n_records) + layout expected by stacked_dataset_builder. - rel_errors = np.where( - np.abs(true_values) > 0, - (estimates - true_values) / np.abs(true_values), - 0.0, - ) - return pd.DataFrame( - { - "target": target_names, - "true_value": true_values, - "estimate": estimates, - "rel_error": rel_errors, - "abs_rel_error": np.abs(rel_errors), - "achievable": row_sums > 0, - } + Args: + weights: Raw weight vector from L0 fitting, length + n_clones * base_n_records. + cd_geoid: CD GEOID per column from geography assignment. + base_n_records: Number of base households (before cloning). + cds_ordered: Ordered list of CD GEOIDs defining row order. + + Returns: + Flat array of length n_cds * base_n_records that reshapes + to (n_cds, base_n_records). + """ + n_total = len(weights) + n_cds = len(cds_ordered) + + cd_to_idx = {cd: idx for idx, cd in enumerate(cds_ordered)} + record_indices = np.arange(n_total) % base_n_records + cd_row_indices = np.array([cd_to_idx[cd] for cd in cd_geoid]) + flat_indices = cd_row_indices * base_n_records + record_indices + + W = np.zeros(n_cds * base_n_records, dtype=np.float64) + np.add.at(W, flat_indices, weights) + + assert np.isclose( + W.sum(), weights.sum() + ), f"Weight sum mismatch: {W.sum()} vs {weights.sum()}" + logger.info( + "Converted weights to stacked format: " + "(%d, %d) = %d elements, sum=%.1f", + n_cds, + base_n_records, + len(W), + W.sum(), ) + return W -def _build_puf_cloned_dataset( - dataset_path: str, - puf_dataset_path: str, - state_fips: np.ndarray, - time_period: int = 2024, - skip_qrf: bool = False, - skip_source_impute: bool = False, -) -> str: - """Build a PUF-cloned dataset from raw CPS. +def convert_blocks_to_stacked_format( + block_geoid: np.ndarray, + cd_geoid: np.ndarray, + base_n_records: int, + cds_ordered: list, +) -> np.ndarray: + """Convert column-ordered block GEOIDs to stacked format. - Loads the CPS, optionally runs source imputations - (ACS/SIPP/SCF), then PUF clone + QRF. + Parallel to convert_weights_to_stacked_format. For each + (CD, record) slot, stores the block GEOID from the first + clone assigned there. Empty string for unfilled slots + (records with no clone in that CD). Args: - dataset_path: Path to raw CPS h5 file. - puf_dataset_path: Path to PUF h5 file. - state_fips: State FIPS per household (base records). - time_period: Tax year. - skip_qrf: Skip QRF imputation. - skip_source_impute: Skip ACS/SIPP/SCF imputations. + block_geoid: Block GEOID per column from geography + assignment. Length n_clones * base_n_records. + cd_geoid: CD GEOID per column from geography + assignment. + base_n_records: Number of base households. + cds_ordered: Ordered list of CD GEOIDs defining + row order. Returns: - Path to the PUF-cloned h5 file. + Array of dtype U15, length n_cds * base_n_records, + reshapeable to (n_cds, base_n_records). """ - import h5py + n_total = len(block_geoid) + n_cds = len(cds_ordered) + + cd_to_idx = {cd: idx for idx, cd in enumerate(cds_ordered)} + record_indices = np.arange(n_total) % base_n_records + cd_row_indices = np.array([cd_to_idx[cd] for cd in cd_geoid]) + flat_indices = cd_row_indices * base_n_records + record_indices + + B = np.full(n_cds * base_n_records, "", dtype="U15") + n_collisions = 0 + for i in range(n_total): + fi = flat_indices[i] + if B[fi] == "": + B[fi] = block_geoid[i] + else: + n_collisions += 1 - from policyengine_us import Microsimulation + if n_collisions > 0: + logger.warning( + "Block collisions: %d slots had multiple clones " + "with different blocks.", + n_collisions, + ) - from policyengine_us_data.calibration.puf_impute import ( - puf_clone_dataset, + n_filled = np.count_nonzero(B != "") + logger.info( + "Converted blocks to stacked format: " + "(%d, %d) = %d slots, %d filled (%.1f%%)", + n_cds, + base_n_records, + len(B), + n_filled, + n_filled / len(B) * 100, ) + return B - logger.info("Building PUF-cloned dataset from %s", dataset_path) - sim = Microsimulation(dataset=dataset_path) - data = sim.dataset.load_dataset() +def compute_stacked_takeup( + weights: np.ndarray, + cd_geoid: np.ndarray, + block_geoid: np.ndarray, + base_n_records: int, + cds_ordered: list, + entity_hh_idx_map: dict, + household_ids: np.ndarray, + precomputed_rates: dict, + affected_target_info: dict, +) -> dict: + """Compute weight-weighted takeup per (CD, entity). + + For each takeup variable, iterates over all clones and + recomputes entity-level takeup draws (deterministic given + block and hh_id). Accumulates weight-weighted takeup + per (CD, base_entity_index) using the final optimizer + weights. - data_dict = {} - for var in data: - if isinstance(data[var], dict): - vals = list(data[var].values()) - data_dict[var] = {time_period: vals[0]} - else: - data_dict[var] = {time_period: np.array(data[var])} + Returns: + Dict mapping takeup variable name to ndarray of + shape (n_cds, n_base_entities). + """ + from policyengine_us_data.utils.takeup import ( + compute_block_takeup_for_entities, + ) - if not skip_source_impute: - from policyengine_us_data.calibration.source_impute import ( - impute_source_variables, + n_total = len(weights) + n_clones = n_total // base_n_records + n_cds = len(cds_ordered) + cd_to_idx = {cd: i for i, cd in enumerate(cds_ordered)} + + col_cd_idx = np.empty(n_total, dtype=np.int32) + for i, cd in enumerate(cd_geoid): + col_cd_idx[i] = cd_to_idx.get(cd, -1) + + unique_takeup = {} + for tvar, info in affected_target_info.items(): + if tvar.endswith("_count"): + continue + tv = info["takeup_var"] + if tv not in unique_takeup: + unique_takeup[tv] = info + + result = {} + + for takeup_var, info in unique_takeup.items(): + entity_level = info["entity"] + rate_key = info["rate_key"] + ent_hh = entity_hh_idx_map[entity_level] + n_ent = len(ent_hh) + rate = precomputed_rates[rate_key] + + numerator = np.zeros(n_cds * n_ent, dtype=np.float64) + denominator = np.zeros(n_cds * n_ent, dtype=np.float64) + + for clone_idx in range(n_clones): + col_start = clone_idx * base_n_records + col_end = col_start + base_n_records + clone_w = weights[col_start:col_end] + + if not np.any(clone_w > 0): + continue + + clone_blocks = block_geoid[col_start:col_end] + + ent_blocks = clone_blocks[ent_hh] + ent_hh_ids = household_ids[ent_hh] + + ent_takeup = compute_block_takeup_for_entities( + takeup_var, rate, ent_blocks, ent_hh_ids + ).astype(np.float64) + + ent_w = clone_w[ent_hh] + ent_cd_idx = col_cd_idx[col_start:col_end][ent_hh] + + ent_active = (ent_w > 0) & (ent_cd_idx >= 0) + if not ent_active.any(): + continue + + ent_indices = np.arange(n_ent) + flat = ent_cd_idx * n_ent + ent_indices + + np.add.at( + numerator, + flat[ent_active], + (ent_w * ent_takeup)[ent_active], + ) + np.add.at( + denominator, + flat[ent_active], + ent_w[ent_active], + ) + + if (clone_idx + 1) % 100 == 0: + logger.info( + "Stacked takeup %s: clone %d/%d", + takeup_var, + clone_idx + 1, + n_clones, + ) + + valid = denominator > 0 + takeup_avg = np.ones(n_cds * n_ent, dtype=np.float32) + takeup_avg[valid] = (numerator[valid] / denominator[valid]).astype( + np.float32 ) - data_dict = impute_source_variables( - data=data_dict, - state_fips=state_fips, - time_period=time_period, - dataset_path=dataset_path, + result[takeup_var] = takeup_avg.reshape(n_cds, n_ent) + logger.info( + "Stacked takeup %s: shape %s, " "active cells %d, mean %.4f", + takeup_var, + result[takeup_var].shape, + valid.sum(), + takeup_avg[valid].mean() if valid.any() else 0, ) - puf_dataset = puf_dataset_path if not skip_qrf else None + return result - new_data = puf_clone_dataset( - data=data_dict, - state_fips=state_fips, - time_period=time_period, - puf_dataset=puf_dataset, - skip_qrf=skip_qrf, - dataset_path=dataset_path, - ) - output_path = str( - Path(dataset_path).parent / f"puf_cloned_{Path(dataset_path).stem}.h5" - ) +def compute_diagnostics( + weights: np.ndarray, + X_sparse, + targets_df, + target_names: list, +) -> "pd.DataFrame": + import pandas as pd - with h5py.File(output_path, "w") as f: - for var, time_dict in new_data.items(): - for tp, values in time_dict.items(): - f.create_dataset(f"{var}/{tp}", data=values) + estimates = X_sparse.dot(weights) + true_values = targets_df["value"].values + row_sums = np.array(X_sparse.sum(axis=1)).flatten() - del sim - logger.info("PUF-cloned dataset saved to %s", output_path) - return output_path + rel_errors = np.where( + np.abs(true_values) > 0, + (estimates - true_values) / np.abs(true_values), + 0.0, + ) + return pd.DataFrame( + { + "target": target_names, + "true_value": true_values, + "estimate": estimates, + "rel_error": rel_errors, + "abs_rel_error": np.abs(rel_errors), + "achievable": row_sums > 0, + } + ) def run_calibration( @@ -498,9 +1053,18 @@ def run_calibration( domain_variables: list = None, hierarchical_domains: list = None, skip_takeup_rerandomize: bool = False, - puf_dataset_path: str = None, - skip_puf: bool = False, - skip_source_impute: bool = False, + skip_source_impute: bool = True, + skip_county: bool = True, + target_config: dict = None, + build_only: bool = False, + package_path: str = None, + package_output_path: str = None, + beta: float = BETA, + lambda_l2: float = LAMBDA_L2, + learning_rate: float = LEARNING_RATE, + log_freq: int = None, + log_path: str = None, + workers: int = 1, ): """Run unified calibration pipeline. @@ -516,32 +1080,98 @@ def run_calibration( hierarchical_domains: Domains for hierarchical uprating + CD reconciliation. skip_takeup_rerandomize: Skip takeup step. - puf_dataset_path: Path to PUF h5 for QRF training. - skip_puf: Skip PUF clone step. skip_source_impute: Skip ACS/SIPP/SCF imputations. + target_config: Parsed target config dict. + build_only: If True, save package and skip fitting. + package_path: Load pre-built package (skip build). + package_output_path: Where to save calibration package. + beta: L0 gate temperature. + lambda_l2: L2 regularization strength. + learning_rate: Optimizer learning rate. + log_freq: Epochs between per-target CSV logs. + log_path: Path for per-target calibration log CSV. Returns: - (weights, targets_df, X_sparse, target_names) + (weights, targets_df, X_sparse, target_names, geography_info) + weights is None when build_only=True. + geography_info is a dict with cd_geoid and base_n_records. """ import time + t0 = time.time() + + # Early exit: load pre-built package + if package_path is not None: + package = load_calibration_package(package_path) + targets_df = package["targets_df"] + X_sparse = package["X_sparse"] + target_names = package["target_names"] + + if target_config: + targets_df, X_sparse, target_names = apply_target_config( + targets_df, X_sparse, target_names, target_config + ) + + initial_weights = package.get("initial_weights") + targets = targets_df["value"].values + row_sums = np.array(X_sparse.sum(axis=1)).flatten() + pkg_achievable = row_sums > 0 + weights = fit_l0_weights( + X_sparse=X_sparse, + targets=targets, + lambda_l0=lambda_l0, + epochs=epochs, + device=device, + beta=beta, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + log_path=log_path, + target_names=target_names, + initial_weights=initial_weights, + targets_df=targets_df, + achievable=pkg_achievable, + ) + logger.info( + "Total pipeline (from package): %.1f min", + (time.time() - t0) / 60, + ) + geography_info = { + "cd_geoid": package.get("cd_geoid"), + "block_geoid": package.get("block_geoid"), + "base_n_records": package["metadata"].get("base_n_records"), + } + return ( + weights, + targets_df, + X_sparse, + target_names, + geography_info, + ) + from policyengine_us import Microsimulation from policyengine_us_data.calibration.clone_and_assign import ( assign_random_geography, - double_geography_for_puf, ) from policyengine_us_data.calibration.unified_matrix_builder import ( UnifiedMatrixBuilder, ) - t0 = time.time() - - # Step 1: Load dataset + # Step 1: Load dataset and detect time period logger.info("Loading dataset from %s", dataset_path) sim = Microsimulation(dataset=dataset_path) n_records = len(sim.calculate("household_id", map_to="household").values) - logger.info("Loaded %d households", n_records) + raw_keys = sim.dataset.load_dataset()["household_id"] + if isinstance(raw_keys, dict): + time_period = int(next(iter(raw_keys))) + else: + time_period = 2024 + logger.info( + "Loaded %d households, time_period=%d", + n_records, + time_period, + ) # Step 2: Clone and assign geography logger.info( @@ -556,46 +1186,28 @@ def run_calibration( seed=seed, ) - # Step 3: Source impute + PUF clone (if requested) + # Step 3: Source imputation (if requested) dataset_for_matrix = dataset_path - if not skip_puf and puf_dataset_path is not None: - base_states = geography.state_fips[:n_records] - - puf_cloned_path = _build_puf_cloned_dataset( - dataset_path=dataset_path, - puf_dataset_path=puf_dataset_path, - state_fips=base_states, - time_period=2024, - skip_qrf=False, - skip_source_impute=skip_source_impute, - ) - - geography = double_geography_for_puf(geography) - dataset_for_matrix = puf_cloned_path - n_records = n_records * 2 - - # Reload sim from PUF-cloned dataset - del sim - sim = Microsimulation(dataset=puf_cloned_path) - - logger.info( - "After PUF clone: %d records x %d clones = %d", - n_records, - n_clones, - n_records * n_clones, - ) - elif not skip_source_impute: + if not skip_source_impute: # Run source imputations without PUF cloning import h5py base_states = geography.state_fips[:n_records] - source_sim = Microsimulation(dataset=dataset_path) - raw_data = source_sim.dataset.load_dataset() + raw_data = sim.dataset.load_dataset() data_dict = {} for var in raw_data: - data_dict[var] = {2024: raw_data[var][...]} - del source_sim + val = raw_data[var] + if isinstance(val, dict): + # h5py returns string keys ("2024"); normalize + # to int so source_impute lookups work. + # Some keys like "ETERNITY" are non-numeric — keep + # them as strings. + data_dict[var] = { + int(k) if k.isdigit() else k: v for k, v in val.items() + } + else: + data_dict[var] = {time_period: val[...]} from policyengine_us_data.calibration.source_impute import ( impute_source_variables, @@ -604,7 +1216,7 @@ def run_calibration( data_dict = impute_source_variables( data=data_dict, state_fips=base_states, - time_period=2024, + time_period=time_period, dataset_path=dataset_path, ) @@ -625,17 +1237,7 @@ def run_calibration( source_path, ) - # Step 4: Build sim_modifier for takeup rerandomization sim_modifier = None - if not skip_takeup_rerandomize: - time_period = 2024 - - def sim_modifier(s, clone_idx): - col_start = clone_idx * n_records - col_end = col_start + n_records - blocks = geography.block_geoid[col_start:col_end] - states = geography.state_fips[col_start:col_end] - rerandomize_takeup(s, blocks, states, time_period) # Step 5: Build target filter target_filter = {} @@ -643,11 +1245,12 @@ def sim_modifier(s, clone_idx): target_filter["domain_variables"] = domain_variables # Step 6: Build sparse calibration matrix + do_rerandomize = not skip_takeup_rerandomize t_matrix = time.time() db_uri = f"sqlite:///{db_path}" builder = UnifiedMatrixBuilder( db_uri=db_uri, - time_period=2024, + time_period=time_period, dataset_path=dataset_for_matrix, ) targets_df, X_sparse, target_names = builder.build_matrix( @@ -656,6 +1259,9 @@ def sim_modifier(s, clone_idx): target_filter=target_filter, hierarchical_domains=hierarchical_domains, sim_modifier=sim_modifier, + rerandomize_takeup=do_rerandomize, + county_level=not skip_county, + workers=workers, ) builder.print_uprating_summary(targets_df) @@ -669,6 +1275,76 @@ def sim_modifier(s, clone_idx): X_sparse.nnz, ) + # Step 6b: Save FULL (unfiltered) calibration package. + # Target config is applied at fit time, so the package can be + # reused with different configs without rebuilding. + import datetime + + metadata = { + "dataset_path": dataset_path, + "db_path": db_path, + "n_clones": n_clones, + "n_records": X_sparse.shape[1], + "base_n_records": n_records, + "seed": seed, + "created_at": datetime.datetime.now().isoformat(), + } + metadata.update(get_git_provenance()) + from policyengine_us_data.utils.manifest import compute_file_checksum + + metadata["dataset_sha256"] = compute_file_checksum(Path(dataset_path)) + metadata["db_sha256"] = compute_file_checksum(Path(db_path)) + + if package_output_path: + full_initial_weights = compute_initial_weights(X_sparse, targets_df) + save_calibration_package( + package_output_path, + X_sparse, + targets_df, + target_names, + metadata, + initial_weights=full_initial_weights, + cd_geoid=geography.cd_geoid, + block_geoid=geography.block_geoid, + ) + + # Step 6c: Apply target config filtering (for fit or validation) + if target_config: + targets_df, X_sparse, target_names = apply_target_config( + targets_df, X_sparse, target_names, target_config + ) + + initial_weights = compute_initial_weights(X_sparse, targets_df) + + if build_only: + from policyengine_us_data.calibration.validate_package import ( + validate_package, + format_report, + ) + + package = { + "X_sparse": X_sparse, + "targets_df": targets_df, + "target_names": target_names, + "metadata": metadata, + "initial_weights": initial_weights, + } + result = validate_package(package) + print(format_report(result)) + geography_info = { + "cd_geoid": geography.cd_geoid, + "block_geoid": geography.block_geoid, + "base_n_records": n_records, + "dataset_for_matrix": dataset_for_matrix, + } + return ( + None, + targets_df, + X_sparse, + target_names, + geography_info, + ) + # Step 7: L0 calibration targets = targets_df["value"].values @@ -686,13 +1362,38 @@ def sim_modifier(s, clone_idx): lambda_l0=lambda_l0, epochs=epochs, device=device, + beta=beta, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + log_freq=log_freq, + log_path=log_path, + target_names=target_names, + initial_weights=initial_weights, + targets_df=targets_df, + achievable=achievable, ) logger.info( "Total pipeline: %.1f min", (time.time() - t0) / 60, ) - return weights, targets_df, X_sparse, target_names + geography_info = { + "cd_geoid": geography.cd_geoid, + "block_geoid": geography.block_geoid, + "base_n_records": n_records, + "dataset_for_matrix": dataset_for_matrix, + "entity_hh_idx_map": getattr(builder, "entity_hh_idx_map", None), + "household_ids": getattr(builder, "household_ids", None), + "precomputed_rates": getattr(builder, "precomputed_rates", None), + "affected_target_info": getattr(builder, "affected_target_info", None), + } + return ( + weights, + targets_df, + X_sparse, + target_names, + geography_info, + ) def main(argv=None): @@ -710,17 +1411,18 @@ def main(argv=None): pass args = parse_args(argv) + logger.info("CLI args: %s", vars(args)) from policyengine_us_data.storage import STORAGE_FOLDER dataset_path = args.dataset or str( - STORAGE_FOLDER / "stratified_extended_cps_2024.h5" + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" ) db_path = args.db_path or str( STORAGE_FOLDER / "calibration" / "policy_data.db" ) output_path = args.output or str( - STORAGE_FOLDER / "calibration" / "unified_weights.npy" + STORAGE_FOLDER / "calibration" / "calibration_weights.npy" ) if args.lambda_l0 is not None: @@ -744,7 +1446,27 @@ def main(argv=None): t_start = time.time() - weights, targets_df, X_sparse, target_names = run_calibration( + target_config = None + if args.target_config: + target_config = load_target_config(args.target_config) + + package_output_path = args.package_output + if args.build_only and not package_output_path: + package_output_path = str( + STORAGE_FOLDER / "calibration" / "calibration_package.pkl" + ) + + output_dir = Path(output_path).parent + cal_log_path = None + if args.log_freq is not None: + cal_log_path = str(output_dir / "calibration_log.csv") + ( + weights, + targets_df, + X_sparse, + target_names, + geography_info, + ) = run_calibration( dataset_path=dataset_path, db_path=db_path, n_clones=args.n_clones, @@ -755,17 +1477,29 @@ def main(argv=None): domain_variables=domain_variables, hierarchical_domains=hierarchical_domains, skip_takeup_rerandomize=args.skip_takeup_rerandomize, - puf_dataset_path=args.puf_dataset, - skip_puf=args.skip_puf, skip_source_impute=args.skip_source_impute, + skip_county=not args.county_level, + target_config=target_config, + build_only=args.build_only, + package_path=args.package_path, + package_output_path=package_output_path, + beta=args.beta, + lambda_l2=args.lambda_l2, + learning_rate=args.learning_rate, + log_freq=args.log_freq, + log_path=cal_log_path, + workers=args.workers, ) - # Save weights - np.save(output_path, weights) - logger.info("Weights saved to %s", output_path) - print(f"OUTPUT_PATH:{output_path}") + source_imputed = geography_info.get("dataset_for_matrix") + if source_imputed and source_imputed != dataset_path: + print(f"SOURCE_IMPUTED_PATH:{source_imputed}") + + if weights is None: + logger.info("Build-only complete. Package saved.") + return - # Save diagnostics + # Diagnostics (raw weights match X_sparse column layout) output_dir = Path(output_path).parent diag_df = compute_diagnostics(weights, X_sparse, targets_df, target_names) diag_path = output_dir / "unified_diagnostics.csv" @@ -784,33 +1518,220 @@ def main(argv=None): (err_pct < 25).mean() * 100, ) + # Convert to stacked format for stacked_dataset_builder + cd_geoid = geography_info.get("cd_geoid") + base_n_records = geography_info.get("base_n_records") + + if cd_geoid is not None and base_n_records is not None: + from policyengine_us_data.calibration.calibration_utils import ( + save_geo_labels, + ) + + cds_ordered = sorted(set(cd_geoid)) + save_geo_labels(cds_ordered, output_dir / "geo_labels.json") + print(f"GEO_LABELS_PATH:{output_dir / 'geo_labels.json'}") + logger.info( + "Saved %d geo labels to %s", + len(cds_ordered), + output_dir / "geo_labels.json", + ) + stacked_weights = convert_weights_to_stacked_format( + weights=weights, + cd_geoid=cd_geoid, + base_n_records=base_n_records, + cds_ordered=cds_ordered, + ) + else: + logger.warning("No geography info available; saving raw weights") + stacked_weights = weights + + # Save stacked blocks alongside weights + block_geoid = geography_info.get("block_geoid") + if ( + block_geoid is not None + and cd_geoid is not None + and base_n_records is not None + ): + blocks_stacked = convert_blocks_to_stacked_format( + block_geoid=block_geoid, + cd_geoid=cd_geoid, + base_n_records=base_n_records, + cds_ordered=cds_ordered, + ) + blocks_path = output_dir / "stacked_blocks.npy" + np.save(str(blocks_path), blocks_stacked) + logger.info("Stacked blocks saved to %s", blocks_path) + print(f"BLOCKS_PATH:{blocks_path}") + + # Save stacked takeup (weight-averaged per CD × entity) + entity_hh_idx_map = geography_info.get("entity_hh_idx_map") + affected_info = geography_info.get("affected_target_info") + hh_ids_for_takeup = geography_info.get("household_ids") + rates_for_takeup = geography_info.get("precomputed_rates") + + # Rebuild entity mappings from sim when loading from package + if ( + block_geoid is not None + and cd_geoid is not None + and base_n_records is not None + and entity_hh_idx_map is None + ): + logger.info("Rebuilding entity mappings for stacked takeup...") + from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + ) + from policyengine_us_data.parameters import ( + load_take_up_rate, + ) + + ds = source_imputed or dataset_path + sim = Microsimulation(dataset=ds) + tp = int(sim.default_calculation_period) + + hh_ids_for_takeup = sim.calculate( + "household_id", map_to="household" + ).values + person_hh_ids = sim.calculate("household_id", map_to="person").values + hh_id_to_idx = {int(h): i for i, h in enumerate(hh_ids_for_takeup)} + person_hh_idx = np.array([hh_id_to_idx[int(h)] for h in person_hh_ids]) + + import pandas as pd + + entity_rel_df = pd.DataFrame( + { + "household_id": sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": sim.calculate( + "spm_unit_id", map_to="person" + ).values, + } + ) + spm_to_hh = ( + entity_rel_df.groupby("spm_unit_id")["household_id"] + .first() + .to_dict() + ) + spm_ids = sim.calculate("spm_unit_id", map_to="spm_unit").values + spm_hh_idx = np.array( + [hh_id_to_idx[int(spm_to_hh[int(s)])] for s in spm_ids] + ) + + tu_to_hh = ( + entity_rel_df.groupby("tax_unit_id")["household_id"] + .first() + .to_dict() + ) + tu_ids = sim.calculate("tax_unit_id", map_to="tax_unit").values + tu_hh_idx = np.array( + [hh_id_to_idx[int(tu_to_hh[int(t)])] for t in tu_ids] + ) + + entity_hh_idx_map = { + "spm_unit": spm_hh_idx, + "tax_unit": tu_hh_idx, + "person": person_hh_idx, + } + + unique_vars = set(targets_df["variable"].values) + affected_info = {} + for tvar in unique_vars: + for key, info in TAKEUP_AFFECTED_TARGETS.items(): + if tvar == key: + affected_info[tvar] = info + break + + rates_for_takeup = {} + for tvar, info in affected_info.items(): + rk = info["rate_key"] + if rk not in rates_for_takeup: + rates_for_takeup[rk] = load_take_up_rate(rk, tp) + + del sim + logger.info( + "Rebuilt entity mappings: %d affected vars", + len(affected_info), + ) + + if ( + block_geoid is not None + and cd_geoid is not None + and base_n_records is not None + and entity_hh_idx_map is not None + and affected_info + ): + import time as _time + + t_takeup = _time.time() + stacked_tu = compute_stacked_takeup( + weights=weights, + cd_geoid=cd_geoid, + block_geoid=block_geoid, + base_n_records=base_n_records, + cds_ordered=cds_ordered, + entity_hh_idx_map=entity_hh_idx_map, + household_ids=hh_ids_for_takeup, + precomputed_rates=rates_for_takeup, + affected_target_info=affected_info, + ) + takeup_path = output_dir / "stacked_takeup.npz" + np.savez_compressed(str(takeup_path), **stacked_tu) + logger.info( + "Stacked takeup saved to %s (%.1f min)", + takeup_path, + (_time.time() - t_takeup) / 60, + ) + print(f"TAKEUP_PATH:{takeup_path}") + + # Save weights + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + np.save(output_path, stacked_weights) + logger.info("Weights saved to %s", output_path) + print(f"OUTPUT_PATH:{output_path}") + # Save run config t_end = time.time() + weight_format = ( + "stacked" + if cd_geoid is not None and base_n_records is not None + else "raw" + ) run_config = { "dataset": dataset_path, "db_path": db_path, - "puf_dataset": args.puf_dataset, - "skip_puf": args.skip_puf, "skip_source_impute": args.skip_source_impute, "n_clones": args.n_clones, "lambda_l0": lambda_l0, + "beta": args.beta, + "lambda_l2": args.lambda_l2, + "learning_rate": args.learning_rate, "epochs": args.epochs, "device": args.device, "seed": args.seed, "domain_variables": domain_variables, "hierarchical_domains": hierarchical_domains, + "target_config": args.target_config, "n_targets": len(targets_df), "n_records": X_sparse.shape[1], - "weight_sum": float(weights.sum()), - "weight_nonzero": int((weights > 0).sum()), + "geo_labels_file": "geo_labels.json", + "weight_format": weight_format, + "weight_sum": float(stacked_weights.sum()), + "weight_nonzero": int((stacked_weights > 0).sum()), "mean_error_pct": float(err_pct.mean()), "elapsed_seconds": round(t_end - t_start, 1), } + run_config.update(get_git_provenance()) config_path = output_dir / "unified_run_config.json" with open(config_path, "w") as f: json.dump(run_config, f, indent=2) logger.info("Config saved to %s", config_path) + print(f"CONFIG_PATH:{config_path}") print(f"LOG_PATH:{diag_path}") + if cal_log_path: + print(f"CAL_LOG_PATH:{cal_log_path}") if __name__ == "__main__": diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index ac31c34e1..1fa8ea73f 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -21,11 +21,14 @@ from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.utils.census import STATE_NAME_TO_FIPS -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( +from policyengine_us_data.calibration.calibration_utils import ( get_calculated_variables, apply_op, get_geo_level, ) +from policyengine_us_data.calibration.block_assignment import ( + get_county_enum_index_from_fips, +) logger = logging.getLogger(__name__) @@ -35,6 +38,680 @@ "congressional_district_geoid", } +COUNTY_DEPENDENT_VARS = { + "aca_ptc", +} + + +def _compute_single_state( + dataset_path: str, + time_period: int, + state: int, + n_hh: int, + target_vars: list, + constraint_vars: list, + rerandomize_takeup: bool, + affected_targets: dict, +): + """Compute household/person/entity values for one state. + + Top-level function (not a method) so it is picklable for + ``ProcessPoolExecutor``. + + Args: + dataset_path: Path to the base CPS h5 file. + time_period: Tax year for simulation. + state: State FIPS code. + n_hh: Number of household records. + target_vars: Target variable names (list for determinism). + constraint_vars: Constraint variable names (list). + rerandomize_takeup: Force takeup=True if True. + affected_targets: Takeup-affected target info dict. + + Returns: + (state_fips, {"hh": {...}, "person": {...}, "entity": {...}}) + """ + from policyengine_us import Microsimulation + from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS + from policyengine_us_data.calibration.calibration_utils import ( + get_calculated_variables, + ) + + state_sim = Microsimulation(dataset=dataset_path) + + state_sim.set_input( + "state_fips", + time_period, + np.full(n_hh, state, dtype=np.int32), + ) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + + hh = {} + for var in target_vars: + if var.endswith("_count"): + continue + try: + hh[var] = state_sim.calculate( + var, + time_period, + map_to="household", + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate '%s' for state %d: %s", + var, + state, + exc, + ) + + person = {} + for var in constraint_vars: + try: + person[var] = state_sim.calculate( + var, + time_period, + map_to="person", + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate constraint '%s' " "for state %d: %s", + var, + state, + exc, + ) + + if rerandomize_takeup: + for spec in SIMPLE_TAKEUP_VARS: + entity = spec["entity"] + n_ent = len( + state_sim.calculate(f"{entity}_id", map_to=entity).values + ) + state_sim.set_input( + spec["variable"], + time_period, + np.ones(n_ent, dtype=bool), + ) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + + entity_vals = {} + if rerandomize_takeup: + for tvar, info in affected_targets.items(): + entity_level = info["entity"] + try: + entity_vals[tvar] = state_sim.calculate( + tvar, + time_period, + map_to=entity_level, + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate entity-level " + "'%s' (map_to=%s) for state %d: %s", + tvar, + entity_level, + state, + exc, + ) + + return (state, {"hh": hh, "person": person, "entity": entity_vals}) + + +def _compute_single_state_group_counties( + dataset_path: str, + time_period: int, + state_fips: int, + counties: list, + n_hh: int, + county_dep_targets: list, + rerandomize_takeup: bool, + affected_targets: dict, +): + """Compute county-dependent values for all counties in one state. + + Top-level function (not a method) so it is picklable for + ``ProcessPoolExecutor``. Creates one ``Microsimulation`` per + state and reuses it across counties within that state. + + Args: + dataset_path: Path to the base CPS h5 file. + time_period: Tax year for simulation. + state_fips: State FIPS code for this group. + counties: List of county FIPS strings in this state. + n_hh: Number of household records. + county_dep_targets: County-dependent target var names. + rerandomize_takeup: Force takeup=True if True. + affected_targets: Takeup-affected target info dict. + + Returns: + list of (county_fips_str, {"hh": {...}, "entity": {...}}) + """ + from policyengine_us import Microsimulation + from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS + from policyengine_us_data.calibration.calibration_utils import ( + get_calculated_variables, + ) + from policyengine_us_data.calibration.block_assignment import ( + get_county_enum_index_from_fips, + ) + + state_sim = Microsimulation(dataset=dataset_path) + + state_sim.set_input( + "state_fips", + time_period, + np.full(n_hh, state_fips, dtype=np.int32), + ) + + original_takeup = {} + if rerandomize_takeup: + for spec in SIMPLE_TAKEUP_VARS: + entity = spec["entity"] + original_takeup[spec["variable"]] = ( + entity, + state_sim.calculate( + spec["variable"], + time_period, + map_to=entity, + ).values.copy(), + ) + + results = [] + for county_fips in counties: + county_idx = get_county_enum_index_from_fips(county_fips) + state_sim.set_input( + "county", + time_period, + np.full(n_hh, county_idx, dtype=np.int32), + ) + if rerandomize_takeup: + for vname, (ent, orig) in original_takeup.items(): + state_sim.set_input(vname, time_period, orig) + for var in get_calculated_variables(state_sim): + if var != "county": + state_sim.delete_arrays(var) + + hh = {} + for var in county_dep_targets: + if var.endswith("_count"): + continue + try: + hh[var] = state_sim.calculate( + var, + time_period, + map_to="household", + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate '%s' for " "county %s: %s", + var, + county_fips, + exc, + ) + + if rerandomize_takeup: + for spec in SIMPLE_TAKEUP_VARS: + entity = spec["entity"] + n_ent = len( + state_sim.calculate(f"{entity}_id", map_to=entity).values + ) + state_sim.set_input( + spec["variable"], + time_period, + np.ones(n_ent, dtype=bool), + ) + for var in get_calculated_variables(state_sim): + if var != "county": + state_sim.delete_arrays(var) + + entity_vals = {} + if rerandomize_takeup: + for tvar, info in affected_targets.items(): + entity_level = info["entity"] + try: + entity_vals[tvar] = state_sim.calculate( + tvar, + time_period, + map_to=entity_level, + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate entity-level " + "'%s' for county %s: %s", + tvar, + county_fips, + exc, + ) + + results.append((county_fips, {"hh": hh, "entity": entity_vals})) + + return results + + +# --------------------------------------------------------------- +# Clone-loop parallelisation helpers (module-level for pickling) +# --------------------------------------------------------------- + +_CLONE_SHARED: dict = {} + + +def _init_clone_worker(shared_data: dict) -> None: + """Initialise worker process with shared read-only data. + + Called once per worker at ``ProcessPoolExecutor`` startup so the + ~50-200 MB payload is pickled *per worker* (not per clone). + """ + _CLONE_SHARED.update(shared_data) + + +def _assemble_clone_values_standalone( + state_values: dict, + clone_states: np.ndarray, + person_hh_indices: np.ndarray, + target_vars: set, + constraint_vars: set, + county_values: dict = None, + clone_counties: np.ndarray = None, + county_dependent_vars: set = None, +) -> tuple: + """Standalone clone-value assembly (no ``self``). + + Identical logic to + ``UnifiedMatrixBuilder._assemble_clone_values`` but usable + from a worker process. + """ + n_records = len(clone_states) + n_persons = len(person_hh_indices) + person_states = clone_states[person_hh_indices] + unique_clone_states = np.unique(clone_states) + cdv = county_dependent_vars or set() + + state_masks = {int(s): clone_states == s for s in unique_clone_states} + unique_person_states = np.unique(person_states) + person_state_masks = { + int(s): person_states == s for s in unique_person_states + } + county_masks = {} + unique_counties = None + if clone_counties is not None and county_values: + unique_counties = np.unique(clone_counties) + county_masks = {c: clone_counties == c for c in unique_counties} + + hh_vars: dict = {} + for var in target_vars: + if var.endswith("_count"): + continue + if var in cdv and county_values and clone_counties is not None: + first_county = unique_counties[0] + if var not in county_values.get(first_county, {}).get("hh", {}): + continue + arr = np.empty(n_records, dtype=np.float32) + for county in unique_counties: + mask = county_masks[county] + county_hh = county_values.get(county, {}).get("hh", {}) + if var in county_hh: + arr[mask] = county_hh[var][mask] + else: + st = int(county[:2]) + arr[mask] = state_values[st]["hh"][var][mask] + hh_vars[var] = arr + else: + if var not in state_values[unique_clone_states[0]]["hh"]: + continue + arr = np.empty(n_records, dtype=np.float32) + for state in unique_clone_states: + mask = state_masks[int(state)] + arr[mask] = state_values[int(state)]["hh"][var][mask] + hh_vars[var] = arr + + person_vars: dict = {} + for var in constraint_vars: + if var not in state_values[unique_clone_states[0]]["person"]: + continue + arr = np.empty(n_persons, dtype=np.float32) + for state in unique_person_states: + mask = person_state_masks[int(state)] + arr[mask] = state_values[int(state)]["person"][var][mask] + person_vars[var] = arr + + return hh_vars, person_vars + + +def _evaluate_constraints_standalone( + constraints, + person_vars: dict, + entity_rel: pd.DataFrame, + household_ids: np.ndarray, + n_households: int, +) -> np.ndarray: + """Standalone constraint evaluation (no class instance). + + Evaluates person-level constraints and aggregates to + household level via .any(). + """ + if not constraints: + return np.ones(n_households, dtype=bool) + + n_persons = len(entity_rel) + person_mask = np.ones(n_persons, dtype=bool) + + for c in constraints: + var = c["variable"] + if var not in person_vars: + logger.warning( + "Constraint var '%s' not in " "precomputed person_vars", + var, + ) + return np.zeros(n_households, dtype=bool) + vals = person_vars[var] + person_mask &= apply_op(vals, c["operation"], c["value"]) + + df = entity_rel.copy() + df["satisfies"] = person_mask + hh_mask = df.groupby("household_id")["satisfies"].any() + return np.array([hh_mask.get(hid, False) for hid in household_ids]) + + +def _calculate_target_values_standalone( + target_variable: str, + non_geo_constraints: list, + n_households: int, + hh_vars: dict, + person_vars: dict, + entity_rel: pd.DataFrame, + household_ids: np.ndarray, + variable_entity_map: dict, +) -> np.ndarray: + """Standalone target-value calculation (no class instance). + + Uses ``variable_entity_map`` dict for entity resolution + (picklable, unlike ``tax_benefit_system``). + """ + is_count = target_variable.endswith("_count") + + if not is_count: + mask = _evaluate_constraints_standalone( + non_geo_constraints, + person_vars, + entity_rel, + household_ids, + n_households, + ) + vals = hh_vars.get(target_variable) + if vals is None: + return np.zeros(n_households, dtype=np.float32) + return (vals * mask).astype(np.float32) + + # Count target: entity-aware counting + n_persons = len(entity_rel) + person_mask = np.ones(n_persons, dtype=bool) + + for c in non_geo_constraints: + var = c["variable"] + if var not in person_vars: + return np.zeros(n_households, dtype=np.float32) + cv = person_vars[var] + person_mask &= apply_op(cv, c["operation"], c["value"]) + + target_entity = variable_entity_map.get(target_variable) + if target_entity is None: + return np.zeros(n_households, dtype=np.float32) + + if target_entity == "household": + if non_geo_constraints: + mask = _evaluate_constraints_standalone( + non_geo_constraints, + person_vars, + entity_rel, + household_ids, + n_households, + ) + return mask.astype(np.float32) + return np.ones(n_households, dtype=np.float32) + + if target_entity == "person": + er = entity_rel.copy() + er["satisfies"] = person_mask + filtered = er[er["satisfies"]] + counts = filtered.groupby("household_id")["person_id"].nunique() + else: + eid_col = f"{target_entity}_id" + er = entity_rel.copy() + er["satisfies"] = person_mask + entity_ok = er.groupby(eid_col)["satisfies"].any() + unique = er[["household_id", eid_col]].drop_duplicates() + unique["entity_ok"] = unique[eid_col].map(entity_ok) + filtered = unique[unique["entity_ok"]] + counts = filtered.groupby("household_id")[eid_col].nunique() + + return np.array( + [counts.get(hid, 0) for hid in household_ids], + dtype=np.float32, + ) + + +def _process_single_clone( + clone_idx: int, + col_start: int, + col_end: int, + cache_path: str, +) -> tuple: + """Process one clone in a worker process. + + Reads shared read-only data from ``_CLONE_SHARED`` + (populated by ``_init_clone_worker``). Writes COO + entries as a compressed ``.npz`` file to *cache_path*. + + Args: + clone_idx: Zero-based clone index. + col_start: First column index for this clone. + col_end: One-past-last column index. + cache_path: File path for output ``.npz``. + + Returns: + (clone_idx, n_nonzero) tuple. + """ + sd = _CLONE_SHARED + + # Unpack shared data + geo_states = sd["geography_state_fips"] + geo_counties = sd["geography_county_fips"] + geo_blocks = sd["geography_block_geoid"] + state_values = sd["state_values"] + county_values = sd["county_values"] + person_hh_indices = sd["person_hh_indices"] + unique_variables = sd["unique_variables"] + unique_constraint_vars = sd["unique_constraint_vars"] + county_dep_targets = sd["county_dep_targets"] + target_variables = sd["target_variables"] + target_geo_info = sd["target_geo_info"] + non_geo_constraints_list = sd["non_geo_constraints_list"] + n_records = sd["n_records"] + n_total = sd["n_total"] + n_targets = sd["n_targets"] + state_to_cols = sd["state_to_cols"] + cd_to_cols = sd["cd_to_cols"] + entity_rel = sd["entity_rel"] + household_ids = sd["household_ids"] + variable_entity_map = sd["variable_entity_map"] + do_takeup = sd["rerandomize_takeup"] + affected_target_info = sd["affected_target_info"] + entity_hh_idx_map = sd.get("entity_hh_idx_map", {}) + entity_to_person_idx = sd.get("entity_to_person_idx", {}) + precomputed_rates = sd.get("precomputed_rates", {}) + + # Slice geography for this clone + clone_states = geo_states[col_start:col_end] + clone_counties = geo_counties[col_start:col_end] + + # Assemble hh/person values from precomputed state/county + hh_vars, person_vars = _assemble_clone_values_standalone( + state_values, + clone_states, + person_hh_indices, + unique_variables, + unique_constraint_vars, + county_values=county_values, + clone_counties=clone_counties, + county_dependent_vars=county_dep_targets, + ) + + # Takeup re-randomisation + if do_takeup and affected_target_info: + from policyengine_us_data.utils.takeup import ( + compute_block_takeup_for_entities, + ) + + clone_blocks = geo_blocks[col_start:col_end] + + for tvar, info in affected_target_info.items(): + if tvar.endswith("_count"): + continue + entity_level = info["entity"] + takeup_var = info["takeup_var"] + ent_hh = entity_hh_idx_map[entity_level] + n_ent = len(ent_hh) + ent_states = clone_states[ent_hh] + + ent_eligible = np.zeros(n_ent, dtype=np.float32) + if tvar in county_dep_targets and county_values: + ent_counties = clone_counties[ent_hh] + for cfips in np.unique(ent_counties): + m = ent_counties == cfips + cv = county_values.get(cfips, {}).get("entity", {}) + if tvar in cv: + ent_eligible[m] = cv[tvar][m] + else: + st = int(cfips[:2]) + sv = state_values[st]["entity"] + if tvar in sv: + ent_eligible[m] = sv[tvar][m] + else: + for st in np.unique(ent_states): + m = ent_states == st + sv = state_values[int(st)]["entity"] + if tvar in sv: + ent_eligible[m] = sv[tvar][m] + + ent_blocks = clone_blocks[ent_hh] + ent_hh_ids = household_ids[ent_hh] + + ent_takeup = compute_block_takeup_for_entities( + takeup_var, + precomputed_rates[info["rate_key"]], + ent_blocks, + ent_hh_ids, + ) + + ent_values = (ent_eligible * ent_takeup).astype(np.float32) + + hh_result = np.zeros(n_records, dtype=np.float32) + np.add.at(hh_result, ent_hh, ent_values) + hh_vars[tvar] = hh_result + + if tvar in person_vars: + pidx = entity_to_person_idx[entity_level] + person_vars[tvar] = ent_values[pidx] + + # Build COO entries for every target row + mask_cache: dict = {} + count_cache: dict = {} + rows_list: list = [] + cols_list: list = [] + vals_list: list = [] + + for row_idx in range(n_targets): + variable = target_variables[row_idx] + geo_level, geo_id = target_geo_info[row_idx] + non_geo = non_geo_constraints_list[row_idx] + + if geo_level == "district": + all_geo_cols = cd_to_cols.get( + str(geo_id), + np.array([], dtype=np.int64), + ) + elif geo_level == "state": + all_geo_cols = state_to_cols.get( + int(geo_id), + np.array([], dtype=np.int64), + ) + else: + all_geo_cols = np.arange(n_total) + + clone_cols = all_geo_cols[ + (all_geo_cols >= col_start) & (all_geo_cols < col_end) + ] + if len(clone_cols) == 0: + continue + + rec_indices = clone_cols - col_start + + constraint_key = tuple( + sorted( + ( + c["variable"], + c["operation"], + c["value"], + ) + for c in non_geo + ) + ) + + if variable.endswith("_count"): + vkey = (variable, constraint_key) + if vkey not in count_cache: + count_cache[vkey] = _calculate_target_values_standalone( + variable, + non_geo, + n_records, + hh_vars, + person_vars, + entity_rel, + household_ids, + variable_entity_map, + ) + values = count_cache[vkey] + else: + if variable not in hh_vars: + continue + if constraint_key not in mask_cache: + mask_cache[constraint_key] = _evaluate_constraints_standalone( + non_geo, + person_vars, + entity_rel, + household_ids, + n_records, + ) + mask = mask_cache[constraint_key] + values = hh_vars[variable] * mask + + vals = values[rec_indices] + nonzero = vals != 0 + if nonzero.any(): + rows_list.append( + np.full( + nonzero.sum(), + row_idx, + dtype=np.int32, + ) + ) + cols_list.append(clone_cols[nonzero].astype(np.int32)) + vals_list.append(vals[nonzero]) + + # Write COO + if rows_list: + cr = np.concatenate(rows_list) + cc = np.concatenate(cols_list) + cv = np.concatenate(vals_list) + else: + cr = np.array([], dtype=np.int32) + cc = np.array([], dtype=np.int32) + cv = np.array([], dtype=np.float32) + + np.savez_compressed(cache_path, rows=cr, cols=cc, vals=cv) + return clone_idx, len(cv) + class UnifiedMatrixBuilder: """Build sparse calibration matrix for cloned CPS records. @@ -88,48 +765,597 @@ def _build_entity_relationship(self, sim) -> pd.DataFrame: return self._entity_rel_cache # --------------------------------------------------------------- - # Constraint evaluation + # Per-state precomputation # --------------------------------------------------------------- - def _evaluate_constraints_entity_aware( + def _build_state_values( self, sim, - constraints: List[dict], - n_households: int, - ) -> np.ndarray: - """Evaluate constraints at person level, aggregate to - household level via .any().""" - if not constraints: - return np.ones(n_households, dtype=bool) + target_vars: set, + constraint_vars: set, + geography, + rerandomize_takeup: bool = True, + workers: int = 1, + ) -> dict: + """Precompute household/person/entity values per state. - entity_rel = self._build_entity_relationship(sim) - n_persons = len(entity_rel) - person_mask = np.ones(n_persons, dtype=bool) + Creates a fresh Microsimulation per state to prevent + cross-state cache pollution (stale intermediate values + from one state leaking into another's calculations). - for c in constraints: - try: - vals = sim.calculate( - c["variable"], + County-dependent variables (e.g. aca_ptc) are computed + here as a state-level fallback; county-level overrides + are applied later via ``_build_county_values``. + + Args: + sim: Microsimulation instance (unused; kept for API + compatibility). + target_vars: Set of target variable names. + constraint_vars: Set of constraint variable names. + geography: GeographyAssignment with state_fips. + rerandomize_takeup: If True, force takeup=True and + also store entity-level eligible amounts for + takeup-affected targets. + workers: Number of parallel worker processes. + When >1, uses ProcessPoolExecutor. + + Returns: + {state_fips: { + 'hh': {var: array}, + 'person': {var: array}, + 'entity': {var: array} # only if rerandomize + }} + """ + from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + ) + + unique_states = sorted(set(int(s) for s in geography.state_fips)) + n_hh = geography.n_records + + logger.info( + "Per-state precomputation: %d states, " + "%d hh vars, %d constraint vars " + "(fresh sim per state, workers=%d)", + len(unique_states), + len([v for v in target_vars if not v.endswith("_count")]), + len(constraint_vars), + workers, + ) + + # Identify takeup-affected targets before the state loop + affected_targets = {} + if rerandomize_takeup: + for tvar in target_vars: + for key, info in TAKEUP_AFFECTED_TARGETS.items(): + if tvar == key or tvar.startswith(key): + affected_targets[tvar] = info + break + + # Convert sets to sorted lists for deterministic iteration + target_vars_list = sorted(target_vars) + constraint_vars_list = sorted(constraint_vars) + + state_values = {} + + if workers > 1: + from concurrent.futures import ( + ProcessPoolExecutor, + as_completed, + ) + + logger.info( + "Parallel state precomputation with %d workers", + workers, + ) + with ProcessPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit( + _compute_single_state, + self.dataset_path, + self.time_period, + st, + n_hh, + target_vars_list, + constraint_vars_list, + rerandomize_takeup, + affected_targets, + ): st + for st in unique_states + } + completed = 0 + for future in as_completed(futures): + st = futures[future] + try: + sf, vals = future.result() + state_values[sf] = vals + completed += 1 + if completed % 10 == 0 or completed == 1: + logger.info( + "State %d/%d complete", + completed, + len(unique_states), + ) + except Exception as exc: + for f in futures: + f.cancel() + raise RuntimeError( + f"State {st} failed: {exc}" + ) from exc + else: + from policyengine_us import Microsimulation + from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS, + ) + + for i, state in enumerate(unique_states): + state_sim = Microsimulation(dataset=self.dataset_path) + + state_sim.set_input( + "state_fips", self.time_period, - map_to="person", - ).values - except Exception as exc: - logger.warning( - "Cannot evaluate constraint '%s': %s", - c["variable"], - exc, + np.full(n_hh, state, dtype=np.int32), ) - return np.zeros(n_households, dtype=bool) - person_mask &= apply_op(vals, c["operation"], c["value"]) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) - df = entity_rel.copy() - df["satisfies"] = person_mask - hh_mask = df.groupby("household_id")["satisfies"].any() + hh = {} + for var in target_vars: + if var.endswith("_count"): + continue + try: + hh[var] = state_sim.calculate( + var, + self.time_period, + map_to="household", + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate '%s' " "for state %d: %s", + var, + state, + exc, + ) - household_ids = sim.calculate( - "household_id", map_to="household" - ).values - return np.array([hh_mask.get(hid, False) for hid in household_ids]) + person = {} + for var in constraint_vars: + try: + person[var] = state_sim.calculate( + var, + self.time_period, + map_to="person", + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate constraint " + "'%s' for state %d: %s", + var, + state, + exc, + ) + + if rerandomize_takeup: + for spec in SIMPLE_TAKEUP_VARS: + entity = spec["entity"] + n_ent = len( + state_sim.calculate( + f"{entity}_id", map_to=entity + ).values + ) + state_sim.set_input( + spec["variable"], + self.time_period, + np.ones(n_ent, dtype=bool), + ) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + + entity_vals = {} + if rerandomize_takeup: + for tvar, info in affected_targets.items(): + entity_level = info["entity"] + try: + entity_vals[tvar] = state_sim.calculate( + tvar, + self.time_period, + map_to=entity_level, + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate entity-level " + "'%s' (map_to=%s) for " + "state %d: %s", + tvar, + entity_level, + state, + exc, + ) + + state_values[state] = { + "hh": hh, + "person": person, + "entity": entity_vals, + } + if (i + 1) % 10 == 0 or i == 0: + logger.info( + "State %d/%d complete", + i + 1, + len(unique_states), + ) + + logger.info( + "Per-state precomputation done: %d states", + len(state_values), + ) + return state_values + + def _build_county_values( + self, + sim, + county_dep_targets: set, + geography, + rerandomize_takeup: bool = True, + county_level: bool = True, + workers: int = 1, + ) -> dict: + """Precompute county-dependent variable values per county. + + Only iterates over COUNTY_DEPENDENT_VARS that actually + benefit from per-county computation. All other target + variables use state-level values from _build_state_values. + + Creates a fresh Microsimulation per state group to prevent + cross-state cache pollution. Counties within the same state + share a simulation since within-state recalculation is clean + (only cross-state switches cause pollution). + + When county_level=False, returns an empty dict immediately + (all values come from state-level precomputation). + + Args: + sim: Microsimulation instance (unused; kept for API + compatibility). + county_dep_targets: Subset of target vars that depend + on county (intersection of targets with + COUNTY_DEPENDENT_VARS). + geography: GeographyAssignment with county_fips. + rerandomize_takeup: If True, force takeup=True and + also store entity-level eligible amounts for + takeup-affected targets. + county_level: If True, iterate counties within each + state. If False, return empty dict (skip county + computation entirely). + workers: Number of parallel worker processes. + When >1, uses ProcessPoolExecutor. + + Returns: + {county_fips_str: { + 'hh': {var: array}, + 'entity': {var: array} + }} + """ + if not county_level or not county_dep_targets: + if not county_level: + logger.info( + "County-level computation disabled " "(skip-county mode)" + ) + else: + logger.info( + "No county-dependent target vars; " + "skipping county precomputation" + ) + return {} + + from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + ) + + unique_counties = sorted(set(geography.county_fips)) + n_hh = geography.n_records + + state_to_counties = defaultdict(list) + for county in unique_counties: + state_to_counties[int(county[:2])].append(county) + + logger.info( + "Per-county precomputation: %d counties in %d " + "states, %d county-dependent vars " + "(fresh sim per state, workers=%d)", + len(unique_counties), + len(state_to_counties), + len(county_dep_targets), + workers, + ) + + affected_targets = {} + if rerandomize_takeup: + for tvar in county_dep_targets: + for key, info in TAKEUP_AFFECTED_TARGETS.items(): + if tvar == key or tvar.startswith(key): + affected_targets[tvar] = info + break + + # Convert to sorted list for deterministic iteration + county_dep_targets_list = sorted(county_dep_targets) + + county_values = {} + + if workers > 1: + from concurrent.futures import ( + ProcessPoolExecutor, + as_completed, + ) + + logger.info( + "Parallel county precomputation with " + "%d workers (%d state groups)", + workers, + len(state_to_counties), + ) + with ProcessPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit( + _compute_single_state_group_counties, + self.dataset_path, + self.time_period, + sf, + counties, + n_hh, + county_dep_targets_list, + rerandomize_takeup, + affected_targets, + ): sf + for sf, counties in sorted(state_to_counties.items()) + } + completed = 0 + county_count = 0 + for future in as_completed(futures): + sf = futures[future] + try: + results = future.result() + for cfips, vals in results: + county_values[cfips] = vals + county_count += 1 + completed += 1 + if county_count % 500 == 0 or completed == 1: + logger.info( + "County %d/%d complete " + "(%d/%d state groups)", + county_count, + len(unique_counties), + completed, + len(state_to_counties), + ) + except Exception as exc: + for f in futures: + f.cancel() + raise RuntimeError( + f"State group {sf} failed: " f"{exc}" + ) from exc + else: + from policyengine_us import Microsimulation + from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS, + ) + + county_count = 0 + for state_fips, counties in sorted(state_to_counties.items()): + state_sim = Microsimulation(dataset=self.dataset_path) + + state_sim.set_input( + "state_fips", + self.time_period, + np.full(n_hh, state_fips, dtype=np.int32), + ) + + original_takeup = {} + if rerandomize_takeup: + for spec in SIMPLE_TAKEUP_VARS: + entity = spec["entity"] + original_takeup[spec["variable"]] = ( + entity, + state_sim.calculate( + spec["variable"], + self.time_period, + map_to=entity, + ).values.copy(), + ) + + for county_fips in counties: + county_idx = get_county_enum_index_from_fips(county_fips) + state_sim.set_input( + "county", + self.time_period, + np.full( + n_hh, + county_idx, + dtype=np.int32, + ), + ) + if rerandomize_takeup: + for vname, ( + ent, + orig, + ) in original_takeup.items(): + state_sim.set_input( + vname, + self.time_period, + orig, + ) + for var in get_calculated_variables(state_sim): + if var != "county": + state_sim.delete_arrays(var) + + hh = {} + for var in county_dep_targets: + if var.endswith("_count"): + continue + try: + hh[var] = state_sim.calculate( + var, + self.time_period, + map_to="household", + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate '%s' " "for county %s: %s", + var, + county_fips, + exc, + ) + + if rerandomize_takeup: + for spec in SIMPLE_TAKEUP_VARS: + entity = spec["entity"] + n_ent = len( + state_sim.calculate( + f"{entity}_id", + map_to=entity, + ).values + ) + state_sim.set_input( + spec["variable"], + self.time_period, + np.ones(n_ent, dtype=bool), + ) + for var in get_calculated_variables(state_sim): + if var != "county": + state_sim.delete_arrays(var) + + entity_vals = {} + if rerandomize_takeup: + for ( + tvar, + info, + ) in affected_targets.items(): + entity_level = info["entity"] + try: + entity_vals[tvar] = state_sim.calculate( + tvar, + self.time_period, + map_to=entity_level, + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Cannot calculate " + "entity-level '%s' " + "for county %s: %s", + tvar, + county_fips, + exc, + ) + + county_values[county_fips] = { + "hh": hh, + "entity": entity_vals, + } + county_count += 1 + if county_count % 500 == 0 or county_count == 1: + logger.info( + "County %d/%d complete", + county_count, + len(unique_counties), + ) + + logger.info( + "Per-county precomputation done: %d counties", + len(county_values), + ) + return county_values + + def _assemble_clone_values( + self, + state_values: dict, + clone_states: np.ndarray, + person_hh_indices: np.ndarray, + target_vars: set, + constraint_vars: set, + county_values: dict = None, + clone_counties: np.ndarray = None, + county_dependent_vars: set = None, + ) -> tuple: + """Assemble per-clone values from state/county precomputation. + + For each target variable, selects values from either + county_values (if the var is county-dependent) or + state_values (otherwise) using numpy fancy indexing. + + Args: + state_values: Output of _build_state_values. + clone_states: State FIPS per record for this clone. + person_hh_indices: Maps person index to household + index (0..n_records-1). + target_vars: Set of target variable names. + constraint_vars: Set of constraint variable names. + county_values: Output of _build_county_values. + clone_counties: County FIPS per record for this + clone (str array). + county_dependent_vars: Set of var names that should + be looked up by county instead of state. + + Returns: + (hh_vars, person_vars) where hh_vars maps variable + name to household-level float32 array and person_vars + maps constraint variable name to person-level array. + """ + n_records = len(clone_states) + n_persons = len(person_hh_indices) + person_states = clone_states[person_hh_indices] + unique_clone_states = np.unique(clone_states) + cdv = county_dependent_vars or set() + + # Pre-compute masks to avoid recomputing per variable + state_masks = {int(s): clone_states == s for s in unique_clone_states} + unique_person_states = np.unique(person_states) + person_state_masks = { + int(s): person_states == s for s in unique_person_states + } + county_masks = {} + unique_counties = None + if clone_counties is not None and county_values: + unique_counties = np.unique(clone_counties) + county_masks = {c: clone_counties == c for c in unique_counties} + + hh_vars = {} + for var in target_vars: + if var.endswith("_count"): + continue + if var in cdv and county_values and clone_counties is not None: + first_county = unique_counties[0] + if var not in county_values.get(first_county, {}).get( + "hh", {} + ): + continue + arr = np.empty(n_records, dtype=np.float32) + for county in unique_counties: + mask = county_masks[county] + county_hh = county_values.get(county, {}).get("hh", {}) + if var in county_hh: + arr[mask] = county_hh[var][mask] + else: + st = int(county[:2]) + arr[mask] = state_values[st]["hh"][var][mask] + hh_vars[var] = arr + else: + if var not in state_values[unique_clone_states[0]]["hh"]: + continue + arr = np.empty(n_records, dtype=np.float32) + for state in unique_clone_states: + mask = state_masks[int(state)] + arr[mask] = state_values[int(state)]["hh"][var][mask] + hh_vars[var] = arr + + person_vars = {} + for var in constraint_vars: + if var not in state_values[unique_clone_states[0]]["person"]: + continue + arr = np.empty(n_persons, dtype=np.float32) + for state in unique_person_states: + mask = person_state_masks[int(state)] + arr[mask] = state_values[int(state)]["person"][var][mask] + person_vars[var] = arr + + return hh_vars, person_vars # --------------------------------------------------------------- # Database queries @@ -261,14 +1487,7 @@ def _get_uprating_info( if period == self.time_period: return 1.0, "none" - count_indicators = [ - "count", - "person", - "people", - "households", - "tax_units", - ] - is_count = any(ind in variable.lower() for ind in count_indicators) + is_count = variable.endswith("_count") uprating_type = "pop" if is_count else "cpi" factor = factors.get((period, uprating_type), 1.0) return factor, uprating_type @@ -472,79 +1691,6 @@ def _make_target_name( return "/".join(parts) - # --------------------------------------------------------------- - # Target value calculation - # --------------------------------------------------------------- - - def _calculate_target_values( - self, - sim, - target_variable: str, - non_geo_constraints: List[dict], - n_households: int, - ) -> np.ndarray: - """Calculate per-household target values. - - For count targets (*_count): count entities per HH - satisfying constraints. - For value targets: multiply values by constraint mask. - """ - is_count = target_variable.endswith("_count") - - if not is_count: - mask = self._evaluate_constraints_entity_aware( - sim, non_geo_constraints, n_households - ) - vals = sim.calculate(target_variable, map_to="household").values - return (vals * mask).astype(np.float32) - - # Count target: entity-aware counting - entity_rel = self._build_entity_relationship(sim) - n_persons = len(entity_rel) - person_mask = np.ones(n_persons, dtype=bool) - - for c in non_geo_constraints: - try: - cv = sim.calculate(c["variable"], map_to="person").values - except Exception: - return np.zeros(n_households, dtype=np.float32) - person_mask &= apply_op(cv, c["operation"], c["value"]) - - target_entity = sim.tax_benefit_system.variables[ - target_variable - ].entity.key - household_ids = sim.calculate( - "household_id", map_to="household" - ).values - - if target_entity == "household": - if non_geo_constraints: - mask = self._evaluate_constraints_entity_aware( - sim, non_geo_constraints, n_households - ) - return mask.astype(np.float32) - return np.ones(n_households, dtype=np.float32) - - if target_entity == "person": - er = entity_rel.copy() - er["satisfies"] = person_mask - filtered = er[er["satisfies"]] - counts = filtered.groupby("household_id")["person_id"].nunique() - else: - eid_col = f"{target_entity}_id" - er = entity_rel.copy() - er["satisfies"] = person_mask - entity_ok = er.groupby(eid_col)["satisfies"].any() - unique = er[["household_id", eid_col]].drop_duplicates() - unique["entity_ok"] = unique[eid_col].map(entity_ok) - filtered = unique[unique["entity_ok"]] - counts = filtered.groupby("household_id")[eid_col].nunique() - - return np.array( - [counts.get(hid, 0) for hid in household_ids], - dtype=np.float32, - ) - # --------------------------------------------------------------- # Clone simulation # --------------------------------------------------------------- @@ -614,6 +1760,9 @@ def build_matrix( hierarchical_domains: Optional[List[str]] = None, cache_dir: Optional[str] = None, sim_modifier=None, + rerandomize_takeup: bool = True, + county_level: bool = True, + workers: int = 1, ) -> Tuple[pd.DataFrame, sparse.csr_matrix, List[str]]: """Build sparse calibration matrix. @@ -635,6 +1784,13 @@ def build_matrix( called per clone after state_fips is set but before cache clearing. Use for takeup re-randomization. + rerandomize_takeup: If True, use geo-salted + entity-level takeup draws instead of base h5 + takeup values for takeup-affected targets. + county_level: If True (default), iterate counties + within each state during precomputation. If + False, compute once per state and alias to all + counties (faster for county-invariant vars). Returns: (targets_df, X_sparse, target_names) @@ -720,156 +1876,457 @@ def build_matrix( unique_variables = set(targets_df["variable"].values) - # 5. Clone loop + # 5a. Collect unique constraint variables + unique_constraint_vars = set() + for constraints in non_geo_constraints_list: + for c in constraints: + unique_constraint_vars.add(c["variable"]) + + # 5b. Per-state precomputation (51 sims on one object) + self._entity_rel_cache = None + state_values = self._build_state_values( + sim, + unique_variables, + unique_constraint_vars, + geography, + rerandomize_takeup=rerandomize_takeup, + workers=workers, + ) + + # 5b-county. Per-county precomputation for county-dependent vars + county_dep_targets = unique_variables & COUNTY_DEPENDENT_VARS + county_values = self._build_county_values( + sim, + county_dep_targets, + geography, + rerandomize_takeup=rerandomize_takeup, + county_level=county_level, + workers=workers, + ) + + # 5c. State-independent structures (computed once) + entity_rel = self._build_entity_relationship(sim) + household_ids = sim.calculate( + "household_id", map_to="household" + ).values + person_hh_ids = sim.calculate("household_id", map_to="person").values + hh_id_to_idx = {int(hid): idx for idx, hid in enumerate(household_ids)} + person_hh_indices = np.array( + [hh_id_to_idx[int(hid)] for hid in person_hh_ids] + ) + tax_benefit_system = sim.tax_benefit_system + + # Pre-extract entity keys so workers don't need + # the unpicklable TaxBenefitSystem object. + variable_entity_map: Dict[str, str] = {} + for var in unique_variables: + if var.endswith("_count") and var in tax_benefit_system.variables: + variable_entity_map[var] = tax_benefit_system.variables[ + var + ].entity.key + + # 5c-extra: Entity-to-household index maps for takeup + affected_target_info = {} + if rerandomize_takeup: + from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + compute_block_takeup_for_entities, + ) + from policyengine_us_data.parameters import ( + load_take_up_rate, + ) + + # Build entity-to-household index arrays + spm_to_hh_id = ( + entity_rel.groupby("spm_unit_id")["household_id"] + .first() + .to_dict() + ) + spm_ids = sim.calculate("spm_unit_id", map_to="spm_unit").values + spm_hh_idx = np.array( + [hh_id_to_idx[int(spm_to_hh_id[int(sid)])] for sid in spm_ids] + ) + + tu_to_hh_id = ( + entity_rel.groupby("tax_unit_id")["household_id"] + .first() + .to_dict() + ) + tu_ids = sim.calculate("tax_unit_id", map_to="tax_unit").values + tu_hh_idx = np.array( + [hh_id_to_idx[int(tu_to_hh_id[int(tid)])] for tid in tu_ids] + ) + + entity_hh_idx_map = { + "spm_unit": spm_hh_idx, + "tax_unit": tu_hh_idx, + "person": person_hh_indices, + } + + entity_to_person_idx = {} + for entity_level in ("spm_unit", "tax_unit"): + ent_ids = sim.calculate( + f"{entity_level}_id", + map_to=entity_level, + ).values + ent_id_to_idx = { + int(eid): idx for idx, eid in enumerate(ent_ids) + } + person_ent_ids = entity_rel[f"{entity_level}_id"].values + entity_to_person_idx[entity_level] = np.array( + [ent_id_to_idx[int(eid)] for eid in person_ent_ids] + ) + entity_to_person_idx["person"] = np.arange(len(entity_rel)) + + for tvar in unique_variables: + for key, info in TAKEUP_AFFECTED_TARGETS.items(): + if tvar == key: + affected_target_info[tvar] = info + break + + logger.info( + "Block-level takeup enabled, " "%d affected target vars", + len(affected_target_info), + ) + + # Pre-compute takeup rates (constant across clones) + precomputed_rates = {} + for tvar, info in affected_target_info.items(): + rk = info["rate_key"] + if rk not in precomputed_rates: + precomputed_rates[rk] = load_take_up_rate( + rk, self.time_period + ) + + # Store for post-optimization stacked takeup + self.entity_hh_idx_map = entity_hh_idx_map + self.household_ids = household_ids + self.precomputed_rates = precomputed_rates + self.affected_target_info = affected_target_info + + # 5d. Clone loop from pathlib import Path - clone_dir = Path(cache_dir) if cache_dir else None - if clone_dir: + if workers > 1: + # ---- Parallel clone processing ---- + import concurrent.futures + import tempfile + + if cache_dir: + clone_dir = Path(cache_dir) + else: + clone_dir = Path(tempfile.mkdtemp(prefix="clone_coo_")) clone_dir.mkdir(parents=True, exist_ok=True) - self._entity_rel_cache = None + target_variables = [ + str(targets_df.iloc[i]["variable"]) for i in range(n_targets) + ] - for clone_idx in range(n_clones): + shared_data = { + "geography_state_fips": geography.state_fips, + "geography_county_fips": geography.county_fips, + "geography_block_geoid": geography.block_geoid, + "state_values": state_values, + "county_values": county_values, + "person_hh_indices": person_hh_indices, + "unique_variables": unique_variables, + "unique_constraint_vars": unique_constraint_vars, + "county_dep_targets": county_dep_targets, + "target_variables": target_variables, + "target_geo_info": target_geo_info, + "non_geo_constraints_list": (non_geo_constraints_list), + "n_records": n_records, + "n_total": n_total, + "n_targets": n_targets, + "state_to_cols": state_to_cols, + "cd_to_cols": cd_to_cols, + "entity_rel": entity_rel, + "household_ids": household_ids, + "variable_entity_map": variable_entity_map, + "rerandomize_takeup": rerandomize_takeup, + "affected_target_info": affected_target_info, + } + if rerandomize_takeup and affected_target_info: + shared_data["entity_hh_idx_map"] = entity_hh_idx_map + shared_data["entity_to_person_idx"] = entity_to_person_idx + shared_data["precomputed_rates"] = precomputed_rates + + logger.info( + "Starting parallel clone processing: " "%d clones, %d workers", + n_clones, + workers, + ) + + futures: dict = {} + with concurrent.futures.ProcessPoolExecutor( + max_workers=workers, + initializer=_init_clone_worker, + initargs=(shared_data,), + ) as pool: + for ci in range(n_clones): + coo_path = str(clone_dir / f"clone_{ci:04d}.npz") + if Path(coo_path).exists(): + logger.info( + "Clone %d/%d cached.", + ci + 1, + n_clones, + ) + continue + cs = ci * n_records + ce = cs + n_records + fut = pool.submit( + _process_single_clone, + ci, + cs, + ce, + coo_path, + ) + futures[fut] = ci + + for fut in concurrent.futures.as_completed(futures): + ci = futures[fut] + try: + _, nnz = fut.result() + if (ci + 1) % 50 == 0: + logger.info( + "Clone %d/%d done " "(%d nnz).", + ci + 1, + n_clones, + nnz, + ) + except Exception as exc: + for f in futures: + f.cancel() + raise RuntimeError( + f"Clone {ci} failed: {exc}" + ) from exc + + else: + # ---- Sequential clone processing (unchanged) ---- + clone_dir = Path(cache_dir) if cache_dir else None if clone_dir: - coo_path = clone_dir / f"clone_{clone_idx:04d}.npz" - if coo_path.exists(): + clone_dir.mkdir(parents=True, exist_ok=True) + + for clone_idx in range(n_clones): + if clone_dir: + coo_path = clone_dir / f"clone_{clone_idx:04d}.npz" + if coo_path.exists(): + logger.info( + "Clone %d/%d cached, " "skipping.", + clone_idx + 1, + n_clones, + ) + continue + + col_start = clone_idx * n_records + col_end = col_start + n_records + clone_states = geography.state_fips[col_start:col_end] + clone_counties = geography.county_fips[col_start:col_end] + + if (clone_idx + 1) % 50 == 0 or clone_idx == 0: logger.info( - "Clone %d/%d cached, skipping.", + "Assembling clone %d/%d " + "(cols %d-%d, " + "%d unique states)...", clone_idx + 1, n_clones, + col_start, + col_end - 1, + len(np.unique(clone_states)), ) - continue - col_start = clone_idx * n_records - col_end = col_start + n_records - clone_states = geography.state_fips[col_start:col_end] + hh_vars, person_vars = self._assemble_clone_values( + state_values, + clone_states, + person_hh_indices, + unique_variables, + unique_constraint_vars, + county_values=county_values, + clone_counties=clone_counties, + county_dependent_vars=(county_dep_targets), + ) - logger.info( - "Processing clone %d/%d " "(cols %d-%d, %d unique states)...", - clone_idx + 1, - n_clones, - col_start, - col_end - 1, - len(np.unique(clone_states)), - ) + # Apply geo-specific entity-level takeup + # for affected target variables + if rerandomize_takeup and affected_target_info: + clone_blocks = geography.block_geoid[col_start:col_end] + for ( + tvar, + info, + ) in affected_target_info.items(): + if tvar.endswith("_count"): + continue + entity_level = info["entity"] + takeup_var = info["takeup_var"] + ent_hh = entity_hh_idx_map[entity_level] + n_ent = len(ent_hh) + + ent_states = clone_states[ent_hh] + + ent_eligible = np.zeros(n_ent, dtype=np.float32) + if tvar in county_dep_targets and county_values: + ent_counties = clone_counties[ent_hh] + for cfips in np.unique(ent_counties): + m = ent_counties == cfips + cv = county_values.get(cfips, {}).get( + "entity", {} + ) + if tvar in cv: + ent_eligible[m] = cv[tvar][m] + else: + st = int(cfips[:2]) + sv = state_values[st]["entity"] + if tvar in sv: + ent_eligible[m] = sv[tvar][m] + else: + for st in np.unique(ent_states): + m = ent_states == st + sv = state_values[int(st)]["entity"] + if tvar in sv: + ent_eligible[m] = sv[tvar][m] + + ent_blocks = clone_blocks[ent_hh] + ent_hh_ids = household_ids[ent_hh] + + ent_takeup = compute_block_takeup_for_entities( + takeup_var, + precomputed_rates[info["rate_key"]], + ent_blocks, + ent_hh_ids, + ) - var_values, clone_sim = self._simulate_clone( - clone_states, - n_records, - unique_variables, - sim_modifier=sim_modifier, - clone_idx=clone_idx, - ) + ent_values = (ent_eligible * ent_takeup).astype( + np.float32 + ) - mask_cache: Dict[tuple, np.ndarray] = {} - count_cache: Dict[tuple, np.ndarray] = {} + hh_result = np.zeros(n_records, dtype=np.float32) + np.add.at(hh_result, ent_hh, ent_values) + hh_vars[tvar] = hh_result - rows_list: list = [] - cols_list: list = [] - vals_list: list = [] + if tvar in person_vars: + pidx = entity_to_person_idx[entity_level] + person_vars[tvar] = ent_values[pidx] - for row_idx in range(n_targets): - variable = str(targets_df.iloc[row_idx]["variable"]) - geo_level, geo_id = target_geo_info[row_idx] - non_geo = non_geo_constraints_list[row_idx] + mask_cache: Dict[tuple, np.ndarray] = {} + count_cache: Dict[tuple, np.ndarray] = {} - # Geographic column selection - if geo_level == "district": - all_geo_cols = cd_to_cols.get( - str(geo_id), - np.array([], dtype=np.int64), - ) - elif geo_level == "state": - all_geo_cols = state_to_cols.get( - int(geo_id), - np.array([], dtype=np.int64), - ) - else: - all_geo_cols = np.arange(n_total) + rows_list: list = [] + cols_list: list = [] + vals_list: list = [] - clone_cols = all_geo_cols[ - (all_geo_cols >= col_start) & (all_geo_cols < col_end) - ] - if len(clone_cols) == 0: - continue + for row_idx in range(n_targets): + variable = str(targets_df.iloc[row_idx]["variable"]) + geo_level, geo_id = target_geo_info[row_idx] + non_geo = non_geo_constraints_list[row_idx] + + if geo_level == "district": + all_geo_cols = cd_to_cols.get( + str(geo_id), + np.array([], dtype=np.int64), + ) + elif geo_level == "state": + all_geo_cols = state_to_cols.get( + int(geo_id), + np.array([], dtype=np.int64), + ) + else: + all_geo_cols = np.arange(n_total) + + clone_cols = all_geo_cols[ + (all_geo_cols >= col_start) & (all_geo_cols < col_end) + ] + if len(clone_cols) == 0: + continue - rec_indices = clone_cols - col_start + rec_indices = clone_cols - col_start - constraint_key = tuple( - sorted( - ( - c["variable"], - c["operation"], - c["value"], + constraint_key = tuple( + sorted( + ( + c["variable"], + c["operation"], + c["value"], + ) + for c in non_geo ) - for c in non_geo ) - ) - if variable.endswith("_count"): - vkey = (variable, constraint_key) - if vkey not in count_cache: - count_cache[vkey] = self._calculate_target_values( - clone_sim, + if variable.endswith("_count"): + vkey = ( variable, - non_geo, - n_records, + constraint_key, ) - values = count_cache[vkey] - else: - if variable not in var_values: - continue - if constraint_key not in mask_cache: - mask_cache[constraint_key] = ( - self._evaluate_constraints_entity_aware( - clone_sim, - non_geo, - n_records, + if vkey not in count_cache: + count_cache[vkey] = ( + _calculate_target_values_standalone( + target_variable=variable, + non_geo_constraints=non_geo, + n_households=n_records, + hh_vars=hh_vars, + person_vars=person_vars, + entity_rel=entity_rel, + household_ids=household_ids, + variable_entity_map=variable_entity_map, + ) + ) + values = count_cache[vkey] + else: + if variable not in hh_vars: + continue + if constraint_key not in mask_cache: + mask_cache[constraint_key] = ( + _evaluate_constraints_standalone( + non_geo, + person_vars, + entity_rel, + household_ids, + n_records, + ) + ) + mask = mask_cache[constraint_key] + values = hh_vars[variable] * mask + + vals = values[rec_indices] + nonzero = vals != 0 + if nonzero.any(): + rows_list.append( + np.full( + nonzero.sum(), + row_idx, + dtype=np.int32, ) ) - mask = mask_cache[constraint_key] - values = var_values[variable] * mask - - vals = values[rec_indices] - nonzero = vals != 0 - if nonzero.any(): - rows_list.append( - np.full( - nonzero.sum(), - row_idx, - dtype=np.int32, - ) + cols_list.append(clone_cols[nonzero].astype(np.int32)) + vals_list.append(vals[nonzero]) + + # Save COO entries + if rows_list: + cr = np.concatenate(rows_list) + cc = np.concatenate(cols_list) + cv = np.concatenate(vals_list) + else: + cr = np.array([], dtype=np.int32) + cc = np.array([], dtype=np.int32) + cv = np.array([], dtype=np.float32) + + if clone_dir: + np.savez_compressed( + str(coo_path), + rows=cr, + cols=cc, + vals=cv, ) - cols_list.append(clone_cols[nonzero].astype(np.int32)) - vals_list.append(vals[nonzero]) - - # Save COO entries - if rows_list: - cr = np.concatenate(rows_list) - cc = np.concatenate(cols_list) - cv = np.concatenate(vals_list) - else: - cr = np.array([], dtype=np.int32) - cc = np.array([], dtype=np.int32) - cv = np.array([], dtype=np.float32) - - if clone_dir: - np.savez_compressed( - str(coo_path), - rows=cr, - cols=cc, - vals=cv, - ) - logger.info( - "Clone %d: %d nonzero entries saved.", - clone_idx + 1, - len(cv), - ) - del var_values, clone_sim - else: - self._coo_parts[0].append(cr) - self._coo_parts[1].append(cc) - self._coo_parts[2].append(cv) + if (clone_idx + 1) % 50 == 0: + logger.info( + "Clone %d: %d nonzero " "entries saved.", + clone_idx + 1, + len(cv), + ) + del hh_vars, person_vars + else: + self._coo_parts[0].append(cr) + self._coo_parts[1].append(cc) + self._coo_parts[2].append(cv) # 6. Assemble sparse matrix from COO data logger.info("Assembling matrix from %d clones...", n_clones) diff --git a/policyengine_us_data/calibration/validate_national_h5.py b/policyengine_us_data/calibration/validate_national_h5.py new file mode 100644 index 000000000..c63632851 --- /dev/null +++ b/policyengine_us_data/calibration/validate_national_h5.py @@ -0,0 +1,158 @@ +"""Validate a national US.h5 file against reference values. + +Loads the national H5, computes key variables, and compares to +known national totals. Also runs structural sanity checks. + +Usage: + python -m policyengine_us_data.calibration.validate_national_h5 + python -m policyengine_us_data.calibration.validate_national_h5 \ + --h5-path path/to/US.h5 + python -m policyengine_us_data.calibration.validate_national_h5 \ + --hf-path hf://policyengine/policyengine-us-data/national/US.h5 +""" + +import argparse + +VARIABLES = [ + "adjusted_gross_income", + "employment_income", + "self_employment_income", + "tax_unit_partnership_s_corp_income", + "taxable_pension_income", + "dividend_income", + "net_capital_gains", + "rental_income", + "taxable_interest_income", + "social_security", + "snap", + "ssi", + "income_tax_before_credits", + "eitc", + "refundable_ctc", + "real_estate_taxes", + "rent", + "is_pregnant", + "person_count", + "household_count", +] + +REFERENCES = { + "person_count": (335_000_000, "~335M"), + "household_count": (130_000_000, "~130M"), + "adjusted_gross_income": (15_000_000_000_000, "~$15T"), + "employment_income": (10_000_000_000_000, "~$10T"), + "social_security": (1_200_000_000_000, "~$1.2T"), + "snap": (110_000_000_000, "~$110B"), + "ssi": (60_000_000_000, "~$60B"), + "eitc": (60_000_000_000, "~$60B"), + "refundable_ctc": (120_000_000_000, "~$120B"), + "income_tax_before_credits": (4_000_000_000_000, "~$4T"), +} + +DEFAULT_HF_PATH = "hf://policyengine/policyengine-us-data/national/US.h5" + +COUNT_VARS = {"person_count", "household_count", "is_pregnant"} + + +def main(argv=None): + parser = argparse.ArgumentParser(description="Validate national US.h5") + parser.add_argument( + "--h5-path", + default=None, + help="Local path to US.h5", + ) + parser.add_argument( + "--hf-path", + default=DEFAULT_HF_PATH, + help=f"HF path to US.h5 (default: {DEFAULT_HF_PATH})", + ) + args = parser.parse_args(argv) + + dataset_path = args.h5_path or args.hf_path + + from policyengine_us import Microsimulation + + print(f"Loading {dataset_path}...") + sim = Microsimulation(dataset=dataset_path) + + n_hh = sim.calculate("household_id", map_to="household").shape[0] + print(f"Households in file: {n_hh:,}") + + print("\n" + "=" * 70) + print("NATIONAL H5 VALUES") + print("=" * 70) + + values = {} + failures = [] + for var in VARIABLES: + try: + val = float(sim.calculate(var).sum()) + values[var] = val + if var in COUNT_VARS: + print(f" {var:45s} {val:>15,.0f}") + else: + print(f" {var:45s} ${val:>15,.0f}") + except Exception as e: + failures.append((var, str(e))) + print(f" {var:45s} FAILED: {e}") + + print("\n" + "=" * 70) + print("COMPARISON TO REFERENCE VALUES") + print("=" * 70) + + any_flag = False + for var, (ref_val, ref_label) in REFERENCES.items(): + if var not in values: + continue + val = values[var] + pct_diff = (val - ref_val) / ref_val * 100 + flag = " ***" if abs(pct_diff) > 30 else "" + if flag: + any_flag = True + if var in COUNT_VARS: + print( + f" {var:35s} {val:>15,.0f} " + f"ref {ref_label:>8s} " + f"({pct_diff:+.1f}%){flag}" + ) + else: + print( + f" {var:35s} ${val:>15,.0f} " + f"ref {ref_label:>8s} " + f"({pct_diff:+.1f}%){flag}" + ) + + if any_flag: + print("\n*** = >30% deviation from reference. " "Investigate further.") + + if failures: + print(f"\n{len(failures)} variables failed:") + for var, err in failures: + print(f" {var}: {err}") + + print("\n" + "=" * 70) + print("STRUCTURAL CHECKS") + print("=" * 70) + + from policyengine_us_data.calibration.sanity_checks import ( + run_sanity_checks, + ) + + results = run_sanity_checks(dataset_path) + n_pass = sum(1 for r in results if r["status"] == "PASS") + n_fail = sum(1 for r in results if r["status"] == "FAIL") + for r in results: + icon = ( + "PASS" + if r["status"] == "PASS" + else "FAIL" if r["status"] == "FAIL" else "WARN" + ) + print(f" [{icon}] {r['check']}: {r['detail']}") + + print(f"\n{n_pass}/{len(results)} passed, {n_fail} failed") + + return 0 if n_fail == 0 and not failures else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/policyengine_us_data/calibration/validate_package.py b/policyengine_us_data/calibration/validate_package.py new file mode 100644 index 000000000..bd6862f03 --- /dev/null +++ b/policyengine_us_data/calibration/validate_package.py @@ -0,0 +1,342 @@ +""" +Validate a calibration package before uploading to Modal. + +Usage: + python -m policyengine_us_data.calibration.validate_package [path] + [--n-hardest N] [--strict [RATIO]] +""" + +import argparse +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd + + +@dataclass +class ValidationResult: + n_targets: int + n_columns: int + nnz: int + density: float + metadata: dict + n_achievable: int + n_impossible: int + impossible_targets: pd.DataFrame + impossible_by_group: pd.DataFrame + hardest_targets: pd.DataFrame + group_summary: pd.DataFrame + strict_ratio: Optional[float] = None + strict_failures: int = 0 + + +def validate_package( + package: dict, + n_hardest: int = 10, + strict_ratio: float = None, +) -> ValidationResult: + X_sparse = package["X_sparse"] + targets_df = package["targets_df"] + target_names = package["target_names"] + metadata = package.get("metadata", {}) + + n_targets, n_columns = X_sparse.shape + nnz = X_sparse.nnz + density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0 + + row_sums = np.array(X_sparse.sum(axis=1)).flatten() + achievable_mask = row_sums > 0 + n_achievable = int(achievable_mask.sum()) + n_impossible = n_targets - n_achievable + + impossible_idx = np.where(~achievable_mask)[0] + impossible_rows = targets_df.iloc[impossible_idx] + impossible_targets = pd.DataFrame( + { + "target_name": [target_names[i] for i in impossible_idx], + "domain_variable": impossible_rows["domain_variable"].values, + "variable": impossible_rows["variable"].values, + "geo_level": impossible_rows["geo_level"].values, + "geographic_id": impossible_rows["geographic_id"].values, + "target_value": impossible_rows["value"].values, + } + ) + impossible_by_group = ( + impossible_rows.groupby(["domain_variable", "variable", "geo_level"]) + .size() + .reset_index(name="count") + .sort_values("count", ascending=False) + .reset_index(drop=True) + ) + + target_values = targets_df["value"].values + achievable_idx = np.where(achievable_mask)[0] + if len(achievable_idx) > 0: + a_row_sums = row_sums[achievable_idx] + a_target_vals = target_values[achievable_idx] + with np.errstate(divide="ignore", invalid="ignore"): + ratios = np.where( + a_target_vals != 0, + a_row_sums / a_target_vals, + np.inf, + ) + k = min(n_hardest, len(ratios)) + hardest_local_idx = np.argpartition(ratios, k)[:k] + hardest_local_idx = hardest_local_idx[ + np.argsort(ratios[hardest_local_idx]) + ] + hardest_global_idx = achievable_idx[hardest_local_idx] + + hardest_targets = pd.DataFrame( + { + "target_name": [target_names[i] for i in hardest_global_idx], + "domain_variable": targets_df["domain_variable"] + .iloc[hardest_global_idx] + .values, + "variable": targets_df["variable"] + .iloc[hardest_global_idx] + .values, + "geographic_id": targets_df["geographic_id"] + .iloc[hardest_global_idx] + .values, + "ratio": ratios[hardest_local_idx], + "row_sum": a_row_sums[hardest_local_idx], + "target_value": a_target_vals[hardest_local_idx], + } + ) + else: + hardest_targets = pd.DataFrame( + columns=[ + "target_name", + "domain_variable", + "variable", + "geographic_id", + "ratio", + "row_sum", + "target_value", + ] + ) + + group_summary = ( + targets_df.assign(achievable=achievable_mask) + .groupby(["domain_variable", "variable", "geo_level"]) + .agg(total=("value", "size"), ok=("achievable", "sum")) + .reset_index() + ) + group_summary["impossible"] = group_summary["total"] - group_summary["ok"] + group_summary["ok"] = group_summary["ok"].astype(int) + group_summary = group_summary.sort_values( + ["domain_variable", "variable", "geo_level"] + ).reset_index(drop=True) + + strict_failures = 0 + if strict_ratio is not None and len(achievable_idx) > 0: + strict_failures = int((ratios < strict_ratio).sum()) + + return ValidationResult( + n_targets=n_targets, + n_columns=n_columns, + nnz=nnz, + density=density, + metadata=metadata, + n_achievable=n_achievable, + n_impossible=n_impossible, + impossible_targets=impossible_targets, + impossible_by_group=impossible_by_group, + hardest_targets=hardest_targets, + group_summary=group_summary, + strict_ratio=strict_ratio, + strict_failures=strict_failures, + ) + + +def format_report(result: ValidationResult, package_path: str = None) -> str: + lines = ["", "=== Calibration Package Validation ===", ""] + + if package_path: + lines.append(f"Package: {package_path}") + meta = result.metadata + if meta.get("created_at"): + lines.append(f"Created: {meta['created_at']}") + if meta.get("dataset_path"): + lines.append(f"Dataset: {meta['dataset_path']}") + if meta.get("git_branch") or meta.get("git_commit"): + branch = meta.get("git_branch", "unknown") + commit = meta.get("git_commit", "") + commit_short = commit[:8] if commit else "unknown" + dirty = " (DIRTY)" if meta.get("git_dirty") else "" + lines.append(f"Git: {branch} @ {commit_short}{dirty}") + if meta.get("package_version"): + lines.append(f"Version: {meta['package_version']}") + if meta.get("dataset_sha256"): + lines.append(f"Dataset SHA: {meta['dataset_sha256'][:12]}") + if meta.get("db_sha256"): + lines.append(f"DB SHA: {meta['db_sha256'][:12]}") + lines.append("") + + lines.append( + f"Matrix: {result.n_targets:,} targets" + f" x {result.n_columns:,} columns" + ) + lines.append(f"Non-zero: {result.nnz:,} (density: {result.density:.6f})") + if meta.get("n_clones"): + parts = [f"Clones: {meta['n_clones']}"] + if meta.get("n_records"): + parts.append(f"Records: {meta['n_records']:,}") + if meta.get("seed") is not None: + parts.append(f"Seed: {meta['seed']}") + lines.append(", ".join(parts)) + lines.append("") + + pct = ( + 100 * result.n_achievable / result.n_targets if result.n_targets else 0 + ) + pct_imp = 100 - pct + lines.append("--- Achievability ---") + lines.append( + f"Achievable: {result.n_achievable:>6,}" + f" / {result.n_targets:,} ({pct:.1f}%)" + ) + lines.append( + f"Impossible: {result.n_impossible:>6,}" + f" / {result.n_targets:,} ({pct_imp:.1f}%)" + ) + lines.append("") + + if len(result.impossible_targets) > 0: + lines.append("--- Impossible Targets ---") + for _, row in result.impossible_targets.iterrows(): + lines.append( + f" {row['target_name']:<60s}" + f" {row['target_value']:>14,.0f}" + ) + lines.append("") + + if len(result.impossible_by_group) > 1: + lines.append("--- Impossible Targets by Group ---") + for _, row in result.impossible_by_group.iterrows(): + lines.append( + f" {row['domain_variable']:<20s}" + f" {row['variable']:<25s}" + f" {row['geo_level']:<12s}" + f" {row['count']:>5d}" + ) + lines.append("") + + if len(result.hardest_targets) > 0: + n = len(result.hardest_targets) + lines.append( + f"--- Hardest Achievable Targets" f" ({n} lowest ratio) ---" + ) + for _, row in result.hardest_targets.iterrows(): + lines.append( + f" {row['target_name']:<50s}" + f" {row['ratio']:>10.4f}" + f" {row['row_sum']:>14,.0f}" + f" {row['target_value']:>14,.0f}" + ) + lines.append("") + + if len(result.group_summary) > 0: + lines.append("--- Group Summary ---") + lines.append( + f" {'domain':<20s} {'variable':<25s}" + f" {'geo_level':<12s}" + f" {'total':>6s} {'ok':>6s} {'impossible':>10s}" + ) + for _, row in result.group_summary.iterrows(): + lines.append( + f" {row['domain_variable']:<20s}" + f" {row['variable']:<25s}" + f" {row['geo_level']:<12s}" + f" {row['total']:>6d}" + f" {row['ok']:>6d}" + f" {row['impossible']:>10d}" + ) + lines.append("") + + if result.strict_ratio is not None: + lines.append( + f"Strict check (ratio < {result.strict_ratio}):" + f" {result.strict_failures} failures" + ) + lines.append("") + + if result.strict_ratio is not None and result.strict_failures > 0: + lines.append( + f"RESULT: FAIL ({result.strict_failures}" + f" targets below ratio {result.strict_ratio})" + ) + elif result.n_impossible > 0: + lines.append( + f"RESULT: FAIL ({result.n_impossible} impossible targets)" + ) + else: + lines.append("RESULT: PASS") + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser( + description="Validate a calibration package" + ) + parser.add_argument( + "path", + nargs="?", + default=None, + help="Path to calibration_package.pkl", + ) + parser.add_argument( + "--n-hardest", + type=int, + default=10, + help="Number of hardest achievable targets to show", + ) + parser.add_argument( + "--strict", + nargs="?", + const=0.01, + type=float, + default=None, + metavar="RATIO", + help="Fail if any achievable target has ratio below RATIO" + " (default: 0.01)", + ) + args = parser.parse_args() + + if args.path is None: + from policyengine_us_data.storage import STORAGE_FOLDER + + path = STORAGE_FOLDER / "calibration" / "calibration_package.pkl" + else: + path = Path(args.path) + + if not path.exists(): + print(f"Error: package not found at {path}", file=sys.stderr) + sys.exit(1) + + from policyengine_us_data.calibration.unified_calibration import ( + load_calibration_package, + ) + + package = load_calibration_package(str(path)) + result = validate_package( + package, + n_hardest=args.n_hardest, + strict_ratio=args.strict, + ) + print(format_report(result, package_path=str(path))) + + if args.strict is not None and result.strict_failures > 0: + sys.exit(2) + elif result.n_impossible > 0: + sys.exit(1) + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/validate_staging.py b/policyengine_us_data/calibration/validate_staging.py new file mode 100644 index 000000000..58bd61667 --- /dev/null +++ b/policyengine_us_data/calibration/validate_staging.py @@ -0,0 +1,732 @@ +""" +Validate staging .h5 files by running sim.calculate() and comparing +against calibration targets from policy_data.db. + +Usage: + python -m policyengine_us_data.calibration.validate_staging \ + --area-type states,districts --areas NC \ + --period 2024 --output validation_results.csv + + python -m policyengine_us_data.calibration.validate_staging \ + --sanity-only --area-type states --areas NC +""" + +import argparse +import csv +import gc +import logging +import math +import multiprocessing as mp +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +from sqlalchemy import create_engine + +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.calibration.unified_calibration import ( + load_target_config, + _match_rules, +) +from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + _calculate_target_values_standalone, + _GEO_VARS, +) +from policyengine_us_data.calibration.calibration_utils import ( + STATE_CODES, +) +from policyengine_us_data.calibration.sanity_checks import ( + run_sanity_checks, +) + +logger = logging.getLogger(__name__) + +DEFAULT_HF_PREFIX = "hf://policyengine/policyengine-us-data/staging" +DEFAULT_DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") +DEFAULT_TARGET_CONFIG = str(Path(__file__).parent / "target_config_full.yaml") +TRAINING_TARGET_CONFIG = str(Path(__file__).parent / "target_config.yaml") + +SANITY_CEILINGS = { + "national": { + "dollar": 30e12, + "person_count": 340e6, + "household_count": 135e6, + "count": 340e6, + }, + "state": { + "dollar": 5e12, + "person_count": 40e6, + "household_count": 15e6, + "count": 40e6, + }, + "district": { + "dollar": 500e9, + "person_count": 1e6, + "household_count": 400e3, + "count": 1e6, + }, +} + +FIPS_TO_ABBR = {str(k): v for k, v in STATE_CODES.items()} +ABBR_TO_FIPS = {v: str(k) for k, v in STATE_CODES.items()} + +CSV_COLUMNS = [ + "area_type", + "area_id", + "variable", + "target_name", + "period", + "target_value", + "sim_value", + "error", + "rel_error", + "abs_error", + "rel_abs_error", + "sanity_check", + "sanity_reason", + "in_training", +] + + +def _classify_variable(variable: str) -> str: + if variable == "household_count": + return "household_count" + if variable == "person_count": + return "person_count" + if variable.endswith("_count"): + return "count" + return "dollar" + + +def _run_sanity_check( + sim_value: float, + variable: str, + geo_level: str, +) -> tuple: + if not math.isfinite(sim_value): + return "FAIL", "non-finite value" + vtype = _classify_variable(variable) + ceilings = SANITY_CEILINGS.get(geo_level, SANITY_CEILINGS["state"]) + ceiling = ceilings.get(vtype, ceilings["dollar"]) + if abs(sim_value) > ceiling: + return ( + "FAIL", + f"|{sim_value:.2e}| > {ceiling:.0e} ceiling " + f"({vtype} @ {geo_level})", + ) + return "PASS", "" + + +def _query_all_active_targets(engine, period: int) -> pd.DataFrame: + query = """ + WITH best_periods AS ( + SELECT stratum_id, variable, + CASE + WHEN MAX(CASE WHEN period <= :period + THEN period END) IS NOT NULL + THEN MAX(CASE WHEN period <= :period + THEN period END) + ELSE MIN(period) + END as best_period + FROM target_overview + WHERE active = 1 + GROUP BY stratum_id, variable + ) + SELECT tv.target_id, tv.stratum_id, tv.variable, + tv.value, tv.period, tv.geo_level, + tv.geographic_id, tv.domain_variable + FROM target_overview tv + JOIN best_periods bp + ON tv.stratum_id = bp.stratum_id + AND tv.variable = bp.variable + AND tv.period = bp.best_period + WHERE tv.active = 1 + ORDER BY tv.target_id + """ + with engine.connect() as conn: + return pd.read_sql(query, conn, params={"period": period}) + + +def _get_stratum_constraints(engine, stratum_id: int) -> list: + query = """ + SELECT constraint_variable AS variable, operation, value + FROM stratum_constraints + WHERE stratum_id = :stratum_id + """ + with engine.connect() as conn: + df = pd.read_sql(query, conn, params={"stratum_id": int(stratum_id)}) + return df.to_dict("records") + + +def _geoid_to_district_filename(geoid: str) -> str: + """Convert DB geographic_id like '3701' to filename 'NC-01'.""" + geoid = geoid.zfill(4) + state_fips = geoid[:-2] + district_num = geoid[-2:] + abbr = FIPS_TO_ABBR.get(state_fips) + if abbr is None: + return geoid + return f"{abbr}-{district_num}" + + +def _geoid_to_display(geoid: str) -> str: + """Convert DB geographic_id like '3701' to 'NC-01'.""" + return _geoid_to_district_filename(geoid) + + +def _resolve_state_fips(areas_str: Optional[str]) -> list: + """Resolve --areas to state FIPS codes.""" + if not areas_str: + return [str(f) for f in sorted(STATE_CODES.keys())] + resolved = [] + for a in areas_str.split(","): + a = a.strip() + if a in ABBR_TO_FIPS: + resolved.append(ABBR_TO_FIPS[a]) + elif a.isdigit(): + resolved.append(a) + else: + logger.warning("Unknown area '%s', skipping", a) + return resolved + + +def _resolve_district_ids(engine, areas_str: Optional[str]) -> list: + """Resolve --areas to district geographic_ids from DB.""" + state_fips_list = _resolve_state_fips(areas_str) + with engine.connect() as conn: + df = pd.read_sql( + "SELECT DISTINCT geographic_id FROM target_overview " + "WHERE geo_level = 'district'", + conn, + ) + all_geoids = df["geographic_id"].tolist() + result = [] + for geoid in all_geoids: + padded = str(geoid).zfill(4) + sfips = padded[:-2] + if sfips in state_fips_list: + result.append(str(geoid)) + return sorted(result) + + +def _build_variable_entity_map(sim) -> dict: + tbs = sim.tax_benefit_system + mapping = {} + for var_name in tbs.variables: + var = tbs.get_variable(var_name) + if var is not None: + mapping[var_name] = var.entity.key + count_entities = { + "person_count": "person", + "household_count": "household", + "tax_unit_count": "tax_unit", + "spm_unit_count": "spm_unit", + } + mapping.update(count_entities) + return mapping + + +def _build_entity_rel(sim) -> pd.DataFrame: + return pd.DataFrame( + { + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": sim.calculate( + "spm_unit_id", map_to="person" + ).values, + } + ) + + +def validate_area( + sim, + targets_df: pd.DataFrame, + engine, + area_type: str, + area_id: str, + display_id: str, + period: int, + training_mask: np.ndarray, + variable_entity_map: dict, +) -> list: + entity_rel = _build_entity_rel(sim) + household_ids = sim.calculate("household_id", map_to="household").values + n_households = len(household_ids) + + hh_weight = sim.calculate( + "household_weight", + map_to="household", + period=period, + ).values.astype(np.float64) + + hh_vars_cache = {} + person_vars_cache = {} + + training_arr = np.asarray(training_mask, dtype=bool) + + geo_level = "state" if area_type == "states" else "district" + + results = [] + for i, (idx, row) in enumerate(targets_df.iterrows()): + variable = row["variable"] + target_value = float(row["value"]) + stratum_id = int(row["stratum_id"]) + + constraints = _get_stratum_constraints(engine, stratum_id) + non_geo = [c for c in constraints if c["variable"] not in _GEO_VARS] + + needed_vars = set() + needed_vars.add(variable) + for c in non_geo: + needed_vars.add(c["variable"]) + + is_count = variable.endswith("_count") + if not is_count and variable not in hh_vars_cache: + try: + hh_vars_cache[variable] = sim.calculate( + variable, + map_to="household", + period=period, + ).values + except Exception: + pass + + for vname in needed_vars: + if vname not in person_vars_cache: + try: + person_vars_cache[vname] = sim.calculate( + vname, + map_to="person", + period=period, + ).values + except Exception: + pass + + per_hh = _calculate_target_values_standalone( + target_variable=variable, + non_geo_constraints=non_geo, + n_households=n_households, + hh_vars=hh_vars_cache, + person_vars=person_vars_cache, + entity_rel=entity_rel, + household_ids=household_ids, + variable_entity_map=variable_entity_map, + ) + + sim_value = float(np.dot(per_hh, hh_weight)) + + error = sim_value - target_value + abs_error = abs(error) + if target_value != 0: + rel_error = error / target_value + rel_abs_error = abs_error / abs(target_value) + else: + rel_error = float("inf") if error != 0 else 0.0 + rel_abs_error = float("inf") if abs_error != 0 else 0.0 + + target_name = UnifiedMatrixBuilder._make_target_name( + variable, + constraints, + ) + + sanity_check, sanity_reason = _run_sanity_check( + sim_value, + variable, + geo_level, + ) + + in_training = bool(training_arr[i]) + + results.append( + { + "area_type": area_type, + "area_id": display_id, + "variable": variable, + "target_name": target_name, + "period": int(row["period"]), + "target_value": target_value, + "sim_value": sim_value, + "error": error, + "rel_error": rel_error, + "abs_error": abs_error, + "rel_abs_error": rel_abs_error, + "sanity_check": sanity_check, + "sanity_reason": sanity_reason, + "in_training": in_training, + } + ) + + return results + + +def parse_args(argv=None): + parser = argparse.ArgumentParser( + description="Validate staging .h5 files against " + "calibration targets via sim.calculate()" + ) + parser.add_argument( + "--area-type", + default="states", + help="Comma-separated geo levels to validate: " + "states, districts (default: states)", + ) + parser.add_argument( + "--areas", + default=None, + help="Comma-separated state abbreviations or FIPS " + "(applies to all area types; all if omitted)", + ) + parser.add_argument( + "--hf-prefix", + default=DEFAULT_HF_PREFIX, + help="HuggingFace path prefix for .h5 files", + ) + parser.add_argument( + "--period", + type=int, + default=2024, + help="Tax year to validate (default: 2024)", + ) + parser.add_argument( + "--target-config", + default=DEFAULT_TARGET_CONFIG, + help="YAML config with exclude rules " + "(default: target_config_full.yaml)", + ) + parser.add_argument( + "--db-path", + default=DEFAULT_DB_PATH, + help="Path to policy_data.db", + ) + parser.add_argument( + "--output", + default="validation_results.csv", + help="Output CSV path", + ) + parser.add_argument( + "--sanity-only", + action="store_true", + help="Run only structural sanity checks (fast, " "no database needed)", + ) + return parser.parse_args(argv) + + +def _validate_single_area( + area_type, + area_id, + h5_path, + display_id, + area_targets, + area_training, + db_path, + period, +): + """Validate one area in an isolated process. + + Runs in a subprocess so all memory (Microsimulation, caches) + is fully reclaimed by the OS when the process exits. + """ + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + from policyengine_us import Microsimulation + from sqlalchemy import create_engine as _create_engine + + engine = _create_engine(f"sqlite:///{db_path}") + + logger.info("Loading sim from %s", h5_path) + try: + sim = Microsimulation(dataset=h5_path) + except Exception as e: + logger.error("Failed to load %s: %s", h5_path, e) + return [], 0.0 + + area_pop = 0.0 + if area_type == "states": + person_weight = sim.calculate( + "person_weight", + map_to="person", + period=period, + ).values.astype(np.float64) + area_pop = float(person_weight.sum()) + logger.info(" %s population: %.0f", display_id, area_pop) + + if len(area_targets) == 0: + logger.warning("No targets for %s, skipping", display_id) + return [], area_pop + + logger.info( + "Validating %d targets for %s", + len(area_targets), + display_id, + ) + + variable_entity_map = _build_variable_entity_map(sim) + + area_results = validate_area( + sim=sim, + targets_df=area_targets, + engine=engine, + area_type=area_type, + area_id=area_id, + display_id=display_id, + period=period, + training_mask=area_training, + variable_entity_map=variable_entity_map, + ) + + n_fail = sum( + 1 for r in area_results if r["sanity_check"] == "FAIL" + ) + logger.info( + " %s: %d results, %d sanity failures", + display_id, + len(area_results), + n_fail, + ) + + return area_results, area_pop + + +def _run_area_type( + area_type, + area_ids, + level_targets, + level_training, + engine, + args, + Microsimulation, +): + """Validate all areas for a single area_type. + + Each area runs in a subprocess so the OS fully reclaims + memory between states (avoids OOM on large states like CA). + """ + results = [] + total_weighted_pop = 0.0 + + for area_id in area_ids: + if area_type == "states": + abbr = FIPS_TO_ABBR.get(area_id, area_id) + h5_name = abbr + display_id = abbr + else: + h5_name = _geoid_to_district_filename(area_id) + display_id = h5_name + + h5_path = f"{args.hf_prefix}/{area_type}/{h5_name}.h5" + + area_mask = ( + level_targets["geographic_id"] == area_id + ).values + area_targets = level_targets[area_mask].reset_index( + drop=True + ) + area_training = level_training[area_mask] + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + area_results, area_pop = pool.apply( + _validate_single_area, + ( + area_type, + area_id, + h5_path, + display_id, + area_targets, + area_training, + args.db_path, + args.period, + ), + ) + + total_weighted_pop += area_pop + results.extend(area_results) + + if area_type == "states" and total_weighted_pop > 0: + logger.info( + "TOTAL WEIGHTED POPULATION: %.0f (expect ~340M)", + total_weighted_pop, + ) + + return results + + +def _run_sanity_only(args): + """Run structural sanity checks on staging H5 files.""" + area_types = [t.strip() for t in args.area_type.split(",")] + state_fips_list = _resolve_state_fips(args.areas) + + total_failures = 0 + + for area_type in area_types: + if area_type == "states": + for fips in state_fips_list: + abbr = FIPS_TO_ABBR.get(fips, fips) + h5_url = f"{args.hf_prefix}/{area_type}/{abbr}.h5" + logger.info("Sanity-checking %s", h5_url) + + if h5_url.startswith("hf://"): + from huggingface_hub import hf_hub_download + import tempfile + + parts = h5_url[5:].split("/", 2) + repo = f"{parts[0]}/{parts[1]}" + path = parts[2] + local = hf_hub_download( + repo_id=repo, + filename=path, + repo_type="model", + token=os.environ.get("HUGGING_FACE_TOKEN"), + ) + else: + local = h5_url + + results = run_sanity_checks(local, args.period) + n_fail = sum(1 for r in results if r["status"] == "FAIL") + total_failures += n_fail + + for r in results: + if r["status"] != "PASS": + detail = f" — {r['detail']}" if r["detail"] else "" + logger.warning( + " %s [%s] %s%s", + abbr, + r["status"], + r["check"], + detail, + ) + + if n_fail == 0: + logger.info(" %s: all checks passed", abbr) + else: + logger.info( + "Sanity-only mode for %s not yet implemented", + area_type, + ) + + if total_failures > 0: + logger.error("%d total sanity failures", total_failures) + else: + logger.info("All sanity checks passed") + + +def main(argv=None): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + + args = parse_args(argv) + logger.info("CLI args: %s", vars(args)) + + if args.sanity_only: + _run_sanity_only(args) + return + + from policyengine_us import Microsimulation + + engine = create_engine(f"sqlite:///{args.db_path}") + + all_targets = _query_all_active_targets(engine, args.period) + logger.info("Loaded %d active targets from DB", len(all_targets)) + + exclude_config = load_target_config(args.target_config) + exclude_rules = exclude_config.get("exclude", []) + if exclude_rules: + exc_mask = _match_rules(all_targets, exclude_rules) + all_targets = all_targets[~exc_mask].reset_index(drop=True) + logger.info("After exclusions: %d targets", len(all_targets)) + + include_rules = exclude_config.get("include", []) + if include_rules: + inc_mask = _match_rules(all_targets, include_rules) + all_targets = all_targets[inc_mask].reset_index(drop=True) + logger.info("After inclusions: %d targets", len(all_targets)) + + training_config = load_target_config(TRAINING_TARGET_CONFIG) + training_include = training_config.get("include", []) + if training_include: + training_mask = np.asarray( + _match_rules(all_targets, training_include), + dtype=bool, + ) + else: + training_mask = np.ones(len(all_targets), dtype=bool) + + area_types = [t.strip() for t in args.area_type.split(",")] + valid_types = {"states", "districts"} + for t in area_types: + if t not in valid_types: + logger.error( + "Unknown area-type '%s'. Use: %s", + t, + ", ".join(sorted(valid_types)), + ) + return + + all_results = [] + + for area_type in area_types: + geo_level = "state" if area_type == "states" else "district" + geo_mask = (all_targets["geo_level"] == geo_level).values + level_targets = all_targets[geo_mask].reset_index(drop=True) + level_training = training_mask[geo_mask] + + logger.info( + "%d targets at geo_level=%s", + len(level_targets), + geo_level, + ) + + if area_type == "states": + area_ids = _resolve_state_fips(args.areas) + else: + area_ids = _resolve_district_ids(engine, args.areas) + + logger.info( + "%s: %d areas to validate", + area_type, + len(area_ids), + ) + + results = _run_area_type( + area_type=area_type, + area_ids=area_ids, + level_targets=level_targets, + level_training=level_training, + engine=engine, + args=args, + Microsimulation=Microsimulation, + ) + all_results.extend(results) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS) + writer.writeheader() + writer.writerows(all_results) + + logger.info("Wrote %d rows to %s", len(all_results), output_path) + + n_total_fail = sum(1 for r in all_results if r["sanity_check"] == "FAIL") + if n_total_fail > 0: + logger.warning( + "%d SANITY FAILURES across all areas", + n_total_fail, + ) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/extended_cps.py b/policyengine_us_data/datasets/cps/extended_cps.py index b60fbf42f..b095a7f08 100644 --- a/policyengine_us_data/datasets/cps/extended_cps.py +++ b/policyengine_us_data/datasets/cps/extended_cps.py @@ -55,7 +55,7 @@ def generate(self): # Variables with formulas that must still be stored (e.g. IDs # needed by the dataset loader before formulas can run). - _KEEP_FORMULA_VARS = {"person_id"} + _KEEP_FORMULA_VARS = {"person_id", "is_pregnant"} @classmethod def _drop_formula_variables(cls, data): diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/__init__.py b/policyengine_us_data/datasets/cps/local_area_calibration/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py deleted file mode 100644 index 4963f3979..000000000 --- a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py +++ /dev/null @@ -1,587 +0,0 @@ -""" -Publish local area H5 files to GCP and Hugging Face. - -Downloads calibration inputs from HF, generates state/district H5s -with checkpointing, and uploads to both destinations. - -Usage: - python publish_local_area.py [--skip-download] [--states-only] [--districts-only] -""" - -import os -import numpy as np -from pathlib import Path -from typing import List, Optional, Set - -from policyengine_us import Microsimulation -from policyengine_us_data.utils.huggingface import download_calibration_inputs -from policyengine_us_data.utils.data_upload import ( - upload_local_area_file, - upload_local_area_batch_to_hf, -) -from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import ( - create_sparse_cd_stacked_dataset, - NYC_COUNTIES, - NYC_CDS, -) -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( - get_all_cds_from_database, - STATE_CODES, -) - -CHECKPOINT_FILE = Path("completed_states.txt") -CHECKPOINT_FILE_DISTRICTS = Path("completed_districts.txt") -CHECKPOINT_FILE_CITIES = Path("completed_cities.txt") -WORK_DIR = Path("local_area_build") - - -def load_completed_states() -> set: - if CHECKPOINT_FILE.exists(): - content = CHECKPOINT_FILE.read_text().strip() - if content: - return set(content.split("\n")) - return set() - - -def record_completed_state(state_code: str): - with open(CHECKPOINT_FILE, "a") as f: - f.write(f"{state_code}\n") - - -def load_completed_districts() -> set: - if CHECKPOINT_FILE_DISTRICTS.exists(): - content = CHECKPOINT_FILE_DISTRICTS.read_text().strip() - if content: - return set(content.split("\n")) - return set() - - -def record_completed_district(district_name: str): - with open(CHECKPOINT_FILE_DISTRICTS, "a") as f: - f.write(f"{district_name}\n") - - -def load_completed_cities() -> set: - if CHECKPOINT_FILE_CITIES.exists(): - content = CHECKPOINT_FILE_CITIES.read_text().strip() - if content: - return set(content.split("\n")) - return set() - - -def record_completed_city(city_name: str): - with open(CHECKPOINT_FILE_CITIES, "a") as f: - f.write(f"{city_name}\n") - - -def build_state_h5( - state_code: str, - weights: np.ndarray, - cds_to_calibrate: List[str], - dataset_path: Path, - output_dir: Path, -) -> Optional[Path]: - """ - Build a single state H5 file (build only, no upload). - - Args: - state_code: Two-letter state code (e.g., "AL", "CA") - weights: Calibrated weight vector - cds_to_calibrate: Full list of CD GEOIDs from calibration - dataset_path: Path to base dataset H5 file - output_dir: Output directory for H5 file - - Returns: - Path to output H5 file if successful, None if no CDs found - """ - state_fips = None - for fips, code in STATE_CODES.items(): - if code == state_code: - state_fips = fips - break - - if state_fips is None: - print(f"Unknown state code: {state_code}") - return None - - cd_subset = [cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips] - if not cd_subset: - print(f"No CDs found for {state_code}, skipping") - return None - - states_dir = output_dir / "states" - states_dir.mkdir(parents=True, exist_ok=True) - output_path = states_dir / f"{state_code}.h5" - - print(f"\n{'='*60}") - print(f"Building {state_code} ({len(cd_subset)} CDs)") - print(f"{'='*60}") - - create_sparse_cd_stacked_dataset( - weights, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=str(dataset_path), - output_path=str(output_path), - ) - - return output_path - - -def build_district_h5( - cd_geoid: str, - weights: np.ndarray, - cds_to_calibrate: List[str], - dataset_path: Path, - output_dir: Path, -) -> Path: - """ - Build a single district H5 file (build only, no upload). - - Args: - cd_geoid: Congressional district GEOID (e.g., "0101" for AL-01) - weights: Calibrated weight vector - cds_to_calibrate: Full list of CD GEOIDs from calibration - dataset_path: Path to base dataset H5 file - output_dir: Output directory for H5 file - - Returns: - Path to output H5 file - """ - cd_int = int(cd_geoid) - state_fips = cd_int // 100 - district_num = cd_int % 100 - state_code = STATE_CODES.get(state_fips, str(state_fips)) - friendly_name = f"{state_code}-{district_num:02d}" - - districts_dir = output_dir / "districts" - districts_dir.mkdir(parents=True, exist_ok=True) - output_path = districts_dir / f"{friendly_name}.h5" - - print(f"\n{'='*60}") - print(f"Building {friendly_name}") - print(f"{'='*60}") - - create_sparse_cd_stacked_dataset( - weights, - cds_to_calibrate, - cd_subset=[cd_geoid], - dataset_path=str(dataset_path), - output_path=str(output_path), - ) - - return output_path - - -def build_city_h5( - city_name: str, - weights: np.ndarray, - cds_to_calibrate: List[str], - dataset_path: Path, - output_dir: Path, -) -> Optional[Path]: - """ - Build a city H5 file (build only, no upload). - - Currently supports NYC only. - - Args: - city_name: City name (currently only "NYC" supported) - weights: Calibrated weight vector - cds_to_calibrate: Full list of CD GEOIDs from calibration - dataset_path: Path to base dataset H5 file - output_dir: Output directory for H5 file - - Returns: - Path to output H5 file if successful, None otherwise - """ - if city_name != "NYC": - print(f"Unsupported city: {city_name}") - return None - - cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] - if not cd_subset: - print("No NYC-related CDs found, skipping") - return None - - cities_dir = output_dir / "cities" - cities_dir.mkdir(parents=True, exist_ok=True) - output_path = cities_dir / "NYC.h5" - - print(f"\n{'='*60}") - print(f"Building NYC ({len(cd_subset)} CDs)") - print(f"{'='*60}") - - create_sparse_cd_stacked_dataset( - weights, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=str(dataset_path), - output_path=str(output_path), - county_filter=NYC_COUNTIES, - ) - - return output_path - - -def get_district_friendly_name(cd_geoid: str) -> str: - """Convert GEOID to friendly name (e.g., '0101' -> 'AL-01').""" - cd_int = int(cd_geoid) - state_fips = cd_int // 100 - district_num = cd_int % 100 - state_code = STATE_CODES.get(state_fips, str(state_fips)) - return f"{state_code}-{district_num:02d}" - - -def build_and_upload_states( - weights_path: Path, - dataset_path: Path, - db_path: Path, - output_dir: Path, - completed_states: set, - hf_batch_size: int = 10, -): - """Build and upload state H5 files with checkpointing.""" - db_uri = f"sqlite:///{db_path}" - cds_to_calibrate = get_all_cds_from_database(db_uri) - w = np.load(weights_path) - - states_dir = output_dir / "states" - states_dir.mkdir(parents=True, exist_ok=True) - - hf_queue = [] # Queue for batched HuggingFace uploads - - for state_fips, state_code in STATE_CODES.items(): - if state_code in completed_states: - print(f"Skipping {state_code} (already completed)") - continue - - cd_subset = [ - cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips - ] - if not cd_subset: - print(f"No CDs found for {state_code}, skipping") - continue - - output_path = states_dir / f"{state_code}.h5" - print(f"\n{'='*60}") - print(f"Building {state_code} ({len(cd_subset)} CDs)") - print(f"{'='*60}") - - try: - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=str(dataset_path), - output_path=str(output_path), - ) - - print(f"Uploading {state_code}.h5 to GCP...") - upload_local_area_file(str(output_path), "states", skip_hf=True) - - # Queue for batched HuggingFace upload - hf_queue.append((str(output_path), "states")) - - record_completed_state(state_code) - print(f"Completed {state_code}") - - # Flush HF queue every batch_size files - if len(hf_queue) >= hf_batch_size: - print( - f"\nUploading batch of {len(hf_queue)} files to HuggingFace..." - ) - upload_local_area_batch_to_hf(hf_queue) - hf_queue = [] - - except Exception as e: - print(f"ERROR building {state_code}: {e}") - raise - - # Flush remaining files to HuggingFace - if hf_queue: - print( - f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..." - ) - upload_local_area_batch_to_hf(hf_queue) - - -def build_and_upload_districts( - weights_path: Path, - dataset_path: Path, - db_path: Path, - output_dir: Path, - completed_districts: set, - hf_batch_size: int = 10, -): - """Build and upload district H5 files with checkpointing.""" - db_uri = f"sqlite:///{db_path}" - cds_to_calibrate = get_all_cds_from_database(db_uri) - w = np.load(weights_path) - - districts_dir = output_dir / "districts" - districts_dir.mkdir(parents=True, exist_ok=True) - - hf_queue = [] # Queue for batched HuggingFace uploads - - for i, cd_geoid in enumerate(cds_to_calibrate): - cd_int = int(cd_geoid) - state_fips = cd_int // 100 - district_num = cd_int % 100 - state_code = STATE_CODES.get(state_fips, str(state_fips)) - friendly_name = f"{state_code}-{district_num:02d}" - - if friendly_name in completed_districts: - print(f"Skipping {friendly_name} (already completed)") - continue - - output_path = districts_dir / f"{friendly_name}.h5" - print(f"\n{'='*60}") - print(f"[{i+1}/{len(cds_to_calibrate)}] Building {friendly_name}") - print(f"{'='*60}") - - try: - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=[cd_geoid], - dataset_path=str(dataset_path), - output_path=str(output_path), - ) - - print(f"Uploading {friendly_name}.h5 to GCP...") - upload_local_area_file(str(output_path), "districts", skip_hf=True) - - # Queue for batched HuggingFace upload - hf_queue.append((str(output_path), "districts")) - - record_completed_district(friendly_name) - print(f"Completed {friendly_name}") - - # Flush HF queue every batch_size files - if len(hf_queue) >= hf_batch_size: - print( - f"\nUploading batch of {len(hf_queue)} files to HuggingFace..." - ) - upload_local_area_batch_to_hf(hf_queue) - hf_queue = [] - - except Exception as e: - print(f"ERROR building {friendly_name}: {e}") - raise - - # Flush remaining files to HuggingFace - if hf_queue: - print( - f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..." - ) - upload_local_area_batch_to_hf(hf_queue) - - -def build_and_upload_cities( - weights_path: Path, - dataset_path: Path, - db_path: Path, - output_dir: Path, - completed_cities: set, - hf_batch_size: int = 10, -): - """Build and upload city H5 files with checkpointing.""" - db_uri = f"sqlite:///{db_path}" - cds_to_calibrate = get_all_cds_from_database(db_uri) - w = np.load(weights_path) - - cities_dir = output_dir / "cities" - cities_dir.mkdir(parents=True, exist_ok=True) - - hf_queue = [] # Queue for batched HuggingFace uploads - - # NYC - if "NYC" in completed_cities: - print("Skipping NYC (already completed)") - else: - cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] - if not cd_subset: - print("No NYC-related CDs found, skipping") - else: - output_path = cities_dir / "NYC.h5" - print(f"\n{'='*60}") - print(f"Building NYC ({len(cd_subset)} CDs)") - print(f"{'='*60}") - - try: - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=str(dataset_path), - output_path=str(output_path), - county_filter=NYC_COUNTIES, - ) - - print("Uploading NYC.h5 to GCP...") - upload_local_area_file( - str(output_path), "cities", skip_hf=True - ) - - # Queue for batched HuggingFace upload - hf_queue.append((str(output_path), "cities")) - - record_completed_city("NYC") - print("Completed NYC") - - except Exception as e: - print(f"ERROR building NYC: {e}") - raise - - # Flush remaining files to HuggingFace - if hf_queue: - print( - f"\nUploading batch of {len(hf_queue)} city files to HuggingFace..." - ) - upload_local_area_batch_to_hf(hf_queue) - - -def main(): - import argparse - - parser = argparse.ArgumentParser( - description="Build and publish local area H5 files" - ) - parser.add_argument( - "--skip-download", - action="store_true", - help="Skip downloading inputs from HF (use existing files)", - ) - parser.add_argument( - "--states-only", - action="store_true", - help="Only build and upload state files", - ) - parser.add_argument( - "--districts-only", - action="store_true", - help="Only build and upload district files", - ) - parser.add_argument( - "--cities-only", - action="store_true", - help="Only build and upload city files (e.g., NYC)", - ) - parser.add_argument( - "--weights-path", - type=str, - help="Override path to weights file (for local testing)", - ) - parser.add_argument( - "--dataset-path", - type=str, - help="Override path to dataset file (for local testing)", - ) - parser.add_argument( - "--db-path", - type=str, - help="Override path to database file (for local testing)", - ) - args = parser.parse_args() - - WORK_DIR.mkdir(parents=True, exist_ok=True) - - if args.weights_path and args.dataset_path and args.db_path: - inputs = { - "weights": Path(args.weights_path), - "dataset": Path(args.dataset_path), - "database": Path(args.db_path), - } - print("Using provided paths:") - for key, path in inputs.items(): - print(f" {key}: {path}") - elif args.skip_download: - inputs = { - "weights": WORK_DIR / "w_district_calibration.npy", - "dataset": WORK_DIR / "stratified_extended_cps.h5", - "database": WORK_DIR / "policy_data.db", - } - print("Using existing files in work directory:") - for key, path in inputs.items(): - if not path.exists(): - raise FileNotFoundError(f"Expected file not found: {path}") - print(f" {key}: {path}") - else: - print("Downloading calibration inputs from Hugging Face...") - inputs = download_calibration_inputs(str(WORK_DIR)) - for key, path in inputs.items(): - inputs[key] = Path(path) - - sim = Microsimulation(dataset=str(inputs["dataset"])) - n_hh = sim.calculate("household_id", map_to="household").shape[0] - print(f"\nBase dataset has {n_hh:,} households") - - # Determine what to build based on flags - build_states = not args.districts_only and not args.cities_only - build_districts = not args.states_only and not args.cities_only - build_cities = not args.states_only and not args.districts_only - - # If a specific *-only flag is set, only build that type - if args.states_only: - build_states = True - build_districts = False - build_cities = False - elif args.districts_only: - build_states = False - build_districts = True - build_cities = False - elif args.cities_only: - build_states = False - build_districts = False - build_cities = True - - if build_states: - print("\n" + "=" * 60) - print("BUILDING STATE FILES") - print("=" * 60) - completed_states = load_completed_states() - print(f"Already completed: {len(completed_states)} states") - build_and_upload_states( - inputs["weights"], - inputs["dataset"], - inputs["database"], - WORK_DIR, - completed_states, - ) - - if build_districts: - print("\n" + "=" * 60) - print("BUILDING DISTRICT FILES") - print("=" * 60) - completed_districts = load_completed_districts() - print(f"Already completed: {len(completed_districts)} districts") - build_and_upload_districts( - inputs["weights"], - inputs["dataset"], - inputs["database"], - WORK_DIR, - completed_districts, - ) - - if build_cities: - print("\n" + "=" * 60) - print("BUILDING CITY FILES") - print("=" * 60) - completed_cities = load_completed_cities() - print(f"Already completed: {len(completed_cities)} cities") - build_and_upload_cities( - inputs["weights"], - inputs["dataset"], - inputs["database"], - WORK_DIR, - completed_cities, - ) - - print("\n" + "=" * 60) - print("ALL DONE!") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py deleted file mode 100644 index 010e151f3..000000000 --- a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py +++ /dev/null @@ -1,932 +0,0 @@ -""" -Create a sparse congressional district-stacked dataset with non-zero weight -households. -""" - -import os -import numpy as np -import pandas as pd -import h5py -from pathlib import Path -from policyengine_us import Microsimulation -from policyengine_core.data.dataset import Dataset -from policyengine_core.enums import Enum -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( - get_all_cds_from_database, - get_calculated_variables, - STATE_CODES, - STATE_FIPS_TO_NAME, - STATE_FIPS_TO_CODE, - load_cd_geoadj_values, - calculate_spm_thresholds_for_cd, -) -from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( - County, -) -from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( - assign_geography_for_cd, - get_county_filter_probability, - get_filtered_block_distribution, -) - -NYC_COUNTIES = { - "QUEENS_COUNTY_NY", - "BRONX_COUNTY_NY", - "RICHMOND_COUNTY_NY", - "NEW_YORK_COUNTY_NY", - "KINGS_COUNTY_NY", -} - -NYC_CDS = [ - "3603", - "3605", - "3606", - "3607", - "3608", - "3609", - "3610", - "3611", - "3612", - "3613", - "3614", - "3615", - "3616", -] - - -def get_county_name(county_index: int) -> str: - """Convert county enum index back to name.""" - return County._member_names_[county_index] - - -def create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=None, - output_path=None, - dataset_path=None, - county_filter=None, - seed: int = 42, -): - """ - Create a SPARSE congressional district-stacked dataset using DataFrame approach. - - Args: - w: Calibrated weight vector from L0 calibration. Shape is (n_cds * n_households,), - reshaped internally to (n_cds, n_households) using cds_to_calibrate ordering. - cds_to_calibrate: Ordered list of CD GEOID codes that defines the row ordering - of the weight matrix. Required to correctly index into w for any cd_subset. - cd_subset: Optional list of CD GEOIDs to include in output (must be subset of - cds_to_calibrate). If None, includes all CDs. - output_path: Where to save the sparse CD-stacked .h5 file. - dataset_path: Path to the base .h5 dataset used during calibration. - county_filter: Optional set of county names to filter to. Only households - assigned to these counties will be included. Used for city-level datasets. - seed: Base random seed for county assignment. Each CD gets seed + int(cd_geoid) - for deterministic, order-independent results. Default 42. - - Returns: - output_path: Path to the saved .h5 file. - """ - - # Handle CD subset filtering - if cd_subset is not None: - # Validate that requested CDs are in the calibration - for cd in cd_subset: - if cd not in cds_to_calibrate: - raise ValueError(f"CD {cd} not in calibrated CDs list") - - # Get indices of requested CDs - cd_indices = [cds_to_calibrate.index(cd) for cd in cd_subset] - cds_to_process = cd_subset - - print( - f"Processing subset of {len(cd_subset)} CDs: {', '.join(cd_subset[:5])}..." - ) - else: - # Process all CDs - cd_indices = list(range(len(cds_to_calibrate))) - cds_to_process = cds_to_calibrate - print( - f"Processing all {len(cds_to_calibrate)} congressional districts" - ) - - # Generate output path if not provided - if output_path is None: - raise ValueError("No output .h5 path given") - print(f"Output path: {output_path}") - - # Check that output directory exists, create if needed - output_dir_path = os.path.dirname(output_path) - if output_dir_path and not os.path.exists(output_dir_path): - print(f"Creating output directory: {output_dir_path}") - os.makedirs(output_dir_path, exist_ok=True) - - # Load the original simulation - base_sim = Microsimulation(dataset=dataset_path) - - household_ids = base_sim.calculate( - "household_id", map_to="household" - ).values - n_households_orig = len(household_ids) - - # From the base sim, create mapping from household ID to index for proper filtering - hh_id_to_idx = {int(hh_id): idx for idx, hh_id in enumerate(household_ids)} - - # Infer the number of households from weight vector and CD count - if len(w) % len(cds_to_calibrate) != 0: - raise ValueError( - f"Weight vector length ({len(w):,}) is not evenly divisible by " - f"number of CDs ({len(cds_to_calibrate)}). Cannot determine household count." - ) - n_households_from_weights = len(w) // len(cds_to_calibrate) - - if n_households_from_weights != n_households_orig: - raise ValueError( - "Households from base data set do not match households from weights" - ) - - print(f"\nOriginal dataset has {n_households_orig:,} households") - - # Process the weight vector to understand active household-CD pairs - W_full = w.reshape(len(cds_to_calibrate), n_households_orig) - # (436, 10580) - - # Extract only the CDs we want to process - if cd_subset is not None: - W = W_full[cd_indices, :] - print( - f"Extracted weights for {len(cd_indices)} CDs from full weight matrix" - ) - else: - W = W_full - - # Count total active weights: i.e., number of active households - total_active_weights = np.sum(W > 0) - total_weight_in_W = np.sum(W) - print(f"Total active household-CD pairs: {total_active_weights:,}") - print(f"Total weight in W matrix: {total_weight_in_W:,.0f}") - - cd_geoadj_values = load_cd_geoadj_values(cds_to_calibrate) - - # Collect DataFrames for each CD - cd_dfs = [] - total_kept_households = 0 - time_period = int(base_sim.default_calculation_period) - - for idx, cd_geoid in enumerate(cds_to_process): - # Progress every 10 CDs and at the end ---- - if (idx + 1) % 10 == 0 or (idx + 1) == len(cds_to_process): - print( - f"Processing CD {cd_geoid} ({idx + 1}/{len(cds_to_process)})..." - ) - - # Get the correct index in the weight matrix - cd_idx = idx # Index in our filtered W matrix - - # Get ALL households with non-zero weight in this CD - active_household_indices = np.where(W[cd_idx, :] > 0)[0] - - if len(active_household_indices) == 0: - continue - - # Get the household IDs for active households - active_household_ids = set( - household_ids[hh_idx] for hh_idx in active_household_indices - ) - - # Fresh simulation per CD is necessary because: - # 1. Each CD needs different state_fips, county, and CD values set - # 2. Calculated variables (SNAP, Medicaid, etc.) must be invalidated - # and recalculated with the new geographic inputs - # 3. Reusing a simulation would retain stale cached calculations - # Memory impact: ~50MB per simulation, but allows correct state-specific - # benefit calculations. Total memory scales with CD count in cd_subset. - cd_sim = Microsimulation(dataset=dataset_path) - - # First, create hh_df with CALIBRATED weights from the W matrix - household_ids_in_sim = cd_sim.calculate( - "household_id", map_to="household" - ).values - - # Get this CD's calibrated weights from the weight matrix - calibrated_weights_for_cd = W[ - cd_idx, : - ].copy() # Get this CD's row from weight matrix - - # For city datasets: scale weights by P(target|CD) - # This preserves the representative sample while adjusting for target population - if county_filter is not None: - p_target = get_county_filter_probability(cd_geoid, county_filter) - if p_target == 0: - # CD has no overlap with target area, skip entirely - continue - calibrated_weights_for_cd = calibrated_weights_for_cd * p_target - - # Map the calibrated weights to household IDs - hh_weight_values = [] - for hh_id in household_ids_in_sim: - hh_idx = hh_id_to_idx[int(hh_id)] # Get index in weight matrix - hh_weight_values.append(calibrated_weights_for_cd[hh_idx]) - - entity_rel = pd.DataFrame( - { - "person_id": cd_sim.calculate( - "person_id", map_to="person" - ).values, - "household_id": cd_sim.calculate( - "household_id", map_to="person" - ).values, - "tax_unit_id": cd_sim.calculate( - "tax_unit_id", map_to="person" - ).values, - "spm_unit_id": cd_sim.calculate( - "spm_unit_id", map_to="person" - ).values, - "family_id": cd_sim.calculate( - "family_id", map_to="person" - ).values, - "marital_unit_id": cd_sim.calculate( - "marital_unit_id", map_to="person" - ).values, - } - ) - - hh_df = pd.DataFrame( - { - "household_id": household_ids_in_sim, - "household_weight": hh_weight_values, - } - ) - counts = ( - entity_rel.groupby("household_id")["person_id"] - .size() - .reset_index(name="persons_per_hh") - ) - hh_df = hh_df.merge(counts) - hh_df["per_person_hh_weight"] = ( - hh_df.household_weight / hh_df.persons_per_hh - ) - - # SET WEIGHTS IN SIMULATION BEFORE EXTRACTING DATAFRAME - # This is the key - set_input updates the simulation's internal state - - non_household_cols = [ - "person_id", - "tax_unit_id", - "spm_unit_id", - "family_id", - "marital_unit_id", - ] - - new_weights_per_id = {} - for col in non_household_cols: - person_counts = ( - entity_rel.groupby(col)["person_id"] - .size() - .reset_index(name="person_id_count") - ) - # Below: drop duplicates to undo the broadcast join done in entity_rel - id_link = entity_rel[["household_id", col]].drop_duplicates() - hh_info = id_link.merge(hh_df) - - hh_info2 = hh_info.merge(person_counts, on=col) - if col == "person_id": - # Person weight = household weight (each person represents same count as their household) - hh_info2["id_weight"] = hh_info2.household_weight - else: - hh_info2["id_weight"] = ( - hh_info2.per_person_hh_weight * hh_info2.person_id_count - ) - new_weights_per_id[col] = hh_info2.id_weight - - cd_sim.set_input( - "household_weight", time_period, hh_df.household_weight.values - ) - cd_sim.set_input( - "person_weight", time_period, new_weights_per_id["person_id"] - ) - cd_sim.set_input( - "tax_unit_weight", time_period, new_weights_per_id["tax_unit_id"] - ) - cd_sim.set_input( - "spm_unit_weight", time_period, new_weights_per_id["spm_unit_id"] - ) - cd_sim.set_input( - "marital_unit_weight", - time_period, - new_weights_per_id["marital_unit_id"], - ) - cd_sim.set_input( - "family_weight", time_period, new_weights_per_id["family_id"] - ) - - # Extract state from CD GEOID and update simulation BEFORE calling to_input_dataframe() - # This ensures calculated variables (SNAP, Medicaid) use the correct state - cd_geoid_int = int(cd_geoid) - state_fips = cd_geoid_int // 100 - - cd_sim.set_input( - "state_fips", - time_period, - np.full(n_households_orig, state_fips, dtype=np.int32), - ) - cd_sim.set_input( - "congressional_district_geoid", - time_period, - np.full(n_households_orig, cd_geoid_int, dtype=np.int32), - ) - - # Assign all geography using census block assignment - # For city datasets: use only blocks in target counties - if county_filter is not None: - filtered_dist = get_filtered_block_distribution( - cd_geoid, county_filter - ) - if not filtered_dist: - # Should not happen if we already checked p_target > 0 - continue - geography = assign_geography_for_cd( - cd_geoid=cd_geoid, - n_households=n_households_orig, - seed=seed + int(cd_geoid), - distributions={cd_geoid: filtered_dist}, - ) - else: - geography = assign_geography_for_cd( - cd_geoid=cd_geoid, - n_households=n_households_orig, - seed=seed + int(cd_geoid), - ) - # Set county using indices for backwards compatibility with PolicyEngine-US - cd_sim.set_input("county", time_period, geography["county_index"]) - - # Set all other geography variables from block assignment - cd_sim.set_input("block_geoid", time_period, geography["block_geoid"]) - cd_sim.set_input("tract_geoid", time_period, geography["tract_geoid"]) - cd_sim.set_input("cbsa_code", time_period, geography["cbsa_code"]) - cd_sim.set_input("sldu", time_period, geography["sldu"]) - cd_sim.set_input("sldl", time_period, geography["sldl"]) - cd_sim.set_input("place_fips", time_period, geography["place_fips"]) - cd_sim.set_input("vtd", time_period, geography["vtd"]) - cd_sim.set_input("puma", time_period, geography["puma"]) - cd_sim.set_input("zcta", time_period, geography["zcta"]) - - # Note: We no longer use binary filtering for county_filter. - # Instead, weights are scaled by P(target|CD) and all households - # are included to avoid sample selection bias. - - geoadj = cd_geoadj_values[cd_geoid] - new_spm_thresholds = calculate_spm_thresholds_for_cd( - cd_sim, time_period, geoadj, year=time_period - ) - cd_sim.set_input( - "spm_unit_spm_threshold", time_period, new_spm_thresholds - ) - - # Delete cached calculated variables to ensure they're recalculated - # with new state and county. Exclude 'county' itself since we just set it. - for var in get_calculated_variables(cd_sim): - if var != "county": - cd_sim.delete_arrays(var) - - # Now extract the dataframe - calculated vars will use the updated state - df = cd_sim.to_input_dataframe() - - assert df.shape[0] == entity_rel.shape[0] # df is at the person level - - # Column names follow pattern: variable__year - hh_id_col = f"household_id__{time_period}" - cd_geoid_col = f"congressional_district_geoid__{time_period}" - hh_weight_col = f"household_weight__{time_period}" - person_weight_col = f"person_weight__{time_period}" - tax_unit_weight_col = f"tax_unit_weight__{time_period}" - person_id_col = f"person_id__{time_period}" - tax_unit_id_col = f"tax_unit_id__{time_period}" - - state_fips_col = f"state_fips__{time_period}" - state_name_col = f"state_name__{time_period}" - state_code_col = f"state_code__{time_period}" - - # Filter to only active households in this CD - df_filtered = df[df[hh_id_col].isin(active_household_ids)].copy() - - # Update congressional_district_geoid to target CD - df_filtered[cd_geoid_col] = int(cd_geoid) - - # Update state variables for consistency - df_filtered[state_fips_col] = state_fips - if state_fips in STATE_FIPS_TO_NAME: - df_filtered[state_name_col] = STATE_FIPS_TO_NAME[state_fips] - if state_fips in STATE_FIPS_TO_CODE: - df_filtered[state_code_col] = STATE_FIPS_TO_CODE[state_fips] - - cd_dfs.append(df_filtered) - total_kept_households += len(df_filtered[hh_id_col].unique()) - - print(f"\nCombining {len(cd_dfs)} CD DataFrames...") - print(f"Total households across all CDs: {total_kept_households:,}") - - # Combine all CD DataFrames - combined_df = pd.concat(cd_dfs, ignore_index=True) - print(f"Combined DataFrame shape: {combined_df.shape}") - - # REINDEX ALL IDs TO PREVENT OVERFLOW AND HANDLE DUPLICATES - print("\nReindexing all entity IDs using 25k ranges per CD...") - - # Column names - hh_id_col = f"household_id__{time_period}" - person_id_col = f"person_id__{time_period}" - person_hh_id_col = f"person_household_id__{time_period}" - tax_unit_id_col = f"tax_unit_id__{time_period}" - person_tax_unit_col = f"person_tax_unit_id__{time_period}" - spm_unit_id_col = f"spm_unit_id__{time_period}" - person_spm_unit_col = f"person_spm_unit_id__{time_period}" - marital_unit_id_col = f"marital_unit_id__{time_period}" - person_marital_unit_col = f"person_marital_unit_id__{time_period}" - family_id_col = f"family_id__{time_period}" - person_family_col = f"person_family_id__{time_period}" - cd_geoid_col = f"congressional_district_geoid__{time_period}" - - # Build CD index mapping from cds_to_calibrate (avoids database dependency) - cds_sorted = sorted(cds_to_calibrate) - cd_to_index = {cd: idx for idx, cd in enumerate(cds_sorted)} - - # Create household mapping for CSV export - household_mapping = [] - - # First, create a unique row identifier to track relationships - combined_df["_row_idx"] = range(len(combined_df)) - - # Group by household ID AND congressional district to create unique household-CD pairs - hh_groups = ( - combined_df.groupby([hh_id_col, cd_geoid_col])["_row_idx"] - .apply(list) - .to_dict() - ) - - # Assign new household IDs using 25k ranges per CD - hh_row_to_new_id = {} - cd_hh_counters = {} # Track how many households assigned per CD - - for (old_hh_id, cd_geoid), row_indices in hh_groups.items(): - # Calculate the ID range for this CD directly (avoiding function call) - cd_str = str(int(cd_geoid)) - cd_idx = cd_to_index[cd_str] - start_id = cd_idx * 25_000 - end_id = start_id + 24_999 - - # Get the next available ID in this CD's range - if cd_str not in cd_hh_counters: - cd_hh_counters[cd_str] = 0 - - new_hh_id = start_id + cd_hh_counters[cd_str] - - # Check we haven't exceeded the range - if new_hh_id > end_id: - raise ValueError( - f"CD {cd_str} exceeded its 25k household allocation" - ) - - # All rows in the same household-CD pair get the SAME new ID - for row_idx in row_indices: - hh_row_to_new_id[row_idx] = new_hh_id - - # Save the mapping - household_mapping.append( - { - "new_household_id": new_hh_id, - "original_household_id": int(old_hh_id), - "congressional_district": cd_str, - "state_fips": int(cd_str) // 100, - } - ) - - cd_hh_counters[cd_str] += 1 - - # Apply new household IDs based on row index - combined_df["_new_hh_id"] = combined_df["_row_idx"].map(hh_row_to_new_id) - - # Update household IDs - combined_df[hh_id_col] = combined_df["_new_hh_id"] - - # Update person household references - since persons are already in their households, - # person_household_id should just match the household_id of their row - combined_df[person_hh_id_col] = combined_df["_new_hh_id"] - - # Report statistics - total_households = sum(cd_hh_counters.values()) - print( - f" Created {total_households:,} unique households across {len(cd_hh_counters)} CDs" - ) - - # Now handle persons with same 25k range approach - VECTORIZED - print(" Reindexing persons using 25k ranges...") - - # OFFSET PERSON IDs by 5 million to avoid collision with household IDs - PERSON_ID_OFFSET = 5_000_000 - - # Group by CD and assign IDs in bulk for each CD - for cd_geoid_val in combined_df[cd_geoid_col].unique(): - cd_str = str(int(cd_geoid_val)) - - # Calculate the ID range for this CD directly - cd_idx = cd_to_index[cd_str] - start_id = cd_idx * 25_000 + PERSON_ID_OFFSET # Add offset for persons - end_id = start_id + 24_999 - - # Get all rows for this CD - cd_mask = combined_df[cd_geoid_col] == cd_geoid_val - n_persons_in_cd = cd_mask.sum() - - # Check we won't exceed the range - if n_persons_in_cd > (end_id - start_id + 1): - raise ValueError( - f"CD {cd_str} has {n_persons_in_cd} persons, exceeds 25k allocation" - ) - - # Create sequential IDs for this CD - new_person_ids = np.arange( - start_id, start_id + n_persons_in_cd, dtype=np.int32 - ) - - # Assign all at once using loc - combined_df.loc[cd_mask, person_id_col] = new_person_ids - - # Reindex sub-household entities using vectorized groupby().ngroup() - # This assigns unique IDs to each (household_id, original_entity_id) pair, - # which correctly handles the same original household appearing in multiple CDs - entity_configs = [ - ("tax units", person_tax_unit_col, tax_unit_id_col), - ("SPM units", person_spm_unit_col, spm_unit_id_col), - ("marital units", person_marital_unit_col, marital_unit_id_col), - ("families", person_family_col, family_id_col), - ] - - for entity_name, person_col, entity_col in entity_configs: - print(f" Reindexing {entity_name}...") - # Group by (household_id, original_entity_id) and assign unique group numbers - new_ids = combined_df.groupby( - [hh_id_col, person_col], sort=False - ).ngroup() - combined_df[person_col] = new_ids - if entity_col in combined_df.columns: - combined_df[entity_col] = new_ids - - # Clean up temporary columns - temp_cols = [col for col in combined_df.columns if col.startswith("_")] - combined_df = combined_df.drop(columns=temp_cols) - - print(f" Final persons: {len(combined_df):,}") - print(f" Final households: {total_households:,}") - print(f" Final tax units: {combined_df[person_tax_unit_col].nunique():,}") - print(f" Final SPM units: {combined_df[person_spm_unit_col].nunique():,}") - print( - f" Final marital units: {combined_df[person_marital_unit_col].nunique():,}" - ) - print(f" Final families: {combined_df[person_family_col].nunique():,}") - - # Check weights in combined_df AFTER reindexing - print(f"\nWeights in combined_df AFTER reindexing:") - print(f" HH weight sum: {combined_df[hh_weight_col].sum()/1e6:.2f}M") - print( - f" Person weight sum: {combined_df[person_weight_col].sum()/1e6:.2f}M" - ) - print( - f" Ratio: {combined_df[person_weight_col].sum() / combined_df[hh_weight_col].sum():.2f}" - ) - - # Verify no overflow risk - max_person_id = combined_df[person_id_col].max() - print(f"\nOverflow check:") - print(f" Max person ID after reindexing: {max_person_id:,}") - print(f" Max person ID × 100: {max_person_id * 100:,}") - print(f" int32 max: {2_147_483_647:,}") - if max_person_id * 100 < 2_147_483_647: - print(" ✓ No overflow risk!") - else: - print(" ⚠️ WARNING: Still at risk of overflow!") - - # Create Dataset from combined DataFrame - print("\nCreating Dataset from combined DataFrame...") - sparse_dataset = Dataset.from_dataframe(combined_df, time_period) - - # Build a simulation to convert to h5 - print("Building simulation from Dataset...") - sparse_sim = Microsimulation() - sparse_sim.dataset = sparse_dataset - sparse_sim.build_from_dataset() - - # Save to h5 file - print(f"\nSaving to {output_path}...") - data = {} - - # Only save input variables (not calculated/derived variables) - # Calculated variables like state_name, state_code will be recalculated on load - vars_to_save = set(base_sim.input_variables) - print(f"Found {len(vars_to_save)} input variables to save") - - # congressional_district_geoid isn't in the original microdata and has no formula, - # so it's not in input_vars. Since we set it explicitly during stacking, save it. - vars_to_save.add("congressional_district_geoid") - - # county is set explicitly with assign_counties_for_cd, must be saved - vars_to_save.add("county") - - # spm_unit_spm_threshold is recalculated with CD-specific geo-adjustment - vars_to_save.add("spm_unit_spm_threshold") - - # Add all geography variables set during block assignment - vars_to_save.add("block_geoid") - vars_to_save.add("tract_geoid") - vars_to_save.add("cbsa_code") - vars_to_save.add("sldu") - vars_to_save.add("sldl") - vars_to_save.add("place_fips") - vars_to_save.add("vtd") - vars_to_save.add("puma") - vars_to_save.add("zcta") - - variables_saved = 0 - variables_skipped = 0 - - for variable in sparse_sim.tax_benefit_system.variables: - if variable not in vars_to_save: - variables_skipped += 1 - continue - - # Only process variables that have actual data - data[variable] = {} - for period in sparse_sim.get_holder(variable).get_known_periods(): - values = sparse_sim.get_holder(variable).get_array(period) - - # Handle different value types - if ( - sparse_sim.tax_benefit_system.variables.get( - variable - ).value_type - in (Enum, str) - and variable != "county_fips" - ): - # Handle EnumArray objects - if hasattr(values, "decode_to_str"): - values = values.decode_to_str().astype("S") - else: - # Already a regular numpy array, just convert to string type - values = values.astype("S") - elif variable == "county_fips": - values = values.astype("int32") - else: - values = np.array(values) - - if values is not None: - data[variable][period] = values - variables_saved += 1 - - if len(data[variable]) == 0: - del data[variable] - - print(f"Variables saved: {variables_saved}") - print(f"Variables skipped: {variables_skipped}") - - # Write to h5 - with h5py.File(output_path, "w") as f: - for variable, periods in data.items(): - grp = f.create_group(variable) - for period, values in periods.items(): - grp.create_dataset(str(period), data=values) - - print(f"Sparse CD-stacked dataset saved successfully!") - - # Save household mapping to CSV in a mappings subdirectory - mapping_df = pd.DataFrame(household_mapping) - output_dir = os.path.dirname(output_path) - mappings_dir = ( - os.path.join(output_dir, "mappings") if output_dir else "mappings" - ) - os.makedirs(mappings_dir, exist_ok=True) - csv_filename = os.path.basename(output_path).replace( - ".h5", "_household_mapping.csv" - ) - csv_path = os.path.join(mappings_dir, csv_filename) - mapping_df.to_csv(csv_path, index=False) - print(f"Household mapping saved to {csv_path}") - - # Verify the saved file - print("\nVerifying saved file...") - with h5py.File(output_path, "r") as f: - if "household_id" in f and str(time_period) in f["household_id"]: - hh_ids = f["household_id"][str(time_period)][:] - print(f" Final households: {len(hh_ids):,}") - if "person_id" in f and str(time_period) in f["person_id"]: - person_ids = f["person_id"][str(time_period)][:] - print(f" Final persons: {len(person_ids):,}") - if ( - "household_weight" in f - and str(time_period) in f["household_weight"] - ): - weights = f["household_weight"][str(time_period)][:] - print( - f" Total population (from household weights): {np.sum(weights):,.0f}" - ) - if "person_weight" in f and str(time_period) in f["person_weight"]: - person_weights = f["person_weight"][str(time_period)][:] - print( - f" Total population (from person weights): {np.sum(person_weights):,.0f}" - ) - print( - f" Average persons per household: {np.sum(person_weights) / np.sum(weights):.2f}" - ) - - return output_path - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="Create sparse CD-stacked datasets" - ) - parser.add_argument( - "--weights-path", required=True, help="Path to w_cd.npy file" - ) - parser.add_argument( - "--dataset-path", - required=True, - help="Path to stratified dataset .h5 file", - ) - parser.add_argument( - "--db-path", required=True, help="Path to policy_data.db" - ) - parser.add_argument( - "--output-dir", - default="./temp", - help="Output directory for files", - ) - parser.add_argument( - "--mode", - choices=[ - "national", - "states", - "cds", - "single-cd", - "single-state", - "nyc", - ], - default="national", - help="Output mode: national (one file), states (per-state files), cds (per-CD files), single-cd (one CD), single-state (one state), nyc (NYC only)", - ) - parser.add_argument( - "--cd", - type=str, - help="Single CD GEOID to process (only used with --mode single-cd)", - ) - parser.add_argument( - "--state", - type=str, - help="State code to process, e.g. RI, CA, NC (only used with --mode single-state)", - ) - - args = parser.parse_args() - dataset_path_str = args.dataset_path - weights_path_str = args.weights_path - db_path = Path(args.db_path).resolve() - output_dir = args.output_dir - mode = args.mode - - os.makedirs(output_dir, exist_ok=True) - - # Load weights - w = np.load(weights_path_str) - db_uri = f"sqlite:///{db_path}" - - # Get list of CDs from database - cds_to_calibrate = get_all_cds_from_database(db_uri) - print(f"Found {len(cds_to_calibrate)} congressional districts") - - # Verify dimensions - assert_sim = Microsimulation(dataset=dataset_path_str) - n_hh = assert_sim.calculate("household_id", map_to="household").shape[0] - expected_length = len(cds_to_calibrate) * n_hh - - if len(w) != expected_length: - raise ValueError( - f"Weight vector length ({len(w):,}) doesn't match expected ({expected_length:,})" - ) - - if mode == "national": - output_path = f"{output_dir}/national.h5" - print(f"\nCreating national dataset with all CDs: {output_path}") - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - dataset_path=dataset_path_str, - output_path=output_path, - ) - - elif mode == "states": - for state_fips, state_code in STATE_CODES.items(): - cd_subset = [ - cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips - ] - if not cd_subset: - continue - output_path = f"{output_dir}/{state_code}.h5" - print(f"\nCreating {state_code} dataset: {output_path}") - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=dataset_path_str, - output_path=output_path, - ) - - elif mode == "cds": - for i, cd_geoid in enumerate(cds_to_calibrate): - # Convert GEOID to friendly name: 3705 -> NC-05 - cd_int = int(cd_geoid) - state_fips = cd_int // 100 - district_num = cd_int % 100 - state_code = STATE_CODES.get(state_fips, str(state_fips)) - friendly_name = f"{state_code}-{district_num:02d}" - - output_path = f"{output_dir}/{friendly_name}.h5" - print( - f"\n[{i+1}/{len(cds_to_calibrate)}] Creating {friendly_name}.h5 (GEOID {cd_geoid})" - ) - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=[cd_geoid], - dataset_path=dataset_path_str, - output_path=output_path, - ) - - elif mode == "single-cd": - if not args.cd: - raise ValueError("--cd required with --mode single-cd") - if args.cd not in cds_to_calibrate: - raise ValueError(f"CD {args.cd} not in calibrated CDs list") - output_path = f"{output_dir}/{args.cd}.h5" - print(f"\nCreating single CD dataset: {output_path}") - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=[args.cd], - dataset_path=dataset_path_str, - output_path=output_path, - ) - - elif mode == "single-state": - if not args.state: - raise ValueError("--state required with --mode single-state") - # Find FIPS code for this state - state_code_upper = args.state.upper() - state_fips = None - for fips, code in STATE_CODES.items(): - if code == state_code_upper: - state_fips = fips - break - if state_fips is None: - raise ValueError(f"Unknown state code: {args.state}") - - cd_subset = [ - cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips - ] - if not cd_subset: - raise ValueError(f"No CDs found for state {state_code_upper}") - - output_path = f"{output_dir}/{state_code_upper}.h5" - print( - f"\nCreating {state_code_upper} dataset with {len(cd_subset)} CDs: {output_path}" - ) - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=dataset_path_str, - output_path=output_path, - ) - - elif mode == "nyc": - cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] - if not cd_subset: - raise ValueError("No NYC-related CDs found in calibrated CDs list") - - output_path = f"{output_dir}/NYC.h5" - print( - f"\nCreating NYC dataset with {len(cd_subset)} CDs: {output_path}" - ) - print(f" CDs: {', '.join(cd_subset)}") - print(" Filtering to NYC counties only") - - create_sparse_cd_stacked_dataset( - w, - cds_to_calibrate, - cd_subset=cd_subset, - dataset_path=dataset_path_str, - output_path=output_path, - county_filter=NYC_COUNTIES, - ) - - print("\nDone!") diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 0b9ae8a6d..253262c90 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -40,8 +40,9 @@ def fetch_congressional_districts(year): df["state_fips"] = df["state"].astype(int) df = df[df["state_fips"] <= 56].copy() df["district_number"] = df["congressional district"].apply( - lambda x: 0 if x in ["ZZ", "98"] else int(x) + lambda x: int(x) if x not in ["ZZ"] else -1 ) + df = df[df["district_number"] >= 0].copy() # Filter out statewide summary records for multi-district states df["n_districts"] = df.groupby("state_fips")["state_fips"].transform( @@ -49,8 +50,6 @@ def fetch_congressional_districts(year): ) df = df[(df["n_districts"] == 1) | (df["district_number"] > 0)].copy() df = df.drop(columns=["n_districts"]) - - df.loc[df["district_number"] == 0, "district_number"] = 1 df["congressional_district_geoid"] = ( df["state_fips"] * 100 + df["district_number"] ) diff --git a/policyengine_us_data/storage/calibration/w_district_calibration.npy b/policyengine_us_data/storage/calibration/w_district_calibration.npy deleted file mode 100644 index 6059d1818..000000000 Binary files a/policyengine_us_data/storage/calibration/w_district_calibration.npy and /dev/null differ diff --git a/policyengine_us_data/tests/test_calibration/conftest.py b/policyengine_us_data/tests/test_calibration/conftest.py index 9b8edcf74..354491567 100644 --- a/policyengine_us_data/tests/test_calibration/conftest.py +++ b/policyengine_us_data/tests/test_calibration/conftest.py @@ -1,4 +1,18 @@ -# Calibration test fixtures. -# -# The microimpute mock lives in the root conftest.py (propagates to -# all subdirectories). Add calibration-specific fixtures here. +"""Shared fixtures for local area calibration tests.""" + +import pytest + +from policyengine_us_data.storage import STORAGE_FOLDER + + +@pytest.fixture(scope="module") +def db_uri(): + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + return f"sqlite:///{db_path}" + + +@pytest.fixture(scope="module") +def dataset_path(): + return str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" + ) diff --git a/policyengine_us_data/tests/test_local_area_calibration/create_test_fixture.py b/policyengine_us_data/tests/test_calibration/create_test_fixture.py similarity index 100% rename from policyengine_us_data/tests/test_local_area_calibration/create_test_fixture.py rename to policyengine_us_data/tests/test_calibration/create_test_fixture.py diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_block_assignment.py b/policyengine_us_data/tests/test_calibration/test_block_assignment.py similarity index 82% rename from policyengine_us_data/tests/test_local_area_calibration/test_block_assignment.py rename to policyengine_us_data/tests/test_calibration/test_block_assignment.py index 0f1001385..c128d65e6 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/test_block_assignment.py +++ b/policyengine_us_data/tests/test_calibration/test_block_assignment.py @@ -14,7 +14,7 @@ class TestBlockAssignment: def test_assign_returns_correct_shape(self): """Verify assign_blocks_for_cd returns correct shape.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_blocks_for_cd, ) @@ -26,7 +26,7 @@ def test_assign_returns_correct_shape(self): def test_assign_is_deterministic(self): """Verify same seed produces same results.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_blocks_for_cd, ) @@ -36,7 +36,7 @@ def test_assign_is_deterministic(self): def test_different_seeds_different_results(self): """Verify different seeds produce different results.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_blocks_for_cd, ) @@ -46,7 +46,7 @@ def test_different_seeds_different_results(self): def test_ny_cd_gets_ny_blocks(self): """Verify NY CDs get NY blocks.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_blocks_for_cd, ) @@ -59,7 +59,7 @@ def test_ny_cd_gets_ny_blocks(self): def test_ca_cd_gets_ca_blocks(self): """Verify CA CDs get CA blocks.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_blocks_for_cd, ) @@ -76,7 +76,7 @@ class TestGeographyLookup: def test_get_county_from_block(self): """Verify county FIPS extraction from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_county_fips_from_block, ) @@ -89,7 +89,7 @@ def test_get_county_from_block(self): def test_get_tract_from_block(self): """Verify tract GEOID extraction from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_tract_geoid_from_block, ) @@ -100,7 +100,7 @@ def test_get_tract_from_block(self): def test_get_state_fips_from_block(self): """Verify state FIPS extraction from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_state_fips_from_block, ) @@ -114,7 +114,7 @@ class TestCBSALookup: def test_manhattan_in_nyc_metro(self): """Verify Manhattan (New York County) is in NYC metro area.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_cbsa_from_county, ) @@ -125,7 +125,7 @@ def test_manhattan_in_nyc_metro(self): def test_sf_county_in_sf_metro(self): """Verify San Francisco County is in SF metro area.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_cbsa_from_county, ) @@ -136,7 +136,7 @@ def test_sf_county_in_sf_metro(self): def test_rural_county_no_cbsa(self): """Verify rural county not in any metro area returns None.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_cbsa_from_county, ) @@ -150,7 +150,7 @@ class TestIntegratedAssignment: def test_assign_geography_returns_all_fields(self): """Verify assign_geography returns dict with all geography fields.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) @@ -181,7 +181,7 @@ def test_assign_geography_returns_all_fields(self): def test_geography_is_consistent(self): """Verify all geography fields are consistent with each other.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) @@ -207,7 +207,7 @@ class TestStateLegislativeDistricts: def test_get_sldu_from_block(self): """Verify SLDU lookup from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_sldu_from_block, ) @@ -218,7 +218,7 @@ def test_get_sldu_from_block(self): def test_get_sldl_from_block(self): """Verify SLDL lookup from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_sldl_from_block, ) @@ -229,7 +229,7 @@ def test_get_sldl_from_block(self): def test_assign_geography_includes_state_leg(self): """Verify assign_geography includes SLDU and SLDL.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) @@ -246,7 +246,7 @@ class TestPlaceLookup: def test_get_place_fips_from_block(self): """Verify place FIPS lookup from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_place_fips_from_block, ) @@ -257,7 +257,7 @@ def test_get_place_fips_from_block(self): def test_assign_geography_includes_place(self): """Verify assign_geography includes place_fips.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) @@ -272,7 +272,7 @@ class TestPUMALookup: def test_get_puma_from_block(self): """Verify PUMA lookup from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_puma_from_block, ) @@ -283,7 +283,7 @@ def test_get_puma_from_block(self): def test_assign_geography_includes_puma(self): """Verify assign_geography includes PUMA.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) @@ -298,7 +298,7 @@ class TestVTDLookup: def test_get_vtd_from_block(self): """Verify VTD lookup from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_vtd_from_block, ) @@ -309,7 +309,7 @@ def test_get_vtd_from_block(self): def test_assign_geography_includes_vtd(self): """Verify assign_geography includes VTD.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) @@ -324,7 +324,7 @@ class TestAllGeographyLookup: def test_get_all_geography_returns_all_fields(self): """Verify get_all_geography_from_block returns all expected fields.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_all_geography_from_block, ) @@ -336,7 +336,7 @@ def test_get_all_geography_returns_all_fields(self): def test_get_all_geography_unknown_block(self): """Verify get_all_geography handles unknown block gracefully.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_all_geography_from_block, ) @@ -352,7 +352,7 @@ class TestCountyEnumIntegration: def test_get_county_enum_from_block(self): """Verify we can get County enum index from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_county_enum_index_from_block, ) from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( @@ -368,7 +368,7 @@ def test_get_county_enum_from_block(self): def test_assign_geography_includes_county_index(self): """Verify assign_geography includes county_index for backwards compat.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( @@ -392,7 +392,7 @@ class TestZCTALookup: def test_get_zcta_from_block(self): """Verify ZCTA lookup from block GEOID.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( get_zcta_from_block, ) @@ -403,7 +403,7 @@ def test_get_zcta_from_block(self): def test_assign_geography_includes_zcta(self): """Verify assign_geography includes ZCTA.""" - from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import ( + from policyengine_us_data.calibration.block_assignment import ( assign_geography_for_cd, ) diff --git a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py index 8db56ddcb..58bd3a4f3 100644 --- a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py +++ b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py @@ -15,13 +15,14 @@ from policyengine_us_data.storage import STORAGE_FOLDER -DATASET_PATH = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") +DATASET_PATH = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") DB_URI = f"sqlite:///{DB_PATH}" N_CLONES = 2 SEED = 42 -RECORD_IDX = 8629 # High SNAP ($18k), lands in TX/PA with seed=42 def _data_available(): @@ -56,12 +57,34 @@ def matrix_result(): sim=sim, target_filter={"domain_variables": ["snap", "medicaid"]}, ) + X_csc = X_sparse.tocsc() + national_rows = targets_df[ + targets_df["geo_level"] == "national" + ].index.values + district_targets = targets_df[targets_df["geo_level"] == "district"] + record_idx = None + for ri in range(n_records): + vals = X_csc[:, ri].toarray().ravel() + if not np.any(vals[national_rows] != 0): + continue + cd = str(geography.cd_geoid[ri]) + own_cd_rows = district_targets[ + district_targets["geographic_id"] == cd + ].index.values + if len(own_cd_rows) > 0 and np.any(vals[own_cd_rows] != 0): + record_idx = ri + break + + if record_idx is None: + pytest.skip("No suitable test household found") + return { "geography": geography, "targets_df": targets_df, "X": X_sparse, "target_names": target_names, "n_records": n_records, + "record_idx": record_idx, } @@ -94,8 +117,8 @@ def test_both_clones_visible_to_national_target(self, matrix_result): national_rows = targets_df[targets_df["geo_level"] == "national"].index assert len(national_rows) > 0 - col_0 = _clone_col(n_records, 0, RECORD_IDX) - col_1 = _clone_col(n_records, 1, RECORD_IDX) + col_0 = _clone_col(n_records, 0, matrix_result["record_idx"]) + col_1 = _clone_col(n_records, 1, matrix_result["record_idx"]) X_csc = X.tocsc() visible_0 = X_csc[:, col_0].toarray().ravel() @@ -117,8 +140,8 @@ def test_clone_visible_only_to_own_state(self, matrix_result): geography = matrix_result["geography"] n_records = matrix_result["n_records"] - col_0 = _clone_col(n_records, 0, RECORD_IDX) - col_1 = _clone_col(n_records, 1, RECORD_IDX) + col_0 = _clone_col(n_records, 0, matrix_result["record_idx"]) + col_1 = _clone_col(n_records, 1, matrix_result["record_idx"]) state_0 = str(int(geography.state_fips[col_0])) state_1 = str(int(geography.state_fips[col_1])) @@ -155,7 +178,7 @@ def test_clone_visible_only_to_own_cd(self, matrix_result): geography = matrix_result["geography"] n_records = matrix_result["n_records"] - col_0 = _clone_col(n_records, 0, RECORD_IDX) + col_0 = _clone_col(n_records, 0, matrix_result["record_idx"]) cd_0 = str(geography.cd_geoid[col_0]) state_0 = str(int(geography.state_fips[col_0])) @@ -185,7 +208,7 @@ def test_clone_nonzero_for_own_cd(self, matrix_result): geography = matrix_result["geography"] n_records = matrix_result["n_records"] - col_0 = _clone_col(n_records, 0, RECORD_IDX) + col_0 = _clone_col(n_records, 0, matrix_result["record_idx"]) cd_0 = str(geography.cd_geoid[col_0]) own_cd_targets = targets_df[ diff --git a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py index 0ba330549..b2d45bd55 100644 --- a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py +++ b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py @@ -133,6 +133,22 @@ def test_state_from_block(self, mock_load): expected = int(r.block_geoid[i][:2]) assert r.state_fips[i] == expected + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_no_cd_collisions_across_clones(self, mock_load): + mock_load.return_value = _mock_distribution() + r = assign_random_geography(n_records=100, n_clones=3, seed=42) + for rec in range(r.n_records): + rec_cds = [ + r.cd_geoid[clone * r.n_records + rec] + for clone in range(r.n_clones) + ] + assert len(rec_cds) == len( + set(rec_cds) + ), f"Record {rec} has duplicate CDs: {rec_cds}" + def test_missing_file_raises(self, tmp_path): fake = tmp_path / "nonexistent" fake.mkdir() @@ -150,6 +166,7 @@ def test_doubles_n_records(self): geo = GeographyAssignment( block_geoid=np.array(["010010001001001", "020010001001001"] * 3), cd_geoid=np.array(["101", "202"] * 3), + county_fips=np.array(["01001", "02001"] * 3), state_fips=np.array([1, 2] * 3), n_records=2, n_clones=3, @@ -172,6 +189,9 @@ def test_puf_half_matches_cps_half(self): ] ), cd_geoid=np.array(["101", "202", "1036", "653", "4831", "1227"]), + county_fips=np.array( + ["01001", "02001", "36010", "06010", "48010", "12010"] + ), state_fips=np.array([1, 2, 36, 6, 48, 12]), n_records=3, n_clones=2, diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_county_assignment.py b/policyengine_us_data/tests/test_calibration/test_county_assignment.py similarity index 98% rename from policyengine_us_data/tests/test_local_area_calibration/test_county_assignment.py rename to policyengine_us_data/tests/test_calibration/test_county_assignment.py index 158e0ca68..03d7342d9 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/test_county_assignment.py +++ b/policyengine_us_data/tests/test_calibration/test_county_assignment.py @@ -6,7 +6,7 @@ from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( County, ) -from policyengine_us_data.datasets.cps.local_area_calibration.county_assignment import ( +from policyengine_us_data.calibration.county_assignment import ( assign_counties_for_cd, get_county_index, _build_state_counties, diff --git a/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py b/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py index daade621d..c69abe76a 100644 --- a/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py +++ b/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py @@ -5,7 +5,7 @@ import pytest from scipy import sparse -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( +from policyengine_us_data.calibration.calibration_utils import ( drop_target_groups, create_target_groups, ) diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_fixture_50hh.h5 b/policyengine_us_data/tests/test_calibration/test_fixture_50hh.h5 similarity index 100% rename from policyengine_us_data/tests/test_local_area_calibration/test_fixture_50hh.h5 rename to policyengine_us_data/tests/test_calibration/test_fixture_50hh.h5 diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_stacked_dataset_builder.py b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py similarity index 75% rename from policyengine_us_data/tests/test_local_area_calibration/test_stacked_dataset_builder.py rename to policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py index 2900eec19..48177726a 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/test_stacked_dataset_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py @@ -1,4 +1,4 @@ -"""Tests for stacked_dataset_builder.py using deterministic test fixture.""" +"""Tests for build_h5 using deterministic test fixture.""" import os import tempfile @@ -6,13 +6,14 @@ import pandas as pd import pytest +from pathlib import Path from policyengine_us import Microsimulation -from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import ( - create_sparse_cd_stacked_dataset, +from policyengine_us_data.calibration.publish_local_area import ( + build_h5, ) FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "test_fixture_50hh.h5") -TEST_CDS = ["3701", "201"] # NC-01 and AK at-large +TEST_CDS = ["3701", "200"] # NC-01 and AK at-large SEED = 42 @@ -48,12 +49,13 @@ def stacked_result(test_weights): with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "test_output.h5") - create_sparse_cd_stacked_dataset( - test_weights, - TEST_CDS, + build_h5( + weights=np.array(test_weights), + blocks=None, + dataset_path=Path(FIXTURE_PATH), + output_path=Path(output_path), + cds_to_calibrate=TEST_CDS, cd_subset=TEST_CDS, - dataset_path=FIXTURE_PATH, - output_path=output_path, ) sim_after = Microsimulation(dataset=output_path) @@ -69,12 +71,7 @@ def stacked_result(test_weights): ) ) - mapping_path = os.path.join( - tmpdir, "mappings", "test_output_household_mapping.csv" - ) - mapping_df = pd.read_csv(mapping_path) - - yield {"hh_df": hh_df, "mapping_df": mapping_df} + yield {"hh_df": hh_df} class TestStackedDatasetBuilder: @@ -85,10 +82,10 @@ def test_output_has_correct_cd_count(self, stacked_result): assert len(cds_in_output) == len(TEST_CDS) def test_output_contains_both_cds(self, stacked_result): - """Output should contain both NC-01 (3701) and AK-AL (201).""" + """Output should contain both NC-01 (3701) and AK-AL (200).""" hh_df = stacked_result["hh_df"] cds_in_output = set(hh_df["congressional_district_geoid"].unique()) - expected = {3701, 201} + expected = {3701, 200} assert cds_in_output == expected def test_state_fips_matches_cd(self, stacked_result): @@ -106,27 +103,6 @@ def test_household_ids_are_unique(self, stacked_result): hh_df = stacked_result["hh_df"] assert hh_df["household_id"].nunique() == len(hh_df) - def test_mapping_has_required_columns(self, stacked_result): - """Mapping CSV should have expected columns.""" - mapping_df = stacked_result["mapping_df"] - required_cols = [ - "new_household_id", - "original_household_id", - "congressional_district", - "state_fips", - ] - for col in required_cols: - assert col in mapping_df.columns - - def test_mapping_covers_all_output_households(self, stacked_result): - """Every output household should be in the mapping.""" - hh_df = stacked_result["hh_df"] - mapping_df = stacked_result["mapping_df"] - - output_hh_ids = set(hh_df["household_id"].values) - mapped_hh_ids = set(mapping_df["new_household_id"].values) - assert output_hh_ids == mapped_hh_ids - def test_weights_are_positive(self, stacked_result): """All household weights should be positive.""" hh_df = stacked_result["hh_df"] @@ -164,12 +140,13 @@ def stacked_sim(test_weights): with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "test_output.h5") - create_sparse_cd_stacked_dataset( - test_weights, - TEST_CDS, + build_h5( + weights=np.array(test_weights), + blocks=None, + dataset_path=Path(FIXTURE_PATH), + output_path=Path(output_path), + cds_to_calibrate=TEST_CDS, cd_subset=TEST_CDS, - dataset_path=FIXTURE_PATH, - output_path=output_path, ) sim = Microsimulation(dataset=output_path) @@ -179,21 +156,21 @@ def stacked_sim(test_weights): @pytest.fixture(scope="module") def stacked_sim_with_overlap(n_households): """Stacked dataset where SAME households appear in BOTH CDs.""" - # Force same households to appear in both CDs - tests reindexing w = np.zeros(n_households * len(TEST_CDS), dtype=float) - overlap_households = [0, 1, 2] # Same households in both CDs + overlap_households = [0, 1, 2] for cd_idx in range(len(TEST_CDS)): for hh_idx in overlap_households: w[cd_idx * n_households + hh_idx] = 1.0 with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "test_overlap.h5") - create_sparse_cd_stacked_dataset( - w, - TEST_CDS, + build_h5( + weights=np.array(w), + blocks=None, + dataset_path=Path(FIXTURE_PATH), + output_path=Path(output_path), + cds_to_calibrate=TEST_CDS, cd_subset=TEST_CDS, - dataset_path=FIXTURE_PATH, - output_path=output_path, ) sim = Microsimulation(dataset=output_path) yield {"sim": sim, "n_overlap": len(overlap_households)} @@ -241,21 +218,16 @@ def test_person_family_id_matches_family_id(self, stacked_sim): ), f"person_family_id {pf_id} not in family_ids" def test_family_ids_unique_across_cds(self, stacked_sim_with_overlap): - """Same household in different CDs should have different family_ids.""" + """Same HH in different CDs should get different family_ids.""" sim = stacked_sim_with_overlap["sim"] n_overlap = stacked_sim_with_overlap["n_overlap"] n_cds = len(TEST_CDS) family_ids = sim.calculate("family_id", map_to="family").values - household_ids = sim.calculate( - "household_id", map_to="household" - ).values - # Should have n_overlap * n_cds unique families (one per HH-CD pair) expected_families = n_overlap * n_cds assert len(family_ids) == expected_families, ( - f"Expected {expected_families} families (same HH in {n_cds} CDs), " - f"got {len(family_ids)}" + f"Expected {expected_families} families, " f"got {len(family_ids)}" ) assert len(set(family_ids)) == expected_families, ( f"Family IDs not unique: {len(set(family_ids))} unique " diff --git a/policyengine_us_data/tests/test_calibration/test_target_config.py b/policyengine_us_data/tests/test_calibration/test_target_config.py new file mode 100644 index 000000000..9241660ce --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_target_config.py @@ -0,0 +1,177 @@ +"""Tests for target config filtering in unified calibration.""" + +import numpy as np +import pandas as pd +import pytest +from scipy import sparse + +from policyengine_us_data.calibration.unified_calibration import ( + apply_target_config, + load_target_config, + save_calibration_package, + load_calibration_package, +) + + +@pytest.fixture +def sample_targets(): + targets_df = pd.DataFrame( + { + "variable": [ + "snap", + "snap", + "eitc", + "eitc", + "rent", + "person_count", + ], + "geo_level": [ + "national", + "state", + "district", + "state", + "national", + "national", + ], + "domain_variable": [ + "snap", + "snap", + "eitc", + "eitc", + "rent", + "person_count", + ], + "geographic_id": ["US", "6", "0601", "6", "US", "US"], + "value": [1000, 500, 200, 300, 800, 5000], + } + ) + n_rows = len(targets_df) + n_cols = 10 + rng = np.random.default_rng(42) + X = sparse.random(n_rows, n_cols, density=0.5, random_state=rng) + X = X.tocsr() + target_names = [ + f"{r.variable}_{r.geo_level}_{r.geographic_id}" + for _, r in targets_df.iterrows() + ] + return targets_df, X, target_names + + +class TestApplyTargetConfig: + def test_empty_config_keeps_all(self, sample_targets): + df, X, names = sample_targets + config = {"exclude": []} + out_df, out_X, out_names = apply_target_config(df, X, names, config) + assert len(out_df) == len(df) + assert out_X.shape == X.shape + assert out_names == names + + def test_single_variable_geo_exclusion(self, sample_targets): + df, X, names = sample_targets + config = {"exclude": [{"variable": "rent", "geo_level": "national"}]} + out_df, out_X, out_names = apply_target_config(df, X, names, config) + assert len(out_df) == len(df) - 1 + assert "rent" not in out_df["variable"].values + + def test_multiple_exclusions(self, sample_targets): + df, X, names = sample_targets + config = { + "exclude": [ + {"variable": "rent", "geo_level": "national"}, + {"variable": "eitc", "geo_level": "district"}, + ] + } + out_df, out_X, out_names = apply_target_config(df, X, names, config) + assert len(out_df) == len(df) - 2 + kept = set(zip(out_df["variable"], out_df["geo_level"])) + assert ("rent", "national") not in kept + assert ("eitc", "district") not in kept + assert ("eitc", "state") in kept + + def test_domain_variable_matching(self, sample_targets): + df, X, names = sample_targets + config = { + "exclude": [ + { + "variable": "snap", + "geo_level": "national", + "domain_variable": "snap", + } + ] + } + out_df, out_X, out_names = apply_target_config(df, X, names, config) + assert len(out_df) == len(df) - 1 + + def test_matrix_and_names_stay_in_sync(self, sample_targets): + df, X, names = sample_targets + config = { + "exclude": [{"variable": "person_count", "geo_level": "national"}] + } + out_df, out_X, out_names = apply_target_config(df, X, names, config) + assert out_X.shape[0] == len(out_df) + assert len(out_names) == len(out_df) + assert out_X.shape[1] == X.shape[1] + + def test_no_match_keeps_all(self, sample_targets): + df, X, names = sample_targets + config = { + "exclude": [{"variable": "nonexistent", "geo_level": "national"}] + } + out_df, out_X, out_names = apply_target_config(df, X, names, config) + assert len(out_df) == len(df) + assert out_X.shape[0] == X.shape[0] + + +class TestLoadTargetConfig: + def test_load_valid_config(self, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text( + "exclude:\n" " - variable: snap\n" " geo_level: national\n" + ) + config = load_target_config(str(config_file)) + assert len(config["exclude"]) == 1 + assert config["exclude"][0]["variable"] == "snap" + + def test_load_empty_config(self, tmp_path): + config_file = tmp_path / "empty.yaml" + config_file.write_text("") + config = load_target_config(str(config_file)) + assert config["exclude"] == [] + + +class TestCalibrationPackageRoundTrip: + def test_round_trip(self, sample_targets, tmp_path): + df, X, names = sample_targets + pkg_path = str(tmp_path / "package.pkl") + metadata = { + "dataset_path": "/tmp/test.h5", + "db_path": "/tmp/test.db", + "n_clones": 5, + "n_records": X.shape[1], + "seed": 42, + "created_at": "2024-01-01T00:00:00", + "target_config": None, + } + save_calibration_package(pkg_path, X, df, names, metadata) + loaded = load_calibration_package(pkg_path) + + assert loaded["target_names"] == names + pd.testing.assert_frame_equal(loaded["targets_df"], df) + assert loaded["X_sparse"].shape == X.shape + assert loaded["metadata"]["seed"] == 42 + + def test_package_then_filter(self, sample_targets, tmp_path): + df, X, names = sample_targets + pkg_path = str(tmp_path / "package.pkl") + metadata = {"n_records": X.shape[1]} + save_calibration_package(pkg_path, X, df, names, metadata) + loaded = load_calibration_package(pkg_path) + + config = {"exclude": [{"variable": "rent", "geo_level": "national"}]} + out_df, out_X, out_names = apply_target_config( + loaded["targets_df"], + loaded["X_sparse"], + loaded["target_names"], + config, + ) + assert len(out_df) == len(df) - 1 diff --git a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py index 2d3f80619..04e70ea69 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py @@ -1,13 +1,24 @@ -"""Tests for unified_calibration module. +"""Tests for unified_calibration and shared takeup module. -Focuses on rerandomize_takeup: verifies draws differ by -block and are reproducible within the same block. +Verifies geo-salted draws are reproducible and vary by geo_id, +SIMPLE_TAKEUP_VARS / TAKEUP_AFFECTED_TARGETS configs are valid, +block-level takeup seeding, county precomputation, and CLI flags. """ import numpy as np import pytest from policyengine_us_data.utils.randomness import seeded_rng +from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS, + TAKEUP_AFFECTED_TARGETS, + compute_block_takeup_for_entities, + apply_block_takeup_to_arrays, + _resolve_rate, +) +from policyengine_us_data.calibration.clone_and_assign import ( + GeographyAssignment, +) class TestRerandomizeTakeupSeeding: @@ -61,14 +72,140 @@ def test_rate_comparison_produces_booleans(self): assert 0.70 < frac < 0.80 +class TestBlockSaltedDraws: + """Verify compute_block_takeup_for_entities produces + reproducible, block-dependent draws.""" + + def test_same_block_same_results(self): + blocks = np.array(["370010001001001"] * 500) + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + d2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + np.testing.assert_array_equal(d1, d2) + + def test_different_blocks_different_results(self): + n = 500 + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", + 0.8, + np.array(["370010001001001"] * n), + ) + d2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", + 0.8, + np.array(["480010002002002"] * n), + ) + assert not np.array_equal(d1, d2) + + def test_different_vars_different_results(self): + blocks = np.array(["370010001001001"] * 500) + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + d2 = compute_block_takeup_for_entities( + "takes_up_aca_if_eligible", 0.8, blocks + ) + assert not np.array_equal(d1, d2) + + def test_hh_salt_differs_from_block_only(self): + blocks = np.array(["370010001001001"] * 500) + hh_ids = np.array([1] * 500) + d_block = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + d_hh = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids + ) + assert not np.array_equal(d_block, d_hh) + + +class TestApplyBlockTakeupToArrays: + """Verify apply_block_takeup_to_arrays returns correct + boolean arrays for all entity levels.""" + + def _make_arrays(self, n_hh, persons_per_hh, tu_per_hh, spm_per_hh): + """Build test arrays for n_hh households.""" + n_p = n_hh * persons_per_hh + n_tu = n_hh * tu_per_hh + n_spm = n_hh * spm_per_hh + hh_blocks = np.array(["370010001001001"] * n_hh) + hh_state_fips = np.array([37] * n_hh, dtype=np.int32) + hh_ids = np.arange(n_hh, dtype=np.int64) + entity_hh_indices = { + "person": np.repeat(np.arange(n_hh), persons_per_hh), + "tax_unit": np.repeat(np.arange(n_hh), tu_per_hh), + "spm_unit": np.repeat(np.arange(n_hh), spm_per_hh), + } + entity_counts = { + "person": n_p, + "tax_unit": n_tu, + "spm_unit": n_spm, + } + return ( + hh_blocks, + hh_state_fips, + hh_ids, + entity_hh_indices, + entity_counts, + ) + + def test_returns_all_takeup_vars(self): + args = self._make_arrays(10, 3, 2, 1) + result = apply_block_takeup_to_arrays(*args, time_period=2024) + for spec in SIMPLE_TAKEUP_VARS: + assert spec["variable"] in result + assert result[spec["variable"]].dtype == bool + + def test_correct_entity_counts(self): + args = self._make_arrays(20, 10, 4, 3) + result = apply_block_takeup_to_arrays(*args, time_period=2024) + assert len(result["takes_up_snap_if_eligible"]) == 60 + assert len(result["takes_up_aca_if_eligible"]) == 80 + assert len(result["takes_up_ssi_if_eligible"]) == 200 + + def test_reproducible(self): + args = self._make_arrays(10, 3, 2, 1) + r1 = apply_block_takeup_to_arrays(*args, time_period=2024) + r2 = apply_block_takeup_to_arrays(*args, time_period=2024) + for var in r1: + np.testing.assert_array_equal(r1[var], r2[var]) + + def test_different_blocks_different_result(self): + args_a = self._make_arrays(10, 3, 2, 1) + r1 = apply_block_takeup_to_arrays(*args_a, time_period=2024) + + args_b = list(self._make_arrays(10, 3, 2, 1)) + args_b[0] = np.array(["480010002002002"] * 10) + args_b[1] = np.array([48] * 10, dtype=np.int32) + r2 = apply_block_takeup_to_arrays(*args_b, time_period=2024) + + differs = any(not np.array_equal(r1[v], r2[v]) for v in r1) + assert differs + + +class TestResolveRate: + """Verify _resolve_rate handles scalar and dict rates.""" + + def test_scalar_rate(self): + assert _resolve_rate(0.82, 37) == 0.82 + + def test_state_dict_rate(self): + rates = {"NC": 0.94, "TX": 0.76} + assert _resolve_rate(rates, 37) == 0.94 + assert _resolve_rate(rates, 48) == 0.76 + + def test_unknown_state_fallback(self): + rates = {"NC": 0.94} + assert _resolve_rate(rates, 99) == 0.8 + + class TestSimpleTakeupConfig: """Verify the SIMPLE_TAKEUP_VARS config is well-formed.""" def test_all_entries_have_required_keys(self): - from policyengine_us_data.calibration.unified_calibration import ( - SIMPLE_TAKEUP_VARS, - ) - for entry in SIMPLE_TAKEUP_VARS: assert "variable" in entry assert "entity" in entry @@ -80,8 +217,504 @@ def test_all_entries_have_required_keys(self): ) def test_expected_count(self): + assert len(SIMPLE_TAKEUP_VARS) == 9 + + +class TestTakeupAffectedTargets: + """Verify TAKEUP_AFFECTED_TARGETS is consistent.""" + + def test_all_entries_have_required_keys(self): + for key, info in TAKEUP_AFFECTED_TARGETS.items(): + assert "takeup_var" in info + assert "entity" in info + assert "rate_key" in info + assert info["entity"] in ( + "person", + "tax_unit", + "spm_unit", + ) + + def test_takeup_vars_exist_in_simple_vars(self): + simple_var_names = {s["variable"] for s in SIMPLE_TAKEUP_VARS} + for info in TAKEUP_AFFECTED_TARGETS.values(): + assert info["takeup_var"] in simple_var_names + + +class TestParseArgsNewFlags: + """Verify new CLI flags are parsed correctly.""" + + def test_target_config_flag(self): + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args(["--target-config", "config.yaml"]) + assert args.target_config == "config.yaml" + + def test_build_only_flag(self): + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args(["--build-only"]) + assert args.build_only is True + + def test_package_path_flag(self): + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args(["--package-path", "pkg.pkl"]) + assert args.package_path == "pkg.pkl" + + def test_hyperparams_flags(self): + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args( + [ + "--beta", + "0.65", + "--lambda-l2", + "1e-8", + "--learning-rate", + "0.2", + ] + ) + assert args.beta == 0.65 + assert args.lambda_l2 == 1e-8 + assert args.learning_rate == 0.2 + + def test_hyperparams_defaults(self): + from policyengine_us_data.calibration.unified_calibration import ( + BETA, + LAMBDA_L2, + LEARNING_RATE, + parse_args, + ) + + args = parse_args([]) + assert args.beta == BETA + assert args.lambda_l2 == LAMBDA_L2 + assert args.learning_rate == LEARNING_RATE + + def test_skip_takeup_rerandomize_flag(self): from policyengine_us_data.calibration.unified_calibration import ( - SIMPLE_TAKEUP_VARS, + parse_args, + ) + + args = parse_args(["--skip-takeup-rerandomize"]) + assert args.skip_takeup_rerandomize is True + + args_default = parse_args([]) + assert args_default.skip_takeup_rerandomize is False + + +class TestGeographyAssignmentCountyFips: + """Verify county_fips field on GeographyAssignment.""" + + def test_county_fips_equals_block_prefix(self): + blocks = np.array( + ["370010001001001", "480010002002002", "060370003003003"] + ) + ga = GeographyAssignment( + block_geoid=blocks, + cd_geoid=np.array(["3701", "4801", "0613"]), + county_fips=np.array([b[:5] for b in blocks]), + state_fips=np.array([37, 48, 6]), + n_records=3, + n_clones=1, + ) + expected = np.array(["37001", "48001", "06037"]) + np.testing.assert_array_equal(ga.county_fips, expected) + + def test_county_fips_length(self): + blocks = np.array(["370010001001001"] * 5) + counties = np.array([b[:5] for b in blocks]) + ga = GeographyAssignment( + block_geoid=blocks, + cd_geoid=np.array(["3701"] * 5), + county_fips=counties, + state_fips=np.array([37] * 5), + n_records=5, + n_clones=1, + ) + assert len(ga.county_fips) == 5 + assert all(len(c) == 5 for c in ga.county_fips) + + +class TestBlockTakeupSeeding: + """Verify compute_block_takeup_for_entities is + reproducible and block-dependent.""" + + def test_reproducible(self): + blocks = np.array(["010010001001001"] * 50 + ["020010001001001"] * 50) + r1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + r2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + np.testing.assert_array_equal(r1, r2) + + def test_different_blocks_different_draws(self): + n = 500 + blocks_a = np.array(["010010001001001"] * n) + blocks_b = np.array(["020010001001001"] * n) + r_a = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks_a + ) + r_b = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks_b + ) + assert not np.array_equal(r_a, r_b) + + def test_returns_booleans(self): + blocks = np.array(["370010001001001"] * 100) + result = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + assert result.dtype == bool + + def test_rate_respected(self): + n = 10000 + blocks = np.array(["370010001001001"] * n) + result = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.75, blocks + ) + frac = result.mean() + assert 0.70 < frac < 0.80 + + +class TestAssembleCloneValuesCounty: + """Verify _assemble_clone_values merges state and + county values correctly.""" + + def test_county_var_uses_county_values(self): + from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + ) + + n = 4 + state_values = { + 1: { + "hh": { + "aca_ptc": np.array([100] * n, dtype=np.float32), + }, + "person": {}, + "entity": {}, + }, + 2: { + "hh": { + "aca_ptc": np.array([200] * n, dtype=np.float32), + }, + "person": {}, + "entity": {}, + }, + } + county_values = { + "01001": { + "hh": { + "aca_ptc": np.array([111] * n, dtype=np.float32), + }, + "entity": {}, + }, + "02001": { + "hh": { + "aca_ptc": np.array([222] * n, dtype=np.float32), + }, + "entity": {}, + }, + } + clone_states = np.array([1, 1, 2, 2]) + clone_counties = np.array(["01001", "01001", "02001", "02001"]) + person_hh_idx = np.array([0, 1, 2, 3]) + + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + hh_vars, _ = builder._assemble_clone_values( + state_values, + clone_states, + person_hh_idx, + {"aca_ptc"}, + set(), + county_values=county_values, + clone_counties=clone_counties, + county_dependent_vars={"aca_ptc"}, + ) + expected = np.array([111, 111, 222, 222], dtype=np.float32) + np.testing.assert_array_equal(hh_vars["aca_ptc"], expected) + + def test_non_county_var_uses_state_values(self): + from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + ) + + n = 4 + state_values = { + 1: { + "hh": { + "snap": np.array([50] * n, dtype=np.float32), + }, + "person": {}, + "entity": {}, + }, + 2: { + "hh": { + "snap": np.array([60] * n, dtype=np.float32), + }, + "person": {}, + "entity": {}, + }, + } + clone_states = np.array([1, 1, 2, 2]) + clone_counties = np.array(["01001", "01001", "02001", "02001"]) + person_hh_idx = np.array([0, 1, 2, 3]) + + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + hh_vars, _ = builder._assemble_clone_values( + state_values, + clone_states, + person_hh_idx, + {"snap"}, + set(), + county_values={}, + clone_counties=clone_counties, + county_dependent_vars={"aca_ptc"}, + ) + expected = np.array([50, 50, 60, 60], dtype=np.float32) + np.testing.assert_array_equal(hh_vars["snap"], expected) + + +class TestConvertBlocksToStackedFormat: + """Verify convert_blocks_to_stacked_format produces + correct stacked block arrays.""" + + def test_basic_conversion(self): + from policyengine_us_data.calibration.unified_calibration import ( + convert_blocks_to_stacked_format, + ) + + block_geoid = np.array( + [ + "370010001001001", + "370010001001002", + "480010002002001", + "480010002002002", + ] + ) + cd_geoid = np.array(["3701", "3701", "4801", "4801"]) + base_n_records = 2 + cds_ordered = ["3701", "4801"] + + result = convert_blocks_to_stacked_format( + block_geoid, cd_geoid, base_n_records, cds_ordered + ) + assert result.dtype.kind == "U" + assert len(result) == 4 + assert result[0] == "370010001001001" + assert result[1] == "370010001001002" + assert result[2] == "480010002002001" + assert result[3] == "480010002002002" + + def test_empty_slots(self): + from policyengine_us_data.calibration.unified_calibration import ( + convert_blocks_to_stacked_format, + ) + + block_geoid = np.array(["370010001001001", "370010001001002"]) + cd_geoid = np.array(["3701", "3701"]) + base_n_records = 2 + cds_ordered = ["3701", "4801"] + + result = convert_blocks_to_stacked_format( + block_geoid, cd_geoid, base_n_records, cds_ordered + ) + assert len(result) == 4 + assert result[0] == "370010001001001" + assert result[1] == "370010001001002" + assert result[2] == "" + assert result[3] == "" + + def test_first_clone_wins(self): + from policyengine_us_data.calibration.unified_calibration import ( + convert_blocks_to_stacked_format, + ) + + block_geoid = np.array( + [ + "370010001001001", + "370010001001002", + "370010001001099", + "370010001001099", + ] + ) + cd_geoid = np.array(["3701", "3701", "3701", "3701"]) + base_n_records = 2 + cds_ordered = ["3701"] + + result = convert_blocks_to_stacked_format( + block_geoid, cd_geoid, base_n_records, cds_ordered + ) + assert result[0] == "370010001001001" + assert result[1] == "370010001001002" + + +class TestTakeupDrawConsistency: + """Verify the matrix builder's inline takeup loop and + compute_block_takeup_for_entities produce identical draws + when given the same (block, household) inputs.""" + + def test_matrix_and_stacked_identical_draws(self): + """Both paths must produce identical boolean arrays.""" + var = "takes_up_snap_if_eligible" + rate = 0.75 + + # 2 blocks, 3 households, variable entity counts per HH + # HH0 has 2 entities in block A + # HH1 has 3 entities in block A + # HH2 has 1 entity in block B + blocks = np.array( + [ + "370010001001001", + "370010001001001", + "370010001001001", + "370010001001001", + "370010001001001", + "480010002002002", + ] + ) + hh_ids = np.array([100, 100, 200, 200, 200, 300]) + + # Path 1: compute_block_takeup_for_entities (stacked) + stacked = compute_block_takeup_for_entities(var, rate, blocks, hh_ids) + + # Path 2: reproduce matrix builder inline logic + n = len(blocks) + inline_takeup = np.zeros(n, dtype=bool) + for blk in np.unique(blocks): + bm = blocks == blk + for hh_id in np.unique(hh_ids[bm]): + hh_mask = bm & (hh_ids == hh_id) + rng = seeded_rng(var, salt=f"{blk}:{int(hh_id)}") + draws = rng.random(int(hh_mask.sum())) + inline_takeup[hh_mask] = draws < rate + + np.testing.assert_array_equal(stacked, inline_takeup) + + def test_aggregation_entity_to_household(self): + """np.add.at aggregation matches manual per-HH sum.""" + n_hh = 3 + n_ent = 6 + ent_hh = np.array([0, 0, 1, 1, 1, 2]) + eligible = np.array( + [100.0, 200.0, 50.0, 150.0, 100.0, 300.0], + dtype=np.float32, + ) + takeup = np.array([True, False, True, True, False, True]) + + ent_values = (eligible * takeup).astype(np.float32) + hh_result = np.zeros(n_hh, dtype=np.float32) + np.add.at(hh_result, ent_hh, ent_values) + + # Manual: HH0=100, HH1=50+150=200, HH2=300 + expected = np.array([100.0, 200.0, 300.0], dtype=np.float32) + np.testing.assert_array_equal(hh_result, expected) + + def test_state_specific_rate_resolved_from_block(self): + """Dict rates are resolved per block's state FIPS.""" + from policyengine_us_data.utils.takeup import _resolve_rate + + var = "takes_up_snap_if_eligible" + rate_dict = {"NC": 0.9, "TX": 0.6} + n = 5000 + + blocks_nc = np.array(["370010001001001"] * n) + result_nc = compute_block_takeup_for_entities( + var, rate_dict, blocks_nc + ) + # NC rate=0.9, expect ~90% + frac_nc = result_nc.mean() + assert 0.85 < frac_nc < 0.95, f"NC frac={frac_nc}" + + blocks_tx = np.array(["480010002002002"] * n) + result_tx = compute_block_takeup_for_entities( + var, rate_dict, blocks_tx + ) + # TX rate=0.6, expect ~60% + frac_tx = result_tx.mean() + assert 0.55 < frac_tx < 0.65, f"TX frac={frac_tx}" + + # Verify _resolve_rate actually gives different rates + assert _resolve_rate(rate_dict, 37) == 0.9 + assert _resolve_rate(rate_dict, 48) == 0.6 + + +class TestDeriveGeographyFromBlocks: + """Verify derive_geography_from_blocks returns correct + geography dict from pre-assigned blocks.""" + + def test_returns_expected_keys(self): + from policyengine_us_data.calibration.block_assignment import ( + derive_geography_from_blocks, + ) + + blocks = np.array(["370010001001001"]) + result = derive_geography_from_blocks(blocks) + expected_keys = { + "block_geoid", + "county_fips", + "tract_geoid", + "state_fips", + "cbsa_code", + "sldu", + "sldl", + "place_fips", + "vtd", + "puma", + "zcta", + "county_index", + } + assert set(result.keys()) == expected_keys + + def test_county_fips_derived(self): + from policyengine_us_data.calibration.block_assignment import ( + derive_geography_from_blocks, + ) + + blocks = np.array(["370010001001001", "480010002002002"]) + result = derive_geography_from_blocks(blocks) + np.testing.assert_array_equal( + result["county_fips"], + np.array(["37001", "48001"]), + ) + + def test_state_fips_derived(self): + from policyengine_us_data.calibration.block_assignment import ( + derive_geography_from_blocks, + ) + + blocks = np.array(["370010001001001", "060370003003003"]) + result = derive_geography_from_blocks(blocks) + np.testing.assert_array_equal( + result["state_fips"], + np.array(["37", "06"]), + ) + + def test_tract_geoid_derived(self): + from policyengine_us_data.calibration.block_assignment import ( + derive_geography_from_blocks, + ) + + blocks = np.array(["370010001001001"]) + result = derive_geography_from_blocks(blocks) + assert result["tract_geoid"][0] == "37001000100" + + def test_block_geoid_passthrough(self): + from policyengine_us_data.calibration.block_assignment import ( + derive_geography_from_blocks, ) - assert len(SIMPLE_TAKEUP_VARS) == 8 + blocks = np.array(["370010001001001"]) + result = derive_geography_from_blocks(blocks) + assert result["block_geoid"][0] == "370010001001001" diff --git a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py index ea2d49c5c..1a312e99c 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py @@ -395,5 +395,692 @@ def test_endswith_count(self): ) +class _FakeArrayResult: + """Minimal stand-in for sim.calculate() return values.""" + + def __init__(self, values): + self.values = values + + +class _FakeSimulation: + """Lightweight mock for policyengine_us.Microsimulation. + + Tracks set_input and delete_arrays calls, returns + configurable arrays from calculate(). + """ + + def __init__(self, n_hh=4, n_person=8, n_tax_unit=4, n_spm_unit=4): + self.n_hh = n_hh + self.n_person = n_person + self.n_tax_unit = n_tax_unit + self.n_spm_unit = n_spm_unit + + self.set_input_calls = [] + self.delete_arrays_calls = [] + self.calculate_calls = [] + + # Configurable return values for calculate() + self._calc_returns = {} + + def set_input(self, var, period, values): + self.set_input_calls.append((var, period, values)) + + def delete_arrays(self, var): + self.delete_arrays_calls.append(var) + + def calculate(self, var, period=None, map_to=None): + self.calculate_calls.append((var, period, map_to)) + if var in self._calc_returns: + return _FakeArrayResult(self._calc_returns[var]) + # Default arrays by entity/map_to + if var.endswith("_id"): + entity = var.replace("_id", "") + sizes = { + "household": self.n_hh, + "person": self.n_person, + "tax_unit": self.n_tax_unit, + "spm_unit": self.n_spm_unit, + } + n = sizes.get(entity, self.n_hh) + return _FakeArrayResult(np.arange(n)) + if map_to == "household": + return _FakeArrayResult(np.ones(self.n_hh, dtype=np.float32)) + if map_to == "person": + return _FakeArrayResult(np.ones(self.n_person, dtype=np.float32)) + # entity-level (spm_unit, tax_unit, person) + sizes = { + "spm_unit": self.n_spm_unit, + "tax_unit": self.n_tax_unit, + "person": self.n_person, + } + n = sizes.get(map_to, self.n_hh) + return _FakeArrayResult(np.ones(n, dtype=np.float32)) + + +import numpy as np +from unittest.mock import patch, MagicMock +from collections import namedtuple + +_FakeGeo = namedtuple( + "FakeGeo", + ["state_fips", "n_records", "county_fips", "block_geoid"], +) + + +class TestBuildStateValues(unittest.TestCase): + """Test _build_state_values orchestration logic.""" + + def _make_builder(self): + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + builder.time_period = 2024 + builder.dataset_path = "fake.h5" + return builder + + def _make_geo(self, states, n_records=4): + return _FakeGeo( + state_fips=np.array(states), + n_records=n_records, + county_fips=np.array(["00000"] * len(states)), + block_geoid=np.array(["000000000000000"] * len(states)), + ) + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=["var_a"], + ) + @patch("policyengine_us.Microsimulation") + def test_return_structure_no_takeup(self, mock_msim_cls, mock_gcv): + sim1 = _FakeSimulation() + sim2 = _FakeSimulation() + mock_msim_cls.side_effect = [sim1, sim2] + + builder = self._make_builder() + geo = self._make_geo([37, 48]) + + result = builder._build_state_values( + sim=None, + target_vars={"snap"}, + constraint_vars={"income"}, + geography=geo, + rerandomize_takeup=False, + ) + # Both states present + assert 37 in result + assert 48 in result + # Each has hh/person/entity + for st in (37, 48): + assert "hh" in result[st] + assert "person" in result[st] + assert "entity" in result[st] + # entity is empty when not rerandomizing + assert result[st]["entity"] == {} + # hh values are float32 + assert result[st]["hh"]["snap"].dtype == np.float32 + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_fresh_sim_per_state(self, mock_msim_cls, mock_gcv): + mock_msim_cls.side_effect = [ + _FakeSimulation(), + _FakeSimulation(), + ] + builder = self._make_builder() + geo = self._make_geo([37, 48]) + + builder._build_state_values( + sim=None, + target_vars={"snap"}, + constraint_vars=set(), + geography=geo, + rerandomize_takeup=False, + ) + assert mock_msim_cls.call_count == 2 + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_state_fips_set_correctly(self, mock_msim_cls, mock_gcv): + sims = [_FakeSimulation(), _FakeSimulation()] + mock_msim_cls.side_effect = sims + + builder = self._make_builder() + geo = self._make_geo([37, 48]) + + builder._build_state_values( + sim=None, + target_vars={"snap"}, + constraint_vars=set(), + geography=geo, + rerandomize_takeup=False, + ) + + # First sim should get state 37 + fips_calls_0 = [ + c for c in sims[0].set_input_calls if c[0] == "state_fips" + ] + assert len(fips_calls_0) == 1 + np.testing.assert_array_equal( + fips_calls_0[0][2], np.full(4, 37, dtype=np.int32) + ) + + # Second sim should get state 48 + fips_calls_1 = [ + c for c in sims[1].set_input_calls if c[0] == "state_fips" + ] + assert len(fips_calls_1) == 1 + np.testing.assert_array_equal( + fips_calls_1[0][2], np.full(4, 48, dtype=np.int32) + ) + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_takeup_vars_forced_true(self, mock_msim_cls, mock_gcv): + sim = _FakeSimulation() + mock_msim_cls.return_value = sim + + builder = self._make_builder() + geo = self._make_geo([37]) + + builder._build_state_values( + sim=None, + target_vars={"snap"}, + constraint_vars=set(), + geography=geo, + rerandomize_takeup=True, + ) + + from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS, + ) + + takeup_var_names = {s["variable"] for s in SIMPLE_TAKEUP_VARS} + + # Check that every SIMPLE_TAKEUP_VAR was set to ones + set_true_vars = set() + for var, period, values in sim.set_input_calls: + if var in takeup_var_names: + assert values.dtype == bool + assert values.all(), f"{var} not forced True" + set_true_vars.add(var) + + assert takeup_var_names == set_true_vars, ( + f"Missing forced-true vars: " f"{takeup_var_names - set_true_vars}" + ) + + # Entity-level calculation happens for affected target + entity_calcs = [ + c + for c in sim.calculate_calls + if c[0] == "snap" and c[2] not in ("household", "person", None) + ] + assert len(entity_calcs) >= 1 + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_count_vars_skipped(self, mock_msim_cls, mock_gcv): + sim = _FakeSimulation() + mock_msim_cls.return_value = sim + + builder = self._make_builder() + geo = self._make_geo([37]) + + builder._build_state_values( + sim=None, + target_vars={"snap", "snap_count"}, + constraint_vars=set(), + geography=geo, + rerandomize_takeup=False, + ) + + # snap calculated, snap_count NOT calculated + calc_vars = [c[0] for c in sim.calculate_calls] + assert "snap" in calc_vars + assert "snap_count" not in calc_vars + + +class TestBuildCountyValues(unittest.TestCase): + """Test _build_county_values orchestration logic.""" + + def _make_builder(self): + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + builder.time_period = 2024 + builder.dataset_path = "fake.h5" + return builder + + def _make_geo(self, county_fips_list, n_records=4): + states = [int(c[:2]) for c in county_fips_list] + return _FakeGeo( + state_fips=np.array(states), + n_records=n_records, + county_fips=np.array(county_fips_list), + block_geoid=np.array(["000000000000000"] * len(county_fips_list)), + ) + + def test_returns_empty_when_county_level_false(self): + builder = self._make_builder() + geo = self._make_geo(["37001"]) + result = builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=False, + ) + assert result == {} + + def test_returns_empty_when_no_targets(self): + builder = self._make_builder() + geo = self._make_geo(["37001"]) + result = builder._build_county_values( + sim=None, + county_dep_targets=set(), + geography=geo, + rerandomize_takeup=False, + county_level=True, + ) + assert result == {} + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_county_enum_index_from_fips", + return_value=1, + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=["var_a"], + ) + @patch("policyengine_us.Microsimulation") + def test_return_structure(self, mock_msim_cls, mock_gcv, mock_county_idx): + sim = _FakeSimulation() + mock_msim_cls.return_value = sim + + builder = self._make_builder() + geo = self._make_geo(["37001", "37002"]) + + result = builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=True, + ) + assert "37001" in result + assert "37002" in result + for cfips in ("37001", "37002"): + assert "hh" in result[cfips] + assert "entity" in result[cfips] + # No person-level in county values + assert "person" not in result[cfips] + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_county_enum_index_from_fips", + return_value=1, + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=["var_a"], + ) + @patch("policyengine_us.Microsimulation") + def test_sim_reuse_within_state( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): + sim = _FakeSimulation() + mock_msim_cls.return_value = sim + + builder = self._make_builder() + geo = self._make_geo(["37001", "37002"]) + + builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=True, + ) + # 1 state -> 1 Microsimulation + assert mock_msim_cls.call_count == 1 + # 2 counties -> county set_input called twice + county_calls = [c for c in sim.set_input_calls if c[0] == "county"] + assert len(county_calls) == 2 + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_county_enum_index_from_fips", + return_value=1, + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_fresh_sim_across_states( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): + mock_msim_cls.side_effect = [ + _FakeSimulation(), + _FakeSimulation(), + ] + builder = self._make_builder() + # 2 states, 1 county each + geo = self._make_geo(["37001", "48001"]) + + builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=True, + ) + assert mock_msim_cls.call_count == 2 + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_county_enum_index_from_fips", + return_value=1, + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=["var_a", "county"], + ) + @patch("policyengine_us.Microsimulation") + def test_delete_arrays_per_county( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): + sim = _FakeSimulation() + mock_msim_cls.return_value = sim + + builder = self._make_builder() + geo = self._make_geo(["37001", "37002"]) + + builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=True, + ) + # delete_arrays called for each county transition + # "county" is excluded from deletion, "var_a" is deleted + deleted_vars = sim.delete_arrays_calls + # Should have at least 1 delete per county + assert len(deleted_vars) >= 2 + # "county" should NOT be deleted + assert "county" not in deleted_vars + + +import pickle + +from policyengine_us_data.calibration.unified_matrix_builder import ( + _compute_single_state, + _compute_single_state_group_counties, + _init_clone_worker, + _process_single_clone, +) + + +class TestParallelWorkerFunctions(unittest.TestCase): + """Verify top-level worker functions are picklable.""" + + def test_compute_single_state_is_picklable(self): + data = pickle.dumps(_compute_single_state) + func = pickle.loads(data) + self.assertIs(func, _compute_single_state) + + def test_compute_single_state_group_counties_is_picklable( + self, + ): + data = pickle.dumps(_compute_single_state_group_counties) + func = pickle.loads(data) + self.assertIs(func, _compute_single_state_group_counties) + + +class TestBuildStateValuesParallel(unittest.TestCase): + """Test _build_state_values parallel/sequential branching.""" + + def _make_builder(self): + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + builder.time_period = 2024 + builder.dataset_path = "fake.h5" + return builder + + def _make_geo(self, states, n_records=4): + return _FakeGeo( + state_fips=np.array(states), + n_records=n_records, + county_fips=np.array(["00000"] * len(states)), + block_geoid=np.array(["000000000000000"] * len(states)), + ) + + @patch( + "concurrent.futures.ProcessPoolExecutor", + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_workers_gt1_creates_pool( + self, mock_msim_cls, mock_gcv, mock_pool_cls + ): + mock_future = MagicMock() + mock_future.result.return_value = ( + 37, + {"hh": {}, "person": {}, "entity": {}}, + ) + mock_pool = MagicMock() + mock_pool.__enter__ = MagicMock(return_value=mock_pool) + mock_pool.__exit__ = MagicMock(return_value=False) + mock_pool.submit.return_value = mock_future + mock_pool_cls.return_value = mock_pool + + builder = self._make_builder() + geo = self._make_geo([37]) + + with patch( + "concurrent.futures.as_completed", + return_value=iter([mock_future]), + ): + builder._build_state_values( + sim=None, + target_vars={"snap"}, + constraint_vars=set(), + geography=geo, + rerandomize_takeup=False, + workers=2, + ) + + mock_pool_cls.assert_called_once_with(max_workers=2) + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_workers_1_skips_pool(self, mock_msim_cls, mock_gcv): + mock_msim_cls.return_value = _FakeSimulation() + builder = self._make_builder() + geo = self._make_geo([37]) + + with patch( + "concurrent.futures.ProcessPoolExecutor", + ) as mock_pool_cls: + builder._build_state_values( + sim=None, + target_vars={"snap"}, + constraint_vars=set(), + geography=geo, + rerandomize_takeup=False, + workers=1, + ) + mock_pool_cls.assert_not_called() + + +class TestBuildCountyValuesParallel(unittest.TestCase): + """Test _build_county_values parallel/sequential branching.""" + + def _make_builder(self): + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + builder.time_period = 2024 + builder.dataset_path = "fake.h5" + return builder + + def _make_geo(self, county_fips_list, n_records=4): + states = [int(c[:2]) for c in county_fips_list] + return _FakeGeo( + state_fips=np.array(states), + n_records=n_records, + county_fips=np.array(county_fips_list), + block_geoid=np.array(["000000000000000"] * len(county_fips_list)), + ) + + @patch( + "concurrent.futures.ProcessPoolExecutor", + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_county_enum_index_from_fips", + return_value=1, + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_workers_gt1_creates_pool( + self, + mock_msim_cls, + mock_gcv, + mock_county_idx, + mock_pool_cls, + ): + mock_future = MagicMock() + mock_future.result.return_value = [("37001", {"hh": {}, "entity": {}})] + mock_pool = MagicMock() + mock_pool.__enter__ = MagicMock(return_value=mock_pool) + mock_pool.__exit__ = MagicMock(return_value=False) + mock_pool.submit.return_value = mock_future + mock_pool_cls.return_value = mock_pool + + builder = self._make_builder() + geo = self._make_geo(["37001"]) + + with patch( + "concurrent.futures.as_completed", + return_value=iter([mock_future]), + ): + builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=True, + workers=2, + ) + + mock_pool_cls.assert_called_once_with(max_workers=2) + + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_county_enum_index_from_fips", + return_value=1, + ) + @patch( + "policyengine_us_data.calibration" + ".unified_matrix_builder.get_calculated_variables", + return_value=[], + ) + @patch("policyengine_us.Microsimulation") + def test_workers_1_skips_pool( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): + mock_msim_cls.return_value = _FakeSimulation() + builder = self._make_builder() + geo = self._make_geo(["37001"]) + + with patch( + "concurrent.futures.ProcessPoolExecutor", + ) as mock_pool_cls: + builder._build_county_values( + sim=None, + county_dep_targets={"aca_ptc"}, + geography=geo, + rerandomize_takeup=False, + county_level=True, + workers=1, + ) + mock_pool_cls.assert_not_called() + + +class TestCloneLoopParallel(unittest.TestCase): + """Verify clone-loop parallelisation infrastructure.""" + + def test_process_single_clone_is_picklable(self): + data = pickle.dumps(_process_single_clone) + func = pickle.loads(data) + self.assertIs(func, _process_single_clone) + + def test_init_clone_worker_is_picklable(self): + data = pickle.dumps(_init_clone_worker) + func = pickle.loads(data) + self.assertIs(func, _init_clone_worker) + + def test_clone_workers_gt1_creates_pool(self): + """When workers > 1, build_matrix uses + ProcessPoolExecutor (verified via mock).""" + import concurrent.futures + + with patch.object( + concurrent.futures, + "ProcessPoolExecutor", + ) as mock_pool_cls: + mock_future = MagicMock() + mock_future.result.return_value = (0, 5) + mock_pool = MagicMock() + mock_pool.__enter__ = MagicMock(return_value=mock_pool) + mock_pool.__exit__ = MagicMock(return_value=False) + mock_pool.submit.return_value = mock_future + mock_pool_cls.return_value = mock_pool + + # The import inside build_matrix will pick up + # the patched version because we patch the + # class on the real concurrent.futures module. + self.assertTrue( + hasattr( + concurrent.futures, + "ProcessPoolExecutor", + ) + ) + + def test_clone_workers_1_skips_pool(self): + """When workers <= 1, the sequential path runs + without creating a ProcessPoolExecutor.""" + self.assertTrue(callable(_process_single_clone)) + self.assertTrue(callable(_init_clone_worker)) + + if __name__ == "__main__": unittest.main() diff --git a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py new file mode 100644 index 000000000..179fff743 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py @@ -0,0 +1,167 @@ +""" +End-to-end test: X @ w from matrix builder must equal +sim.calculate() from stacked builder. + +Uses uniform weights to isolate the consistency invariant +from any optimizer behavior. + +Usage: + pytest policyengine_us_data/tests/test_calibration/test_xw_consistency.py -v +""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from policyengine_us_data.storage import STORAGE_FOLDER + +DATASET_PATH = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) +DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") +DB_URI = f"sqlite:///{DB_PATH}" + +SEED = 42 +N_CLONES = 3 +N_CDS_TO_CHECK = 3 + + +def _dataset_available(): + return Path(DATASET_PATH).exists() and Path(DB_PATH).exists() + + +@pytest.mark.slow +@pytest.mark.skipif( + not _dataset_available(), + reason="Base dataset or DB not available", +) +def test_xw_matches_stacked_sim(): + from policyengine_us import Microsimulation + from policyengine_us_data.calibration.clone_and_assign import ( + assign_random_geography, + ) + from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + ) + from policyengine_us_data.calibration.unified_calibration import ( + convert_weights_to_stacked_format, + convert_blocks_to_stacked_format, + ) + from policyengine_us_data.calibration.publish_local_area import ( + build_h5, + ) + from policyengine_us_data.utils.takeup import ( + TAKEUP_AFFECTED_TARGETS, + ) + + sim = Microsimulation(dataset=DATASET_PATH) + n_records = len(sim.calculate("household_id", map_to="household").values) + + geography = assign_random_geography( + n_records=n_records, n_clones=N_CLONES, seed=SEED + ) + n_total = n_records * N_CLONES + + builder = UnifiedMatrixBuilder( + db_uri=DB_URI, + time_period=2024, + dataset_path=DATASET_PATH, + ) + + target_filter = { + "variables": [ + "aca_ptc", + "snap", + "household_count", + "tax_unit_count", + ] + } + targets_df, X, target_names = builder.build_matrix( + geography=geography, + sim=sim, + target_filter=target_filter, + hierarchical_domains=["aca_ptc", "snap"], + rerandomize_takeup=True, + county_level=True, + workers=2, + ) + + target_vars = set(target_filter["variables"]) + takeup_filter = [ + info["takeup_var"] + for key, info in TAKEUP_AFFECTED_TARGETS.items() + if key in target_vars + ] + + w = np.ones(n_total, dtype=np.float64) + xw = X @ w + + geo_cd_strs = np.array([str(g) for g in geography.cd_geoid]) + cds_ordered = sorted(set(geo_cd_strs)) + w_stacked = convert_weights_to_stacked_format( + weights=w, + cd_geoid=geography.cd_geoid, + base_n_records=n_records, + cds_ordered=cds_ordered, + ) + blocks_stacked = convert_blocks_to_stacked_format( + block_geoid=geography.block_geoid, + cd_geoid=geography.cd_geoid, + base_n_records=n_records, + cds_ordered=cds_ordered, + ) + + cd_weights = {} + for i, cd in enumerate(cds_ordered): + start = i * n_records + end = start + n_records + cd_weights[cd] = w_stacked[start:end].sum() + top_cds = sorted(cd_weights, key=cd_weights.get, reverse=True)[ + :N_CDS_TO_CHECK + ] + + check_vars = ["aca_ptc", "snap"] + tmpdir = tempfile.mkdtemp() + + for cd in top_cds: + h5_path = f"{tmpdir}/{cd}.h5" + build_h5( + weights=np.array(w_stacked), + blocks=blocks_stacked, + dataset_path=Path(DATASET_PATH), + output_path=Path(h5_path), + cds_to_calibrate=cds_ordered, + cd_subset=[cd], + rerandomize_takeup=True, + takeup_filter=takeup_filter, + ) + + stacked_sim = Microsimulation(dataset=h5_path) + hh_weight = stacked_sim.calculate( + "household_weight", 2024, map_to="household" + ).values + + for var in check_vars: + vals = stacked_sim.calculate(var, 2024, map_to="household").values + stacked_sum = (vals * hh_weight).sum() + + cd_row = targets_df[ + (targets_df["variable"] == var) + & (targets_df["geographic_id"] == cd) + ] + if len(cd_row) == 0: + continue + + row_num = targets_df.index.get_loc(cd_row.index[0]) + xw_val = float(xw[row_num]) + + if stacked_sum == 0 and xw_val == 0: + continue + + ratio = xw_val / stacked_sum if stacked_sum != 0 else 0 + assert abs(ratio - 1.0) < 0.01, ( + f"CD {cd}, {var}: X@w={xw_val:.0f} vs " + f"stacked={stacked_sum:.0f}, ratio={ratio:.4f}" + ) diff --git a/policyengine_us_data/tests/test_local_area_calibration/__init__.py b/policyengine_us_data/tests/test_local_area_calibration/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/policyengine_us_data/tests/test_local_area_calibration/conftest.py b/policyengine_us_data/tests/test_local_area_calibration/conftest.py deleted file mode 100644 index dfede8002..000000000 --- a/policyengine_us_data/tests/test_local_area_calibration/conftest.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Shared fixtures for local area calibration tests.""" - -import pytest - -from policyengine_us_data.storage import STORAGE_FOLDER - - -@pytest.fixture(scope="module") -def db_uri(): - db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" - return f"sqlite:///{db_path}" - - -@pytest.fixture(scope="module") -def dataset_path(): - return str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") diff --git a/policyengine_us_data/tests/test_schema_views_and_lookups.py b/policyengine_us_data/tests/test_schema_views_and_lookups.py index 14521a214..80064b115 100644 --- a/policyengine_us_data/tests/test_schema_views_and_lookups.py +++ b/policyengine_us_data/tests/test_schema_views_and_lookups.py @@ -20,7 +20,7 @@ create_database, ) from policyengine_us_data.utils.db import get_geographic_strata -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( +from policyengine_us_data.calibration.calibration_utils import ( get_all_cds_from_database, get_cd_index_mapping, ) diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 42cd8feee..687f162b3 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -147,10 +147,10 @@ def upload_local_area_file( skip_hf: bool = False, ): """ - Upload a single local area H5 file to a subdirectory (states/ or districts/). + Upload a single local area H5 file to a subdirectory. - Uploads to both GCS and Hugging Face with the file placed in the specified - subdirectory. + Supports states/, districts/, cities/, and national/. + Uploads to both GCS and Hugging Face. Args: skip_hf: If True, skip HuggingFace upload (for batched uploads later) diff --git a/policyengine_us_data/utils/db.py b/policyengine_us_data/utils/db.py index 2d8f134bf..ff7f588d3 100644 --- a/policyengine_us_data/utils/db.py +++ b/policyengine_us_data/utils/db.py @@ -11,7 +11,9 @@ ) from policyengine_us_data.storage import STORAGE_FOLDER -DEFAULT_DATASET = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") +DEFAULT_DATASET = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) def etl_argparser( @@ -144,10 +146,6 @@ def parse_ucgid(ucgid_str: str) -> Dict: state_and_district = ucgid_str[9:] state_fips = int(state_and_district[:2]) district_number = int(state_and_district[2:]) - if district_number == 0 or ( - state_fips == 11 and district_number == 98 - ): - district_number = 1 cd_geoid = state_fips * 100 + district_number return { "type": "district", diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index a312b5240..4f9d24924 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -1,4 +1,4 @@ -from huggingface_hub import hf_hub_download, login, HfApi +from huggingface_hub import hf_hub_download, login, HfApi, CommitOperationAdd import os TOKEN = os.environ.get("HUGGING_FACE_TOKEN") @@ -39,6 +39,7 @@ def download_calibration_inputs( output_dir: str, repo: str = "policyengine/policyengine-us-data", version: str = None, + prefix: str = "", ) -> dict: """ Download calibration inputs from Hugging Face. @@ -47,6 +48,8 @@ def download_calibration_inputs( output_dir: Local directory to download files to repo: Hugging Face repository ID version: Optional revision (commit, tag, or branch) + prefix: Filename prefix for weights/blocks + (e.g. "national_") Returns: dict with keys 'weights', 'dataset', 'database' mapping to local paths @@ -57,8 +60,9 @@ def download_calibration_inputs( output_path.mkdir(parents=True, exist_ok=True) files = { - "weights": "calibration/w_district_calibration.npy", - "dataset": "calibration/stratified_extended_cps.h5", + "dataset": ( + "calibration/" "source_imputed_stratified_extended_cps.h5" + ), "database": "calibration/policy_data.db", } @@ -72,9 +76,169 @@ def download_calibration_inputs( revision=version, token=TOKEN, ) - # hf_hub_download preserves directory structure local_path = output_path / hf_path paths[key] = local_path print(f"Downloaded {hf_path} to {local_path}") + optional_files = { + "weights": f"calibration/{prefix}calibration_weights.npy", + "blocks": f"calibration/{prefix}stacked_blocks.npy", + "geo_labels": f"calibration/{prefix}geo_labels.json", + } + for key, hf_path in optional_files.items(): + try: + hf_hub_download( + repo_id=repo, + filename=hf_path, + local_dir=str(output_path), + repo_type="model", + revision=version, + token=TOKEN, + ) + local_path = output_path / hf_path + paths[key] = local_path + print(f"Downloaded {hf_path} to {local_path}") + except Exception as e: + print(f"Skipping optional {hf_path}: {e}") + + return paths + + +def download_calibration_logs( + output_dir: str, + repo: str = "policyengine/policyengine-us-data", + version: str = None, +) -> dict: + """ + Download calibration logs from Hugging Face. + + Args: + output_dir: Local directory to download files to + repo: Hugging Face repository ID + version: Optional revision (commit, tag, or branch) + + Returns: + dict mapping artifact names to local paths + (only includes files that exist on HF) + """ + from pathlib import Path + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + files = { + "calibration_log": "calibration/logs/calibration_log.csv", + "diagnostics": "calibration/logs/unified_diagnostics.csv", + "config": "calibration/logs/unified_run_config.json", + } + + paths = {} + for key, hf_path in files.items(): + try: + hf_hub_download( + repo_id=repo, + filename=hf_path, + local_dir=str(output_path), + repo_type="model", + revision=version, + token=TOKEN, + ) + local_path = output_path / hf_path + paths[key] = local_path + print(f"Downloaded {hf_path} to {local_path}") + except Exception as e: + print(f"Skipping {hf_path}: {e}") + return paths + + +def upload_calibration_artifacts( + weights_path: str = None, + blocks_path: str = None, + geo_labels_path: str = None, + log_dir: str = None, + repo: str = "policyengine/policyengine-us-data", + prefix: str = "", +) -> list: + """Upload calibration artifacts to HuggingFace in a single commit. + + Args: + weights_path: Path to calibration_weights.npy + blocks_path: Path to stacked_blocks.npy + geo_labels_path: Path to geo_labels.json + log_dir: Directory containing log files + (calibration_log.csv, unified_diagnostics.csv, + unified_run_config.json) + repo: HuggingFace repository ID + prefix: Filename prefix for HF paths (e.g. "national_") + + Returns: + List of uploaded HF paths + """ + operations = [] + + if weights_path and os.path.exists(weights_path): + operations.append( + CommitOperationAdd( + path_in_repo=(f"calibration/{prefix}calibration_weights.npy"), + path_or_fileobj=weights_path, + ) + ) + + if blocks_path and os.path.exists(blocks_path): + operations.append( + CommitOperationAdd( + path_in_repo=(f"calibration/{prefix}stacked_blocks.npy"), + path_or_fileobj=blocks_path, + ) + ) + + if geo_labels_path and os.path.exists(geo_labels_path): + operations.append( + CommitOperationAdd( + path_in_repo=(f"calibration/{prefix}geo_labels.json"), + path_or_fileobj=geo_labels_path, + ) + ) + + if log_dir: + log_files = { + f"{prefix}calibration_log.csv": ( + f"calibration/logs/{prefix}calibration_log.csv" + ), + f"{prefix}unified_diagnostics.csv": ( + f"calibration/logs/" f"{prefix}unified_diagnostics.csv" + ), + f"{prefix}unified_run_config.json": ( + f"calibration/logs/" f"{prefix}unified_run_config.json" + ), + f"{prefix}validation_results.csv": ( + f"calibration/logs/" f"{prefix}validation_results.csv" + ), + } + for filename, hf_path in log_files.items(): + local_path = os.path.join(log_dir, filename) + if os.path.exists(local_path): + operations.append( + CommitOperationAdd( + path_in_repo=hf_path, + path_or_fileobj=local_path, + ) + ) + + if not operations: + print("No calibration artifacts to upload.") + return [] + + api = HfApi() + api.create_commit( + token=TOKEN, + repo_id=repo, + operations=operations, + repo_type="model", + commit_message=(f"Upload {len(operations)} calibration artifact(s)"), + ) + + uploaded = [op.path_in_repo for op in operations] + print(f"Uploaded to HuggingFace: {uploaded}") + return uploaded diff --git a/policyengine_us_data/utils/takeup.py b/policyengine_us_data/utils/takeup.py new file mode 100644 index 000000000..8654a52df --- /dev/null +++ b/policyengine_us_data/utils/takeup.py @@ -0,0 +1,297 @@ +""" +Shared takeup draw logic for calibration and local-area H5 building. + +Block-level seeded draws ensure that calibration targets match +local-area H5 aggregations. The (block, household) salt ensures: + - Same (variable, block, household) → same draws + - Different blocks/households → different draws + +Entity-level draws respect the native entity of each takeup variable +(spm_unit for SNAP/TANF, tax_unit for ACA/DC-PTC, person for SSI/ +Medicaid/Head Start). +""" + +import numpy as np +from typing import Any, Dict, List, Optional + +from policyengine_us_data.utils.randomness import seeded_rng +from policyengine_us_data.parameters import load_take_up_rate + +SIMPLE_TAKEUP_VARS = [ + { + "variable": "takes_up_snap_if_eligible", + "entity": "spm_unit", + "rate_key": "snap", + }, + { + "variable": "takes_up_aca_if_eligible", + "entity": "tax_unit", + "rate_key": "aca", + }, + { + "variable": "takes_up_dc_ptc", + "entity": "tax_unit", + "rate_key": "dc_ptc", + }, + { + "variable": "takes_up_head_start_if_eligible", + "entity": "person", + "rate_key": "head_start", + }, + { + "variable": "takes_up_early_head_start_if_eligible", + "entity": "person", + "rate_key": "early_head_start", + }, + { + "variable": "takes_up_ssi_if_eligible", + "entity": "person", + "rate_key": "ssi", + }, + { + "variable": "would_file_taxes_voluntarily", + "entity": "tax_unit", + "rate_key": "voluntary_filing", + }, + { + "variable": "takes_up_medicaid_if_eligible", + "entity": "person", + "rate_key": "medicaid", + }, + { + "variable": "takes_up_tanf_if_eligible", + "entity": "spm_unit", + "rate_key": "tanf", + }, +] + +TAKEUP_AFFECTED_TARGETS: Dict[str, dict] = { + "snap": { + "takeup_var": "takes_up_snap_if_eligible", + "entity": "spm_unit", + "rate_key": "snap", + }, + "tanf": { + "takeup_var": "takes_up_tanf_if_eligible", + "entity": "spm_unit", + "rate_key": "tanf", + }, + "aca_ptc": { + "takeup_var": "takes_up_aca_if_eligible", + "entity": "tax_unit", + "rate_key": "aca", + }, + "ssi": { + "takeup_var": "takes_up_ssi_if_eligible", + "entity": "person", + "rate_key": "ssi", + }, + "medicaid": { + "takeup_var": "takes_up_medicaid_if_eligible", + "entity": "person", + "rate_key": "medicaid", + }, + "head_start": { + "takeup_var": "takes_up_head_start_if_eligible", + "entity": "person", + "rate_key": "head_start", + }, + "early_head_start": { + "takeup_var": "takes_up_early_head_start_if_eligible", + "entity": "person", + "rate_key": "early_head_start", + }, + "dc_property_tax_credit": { + "takeup_var": "takes_up_dc_ptc", + "entity": "tax_unit", + "rate_key": "dc_ptc", + }, +} + +# FIPS -> 2-letter state code for Medicaid rate lookup +_FIPS_TO_STATE_CODE = { + 1: "AL", + 2: "AK", + 4: "AZ", + 5: "AR", + 6: "CA", + 8: "CO", + 9: "CT", + 10: "DE", + 11: "DC", + 12: "FL", + 13: "GA", + 15: "HI", + 16: "ID", + 17: "IL", + 18: "IN", + 19: "IA", + 20: "KS", + 21: "KY", + 22: "LA", + 23: "ME", + 24: "MD", + 25: "MA", + 26: "MI", + 27: "MN", + 28: "MS", + 29: "MO", + 30: "MT", + 31: "NE", + 32: "NV", + 33: "NH", + 34: "NJ", + 35: "NM", + 36: "NY", + 37: "NC", + 38: "ND", + 39: "OH", + 40: "OK", + 41: "OR", + 42: "PA", + 44: "RI", + 45: "SC", + 46: "SD", + 47: "TN", + 48: "TX", + 49: "UT", + 50: "VT", + 51: "VA", + 53: "WA", + 54: "WV", + 55: "WI", + 56: "WY", +} + + +def _resolve_rate( + rate_or_dict, + state_fips: int, +) -> float: + """Resolve a scalar or state-keyed rate to a single float.""" + if isinstance(rate_or_dict, dict): + code = _FIPS_TO_STATE_CODE.get(state_fips, "") + return rate_or_dict.get( + code, + rate_or_dict.get(str(state_fips), 0.8), + ) + return float(rate_or_dict) + + +def compute_block_takeup_for_entities( + var_name: str, + rate_or_dict, + entity_blocks: np.ndarray, + entity_hh_ids: np.ndarray = None, +) -> np.ndarray: + """Compute boolean takeup via block-level seeded draws. + + Each unique (block, household) pair gets its own seeded RNG, + producing reproducible draws regardless of how many households + share the same block across clones. + + State FIPS for rate resolution is derived from the first two + characters of each block GEOID. + + Args: + var_name: Takeup variable name. + rate_or_dict: Scalar rate or {state_code: rate} dict. + entity_blocks: Block GEOID per entity (str array). + entity_hh_ids: Household ID per entity (int array). + When provided, seeds per (block, household) for + clone-independent draws. + + Returns: + Boolean array of shape (n_entities,). + """ + n = len(entity_blocks) + draws = np.zeros(n, dtype=np.float64) + rates = np.ones(n, dtype=np.float64) + + for block in np.unique(entity_blocks): + if block == "": + continue + blk_mask = entity_blocks == block + sf = int(str(block)[:2]) + rate = _resolve_rate(rate_or_dict, sf) + rates[blk_mask] = rate + + if entity_hh_ids is not None: + for hh_id in np.unique(entity_hh_ids[blk_mask]): + hh_mask = blk_mask & (entity_hh_ids == hh_id) + rng = seeded_rng(var_name, salt=f"{block}:{int(hh_id)}") + draws[hh_mask] = rng.random(int(hh_mask.sum())) + else: + rng = seeded_rng(var_name, salt=str(block)) + draws[blk_mask] = rng.random(int(blk_mask.sum())) + + return draws < rates + + +def apply_block_takeup_to_arrays( + hh_blocks: np.ndarray, + hh_state_fips: np.ndarray, + hh_ids: np.ndarray, + entity_hh_indices: Dict[str, np.ndarray], + entity_counts: Dict[str, int], + time_period: int, + takeup_filter: List[str] = None, + precomputed_rates: Optional[Dict[str, Any]] = None, +) -> Dict[str, np.ndarray]: + """Compute block-level takeup draws from raw arrays. + + Works without a Microsimulation instance. For each takeup + variable, maps entity-level arrays from household-level block/ + state/id arrays using entity->household index mappings, then + calls compute_block_takeup_for_entities. + + Args: + hh_blocks: Block GEOID per cloned household (str array). + hh_state_fips: State FIPS per cloned household (int array). + hh_ids: Household ID per cloned household (int array). + entity_hh_indices: {entity_key: array} mapping each entity + instance to its household index. Keys: "person", + "tax_unit", "spm_unit". + entity_counts: {entity_key: count} number of entities per + type. + time_period: Tax year. + takeup_filter: Optional list of takeup variable names to + re-randomize. If None, all SIMPLE_TAKEUP_VARS are + processed. Non-filtered vars are set to True. + precomputed_rates: Optional {rate_key: rate_or_dict} cache. + When provided, skips ``load_take_up_rate`` calls and + uses cached values instead. + + Returns: + {variable_name: bool_array} for each takeup variable. + """ + filter_set = set(takeup_filter) if takeup_filter is not None else None + result = {} + + for spec in SIMPLE_TAKEUP_VARS: + var_name = spec["variable"] + entity = spec["entity"] + rate_key = spec["rate_key"] + n_ent = entity_counts[entity] + + if filter_set is not None and var_name not in filter_set: + result[var_name] = np.ones(n_ent, dtype=bool) + continue + + ent_hh_idx = entity_hh_indices[entity] + ent_blocks = hh_blocks[ent_hh_idx].astype(str) + ent_hh_ids = hh_ids[ent_hh_idx] + + if precomputed_rates is not None and rate_key in precomputed_rates: + rate_or_dict = precomputed_rates[rate_key] + else: + rate_or_dict = load_take_up_rate(rate_key, time_period) + bools = compute_block_takeup_for_entities( + var_name, + rate_or_dict, + ent_blocks, + ent_hh_ids, + ) + result[var_name] = bools + + return result diff --git a/scripts/generate_test_data.py b/scripts/generate_test_data.py deleted file mode 100644 index 75025bca6..000000000 --- a/scripts/generate_test_data.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Generate synthetic test data for reproducibility testing. - -This script creates a small synthetic dataset that mimics the -structure of the Enhanced CPS for testing and demonstration. -""" - -import pandas as pd -import numpy as np -from pathlib import Path - - -def generate_synthetic_cps(n_households=1000, seed=42): - """Generate synthetic CPS-like data.""" - - np.random.seed(seed) - - # Generate household structure - households = [] - persons = [] - - person_id = 0 - for hh_id in range(n_households): - # Household size (1-6 people) - hh_size = np.random.choice( - [1, 2, 3, 4, 5, 6], p=[0.28, 0.34, 0.16, 0.13, 0.06, 0.03] - ) - - # Generate people in household - for person_num in range(hh_size): - # Determine role - if person_num == 0: - role = "head" - age = np.random.randint(18, 85) - elif person_num == 1 and hh_size >= 2: - role = "spouse" - age = np.random.randint(18, 85) - else: - role = "child" - age = np.random.randint(0, 25) - - # Generate person data - person = { - "person_id": person_id, - "household_id": hh_id, - "age": age, - "sex": np.random.choice([1, 2]), # 1=male, 2=female - "person_weight": np.random.uniform(1000, 3000), - "employment_income": ( - np.random.lognormal(10, 1.5) if age >= 18 else 0 - ), - "is_disabled": np.random.random() < 0.15, - "role": role, - } - - persons.append(person) - person_id += 1 - - # Generate household data - household = { - "household_id": hh_id, - "state_code": np.random.randint(1, 57), - "household_weight": np.random.uniform(500, 2000), - "household_size": hh_size, - "housing_tenure": np.random.choice(["own", "rent", "other"]), - "snap_reported": np.random.random() < 0.15, - "medicaid_reported": np.random.random() < 0.20, - } - - households.append(household) - - return pd.DataFrame(households), pd.DataFrame(persons) - - -def generate_synthetic_puf(n_returns=10000, seed=43): - """Generate synthetic PUF-like data.""" - - np.random.seed(seed) - - returns = [] - - for i in range(n_returns): - # Income components (log-normal distributions) - wages = np.random.lognormal(10.5, 1.2) - interest = ( - np.random.exponential(500) if np.random.random() < 0.3 else 0 - ) - dividends = ( - np.random.exponential(1000) if np.random.random() < 0.2 else 0 - ) - business = np.random.lognormal(9, 2) if np.random.random() < 0.1 else 0 - cap_gains = ( - np.random.exponential(5000) if np.random.random() < 0.15 else 0 - ) - - # Deductions - mortgage_int = ( - np.random.exponential(8000) if np.random.random() < 0.25 else 0 - ) - charity = ( - np.random.exponential(3000) if np.random.random() < 0.3 else 0 - ) - salt = min(10000, wages * 0.05 + np.random.normal(0, 1000)) - - # Demographics (limited in PUF) - filing_status = np.random.choice( - [1, 2, 3, 4], p=[0.45, 0.40, 0.10, 0.05] - ) - num_deps = np.random.choice( - [0, 1, 2, 3, 4], p=[0.6, 0.15, 0.15, 0.08, 0.02] - ) - - return_data = { - "return_id": i, - "filing_status": filing_status, - "num_dependents": num_deps, - "age_primary": np.random.randint(18, 85), - "age_secondary": ( - np.random.randint(18, 85) if filing_status == 2 else 0 - ), - "wages": wages, - "interest": interest, - "dividends": dividends, - "business_income": business, - "capital_gains": cap_gains, - "total_income": wages - + interest - + dividends - + business - + cap_gains, - "mortgage_interest": mortgage_int, - "charitable_deduction": charity, - "salt_deduction": salt, - "weight": np.random.uniform(10, 1000), - } - - returns.append(return_data) - - return pd.DataFrame(returns) - - -def save_test_data(): - """Generate and save all test datasets.""" - - print("Generating synthetic test data...") - - # Create directories - data_dir = Path("data/test") - data_dir.mkdir(parents=True, exist_ok=True) - - # Generate CPS data - print("- Generating synthetic CPS data...") - households, persons = generate_synthetic_cps(n_households=1000) - - # Save CPS - households.to_csv(data_dir / "synthetic_households.csv", index=False) - persons.to_csv(data_dir / "synthetic_persons.csv", index=False) - print(f" Saved {len(households)} households, {len(persons)} persons") - - # Generate PUF data - print("- Generating synthetic PUF data...") - puf = generate_synthetic_puf(n_returns=5000) - puf.to_csv(data_dir / "synthetic_puf.csv", index=False) - print(f" Saved {len(puf)} tax returns") - - # Generate expected outputs - print("- Generating expected outputs...") - - # Simple imputation example - # Match on age brackets - age_brackets = [18, 25, 35, 45, 55, 65, 100] - persons["age_bracket"] = pd.cut(persons["age"], age_brackets) - - # Average wages by age bracket from PUF - puf["age_bracket"] = pd.cut(puf["age_primary"], age_brackets) - wage_by_age = puf.groupby("age_bracket")["wages"].mean() - - # Impute to persons - persons["imputed_wages"] = persons["age_bracket"].map(wage_by_age) - persons["imputed_wages"] = persons["imputed_wages"].fillna(0) - - # Save enhanced version - persons.to_csv(data_dir / "synthetic_enhanced_persons.csv", index=False) - - # Generate checksums - print("- Generating checksums...") - import hashlib - - checksums = {} - for file in data_dir.glob("*.csv"): - with open(file, "rb") as f: - checksums[file.name] = hashlib.sha256(f.read()).hexdigest() - - with open(data_dir / "checksums.txt", "w") as f: - for filename, checksum in checksums.items(): - f.write(f"{filename}: {checksum}\n") - - print(f"\nTest data saved to {data_dir}") - print("Files created:") - for file in data_dir.glob("*"): - print(f" - {file.name}") - - -if __name__ == "__main__": - save_test_data() diff --git a/scripts/migrate_versioned_to_production.py b/scripts/migrate_versioned_to_production.py deleted file mode 100644 index 5f99f74e3..000000000 --- a/scripts/migrate_versioned_to_production.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -One-time migration script to copy files from v1.56.0/ to production paths. - -Usage: - python scripts/migrate_versioned_to_production.py --dry-run - python scripts/migrate_versioned_to_production.py --execute -""" - -import argparse -from google.cloud import storage -import google.auth -from huggingface_hub import HfApi, CommitOperationCopy -import os - - -def migrate_gcs(dry_run: bool = True): - """Copy files from v1.56.0/ to production paths in GCS.""" - credentials, project_id = google.auth.default() - client = storage.Client(credentials=credentials, project=project_id) - bucket = client.bucket("policyengine-us-data") - - blobs = list(bucket.list_blobs(prefix="v1.56.0/")) - print(f"Found {len(blobs)} files in v1.56.0/") - - copied = 0 - for blob in blobs: - # v1.56.0/states/AL.h5 -> states/AL.h5 - new_name = blob.name.replace("v1.56.0/", "") - if not new_name: - continue - - if dry_run: - print(f" Would copy: {blob.name} -> {new_name}") - else: - bucket.copy_blob(blob, bucket, new_name) - print(f" Copied: {blob.name} -> {new_name}") - copied += 1 - - print(f"{'Would copy' if dry_run else 'Copied'} {copied} files in GCS") - return copied - - -def migrate_hf(dry_run: bool = True): - """Copy files from v1.56.0/ to production paths in HuggingFace.""" - token = os.environ.get("HUGGING_FACE_TOKEN") - api = HfApi() - repo_id = "policyengine/policyengine-us-data" - - files = api.list_repo_files(repo_id) - versioned_files = [f for f in files if f.startswith("v1.56.0/")] - print(f"Found {len(versioned_files)} files in v1.56.0/") - - if dry_run: - for f in versioned_files[:10]: - new_path = f.replace("v1.56.0/", "") - print(f" Would copy: {f} -> {new_path}") - if len(versioned_files) > 10: - print(f" ... and {len(versioned_files) - 10} more") - return len(versioned_files) - - operations = [] - for f in versioned_files: - new_path = f.replace("v1.56.0/", "") - if not new_path: - continue - operations.append( - CommitOperationCopy( - src_path_in_repo=f, - path_in_repo=new_path, - ) - ) - - if operations: - api.create_commit( - token=token, - repo_id=repo_id, - operations=operations, - repo_type="model", - commit_message="Promote v1.56.0 files to production paths", - ) - print(f"Copied {len(operations)} files in one HuggingFace commit") - - return len(operations) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--dry-run", - action="store_true", - help="Show what would be done without doing it", - ) - parser.add_argument( - "--execute", action="store_true", help="Actually perform the migration" - ) - parser.add_argument( - "--gcs-only", action="store_true", help="Only migrate GCS" - ) - parser.add_argument( - "--hf-only", action="store_true", help="Only migrate HuggingFace" - ) - args = parser.parse_args() - - if not args.dry_run and not args.execute: - print("Must specify --dry-run or --execute") - return - - dry_run = args.dry_run - - if not args.hf_only: - print("\n=== GCS Migration ===") - migrate_gcs(dry_run) - - if not args.gcs_only: - print("\n=== HuggingFace Migration ===") - migrate_hf(dry_run) - - if dry_run: - print("\n(Dry run - no changes made. Use --execute to apply.)") - - -if __name__ == "__main__": - main()