diff --git a/.github/workflows/local_area_publish.yaml b/.github/workflows/local_area_publish.yaml
index 44675e63e..9fa86174d 100644
--- a/.github/workflows/local_area_publish.yaml
+++ b/.github/workflows/local_area_publish.yaml
@@ -4,7 +4,7 @@ on:
push:
branches: [main]
paths:
- - 'policyengine_us_data/datasets/cps/local_area_calibration/**'
+ - 'policyengine_us_data/calibration/**'
- '.github/workflows/local_area_publish.yaml'
- 'modal_app/**'
repository_dispatch:
@@ -23,7 +23,7 @@ on:
type: boolean
# Trigger strategy:
-# 1. Automatic: Code changes to local_area_calibration/ pushed to main
+# 1. Automatic: Code changes to calibration/ pushed to main
# 2. repository_dispatch: Calibration workflow triggers after uploading new weights
# 3. workflow_dispatch: Manual trigger with optional parameters
@@ -55,7 +55,7 @@ jobs:
SKIP_UPLOAD="${{ github.event.inputs.skip_upload || 'false' }}"
BRANCH="${{ github.head_ref || github.ref_name }}"
- CMD="modal run modal_app/local_area.py --branch=${BRANCH} --num-workers=${NUM_WORKERS}"
+ CMD="modal run modal_app/local_area.py::main --branch=${BRANCH} --num-workers=${NUM_WORKERS}"
if [ "$SKIP_UPLOAD" = "true" ]; then
CMD="${CMD} --skip-upload"
@@ -71,5 +71,60 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "Files have been uploaded to GCS and staged on HuggingFace." >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
- echo "### Next step: Promote to production" >> $GITHUB_STEP_SUMMARY
- echo "Trigger the **Promote Local Area H5 Files** workflow with the version from the build output." >> $GITHUB_STEP_SUMMARY
+ echo "### Next step: Validation runs automatically" >> $GITHUB_STEP_SUMMARY
+ echo "The validate-staging job will now check all staged H5s." >> $GITHUB_STEP_SUMMARY
+
+ validate-staging:
+ needs: publish-local-area
+ runs-on: ubuntu-latest
+ env:
+ HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
+ steps:
+ - name: Checkout repo
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.13'
+
+ - name: Set up uv
+ uses: astral-sh/setup-uv@v5
+
+ - name: Install dependencies
+ run: uv sync
+
+ - name: Validate staged H5s
+ run: |
+ uv run python -m policyengine_us_data.calibration.validate_staging \
+ --area-type states --output validation_results.csv
+
+ - name: Upload validation results to HF
+ run: |
+ uv run python -c "
+ from policyengine_us_data.utils.huggingface import upload
+ upload('validation_results.csv',
+ 'policyengine/policyengine-us-data',
+ 'calibration/logs/validation_results.csv')
+ "
+
+ - name: Post validation summary
+ if: always()
+ run: |
+ echo "## Validation Results" >> $GITHUB_STEP_SUMMARY
+ if [ -f validation_results.csv ]; then
+ TOTAL=$(tail -n +2 validation_results.csv | wc -l)
+ FAILS=$(grep -c ',FAIL,' validation_results.csv || true)
+ echo "- **${TOTAL}** targets validated" >> $GITHUB_STEP_SUMMARY
+ echo "- **${FAILS}** sanity failures" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "Review in dashboard, then trigger **Promote** workflow." >> $GITHUB_STEP_SUMMARY
+ else
+ echo "Validation did not produce output." >> $GITHUB_STEP_SUMMARY
+ fi
+
+ - name: Upload validation artifact
+ uses: actions/upload-artifact@v4
+ with:
+ name: validation-results
+ path: validation_results.csv
diff --git a/.gitignore b/.gitignore
index a7ab98c9a..5418f2090 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,12 +30,12 @@ docs/.ipynb_checkpoints/
## ACA PTC state-level uprating factors
!policyengine_us_data/storage/aca_ptc_multipliers_2022_2024.csv
-## Raw input cache for database pipeline
-policyengine_us_data/storage/calibration/raw_inputs/
+## Calibration run outputs (weights, diagnostics, packages, config)
+policyengine_us_data/storage/calibration/
## Batch processing checkpoints
completed_*.txt
## Test fixtures
-!policyengine_us_data/tests/test_local_area_calibration/test_fixture_50hh.h5
+!policyengine_us_data/tests/test_calibration/test_fixture_50hh.h5
oregon_ctc_analysis.py
diff --git a/Makefile b/Makefile
index b34b8eb60..c4e7ba541 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,12 @@
-.PHONY: all format test install download upload docker documentation data validate-data calibrate publish-local-area clean build paper clean-paper presentations database database-refresh promote-database promote-dataset
+.PHONY: all format test install download upload docker documentation data validate-data calibrate calibrate-build publish-local-area upload-calibration upload-dataset upload-database push-to-modal build-matrices calibrate-modal calibrate-modal-national calibrate-both stage-h5s stage-national-h5 stage-all-h5s pipeline validate-staging validate-staging-full upload-validation check-staging check-sanity clean build paper clean-paper presentations database database-refresh promote-database promote-dataset promote build-h5s validate-local
+
+GPU ?= A100-80GB
+EPOCHS ?= 200
+NATIONAL_GPU ?= T4
+NATIONAL_EPOCHS ?= 200
+BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD)
+NUM_WORKERS ?= 8
+VERSION ?=
HF_CLONE_DIR ?= $(HOME)/huggingface/policyengine-us-data
@@ -79,8 +87,8 @@ promote-database:
@echo "Copied DB and raw_inputs to HF clone. Now cd to HF repo, commit, and push."
promote-dataset:
- cp policyengine_us_data/storage/stratified_extended_cps_2024.h5 \
- $(HF_CLONE_DIR)/calibration/stratified_extended_cps.h5
+ cp policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \
+ $(HF_CLONE_DIR)/calibration/source_imputed_stratified_extended_cps.h5
@echo "Copied dataset to HF clone. Now cd to HF repo, commit, and push."
data: download
@@ -90,20 +98,147 @@ data: download
python policyengine_us_data/datasets/puf/irs_puf.py
python policyengine_us_data/datasets/puf/puf.py
python policyengine_us_data/datasets/cps/extended_cps.py
+ python policyengine_us_data/calibration/create_stratified_cps.py
+ python policyengine_us_data/calibration/create_source_imputed_cps.py
+
+data-legacy: data
python policyengine_us_data/datasets/cps/enhanced_cps.py
python policyengine_us_data/datasets/cps/small_enhanced_cps.py
- python policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py
calibrate: data
python -m policyengine_us_data.calibration.unified_calibration \
- --puf-dataset policyengine_us_data/storage/puf_2024.h5
+ --target-config policyengine_us_data/calibration/target_config.yaml
+
+calibrate-build: data
+ python -m policyengine_us_data.calibration.unified_calibration \
+ --target-config policyengine_us_data/calibration/target_config.yaml \
+ --build-only
+
+validate-package:
+ python -m policyengine_us_data.calibration.validate_package
publish-local-area:
- python policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py
+ python policyengine_us_data/calibration/publish_local_area.py --upload
+
+build-h5s:
+ python -m policyengine_us_data.calibration.publish_local_area \
+ --weights-path policyengine_us_data/storage/calibration/calibration_weights.npy \
+ --dataset-path policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \
+ --db-path policyengine_us_data/storage/calibration/policy_data.db \
+ --calibration-blocks policyengine_us_data/storage/calibration/stacked_blocks.npy \
+ --stacked-takeup policyengine_us_data/storage/calibration/stacked_takeup.npz \
+ --states-only
+
+validate-local:
+ python -m policyengine_us_data.calibration.validate_staging \
+ --hf-prefix local_area_build \
+ --area-type states --output validation_results.csv
validate-data:
python -c "from policyengine_us_data.storage.upload_completed_datasets import validate_all_datasets; validate_all_datasets()"
+upload-calibration:
+ python -c "from policyengine_us_data.utils.huggingface import upload_calibration_artifacts; \
+ upload_calibration_artifacts()"
+
+upload-dataset:
+ python -c "from policyengine_us_data.utils.huggingface import upload; \
+ upload('policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5', \
+ 'policyengine/policyengine-us-data', \
+ 'calibration/source_imputed_stratified_extended_cps.h5')"
+ @echo "Dataset uploaded to HF."
+
+upload-database:
+ python -c "from policyengine_us_data.utils.huggingface import upload; \
+ upload('policyengine_us_data/storage/calibration/policy_data.db', \
+ 'policyengine/policyengine-us-data', \
+ 'calibration/policy_data.db')"
+ @echo "Database uploaded to HF."
+
+push-to-modal:
+ modal volume put local-area-staging \
+ policyengine_us_data/storage/calibration/calibration_weights.npy \
+ calibration_inputs/calibration/calibration_weights.npy --force
+ modal volume put local-area-staging \
+ policyengine_us_data/storage/calibration/stacked_blocks.npy \
+ calibration_inputs/calibration/stacked_blocks.npy --force
+ modal volume put local-area-staging \
+ policyengine_us_data/storage/calibration/stacked_takeup.npz \
+ calibration_inputs/calibration/stacked_takeup.npz --force
+ modal volume put local-area-staging \
+ policyengine_us_data/storage/calibration/policy_data.db \
+ calibration_inputs/calibration/policy_data.db --force
+ modal volume put local-area-staging \
+ policyengine_us_data/storage/calibration/geo_labels.json \
+ calibration_inputs/calibration/geo_labels.json --force
+ modal volume put local-area-staging \
+ policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \
+ calibration_inputs/calibration/source_imputed_stratified_extended_cps.h5 --force
+ @echo "All calibration inputs pushed to Modal volume."
+
+build-matrices:
+ modal run modal_app/remote_calibration_runner.py::build_package \
+ --branch $(BRANCH)
+
+calibrate-modal:
+ modal run modal_app/remote_calibration_runner.py::main \
+ --branch $(BRANCH) --gpu $(GPU) --epochs $(EPOCHS) \
+ --push-results
+
+calibrate-modal-national:
+ modal run modal_app/remote_calibration_runner.py::main \
+ --branch $(BRANCH) --gpu $(NATIONAL_GPU) \
+ --epochs $(NATIONAL_EPOCHS) \
+ --push-results --national
+
+calibrate-both:
+ $(MAKE) calibrate-modal & $(MAKE) calibrate-modal-national & wait
+
+stage-h5s:
+ modal run modal_app/local_area.py::main \
+ --branch $(BRANCH) --num-workers $(NUM_WORKERS) \
+ $(if $(SKIP_DOWNLOAD),--skip-download)
+
+stage-national-h5:
+ modal run modal_app/local_area.py::main_national \
+ --branch $(BRANCH)
+
+stage-all-h5s:
+ $(MAKE) stage-h5s & $(MAKE) stage-national-h5 & wait
+
+promote:
+ $(eval VERSION := $(or $(VERSION),$(shell python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])")))
+ modal run modal_app/local_area.py::main_promote \
+ --branch $(BRANCH) --version $(VERSION)
+
+validate-staging:
+ python -m policyengine_us_data.calibration.validate_staging \
+ --area-type states --output validation_results.csv
+
+validate-staging-full:
+ python -m policyengine_us_data.calibration.validate_staging \
+ --area-type states,districts --output validation_results.csv
+
+upload-validation:
+ python -c "from policyengine_us_data.utils.huggingface import upload; \
+ upload('validation_results.csv', \
+ 'policyengine/policyengine-us-data', \
+ 'calibration/logs/validation_results.csv')"
+
+check-staging:
+ python -m policyengine_us_data.calibration.check_staging_sums
+
+check-sanity:
+ python -m policyengine_us_data.calibration.validate_staging \
+ --sanity-only --area-type states --areas NC
+
+pipeline: data upload-dataset build-matrices calibrate-both stage-all-h5s
+ @echo ""
+ @echo "========================================"
+ @echo "Pipeline complete. H5s are in HF staging."
+ @echo "Run 'Promote Local Area H5 Files' workflow in GitHub to publish."
+ @echo "========================================"
+
clean:
rm -f policyengine_us_data/storage/*.h5
rm -f policyengine_us_data/storage/*.db
diff --git a/changelog.d/add-database-build-test.added.md b/changelog.d/add-database-build-test.added.md
new file mode 100644
index 000000000..27661ea66
--- /dev/null
+++ b/changelog.d/add-database-build-test.added.md
@@ -0,0 +1 @@
+Add end-to-end test for calibration database build pipeline.
diff --git a/changelog.d/calibration-pipeline-improvements.added.md b/changelog.d/calibration-pipeline-improvements.added.md
new file mode 100644
index 000000000..6f6a34158
--- /dev/null
+++ b/changelog.d/calibration-pipeline-improvements.added.md
@@ -0,0 +1,8 @@
+Unified calibration pipeline with GPU-accelerated L1/L0 solver, target config YAML, and CLI package validator.
+Per-state and per-county precomputation replacing per-clone Microsimulation (51 sims instead of 436).
+Parallel state, county, and clone loop processing via ProcessPoolExecutor.
+Block-level takeup re-randomization with deterministic seeded draws.
+Hierarchical uprating with ACA PTC state-level CSV factors and CD reconciliation.
+Modal remote runner with Volume support, CUDA OOM fixes, and checkpointing.
+Stacked dataset builder with sparse CD subsets and calibration block propagation.
+Staging validation script (validate_staging.py) with sim.calculate() comparison and sanity checks.
diff --git a/changelog.d/calibration-pipeline-improvements.changed.md b/changelog.d/calibration-pipeline-improvements.changed.md
new file mode 100644
index 000000000..492640977
--- /dev/null
+++ b/changelog.d/calibration-pipeline-improvements.changed.md
@@ -0,0 +1,3 @@
+Geography assignment now prevents clone-to-CD collisions.
+County-dependent vars (aca_ptc) selectively precomputed per county; other vars use state-only path.
+Target config switched to finest-grain include mode (~18K targets).
diff --git a/changelog.d/calibration-pipeline-improvements.fixed.md b/changelog.d/calibration-pipeline-improvements.fixed.md
new file mode 100644
index 000000000..c935ce0bc
--- /dev/null
+++ b/changelog.d/calibration-pipeline-improvements.fixed.md
@@ -0,0 +1,3 @@
+Cross-state cache pollution in matrix builder precomputation.
+Takeup draw ordering mismatch between matrix builder and stacked builder.
+At-large district geoid mismatch (7 districts had 0 estimates).
diff --git a/changelog.d/migrate-to-towncrier.changed.md b/changelog.d/migrate-to-towncrier.changed.md
new file mode 100644
index 000000000..865484add
--- /dev/null
+++ b/changelog.d/migrate-to-towncrier.changed.md
@@ -0,0 +1 @@
+Migrated from changelog_entry.yaml to towncrier fragments to eliminate merge conflicts.
diff --git a/docs/build_h5.md b/docs/build_h5.md
new file mode 100644
index 000000000..513680b0b
--- /dev/null
+++ b/docs/build_h5.md
@@ -0,0 +1,123 @@
+# build_h5 — Unified H5 Builder
+
+`build_h5` is the single function that produces all local-area H5 datasets (national, state, district, city). It lives in `policyengine_us_data/calibration/publish_local_area.py`.
+
+## Signature
+
+```python
+def build_h5(
+ weights: np.ndarray,
+ blocks: np.ndarray,
+ dataset_path: Path,
+ output_path: Path,
+ cds_to_calibrate: List[str],
+ cd_subset: List[str] = None,
+ county_filter: set = None,
+ rerandomize_takeup: bool = False,
+ takeup_filter: List[str] = None,
+) -> Path:
+```
+
+## Parameter Semantics
+
+| Parameter | Type | Purpose |
+|---|---|---|
+| `weights` | `np.ndarray` | Stacked weight vector, shape `(n_geo * n_hh,)` |
+| `blocks` | `np.ndarray` | Block GEOID per weight entry (same shape). If `None`, generated from CD assignments. |
+| `dataset_path` | `Path` | Path to base dataset H5 file |
+| `output_path` | `Path` | Where to write the output H5 file |
+| `cds_to_calibrate` | `List[str]` | Ordered list of CD GEOIDs defining weight matrix row ordering |
+| `cd_subset` | `List[str]` | If provided, only include rows for these CDs |
+| `county_filter` | `set` | If provided, scale weights by P(target counties \| CD) for city datasets |
+| `rerandomize_takeup` | `bool` | Re-draw takeup using block-level seeds |
+| `takeup_filter` | `List[str]` | List of takeup variables to re-randomize |
+
+## How `cd_subset` Controls Output Level
+
+The `cd_subset` parameter determines what geographic level the output represents:
+
+- **National** (`cd_subset=None`): All CDs included — produces a full national dataset.
+- **State** (`cd_subset=[CDs in state]`): Filter to CDs whose FIPS prefix matches the state — produces a state dataset.
+- **District** (`cd_subset=[single_cd]`): Single CD — produces a district dataset.
+- **City** (`cd_subset=[NYC CDs]` + `county_filter=NYC_COUNTIES`): Multiple CDs with county filtering — produces a city dataset. The `county_filter` scales weights by the probability that a household in each CD falls within the target counties.
+
+## Internal Pipeline
+
+1. **Load base simulation** — One `Microsimulation` loaded from `dataset_path`. Entity arrays and membership mappings extracted.
+
+2. **Reshape weights** — The flat weight vector is reshaped to `(n_geo, n_hh)`.
+
+3. **CD subset filtering** — Rows for CDs not in `cd_subset` are zeroed out.
+
+4. **County filtering** — If `county_filter` is set, each row is scaled by `P(target_counties | CD)` via `get_county_filter_probability()`.
+
+5. **Identify active clones** — `np.where(W > 0)` finds all nonzero entries. Each represents a distinct household clone.
+
+6. **Clone entity arrays** — Entity arrays (household, person, tax_unit, spm_unit, family, marital_unit) are cloned using fancy indexing on the base simulation arrays.
+
+7. **Reindex entity IDs** — All entity IDs are reassigned to be globally unique. Cross-reference arrays (e.g., `person_household_id`) are updated accordingly.
+
+8. **Derive geography** — Block GEOIDs are mapped to state FIPS, county, tract, CBSA, etc. via `derive_geography_from_blocks()`. Unique blocks are deduplicated for efficiency.
+
+9. **Recalculate SPM thresholds** — SPM thresholds are recomputed using `calculate_spm_thresholds_vectorized()` with the clone's CD-level geographic adjustment factor.
+
+10. **Rerandomize takeup** (optional) — If enabled, takeup booleans are redrawn per census block using `apply_block_takeup_to_arrays()`.
+
+11. **Write H5** — All variable arrays are written to the output file.
+
+## Usage Examples
+
+### National
+```python
+build_h5(
+ weights=w,
+ blocks=blocks,
+ dataset_path=Path("base.h5"),
+ output_path=Path("national/US.h5"),
+ cds_to_calibrate=cds,
+)
+```
+
+### State
+```python
+state_fips = 6 # California
+cd_subset = [cd for cd in cds if int(cd) // 100 == state_fips]
+build_h5(
+ weights=w,
+ blocks=blocks,
+ dataset_path=Path("base.h5"),
+ output_path=Path("states/CA.h5"),
+ cds_to_calibrate=cds,
+ cd_subset=cd_subset,
+)
+```
+
+### District
+```python
+build_h5(
+ weights=w,
+ blocks=blocks,
+ dataset_path=Path("base.h5"),
+ output_path=Path("districts/CA-12.h5"),
+ cds_to_calibrate=cds,
+ cd_subset=["0612"],
+)
+```
+
+### City (NYC)
+```python
+from policyengine_us_data.calibration.publish_local_area import (
+ NYC_COUNTIES, NYC_CDS,
+)
+
+cd_subset = [cd for cd in cds if cd in NYC_CDS]
+build_h5(
+ weights=w,
+ blocks=blocks,
+ dataset_path=Path("base.h5"),
+ output_path=Path("cities/NYC.h5"),
+ cds_to_calibrate=cds,
+ cd_subset=cd_subset,
+ county_filter=NYC_COUNTIES,
+)
+```
diff --git a/docs/calibration.md b/docs/calibration.md
new file mode 100644
index 000000000..131c2a8e5
--- /dev/null
+++ b/docs/calibration.md
@@ -0,0 +1,497 @@
+# Calibration Pipeline User's Manual
+
+The unified calibration pipeline reweights cloned CPS records to match administrative targets using L0-regularized optimization. This guide covers the main workflows: lightweight build-then-fit, full pipeline with PUF, and fitting from a saved package.
+
+## Quick Start
+
+```bash
+# Build matrix only from stratified CPS (no PUF, no re-imputation):
+python -m policyengine_us_data.calibration.unified_calibration \
+ --target-config policyengine_us_data/calibration/target_config.yaml \
+ --skip-source-impute \
+ --skip-takeup-rerandomize \
+ --build-only
+
+# Fit weights from a saved package:
+python -m policyengine_us_data.calibration.unified_calibration \
+ --package-path storage/calibration/calibration_package.pkl \
+ --epochs 500 --device cuda
+
+# Full pipeline with PUF (build + fit in one shot):
+make calibrate
+```
+
+## Architecture Overview
+
+The pipeline has two phases:
+
+1. **Matrix build**: Clone CPS records, assign geography, compute all target variable values, assemble a sparse calibration matrix. Optionally includes PUF cloning (doubles record count) and source re-imputation.
+2. **Weight fitting** (~5-20 min on GPU): L0-regularized optimization to find household weights that reproduce administrative targets.
+
+The calibration package checkpoint lets you run phase 1 once and iterate on phase 2 with different hyperparameters or target selections---without rebuilding.
+
+### Prerequisites
+
+The matrix build requires two inputs from the data pipeline:
+
+- **Stratified CPS** (`storage/stratified_extended_cps_2024.h5`): ~12K households, built by `make data`. This is the base dataset that gets cloned.
+- **Target database** (`storage/calibration/policy_data.db`): Administrative targets, built by `make database`.
+
+Both must exist before running calibration. The stratified CPS already contains all CPS variables needed for calibration; PUF cloning and source re-imputation are optional enhancements that happen at calibration time.
+
+## Workflows
+
+### 1. Lightweight build-then-fit (recommended for iteration)
+
+Build the matrix from the stratified CPS without PUF cloning or re-imputation. This is the fastest way to get a calibration package for experimentation.
+
+**Step 1: Build the matrix (~12K base records x 436 clones = ~5.2M columns).**
+
+```bash
+python -m policyengine_us_data.calibration.unified_calibration \
+ --target-config policyengine_us_data/calibration/target_config.yaml \
+ --skip-source-impute \
+ --skip-takeup-rerandomize \
+ --build-only
+```
+
+This saves `storage/calibration/calibration_package.pkl` (default location). Use `--package-output` to specify a different path.
+
+**Step 2: Fit weights from the package (fast, repeatable).**
+
+```bash
+python -m policyengine_us_data.calibration.unified_calibration \
+ --package-path storage/calibration/calibration_package.pkl \
+ --epochs 1000 \
+ --lambda-l0 1e-8 \
+ --beta 0.65 \
+ --lambda-l2 1e-8 \
+ --device cuda
+```
+
+You can re-run Step 2 as many times as you want with different hyperparameters. The expensive matrix build only happens once.
+
+### 2. Full pipeline with PUF
+
+Adding `--puf-dataset` doubles the record count (~24K base records x 436 clones = ~10.4M columns) by creating PUF-imputed copies of every CPS record. This also triggers source re-imputation unless skipped.
+
+**Single-pass (build + fit):**
+
+```bash
+python -m policyengine_us_data.calibration.unified_calibration \
+ --puf-dataset policyengine_us_data/storage/puf_2024.h5 \
+ --target-config policyengine_us_data/calibration/target_config.yaml \
+ --epochs 200 \
+ --device cuda
+```
+
+Or equivalently: `make calibrate`
+
+Output:
+- `storage/calibration/calibration_weights.npy` --- calibrated weight vector
+- `storage/calibration/unified_diagnostics.csv` --- per-target error report
+- `storage/calibration/unified_run_config.json` --- full run configuration
+
+**Build-only (save package for later fitting):**
+
+```bash
+python -m policyengine_us_data.calibration.unified_calibration \
+ --puf-dataset policyengine_us_data/storage/puf_2024.h5 \
+ --target-config policyengine_us_data/calibration/target_config.yaml \
+ --build-only
+```
+
+Or equivalently: `make calibrate-build`
+
+This saves `storage/calibration/calibration_package.pkl` (default location). Use `--package-output` to specify a different path.
+
+Then fit from the package using the same Step 2 command from Workflow 1.
+
+### 3. Re-filtering a saved package
+
+A saved package contains **all** targets from the database (before target config filtering). You can apply a different target config at fit time:
+
+```bash
+python -m policyengine_us_data.calibration.unified_calibration \
+ --package-path storage/calibration/calibration_package.pkl \
+ --target-config my_custom_config.yaml \
+ --epochs 200
+```
+
+This lets you experiment with which targets to include without rebuilding the matrix.
+
+### 4. Running on Modal (GPU cloud)
+
+**From a pre-built package** (recommended):
+
+Use `--package-path` to point at a local `.pkl` file. The runner automatically uploads it to the Modal Volume and then fits from it on the GPU, avoiding the function argument size limit.
+
+```bash
+modal run modal_app/remote_calibration_runner.py \
+ --package-path policyengine_us_data/storage/calibration/calibration_package.pkl \
+ --branch calibration-pipeline-improvements \
+ --gpu T4 \
+ --epochs 1000 \
+ --beta 0.65 \
+ --lambda-l0 1e-8 \
+ --lambda-l2 1e-8
+```
+
+If a package already exists on the volume from a previous upload, you can also use `--prebuilt-matrices` to fit directly without re-uploading.
+
+**Full pipeline** (builds matrix from scratch on Modal):
+
+```bash
+modal run modal_app/remote_calibration_runner.py \
+ --branch calibration-pipeline-improvements \
+ --gpu T4 \
+ --epochs 1000 \
+ --beta 0.65 \
+ --lambda-l0 1e-8 \
+ --lambda-l2 1e-8 \
+ --target-config policyengine_us_data/calibration/target_config.yaml
+```
+
+The target config YAML is read from the cloned repo inside the container, so it must be committed to the branch you specify.
+
+### 5. Portable fitting (Kaggle, Colab, etc.)
+
+Transfer the package file to any environment with `scipy`, `numpy`, `pandas`, `torch`, and `l0-python` installed:
+
+```python
+from policyengine_us_data.calibration.unified_calibration import (
+ load_calibration_package,
+ apply_target_config,
+ fit_l0_weights,
+)
+
+package = load_calibration_package("calibration_package.pkl")
+targets_df = package["targets_df"]
+X_sparse = package["X_sparse"]
+
+weights = fit_l0_weights(
+ X_sparse=X_sparse,
+ targets=targets_df["value"].values,
+ lambda_l0=1e-8,
+ epochs=500,
+ device="cuda",
+ beta=0.65,
+ lambda_l2=1e-8,
+)
+```
+
+## Target Config
+
+The target config controls which targets reach the optimizer. It uses a YAML exclusion list:
+
+```yaml
+exclude:
+ - variable: rent
+ geo_level: national
+ - variable: eitc
+ geo_level: district
+ - variable: snap
+ geo_level: state
+ domain_variable: snap # optional: further narrow the match
+```
+
+Each rule drops rows from the calibration matrix where **all** specified fields match. Unrecognized variables silently match nothing.
+
+### Fields
+
+| Field | Required | Values | Description |
+|---|---|---|---|
+| `variable` | Yes | Any variable name in `target_overview` | The calibration target variable |
+| `geo_level` | Yes | `national`, `state`, `district` | Geographic aggregation level |
+| `domain_variable` | No | Any domain variable in `target_overview` | Narrows match to a specific domain |
+
+### Default config
+
+The checked-in config at `policyengine_us_data/calibration/target_config.yaml` reproduces the junkyard notebook's 22 excluded target groups. It drops:
+
+- **13 national-level variables**: alimony, charitable deduction, child support, interest deduction, medical expense deduction, net worth, person count, real estate taxes, rent, social security dependents/survivors
+- **9 district-level variables**: ACA PTC, EITC, income tax before credits, medical expense deduction, net capital gains, rental income, tax unit count, partnership/S-corp income, taxable social security
+
+Applying this config reduces targets from ~37K to ~21K, matching the junkyard's target selection.
+
+### Writing a custom config
+
+To experiment, copy the default and edit:
+
+```bash
+cp policyengine_us_data/calibration/target_config.yaml my_config.yaml
+# Edit my_config.yaml to add/remove exclusion rules
+python -m policyengine_us_data.calibration.unified_calibration \
+ --package-path storage/calibration/calibration_package.pkl \
+ --target-config my_config.yaml \
+ --epochs 200
+```
+
+To see what variables and geo_levels are available in the database:
+
+```sql
+SELECT DISTINCT variable, geo_level
+FROM target_overview
+ORDER BY variable, geo_level;
+```
+
+## CLI Reference
+
+### Core flags
+
+| Flag | Default | Description |
+|---|---|---|
+| `--dataset` | `storage/stratified_extended_cps_2024.h5` | Path to CPS h5 file |
+| `--db-path` | `storage/calibration/policy_data.db` | Path to target database |
+| `--output` | `storage/calibration/calibration_weights.npy` | Weight output path |
+| `--puf-dataset` | None | Path to PUF h5 (enables PUF cloning) |
+| `--preset` | `local` | L0 preset: `local` (1e-8) or `national` (1e-4) |
+| `--lambda-l0` | None | Custom L0 penalty (overrides `--preset`) |
+| `--epochs` | 100 | Training epochs |
+| `--device` | `cpu` | `cpu` or `cuda` |
+| `--n-clones` | 436 | Number of dataset clones |
+| `--seed` | 42 | Random seed for geography assignment |
+
+### Target selection
+
+| Flag | Default | Description |
+|---|---|---|
+| `--target-config` | None | Path to YAML exclusion config |
+| `--domain-variables` | None | Comma-separated domain filter (SQL-level) |
+| `--hierarchical-domains` | None | Domains for hierarchical uprating |
+
+### Checkpoint flags
+
+| Flag | Default | Description |
+|---|---|---|
+| `--build-only` | False | Build matrix, save package, skip fitting |
+| `--package-path` | None | Load pre-built package (uploads to Modal volume automatically when using Modal runner) |
+| `--package-output` | Auto (when `--build-only`) | Where to save package |
+
+### Hyperparameter flags
+
+| Flag | Default | Junkyard value | Description |
+|---|---|---|---|
+| `--beta` | 0.35 | 0.65 | L0 gate temperature (higher = softer gates) |
+| `--lambda-l2` | 1e-12 | 1e-8 | L2 regularization on weights |
+| `--learning-rate` | 0.15 | 0.15 | Optimizer learning rate |
+
+### Skip flags
+
+| Flag | Description |
+|---|---|
+| `--skip-puf` | Skip PUF clone + QRF imputation |
+| `--skip-source-impute` | Skip ACS/SIPP/SCF re-imputation |
+| `--skip-takeup-rerandomize` | Skip takeup re-randomization |
+
+## Calibration Package Format
+
+The package is a pickled Python dict:
+
+```python
+{
+ "X_sparse": scipy.sparse.csr_matrix, # (n_targets, n_records)
+ "targets_df": pd.DataFrame, # target metadata + values
+ "target_names": list[str], # human-readable names
+ "metadata": {
+ "dataset_path": str,
+ "db_path": str,
+ "n_clones": int,
+ "n_records": int,
+ "seed": int,
+ "created_at": str, # ISO timestamp
+ "target_config": dict, # config used at build time
+ },
+}
+```
+
+The `targets_df` DataFrame has columns: `variable`, `geo_level`, `geographic_id`, `domain_variable`, `value`, and others from the database.
+
+## Validating a Package
+
+Before uploading a package to Modal, validate it:
+
+```bash
+# Default package location
+python -m policyengine_us_data.calibration.validate_package
+
+# Specific package
+python -m policyengine_us_data.calibration.validate_package path/to/calibration_package.pkl
+
+# Strict mode: fail if any target has row_sum/target < 1%
+python -m policyengine_us_data.calibration.validate_package --strict
+```
+
+Exit codes: **0** = pass, **1** = impossible targets, **2** = strict ratio failures.
+
+Validation also runs automatically after `--build-only`.
+
+## Hyperparameter Tuning Guide
+
+The three key hyperparameters control the tradeoff between target accuracy and sparsity:
+
+- **`beta`** (L0 gate temperature): Controls how sharply the L0 gates open/close. Higher values (0.5--0.8) give softer decisions and more exploration early in training. Lower values (0.2--0.4) give harder on/off decisions.
+
+- **`lambda_l0`** (via `--preset` or `--lambda-l0`): Controls how many records survive. `1e-8` (local preset) keeps millions of records for local-area analysis. `1e-4` (national preset) keeps ~50K for the web app.
+
+- **`lambda_l2`**: Regularizes weight magnitudes. Larger values (1e-8) prevent any single record from having extreme weight. Smaller values (1e-12) allow more weight concentration.
+
+### Suggested starting points
+
+For **local-area calibration** (millions of records):
+```bash
+--lambda-l0 1e-8 --beta 0.65 --lambda-l2 1e-8 --epochs 500
+```
+
+For **national web app** (~50K records):
+```bash
+--lambda-l0 1e-4 --beta 0.35 --lambda-l2 1e-12 --epochs 200
+```
+
+## Makefile Targets
+
+| Target | Description |
+|---|---|
+| `make calibrate` | Full pipeline with PUF and target config |
+| `make calibrate-build` | Build-only mode (saves package, no fitting) |
+| `make pipeline` | End-to-end: data, upload, calibrate, stage |
+| `make validate-staging` | Validate staged H5s against targets (states only) |
+| `make validate-staging-full` | Validate staged H5s (states + districts) |
+| `make upload-validation` | Push validation_results.csv to HF |
+| `make check-staging` | Smoke test: sum key variables across all state H5s |
+| `make check-sanity` | Quick structural integrity check on one state |
+| `make upload-calibration` | Upload weights, blocks, and logs to HF |
+
+## Takeup Rerandomization
+
+The calibration pipeline uses two independent code paths to compute the same target variables:
+
+1. **Matrix builder** (`UnifiedMatrixBuilder.build_matrix`): Computes a sparse calibration matrix $X$ where each row is a target and each column is a cloned household. The optimizer finds weights $w$ that minimize $\|Xw - t\|$ (target values).
+
+2. **Stacked builder** (`create_sparse_cd_stacked_dataset`): Produces the `.h5` files that users load in `Microsimulation`. It reconstructs each congressional district by combining base CPS records with calibrated weights and block-level geography.
+
+For the calibration to be meaningful, **both paths must produce identical values** for every target variable. If the matrix builder computes $X_{snap,NC} \cdot w = \$5.2B$ but the stacked NC.h5 file yields `sim.calculate("snap") * household_weight = $4.8B`, then the optimizer's solution does not actually match the target.
+
+### The problem with takeup variables
+
+Variables like `snap`, `aca_ptc`, `ssi`, and `medicaid` depend on **takeup draws** — random Bernoulli samples that determine whether an eligible household actually claims the benefit. By default, PolicyEngine draws these at simulation time using Python's built-in `hash()`, which is randomized per process.
+
+This means loading the same H5 file in two different processes can produce different SNAP totals, even with the same weights. Worse, the matrix builder runs in process A while the stacked builder runs in process B, so their draws can diverge.
+
+### The solution: block-level seeding
+
+Both paths call `seeded_rng(variable_name, salt=f"{block_geoid}:{household_id}")` to generate deterministic takeup draws. This ensures:
+
+- The same household at the same block always gets the same draw
+- Draws are stable across processes (no dependency on `hash()`)
+- Draws are stable when aggregating to any geography (state, CD, county)
+
+The affected variables are listed in `TAKEUP_AFFECTED_TARGETS` in `utils/takeup.py`: snap, aca_ptc, ssi, medicaid, tanf, head_start, early_head_start, and dc_property_tax_credit.
+
+The `--skip-takeup-rerandomize` flag disables this rerandomization for faster iteration when you only care about non-takeup variables. Do not use it for production calibrations.
+
+## Block-Level Seeding
+
+Each cloned household is assigned to a Census block (15-digit GEOID) during the `assign_random_geography` step. The first 2 digits are the state FIPS code, which determines the household's takeup rates (since benefit eligibility rules are state-specific).
+
+### Mechanism
+
+```python
+rng = seeded_rng(variable_name, salt=f"{block_geoid}:{household_id}")
+draw = rng.random()
+takes_up = draw < takeup_rate[state_fips]
+```
+
+The `seeded_rng` function uses `_stable_string_hash` — a deterministic hash that does not depend on Python's `PYTHONHASHSEED`. This is critical because Python's built-in `hash()` is randomized per process by default (since Python 3.3).
+
+### Why block (not CD or state)?
+
+Blocks are the finest Census geography. A household's block assignment stays the same regardless of how blocks are aggregated — the same household-block-draw triple produces the same result whether you are building an H5 for a state, a congressional district, or a county. This means:
+
+- State H5s and district H5s are consistent (no draw drift)
+- Future county-level H5s will also be consistent
+- Re-running the pipeline with different area selections yields the same per-household values
+
+### Inactive records
+
+When converting to stacked format, households that are not assigned to a given CD get zero weight. These inactive records must receive an empty string `""` as their block GEOID, not a real block. If they received real blocks, they would inflate the entity count `n` passed to the RNG, shifting the draw positions for active entities and breaking the $X \cdot w$ consistency invariant.
+
+## The $X \cdot w$ Consistency Invariant
+
+### Formal statement
+
+For every target variable $v$ and geography $g$:
+
+$$X_{v,g} \cdot w = \sum_{i \in g} \text{sim.calculate}(v)_i \times w_i$$
+
+where the left side comes from the matrix builder and the right side comes from loading the stacked H5 and running `Microsimulation.calculate()`.
+
+### Why it matters
+
+This invariant is what makes calibration meaningful. Without it, the optimizer's solution (which minimizes $\|Xw - t\|$) does not actually produce a dataset that matches the targets. The weights would be "correct" in the matrix builder's view but produce different totals in the H5 files that users actually load.
+
+### Known sources of drift
+
+1. **Mismatched takeup draws**: The matrix builder and stacked builder use different RNG states. Solved by block-level seeding (see above).
+
+2. **Different block assignments**: The stacked format uses first-clone-wins for multi-clone-same-CD records. With ~11M blocks and 3-10 clones, collision rate is ~0.7-10% of records. In practice, the residual mismatch is negligible.
+
+3. **Inactive records in RNG calls**: If inactive records (w=0) receive real block GEOIDs, they inflate the entity count for that block's RNG call, shifting draw positions. Solved by using `""` for inactive blocks.
+
+4. **Entity ordering**: Both paths must iterate over entities in the same order (`sim.calculate("{entity}_id", map_to=entity)`). NumPy boolean masking preserves order, so `draws[i]` maps to the same entity in both paths.
+
+### Testing
+
+The `test_xw_consistency.py` test (`pytest -m slow`) verifies this invariant end-to-end:
+
+1. Load base dataset, create geography with uniform weights
+2. Build $X$ with the matrix builder (including takeup rerandomization)
+3. Convert weights to stacked format
+4. Build stacked H5 for selected CDs
+5. Compare $X \cdot w$ vs `sim.calculate() * household_weight` — assert ratio within 1%
+
+## Post-Calibration Gating Workflow
+
+After the pipeline stages H5 files to HuggingFace, two manual review gates determine whether to promote to production.
+
+### Gate 1: Review calibration fit
+
+Load `calibration_log.csv` in the microcalibrate dashboard. This file contains the $X \cdot w$ values from the matrix builder for every target at every epoch.
+
+**What to check:**
+- Loss curve converges (no divergence in later epochs)
+- No individual target groups diverging while others improve
+- Final loss is comparable to or better than the previous production run
+
+If fit is poor, re-calibrate with different hyperparameters (learning rate, lambda_l0, beta, epochs).
+
+### Gate 2: Review simulation quality
+
+```bash
+make validate-staging # states only (~30 min)
+make validate-staging-full # states + districts (~3 hrs)
+make upload-validation # push CSV to HF
+```
+
+This produces `validation_results.csv` with `sim.calculate()` values for every target. Load it in the dashboard's Combined tab alongside `calibration_log.csv`.
+
+**What to check:**
+- `CalibrationVsSimComparison` shows the gap between $X \cdot w$ and `sim.calculate()` values
+- No large regressions vs the previous production run
+- Sanity check column has no FAIL entries
+
+### Promote
+
+If both gates pass:
+- Run the "Promote Local Area H5 Files" GitHub workflow, OR
+- Manually copy staged files to the production paths in the HF repo
+
+### Structural pre-flight
+
+For a quick structural check without loading the full database:
+
+```bash
+make check-sanity # one state, ~2 min
+```
+
+This runs weight non-negativity, entity ID uniqueness, NaN/Inf detection, person-household mapping, boolean takeup validation, and per-household range checks.
diff --git a/docs/calibration_matrix.ipynb b/docs/calibration_matrix.ipynb
index 41497b1e8..133f45910 100644
--- a/docs/calibration_matrix.ipynb
+++ b/docs/calibration_matrix.ipynb
@@ -24,10 +24,40 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
- "source": "import numpy as np\nimport pandas as pd\nfrom policyengine_us import Microsimulation\nfrom policyengine_us_data.storage import STORAGE_FOLDER\nfrom policyengine_us_data.calibration.unified_matrix_builder import (\n UnifiedMatrixBuilder,\n)\nfrom policyengine_us_data.calibration.clone_and_assign import (\n assign_random_geography,\n)\nfrom policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (\n create_target_groups,\n drop_target_groups,\n get_geo_level,\n STATE_CODES,\n)\n\ndb_path = STORAGE_FOLDER / \"calibration\" / \"policy_data.db\"\ndb_uri = f\"sqlite:///{db_path}\"\ndataset_path = STORAGE_FOLDER / \"stratified_extended_cps_2024.h5\""
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/baogorek/envs/sep/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from policyengine_us import Microsimulation\n",
+ "from policyengine_us_data.storage import STORAGE_FOLDER\n",
+ "from policyengine_us_data.calibration.unified_matrix_builder import (\n",
+ " UnifiedMatrixBuilder,\n",
+ ")\n",
+ "from policyengine_us_data.calibration.clone_and_assign import (\n",
+ " assign_random_geography,\n",
+ ")\n",
+ "from policyengine_us_data.calibration.calibration_utils import (\n",
+ " create_target_groups,\n",
+ " drop_target_groups,\n",
+ " get_geo_level,\n",
+ " STATE_CODES,\n",
+ ")\n",
+ "\n",
+ "db_path = STORAGE_FOLDER / \"calibration\" / \"policy_data.db\"\n",
+ "db_uri = f\"sqlite:///{db_path}\"\n",
+ "dataset_path = STORAGE_FOLDER / \"stratified_extended_cps_2024.h5\""
+ ]
},
{
"cell_type": "code",
@@ -40,7 +70,7 @@
"text": [
"Records: 11,999, Clones: 3, Total columns: 35,997\n",
"Matrix shape: (1411, 35997)\n",
- "Non-zero entries: 14,946\n"
+ "Non-zero entries: 29,425\n"
]
}
],
@@ -79,10 +109,36 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {},
- "outputs": [],
- "source": "print(f\"Targets: {X_sparse.shape[0]}\")\nprint(f\"Columns: {X_sparse.shape[1]:,} ({N_CLONES} clones x {n_records:,} records)\")\nprint(f\"Non-zeros: {X_sparse.nnz:,}\")\nprint(f\"Density: {X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.6f}\")\n\ngeo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\nlevel_names = {0: \"National\", 1: \"State\", 2: \"District\"}\nfor level in [0, 1, 2]:\n n = (geo_levels == level).sum()\n if n > 0:\n print(f\" {level_names[level]}: {n} targets\")"
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Targets: 1411\n",
+ "Columns: 35,997 (3 clones x 11,999 records)\n",
+ "Non-zeros: 29,425\n",
+ "Density: 0.000579\n",
+ " National: 1 targets\n",
+ " State: 102 targets\n",
+ " District: 1308 targets\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Targets: {X_sparse.shape[0]}\")\n",
+ "print(f\"Columns: {X_sparse.shape[1]:,} ({N_CLONES} clones x {n_records:,} records)\")\n",
+ "print(f\"Non-zeros: {X_sparse.nnz:,}\")\n",
+ "print(f\"Density: {X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.6f}\")\n",
+ "\n",
+ "geo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\n",
+ "level_names = {0: \"National\", 1: \"State\", 2: \"District\"}\n",
+ "for level in [0, 1, 2]:\n",
+ " n = (geo_levels == level).sum()\n",
+ " if n > 0:\n",
+ " print(f\" {level_names[level]}: {n} targets\")"
+ ]
},
{
"cell_type": "markdown",
@@ -131,13 +187,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Row 705 has 9 non-zero columns\n",
+ "Row 705 has 10 non-zero columns\n",
" Spans 3 clone(s)\n",
- " Spans 9 unique record(s)\n",
+ " Spans 10 unique record(s)\n",
"\n",
- "First non-zero column (8000):\n",
+ "First non-zero column (1212):\n",
" clone_idx: 0\n",
- " record_idx: 8000\n",
+ " record_idx: 1212\n",
" state_fips: 34\n",
" cd_geoid: 3402\n",
" value: 1.00\n"
@@ -189,7 +245,7 @@
" record_idx: 42\n",
" state_fips: 45\n",
" cd_geoid: 4507\n",
- " block_geoid: 450510801013029\n",
+ " block_geoid: 450410002022009\n",
"\n",
"This column has non-zero values in 0 target rows\n"
]
@@ -334,7 +390,7 @@
"\n",
"--- Group 4: District ACA PTC Tax Unit Count (436 targets) ---\n",
" variable geographic_id value\n",
- "tax_unit_count 1001 25064.255490\n",
+ "tax_unit_count 1000 25064.255490\n",
"tax_unit_count 101 9794.081624\n",
"tax_unit_count 102 11597.544977\n",
"tax_unit_count 103 9160.097959\n",
@@ -373,13 +429,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Example SNAP-receiving household: record index 23\n",
- "SNAP value: $70\n",
+ "Example SNAP-receiving household: record index 2\n",
+ "SNAP value: $679\n",
"\n",
"Column positions across 3 clones:\n",
- " col 23: TX (state=48, CD=4829) — 0 non-zero rows\n",
- " col 12022: IL (state=17, CD=1708) — 0 non-zero rows\n",
- " col 24021: FL (state=12, CD=1220) — 3 non-zero rows\n"
+ " col 2: TX (state=48, CD=4814) — 4 non-zero rows\n",
+ " col 12001: IN (state=18, CD=1804) — 3 non-zero rows\n",
+ " col 24000: PA (state=42, CD=4212) — 3 non-zero rows\n"
]
}
],
@@ -413,10 +469,21 @@
"output_type": "stream",
"text": [
"\n",
- "Clone 2 (col 24021, CD 1220):\n",
- " household_count (geo=12): 1.00\n",
- " snap (geo=12): 70.08\n",
- " household_count (geo=1220): 1.00\n"
+ "Clone 0 (col 2, CD 4814):\n",
+ " person_count (geo=US): 3.00\n",
+ " household_count (geo=48): 1.00\n",
+ " snap (geo=48): 678.60\n",
+ " household_count (geo=4814): 1.00\n",
+ "\n",
+ "Clone 1 (col 12001, CD 1804):\n",
+ " household_count (geo=18): 1.00\n",
+ " snap (geo=18): 678.60\n",
+ " household_count (geo=1804): 1.00\n",
+ "\n",
+ "Clone 2 (col 24000, CD 4212):\n",
+ " household_count (geo=42): 1.00\n",
+ " snap (geo=42): 678.60\n",
+ " household_count (geo=4212): 1.00\n"
]
}
],
@@ -455,9 +522,9 @@
"output_type": "stream",
"text": [
"Total cells: 50,791,767\n",
- "Non-zero entries: 14,946\n",
- "Density: 0.000294\n",
- "Sparsity: 99.9706%\n"
+ "Non-zero entries: 29,425\n",
+ "Density: 0.000579\n",
+ "Sparsity: 99.9421%\n"
]
}
],
@@ -472,10 +539,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {},
- "outputs": [],
- "source": "nnz_per_row = np.diff(X_sparse.indptr)\nprint(f\"Non-zeros per row:\")\nprint(f\" min: {nnz_per_row.min():,}\")\nprint(f\" median: {int(np.median(nnz_per_row)):,}\")\nprint(f\" mean: {nnz_per_row.mean():,.0f}\")\nprint(f\" max: {nnz_per_row.max():,}\")\n\ngeo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\nlevel_names = {0: \"National\", 1: \"State\", 2: \"District\"}\nprint(\"\\nBy geographic level:\")\nfor level in [0, 1, 2]:\n mask = (geo_levels == level).values\n if mask.any():\n vals = nnz_per_row[mask]\n print(\n f\" {level_names[level]:10s}: \"\n f\"n={mask.sum():>4d}, \"\n f\"median nnz={int(np.median(vals)):>7,}, \"\n f\"range=[{vals.min():,}, {vals.max():,}]\"\n )"
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Non-zeros per row:\n",
+ " min: 0\n",
+ " median: 10\n",
+ " mean: 21\n",
+ " max: 3,408\n",
+ "\n",
+ "By geographic level:\n",
+ " National : n= 1, median nnz= 3,408, range=[3,408, 3,408]\n",
+ " State : n= 102, median nnz= 80, range=[10, 694]\n",
+ " District : n=1308, median nnz= 9, range=[0, 27]\n"
+ ]
+ }
+ ],
+ "source": [
+ "nnz_per_row = np.diff(X_sparse.indptr)\n",
+ "print(f\"Non-zeros per row:\")\n",
+ "print(f\" min: {nnz_per_row.min():,}\")\n",
+ "print(f\" median: {int(np.median(nnz_per_row)):,}\")\n",
+ "print(f\" mean: {nnz_per_row.mean():,.0f}\")\n",
+ "print(f\" max: {nnz_per_row.max():,}\")\n",
+ "\n",
+ "geo_levels = targets_df[\"geographic_id\"].apply(get_geo_level)\n",
+ "level_names = {0: \"National\", 1: \"State\", 2: \"District\"}\n",
+ "print(\"\\nBy geographic level:\")\n",
+ "for level in [0, 1, 2]:\n",
+ " mask = (geo_levels == level).values\n",
+ " if mask.any():\n",
+ " vals = nnz_per_row[mask]\n",
+ " print(\n",
+ " f\" {level_names[level]:10s}: \"\n",
+ " f\"n={mask.sum():>4d}, \"\n",
+ " f\"median nnz={int(np.median(vals)):>7,}, \"\n",
+ " f\"range=[{vals.min():,}, {vals.max():,}]\"\n",
+ " )"
+ ]
},
{
"cell_type": "code",
@@ -488,9 +593,9 @@
"text": [
"Non-zeros per clone block:\n",
" clone nnz unique_states\n",
- " 0 4962 50\n",
- " 1 4988 50\n",
- " 2 4996 50\n"
+ " 0 9775 51\n",
+ " 1 9810 51\n",
+ " 2 9840 51\n"
]
}
],
@@ -613,15 +718,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Achievable targets: 479\n",
- "Impossible targets: 881\n",
+ "Achievable targets: 1358\n",
+ "Impossible targets: 2\n",
"\n",
"Impossible targets by (domain, variable):\n",
- " aca_ptc/aca_ptc: 436\n",
- " aca_ptc/tax_unit_count: 436\n",
- " snap/household_count: 7\n",
- " aca_ptc/person_count: 1\n",
- " snap/snap: 1\n"
+ " aca_ptc/aca_ptc: 1\n",
+ " aca_ptc/tax_unit_count: 1\n"
]
}
],
@@ -657,11 +759,11 @@
"output_type": "stream",
"text": [
"Hardest targets (lowest row_sum / target_value ratio):\n",
- " snap/household_count (geo=621): ratio=0.0000, row_sum=4, target=119,148\n",
- " snap/household_count (geo=3615): ratio=0.0001, row_sum=9, target=173,591\n",
- " snap/snap (geo=46): ratio=0.0001, row_sum=9,421, target=180,195,817\n",
- " snap/household_count (geo=3625): ratio=0.0001, row_sum=4, target=67,315\n",
- " snap/household_count (geo=1702): ratio=0.0001, row_sum=6, target=97,494\n"
+ " aca_ptc/aca_ptc (geo=3612): ratio=0.0000, row_sum=5,439, target=376,216,522\n",
+ " aca_ptc/aca_ptc (geo=2508): ratio=0.0000, row_sum=2,024, target=124,980,814\n",
+ " aca_ptc/tax_unit_count (geo=2508): ratio=0.0000, row_sum=1, target=51,937\n",
+ " aca_ptc/tax_unit_count (geo=3612): ratio=0.0000, row_sum=2, target=73,561\n",
+ " aca_ptc/tax_unit_count (geo=1198): ratio=0.0000, row_sum=1, target=30,419\n"
]
}
],
@@ -692,9 +794,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Final matrix shape: (479, 35997)\n",
- "Final non-zero entries: 9,944\n",
- "Final density: 0.000577\n",
+ "Final matrix shape: (1358, 35997)\n",
+ "Final non-zero entries: 23,018\n",
+ "Final density: 0.000471\n",
"\n",
"This is what the optimizer receives.\n"
]
@@ -747,4 +849,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/docs/hierarchical_uprating.ipynb b/docs/hierarchical_uprating.ipynb
index 4da30d82c..5839ccbbe 100644
--- a/docs/hierarchical_uprating.ipynb
+++ b/docs/hierarchical_uprating.ipynb
@@ -54,7 +54,7 @@
"from policyengine_us_data.calibration.unified_matrix_builder import (\n",
" UnifiedMatrixBuilder,\n",
")\n",
- "from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (\n",
+ "from policyengine_us_data.calibration.calibration_utils import (\n",
" STATE_CODES,\n",
")\n",
"\n",
diff --git a/docs/local_area_calibration_setup.ipynb b/docs/local_area_calibration_setup.ipynb
index 2e8614aa9..519e11a94 100644
--- a/docs/local_area_calibration_setup.ipynb
+++ b/docs/local_area_calibration_setup.ipynb
@@ -68,12 +68,12 @@
")\n",
"from policyengine_us_data.utils.randomness import seeded_rng\n",
"from policyengine_us_data.parameters import load_take_up_rate\n",
- "from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (\n",
+ "from policyengine_us_data.calibration.calibration_utils import (\n",
" get_calculated_variables,\n",
" STATE_CODES,\n",
" get_all_cds_from_database,\n",
")\n",
- "from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import (\n",
+ "from policyengine_us_data.calibration.stacked_dataset_builder import (\n",
" create_sparse_cd_stacked_dataset,\n",
")\n",
"\n",
@@ -96,7 +96,7 @@
"output_type": "stream",
"text": [
"Base dataset: 11,999 households\n",
- "Example household: record_idx=8629, household_id=128694, SNAP=$18,396.00\n"
+ "Example household: record_idx=8629, household_id=130831, SNAP=$0.00\n"
]
}
],
@@ -137,9 +137,9 @@
"output_type": "stream",
"text": [
"Total cloned records: 35,997\n",
- "Unique states: 50\n",
- "Unique CDs: 435\n",
- "Unique blocks: 35508\n"
+ "Unique states: 51\n",
+ "Unique CDs: 436\n",
+ "Unique blocks: 35517\n"
]
}
],
@@ -203,8 +203,8 @@
"
8629 | \n",
" 48 | \n",
" TX | \n",
- " 4817 | \n",
- " 481450004002026 | \n",
+ " 4816 | \n",
+ " 481410030003002 | \n",
" \n",
" \n",
" | 1 | \n",
@@ -213,7 +213,7 @@
" 42 | \n",
" PA | \n",
" 4201 | \n",
- " 420171058013029 | \n",
+ " 420171018051005 | \n",
"
\n",
" \n",
" | 2 | \n",
@@ -222,7 +222,7 @@
" 36 | \n",
" NY | \n",
" 3611 | \n",
- " 360850208041023 | \n",
+ " 360470200002002 | \n",
"
\n",
" \n",
"\n",
@@ -230,9 +230,9 @@
],
"text/plain": [
" clone col state_fips abbr cd_geoid block_geoid\n",
- "0 0 8629 48 TX 4817 481450004002026\n",
- "1 1 20628 42 PA 4201 420171058013029\n",
- "2 2 32627 36 NY 3611 360850208041023"
+ "0 0 8629 48 TX 4816 481410030003002\n",
+ "1 1 20628 42 PA 4201 420171018051005\n",
+ "2 2 32627 36 NY 3611 360470200002002"
]
},
"execution_count": 4,
@@ -280,13 +280,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Global block distribution: 5,765,442 blocks\n",
+ "Global block distribution: 5,769,942 blocks\n",
"Top 5 states by total probability:\n",
- " CA (6): 11.954%\n",
- " TX (48): 8.736%\n",
- " FL (12): 6.437%\n",
- " NY (36): 5.977%\n",
- " PA (42): 3.908%\n"
+ " CA (6): 11.927%\n",
+ " TX (48): 8.716%\n",
+ " FL (12): 6.422%\n",
+ " NY (36): 5.963%\n",
+ " PA (42): 3.899%\n"
]
}
],
@@ -327,10 +327,10 @@
"output_type": "stream",
"text": [
"Example household (record_idx=8629):\n",
- " Original state: NC (37)\n",
+ " Original state: CA (6)\n",
" Clone 0 state: TX (48)\n",
- " Original SNAP: $18,396.00\n",
- " Clone 0 SNAP: $18,396.00\n"
+ " Original SNAP: $0.00\n",
+ " Clone 0 SNAP: $0.00\n"
]
}
],
@@ -410,31 +410,31 @@
" 0 | \n",
" TX | \n",
" 48 | \n",
- " $18,396.00 | \n",
+ " $0.00 | \n",
" \n",
" \n",
" | 1 | \n",
" 1 | \n",
" PA | \n",
" 42 | \n",
- " $18,396.00 | \n",
+ " $0.00 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2 | \n",
" NY | \n",
" 36 | \n",
- " $18,396.00 | \n",
+ " $0.00 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " clone state state_fips SNAP\n",
- "0 0 TX 48 $18,396.00\n",
- "1 1 PA 42 $18,396.00\n",
- "2 2 NY 36 $18,396.00"
+ " clone state state_fips SNAP\n",
+ "0 0 TX 48 $0.00\n",
+ "1 1 PA 42 $0.00\n",
+ "2 2 NY 36 $0.00"
]
},
"execution_count": 7,
@@ -499,10 +499,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Unique states mapped: 50\n",
- "Unique CDs mapped: 435\n",
+ "Unique states mapped: 51\n",
+ "Unique CDs mapped: 436\n",
"\n",
- "Columns per state: min=62, median=494, max=4311\n"
+ "Columns per state: min=63, median=490, max=4299\n"
]
}
],
@@ -539,9 +539,9 @@
"text": [
"Example household clone visibility:\n",
"\n",
- "Clone 0 (TX, CD 4817):\n",
+ "Clone 0 (TX, CD 4816):\n",
" Visible to TX state targets: col 8629 in state_to_cols[48]? True\n",
- " Visible to CD 4817 targets: col 8629 in cd_to_cols['4817']? True\n",
+ " Visible to CD 4816 targets: col 8629 in cd_to_cols['4816']? True\n",
" Visible to NC (37) targets: False\n",
"\n",
"Clone 1 (PA, CD 4201):\n",
@@ -612,7 +612,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "8 takeup variables:\n",
+ "9 takeup variables:\n",
"\n",
" takes_up_snap_if_eligible entity=spm_unit rate=82.00%\n",
" takes_up_aca_if_eligible entity=tax_unit rate=67.20%\n",
@@ -621,7 +621,8 @@
" takes_up_early_head_start_if_eligible entity=person rate=9.00%\n",
" takes_up_ssi_if_eligible entity=person rate=50.00%\n",
" would_file_taxes_voluntarily entity=tax_unit rate=5.00%\n",
- " takes_up_medicaid_if_eligible entity=person rate=dict (51 entries)\n"
+ " takes_up_medicaid_if_eligible entity=person rate=dict (51 entries)\n",
+ " takes_up_tanf_if_eligible entity=spm_unit rate=22.00%\n"
]
}
],
@@ -708,14 +709,15 @@
"text": [
"Takeup rates before/after re-randomization (clone 0):\n",
"\n",
- " takes_up_snap_if_eligible before=82.333% after=82.381%\n",
- " takes_up_aca_if_eligible before=66.718% after=67.486%\n",
- " takes_up_dc_ptc before=31.483% after=32.044%\n",
- " takes_up_head_start_if_eligible before=29.963% after=29.689%\n",
- " takes_up_early_head_start_if_eligible before=8.869% after=8.721%\n",
- " takes_up_ssi_if_eligible before=100.000% after=49.776%\n",
- " would_file_taxes_voluntarily before=0.000% after=4.905%\n",
- " takes_up_medicaid_if_eligible before=84.496% after=80.051%\n"
+ " takes_up_snap_if_eligible before=82.116% after=82.364%\n",
+ " takes_up_aca_if_eligible before=67.115% after=67.278%\n",
+ " takes_up_dc_ptc before=31.673% after=31.534%\n",
+ " takes_up_head_start_if_eligible before=100.000% after=29.852%\n",
+ " takes_up_early_head_start_if_eligible before=100.000% after=8.904%\n",
+ " takes_up_ssi_if_eligible before=100.000% after=49.504%\n",
+ " would_file_taxes_voluntarily before=0.000% after=5.115%\n",
+ " takes_up_medicaid_if_eligible before=84.868% after=80.354%\n",
+ " takes_up_tanf_if_eligible before=100.000% after=21.991%\n"
]
}
],
@@ -801,11 +803,17 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2026-02-13 17:11:22,384 - INFO - Processing clone 1/3 (cols 0-11998, 50 unique states)...\n",
- "2026-02-13 17:11:23,509 - INFO - Processing clone 2/3 (cols 11999-23997, 50 unique states)...\n",
- "2026-02-13 17:11:24,645 - INFO - Processing clone 3/3 (cols 23998-35996, 50 unique states)...\n",
- "2026-02-13 17:11:25,769 - INFO - Assembling matrix from 3 clones...\n",
- "2026-02-13 17:11:25,771 - INFO - Matrix: 538 targets x 35997 cols, 14946 nnz\n"
+ "2026-02-20 15:34:21,531 - INFO - Per-state precomputation: 51 states, 1 hh vars, 1 constraint vars\n",
+ "2026-02-20 15:34:22,137 - INFO - State 1/51 complete\n",
+ "2026-02-20 15:34:27,750 - INFO - State 10/51 complete\n",
+ "2026-02-20 15:34:34,205 - INFO - State 20/51 complete\n",
+ "2026-02-20 15:34:40,885 - INFO - State 30/51 complete\n",
+ "2026-02-20 15:34:47,174 - INFO - State 40/51 complete\n",
+ "2026-02-20 15:34:53,723 - INFO - State 50/51 complete\n",
+ "2026-02-20 15:34:54,415 - INFO - Per-state precomputation done: 51 states\n",
+ "2026-02-20 15:34:54,419 - INFO - Assembling clone 1/3 (cols 0-11998, 51 unique states)...\n",
+ "2026-02-20 15:34:54,516 - INFO - Assembling matrix from 3 clones...\n",
+ "2026-02-20 15:34:54,517 - INFO - Matrix: 538 targets x 35997 cols, 19140 nnz\n"
]
},
{
@@ -813,8 +821,8 @@
"output_type": "stream",
"text": [
"Matrix shape: (538, 35997)\n",
- "Non-zero entries: 14,946\n",
- "Density: 0.000772\n"
+ "Non-zero entries: 19,140\n",
+ "Density: 0.000988\n"
]
}
],
@@ -848,18 +856,9 @@
"text": [
"Example household non-zero pattern across clones:\n",
"\n",
- "Clone 0 (TX, CD 4817): 3 non-zero rows\n",
- " row 39: household_count (geo=48): 1.00\n",
- " row 90: snap (geo=48): 18396.00\n",
- " row 410: household_count (geo=4817): 1.00\n",
- "Clone 1 (PA, CD 4201): 3 non-zero rows\n",
- " row 34: household_count (geo=42): 1.00\n",
- " row 85: snap (geo=42): 18396.00\n",
- " row 358: household_count (geo=4201): 1.00\n",
- "Clone 2 (NY, CD 3611): 3 non-zero rows\n",
- " row 27: household_count (geo=36): 1.00\n",
- " row 78: snap (geo=36): 18396.00\n",
- " row 292: household_count (geo=3611): 1.00\n"
+ "Clone 0 (TX, CD 4816): 0 non-zero rows\n",
+ "Clone 1 (PA, CD 4201): 0 non-zero rows\n",
+ "Clone 2 (NY, CD 3611): 0 non-zero rows\n"
]
}
],
@@ -993,6 +992,7 @@
"Extracted weights for 2 CDs from full weight matrix\n",
"Total active household-CD pairs: 277\n",
"Total weight in W matrix: 281\n",
+ "Warning: No rent data for CD 201, using geoadj=1.0\n",
"Processing CD 201 (2/2)...\n"
]
},
@@ -1000,10 +1000,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2026-02-13 17:11:40,873 - INFO - HTTP Request: GET https://huggingface.co/api/models/policyengine/policyengine-us-data \"HTTP/1.1 200 OK\"\n",
- "2026-02-13 17:11:40,899 - INFO - HTTP Request: HEAD https://huggingface.co/policyengine/policyengine-us-data/resolve/main/enhanced_cps_2024.h5 \"HTTP/1.1 302 Found\"\n",
- "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n",
- "2026-02-13 17:11:40,899 - WARNING - Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n"
+ "2026-02-20 15:35:04,090 - INFO - HTTP Request: GET https://huggingface.co/api/models/policyengine/policyengine-us-data \"HTTP/1.1 200 OK\"\n",
+ "2026-02-20 15:35:04,123 - INFO - HTTP Request: HEAD https://huggingface.co/policyengine/policyengine-us-data/resolve/main/enhanced_cps_2024.h5 \"HTTP/1.1 302 Found\"\n"
]
},
{
@@ -1013,7 +1011,7 @@
"\n",
"Combining 2 CD DataFrames...\n",
"Total households across all CDs: 277\n",
- "Combined DataFrame shape: (726, 222)\n",
+ "Combined DataFrame shape: (716, 219)\n",
"\n",
"Reindexing all entity IDs using 25k ranges per CD...\n",
" Created 277 unique households across 2 CDs\n",
@@ -1022,12 +1020,12 @@
" Reindexing SPM units...\n",
" Reindexing marital units...\n",
" Reindexing families...\n",
- " Final persons: 726\n",
+ " Final persons: 716\n",
" Final households: 277\n",
- " Final tax units: 373\n",
- " Final SPM units: 291\n",
- " Final marital units: 586\n",
- " Final families: 309\n",
+ " Final tax units: 387\n",
+ " Final SPM units: 290\n",
+ " Final marital units: 587\n",
+ " Final families: 318\n",
"\n",
"Weights in combined_df AFTER reindexing:\n",
" HH weight sum: 0.00M\n",
@@ -1035,8 +1033,8 @@
" Ratio: 1.00\n",
"\n",
"Overflow check:\n",
- " Max person ID after reindexing: 5,025,335\n",
- " Max person ID × 100: 502,533,500\n",
+ " Max person ID after reindexing: 5,025,365\n",
+ " Max person ID × 100: 502,536,500\n",
" int32 max: 2,147,483,647\n",
" ✓ No overflow risk!\n",
"\n",
@@ -1044,15 +1042,15 @@
"Building simulation from Dataset...\n",
"\n",
"Saving to calibration_output/results.h5...\n",
- "Found 175 input variables to save\n",
- "Variables saved: 218\n",
- "Variables skipped: 3763\n",
+ "Found 172 input variables to save\n",
+ "Variables saved: 215\n",
+ "Variables skipped: 3825\n",
"Sparse CD-stacked dataset saved successfully!\n",
"Household mapping saved to calibration_output/mappings/results_household_mapping.csv\n",
"\n",
"Verifying saved file...\n",
" Final households: 277\n",
- " Final persons: 726\n",
+ " Final persons: 716\n",
" Total population (from household weights): 281\n"
]
},
@@ -1089,17 +1087,17 @@
"text": [
"Stacked dataset: 277 households\n",
"\n",
- "Example household (original_id=128694) in mapping:\n",
+ "Example household (original_id=130831) in mapping:\n",
"\n",
" new_household_id original_household_id congressional_district state_fips\n",
- " 108 128694 201 2\n",
- " 25097 128694 3701 37\n",
+ " 108 130831 201 2\n",
+ " 25097 130831 3701 37\n",
"\n",
"In stacked dataset:\n",
"\n",
- " household_id congressional_district_geoid household_weight state_fips snap\n",
- " 108 201 3.5 2 23640.0\n",
- " 25097 3701 2.5 37 18396.0\n"
+ " household_id congressional_district_geoid household_weight state_fips snap\n",
+ " 108 201 3.5 2 0.0\n",
+ " 25097 3701 2.5 37 0.0\n"
]
}
],
diff --git a/docs/myst.yml b/docs/myst.yml
index c38666af9..56d74ec45 100644
--- a/docs/myst.yml
+++ b/docs/myst.yml
@@ -26,6 +26,7 @@ project:
- file: methodology.md
- file: long_term_projections.md
- file: local_area_calibration_setup.ipynb
+ - file: calibration.md
- file: discussion.md
- file: conclusion.md
- file: appendix.md
diff --git a/modal_app/README.md b/modal_app/README.md
index 0b10cf726..49c976aab 100644
--- a/modal_app/README.md
+++ b/modal_app/README.md
@@ -22,41 +22,217 @@ modal run modal_app/remote_calibration_runner.py --branch --epochs
| `--epochs` | `200` | Number of training epochs |
| `--gpu` | `T4` | GPU type: `T4`, `A10`, `A100-40GB`, `A100-80GB`, `H100` |
| `--output` | `calibration_weights.npy` | Local path for weights file |
-| `--log-output` | `calibration_log.csv` | Local path for calibration log |
+| `--log-output` | `unified_diagnostics.csv` | Local path for diagnostics log |
+| `--log-freq` | (none) | Log every N epochs to `calibration_log.csv` |
+| `--push-results` | `False` | Upload weights, blocks, and logs to HuggingFace |
+| `--trigger-publish` | `False` | Fire `repository_dispatch` to trigger the Publish workflow |
+| `--target-config` | (none) | Target configuration name |
+| `--beta` | (none) | L0 relaxation parameter |
+| `--lambda-l0` | (none) | L0 penalty weight |
+| `--lambda-l2` | (none) | L2 penalty weight |
+| `--learning-rate` | (none) | Optimizer learning rate |
+| `--package-path` | (none) | Local path to a pre-built calibration package (uploads to Modal volume, then fits) |
+| `--prebuilt-matrices` | `False` | Fit from pre-built package on Modal volume |
+| `--full-pipeline` | `False` | Force full rebuild even if a package exists on the volume |
+| `--county-level` | `False` | Include county-level targets |
+| `--workers` | `1` | Number of parallel workers |
+
+### Examples
+
+**Two-step workflow (recommended):**
+
+Step 1 — Build the X matrix on CPU (no GPU cost, 10h timeout):
+```bash
+modal run modal_app/remote_calibration_runner.py::build_package \
+ --branch main
+```
-### Example
+Step 2 — Fit weights from the pre-built package on GPU:
+```bash
+modal run modal_app/remote_calibration_runner.py::main \
+ --branch main --epochs 200 --gpu A100-80GB \
+ --prebuilt-matrices --push-results
+```
+
+**Full pipeline (single step, requires enough timeout for matrix build + fit):**
+```bash
+modal run modal_app/remote_calibration_runner.py::main \
+ --branch main --epochs 200 --gpu A100-80GB \
+ --full-pipeline --push-results
+```
+Fit, push, and trigger the publish workflow:
```bash
-modal run modal_app/remote_calibration_runner.py --branch health-insurance-premiums --epochs 100 --gpu T4
+modal run modal_app/remote_calibration_runner.py::main \
+ --gpu A100-80GB --epochs 200 \
+ --prebuilt-matrices --push-results --trigger-publish
```
## Output Files
-- **calibration_weights.npy** - Fitted household weights
-- **calibration_log.csv** - Per-target performance metrics across epochs (target_name, estimate, target, epoch, error, rel_error, abs_error, rel_abs_error, loss)
+Every run produces these local files (whichever the calibration script emits):
-## Changing Hyperparameters
+- **calibration_weights.npy** — Fitted household weights
+- **unified_diagnostics.csv** — Final per-target diagnostics
+- **calibration_log.csv** — Per-target metrics across epochs (requires `--log-freq`)
+- **unified_run_config.json** — Run configuration and summary stats
+- **stacked_blocks.npy** — Census block assignments for stacked records
-Hyperparameters are in `policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py`:
+## Artifact Upload to HuggingFace
+
+The `--push-results` flag uploads all artifacts to HuggingFace in a single
+atomic commit after writing them locally:
+
+| Local file | HF path |
+|------------|---------|
+| `calibration_weights.npy` | `calibration/calibration_weights.npy` |
+| `stacked_blocks.npy` | `calibration/stacked_blocks.npy` |
+| `calibration_log.csv` | `calibration/logs/calibration_log.csv` |
+| `unified_diagnostics.csv` | `calibration/logs/unified_diagnostics.csv` |
+| `unified_run_config.json` | `calibration/logs/unified_run_config.json` |
+
+Each upload overwrites the previous files. HF git history provides implicit
+versioning — browse past commits to see earlier runs.
+
+## Triggering the Publish Workflow
+
+The `--trigger-publish` flag fires a `repository_dispatch` event
+(`calibration-updated`) on GitHub, which starts the "Publish Local Area H5
+Files" workflow. Requires `GITHUB_TOKEN` or
+`POLICYENGINE_US_DATA_GITHUB_TOKEN` set locally.
+
+### Downloading logs
```python
-BETA = 0.35
-GAMMA = -0.1
-ZETA = 1.1
-INIT_KEEP_PROB = 0.999
-LOG_WEIGHT_JITTER_SD = 0.05
-LOG_ALPHA_JITTER_SD = 0.01
-LAMBDA_L0 = 1e-8
-LAMBDA_L2 = 1e-8
-LEARNING_RATE = 0.15
-```
-
-To change them:
-1. Edit `fit_calibration_weights.py`
-2. Commit and push to your branch
-3. Re-run the Modal command with that branch
+from policyengine_us_data.utils.huggingface import download_calibration_logs
+
+paths = download_calibration_logs("/tmp/cal_logs")
+# {"calibration_log": Path(...), "diagnostics": Path(...), "config": Path(...)}
+```
+
+Pass `version=""` to download from a specific HF revision.
+
+### Viewing logs in the microcalibrate dashboard
+
+The [microcalibration dashboard](https://github.com/PolicyEngine/microcalibrate)
+has a **Hugging Face** tab that loads `calibration_log.csv` directly from HF:
+
+1. Open the dashboard
+2. Click the **Hugging Face** tab
+3. Defaults are pre-filled — click **Load**
+4. Change the **Revision** field to load from a specific HF commit or tag
## Important Notes
-- **Keep your connection open** - Modal needs to stay connected to download results. Don't close your laptop or let it sleep until you see the local "Weights saved to:" and "Calibration log saved to:" messages.
-- Modal clones from GitHub, so local changes must be pushed before they take effect.
+- **Keep your connection open** — Modal needs to stay connected to download
+ results. Don't close your laptop or let it sleep until you see the local
+ "Weights saved to:" message.
+- Modal clones from GitHub, so local changes must be pushed before they
+ take effect.
+- `--push-results` requires the `HUGGING_FACE_TOKEN` environment variable
+ to be set locally (not just as a Modal secret).
+- `--trigger-publish` requires `GITHUB_TOKEN` or
+ `POLICYENGINE_US_DATA_GITHUB_TOKEN` set locally.
+
+## Full Pipeline Reference
+
+The calibration pipeline has six stages. Each can be run locally, via Modal CLI, or via GitHub Actions.
+
+### Stage 1: Build data
+
+Produces `stratified_extended_cps_2024.h5` from raw CPS/PUF/ACS inputs.
+
+| Method | Command |
+|--------|---------|
+| **Local** | `make data` |
+| **Modal (CI)** | `modal run modal_app/data_build.py --branch=` |
+| **GitHub Actions** | Automatic on merge to `main` via `code_changes.yaml` → `reusable_test.yaml` (with `full_suite: true`). Also triggered by `pr_code_changes.yaml` on PRs. |
+
+Notes:
+- `make data` stops at `create_stratified_cps.py`. Use `make data-legacy` to also build `enhanced_cps.py` and `small_enhanced_cps.py`.
+- `data_build.py` (CI) always builds the full suite including enhanced_cps.
+
+### Stage 2: Upload inputs to HuggingFace
+
+Pushes the dataset and (optionally) database to HF so Modal can download them.
+
+| Artifact | Command |
+|----------|---------|
+| Dataset | `make upload-dataset` |
+| Database | `make upload-database` |
+
+The database is relatively stable; only re-upload after `make database` or `make database-refresh`.
+
+### Stage 3: Build calibration matrices
+
+Downloads dataset + database from HF, builds the X matrix, saves to Modal volume. CPU-only, no GPU cost.
+
+| Method | Command |
+|--------|---------|
+| **Local** | `make calibrate-build` |
+| **Modal CLI** | `make build-matrices BRANCH=` (aka `modal run modal_app/remote_calibration_runner.py::build_package --branch=`) |
+
+### Stage 4: Fit calibration weights
+
+Loads pre-built matrices from Modal volume, fits L0-regularized weights on GPU.
+
+| Method | Command |
+|--------|---------|
+| **Local (CPU)** | `make calibrate` |
+| **Modal CLI** | `make calibrate-modal BRANCH= GPU= EPOCHS=` |
+
+`make calibrate-modal` passes `--prebuilt-matrices --push-results` automatically.
+
+Full example:
+```
+modal run modal_app/remote_calibration_runner.py::main \
+ --branch calibration-pipeline-improvements \
+ --gpu T4 --epochs 1000 \
+ --beta 0.65 --lambda-l0 1e-6 --lambda-l2 1e-8 \
+ --log-freq 500 \
+ --target-config policyengine_us_data/calibration/target_config.yaml \
+ --prebuilt-matrices --push-results
+```
+
+**Safety check**: If a pre-built package exists on the volume and you don't pass `--prebuilt-matrices` or `--full-pipeline`, the runner refuses to proceed and tells you which flag to add. This prevents accidentally rebuilding from scratch.
+
+Artifacts uploaded to HF by `--push-results`:
+
+| Local file | HF path |
+|------------|---------|
+| `calibration_weights.npy` | `calibration/calibration_weights.npy` |
+| `stacked_blocks.npy` | `calibration/stacked_blocks.npy` |
+| `calibration_log.csv` | `calibration/logs/calibration_log.csv` |
+| `unified_diagnostics.csv` | `calibration/logs/unified_diagnostics.csv` |
+| `unified_run_config.json` | `calibration/logs/unified_run_config.json` |
+
+### Stage 5: Build and stage local area H5 files
+
+Downloads weights + dataset + database from HF, builds state/district/city H5 files.
+
+| Method | Command |
+|--------|---------|
+| **Local** | `python policyengine_us_data/calibration/publish_local_area.py --rerandomize-takeup` |
+| **Modal CLI** | `make stage-h5s BRANCH=` (aka `modal run modal_app/local_area.py --branch= --num-workers=8`) |
+| **GitHub Actions** | "Publish Local Area H5 Files" workflow — manual trigger via `workflow_dispatch`, or automatic via `repository_dispatch` (`--trigger-publish` flag), or on code push to `main` touching `calibration/` or `modal_app/`. |
+
+This stages H5s to HF `staging/` paths. It does NOT promote to production or GCS.
+
+### Stage 6: Promote (manual gate)
+
+Moves files from HF staging to production paths and uploads to GCS.
+
+| Method | Command |
+|--------|---------|
+| **Modal CLI** | `modal run modal_app/local_area.py::main_promote --version=` |
+| **GitHub Actions** | "Promote Local Area H5 Files" workflow — manual `workflow_dispatch` only. Requires `version` input. |
+
+### One-command pipeline
+
+For the common case (local data build → Modal calibration → Modal staging):
+
+```
+make pipeline GPU=T4 EPOCHS=1000 BRANCH=calibration-pipeline-improvements
+```
+
+This chains: `data` → `upload-dataset` → `build-matrices` → `calibrate-modal` → `stage-h5s`.
diff --git a/modal_app/data_build.py b/modal_app/data_build.py
index 131e7f0bf..8c75187d2 100644
--- a/modal_app/data_build.py
+++ b/modal_app/data_build.py
@@ -55,10 +55,13 @@
"policyengine_us_data/storage/enhanced_cps_2024.h5",
"calibration_log.csv",
],
- "policyengine_us_data/datasets/cps/"
- "local_area_calibration/create_stratified_cps.py": (
+ "policyengine_us_data/calibration/create_stratified_cps.py": (
"policyengine_us_data/storage/stratified_extended_cps_2024.h5"
),
+ "policyengine_us_data/calibration/create_source_imputed_cps.py": (
+ "policyengine_us_data/storage/"
+ "source_imputed_stratified_extended_cps_2024.h5"
+ ),
"policyengine_us_data/datasets/cps/small_enhanced_cps.py": (
"policyengine_us_data/storage/small_enhanced_cps_2024.h5"
),
@@ -70,7 +73,7 @@
"policyengine_us_data/tests/test_database.py",
"policyengine_us_data/tests/test_pandas3_compatibility.py",
"policyengine_us_data/tests/test_datasets/",
- "policyengine_us_data/tests/test_local_area_calibration/",
+ "policyengine_us_data/tests/test_calibration/",
]
@@ -408,11 +411,9 @@ def build_datasets(
),
executor.submit(
run_script_with_checkpoint,
- "policyengine_us_data/datasets/cps/"
- "local_area_calibration/create_stratified_cps.py",
+ "policyengine_us_data/calibration/create_stratified_cps.py",
SCRIPT_OUTPUTS[
- "policyengine_us_data/datasets/cps/"
- "local_area_calibration/create_stratified_cps.py"
+ "policyengine_us_data/calibration/create_stratified_cps.py"
],
branch,
checkpoint_volume,
diff --git a/modal_app/local_area.py b/modal_app/local_area.py
index 92e068335..33f798cf9 100644
--- a/modal_app/local_area.py
+++ b/modal_app/local_area.py
@@ -128,6 +128,88 @@ def get_completed_from_volume(version_dir: Path) -> set:
return completed
+def run_phase(
+ phase_name: str,
+ states: List[str],
+ districts: List[str],
+ cities: List[str],
+ num_workers: int,
+ completed: set,
+ branch: str,
+ version: str,
+ calibration_inputs: Dict[str, str],
+ version_dir: Path,
+) -> set:
+ """Run a single build phase, spawning workers and collecting results."""
+ work_chunks = partition_work(
+ states, districts, cities, num_workers, completed
+ )
+ total_remaining = sum(len(c) for c in work_chunks)
+
+ print(f"\n--- Phase: {phase_name} ---")
+ print(
+ f"Remaining work: {total_remaining} items "
+ f"across {len(work_chunks)} workers"
+ )
+
+ if total_remaining == 0:
+ print(f"All {phase_name} items already built!")
+ return completed
+
+ handles = []
+ for i, chunk in enumerate(work_chunks):
+ print(f" Worker {i}: {len(chunk)} items")
+ handle = build_areas_worker.spawn(
+ branch=branch,
+ version=version,
+ work_items=chunk,
+ calibration_inputs=calibration_inputs,
+ )
+ handles.append(handle)
+
+ print(f"Waiting for {phase_name} workers to complete...")
+ all_results = []
+ all_errors = []
+
+ for i, handle in enumerate(handles):
+ try:
+ result = handle.get()
+ all_results.append(result)
+ print(
+ f" Worker {i}: {len(result['completed'])} completed, "
+ f"{len(result['failed'])} failed"
+ )
+ if result["errors"]:
+ all_errors.extend(result["errors"])
+ except Exception as e:
+ all_errors.append({"worker": i, "error": str(e)})
+ print(f" Worker {i}: CRASHED - {e}")
+
+ total_completed = sum(len(r["completed"]) for r in all_results)
+ total_failed = sum(len(r["failed"]) for r in all_results)
+
+ staging_volume.reload()
+ volume_completed = get_completed_from_volume(version_dir)
+ volume_new = volume_completed - completed
+
+ print(f"\n{phase_name} summary (worker-reported):")
+ print(f" Completed: {total_completed}")
+ print(f" Failed: {total_failed}")
+ print(f"{phase_name} summary (volume verification):")
+ print(f" Files on volume: {len(volume_completed)}")
+ print(f" New files this run: {len(volume_new)}")
+
+ if all_errors:
+ print(f"\nErrors ({len(all_errors)}):")
+ for err in all_errors[:5]:
+ err_msg = err.get("error", "Unknown")[:100]
+ print(f" - {err.get('item', err.get('worker'))}: " f"{err_msg}")
+ if len(all_errors) > 5:
+ print(f" ... and {len(all_errors) - 5} more")
+
+ return volume_completed
+
+
@app.function(
image=image,
secrets=[hf_secret, gcp_secret],
@@ -154,23 +236,46 @@ def build_areas_worker(
work_items_json = json.dumps(work_items)
+ worker_cmd = [
+ "uv",
+ "run",
+ "python",
+ "modal_app/worker_script.py",
+ "--work-items",
+ work_items_json,
+ "--weights-path",
+ calibration_inputs["weights"],
+ "--dataset-path",
+ calibration_inputs["dataset"],
+ "--db-path",
+ calibration_inputs["database"],
+ "--output-dir",
+ str(output_dir),
+ ]
+ if "blocks" in calibration_inputs:
+ worker_cmd.extend(
+ [
+ "--calibration-blocks",
+ calibration_inputs["blocks"],
+ ]
+ )
+ if "geo_labels" in calibration_inputs:
+ worker_cmd.extend(
+ [
+ "--geo-labels",
+ calibration_inputs["geo_labels"],
+ ]
+ )
+ if "stacked_takeup" in calibration_inputs:
+ worker_cmd.extend(
+ [
+ "--stacked-takeup",
+ calibration_inputs["stacked_takeup"],
+ ]
+ )
+
result = subprocess.run(
- [
- "uv",
- "run",
- "python",
- "modal_app/worker_script.py",
- "--work-items",
- work_items_json,
- "--weights-path",
- calibration_inputs["weights"],
- "--dataset-path",
- calibration_inputs["dataset"],
- "--db-path",
- calibration_inputs["database"],
- "--output-dir",
- str(output_dir),
- ],
+ worker_cmd,
capture_output=True,
text=True,
env=os.environ.copy(),
@@ -254,18 +359,18 @@ def validate_staging(branch: str, version: str) -> Dict:
@app.function(
image=image,
- secrets=[hf_secret, gcp_secret],
+ secrets=[hf_secret],
volumes={VOLUME_MOUNT: staging_volume},
memory=8192,
timeout=14400,
)
def upload_to_staging(branch: str, version: str, manifest: Dict) -> str:
"""
- Upload files to GCS (production) and HuggingFace (staging only).
+ Upload files to HuggingFace staging only.
+ GCS is updated during promote_publish, not here.
Promote must be run separately via promote_publish.
"""
- setup_gcp_credentials()
setup_repo(branch)
manifest_json = json.dumps(manifest)
@@ -280,10 +385,7 @@ def upload_to_staging(branch: str, version: str, manifest: Dict) -> str:
import json
from pathlib import Path
from policyengine_us_data.utils.manifest import verify_manifest
-from policyengine_us_data.utils.data_upload import (
- upload_local_area_file,
- upload_to_staging_hf,
-)
+from policyengine_us_data.utils.data_upload import upload_to_staging_hf
manifest = json.loads('''{manifest_json}''')
version = "{version}"
@@ -305,20 +407,6 @@ def upload_to_staging(branch: str, version: str, manifest: Dict) -> str:
local_path = version_dir / rel_path
files_with_paths.append((local_path, rel_path))
-# Upload to GCS (direct to production paths)
-print(f"Uploading {{len(files_with_paths)}} files to GCS...")
-gcs_count = 0
-for local_path, rel_path in files_with_paths:
- subdirectory = str(Path(rel_path).parent)
- upload_local_area_file(
- str(local_path),
- subdirectory,
- version=version,
- skip_hf=True,
- )
- gcs_count += 1
-print(f"Uploaded {{gcs_count}} files to GCS")
-
# Upload to HuggingFace staging/
print(f"Uploading {{len(files_with_paths)}} files to HuggingFace staging/...")
hf_count = upload_to_staging_hf(files_with_paths, version)
@@ -336,24 +424,26 @@ def upload_to_staging(branch: str, version: str, manifest: Dict) -> str:
return (
f"Staged version {version} with {len(manifest['files'])} files. "
- f"Run promote workflow to publish to HuggingFace production."
+ f"Run promote workflow to publish to HuggingFace production and GCS."
)
@app.function(
image=image,
- secrets=[hf_secret],
+ secrets=[hf_secret, gcp_secret],
volumes={VOLUME_MOUNT: staging_volume},
memory=4096,
timeout=3600,
)
def promote_publish(branch: str = "main", version: str = "") -> str:
"""
- Promote staged files from HF staging/ to production paths, then cleanup.
+ Promote staged files from HF staging/ to production paths,
+ upload to GCS, then cleanup HF staging.
Reads the manifest from the Modal staging volume to determine which
files to promote.
"""
+ setup_gcp_credentials()
setup_repo(branch)
staging_dir = Path(VOLUME_MOUNT)
@@ -379,17 +469,34 @@ def promote_publish(branch: str = "main", version: str = "") -> str:
"-c",
f"""
import json
+from pathlib import Path
from policyengine_us_data.utils.data_upload import (
promote_staging_to_production_hf,
cleanup_staging_hf,
+ upload_local_area_file,
)
rel_paths = json.loads('''{rel_paths_json}''')
version = "{version}"
+version_dir = Path("{VOLUME_MOUNT}") / version
print(f"Promoting {{len(rel_paths)}} files from staging/ to production...")
promoted = promote_staging_to_production_hf(rel_paths, version)
-print(f"Promoted {{promoted}} files to production")
+print(f"Promoted {{promoted}} files to HuggingFace production")
+
+print(f"Uploading {{len(rel_paths)}} files to GCS...")
+gcs_count = 0
+for rel_path in rel_paths:
+ local_path = version_dir / rel_path
+ subdirectory = str(Path(rel_path).parent)
+ upload_local_area_file(
+ str(local_path),
+ subdirectory,
+ version=version,
+ skip_hf=True,
+ )
+ gcs_count += 1
+print(f"Uploaded {{gcs_count}} files to GCS")
print("Cleaning up staging/...")
cleaned = cleanup_staging_hf(rel_paths, version)
@@ -419,6 +526,7 @@ def coordinate_publish(
branch: str = "main",
num_workers: int = 8,
skip_upload: bool = False,
+ skip_download: bool = False,
) -> str:
"""Coordinate the full publishing workflow."""
setup_gcp_credentials()
@@ -428,24 +536,46 @@ def coordinate_publish(
print(f"Publishing version {version} from branch {branch}")
print(f"Using {num_workers} parallel workers")
+ import shutil
+
staging_dir = Path(VOLUME_MOUNT)
version_dir = staging_dir / version
+ if version_dir.exists():
+ print(f"Clearing stale build directory: {version_dir}")
+ shutil.rmtree(version_dir)
version_dir.mkdir(parents=True, exist_ok=True)
calibration_dir = staging_dir / "calibration_inputs"
- calibration_dir.mkdir(parents=True, exist_ok=True)
# hf_hub_download preserves directory structure, so files are in calibration/ subdir
- weights_path = (
- calibration_dir / "calibration" / "w_district_calibration.npy"
- )
- dataset_path = (
- calibration_dir / "calibration" / "stratified_extended_cps.h5"
- )
+ weights_path = calibration_dir / "calibration" / "calibration_weights.npy"
db_path = calibration_dir / "calibration" / "policy_data.db"
- if not all(p.exists() for p in [weights_path, dataset_path, db_path]):
- print("Downloading calibration inputs...")
+ if skip_download:
+ print("Verifying pre-pushed calibration inputs...")
+ staging_volume.reload()
+ dataset_path = (
+ calibration_dir
+ / "calibration"
+ / "source_imputed_stratified_extended_cps.h5"
+ )
+ required = {
+ "weights": weights_path,
+ "dataset": dataset_path,
+ "database": db_path,
+ }
+ for label, p in required.items():
+ if not p.exists():
+ raise RuntimeError(
+ f"Missing required calibration input " f"({label}): {p}"
+ )
+ print("All required calibration inputs found on volume.")
+ else:
+ if calibration_dir.exists():
+ shutil.rmtree(calibration_dir)
+ calibration_dir.mkdir(parents=True, exist_ok=True)
+
+ print("Downloading calibration inputs from HuggingFace...")
result = subprocess.run(
[
"uv",
@@ -464,15 +594,31 @@ def coordinate_publish(
if result.returncode != 0:
raise RuntimeError(f"Download failed: {result.stderr}")
staging_volume.commit()
- print("Calibration inputs downloaded and cached on volume")
- else:
- print("Using cached calibration inputs from volume")
+ print("Calibration inputs downloaded")
+ dataset_path = (
+ calibration_dir
+ / "calibration"
+ / "source_imputed_stratified_extended_cps.h5"
+ )
+
+ blocks_path = calibration_dir / "calibration" / "stacked_blocks.npy"
+ geo_labels_path = calibration_dir / "calibration" / "geo_labels.json"
calibration_inputs = {
"weights": str(weights_path),
"dataset": str(dataset_path),
"database": str(db_path),
}
+ if blocks_path.exists():
+ calibration_inputs["blocks"] = str(blocks_path)
+ print(f"Calibration blocks found: {blocks_path}")
+ if geo_labels_path.exists():
+ calibration_inputs["geo_labels"] = str(geo_labels_path)
+ print(f"Geo labels found: {geo_labels_path}")
+ takeup_path = calibration_dir / "calibration" / "stacked_takeup.npz"
+ if takeup_path.exists():
+ calibration_inputs["stacked_takeup"] = str(takeup_path)
+ print(f"Stacked takeup found: {takeup_path}")
result = subprocess.run(
[
@@ -482,11 +628,11 @@ def coordinate_publish(
"-c",
f"""
import json
-from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
+from policyengine_us_data.calibration.calibration_utils import (
get_all_cds_from_database,
STATE_CODES,
)
-from policyengine_us_data.datasets.cps.local_area_calibration.publish_local_area import (
+from policyengine_us_data.calibration.publish_local_area import (
get_district_friendly_name,
)
@@ -514,70 +660,49 @@ def coordinate_publish(
completed = get_completed_from_volume(version_dir)
print(f"Found {len(completed)} already-completed items on volume")
- work_chunks = partition_work(
- states, districts, cities, num_workers, completed
+ phase_args = dict(
+ num_workers=num_workers,
+ branch=branch,
+ version=version,
+ calibration_inputs=calibration_inputs,
+ version_dir=version_dir,
)
- total_remaining = sum(len(c) for c in work_chunks)
- print(
- f"Remaining work: {total_remaining} items "
- f"across {len(work_chunks)} workers"
+ completed = run_phase(
+ "States",
+ states=states,
+ districts=[],
+ cities=[],
+ completed=completed,
+ **phase_args,
)
- if total_remaining == 0:
- print("All items already built!")
- else:
- print("\nSpawning workers...")
- handles = []
- for i, chunk in enumerate(work_chunks):
- print(f" Worker {i}: {len(chunk)} items")
- handle = build_areas_worker.spawn(
- branch=branch,
- version=version,
- work_items=chunk,
- calibration_inputs=calibration_inputs,
- )
- handles.append(handle)
-
- print("\nWaiting for workers to complete...")
- all_results = []
- all_errors = []
-
- for i, handle in enumerate(handles):
- try:
- result = handle.get()
- all_results.append(result)
- print(
- f" Worker {i}: {len(result['completed'])} completed, "
- f"{len(result['failed'])} failed"
- )
- if result["errors"]:
- all_errors.extend(result["errors"])
- except Exception as e:
- all_errors.append({"worker": i, "error": str(e)})
- print(f" Worker {i}: CRASHED - {e}")
-
- total_completed = sum(len(r["completed"]) for r in all_results)
- total_failed = sum(len(r["failed"]) for r in all_results)
-
- print(f"\nBuild summary:")
- print(f" Completed: {total_completed}")
- print(f" Failed: {total_failed}")
- print(f" Previously completed: {len(completed)}")
-
- if all_errors:
- print(f"\nErrors ({len(all_errors)}):")
- for err in all_errors[:5]:
- err_msg = err.get("error", "Unknown")[:100]
- print(f" - {err.get('item', err.get('worker'))}: {err_msg}")
- if len(all_errors) > 5:
- print(f" ... and {len(all_errors) - 5} more")
-
- if total_failed > 0:
- raise RuntimeError(
- f"Build incomplete: {total_failed} failures. "
- f"Volume preserved for retry."
- )
+ completed = run_phase(
+ "Districts",
+ states=[],
+ districts=districts,
+ cities=[],
+ completed=completed,
+ **phase_args,
+ )
+
+ completed = run_phase(
+ "Cities",
+ states=[],
+ districts=[],
+ cities=cities,
+ completed=completed,
+ **phase_args,
+ )
+
+ expected_total = len(states) + len(districts) + len(cities)
+ if len(completed) < expected_total:
+ missing = expected_total - len(completed)
+ raise RuntimeError(
+ f"Build incomplete: {missing} files missing from "
+ f"volume ({len(completed)}/{expected_total}). "
+ f"Volume preserved for retry."
+ )
if skip_upload:
print("\nSkipping upload (--skip-upload flag set)")
@@ -625,13 +750,235 @@ def main(
branch: str = "main",
num_workers: int = 8,
skip_upload: bool = False,
+ skip_download: bool = False,
):
"""Local entrypoint for Modal CLI."""
result = coordinate_publish.remote(
branch=branch,
num_workers=num_workers,
skip_upload=skip_upload,
+ skip_download=skip_download,
+ )
+ print(result)
+
+
+@app.function(
+ image=image,
+ secrets=[hf_secret, gcp_secret],
+ volumes={VOLUME_MOUNT: staging_volume},
+ memory=16384,
+ timeout=14400,
+)
+def coordinate_national_publish(
+ branch: str = "main",
+) -> str:
+ """Build and upload a national US.h5 from national weights."""
+ setup_gcp_credentials()
+ setup_repo(branch)
+
+ version = get_version()
+ print(
+ f"Building national H5 for version {version} " f"from branch {branch}"
+ )
+
+ import shutil
+
+ staging_dir = Path(VOLUME_MOUNT)
+ calibration_dir = staging_dir / "national_calibration_inputs"
+ if calibration_dir.exists():
+ shutil.rmtree(calibration_dir)
+ calibration_dir.mkdir(parents=True, exist_ok=True)
+
+ print("Downloading national calibration inputs from HF...")
+ result = subprocess.run(
+ [
+ "uv",
+ "run",
+ "python",
+ "-c",
+ f"""
+from policyengine_us_data.utils.huggingface import (
+ download_calibration_inputs,
+)
+download_calibration_inputs("{calibration_dir}", prefix="national_")
+print("Done")
+""",
+ ],
+ text=True,
+ env=os.environ.copy(),
+ )
+ if result.returncode != 0:
+ raise RuntimeError(f"Download failed: {result.stderr}")
+ staging_volume.commit()
+ print("National calibration inputs downloaded")
+
+ weights_path = (
+ calibration_dir / "calibration" / "national_calibration_weights.npy"
+ )
+ db_path = calibration_dir / "calibration" / "policy_data.db"
+ dataset_path = (
+ calibration_dir
+ / "calibration"
+ / "source_imputed_stratified_extended_cps.h5"
+ )
+
+ blocks_path = (
+ calibration_dir / "calibration" / "national_stacked_blocks.npy"
)
+ national_geo_labels_path = (
+ calibration_dir / "calibration" / "national_geo_labels.json"
+ )
+ calibration_inputs = {
+ "weights": str(weights_path),
+ "dataset": str(dataset_path),
+ "database": str(db_path),
+ }
+ if blocks_path.exists():
+ calibration_inputs["blocks"] = str(blocks_path)
+ print(f"National calibration blocks found: {blocks_path}")
+ if national_geo_labels_path.exists():
+ calibration_inputs["geo_labels"] = str(national_geo_labels_path)
+ print(f"National geo labels found: " f"{national_geo_labels_path}")
+ national_takeup_path = (
+ calibration_dir / "calibration" / "national_stacked_takeup.npz"
+ )
+ if national_takeup_path.exists():
+ calibration_inputs["stacked_takeup"] = str(national_takeup_path)
+ print(f"National stacked takeup found: " f"{national_takeup_path}")
+
+ version_dir = staging_dir / version
+ version_dir.mkdir(parents=True, exist_ok=True)
+
+ work_items = [{"type": "national", "id": "US"}]
+ print("Spawning worker for national H5 build...")
+ worker_result = build_areas_worker.remote(
+ branch=branch,
+ version=version,
+ work_items=work_items,
+ calibration_inputs=calibration_inputs,
+ )
+
+ print(
+ f"Worker result: "
+ f"{len(worker_result['completed'])} completed, "
+ f"{len(worker_result['failed'])} failed"
+ )
+
+ if worker_result["failed"]:
+ raise RuntimeError(f"National build failed: {worker_result['errors']}")
+
+ staging_volume.reload()
+ national_h5 = version_dir / "national" / "US.h5"
+ if not national_h5.exists():
+ raise RuntimeError(f"Expected {national_h5} not found after build")
+
+ print(f"Uploading {national_h5} to HF staging...")
+ result = subprocess.run(
+ [
+ "uv",
+ "run",
+ "python",
+ "-c",
+ f"""
+from policyengine_us_data.utils.data_upload import (
+ upload_to_staging_hf,
+)
+upload_to_staging_hf(
+ [("{national_h5}", "national/US.h5")],
+ "{version}",
+)
+print("Done")
+""",
+ ],
+ text=True,
+ env=os.environ.copy(),
+ )
+ if result.returncode != 0:
+ raise RuntimeError(f"Staging upload failed: {result.stderr}")
+
+ print("National H5 staged. Run promote workflow to publish.")
+ return (
+ f"National US.h5 built and staged for version {version}. "
+ f"Run main_national_promote to publish."
+ )
+
+
+@app.local_entrypoint()
+def main_national(branch: str = "main"):
+ """Build and stage national US.h5."""
+ result = coordinate_national_publish.remote(branch=branch)
+ print(result)
+
+
+@app.function(
+ image=image,
+ secrets=[hf_secret, gcp_secret],
+ volumes={VOLUME_MOUNT: staging_volume},
+ memory=4096,
+ timeout=3600,
+)
+def promote_national_publish(
+ branch: str = "main",
+) -> str:
+ """Promote national US.h5 from HF staging to production + GCS."""
+ setup_gcp_credentials()
+ setup_repo(branch)
+
+ version = get_version()
+ rel_paths = ["national/US.h5"]
+
+ result = subprocess.run(
+ [
+ "uv",
+ "run",
+ "python",
+ "-c",
+ f"""
+import json
+from pathlib import Path
+from policyengine_us_data.utils.data_upload import (
+ promote_staging_to_production_hf,
+ cleanup_staging_hf,
+ upload_local_area_file,
+)
+
+version = "{version}"
+rel_paths = {json.dumps(rel_paths)}
+version_dir = Path("{VOLUME_MOUNT}") / version
+
+print(f"Promoting national H5 from staging to production...")
+promoted = promote_staging_to_production_hf(rel_paths, version)
+print(f"Promoted {{promoted}} files to HuggingFace production")
+
+national_h5 = version_dir / "national" / "US.h5"
+if national_h5.exists():
+ print("Uploading national H5 to GCS...")
+ upload_local_area_file(
+ str(national_h5), "national", version=version, skip_hf=True
+ )
+ print("Uploaded national H5 to GCS")
+else:
+ print(f"WARNING: {{national_h5}} not on volume, skipping GCS")
+
+print("Cleaning up staging...")
+cleaned = cleanup_staging_hf(rel_paths, version)
+print(f"Cleaned up {{cleaned}} files from staging")
+print(f"Successfully promoted national H5 for version {{version}}")
+""",
+ ],
+ text=True,
+ env=os.environ.copy(),
+ )
+ if result.returncode != 0:
+ raise RuntimeError(f"National promote failed: {result.stderr}")
+
+ return f"National US.h5 promoted for version {version}"
+
+
+@app.local_entrypoint()
+def main_national_promote(branch: str = "main"):
+ """Promote staged national US.h5 to production."""
+ result = promote_national_publish.remote(branch=branch)
print(result)
diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py
index 689d245dd..87b2fb833 100644
--- a/modal_app/remote_calibration_runner.py
+++ b/modal_app/remote_calibration_runner.py
@@ -5,6 +5,9 @@
app = modal.App("policyengine-us-data-fit-weights")
hf_secret = modal.Secret.from_name("huggingface-token")
+calibration_vol = modal.Volume.from_name(
+ "calibration-data", create_if_missing=True
+)
image = (
modal.Image.debian_slim(python_version="3.11")
@@ -13,18 +16,211 @@
)
REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git"
+VOLUME_MOUNT = "/calibration-data"
+
+
+def _run_streaming(cmd, env=None, label=""):
+ """Run a subprocess, streaming output line-by-line.
+
+ Returns (returncode, captured_stdout_lines).
+ """
+ proc = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ env=env,
+ )
+ lines = []
+ for line in proc.stdout:
+ line = line.rstrip("\n")
+ if label:
+ print(f"[{label}] {line}", flush=True)
+ else:
+ print(line, flush=True)
+ lines.append(line)
+ proc.wait()
+ return proc.returncode, lines
-def _fit_weights_impl(branch: str, epochs: int) -> dict:
- """Shared implementation for weight fitting."""
+def _clone_and_install(branch: str):
+ """Clone the repo and install dependencies."""
os.chdir("/root")
subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True)
os.chdir("policyengine-us-data")
-
subprocess.run(["uv", "sync", "--extra", "l0"], check=True)
- print("Downloading calibration inputs from HuggingFace...")
- download_result = subprocess.run(
+
+def _append_hyperparams(
+ cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None
+):
+ """Append optional hyperparameter flags to a command list."""
+ if beta is not None:
+ cmd.extend(["--beta", str(beta)])
+ if lambda_l0 is not None:
+ cmd.extend(["--lambda-l0", str(lambda_l0)])
+ if lambda_l2 is not None:
+ cmd.extend(["--lambda-l2", str(lambda_l2)])
+ if learning_rate is not None:
+ cmd.extend(["--learning-rate", str(learning_rate)])
+ if log_freq is not None:
+ cmd.extend(["--log-freq", str(log_freq)])
+
+
+def _collect_outputs(cal_lines):
+ """Extract weights and log bytes from calibration output lines."""
+ output_path = None
+ log_path = None
+ cal_log_path = None
+ config_path = None
+ blocks_path = None
+ geo_labels_path = None
+ for line in cal_lines:
+ if "OUTPUT_PATH:" in line:
+ output_path = line.split("OUTPUT_PATH:")[1].strip()
+ elif "CONFIG_PATH:" in line:
+ config_path = line.split("CONFIG_PATH:")[1].strip()
+ elif "CAL_LOG_PATH:" in line:
+ cal_log_path = line.split("CAL_LOG_PATH:")[1].strip()
+ elif "GEO_LABELS_PATH:" in line:
+ geo_labels_path = line.split("GEO_LABELS_PATH:")[1].strip()
+ elif "BLOCKS_PATH:" in line:
+ blocks_path = line.split("BLOCKS_PATH:")[1].strip()
+ elif "LOG_PATH:" in line:
+ log_path = line.split("LOG_PATH:")[1].strip()
+
+ with open(output_path, "rb") as f:
+ weights_bytes = f.read()
+
+ log_bytes = None
+ if log_path:
+ with open(log_path, "rb") as f:
+ log_bytes = f.read()
+
+ cal_log_bytes = None
+ if cal_log_path:
+ with open(cal_log_path, "rb") as f:
+ cal_log_bytes = f.read()
+
+ config_bytes = None
+ if config_path:
+ with open(config_path, "rb") as f:
+ config_bytes = f.read()
+
+ blocks_bytes = None
+ if blocks_path and os.path.exists(blocks_path):
+ with open(blocks_path, "rb") as f:
+ blocks_bytes = f.read()
+
+ geo_labels_bytes = None
+ if geo_labels_path and os.path.exists(geo_labels_path):
+ with open(geo_labels_path, "rb") as f:
+ geo_labels_bytes = f.read()
+
+ return {
+ "weights": weights_bytes,
+ "log": log_bytes,
+ "cal_log": cal_log_bytes,
+ "config": config_bytes,
+ "blocks": blocks_bytes,
+ "geo_labels": geo_labels_bytes,
+ }
+
+
+def _trigger_repository_dispatch(event_type: str = "calibration-updated"):
+ """Fire a repository_dispatch event on GitHub."""
+ import json
+ import urllib.request
+
+ token = os.environ.get(
+ "GITHUB_TOKEN",
+ os.environ.get("POLICYENGINE_US_DATA_GITHUB_TOKEN"),
+ )
+ if not token:
+ print(
+ "WARNING: No GITHUB_TOKEN or "
+ "POLICYENGINE_US_DATA_GITHUB_TOKEN found. "
+ "Skipping repository_dispatch.",
+ flush=True,
+ )
+ return False
+
+ url = (
+ "https://api.github.com/repos/"
+ "PolicyEngine/policyengine-us-data/dispatches"
+ )
+ payload = json.dumps({"event_type": event_type}).encode()
+ req = urllib.request.Request(
+ url,
+ data=payload,
+ headers={
+ "Accept": "application/vnd.github+json",
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ },
+ method="POST",
+ )
+ resp = urllib.request.urlopen(req)
+ print(
+ f"Triggered repository_dispatch '{event_type}' "
+ f"(HTTP {resp.status})",
+ flush=True,
+ )
+ return True
+
+
+def _upload_source_imputed(lines):
+ """Parse SOURCE_IMPUTED_PATH from output and upload to HF."""
+ source_path = None
+ for line in lines:
+ if "SOURCE_IMPUTED_PATH:" in line:
+ raw = line.split("SOURCE_IMPUTED_PATH:")[1].strip()
+ source_path = raw.split("]")[-1].strip() if "]" in raw else raw
+ if not source_path or not os.path.exists(source_path):
+ return
+ print(f"Uploading source-imputed dataset: {source_path}", flush=True)
+ rc, _ = _run_streaming(
+ [
+ "uv",
+ "run",
+ "python",
+ "-c",
+ "from policyengine_us_data.utils.huggingface import upload; "
+ f"upload('{source_path}', "
+ "'policyengine/policyengine-us-data', "
+ "'calibration/"
+ "source_imputed_stratified_extended_cps.h5')",
+ ],
+ env=os.environ.copy(),
+ label="upload-source-imputed",
+ )
+ if rc != 0:
+ print(
+ "WARNING: Failed to upload source-imputed dataset",
+ flush=True,
+ )
+ else:
+ print("Source-imputed dataset uploaded to HF", flush=True)
+
+
+def _fit_weights_impl(
+ branch: str,
+ epochs: int,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> dict:
+ """Full pipeline: download data, build matrix, fit weights."""
+ _clone_and_install(branch)
+
+ print("Downloading calibration inputs from HuggingFace...", flush=True)
+ dl_rc, dl_lines = _run_streaming(
[
"uv",
"run",
@@ -36,66 +232,324 @@ def _fit_weights_impl(branch: str, epochs: int) -> dict:
"print(f\"DB: {paths['database']}\"); "
"print(f\"DATASET: {paths['dataset']}\")",
],
- capture_output=True,
- text=True,
env=os.environ.copy(),
+ label="download",
)
- print(download_result.stdout)
- if download_result.stderr:
- print("Download STDERR:", download_result.stderr)
- if download_result.returncode != 0:
- raise RuntimeError(f"Download failed: {download_result.returncode}")
+ if dl_rc != 0:
+ raise RuntimeError(f"Download failed with code {dl_rc}")
db_path = dataset_path = None
- for line in download_result.stdout.split("\n"):
- if line.startswith("DB:"):
+ for line in dl_lines:
+ if "DB:" in line:
db_path = line.split("DB:")[1].strip()
- elif line.startswith("DATASET:"):
+ elif "DATASET:" in line:
dataset_path = line.split("DATASET:")[1].strip()
script_path = "policyengine_us_data/calibration/unified_calibration.py"
- result = subprocess.run(
+ cmd = [
+ "uv",
+ "run",
+ "python",
+ script_path,
+ "--device",
+ "cuda",
+ "--epochs",
+ str(epochs),
+ "--db-path",
+ db_path,
+ "--dataset",
+ dataset_path,
+ ]
+ if target_config:
+ cmd.extend(["--target-config", target_config])
+ if not skip_county:
+ cmd.append("--county-level")
+ if workers > 1:
+ cmd.extend(["--workers", str(workers)])
+ _append_hyperparams(
+ cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq
+ )
+
+ cal_rc, cal_lines = _run_streaming(
+ cmd,
+ env=os.environ.copy(),
+ label="calibrate",
+ )
+ if cal_rc != 0:
+ raise RuntimeError(f"Script failed with code {cal_rc}")
+
+ _upload_source_imputed(cal_lines)
+
+ return _collect_outputs(cal_lines)
+
+
+def _fit_from_package_impl(
+ branch: str,
+ epochs: int,
+ volume_package_path: str = None,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+) -> dict:
+ """Fit weights from a pre-built calibration package."""
+ if not volume_package_path:
+ raise ValueError("volume_package_path is required")
+
+ _clone_and_install(branch)
+
+ pkg_path = "/root/calibration_package.pkl"
+ import shutil
+
+ shutil.copy(volume_package_path, pkg_path)
+ size = os.path.getsize(pkg_path)
+ print(
+ f"Copied package from volume ({size:,} bytes) to {pkg_path}",
+ flush=True,
+ )
+
+ script_path = "policyengine_us_data/calibration/unified_calibration.py"
+ cmd = [
+ "uv",
+ "run",
+ "python",
+ script_path,
+ "--device",
+ "cuda",
+ "--epochs",
+ str(epochs),
+ "--package-path",
+ pkg_path,
+ ]
+ if target_config:
+ cmd.extend(["--target-config", target_config])
+ _append_hyperparams(
+ cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq
+ )
+
+ print(f"Running command: {' '.join(cmd)}", flush=True)
+
+ cal_rc, cal_lines = _run_streaming(
+ cmd,
+ env=os.environ.copy(),
+ label="calibrate",
+ )
+ if cal_rc != 0:
+ raise RuntimeError(f"Script failed with code {cal_rc}")
+
+ return _collect_outputs(cal_lines)
+
+
+def _print_provenance_from_meta(
+ meta: dict, current_branch: str = None
+) -> None:
+ """Print provenance info and warn on branch mismatch."""
+ built = meta.get("created_at", "unknown")
+ branch = meta.get("git_branch", "unknown")
+ commit = meta.get("git_commit")
+ commit_short = commit[:8] if commit else "unknown"
+ dirty = " (DIRTY)" if meta.get("git_dirty") else ""
+ version = meta.get("package_version", "unknown")
+ print("--- Package Provenance ---", flush=True)
+ print(f" Built: {built}", flush=True)
+ print(
+ f" Branch: {branch} @ {commit_short}{dirty}",
+ flush=True,
+ )
+ print(f" Version: {version}", flush=True)
+ print("--------------------------", flush=True)
+ if current_branch and branch != "unknown" and branch != current_branch:
+ print(
+ f"WARNING: Package built on branch "
+ f"'{branch}', but fitting with "
+ f"--branch {current_branch}",
+ flush=True,
+ )
+
+
+def _write_package_sidecar(pkg_path: str) -> None:
+ """Extract metadata from a pickle package and write a JSON sidecar."""
+ import json
+ import pickle
+
+ sidecar_path = pkg_path.replace(".pkl", "_meta.json")
+ try:
+ with open(pkg_path, "rb") as f:
+ package = pickle.load(f)
+ meta = package.get("metadata", {})
+ del package
+ with open(sidecar_path, "w") as f:
+ json.dump(meta, f, indent=2)
+ print(
+ f"Sidecar metadata written to {sidecar_path}",
+ flush=True,
+ )
+ except Exception as e:
+ print(
+ f"WARNING: Failed to write sidecar: {e}",
+ flush=True,
+ )
+
+
+def _build_package_impl(
+ branch: str,
+ target_config: str = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> str:
+ """Download data, build X matrix, save package to volume."""
+ _clone_and_install(branch)
+
+ print(
+ "Downloading calibration inputs from HuggingFace...",
+ flush=True,
+ )
+ dl_rc, dl_lines = _run_streaming(
[
"uv",
"run",
"python",
- script_path,
- "--device",
- "cuda",
- "--epochs",
- str(epochs),
- "--db-path",
- db_path,
- "--dataset",
- dataset_path,
+ "-c",
+ "from policyengine_us_data.utils.huggingface import "
+ "download_calibration_inputs; "
+ "paths = download_calibration_inputs("
+ "'/root/calibration_data'); "
+ "print(f\"DB: {paths['database']}\"); "
+ "print(f\"DATASET: {paths['dataset']}\")",
],
- capture_output=True,
- text=True,
env=os.environ.copy(),
+ label="download",
)
- print(result.stdout)
- if result.stderr:
- print("STDERR:", result.stderr)
- if result.returncode != 0:
- raise RuntimeError(f"Script failed with code {result.returncode}")
+ if dl_rc != 0:
+ raise RuntimeError(f"Download failed with code {dl_rc}")
- output_path = None
- log_path = None
- for line in result.stdout.split("\n"):
- if "OUTPUT_PATH:" in line:
- output_path = line.split("OUTPUT_PATH:")[1].strip()
- elif "LOG_PATH:" in line:
- log_path = line.split("LOG_PATH:")[1].strip()
+ db_path = dataset_path = None
+ for line in dl_lines:
+ if "DB:" in line:
+ db_path = line.split("DB:")[1].strip()
+ elif "DATASET:" in line:
+ dataset_path = line.split("DATASET:")[1].strip()
- with open(output_path, "rb") as f:
- weights_bytes = f.read()
+ pkg_path = f"{VOLUME_MOUNT}/calibration_package.pkl"
+ script_path = "policyengine_us_data/calibration/unified_calibration.py"
+ cmd = [
+ "uv",
+ "run",
+ "python",
+ script_path,
+ "--device",
+ "cpu",
+ "--epochs",
+ "0",
+ "--db-path",
+ db_path,
+ "--dataset",
+ dataset_path,
+ "--build-only",
+ "--package-output",
+ pkg_path,
+ ]
+ if target_config:
+ cmd.extend(["--target-config", target_config])
+ if not skip_county:
+ cmd.append("--county-level")
+ if workers > 1:
+ cmd.extend(["--workers", str(workers)])
- log_bytes = None
- if log_path:
- with open(log_path, "rb") as f:
- log_bytes = f.read()
+ build_rc, build_lines = _run_streaming(
+ cmd,
+ env=os.environ.copy(),
+ label="build",
+ )
+ if build_rc != 0:
+ raise RuntimeError(f"Package build failed with code {build_rc}")
- return {"weights": weights_bytes, "log": log_bytes}
+ _upload_source_imputed(build_lines)
+
+ _write_package_sidecar(pkg_path)
+
+ size = os.path.getsize(pkg_path)
+ print(
+ f"Package saved to volume at {pkg_path} " f"({size:,} bytes)",
+ flush=True,
+ )
+ calibration_vol.commit()
+ return pkg_path
+
+
+@app.function(
+ image=image,
+ secrets=[hf_secret],
+ memory=65536,
+ cpu=4.0,
+ timeout=36000,
+ volumes={VOLUME_MOUNT: calibration_vol},
+)
+def build_package_remote(
+ branch: str = "main",
+ target_config: str = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> str:
+ return _build_package_impl(
+ branch,
+ target_config=target_config,
+ skip_county=skip_county,
+ workers=workers,
+ )
+
+
+@app.function(
+ image=image,
+ timeout=30,
+ volumes={VOLUME_MOUNT: calibration_vol},
+)
+def check_volume_package() -> dict:
+ """Check if a calibration package exists on the volume.
+
+ Reads the lightweight JSON sidecar for provenance fields.
+ Falls back to size/mtime if sidecar is missing.
+ """
+ import datetime
+ import json
+
+ pkg_path = f"{VOLUME_MOUNT}/calibration_package.pkl"
+ sidecar_path = f"{VOLUME_MOUNT}/calibration_package_meta.json"
+ if not os.path.exists(pkg_path):
+ return {"exists": False}
+
+ stat = os.stat(pkg_path)
+ mtime = datetime.datetime.fromtimestamp(
+ stat.st_mtime, tz=datetime.timezone.utc
+ )
+ info = {
+ "exists": True,
+ "size": stat.st_size,
+ "modified": mtime.strftime("%Y-%m-%d %H:%M UTC"),
+ }
+ if os.path.exists(sidecar_path):
+ try:
+ with open(sidecar_path) as f:
+ meta = json.load(f)
+ for key in (
+ "git_branch",
+ "git_commit",
+ "git_dirty",
+ "package_version",
+ "created_at",
+ "dataset_sha256",
+ "db_sha256",
+ ):
+ if key in meta:
+ info[key] = meta[key]
+ except Exception:
+ pass
+ return info
+
+
+# --- Full pipeline GPU functions ---
@app.function(
@@ -106,8 +560,30 @@ def _fit_weights_impl(branch: str, epochs: int) -> dict:
gpu="T4",
timeout=14400,
)
-def fit_weights_t4(branch: str = "main", epochs: int = 200) -> dict:
- return _fit_weights_impl(branch, epochs)
+def fit_weights_t4(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> dict:
+ return _fit_weights_impl(
+ branch,
+ epochs,
+ target_config,
+ beta,
+ lambda_l0,
+ lambda_l2,
+ learning_rate,
+ log_freq,
+ skip_county=skip_county,
+ workers=workers,
+ )
@app.function(
@@ -118,8 +594,30 @@ def fit_weights_t4(branch: str = "main", epochs: int = 200) -> dict:
gpu="A10",
timeout=14400,
)
-def fit_weights_a10(branch: str = "main", epochs: int = 200) -> dict:
- return _fit_weights_impl(branch, epochs)
+def fit_weights_a10(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> dict:
+ return _fit_weights_impl(
+ branch,
+ epochs,
+ target_config,
+ beta,
+ lambda_l0,
+ lambda_l2,
+ learning_rate,
+ log_freq,
+ skip_county=skip_county,
+ workers=workers,
+ )
@app.function(
@@ -130,8 +628,30 @@ def fit_weights_a10(branch: str = "main", epochs: int = 200) -> dict:
gpu="A100-40GB",
timeout=14400,
)
-def fit_weights_a100_40(branch: str = "main", epochs: int = 200) -> dict:
- return _fit_weights_impl(branch, epochs)
+def fit_weights_a100_40(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> dict:
+ return _fit_weights_impl(
+ branch,
+ epochs,
+ target_config,
+ beta,
+ lambda_l0,
+ lambda_l2,
+ learning_rate,
+ log_freq,
+ skip_county=skip_county,
+ workers=workers,
+ )
@app.function(
@@ -142,8 +662,30 @@ def fit_weights_a100_40(branch: str = "main", epochs: int = 200) -> dict:
gpu="A100-80GB",
timeout=14400,
)
-def fit_weights_a100_80(branch: str = "main", epochs: int = 200) -> dict:
- return _fit_weights_impl(branch, epochs)
+def fit_weights_a100_80(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> dict:
+ return _fit_weights_impl(
+ branch,
+ epochs,
+ target_config,
+ beta,
+ lambda_l0,
+ lambda_l2,
+ learning_rate,
+ log_freq,
+ skip_county=skip_county,
+ workers=workers,
+ )
@app.function(
@@ -154,8 +696,30 @@ def fit_weights_a100_80(branch: str = "main", epochs: int = 200) -> dict:
gpu="H100",
timeout=14400,
)
-def fit_weights_h100(branch: str = "main", epochs: int = 200) -> dict:
- return _fit_weights_impl(branch, epochs)
+def fit_weights_h100(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ skip_county: bool = True,
+ workers: int = 1,
+) -> dict:
+ return _fit_weights_impl(
+ branch,
+ epochs,
+ target_config,
+ beta,
+ lambda_l0,
+ lambda_l2,
+ learning_rate,
+ log_freq,
+ skip_county=skip_county,
+ workers=workers,
+ )
GPU_FUNCTIONS = {
@@ -167,22 +731,344 @@ def fit_weights_h100(branch: str = "main", epochs: int = 200) -> dict:
}
+# --- Package-path GPU functions ---
+
+
+@app.function(
+ image=image,
+ memory=32768,
+ cpu=4.0,
+ gpu="T4",
+ timeout=14400,
+ volumes={"/calibration-data": calibration_vol},
+)
+def fit_from_package_t4(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ volume_package_path: str = None,
+) -> dict:
+ return _fit_from_package_impl(
+ branch,
+ epochs,
+ volume_package_path=volume_package_path,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ )
+
+
+@app.function(
+ image=image,
+ memory=32768,
+ cpu=4.0,
+ gpu="A10",
+ timeout=14400,
+ volumes={"/calibration-data": calibration_vol},
+)
+def fit_from_package_a10(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ volume_package_path: str = None,
+) -> dict:
+ return _fit_from_package_impl(
+ branch,
+ epochs,
+ volume_package_path=volume_package_path,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ )
+
+
+@app.function(
+ image=image,
+ memory=32768,
+ cpu=4.0,
+ gpu="A100-40GB",
+ timeout=14400,
+ volumes={"/calibration-data": calibration_vol},
+)
+def fit_from_package_a100_40(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ volume_package_path: str = None,
+) -> dict:
+ return _fit_from_package_impl(
+ branch,
+ epochs,
+ volume_package_path=volume_package_path,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ )
+
+
+@app.function(
+ image=image,
+ memory=32768,
+ cpu=4.0,
+ gpu="A100-80GB",
+ timeout=14400,
+ volumes={"/calibration-data": calibration_vol},
+)
+def fit_from_package_a100_80(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ volume_package_path: str = None,
+) -> dict:
+ return _fit_from_package_impl(
+ branch,
+ epochs,
+ volume_package_path=volume_package_path,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ )
+
+
+@app.function(
+ image=image,
+ memory=32768,
+ cpu=4.0,
+ gpu="H100",
+ timeout=14400,
+ volumes={"/calibration-data": calibration_vol},
+)
+def fit_from_package_h100(
+ branch: str = "main",
+ epochs: int = 200,
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ volume_package_path: str = None,
+) -> dict:
+ return _fit_from_package_impl(
+ branch,
+ epochs,
+ volume_package_path=volume_package_path,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ )
+
+
+PACKAGE_GPU_FUNCTIONS = {
+ "T4": fit_from_package_t4,
+ "A10": fit_from_package_a10,
+ "A100-40GB": fit_from_package_a100_40,
+ "A100-80GB": fit_from_package_a100_80,
+ "H100": fit_from_package_h100,
+}
+
+
@app.local_entrypoint()
def main(
branch: str = "main",
epochs: int = 200,
gpu: str = "T4",
output: str = "calibration_weights.npy",
- log_output: str = "calibration_log.csv",
+ log_output: str = "unified_diagnostics.csv",
+ target_config: str = None,
+ beta: float = None,
+ lambda_l0: float = None,
+ lambda_l2: float = None,
+ learning_rate: float = None,
+ log_freq: int = None,
+ package_path: str = None,
+ full_pipeline: bool = False,
+ county_level: bool = False,
+ workers: int = 1,
+ push_results: bool = False,
+ trigger_publish: bool = False,
+ national: bool = False,
):
+ prefix = "national_" if national else ""
+ if national:
+ if lambda_l0 is None:
+ lambda_l0 = 1e-4
+ output = f"{prefix}{output}"
+ log_output = f"{prefix}{log_output}"
+
if gpu not in GPU_FUNCTIONS:
raise ValueError(
- f"Unknown GPU: {gpu}. Choose from: {list(GPU_FUNCTIONS.keys())}"
+ f"Unknown GPU: {gpu}. "
+ f"Choose from: {list(GPU_FUNCTIONS.keys())}"
)
- print(f"Running with GPU: {gpu}, epochs: {epochs}, branch: {branch}")
- func = GPU_FUNCTIONS[gpu]
- result = func.remote(branch=branch, epochs=epochs)
+ if package_path:
+ vol_path = f"{VOLUME_MOUNT}/calibration_package.pkl"
+ print(f"Reading package from {package_path}...", flush=True)
+ import json as _json
+ import pickle as _pkl
+
+ with open(package_path, "rb") as f:
+ package_bytes = f.read()
+ size = len(package_bytes)
+ # Extract metadata for sidecar
+ pkg_meta = _pkl.loads(package_bytes).get("metadata", {})
+ sidecar_bytes = _json.dumps(pkg_meta, indent=2).encode()
+ print(
+ f"Uploading package ({size:,} bytes) to Modal volume...",
+ flush=True,
+ )
+ with calibration_vol.batch_upload(force=True) as batch:
+ from io import BytesIO
+
+ batch.put(
+ BytesIO(package_bytes),
+ "calibration_package.pkl",
+ )
+ batch.put(
+ BytesIO(sidecar_bytes),
+ "calibration_package_meta.json",
+ )
+ calibration_vol.commit()
+ del package_bytes
+ print("Upload complete.", flush=True)
+ _print_provenance_from_meta(pkg_meta, branch)
+ func = PACKAGE_GPU_FUNCTIONS[gpu]
+ result = func.remote(
+ branch=branch,
+ epochs=epochs,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ volume_package_path=vol_path,
+ )
+ elif full_pipeline:
+ print(
+ "========================================",
+ flush=True,
+ )
+ print(
+ "Mode: full pipeline (download, build matrix, fit)",
+ flush=True,
+ )
+ print(
+ f"GPU: {gpu} | Epochs: {epochs} | " f"Branch: {branch}",
+ flush=True,
+ )
+ print(
+ "========================================",
+ flush=True,
+ )
+ func = GPU_FUNCTIONS[gpu]
+ result = func.remote(
+ branch=branch,
+ epochs=epochs,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ skip_county=not county_level,
+ workers=workers,
+ )
+ else:
+ vol_path = f"{VOLUME_MOUNT}/calibration_package.pkl"
+ vol_info = check_volume_package.remote()
+ if not vol_info["exists"]:
+ raise SystemExit(
+ "\nNo calibration package found on Modal volume.\n"
+ "Run 'make build-matrices' first, or use "
+ "--full-pipeline to build from scratch.\n"
+ )
+ if vol_info.get("created_at") or vol_info.get("git_branch"):
+ _print_provenance_from_meta(vol_info, branch)
+ mode_label = (
+ "national calibration"
+ if national
+ else "fitting from pre-built package"
+ )
+ print(
+ "========================================",
+ flush=True,
+ )
+ print(f"Mode: {mode_label}", flush=True)
+ print(
+ f"GPU: {gpu} | Epochs: {epochs} | " f"Branch: {branch}",
+ flush=True,
+ )
+ if push_results:
+ print(
+ "After fitting, will upload to HuggingFace:",
+ flush=True,
+ )
+ print(
+ f" - calibration/{prefix}calibration_weights.npy",
+ flush=True,
+ )
+ print(
+ f" - calibration/{prefix}stacked_blocks.npy",
+ flush=True,
+ )
+ print(
+ f" - calibration/logs/{prefix}* (diagnostics, "
+ "config, calibration log)",
+ flush=True,
+ )
+ print(
+ "========================================",
+ flush=True,
+ )
+ func = PACKAGE_GPU_FUNCTIONS[gpu]
+ result = func.remote(
+ branch=branch,
+ epochs=epochs,
+ target_config=target_config,
+ beta=beta,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ volume_package_path=vol_path,
+ )
with open(output, "wb") as f:
f.write(result["weights"])
@@ -191,4 +1077,96 @@ def main(
if result["log"]:
with open(log_output, "wb") as f:
f.write(result["log"])
- print(f"Calibration log saved to: {log_output}")
+ print(f"Diagnostics log saved to: {log_output}")
+
+ cal_log_output = f"{prefix}calibration_log.csv"
+ if result.get("cal_log"):
+ with open(cal_log_output, "wb") as f:
+ f.write(result["cal_log"])
+ print(f"Calibration log saved to: {cal_log_output}")
+
+ config_output = f"{prefix}unified_run_config.json"
+ if result.get("config"):
+ with open(config_output, "wb") as f:
+ f.write(result["config"])
+ print(f"Run config saved to: {config_output}")
+
+ blocks_output = f"{prefix}stacked_blocks.npy"
+ if result.get("blocks"):
+ with open(blocks_output, "wb") as f:
+ f.write(result["blocks"])
+ print(f"Stacked blocks saved to: {blocks_output}")
+
+ geo_labels_output = f"{prefix}geo_labels.json"
+ if result.get("geo_labels"):
+ with open(geo_labels_output, "wb") as f:
+ f.write(result["geo_labels"])
+ print(f"Geo labels saved to: {geo_labels_output}")
+
+ if push_results:
+ from policyengine_us_data.utils.huggingface import (
+ upload_calibration_artifacts,
+ )
+
+ upload_calibration_artifacts(
+ weights_path=output,
+ blocks_path=(blocks_output if result.get("blocks") else None),
+ geo_labels_path=(
+ geo_labels_output if result.get("geo_labels") else None
+ ),
+ log_dir=".",
+ prefix=prefix,
+ )
+
+ if trigger_publish:
+ _trigger_repository_dispatch()
+
+
+@app.local_entrypoint()
+def build_package(
+ branch: str = "main",
+ target_config: str = None,
+ county_level: bool = False,
+ workers: int = 1,
+):
+ """Build the calibration package (X matrix) on CPU and save
+ to Modal volume. Then run main() to fit."""
+ print(
+ "========================================",
+ flush=True,
+ )
+ print(
+ f"Mode: building calibration package (CPU only)",
+ flush=True,
+ )
+ print(f"Branch: {branch}", flush=True)
+ print(
+ "This builds the X matrix and saves it to " "a Modal volume.",
+ flush=True,
+ )
+ print(
+ "No GPU is used. Timeout: 10 hours.",
+ flush=True,
+ )
+ print(
+ "========================================",
+ flush=True,
+ )
+ vol_path = build_package_remote.remote(
+ branch=branch,
+ target_config=target_config,
+ skip_county=not county_level,
+ workers=workers,
+ )
+ print(
+ f"Package built and saved to Modal volume at {vol_path}",
+ flush=True,
+ )
+ print(
+ "\nTo fit weights, run:\n"
+ " modal run modal_app/remote_calibration_runner.py"
+ "::main \\\n"
+ f" --branch {branch} --gpu "
+ "--epochs --push-results",
+ flush=True,
+ )
diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py
index b197260e8..e19f0c82b 100644
--- a/modal_app/worker_script.py
+++ b/modal_app/worker_script.py
@@ -20,6 +20,24 @@ def main():
parser.add_argument("--dataset-path", required=True)
parser.add_argument("--db-path", required=True)
parser.add_argument("--output-dir", required=True)
+ parser.add_argument(
+ "--calibration-blocks",
+ type=str,
+ default=None,
+ help="Path to stacked_blocks.npy from calibration",
+ )
+ parser.add_argument(
+ "--geo-labels",
+ type=str,
+ default=None,
+ help="Path to geo_labels.json (overrides DB lookup)",
+ )
+ parser.add_argument(
+ "--stacked-takeup",
+ type=str,
+ default=None,
+ help="Path to stacked_takeup.npz from calibration",
+ )
args = parser.parse_args()
work_items = json.loads(args.work_items)
@@ -28,18 +46,43 @@ def main():
db_path = Path(args.db_path)
output_dir = Path(args.output_dir)
- from policyengine_us_data.datasets.cps.local_area_calibration.publish_local_area import (
- build_state_h5,
- build_district_h5,
- build_city_h5,
+ calibration_blocks = None
+ if args.calibration_blocks:
+ calibration_blocks = np.load(args.calibration_blocks)
+
+ stacked_takeup = None
+ if args.stacked_takeup:
+ stacked_takeup = dict(np.load(args.stacked_takeup))
+
+ from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ )
+
+ takeup_filter = [
+ info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values()
+ ]
+
+ original_stdout = sys.stdout
+ sys.stdout = sys.stderr
+
+ from policyengine_us_data.calibration.publish_local_area import (
+ build_h5,
+ NYC_COUNTIES,
+ NYC_CDS,
+ AT_LARGE_DISTRICTS,
)
- from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
+ from policyengine_us_data.calibration.calibration_utils import (
get_all_cds_from_database,
+ load_geo_labels,
STATE_CODES,
)
- db_uri = f"sqlite:///{db_path}"
- cds_to_calibrate = get_all_cds_from_database(db_uri)
+ if args.geo_labels and Path(args.geo_labels).exists():
+ geo_labels = load_geo_labels(args.geo_labels)
+ else:
+ db_uri = f"sqlite:///{db_path}"
+ geo_labels = get_all_cds_from_database(db_uri)
+ cds_to_calibrate = geo_labels
weights = np.load(weights_path)
results = {
@@ -54,44 +97,124 @@ def main():
try:
if item_type == "state":
- path = build_state_h5(
- state_code=item_id,
+ state_fips = None
+ for fips, code in STATE_CODES.items():
+ if code == item_id:
+ state_fips = fips
+ break
+ if state_fips is None:
+ raise ValueError(f"Unknown state code: {item_id}")
+ cd_subset = [
+ cd
+ for cd in cds_to_calibrate
+ if int(cd) // 100 == state_fips
+ ]
+ if not cd_subset:
+ print(
+ f"No CDs for {item_id}, skipping",
+ file=sys.stderr,
+ )
+ continue
+ states_dir = output_dir / "states"
+ states_dir.mkdir(parents=True, exist_ok=True)
+ path = build_h5(
weights=weights,
- cds_to_calibrate=cds_to_calibrate,
+ blocks=calibration_blocks,
dataset_path=dataset_path,
- output_dir=output_dir,
+ output_path=states_dir / f"{item_id}.h5",
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
)
+
elif item_type == "district":
state_code, dist_num = item_id.split("-")
- geoid = None
+ state_fips = None
for fips, code in STATE_CODES.items():
if code == state_code:
- geoid = f"{fips}{int(dist_num):02d}"
+ state_fips = fips
break
- if geoid is None:
+ if state_fips is None:
raise ValueError(f"Unknown state in district: {item_id}")
- path = build_district_h5(
- cd_geoid=geoid,
+ candidate = f"{state_fips}{int(dist_num):02d}"
+ if candidate in geo_labels:
+ geoid = candidate
+ else:
+ state_cds = [
+ cd for cd in geo_labels if int(cd) // 100 == state_fips
+ ]
+ if len(state_cds) == 1:
+ geoid = state_cds[0]
+ else:
+ raise ValueError(
+ f"CD {candidate} not found and "
+ f"state {state_code} has "
+ f"{len(state_cds)} CDs"
+ )
+
+ cd_int = int(geoid)
+ district_num = cd_int % 100
+ if district_num in AT_LARGE_DISTRICTS:
+ district_num = 1
+ friendly_name = f"{state_code}-{district_num:02d}"
+
+ districts_dir = output_dir / "districts"
+ districts_dir.mkdir(parents=True, exist_ok=True)
+ path = build_h5(
weights=weights,
- cds_to_calibrate=cds_to_calibrate,
+ blocks=calibration_blocks,
dataset_path=dataset_path,
- output_dir=output_dir,
+ output_path=districts_dir / f"{friendly_name}.h5",
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=[geoid],
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
)
+
elif item_type == "city":
- path = build_city_h5(
- city_name=item_id,
+ cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
+ if not cd_subset:
+ print(
+ "No NYC CDs found, skipping",
+ file=sys.stderr,
+ )
+ continue
+ cities_dir = output_dir / "cities"
+ cities_dir.mkdir(parents=True, exist_ok=True)
+ path = build_h5(
weights=weights,
+ blocks=calibration_blocks,
+ dataset_path=dataset_path,
+ output_path=cities_dir / "NYC.h5",
cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ county_filter=NYC_COUNTIES,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ elif item_type == "national":
+ national_dir = output_dir / "national"
+ national_dir.mkdir(parents=True, exist_ok=True)
+ path = build_h5(
+ weights=weights,
+ blocks=calibration_blocks,
dataset_path=dataset_path,
- output_dir=output_dir,
+ output_path=national_dir / "US.h5",
+ cds_to_calibrate=cds_to_calibrate,
+ stacked_takeup=stacked_takeup,
)
else:
raise ValueError(f"Unknown item type: {item_type}")
if path:
results["completed"].append(f"{item_type}:{item_id}")
- print(f"Completed {item_type}:{item_id}", file=sys.stderr)
+ print(
+ f"Completed {item_type}:{item_id}",
+ file=sys.stderr,
+ )
except Exception as e:
results["failed"].append(f"{item_type}:{item_id}")
@@ -102,8 +225,12 @@ def main():
"traceback": traceback.format_exc(),
}
)
- print(f"FAILED {item_type}:{item_id}: {e}", file=sys.stderr)
+ print(
+ f"FAILED {item_type}:{item_id}: {e}",
+ file=sys.stderr,
+ )
+ sys.stdout = original_stdout
print(json.dumps(results))
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/block_assignment.py b/policyengine_us_data/calibration/block_assignment.py
similarity index 86%
rename from policyengine_us_data/datasets/cps/local_area_calibration/block_assignment.py
rename to policyengine_us_data/calibration/block_assignment.py
index 73b435f69..ddeafa378 100644
--- a/policyengine_us_data/datasets/cps/local_area_calibration/block_assignment.py
+++ b/policyengine_us_data/calibration/block_assignment.py
@@ -100,22 +100,33 @@ def _build_county_fips_to_enum() -> Dict[str, str]:
return fips_to_enum
-def get_county_enum_index_from_block(block_geoid: str) -> int:
- """
- Get County enum index from block GEOID.
+def get_county_enum_index_from_fips(county_fips: str) -> int:
+ """Get County enum index from 5-digit county FIPS.
Args:
- block_geoid: 15-digit census block GEOID
+ county_fips: 5-digit county FIPS code (e.g. "37183")
Returns:
Integer index into County enum, or UNKNOWN index if not found
"""
- county_fips = get_county_fips_from_block(block_geoid)
fips_to_enum = _build_county_fips_to_enum()
enum_name = fips_to_enum.get(county_fips, "UNKNOWN")
return County._member_names_.index(enum_name)
+def get_county_enum_index_from_block(block_geoid: str) -> int:
+ """Get County enum index from block GEOID.
+
+ Args:
+ block_geoid: 15-digit census block GEOID
+
+ Returns:
+ Integer index into County enum, or UNKNOWN index if not found
+ """
+ county_fips = get_county_fips_from_block(block_geoid)
+ return get_county_enum_index_from_fips(county_fips)
+
+
# === CBSA Lookup ===
@@ -338,7 +349,7 @@ def _generate_fallback_blocks(cd_geoid: str, n_households: int) -> np.ndarray:
Array of 15-character block GEOID strings
"""
# Import here to avoid circular dependency
- from policyengine_us_data.datasets.cps.local_area_calibration.county_assignment import (
+ from policyengine_us_data.calibration.county_assignment import (
assign_counties_for_cd,
)
@@ -508,6 +519,82 @@ def assign_geography_for_cd(
}
+def derive_geography_from_blocks(
+ block_geoids: np.ndarray,
+) -> Dict[str, np.ndarray]:
+ """Derive all geography from pre-assigned block GEOIDs.
+
+ Given an array of block GEOIDs (already assigned by
+ calibration), derives county, tract, state, CBSA, SLDU,
+ SLDL, place, VTD, PUMA, ZCTA, and county enum index.
+
+ Args:
+ block_geoids: Array of 15-char block GEOID strings.
+
+ Returns:
+ Dict with same keys as assign_geography_for_cd.
+ """
+ county_fips = np.array(
+ [get_county_fips_from_block(b) for b in block_geoids]
+ )
+ tract_geoids = np.array(
+ [get_tract_geoid_from_block(b) for b in block_geoids]
+ )
+ state_fips = np.array([get_state_fips_from_block(b) for b in block_geoids])
+ cbsa_codes = np.array([get_cbsa_from_county(c) or "" for c in county_fips])
+ county_indices = np.array(
+ [get_county_enum_index_from_block(b) for b in block_geoids],
+ dtype=np.int32,
+ )
+
+ crosswalk = _load_block_crosswalk()
+ has_zcta = "zcta" in crosswalk.columns
+
+ sldu_list = []
+ sldl_list = []
+ place_fips_list = []
+ vtd_list = []
+ puma_list = []
+ zcta_list = []
+
+ for b in block_geoids:
+ if not crosswalk.empty and b in crosswalk.index:
+ row = crosswalk.loc[b]
+ sldu_list.append(row["sldu"] if pd.notna(row["sldu"]) else "")
+ sldl_list.append(row["sldl"] if pd.notna(row["sldl"]) else "")
+ place_fips_list.append(
+ row["place_fips"] if pd.notna(row["place_fips"]) else ""
+ )
+ vtd_list.append(row["vtd"] if pd.notna(row["vtd"]) else "")
+ puma_list.append(row["puma"] if pd.notna(row["puma"]) else "")
+ if has_zcta:
+ zcta_list.append(row["zcta"] if pd.notna(row["zcta"]) else "")
+ else:
+ zcta_list.append("")
+ else:
+ sldu_list.append("")
+ sldl_list.append("")
+ place_fips_list.append("")
+ vtd_list.append("")
+ puma_list.append("")
+ zcta_list.append("")
+
+ return {
+ "block_geoid": block_geoids,
+ "county_fips": county_fips,
+ "tract_geoid": tract_geoids,
+ "state_fips": state_fips,
+ "cbsa_code": cbsa_codes,
+ "sldu": np.array(sldu_list),
+ "sldl": np.array(sldl_list),
+ "place_fips": np.array(place_fips_list),
+ "vtd": np.array(vtd_list),
+ "puma": np.array(puma_list),
+ "zcta": np.array(zcta_list),
+ "county_index": county_indices,
+ }
+
+
# === County Filter Functions (for city-level datasets) ===
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py b/policyengine_us_data/calibration/calibration_utils.py
similarity index 86%
rename from policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py
rename to policyengine_us_data/calibration/calibration_utils.py
index 97c82360d..6920955b9 100644
--- a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py
+++ b/policyengine_us_data/calibration/calibration_utils.py
@@ -3,6 +3,7 @@
"""
from typing import Dict, List, Tuple
+import json
import numpy as np
import pandas as pd
@@ -521,6 +522,22 @@ def get_cd_index_mapping(db_uri: str = None):
return cd_to_index, index_to_cd, cds_ordered
+def save_geo_labels(labels: List[str], path) -> None:
+ """Save geo unit labels to JSON."""
+ from pathlib import Path
+
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as f:
+ json.dump(labels, f)
+
+
+def load_geo_labels(path) -> List[str]:
+ """Load geo unit labels from JSON."""
+ with open(path) as f:
+ return json.load(f)
+
+
def load_cd_geoadj_values(
cds_to_calibrate: List[str],
) -> Dict[str, float]:
@@ -548,93 +565,74 @@ def load_cd_geoadj_values(
)
rent_lookup[row["cd_geoid"]] = geoadj
- # Map each CD to calibrate to its geoadj value
- # Handle at-large districts: database uses XX01, rent CSV uses XX00
geoadj_dict = {}
for cd in cds_to_calibrate:
if cd in rent_lookup:
geoadj_dict[cd] = rent_lookup[cd]
else:
- # Try at-large mapping: XX01 -> XX00
- cd_int = int(cd)
- state_fips = cd_int // 100
- district = cd_int % 100
- if district == 1:
- at_large_cd = str(state_fips * 100) # XX00
- if at_large_cd in rent_lookup:
- geoadj_dict[cd] = rent_lookup[at_large_cd]
- continue
- # Fallback to national average (geoadj = 1.0)
print(f"Warning: No rent data for CD {cd}, using geoadj=1.0")
geoadj_dict[cd] = 1.0
return geoadj_dict
-def calculate_spm_thresholds_for_cd(
- sim,
- time_period: int,
- geoadj: float,
+def calculate_spm_thresholds_vectorized(
+ person_ages: np.ndarray,
+ person_spm_unit_ids: np.ndarray,
+ spm_unit_tenure_types: np.ndarray,
+ spm_unit_geoadj: np.ndarray,
year: int,
) -> np.ndarray:
+ """Calculate SPM thresholds for cloned SPM units from raw arrays.
+
+ Works without a Microsimulation instance. Counts adults/children
+ per SPM unit from person-level arrays, then computes
+ base_threshold * equivalence_scale * geoadj for each unit.
+
+ Args:
+ person_ages: Age per cloned person.
+ person_spm_unit_ids: New SPM unit ID per cloned person
+ (0-based contiguous).
+ spm_unit_tenure_types: Tenure type string per cloned SPM
+ unit (e.g. b"RENTER", b"OWNER_WITH_MORTGAGE").
+ spm_unit_geoadj: Geographic adjustment factor per cloned
+ SPM unit.
+ year: Tax year for base threshold lookup.
+
+ Returns:
+ Float32 array of SPM thresholds, one per SPM unit.
"""
- Calculate SPM thresholds for all SPM units using CD-specific geo-adjustment.
- """
- spm_unit_ids_person = sim.calculate("spm_unit_id", map_to="person").values
- ages = sim.calculate("age", map_to="person").values
-
- df = pd.DataFrame(
- {
- "spm_unit_id": spm_unit_ids_person,
- "is_adult": ages >= 18,
- "is_child": ages < 18,
- }
- )
-
- agg = (
- df.groupby("spm_unit_id")
- .agg(
- num_adults=("is_adult", "sum"),
- num_children=("is_child", "sum"),
+ n_units = len(spm_unit_tenure_types)
+
+ # Count adults and children per SPM unit
+ is_adult = person_ages >= 18
+ num_adults = np.zeros(n_units, dtype=np.int32)
+ num_children = np.zeros(n_units, dtype=np.int32)
+ np.add.at(num_adults, person_spm_unit_ids, is_adult.astype(np.int32))
+ np.add.at(num_children, person_spm_unit_ids, (~is_adult).astype(np.int32))
+
+ # Map tenure type strings to codes
+ tenure_codes = np.full(n_units, 3, dtype=np.int32)
+ for tenure_str, code in SPM_TENURE_STRING_TO_CODE.items():
+ tenure_bytes = (
+ tenure_str.encode() if isinstance(tenure_str, str) else tenure_str
)
- .reset_index()
- )
-
- tenure_types = sim.calculate(
- "spm_unit_tenure_type", map_to="spm_unit"
- ).values
- spm_unit_ids_unit = sim.calculate("spm_unit_id", map_to="spm_unit").values
-
- tenure_df = pd.DataFrame(
- {
- "spm_unit_id": spm_unit_ids_unit,
- "tenure_type": tenure_types,
- }
- )
-
- merged = agg.merge(tenure_df, on="spm_unit_id", how="left")
- merged["tenure_code"] = (
- merged["tenure_type"]
- .map(SPM_TENURE_STRING_TO_CODE)
- .fillna(3)
- .astype(int)
- )
+ mask = spm_unit_tenure_types == tenure_bytes
+ if not mask.any():
+ mask = spm_unit_tenure_types == tenure_str
+ tenure_codes[mask] = code
+ # Look up base thresholds
calc = SPMCalculator(year=year)
base_thresholds = calc.get_base_thresholds()
- n = len(merged)
- thresholds = np.zeros(n, dtype=np.float32)
-
- for i in range(n):
- tenure_str = TENURE_CODE_MAP.get(
- int(merged.iloc[i]["tenure_code"]), "renter"
- )
+ thresholds = np.zeros(n_units, dtype=np.float32)
+ for i in range(n_units):
+ tenure_str = TENURE_CODE_MAP.get(int(tenure_codes[i]), "renter")
base = base_thresholds[tenure_str]
equiv_scale = spm_equivalence_scale(
- int(merged.iloc[i]["num_adults"]),
- int(merged.iloc[i]["num_children"]),
+ int(num_adults[i]), int(num_children[i])
)
- thresholds[i] = base * equiv_scale * geoadj
+ thresholds[i] = base * equiv_scale * spm_unit_geoadj[i]
return thresholds
diff --git a/policyengine_us_data/calibration/check_staging_sums.py b/policyengine_us_data/calibration/check_staging_sums.py
new file mode 100644
index 000000000..9d2c4f879
--- /dev/null
+++ b/policyengine_us_data/calibration/check_staging_sums.py
@@ -0,0 +1,118 @@
+"""Sum key variables across all staging state H5 files.
+
+Quick smoke test: loads all 51 state H5s, sums key variables,
+compares to national references. No database needed. ~10 min runtime.
+
+Usage:
+ python -m policyengine_us_data.calibration.check_staging_sums
+ python -m policyengine_us_data.calibration.check_staging_sums \
+ --hf-prefix hf://policyengine/policyengine-us-data/staging
+"""
+
+import argparse
+
+import pandas as pd
+
+from policyengine_us_data.calibration.calibration_utils import (
+ STATE_CODES,
+)
+
+STATE_ABBRS = sorted(STATE_CODES.values())
+
+VARIABLES = [
+ "adjusted_gross_income",
+ "employment_income",
+ "self_employment_income",
+ "tax_unit_partnership_s_corp_income",
+ "taxable_pension_income",
+ "dividend_income",
+ "net_capital_gains",
+ "rental_income",
+ "taxable_interest_income",
+ "social_security",
+ "snap",
+ "ssi",
+ "income_tax_before_credits",
+ "eitc",
+ "refundable_ctc",
+ "real_estate_taxes",
+ "rent",
+ "is_pregnant",
+ "person_count",
+ "household_count",
+]
+
+DEFAULT_HF_PREFIX = "hf://policyengine/policyengine-us-data/staging/states"
+
+
+def main(argv=None):
+ parser = argparse.ArgumentParser(
+ description="Sum key variables across staging state H5 files"
+ )
+ parser.add_argument(
+ "--hf-prefix",
+ default=DEFAULT_HF_PREFIX,
+ help="HF path prefix for state H5 files "
+ f"(default: {DEFAULT_HF_PREFIX})",
+ )
+ args = parser.parse_args(argv)
+
+ from policyengine_us import Microsimulation
+
+ results = {}
+ errors = []
+
+ for i, st in enumerate(STATE_ABBRS):
+ print(
+ f"[{i + 1}/{len(STATE_ABBRS)}] {st}...",
+ end=" ",
+ flush=True,
+ )
+ try:
+ sim = Microsimulation(dataset=f"{args.hf_prefix}/{st}.h5")
+ row = {}
+ for var in VARIABLES:
+ try:
+ row[var] = float(sim.calculate(var).sum())
+ except Exception:
+ row[var] = None
+ results[st] = row
+ print("OK")
+ except Exception as e:
+ errors.append((st, str(e)))
+ print(f"FAILED: {e}")
+
+ df = pd.DataFrame(results).T
+ df.index.name = "state"
+
+ print("\n" + "=" * 70)
+ print("NATIONAL TOTALS (sum across all states)")
+ print("=" * 70)
+
+ totals = df.sum()
+ for var in VARIABLES:
+ val = totals[var]
+ if var in ("person_count", "household_count", "is_pregnant"):
+ print(f" {var:45s} {val:>15,.0f}")
+ else:
+ print(f" {var:45s} ${val:>15,.0f}")
+
+ print("\n" + "=" * 70)
+ print("REFERENCE VALUES (approximate, for sanity checking)")
+ print("=" * 70)
+ print(" US GDP ~$29T, US population ~335M, ~130M households")
+ print(" Total AGI ~$15T, Employment income ~$10T")
+ print(" SNAP ~$110B, SSI ~$60B, Social Security ~$1.2T")
+ print(" EITC ~$60B, CTC ~$120B")
+
+ if errors:
+ print(f"\n{len(errors)} states failed:")
+ for st, err in errors:
+ print(f" {st}: {err}")
+
+ print("\nPer-state details saved to staging_sums.csv")
+ df.to_csv("staging_sums.csv")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py
index 9aa64cbbc..e25af7d02 100644
--- a/policyengine_us_data/calibration/clone_and_assign.py
+++ b/policyengine_us_data/calibration/clone_and_assign.py
@@ -23,6 +23,7 @@ class GeographyAssignment:
block_geoid: np.ndarray # str array, 15-char block GEOIDs
cd_geoid: np.ndarray # str array of CD GEOIDs
+ county_fips: np.ndarray # str array of 5-char county FIPS
state_fips: np.ndarray # int array of 2-digit state FIPS
n_records: int
n_clones: int
@@ -52,7 +53,7 @@ def load_global_block_distribution():
df = pd.read_csv(csv_path, dtype={"block_geoid": str})
block_geoids = df["block_geoid"].values
- cd_geoids = df["cd_geoid"].astype(str).values
+ cd_geoids = np.array(df["cd_geoid"].astype(str).tolist())
state_fips = np.array([int(b[:2]) for b in block_geoids])
probs = df["probability"].values.astype(np.float64)
@@ -88,11 +89,44 @@ def assign_random_geography(
n_total = n_records * n_clones
rng = np.random.default_rng(seed)
- indices = rng.choice(len(blocks), size=n_total, p=probs)
+ indices = np.empty(n_total, dtype=np.int64)
+
+ # Clone 0: unrestricted draw
+ indices[:n_records] = rng.choice(len(blocks), size=n_records, p=probs)
+
+ assigned_cds = np.empty((n_clones, n_records), dtype=object)
+ assigned_cds[0] = cds[indices[:n_records]]
+
+ for clone_idx in range(1, n_clones):
+ start = clone_idx * n_records
+ clone_indices = rng.choice(len(blocks), size=n_records, p=probs)
+ clone_cds = cds[clone_indices]
+
+ collisions = np.zeros(n_records, dtype=bool)
+ for prev in range(clone_idx):
+ collisions |= clone_cds == assigned_cds[prev]
+
+ for _ in range(50):
+ n_bad = collisions.sum()
+ if n_bad == 0:
+ break
+ clone_indices[collisions] = rng.choice(
+ len(blocks), size=n_bad, p=probs
+ )
+ clone_cds = cds[clone_indices]
+ collisions = np.zeros(n_records, dtype=bool)
+ for prev in range(clone_idx):
+ collisions |= clone_cds == assigned_cds[prev]
+
+ indices[start : start + n_records] = clone_indices
+ assigned_cds[clone_idx] = clone_cds
+
+ assigned_blocks = blocks[indices]
return GeographyAssignment(
- block_geoid=blocks[indices],
+ block_geoid=assigned_blocks,
cd_geoid=cds[indices],
+ county_fips=np.array([b[:5] for b in assigned_blocks]),
state_fips=states[indices],
n_records=n_records,
n_clones=n_clones,
@@ -124,6 +158,7 @@ def double_geography_for_puf(
new_blocks = []
new_cds = []
+ new_counties = []
new_states = []
for c in range(n_clones):
@@ -131,14 +166,17 @@ def double_geography_for_puf(
end = start + n_old
clone_blocks = geography.block_geoid[start:end]
clone_cds = geography.cd_geoid[start:end]
+ clone_counties = geography.county_fips[start:end]
clone_states = geography.state_fips[start:end]
new_blocks.append(np.concatenate([clone_blocks, clone_blocks]))
new_cds.append(np.concatenate([clone_cds, clone_cds]))
+ new_counties.append(np.concatenate([clone_counties, clone_counties]))
new_states.append(np.concatenate([clone_states, clone_states]))
return GeographyAssignment(
block_geoid=np.concatenate(new_blocks),
cd_geoid=np.concatenate(new_cds),
+ county_fips=np.concatenate(new_counties),
state_fips=np.concatenate(new_states),
n_records=n_new,
n_clones=n_clones,
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/county_assignment.py b/policyengine_us_data/calibration/county_assignment.py
similarity index 98%
rename from policyengine_us_data/datasets/cps/local_area_calibration/county_assignment.py
rename to policyengine_us_data/calibration/county_assignment.py
index 780bc4c77..6d32d30bd 100644
--- a/policyengine_us_data/datasets/cps/local_area_calibration/county_assignment.py
+++ b/policyengine_us_data/calibration/county_assignment.py
@@ -38,7 +38,7 @@ def _build_state_counties() -> Dict[str, List[str]]:
def _generate_uniform_distribution(cd_geoid: str) -> Dict[str, float]:
"""Generate uniform distribution across counties in CD's state."""
- from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
+ from policyengine_us_data.calibration.calibration_utils import (
STATE_CODES,
)
diff --git a/policyengine_us_data/calibration/create_source_imputed_cps.py b/policyengine_us_data/calibration/create_source_imputed_cps.py
new file mode 100644
index 000000000..4381f72dd
--- /dev/null
+++ b/policyengine_us_data/calibration/create_source_imputed_cps.py
@@ -0,0 +1,91 @@
+"""Create source-imputed stratified extended CPS.
+
+Standalone step that runs ACS/SIPP/SCF source imputations on the
+stratified extended CPS, producing the dataset used by calibration
+and H5 generation.
+
+Usage:
+ python policyengine_us_data/calibration/create_source_imputed_cps.py
+"""
+
+import logging
+import sys
+from pathlib import Path
+
+import h5py
+
+from policyengine_us_data.storage import STORAGE_FOLDER
+
+logger = logging.getLogger(__name__)
+
+INPUT_PATH = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5")
+OUTPUT_PATH = str(
+ STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5"
+)
+
+
+def create_source_imputed_cps(
+ input_path: str = INPUT_PATH,
+ output_path: str = OUTPUT_PATH,
+ seed: int = 42,
+):
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.calibration.clone_and_assign import (
+ assign_random_geography,
+ )
+ from policyengine_us_data.calibration.source_impute import (
+ impute_source_variables,
+ )
+
+ logger.info("Loading dataset from %s", input_path)
+ sim = Microsimulation(dataset=input_path)
+ n_records = len(sim.calculate("household_id", map_to="household").values)
+
+ raw_keys = sim.dataset.load_dataset()["household_id"]
+ if isinstance(raw_keys, dict):
+ time_period = int(next(iter(raw_keys)))
+ else:
+ time_period = 2024
+
+ logger.info("Loaded %d households, time_period=%d", n_records, time_period)
+
+ geography = assign_random_geography(
+ n_records=n_records, n_clones=1, seed=seed
+ )
+ base_states = geography.state_fips[:n_records]
+
+ raw_data = sim.dataset.load_dataset()
+ data_dict = {}
+ for var in raw_data:
+ val = raw_data[var]
+ if isinstance(val, dict):
+ data_dict[var] = {
+ int(k) if k.isdigit() else k: v for k, v in val.items()
+ }
+ else:
+ data_dict[var] = {time_period: val[...]}
+
+ logger.info("Running source imputations...")
+ data_dict = impute_source_variables(
+ data=data_dict,
+ state_fips=base_states,
+ time_period=time_period,
+ dataset_path=input_path,
+ )
+
+ logger.info("Saving to %s", output_path)
+ with h5py.File(output_path, "w") as f:
+ for var, time_dict in data_dict.items():
+ for tp, values in time_dict.items():
+ f.create_dataset(f"{var}/{tp}", data=values)
+
+ logger.info("Done.")
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
+ stream=sys.stderr,
+ )
+ create_source_imputed_cps()
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py b/policyengine_us_data/calibration/create_stratified_cps.py
similarity index 100%
rename from policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py
rename to policyengine_us_data/calibration/create_stratified_cps.py
diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py
new file mode 100644
index 000000000..8d39bdf5a
--- /dev/null
+++ b/policyengine_us_data/calibration/publish_local_area.py
@@ -0,0 +1,1071 @@
+"""
+Build local area H5 files, optionally uploading to GCP and Hugging Face.
+
+Downloads calibration inputs from HF, generates state/district H5s
+with checkpointing. Uploads only occur when --upload is explicitly passed.
+
+Usage:
+ python publish_local_area.py [--skip-download] [--states-only] [--upload]
+"""
+
+import numpy as np
+from pathlib import Path
+from typing import List
+
+from policyengine_us import Microsimulation
+from policyengine_us_data.utils.huggingface import download_calibration_inputs
+from policyengine_us_data.utils.data_upload import (
+ upload_local_area_file,
+ upload_local_area_batch_to_hf,
+)
+from policyengine_us_data.calibration.calibration_utils import (
+ get_all_cds_from_database,
+ STATE_CODES,
+ load_cd_geoadj_values,
+ calculate_spm_thresholds_vectorized,
+)
+from policyengine_us_data.calibration.block_assignment import (
+ assign_geography_for_cd,
+ derive_geography_from_blocks,
+ get_county_filter_probability,
+)
+from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ apply_block_takeup_to_arrays,
+)
+
+CHECKPOINT_FILE = Path("completed_states.txt")
+CHECKPOINT_FILE_DISTRICTS = Path("completed_districts.txt")
+CHECKPOINT_FILE_CITIES = Path("completed_cities.txt")
+WORK_DIR = Path("local_area_build")
+
+NYC_COUNTIES = {
+ "QUEENS_COUNTY_NY",
+ "BRONX_COUNTY_NY",
+ "RICHMOND_COUNTY_NY",
+ "NEW_YORK_COUNTY_NY",
+ "KINGS_COUNTY_NY",
+}
+
+NYC_CDS = [
+ "3603",
+ "3605",
+ "3606",
+ "3607",
+ "3608",
+ "3609",
+ "3610",
+ "3611",
+ "3612",
+ "3613",
+ "3614",
+ "3615",
+ "3616",
+]
+
+
+def load_completed_states() -> set:
+ if CHECKPOINT_FILE.exists():
+ content = CHECKPOINT_FILE.read_text().strip()
+ if content:
+ return set(content.split("\n"))
+ return set()
+
+
+def record_completed_state(state_code: str):
+ with open(CHECKPOINT_FILE, "a") as f:
+ f.write(f"{state_code}\n")
+
+
+def load_completed_districts() -> set:
+ if CHECKPOINT_FILE_DISTRICTS.exists():
+ content = CHECKPOINT_FILE_DISTRICTS.read_text().strip()
+ if content:
+ return set(content.split("\n"))
+ return set()
+
+
+def record_completed_district(district_name: str):
+ with open(CHECKPOINT_FILE_DISTRICTS, "a") as f:
+ f.write(f"{district_name}\n")
+
+
+def load_completed_cities() -> set:
+ if CHECKPOINT_FILE_CITIES.exists():
+ content = CHECKPOINT_FILE_CITIES.read_text().strip()
+ if content:
+ return set(content.split("\n"))
+ return set()
+
+
+def record_completed_city(city_name: str):
+ with open(CHECKPOINT_FILE_CITIES, "a") as f:
+ f.write(f"{city_name}\n")
+
+
+def build_h5(
+ weights: np.ndarray,
+ blocks: np.ndarray,
+ dataset_path: Path,
+ output_path: Path,
+ cds_to_calibrate: List[str],
+ cd_subset: List[str] = None,
+ county_filter: set = None,
+ takeup_filter: List[str] = None,
+ stacked_takeup: dict = None,
+) -> Path:
+ """Build an H5 file by cloning records for each nonzero weight.
+
+ Uses fancy indexing on a single loaded simulation instead of
+ looping over CDs.
+
+ Each nonzero entry in the (n_geo, n_hh) weight matrix represents
+ a distinct household clone. This function clones entity arrays,
+ derives geography from blocks, reindexes entity IDs, recalculates
+ SPM thresholds, applies calibration takeup draws (when blocks are
+ provided), and writes the H5.
+
+ Args:
+ weights: Stacked weight vector, shape (n_geo * n_hh,).
+ blocks: Block GEOID per weight entry, same shape. When
+ provided, calibration takeup draws are applied
+ automatically.
+ dataset_path: Path to base dataset H5 file.
+ output_path: Where to write the output H5 file.
+ cds_to_calibrate: Ordered list of CD GEOIDs defining
+ weight matrix row ordering.
+ cd_subset: If provided, only include rows for these CDs.
+ county_filter: If provided, scale weights by P(target|CD)
+ for city datasets.
+ takeup_filter: List of takeup vars to apply.
+ stacked_takeup: Pre-computed weight-averaged takeup per
+ (CD, entity). Dict of {var_name: (n_cds, n_ent)}.
+ When provided, overrides block-based takeup draws.
+
+ Returns:
+ Path to the output H5 file.
+ """
+ import h5py
+ from collections import defaultdict
+ from policyengine_core.enums import Enum
+ from policyengine_us.variables.household.demographic.geographic.county.county_enum import (
+ County,
+ )
+
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ apply_takeup = blocks is not None
+
+ # === Load base simulation ===
+ sim = Microsimulation(dataset=str(dataset_path))
+ time_period = int(sim.default_calculation_period)
+ household_ids = sim.calculate("household_id", map_to="household").values
+ n_hh = len(household_ids)
+
+ if weights.shape[0] % n_hh != 0:
+ raise ValueError(
+ f"Weight vector length {weights.shape[0]} is not "
+ f"divisible by n_hh={n_hh}"
+ )
+ n_geo = weights.shape[0] // n_hh
+
+ # Generate blocks from assign_geography_for_cd if not provided
+ if blocks is None:
+ print("No blocks provided, generating from CD assignments...")
+ all_blocks = np.empty(n_geo * n_hh, dtype="U15")
+ for geo_idx, cd in enumerate(cds_to_calibrate):
+ geo = assign_geography_for_cd(
+ cd_geoid=cd,
+ n_households=n_hh,
+ seed=42 + int(cd),
+ )
+ start = geo_idx * n_hh
+ all_blocks[start : start + n_hh] = geo["block_geoid"]
+ blocks = all_blocks
+
+ if len(blocks) != len(weights):
+ raise ValueError(
+ f"Blocks length {len(blocks)} != " f"weights length {len(weights)}"
+ )
+
+ # === Reshape and filter weight matrix ===
+ W = weights.reshape(n_geo, n_hh).copy()
+
+ # CD subset filtering: zero out rows for CDs not in subset
+ if cd_subset is not None:
+ cd_index_set = set()
+ for cd in cd_subset:
+ if cd not in cds_to_calibrate:
+ raise ValueError(f"CD {cd} not in calibrated CDs list")
+ cd_index_set.add(cds_to_calibrate.index(cd))
+ for i in range(n_geo):
+ if i not in cd_index_set:
+ W[i, :] = 0
+
+ # County filtering: scale weights by P(target_counties | CD)
+ if county_filter is not None:
+ for geo_idx in range(n_geo):
+ cd = cds_to_calibrate[geo_idx]
+ p = get_county_filter_probability(cd, county_filter)
+ W[geo_idx, :] *= p
+
+ n_active_cds = len(cd_subset) if cd_subset is not None else n_geo
+ label = (
+ f"{n_active_cds} CDs"
+ if cd_subset is not None
+ else f"{n_geo} geo units"
+ )
+ print(f"\n{'='*60}")
+ print(f"Building {output_path.name} ({label}, {n_hh} households)")
+ print(f"{'='*60}")
+
+ # === Identify active clones ===
+ active_geo, active_hh = np.where(W > 0)
+ n_clones = len(active_geo)
+ clone_weights = W[active_geo, active_hh]
+ active_blocks = blocks.reshape(n_geo, n_hh)[active_geo, active_hh]
+
+ empty_count = np.sum(active_blocks == "")
+ if empty_count > 0:
+ raise ValueError(
+ f"{empty_count} active clones have empty block GEOIDs"
+ )
+
+ print(f"Active clones: {n_clones:,}")
+ print(f"Total weight: {clone_weights.sum():,.0f}")
+
+ # === Build entity membership maps ===
+ hh_id_to_idx = {int(hid): i for i, hid in enumerate(household_ids)}
+ person_hh_ids = sim.calculate("household_id", map_to="person").values
+
+ hh_to_persons = defaultdict(list)
+ for p_idx, p_hh_id in enumerate(person_hh_ids):
+ hh_to_persons[hh_id_to_idx[int(p_hh_id)]].append(p_idx)
+
+ SUB_ENTITIES = [
+ "tax_unit",
+ "spm_unit",
+ "family",
+ "marital_unit",
+ ]
+ hh_to_entity = {}
+ entity_id_arrays = {}
+ person_entity_id_arrays = {}
+
+ for ek in SUB_ENTITIES:
+ eids = sim.calculate(f"{ek}_id", map_to=ek).values
+ peids = sim.calculate(f"person_{ek}_id", map_to="person").values
+ entity_id_arrays[ek] = eids
+ person_entity_id_arrays[ek] = peids
+ eid_to_idx = {int(eid): i for i, eid in enumerate(eids)}
+
+ mapping = defaultdict(list)
+ seen = defaultdict(set)
+ for p_idx in range(len(person_hh_ids)):
+ hh_idx = hh_id_to_idx[int(person_hh_ids[p_idx])]
+ e_idx = eid_to_idx[int(peids[p_idx])]
+ if e_idx not in seen[hh_idx]:
+ seen[hh_idx].add(e_idx)
+ mapping[hh_idx].append(e_idx)
+ for hh_idx in mapping:
+ mapping[hh_idx].sort()
+ hh_to_entity[ek] = mapping
+
+ # === Build clone index arrays ===
+ hh_clone_idx = active_hh
+
+ persons_per_clone = np.array(
+ [len(hh_to_persons.get(h, [])) for h in active_hh]
+ )
+ person_parts = [
+ np.array(hh_to_persons.get(h, []), dtype=np.int64) for h in active_hh
+ ]
+ person_clone_idx = (
+ np.concatenate(person_parts)
+ if person_parts
+ else np.array([], dtype=np.int64)
+ )
+
+ entity_clone_idx = {}
+ entities_per_clone = {}
+ for ek in SUB_ENTITIES:
+ epc = np.array([len(hh_to_entity[ek].get(h, [])) for h in active_hh])
+ entities_per_clone[ek] = epc
+ parts = [
+ np.array(hh_to_entity[ek].get(h, []), dtype=np.int64)
+ for h in active_hh
+ ]
+ entity_clone_idx[ek] = (
+ np.concatenate(parts) if parts else np.array([], dtype=np.int64)
+ )
+
+ n_persons = len(person_clone_idx)
+ print(f"Cloned persons: {n_persons:,}")
+ for ek in SUB_ENTITIES:
+ print(f"Cloned {ek}s: {len(entity_clone_idx[ek]):,}")
+
+ # === Build new entity IDs and cross-references ===
+ new_hh_ids = np.arange(n_clones, dtype=np.int32)
+ new_person_ids = np.arange(n_persons, dtype=np.int32)
+ new_person_hh_ids = np.repeat(new_hh_ids, persons_per_clone)
+
+ new_entity_ids = {}
+ new_person_entity_ids = {}
+ clone_ids_for_persons = np.repeat(
+ np.arange(n_clones, dtype=np.int64), persons_per_clone
+ )
+
+ for ek in SUB_ENTITIES:
+ n_ents = len(entity_clone_idx[ek])
+ new_entity_ids[ek] = np.arange(n_ents, dtype=np.int32)
+
+ old_eids = entity_id_arrays[ek][entity_clone_idx[ek]].astype(np.int64)
+ clone_ids_e = np.repeat(
+ np.arange(n_clones, dtype=np.int64),
+ entities_per_clone[ek],
+ )
+
+ offset = int(old_eids.max()) + 1 if len(old_eids) > 0 else 1
+ entity_keys = clone_ids_e * offset + old_eids
+
+ sorted_order = np.argsort(entity_keys)
+ sorted_keys = entity_keys[sorted_order]
+ sorted_new = new_entity_ids[ek][sorted_order]
+
+ p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype(
+ np.int64
+ )
+ person_keys = clone_ids_for_persons * offset + p_old_eids
+
+ positions = np.searchsorted(sorted_keys, person_keys)
+ positions = np.clip(positions, 0, len(sorted_keys) - 1)
+ new_person_entity_ids[ek] = sorted_new[positions]
+
+ # === Derive geography from blocks (dedup optimization) ===
+ print("Deriving geography from blocks...")
+ unique_blocks, block_inv = np.unique(active_blocks, return_inverse=True)
+ print(f" {n_clones:,} blocks -> " f"{len(unique_blocks):,} unique")
+ unique_geo = derive_geography_from_blocks(unique_blocks)
+ geography = {k: v[block_inv] for k, v in unique_geo.items()}
+
+ # === Calculate weights for all entity levels ===
+ person_weights = np.repeat(clone_weights, persons_per_clone)
+ per_person_wt = clone_weights / np.maximum(persons_per_clone, 1)
+
+ entity_weights = {}
+ for ek in SUB_ENTITIES:
+ n_ents = len(entity_clone_idx[ek])
+ ent_person_counts = np.zeros(n_ents, dtype=np.int32)
+ np.add.at(
+ ent_person_counts,
+ new_person_entity_ids[ek],
+ 1,
+ )
+ clone_ids_e = np.repeat(np.arange(n_clones), entities_per_clone[ek])
+ entity_weights[ek] = per_person_wt[clone_ids_e] * ent_person_counts
+
+ # === Determine variables to save ===
+ vars_to_save = set(sim.input_variables)
+ vars_to_save.add("county")
+ vars_to_save.add("spm_unit_spm_threshold")
+ vars_to_save.add("congressional_district_geoid")
+ for gv in [
+ "block_geoid",
+ "tract_geoid",
+ "cbsa_code",
+ "sldu",
+ "sldl",
+ "place_fips",
+ "vtd",
+ "puma",
+ "zcta",
+ ]:
+ vars_to_save.add(gv)
+
+ # === Clone variable arrays ===
+ clone_idx_map = {
+ "household": hh_clone_idx,
+ "person": person_clone_idx,
+ }
+ for ek in SUB_ENTITIES:
+ clone_idx_map[ek] = entity_clone_idx[ek]
+
+ data = {}
+ variables_saved = 0
+
+ for variable in sim.tax_benefit_system.variables:
+ if variable not in vars_to_save:
+ continue
+
+ holder = sim.get_holder(variable)
+ periods = holder.get_known_periods()
+ if not periods:
+ continue
+
+ var_def = sim.tax_benefit_system.variables.get(variable)
+ entity_key = var_def.entity.key
+ if entity_key not in clone_idx_map:
+ continue
+
+ cidx = clone_idx_map[entity_key]
+ var_data = {}
+
+ for period in periods:
+ values = holder.get_array(period)
+
+ if var_def.value_type in (Enum, str) and variable != "county_fips":
+ if hasattr(values, "decode_to_str"):
+ values = values.decode_to_str().astype("S")
+ else:
+ values = values.astype("S")
+ elif variable == "county_fips":
+ values = values.astype("int32")
+ else:
+ values = np.array(values)
+
+ var_data[period] = values[cidx]
+ variables_saved += 1
+
+ if var_data:
+ data[variable] = var_data
+
+ print(f"Variables cloned: {variables_saved}")
+
+ # === Override entity IDs ===
+ data["household_id"] = {time_period: new_hh_ids}
+ data["person_id"] = {time_period: new_person_ids}
+ data["person_household_id"] = {
+ time_period: new_person_hh_ids,
+ }
+ for ek in SUB_ENTITIES:
+ data[f"{ek}_id"] = {
+ time_period: new_entity_ids[ek],
+ }
+ data[f"person_{ek}_id"] = {
+ time_period: new_person_entity_ids[ek],
+ }
+
+ # === Override weights ===
+ data["household_weight"] = {
+ time_period: clone_weights.astype(np.float32),
+ }
+ data["person_weight"] = {
+ time_period: person_weights.astype(np.float32),
+ }
+ for ek in SUB_ENTITIES:
+ data[f"{ek}_weight"] = {
+ time_period: entity_weights[ek].astype(np.float32),
+ }
+
+ # === Override geography ===
+ data["state_fips"] = {
+ time_period: geography["state_fips"].astype(np.int32),
+ }
+ county_names = np.array(
+ [County._member_names_[i] for i in geography["county_index"]]
+ ).astype("S")
+ data["county"] = {time_period: county_names}
+ data["county_fips"] = {
+ time_period: geography["county_fips"].astype(np.int32),
+ }
+ for gv in [
+ "block_geoid",
+ "tract_geoid",
+ "cbsa_code",
+ "sldu",
+ "sldl",
+ "place_fips",
+ "vtd",
+ "puma",
+ "zcta",
+ ]:
+ if gv in geography:
+ data[gv] = {
+ time_period: geography[gv].astype("S"),
+ }
+
+ # === Gap 4: Congressional district GEOID ===
+ clone_cd_geoids = np.array(
+ [int(cds_to_calibrate[g]) for g in active_geo],
+ dtype=np.int32,
+ )
+ data["congressional_district_geoid"] = {
+ time_period: clone_cd_geoids,
+ }
+
+ # === Gap 1: SPM threshold recalculation ===
+ print("Recalculating SPM thresholds...")
+ cd_geoadj_values = load_cd_geoadj_values(cds_to_calibrate)
+ # Build per-SPM-unit geoadj from clone's CD
+ spm_clone_ids = np.repeat(
+ np.arange(n_clones, dtype=np.int64),
+ entities_per_clone["spm_unit"],
+ )
+ spm_unit_geoadj = np.array(
+ [
+ cd_geoadj_values[cds_to_calibrate[active_geo[c]]]
+ for c in spm_clone_ids
+ ],
+ dtype=np.float64,
+ )
+
+ # Get cloned person ages and SPM unit IDs
+ person_ages = sim.calculate("age", map_to="person").values[
+ person_clone_idx
+ ]
+
+ # Get cloned tenure types
+ spm_tenure_holder = sim.get_holder("spm_unit_tenure_type")
+ spm_tenure_periods = spm_tenure_holder.get_known_periods()
+ if spm_tenure_periods:
+ raw_tenure = spm_tenure_holder.get_array(spm_tenure_periods[0])
+ if hasattr(raw_tenure, "decode_to_str"):
+ raw_tenure = raw_tenure.decode_to_str().astype("S")
+ else:
+ raw_tenure = np.array(raw_tenure).astype("S")
+ spm_tenure_cloned = raw_tenure[entity_clone_idx["spm_unit"]]
+ else:
+ spm_tenure_cloned = np.full(
+ len(entity_clone_idx["spm_unit"]),
+ b"RENTER",
+ dtype="S30",
+ )
+
+ new_spm_thresholds = calculate_spm_thresholds_vectorized(
+ person_ages=person_ages,
+ person_spm_unit_ids=new_person_entity_ids["spm_unit"],
+ spm_unit_tenure_types=spm_tenure_cloned,
+ spm_unit_geoadj=spm_unit_geoadj,
+ year=time_period,
+ )
+ data["spm_unit_spm_threshold"] = {
+ time_period: new_spm_thresholds,
+ }
+
+ # === Apply calibration takeup draws ===
+ if stacked_takeup is not None:
+ print("Applying pre-computed stacked takeup values...")
+ from policyengine_us_data.utils.takeup import (
+ SIMPLE_TAKEUP_VARS,
+ )
+
+ var_to_entity = {
+ spec["variable"]: spec["entity"] for spec in SIMPLE_TAKEUP_VARS
+ }
+ for spec in SIMPLE_TAKEUP_VARS:
+ var_name = spec["variable"]
+ entity_level = spec["entity"]
+
+ if entity_level == "person":
+ n_ent = n_persons
+ clone_ids_e = np.repeat(np.arange(n_clones), persons_per_clone)
+ base_indices = person_clone_idx
+ else:
+ n_ent = len(entity_clone_idx[entity_level])
+ clone_ids_e = np.repeat(
+ np.arange(n_clones),
+ entities_per_clone[entity_level],
+ )
+ base_indices = entity_clone_idx[entity_level]
+
+ if var_name in stacked_takeup:
+ geo_indices = active_geo[clone_ids_e]
+ values = stacked_takeup[var_name][geo_indices, base_indices]
+ data[var_name] = {time_period: values.astype(np.float32)}
+ else:
+ data[var_name] = {time_period: np.ones(n_ent, dtype=bool)}
+ elif apply_takeup:
+ print("Applying calibration takeup draws...")
+ entity_hh_indices = {
+ "person": np.repeat(
+ np.arange(n_clones, dtype=np.int64),
+ persons_per_clone,
+ ).astype(np.int64),
+ "tax_unit": np.repeat(
+ np.arange(n_clones, dtype=np.int64),
+ entities_per_clone["tax_unit"],
+ ).astype(np.int64),
+ "spm_unit": np.repeat(
+ np.arange(n_clones, dtype=np.int64),
+ entities_per_clone["spm_unit"],
+ ).astype(np.int64),
+ }
+ entity_counts = {
+ "person": n_persons,
+ "tax_unit": len(entity_clone_idx["tax_unit"]),
+ "spm_unit": len(entity_clone_idx["spm_unit"]),
+ }
+ hh_state_fips = geography["state_fips"].astype(np.int32)
+ original_hh_ids = household_ids[active_hh].astype(np.int64)
+
+ takeup_results = apply_block_takeup_to_arrays(
+ hh_blocks=active_blocks,
+ hh_state_fips=hh_state_fips,
+ hh_ids=original_hh_ids,
+ entity_hh_indices=entity_hh_indices,
+ entity_counts=entity_counts,
+ time_period=time_period,
+ takeup_filter=takeup_filter,
+ )
+ for var_name, bools in takeup_results.items():
+ data[var_name] = {time_period: bools}
+
+ # === Write H5 ===
+ with h5py.File(str(output_path), "w") as f:
+ for variable, periods in data.items():
+ grp = f.create_group(variable)
+ for period, values in periods.items():
+ grp.create_dataset(str(period), data=values)
+
+ print(f"\nH5 saved to {output_path}")
+
+ with h5py.File(str(output_path), "r") as f:
+ tp = str(time_period)
+ if "household_id" in f and tp in f["household_id"]:
+ n = len(f["household_id"][tp][:])
+ print(f"Verified: {n:,} households in output")
+ if "person_id" in f and tp in f["person_id"]:
+ n = len(f["person_id"][tp][:])
+ print(f"Verified: {n:,} persons in output")
+ if "household_weight" in f and tp in f["household_weight"]:
+ hw = f["household_weight"][tp][:]
+ print(f"Total population (HH weights): " f"{hw.sum():,.0f}")
+ if "person_weight" in f and tp in f["person_weight"]:
+ pw = f["person_weight"][tp][:]
+ print(f"Total population (person weights): " f"{pw.sum():,.0f}")
+
+ return output_path
+
+
+AT_LARGE_DISTRICTS = {0, 98}
+
+
+def get_district_friendly_name(cd_geoid: str) -> str:
+ """Convert GEOID to friendly name (e.g., '0101' -> 'AL-01')."""
+ cd_int = int(cd_geoid)
+ state_fips = cd_int // 100
+ district_num = cd_int % 100
+ if district_num in AT_LARGE_DISTRICTS:
+ district_num = 1
+ state_code = STATE_CODES.get(state_fips, str(state_fips))
+ return f"{state_code}-{district_num:02d}"
+
+
+def build_states(
+ weights_path: Path,
+ dataset_path: Path,
+ db_path: Path,
+ output_dir: Path,
+ completed_states: set,
+ hf_batch_size: int = 10,
+ calibration_blocks: np.ndarray = None,
+ takeup_filter: List[str] = None,
+ upload: bool = False,
+ stacked_takeup: dict = None,
+):
+ """Build state H5 files with checkpointing, optionally uploading."""
+ db_uri = f"sqlite:///{db_path}"
+ cds_to_calibrate = get_all_cds_from_database(db_uri)
+ w = np.load(weights_path)
+
+ states_dir = output_dir / "states"
+ states_dir.mkdir(parents=True, exist_ok=True)
+
+ hf_queue = []
+
+ for state_fips, state_code in STATE_CODES.items():
+ if state_code in completed_states:
+ print(f"Skipping {state_code} (already completed)")
+ continue
+
+ cd_subset = [
+ cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips
+ ]
+ if not cd_subset:
+ print(f"No CDs found for {state_code}, skipping")
+ continue
+
+ output_path = states_dir / f"{state_code}.h5"
+
+ try:
+ build_h5(
+ weights=w,
+ blocks=calibration_blocks,
+ dataset_path=dataset_path,
+ output_path=output_path,
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ if upload:
+ print(f"Uploading {state_code}.h5 to GCP...")
+ upload_local_area_file(
+ str(output_path), "states", skip_hf=True
+ )
+ hf_queue.append((str(output_path), "states"))
+
+ record_completed_state(state_code)
+ print(f"Completed {state_code}")
+
+ if upload and len(hf_queue) >= hf_batch_size:
+ print(
+ f"\nUploading batch of {len(hf_queue)} "
+ f"files to HuggingFace..."
+ )
+ upload_local_area_batch_to_hf(hf_queue)
+ hf_queue = []
+
+ except Exception as e:
+ print(f"ERROR building {state_code}: {e}")
+ raise
+
+ if upload and hf_queue:
+ print(
+ f"\nUploading final batch of {len(hf_queue)} "
+ f"files to HuggingFace..."
+ )
+ upload_local_area_batch_to_hf(hf_queue)
+
+
+def build_districts(
+ weights_path: Path,
+ dataset_path: Path,
+ db_path: Path,
+ output_dir: Path,
+ completed_districts: set,
+ hf_batch_size: int = 10,
+ calibration_blocks: np.ndarray = None,
+ takeup_filter: List[str] = None,
+ upload: bool = False,
+ stacked_takeup: dict = None,
+):
+ """Build district H5 files with checkpointing, optionally uploading."""
+ db_uri = f"sqlite:///{db_path}"
+ cds_to_calibrate = get_all_cds_from_database(db_uri)
+ w = np.load(weights_path)
+
+ districts_dir = output_dir / "districts"
+ districts_dir.mkdir(parents=True, exist_ok=True)
+
+ hf_queue = []
+
+ for i, cd_geoid in enumerate(cds_to_calibrate):
+ cd_int = int(cd_geoid)
+ state_fips = cd_int // 100
+ district_num = cd_int % 100
+ if district_num in AT_LARGE_DISTRICTS:
+ district_num = 1
+ state_code = STATE_CODES.get(state_fips, str(state_fips))
+ friendly_name = f"{state_code}-{district_num:02d}"
+
+ if friendly_name in completed_districts:
+ print(f"Skipping {friendly_name} (already completed)")
+ continue
+
+ output_path = districts_dir / f"{friendly_name}.h5"
+ print(
+ f"\n[{i+1}/{len(cds_to_calibrate)}] " f"Building {friendly_name}"
+ )
+
+ try:
+ build_h5(
+ weights=w,
+ blocks=calibration_blocks,
+ dataset_path=dataset_path,
+ output_path=output_path,
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=[cd_geoid],
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ if upload:
+ print(f"Uploading {friendly_name}.h5 to GCP...")
+ upload_local_area_file(
+ str(output_path), "districts", skip_hf=True
+ )
+ hf_queue.append((str(output_path), "districts"))
+
+ record_completed_district(friendly_name)
+ print(f"Completed {friendly_name}")
+
+ if upload and len(hf_queue) >= hf_batch_size:
+ print(
+ f"\nUploading batch of {len(hf_queue)} "
+ f"files to HuggingFace..."
+ )
+ upload_local_area_batch_to_hf(hf_queue)
+ hf_queue = []
+
+ except Exception as e:
+ print(f"ERROR building {friendly_name}: {e}")
+ raise
+
+ if upload and hf_queue:
+ print(
+ f"\nUploading final batch of {len(hf_queue)} "
+ f"files to HuggingFace..."
+ )
+ upload_local_area_batch_to_hf(hf_queue)
+
+
+def build_cities(
+ weights_path: Path,
+ dataset_path: Path,
+ db_path: Path,
+ output_dir: Path,
+ completed_cities: set,
+ hf_batch_size: int = 10,
+ calibration_blocks: np.ndarray = None,
+ takeup_filter: List[str] = None,
+ upload: bool = False,
+ stacked_takeup: dict = None,
+):
+ """Build city H5 files with checkpointing, optionally uploading."""
+ db_uri = f"sqlite:///{db_path}"
+ cds_to_calibrate = get_all_cds_from_database(db_uri)
+ w = np.load(weights_path)
+
+ cities_dir = output_dir / "cities"
+ cities_dir.mkdir(parents=True, exist_ok=True)
+
+ hf_queue = []
+
+ # NYC
+ if "NYC" in completed_cities:
+ print("Skipping NYC (already completed)")
+ else:
+ cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
+ if not cd_subset:
+ print("No NYC-related CDs found, skipping")
+ else:
+ output_path = cities_dir / "NYC.h5"
+
+ try:
+ build_h5(
+ weights=w,
+ blocks=calibration_blocks,
+ dataset_path=dataset_path,
+ output_path=output_path,
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ county_filter=NYC_COUNTIES,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ if upload:
+ print("Uploading NYC.h5 to GCP...")
+ upload_local_area_file(
+ str(output_path), "cities", skip_hf=True
+ )
+ hf_queue.append((str(output_path), "cities"))
+
+ record_completed_city("NYC")
+ print("Completed NYC")
+
+ except Exception as e:
+ print(f"ERROR building NYC: {e}")
+ raise
+
+ if upload and hf_queue:
+ print(
+ f"\nUploading batch of {len(hf_queue)} "
+ f"city files to HuggingFace..."
+ )
+ upload_local_area_batch_to_hf(hf_queue)
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="Build and publish local area H5 files"
+ )
+ parser.add_argument(
+ "--skip-download",
+ action="store_true",
+ help="Skip downloading inputs from HF (use existing files)",
+ )
+ parser.add_argument(
+ "--states-only",
+ action="store_true",
+ help="Only build and upload state files",
+ )
+ parser.add_argument(
+ "--districts-only",
+ action="store_true",
+ help="Only build and upload district files",
+ )
+ parser.add_argument(
+ "--cities-only",
+ action="store_true",
+ help="Only build and upload city files (e.g., NYC)",
+ )
+ parser.add_argument(
+ "--weights-path",
+ type=str,
+ help="Override path to weights file (for local testing)",
+ )
+ parser.add_argument(
+ "--dataset-path",
+ type=str,
+ help="Override path to dataset file (for local testing)",
+ )
+ parser.add_argument(
+ "--db-path",
+ type=str,
+ help="Override path to database file (for local testing)",
+ )
+ parser.add_argument(
+ "--calibration-blocks",
+ type=str,
+ help="Path to stacked_blocks.npy from calibration",
+ )
+ parser.add_argument(
+ "--stacked-takeup",
+ type=str,
+ help="Path to stacked_takeup.npz from calibration",
+ )
+ parser.add_argument(
+ "--upload",
+ action="store_true",
+ help="Upload to GCP and HuggingFace (default: build locally only)",
+ )
+ args = parser.parse_args()
+
+ WORK_DIR.mkdir(parents=True, exist_ok=True)
+
+ if args.weights_path and args.dataset_path and args.db_path:
+ inputs = {
+ "weights": Path(args.weights_path),
+ "dataset": Path(args.dataset_path),
+ "database": Path(args.db_path),
+ }
+ print("Using provided paths:")
+ for key, path in inputs.items():
+ print(f" {key}: {path}")
+ elif args.skip_download:
+ inputs = {
+ "weights": WORK_DIR / "calibration_weights.npy",
+ "dataset": (
+ WORK_DIR / "source_imputed_stratified_extended_cps.h5"
+ ),
+ "database": WORK_DIR / "policy_data.db",
+ }
+ print("Using existing files in work directory:")
+ for key, path in inputs.items():
+ if not path.exists():
+ raise FileNotFoundError(f"Expected file not found: {path}")
+ print(f" {key}: {path}")
+ else:
+ print("Downloading calibration inputs from Hugging Face...")
+ inputs = download_calibration_inputs(str(WORK_DIR))
+ for key, path in inputs.items():
+ inputs[key] = Path(path)
+
+ print(f"Using dataset: {inputs['dataset']}")
+
+ sim = Microsimulation(dataset=str(inputs["dataset"]))
+ n_hh = sim.calculate("household_id", map_to="household").shape[0]
+ print(f"\nBase dataset has {n_hh:,} households")
+
+ calibration_blocks = None
+ takeup_filter = None
+
+ if args.calibration_blocks:
+ calibration_blocks = np.load(args.calibration_blocks)
+ print(f"Loaded calibration blocks: {len(calibration_blocks):,}")
+ takeup_filter = [
+ info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values()
+ ]
+ print(f"Takeup filter: {takeup_filter}")
+
+ stacked_takeup = None
+ if getattr(args, "stacked_takeup", None):
+ stacked_takeup = dict(np.load(args.stacked_takeup))
+ print(f"Loaded stacked takeup: " f"{list(stacked_takeup.keys())}")
+
+ # Determine what to build based on flags
+ do_states = not args.districts_only and not args.cities_only
+ do_districts = not args.states_only and not args.cities_only
+ do_cities = not args.states_only and not args.districts_only
+
+ # If a specific *-only flag is set, only build that type
+ if args.states_only:
+ do_states = True
+ do_districts = False
+ do_cities = False
+ elif args.districts_only:
+ do_states = False
+ do_districts = True
+ do_cities = False
+ elif args.cities_only:
+ do_states = False
+ do_districts = False
+ do_cities = True
+
+ if do_states:
+ print("\n" + "=" * 60)
+ print("BUILDING STATE FILES")
+ print("=" * 60)
+ completed_states = load_completed_states()
+ print(f"Already completed: {len(completed_states)} states")
+ build_states(
+ inputs["weights"],
+ inputs["dataset"],
+ inputs["database"],
+ WORK_DIR,
+ completed_states,
+ calibration_blocks=calibration_blocks,
+ takeup_filter=takeup_filter,
+ upload=args.upload,
+ stacked_takeup=stacked_takeup,
+ )
+
+ if do_districts:
+ print("\n" + "=" * 60)
+ print("BUILDING DISTRICT FILES")
+ print("=" * 60)
+ completed_districts = load_completed_districts()
+ print(f"Already completed: {len(completed_districts)} districts")
+ build_districts(
+ inputs["weights"],
+ inputs["dataset"],
+ inputs["database"],
+ WORK_DIR,
+ completed_districts,
+ calibration_blocks=calibration_blocks,
+ takeup_filter=takeup_filter,
+ upload=args.upload,
+ stacked_takeup=stacked_takeup,
+ )
+
+ if do_cities:
+ print("\n" + "=" * 60)
+ print("BUILDING CITY FILES")
+ print("=" * 60)
+ completed_cities = load_completed_cities()
+ print(f"Already completed: {len(completed_cities)} cities")
+ build_cities(
+ inputs["weights"],
+ inputs["dataset"],
+ inputs["database"],
+ WORK_DIR,
+ completed_cities,
+ calibration_blocks=calibration_blocks,
+ takeup_filter=takeup_filter,
+ upload=args.upload,
+ stacked_takeup=stacked_takeup,
+ )
+
+ print("\n" + "=" * 60)
+ print("ALL DONE!")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/policyengine_us_data/calibration/sanity_checks.py b/policyengine_us_data/calibration/sanity_checks.py
new file mode 100644
index 000000000..91beb3b1b
--- /dev/null
+++ b/policyengine_us_data/calibration/sanity_checks.py
@@ -0,0 +1,291 @@
+"""Structural integrity checks for calibrated H5 files.
+
+Run standalone:
+ python -m policyengine_us_data.calibration.sanity_checks path/to/file.h5
+
+Or integrated via validate_staging.py --sanity-only.
+"""
+
+import logging
+from typing import List
+
+import h5py
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+KEY_MONETARY_VARS = [
+ "employment_income",
+ "adjusted_gross_income",
+ "snap",
+ "ssi",
+ "eitc",
+ "social_security",
+ "income_tax_before_credits",
+]
+
+TAKEUP_VARS = [
+ "takes_up_snap_if_eligible",
+ "takes_up_ssi_if_eligible",
+ "takes_up_aca_ptc_if_eligible",
+ "takes_up_medicaid_if_eligible",
+ "takes_up_tanf_if_eligible",
+ "takes_up_head_start_if_eligible",
+ "takes_up_early_head_start_if_eligible",
+ "takes_up_dc_property_tax_credit_if_eligible",
+]
+
+
+def run_sanity_checks(
+ h5_path: str,
+ period: int = 2024,
+) -> List[dict]:
+ """Run structural integrity checks on an H5 file.
+
+ Args:
+ h5_path: Path to the H5 dataset file.
+ period: Tax year (used for variable keys).
+
+ Returns:
+ List of {check, status, detail} dicts.
+ """
+ results = []
+
+ def _get(f, path):
+ """Resolve a slash path like 'var/2024' in the H5."""
+ try:
+ obj = f[path]
+ if isinstance(obj, h5py.Dataset):
+ return obj[:]
+ return None
+ except KeyError:
+ return None
+
+ with h5py.File(h5_path, "r") as f:
+ # 1. Weight non-negativity
+ w_key = f"household_weight/{period}"
+ weights = _get(f, w_key)
+ if weights is not None:
+ n_neg = int((weights < 0).sum())
+ if n_neg > 0:
+ results.append(
+ {
+ "check": "weight_non_negativity",
+ "status": "FAIL",
+ "detail": f"{n_neg} negative weights",
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": "weight_non_negativity",
+ "status": "PASS",
+ "detail": "",
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": "weight_non_negativity",
+ "status": "SKIP",
+ "detail": f"key {w_key} not found",
+ }
+ )
+
+ # 2. Entity ID uniqueness
+ for entity in [
+ "person",
+ "household",
+ "tax_unit",
+ "spm_unit",
+ ]:
+ ids = _get(f, f"{entity}_id/{period}")
+ if ids is None:
+ ids = _get(f, f"{entity}_id")
+ if ids is not None:
+ n_dup = len(ids) - len(np.unique(ids))
+ if n_dup > 0:
+ results.append(
+ {
+ "check": f"{entity}_id_uniqueness",
+ "status": "FAIL",
+ "detail": f"{n_dup} duplicate IDs",
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": f"{entity}_id_uniqueness",
+ "status": "PASS",
+ "detail": "",
+ }
+ )
+
+ # 3. No NaN/Inf in key monetary variables
+ for var in KEY_MONETARY_VARS:
+ vals = _get(f, f"{var}/{period}")
+ if vals is None:
+ continue
+ n_nan = int(np.isnan(vals).sum())
+ n_inf = int(np.isinf(vals).sum())
+ if n_nan > 0 or n_inf > 0:
+ results.append(
+ {
+ "check": f"no_nan_inf_{var}",
+ "status": "FAIL",
+ "detail": f"{n_nan} NaN, {n_inf} Inf",
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": f"no_nan_inf_{var}",
+ "status": "PASS",
+ "detail": "",
+ }
+ )
+
+ # 4. Person-to-household mapping
+ person_hh_arr = _get(f, f"person_household_id/{period}")
+ if person_hh_arr is None:
+ person_hh_arr = _get(f, "person_household_id")
+ hh_id_arr = _get(f, f"household_id/{period}")
+ if hh_id_arr is None:
+ hh_id_arr = _get(f, "household_id")
+
+ if person_hh_arr is not None and hh_id_arr is not None:
+ person_hh = set(person_hh_arr.tolist())
+ hh_ids = set(hh_id_arr.tolist())
+ orphans = person_hh - hh_ids
+ if orphans:
+ results.append(
+ {
+ "check": "person_household_mapping",
+ "status": "FAIL",
+ "detail": (
+ f"{len(orphans)} persons map to "
+ "non-existent households"
+ ),
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": "person_household_mapping",
+ "status": "PASS",
+ "detail": "",
+ }
+ )
+
+ # 5. Boolean takeup variables
+ for var in TAKEUP_VARS:
+ vals = _get(f, f"{var}/{period}")
+ if vals is None:
+ continue
+ unique = set(np.unique(vals).tolist())
+ valid = {True, False, 0, 1, 0.0, 1.0}
+ bad = unique - valid
+ if bad:
+ results.append(
+ {
+ "check": f"boolean_takeup_{var}",
+ "status": "FAIL",
+ "detail": (f"unexpected values: {bad}"),
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": f"boolean_takeup_{var}",
+ "status": "PASS",
+ "detail": "",
+ }
+ )
+
+ # 6. Reasonable per-capita ranges
+ if weights is not None:
+ total_hh = weights.sum()
+ if total_hh > 0:
+ emp = _get(f, f"employment_income/{period}")
+ if emp is not None:
+ total_emp = (emp * weights).sum()
+ per_hh = total_emp / total_hh
+ if per_hh < 10_000 or per_hh > 200_000:
+ results.append(
+ {
+ "check": "per_hh_employment_income",
+ "status": "WARN",
+ "detail": (
+ f"${per_hh:,.0f}/hh "
+ "(expected $10K-$200K)"
+ ),
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": "per_hh_employment_income",
+ "status": "PASS",
+ "detail": f"${per_hh:,.0f}/hh",
+ }
+ )
+
+ snap_arr = _get(f, f"snap/{period}")
+ if snap_arr is not None:
+ total_snap = (snap_arr * weights).sum()
+ per_hh_snap = total_snap / total_hh
+ if per_hh_snap < 0 or per_hh_snap > 10_000:
+ results.append(
+ {
+ "check": "per_hh_snap",
+ "status": "WARN",
+ "detail": (
+ f"${per_hh_snap:,.0f}/hh "
+ "(expected $0-$10K)"
+ ),
+ }
+ )
+ else:
+ results.append(
+ {
+ "check": "per_hh_snap",
+ "status": "PASS",
+ "detail": f"${per_hh_snap:,.0f}/hh",
+ }
+ )
+
+ return results
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="Run structural sanity checks on an H5 file"
+ )
+ parser.add_argument("h5_path", help="Path to the H5 file")
+ parser.add_argument(
+ "--period",
+ type=int,
+ default=2024,
+ help="Tax year (default: 2024)",
+ )
+ args = parser.parse_args()
+
+ results = run_sanity_checks(args.h5_path, args.period)
+
+ n_fail = sum(1 for r in results if r["status"] == "FAIL")
+ n_warn = sum(1 for r in results if r["status"] == "WARN")
+
+ for r in results:
+ icon = "PASS" if r["status"] == "PASS" else r["status"]
+ detail = f" — {r['detail']}" if r["detail"] else ""
+ print(f" [{icon}] {r['check']}{detail}")
+
+ print(f"\n{len(results)} checks: " f"{n_fail} failures, {n_warn} warnings")
+ if n_fail > 0:
+ raise SystemExit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/policyengine_us_data/calibration/stacked_dataset_builder.py b/policyengine_us_data/calibration/stacked_dataset_builder.py
new file mode 100644
index 000000000..ef14b1c5a
--- /dev/null
+++ b/policyengine_us_data/calibration/stacked_dataset_builder.py
@@ -0,0 +1,256 @@
+"""
+CLI for creating CD-stacked datasets via build_h5.
+
+All H5 building logic lives in build_h5() in publish_local_area.py.
+This module provides a CLI for common build modes (national, states,
+cds, single-cd, single-state, nyc).
+"""
+
+import os
+import numpy as np
+from pathlib import Path
+
+from policyengine_us_data.calibration.calibration_utils import (
+ get_all_cds_from_database,
+ STATE_CODES,
+)
+
+if __name__ == "__main__":
+ import argparse
+
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.calibration.publish_local_area import (
+ build_h5,
+ NYC_COUNTIES,
+ NYC_CDS,
+ TAKEUP_AFFECTED_TARGETS,
+ )
+
+ parser = argparse.ArgumentParser(
+ description="Create sparse CD-stacked datasets"
+ )
+ parser.add_argument(
+ "--weights-path",
+ required=True,
+ help="Path to w_cd.npy file",
+ )
+ parser.add_argument(
+ "--dataset-path",
+ required=True,
+ help="Path to stratified dataset .h5 file",
+ )
+ parser.add_argument(
+ "--db-path",
+ required=True,
+ help="Path to policy_data.db",
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="./temp",
+ help="Output directory for files",
+ )
+ parser.add_argument(
+ "--mode",
+ choices=[
+ "national",
+ "states",
+ "cds",
+ "single-cd",
+ "single-state",
+ "nyc",
+ ],
+ default="national",
+ help="Output mode",
+ )
+ parser.add_argument(
+ "--cd",
+ type=str,
+ help="Single CD GEOID (--mode single-cd)",
+ )
+ parser.add_argument(
+ "--state",
+ type=str,
+ help="State code e.g. RI, CA (--mode single-state)",
+ )
+ parser.add_argument(
+ "--calibration-blocks",
+ default=None,
+ help="Path to stacked_blocks.npy",
+ )
+ parser.add_argument(
+ "--stacked-takeup",
+ default=None,
+ help="Path to stacked_takeup.npz from calibration",
+ )
+
+ args = parser.parse_args()
+ dataset_path_str = args.dataset_path
+ weights_path_str = args.weights_path
+ db_path = Path(args.db_path).resolve()
+ output_dir = args.output_dir
+ mode = args.mode
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ w = np.load(weights_path_str)
+ db_uri = f"sqlite:///{db_path}"
+
+ cds_to_calibrate = get_all_cds_from_database(db_uri)
+ print(f"Found {len(cds_to_calibrate)} congressional districts")
+
+ assert_sim = Microsimulation(dataset=dataset_path_str)
+ n_hh = assert_sim.calculate("household_id", map_to="household").shape[0]
+ expected_length = len(cds_to_calibrate) * n_hh
+
+ if len(w) != expected_length:
+ raise ValueError(
+ f"Weight vector length ({len(w):,}) doesn't match "
+ f"expected ({expected_length:,})"
+ )
+
+ cal_blocks = None
+ takeup_filter = None
+ if args.calibration_blocks:
+ cal_blocks = np.load(args.calibration_blocks)
+ print(f"Loaded calibration blocks: {len(cal_blocks):,}")
+ takeup_filter = [
+ info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values()
+ ]
+ print(f"Takeup filter: {takeup_filter}")
+
+ stacked_takeup = None
+ if args.stacked_takeup:
+ stacked_takeup = dict(np.load(args.stacked_takeup))
+ print(f"Loaded stacked takeup: " f"{list(stacked_takeup.keys())}")
+
+ if mode == "national":
+ output_path = f"{output_dir}/national.h5"
+ print(f"\nCreating national dataset: {output_path}")
+ build_h5(
+ weights=np.array(w),
+ blocks=cal_blocks,
+ dataset_path=Path(dataset_path_str),
+ output_path=Path(output_path),
+ cds_to_calibrate=cds_to_calibrate,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ elif mode == "states":
+ for state_fips, state_code in STATE_CODES.items():
+ cd_subset = [
+ cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips
+ ]
+ if not cd_subset:
+ continue
+ output_path = f"{output_dir}/{state_code}.h5"
+ print(f"\nCreating {state_code}: {output_path}")
+ build_h5(
+ weights=np.array(w),
+ blocks=cal_blocks,
+ dataset_path=Path(dataset_path_str),
+ output_path=Path(output_path),
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ elif mode == "cds":
+ for i, cd_geoid in enumerate(cds_to_calibrate):
+ cd_int = int(cd_geoid)
+ state_fips = cd_int // 100
+ district_num = cd_int % 100
+ if district_num in (0, 98):
+ district_num = 1
+ state_code = STATE_CODES.get(state_fips, str(state_fips))
+ friendly_name = f"{state_code}-{district_num:02d}"
+
+ output_path = f"{output_dir}/{friendly_name}.h5"
+ print(
+ f"\n[{i+1}/{len(cds_to_calibrate)}] "
+ f"Creating {friendly_name}.h5"
+ )
+ build_h5(
+ weights=np.array(w),
+ blocks=cal_blocks,
+ dataset_path=Path(dataset_path_str),
+ output_path=Path(output_path),
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=[cd_geoid],
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ elif mode == "single-cd":
+ if not args.cd:
+ raise ValueError("--cd required with --mode single-cd")
+ if args.cd not in cds_to_calibrate:
+ raise ValueError(f"CD {args.cd} not in calibrated CDs list")
+ output_path = f"{output_dir}/{args.cd}.h5"
+ print(f"\nCreating single CD dataset: {output_path}")
+ build_h5(
+ weights=np.array(w),
+ blocks=cal_blocks,
+ dataset_path=Path(dataset_path_str),
+ output_path=Path(output_path),
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=[args.cd],
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ elif mode == "single-state":
+ if not args.state:
+ raise ValueError("--state required with --mode single-state")
+ state_code_upper = args.state.upper()
+ state_fips = None
+ for fips, code in STATE_CODES.items():
+ if code == state_code_upper:
+ state_fips = fips
+ break
+ if state_fips is None:
+ raise ValueError(f"Unknown state code: {args.state}")
+
+ cd_subset = [
+ cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips
+ ]
+ if not cd_subset:
+ raise ValueError(f"No CDs found for state {state_code_upper}")
+
+ output_path = f"{output_dir}/{state_code_upper}.h5"
+ print(
+ f"\nCreating {state_code_upper} with "
+ f"{len(cd_subset)} CDs: {output_path}"
+ )
+ build_h5(
+ weights=np.array(w),
+ blocks=cal_blocks,
+ dataset_path=Path(dataset_path_str),
+ output_path=Path(output_path),
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ elif mode == "nyc":
+ cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
+ if not cd_subset:
+ raise ValueError("No NYC CDs found")
+
+ output_path = f"{output_dir}/NYC.h5"
+ print(f"\nCreating NYC with {len(cd_subset)} CDs: " f"{output_path}")
+ build_h5(
+ weights=np.array(w),
+ blocks=cal_blocks,
+ dataset_path=Path(dataset_path_str),
+ output_path=Path(output_path),
+ cds_to_calibrate=cds_to_calibrate,
+ cd_subset=cd_subset,
+ county_filter=NYC_COUNTIES,
+ takeup_filter=takeup_filter,
+ stacked_takeup=stacked_takeup,
+ )
+
+ print("\nDone!")
diff --git a/policyengine_us_data/calibration/target_config.yaml b/policyengine_us_data/calibration/target_config.yaml
new file mode 100644
index 000000000..dc867ef73
--- /dev/null
+++ b/policyengine_us_data/calibration/target_config.yaml
@@ -0,0 +1,212 @@
+
+include:
+ # === DISTRICT — age demographics ===
+ - variable: person_count
+ geo_level: district
+ domain_variable: age
+
+ # === DISTRICT — count targets ===
+ - variable: person_count
+ geo_level: district
+ domain_variable: adjusted_gross_income
+ - variable: household_count
+ geo_level: district
+
+ # === DISTRICT — dollar targets (needed_w 7-41, compatible) ===
+ - variable: real_estate_taxes
+ geo_level: district
+ - variable: self_employment_income
+ geo_level: district
+ - variable: taxable_pension_income
+ geo_level: district
+ - variable: refundable_ctc
+ geo_level: district
+ - variable: unemployment_compensation
+ geo_level: district
+
+ # === DISTRICT — ACA PTC ===
+ - variable: aca_ptc
+ geo_level: district
+ - variable: tax_unit_count
+ geo_level: district
+ domain_variable: aca_ptc
+
+ # === STATE ===
+ - variable: person_count
+ geo_level: state
+ domain_variable: medicaid_enrolled
+ # TODO: re-enable once make data is re-run with is_pregnant preserved
+ # - variable: person_count
+ # geo_level: state
+ # domain_variable: is_pregnant
+ - variable: snap
+ geo_level: state
+
+ # === NATIONAL — aggregate dollar targets ===
+ - variable: adjusted_gross_income
+ geo_level: national
+ - variable: child_support_expense
+ geo_level: national
+ - variable: child_support_received
+ geo_level: national
+ - variable: eitc
+ geo_level: national
+ - variable: health_insurance_premiums_without_medicare_part_b
+ geo_level: national
+ - variable: medicaid
+ geo_level: national
+ - variable: medicare_part_b_premiums
+ geo_level: national
+ - variable: other_medical_expenses
+ geo_level: national
+ - variable: over_the_counter_health_expenses
+ geo_level: national
+ - variable: qualified_business_income_deduction
+ geo_level: national
+ - variable: rent
+ geo_level: national
+ - variable: salt_deduction
+ geo_level: national
+ - variable: snap
+ geo_level: national
+ - variable: social_security
+ geo_level: national
+ - variable: social_security_disability
+ geo_level: national
+ - variable: social_security_retirement
+ geo_level: national
+ - variable: spm_unit_capped_housing_subsidy
+ geo_level: national
+ - variable: spm_unit_capped_work_childcare_expenses
+ geo_level: national
+ - variable: ssi
+ geo_level: national
+ - variable: tanf
+ geo_level: national
+ - variable: tip_income
+ geo_level: national
+ - variable: unemployment_compensation
+ geo_level: national
+
+ # === NATIONAL — IRS SOI domain-constrained dollar targets ===
+ - variable: aca_ptc
+ geo_level: national
+ domain_variable: aca_ptc
+ - variable: dividend_income
+ geo_level: national
+ domain_variable: dividend_income
+ - variable: eitc
+ geo_level: national
+ domain_variable: eitc_child_count
+ - variable: income_tax_positive
+ geo_level: national
+ - variable: income_tax_before_credits
+ geo_level: national
+ domain_variable: income_tax_before_credits
+ - variable: net_capital_gains
+ geo_level: national
+ domain_variable: net_capital_gains
+ - variable: qualified_business_income_deduction
+ geo_level: national
+ domain_variable: qualified_business_income_deduction
+ - variable: qualified_dividend_income
+ geo_level: national
+ domain_variable: qualified_dividend_income
+ - variable: refundable_ctc
+ geo_level: national
+ domain_variable: refundable_ctc
+ - variable: rental_income
+ geo_level: national
+ domain_variable: rental_income
+ - variable: salt
+ geo_level: national
+ domain_variable: salt
+ - variable: self_employment_income
+ geo_level: national
+ domain_variable: self_employment_income
+ - variable: tax_exempt_interest_income
+ geo_level: national
+ domain_variable: tax_exempt_interest_income
+ - variable: tax_unit_partnership_s_corp_income
+ geo_level: national
+ domain_variable: tax_unit_partnership_s_corp_income
+ - variable: taxable_interest_income
+ geo_level: national
+ domain_variable: taxable_interest_income
+ - variable: taxable_ira_distributions
+ geo_level: national
+ domain_variable: taxable_ira_distributions
+ - variable: taxable_pension_income
+ geo_level: national
+ domain_variable: taxable_pension_income
+ - variable: taxable_social_security
+ geo_level: national
+ domain_variable: taxable_social_security
+ - variable: unemployment_compensation
+ geo_level: national
+ domain_variable: unemployment_compensation
+
+ # === NATIONAL — IRS SOI filer count targets ===
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: aca_ptc
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: dividend_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: eitc_child_count
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: income_tax
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: income_tax_before_credits
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: medical_expense_deduction
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: net_capital_gains
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: qualified_business_income_deduction
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: qualified_dividend_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: real_estate_taxes
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: refundable_ctc
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: rental_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: salt
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: self_employment_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: tax_exempt_interest_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: tax_unit_partnership_s_corp_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: taxable_interest_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: taxable_ira_distributions
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: taxable_pension_income
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: taxable_social_security
+ - variable: tax_unit_count
+ geo_level: national
+ domain_variable: unemployment_compensation
diff --git a/policyengine_us_data/calibration/target_config_full.yaml b/policyengine_us_data/calibration/target_config_full.yaml
new file mode 100644
index 000000000..1e1e287dd
--- /dev/null
+++ b/policyengine_us_data/calibration/target_config_full.yaml
@@ -0,0 +1,51 @@
+# Target exclusion config for unified calibration.
+# Each entry excludes targets matching (variable, geo_level).
+# Derived from junkyard's 22 excluded target groups.
+
+exclude:
+ # National exclusions
+ - variable: alimony_expense
+ geo_level: national
+ - variable: alimony_income
+ geo_level: national
+ - variable: charitable_deduction
+ geo_level: national
+ - variable: child_support_expense
+ geo_level: national
+ - variable: child_support_received
+ geo_level: national
+ - variable: interest_deduction
+ geo_level: national
+ - variable: medical_expense_deduction
+ geo_level: national
+ - variable: net_worth
+ geo_level: national
+ - variable: person_count
+ geo_level: national
+ - variable: real_estate_taxes
+ geo_level: national
+ - variable: rent
+ geo_level: national
+ - variable: social_security_dependents
+ geo_level: national
+ - variable: social_security_survivors
+ geo_level: national
+ # District exclusions
+ - variable: aca_ptc
+ geo_level: district
+ - variable: eitc
+ geo_level: district
+ - variable: income_tax_before_credits
+ geo_level: district
+ - variable: medical_expense_deduction
+ geo_level: district
+ - variable: net_capital_gains
+ geo_level: district
+ - variable: rental_income
+ geo_level: district
+ - variable: tax_unit_count
+ geo_level: district
+ - variable: tax_unit_partnership_s_corp_income
+ geo_level: district
+ - variable: taxable_social_security
+ geo_level: district
diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py
index 1fb7a6b34..506bd6e14 100644
--- a/policyengine_us_data/calibration/unified_calibration.py
+++ b/policyengine_us_data/calibration/unified_calibration.py
@@ -5,11 +5,11 @@
1. Load CPS dataset -> get n_records
2. Clone Nx, assign random geography (census block)
3. (Optional) Source impute ACS/SIPP/SCF vars with state
- 4. (Optional) PUF clone (2x) + QRF impute with state
- 5. Re-randomize simple takeup variables per block
- 6. Build sparse calibration matrix (clone-by-clone)
- 7. L0-regularized optimization -> calibrated weights
- 8. Save weights, diagnostics, run config
+ 4. Build sparse calibration matrix (clone-by-clone)
+ 5. L0-regularized optimization -> calibrated weights
+ 6. Save weights, diagnostics, run config
+
+Note: PUF cloning happens upstream in `extended_cps.py`, not here.
Two presets control output size via L0 regularization:
- local: L0=1e-8, ~3-4M records (for local area dataset)
@@ -22,12 +22,13 @@
--output path/to/weights.npy \\
--preset local \\
--epochs 100 \\
- --puf-dataset path/to/puf_2024.h5
+ --skip-source-impute
"""
import argparse
import builtins
import logging
+import os
import sys
from pathlib import Path
from typing import Optional
@@ -55,133 +56,119 @@
LAMBDA_L2 = 1e-12
LEARNING_RATE = 0.15
DEFAULT_EPOCHS = 100
-DEFAULT_N_CLONES = 10
-
-SIMPLE_TAKEUP_VARS = [
- {
- "variable": "takes_up_snap_if_eligible",
- "entity": "spm_unit",
- "rate_key": "snap",
- },
- {
- "variable": "takes_up_aca_if_eligible",
- "entity": "tax_unit",
- "rate_key": "aca",
- },
- {
- "variable": "takes_up_dc_ptc",
- "entity": "tax_unit",
- "rate_key": "dc_ptc",
- },
- {
- "variable": "takes_up_head_start_if_eligible",
- "entity": "person",
- "rate_key": "head_start",
- },
- {
- "variable": "takes_up_early_head_start_if_eligible",
- "entity": "person",
- "rate_key": "early_head_start",
- },
- {
- "variable": "takes_up_ssi_if_eligible",
- "entity": "person",
- "rate_key": "ssi",
- },
- {
- "variable": "would_file_taxes_voluntarily",
- "entity": "tax_unit",
- "rate_key": "voluntary_filing",
- },
- {
- "variable": "takes_up_medicaid_if_eligible",
- "entity": "person",
- "rate_key": "medicaid",
- },
- {
- "variable": "takes_up_tanf_if_eligible",
- "entity": "spm_unit",
- "rate_key": "tanf",
- },
-]
-
-
-def rerandomize_takeup(
- sim,
- clone_block_geoids: np.ndarray,
- clone_state_fips: np.ndarray,
- time_period: int,
-) -> None:
- """Re-randomize simple takeup variables per census block.
-
- Groups entities by their household's block GEOID and draws
- new takeup booleans using seeded_rng(var_name, salt=block).
- Overrides the simulation's stored inputs.
-
- Args:
- sim: Microsimulation instance (already has state_fips).
- clone_block_geoids: Block GEOIDs per household.
- clone_state_fips: State FIPS per household.
- time_period: Tax year.
- """
- from policyengine_us_data.parameters import (
- load_take_up_rate,
- )
- from policyengine_us_data.utils.randomness import (
- seeded_rng,
- )
+DEFAULT_N_CLONES = 430
- hh_ids = sim.calculate("household_id", map_to="household").values
- hh_to_block = dict(zip(hh_ids, clone_block_geoids))
- hh_to_state = dict(zip(hh_ids, clone_state_fips))
- for spec in SIMPLE_TAKEUP_VARS:
- var_name = spec["variable"]
- entity_level = spec["entity"]
- rate_key = spec["rate_key"]
+def get_git_provenance() -> dict:
+ """Capture git state and package version for provenance tracking."""
+ import subprocess as _sp
- rate_or_dict = load_take_up_rate(rate_key, time_period)
-
- is_state_specific = isinstance(rate_or_dict, dict)
+ info = {
+ "git_commit": None,
+ "git_branch": None,
+ "git_dirty": None,
+ "package_version": None,
+ }
+ try:
+ info["git_commit"] = (
+ _sp.check_output(
+ ["git", "rev-parse", "HEAD"],
+ stderr=_sp.DEVNULL,
+ )
+ .decode()
+ .strip()
+ )
+ info["git_branch"] = (
+ _sp.check_output(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
+ stderr=_sp.DEVNULL,
+ )
+ .decode()
+ .strip()
+ )
+ porcelain = (
+ _sp.check_output(
+ ["git", "status", "--porcelain"],
+ stderr=_sp.DEVNULL,
+ )
+ .decode()
+ .strip()
+ )
+ info["git_dirty"] = len(porcelain) > 0
+ except Exception:
+ pass
+ try:
+ from policyengine_us_data.__version__ import __version__
- entity_ids = sim.calculate(
- f"{entity_level}_id", map_to=entity_level
- ).values
- entity_hh_ids = sim.calculate(
- "household_id", map_to=entity_level
- ).values
- n_entities = len(entity_ids)
-
- draws = np.zeros(n_entities, dtype=np.float64)
- rates = np.zeros(n_entities, dtype=np.float64)
-
- entity_blocks = np.array(
- [hh_to_block.get(hid, "0") for hid in entity_hh_ids]
- )
-
- unique_blocks = np.unique(entity_blocks)
- for block in unique_blocks:
- mask = entity_blocks == block
- n_in_block = mask.sum()
- rng = seeded_rng(var_name, salt=str(block))
- draws[mask] = rng.random(n_in_block)
-
- if is_state_specific:
- block_hh_ids = entity_hh_ids[mask]
- for i, hid in enumerate(block_hh_ids):
- state = int(hh_to_state.get(hid, 0))
- state_str = str(state)
- r = rate_or_dict.get(
- state_str,
- rate_or_dict.get(state, 0.8),
- )
- idx = np.where(mask)[0][i]
- rates[idx] = r
- else:
- rates[mask] = rate_or_dict
+ info["package_version"] = __version__
+ except Exception:
+ pass
+ return info
+
+
+def print_package_provenance(metadata: dict) -> None:
+ """Print a provenance banner from package metadata."""
+ built = metadata.get("created_at", "unknown")
+ branch = metadata.get("git_branch", "unknown")
+ commit = metadata.get("git_commit")
+ commit_short = commit[:8] if commit else "unknown"
+ dirty = " (DIRTY)" if metadata.get("git_dirty") else ""
+ version = metadata.get("package_version", "unknown")
+ ds_sha = metadata.get("dataset_sha256", "")
+ db_sha = metadata.get("db_sha256", "")
+ ds_label = ds_sha[:12] if ds_sha else "unknown"
+ db_label = db_sha[:12] if db_sha else "unknown"
+ print("--- Package Provenance ---")
+ print(f" Built: {built}")
+ print(f" Branch: {branch} @ {commit_short}{dirty}")
+ print(f" Version: {version}")
+ print(f" Dataset SHA: {ds_label} DB SHA: {db_label}")
+ print("--------------------------")
+
+
+def check_package_staleness(metadata: dict) -> None:
+ """Warn if package is stale, dirty, or from a different branch."""
+ import datetime
+
+ created = metadata.get("created_at")
+ if created:
+ try:
+ built_dt = datetime.datetime.fromisoformat(created)
+ age = datetime.datetime.now() - built_dt
+ if age.days > 7:
+ print(
+ f"WARNING: Package is {age.days} days old "
+ f"(built {created})"
+ )
+ except Exception:
+ pass
+
+ if metadata.get("git_dirty"):
+ print("WARNING: Package was built from a dirty " "working tree")
+
+ current = get_git_provenance()
+ pkg_branch = metadata.get("git_branch")
+ cur_branch = current.get("git_branch")
+ if pkg_branch and cur_branch and pkg_branch != cur_branch:
+ print(
+ f"WARNING: Package built on branch "
+ f"'{pkg_branch}', current branch is "
+ f"'{cur_branch}'"
+ )
- new_values = draws < rates
- sim.set_input(var_name, time_period, new_values)
+ pkg_commit = metadata.get("git_commit")
+ cur_commit = current.get("git_commit")
+ if (
+ pkg_commit
+ and cur_commit
+ and pkg_commit != cur_commit
+ and pkg_branch == cur_branch
+ ):
+ print(
+ f"WARNING: Package commit {pkg_commit[:8]} "
+ f"differs from current {cur_commit[:8]} "
+ f"on same branch '{cur_branch}'"
+ )
def parse_args(argv=None):
@@ -257,23 +244,297 @@ def parse_args(argv=None):
help="Skip takeup re-randomization",
)
parser.add_argument(
- "--puf-dataset",
+ "--skip-source-impute",
+ action="store_true",
+ default=True,
+ help="(default) Skip ACS/SIPP/SCF re-imputation with state",
+ )
+ parser.add_argument(
+ "--no-skip-source-impute",
+ dest="skip_source_impute",
+ action="store_false",
+ help="Run ACS/SIPP/SCF source imputation inline",
+ )
+ parser.add_argument(
+ "--target-config",
default=None,
- help="Path to PUF h5 file for QRF training",
+ help="Path to target exclusion YAML config",
)
parser.add_argument(
- "--skip-puf",
+ "--county-level",
action="store_true",
- help="Skip PUF clone + QRF imputation",
+ help="Iterate per-county (slow, ~3143 counties). "
+ "Default is state-only (~51 states), which is much "
+ "faster for county-invariant target variables.",
)
parser.add_argument(
- "--skip-source-impute",
+ "--build-only",
action="store_true",
- help="Skip ACS/SIPP/SCF re-imputation with state",
+ help="Build matrix + save package, skip fitting",
+ )
+ parser.add_argument(
+ "--package-path",
+ default=None,
+ help="Load pre-built calibration package (skip matrix build)",
+ )
+ parser.add_argument(
+ "--package-output",
+ default=None,
+ help="Where to save calibration package",
+ )
+ parser.add_argument(
+ "--beta",
+ type=float,
+ default=BETA,
+ help=f"L0 gate temperature (default: {BETA})",
+ )
+ parser.add_argument(
+ "--lambda-l2",
+ type=float,
+ default=LAMBDA_L2,
+ help=f"L2 regularization (default: {LAMBDA_L2})",
+ )
+ parser.add_argument(
+ "--learning-rate",
+ type=float,
+ default=LEARNING_RATE,
+ help=f"Learning rate (default: {LEARNING_RATE})",
+ )
+ parser.add_argument(
+ "--log-freq",
+ type=int,
+ default=None,
+ help="Epochs between per-target CSV log entries. "
+ "Omit to disable epoch logging.",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=1,
+ help="Number of parallel workers for state/county "
+ "precomputation (default: 1, sequential).",
)
return parser.parse_args(argv)
+def load_target_config(path: str) -> dict:
+ """Load target exclusion config from YAML.
+
+ Args:
+ path: Path to YAML config file.
+
+ Returns:
+ Parsed config dict with 'exclude' list.
+ """
+ import yaml
+
+ with open(path) as f:
+ config = yaml.safe_load(f)
+ if config is None:
+ config = {}
+ if "exclude" not in config:
+ config["exclude"] = []
+ return config
+
+
+def _match_rules(targets_df, rules):
+ """Build a boolean mask matching any of the given rules."""
+ mask = np.zeros(len(targets_df), dtype=bool)
+ for rule in rules:
+ rule_mask = targets_df["variable"] == rule["variable"]
+ if "geo_level" in rule:
+ rule_mask = rule_mask & (
+ targets_df["geo_level"] == rule["geo_level"]
+ )
+ if "domain_variable" in rule:
+ rule_mask = rule_mask & (
+ targets_df["domain_variable"] == rule["domain_variable"]
+ )
+ mask |= rule_mask
+ return mask
+
+
+def apply_target_config(
+ targets_df: "pd.DataFrame",
+ X_sparse,
+ target_names: list,
+ config: dict,
+) -> tuple:
+ """Filter targets based on include/exclude config.
+
+ Use ``include`` to keep only matching targets, or ``exclude``
+ to drop matching targets. Both support ``variable``,
+ ``geo_level`` (optional), and ``domain_variable`` (optional).
+ If both are present, ``include`` is applied first, then
+ ``exclude`` removes from the included set.
+
+ Args:
+ targets_df: DataFrame with target rows.
+ X_sparse: Sparse matrix (targets x records).
+ target_names: List of target name strings.
+ config: Config dict with 'include' and/or 'exclude' list.
+
+ Returns:
+ (filtered_targets_df, filtered_X_sparse, filtered_names)
+ """
+ include_rules = config.get("include", [])
+ exclude_rules = config.get("exclude", [])
+
+ if not include_rules and not exclude_rules:
+ return targets_df, X_sparse, target_names
+
+ n_before = len(targets_df)
+
+ if include_rules:
+ keep_mask = _match_rules(targets_df, include_rules)
+ else:
+ keep_mask = np.ones(n_before, dtype=bool)
+
+ if exclude_rules:
+ drop_mask = _match_rules(targets_df, exclude_rules)
+ keep_mask &= ~drop_mask
+
+ n_dropped = n_before - keep_mask.sum()
+ logger.info(
+ "Target config: kept %d / %d targets (dropped %d)",
+ keep_mask.sum(),
+ n_before,
+ n_dropped,
+ )
+
+ idx = np.where(keep_mask)[0]
+ filtered_df = targets_df.iloc[idx].reset_index(drop=True)
+ filtered_X = X_sparse[idx, :]
+ filtered_names = [target_names[i] for i in idx]
+
+ return filtered_df, filtered_X, filtered_names
+
+
+def save_calibration_package(
+ path: str,
+ X_sparse,
+ targets_df: "pd.DataFrame",
+ target_names: list,
+ metadata: dict,
+ initial_weights: np.ndarray = None,
+ cd_geoid: np.ndarray = None,
+ block_geoid: np.ndarray = None,
+) -> None:
+ """Save calibration package to pickle.
+
+ Args:
+ path: Output file path.
+ X_sparse: Sparse matrix.
+ targets_df: Targets DataFrame.
+ target_names: Target name list.
+ metadata: Run metadata dict.
+ initial_weights: Pre-computed initial weight array.
+ cd_geoid: CD GEOID array from geography assignment.
+ block_geoid: Block GEOID array from geography assignment.
+ """
+ import pickle
+
+ package = {
+ "X_sparse": X_sparse,
+ "targets_df": targets_df,
+ "target_names": target_names,
+ "metadata": metadata,
+ "initial_weights": initial_weights,
+ "cd_geoid": cd_geoid,
+ "block_geoid": block_geoid,
+ }
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "wb") as f:
+ pickle.dump(package, f, protocol=pickle.HIGHEST_PROTOCOL)
+ logger.info("Calibration package saved to %s", path)
+
+
+def load_calibration_package(path: str) -> dict:
+ """Load calibration package from pickle.
+
+ Args:
+ path: Path to package file.
+
+ Returns:
+ Dict with X_sparse, targets_df, target_names, metadata.
+ """
+ import pickle
+
+ with open(path, "rb") as f:
+ package = pickle.load(f)
+ logger.info(
+ "Loaded package: %d targets, %d records",
+ package["X_sparse"].shape[0],
+ package["X_sparse"].shape[1],
+ )
+ meta = package.get("metadata", {})
+ print_package_provenance(meta)
+ check_package_staleness(meta)
+ return package
+
+
+def compute_initial_weights(
+ X_sparse,
+ targets_df: "pd.DataFrame",
+) -> np.ndarray:
+ """Compute population-based initial weights from age targets.
+
+ For each congressional district, sums person_count targets where
+ domain_variable == "age" to get district population, then divides
+ by the number of columns (households) active in that district.
+
+ Args:
+ X_sparse: Sparse matrix (targets x records).
+ targets_df: Targets DataFrame with columns: variable,
+ domain_variable, geo_level, geographic_id, value.
+
+ Returns:
+ Weight array of shape (n_records,).
+ """
+ n_total = X_sparse.shape[1]
+
+ age_mask = (
+ (targets_df["variable"] == "person_count")
+ & (targets_df["domain_variable"] == "age")
+ & (targets_df["geo_level"] == "district")
+ )
+ age_rows = targets_df[age_mask]
+
+ if len(age_rows) == 0:
+ logger.warning(
+ "No person_count/age/district targets found; "
+ "falling back to uniform weights=100"
+ )
+ return np.ones(n_total) * 100
+
+ initial_weights = np.ones(n_total) * 100
+ cd_groups = age_rows.groupby("geographic_id")
+
+ for cd_id, group in cd_groups:
+ cd_pop = group["value"].sum()
+ row_indices = group.index.tolist()
+ col_set = set()
+ for ri in row_indices:
+ row = X_sparse[ri]
+ col_set.update(row.indices)
+ n_cols = len(col_set)
+ if n_cols == 0:
+ continue
+ w = cd_pop / n_cols
+ for c in col_set:
+ initial_weights[c] = w
+
+ n_unique = len(np.unique(initial_weights))
+ logger.info(
+ "Initial weights: min=%.1f, max=%.1f, mean=%.1f, " "%d unique values",
+ initial_weights.min(),
+ initial_weights.max(),
+ initial_weights.mean(),
+ n_unique,
+ )
+ return initial_weights
+
+
def fit_l0_weights(
X_sparse,
targets: np.ndarray,
@@ -281,6 +542,15 @@ def fit_l0_weights(
epochs: int = DEFAULT_EPOCHS,
device: str = "cpu",
verbose_freq: Optional[int] = None,
+ beta: float = BETA,
+ lambda_l2: float = LAMBDA_L2,
+ learning_rate: float = LEARNING_RATE,
+ log_freq: int = None,
+ log_path: str = None,
+ target_names: list = None,
+ initial_weights: np.ndarray = None,
+ targets_df: "pd.DataFrame" = None,
+ achievable: np.ndarray = None,
) -> np.ndarray:
"""Fit L0-regularized calibration weights.
@@ -291,6 +561,17 @@ def fit_l0_weights(
epochs: Training epochs.
device: Torch device.
verbose_freq: Print frequency. Defaults to 10%.
+ beta: L0 gate temperature.
+ lambda_l2: L2 regularization strength.
+ learning_rate: Optimizer learning rate.
+ log_freq: Epochs between per-target CSV logs.
+ None disables logging.
+ log_path: Path for the per-target calibration log CSV.
+ target_names: Human-readable target names for the log.
+ initial_weights: Pre-computed initial weights. If None,
+ computed from targets_df age targets.
+ targets_df: Targets DataFrame, used to compute
+ initial_weights when not provided.
Returns:
Weight array of shape (n_records,).
@@ -304,21 +585,30 @@ def fit_l0_weights(
import torch
+ os.environ.setdefault(
+ "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True"
+ )
+
n_total = X_sparse.shape[1]
- initial_weights = np.ones(n_total) * 100
+ if initial_weights is None:
+ initial_weights = compute_initial_weights(X_sparse, targets_df)
logger.info(
"L0 calibration: %d targets, %d features, "
- "lambda_l0=%.1e, epochs=%d",
+ "lambda_l0=%.1e, beta=%.2f, lambda_l2=%.1e, "
+ "lr=%.3f, epochs=%d",
X_sparse.shape[0],
n_total,
lambda_l0,
+ beta,
+ lambda_l2,
+ learning_rate,
epochs,
)
model = SparseCalibrationWeights(
n_features=n_total,
- beta=BETA,
+ beta=beta,
gamma=GAMMA,
zeta=ZETA,
init_keep_prob=INIT_KEEP_PROB,
@@ -339,22 +629,136 @@ def _flushed_print(*args, **kwargs):
builtins.print = _flushed_print
- t0 = time.time()
- try:
- model.fit(
- M=X_sparse,
- y=targets,
- target_groups=None,
- lambda_l0=lambda_l0,
- lambda_l2=LAMBDA_L2,
- lr=LEARNING_RATE,
- epochs=epochs,
- loss_type="relative",
- verbose=True,
- verbose_freq=verbose_freq,
+ enable_logging = (
+ log_freq is not None
+ and log_path is not None
+ and target_names is not None
+ )
+ if enable_logging:
+ Path(log_path).parent.mkdir(parents=True, exist_ok=True)
+ with open(log_path, "w") as f:
+ f.write(
+ "target_name,estimate,target,epoch,"
+ "error,rel_error,abs_error,"
+ "rel_abs_error,loss,achievable\n"
+ )
+ logger.info(
+ "Epoch logging enabled: freq=%d, path=%s",
+ log_freq,
+ log_path,
)
- finally:
- builtins.print = _builtin_print
+
+ t0 = time.time()
+ if enable_logging:
+ epochs_done = 0
+ while epochs_done < epochs:
+ chunk = min(log_freq, epochs - epochs_done)
+ model.fit(
+ M=X_sparse,
+ y=targets,
+ target_groups=None,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ lr=learning_rate,
+ epochs=chunk,
+ loss_type="relative",
+ verbose=False,
+ )
+
+ epochs_done += chunk
+
+ with torch.no_grad():
+ y_pred = model.predict(X_sparse).cpu().numpy()
+ weights_snap = (
+ model.get_weights(deterministic=True).cpu().numpy()
+ )
+
+ active_w = weights_snap[weights_snap > 0]
+ nz = len(active_w)
+ sparsity = (1 - nz / n_total) * 100
+
+ rel_errs = np.where(
+ np.abs(targets) > 0,
+ (y_pred - targets) / np.abs(targets),
+ 0.0,
+ )
+ mean_err = np.mean(np.abs(rel_errs))
+ max_err = np.max(np.abs(rel_errs))
+ total_loss = np.sum(rel_errs**2)
+
+ if nz > 0:
+ w_tiny = (active_w < 0.01).sum()
+ w_small = ((active_w >= 0.01) & (active_w < 0.1)).sum()
+ w_med = ((active_w >= 0.1) & (active_w < 1.0)).sum()
+ w_normal = ((active_w >= 1.0) & (active_w < 10.0)).sum()
+ w_large = ((active_w >= 10.0) & (active_w < 1000.0)).sum()
+ w_huge = (active_w >= 1000.0).sum()
+ weight_dist = (
+ f"[<0.01: {100*w_tiny/nz:.1f}%, "
+ f"0.01-0.1: {100*w_small/nz:.1f}%, "
+ f"0.1-1: {100*w_med/nz:.1f}%, "
+ f"1-10: {100*w_normal/nz:.1f}%, "
+ f"10-1000: {100*w_large/nz:.1f}%, "
+ f">1000: {100*w_huge/nz:.1f}%]"
+ )
+ else:
+ weight_dist = "[no active weights]"
+
+ print(
+ f"Epoch {epochs_done:4d}: "
+ f"mean_error={mean_err:.4%}, "
+ f"max_error={max_err:.1%}, "
+ f"total_loss={total_loss:.3f}, "
+ f"active={nz}/{n_total} "
+ f"({sparsity:.1f}% sparse)\n"
+ f" Weight dist: {weight_dist}",
+ flush=True,
+ )
+
+ ach_flags = (
+ achievable if achievable is not None else [True] * len(targets)
+ )
+ with open(log_path, "a") as f:
+ for i in range(len(targets)):
+ est = y_pred[i]
+ tgt = targets[i]
+ err = est - tgt
+ rel_err = err / tgt if tgt != 0 else 0
+ abs_err = abs(err)
+ rel_abs = abs(rel_err)
+ loss = rel_err**2
+ f.write(
+ f'"{target_names[i]}",'
+ f"{est},{tgt},{epochs_done},"
+ f"{err},{rel_err},{abs_err},"
+ f"{rel_abs},{loss},"
+ f"{ach_flags[i]}\n"
+ )
+
+ logger.info(
+ "Logged %d targets at epoch %d",
+ len(targets),
+ epochs_done,
+ )
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ else:
+ try:
+ model.fit(
+ M=X_sparse,
+ y=targets,
+ target_groups=None,
+ lambda_l0=lambda_l0,
+ lambda_l2=lambda_l2,
+ lr=learning_rate,
+ epochs=epochs,
+ loss_type="relative",
+ verbose=True,
+ verbose_freq=verbose_freq,
+ )
+ finally:
+ builtins.print = _builtin_print
elapsed = time.time() - t0
logger.info(
@@ -376,115 +780,266 @@ def _flushed_print(*args, **kwargs):
return weights
-def compute_diagnostics(
+def convert_weights_to_stacked_format(
weights: np.ndarray,
- X_sparse,
- targets_df,
- target_names: list,
-) -> "pd.DataFrame":
- import pandas as pd
+ cd_geoid: np.ndarray,
+ base_n_records: int,
+ cds_ordered: list,
+) -> np.ndarray:
+ """Convert column-ordered weights to (n_cds, n_records) stacked format.
- estimates = X_sparse.dot(weights)
- true_values = targets_df["value"].values
- row_sums = np.array(X_sparse.sum(axis=1)).flatten()
+ The L0 calibration produces one weight per column, where columns
+ are ordered by clone (column i -> clone i // n_records, record
+ i % n_records) with random CD assignments. This function
+ aggregates weights across clones into the (n_cds, n_records)
+ layout expected by stacked_dataset_builder.
- rel_errors = np.where(
- np.abs(true_values) > 0,
- (estimates - true_values) / np.abs(true_values),
- 0.0,
- )
- return pd.DataFrame(
- {
- "target": target_names,
- "true_value": true_values,
- "estimate": estimates,
- "rel_error": rel_errors,
- "abs_rel_error": np.abs(rel_errors),
- "achievable": row_sums > 0,
- }
+ Args:
+ weights: Raw weight vector from L0 fitting, length
+ n_clones * base_n_records.
+ cd_geoid: CD GEOID per column from geography assignment.
+ base_n_records: Number of base households (before cloning).
+ cds_ordered: Ordered list of CD GEOIDs defining row order.
+
+ Returns:
+ Flat array of length n_cds * base_n_records that reshapes
+ to (n_cds, base_n_records).
+ """
+ n_total = len(weights)
+ n_cds = len(cds_ordered)
+
+ cd_to_idx = {cd: idx for idx, cd in enumerate(cds_ordered)}
+ record_indices = np.arange(n_total) % base_n_records
+ cd_row_indices = np.array([cd_to_idx[cd] for cd in cd_geoid])
+ flat_indices = cd_row_indices * base_n_records + record_indices
+
+ W = np.zeros(n_cds * base_n_records, dtype=np.float64)
+ np.add.at(W, flat_indices, weights)
+
+ assert np.isclose(
+ W.sum(), weights.sum()
+ ), f"Weight sum mismatch: {W.sum()} vs {weights.sum()}"
+ logger.info(
+ "Converted weights to stacked format: "
+ "(%d, %d) = %d elements, sum=%.1f",
+ n_cds,
+ base_n_records,
+ len(W),
+ W.sum(),
)
+ return W
-def _build_puf_cloned_dataset(
- dataset_path: str,
- puf_dataset_path: str,
- state_fips: np.ndarray,
- time_period: int = 2024,
- skip_qrf: bool = False,
- skip_source_impute: bool = False,
-) -> str:
- """Build a PUF-cloned dataset from raw CPS.
+def convert_blocks_to_stacked_format(
+ block_geoid: np.ndarray,
+ cd_geoid: np.ndarray,
+ base_n_records: int,
+ cds_ordered: list,
+) -> np.ndarray:
+ """Convert column-ordered block GEOIDs to stacked format.
- Loads the CPS, optionally runs source imputations
- (ACS/SIPP/SCF), then PUF clone + QRF.
+ Parallel to convert_weights_to_stacked_format. For each
+ (CD, record) slot, stores the block GEOID from the first
+ clone assigned there. Empty string for unfilled slots
+ (records with no clone in that CD).
Args:
- dataset_path: Path to raw CPS h5 file.
- puf_dataset_path: Path to PUF h5 file.
- state_fips: State FIPS per household (base records).
- time_period: Tax year.
- skip_qrf: Skip QRF imputation.
- skip_source_impute: Skip ACS/SIPP/SCF imputations.
+ block_geoid: Block GEOID per column from geography
+ assignment. Length n_clones * base_n_records.
+ cd_geoid: CD GEOID per column from geography
+ assignment.
+ base_n_records: Number of base households.
+ cds_ordered: Ordered list of CD GEOIDs defining
+ row order.
Returns:
- Path to the PUF-cloned h5 file.
+ Array of dtype U15, length n_cds * base_n_records,
+ reshapeable to (n_cds, base_n_records).
"""
- import h5py
+ n_total = len(block_geoid)
+ n_cds = len(cds_ordered)
+
+ cd_to_idx = {cd: idx for idx, cd in enumerate(cds_ordered)}
+ record_indices = np.arange(n_total) % base_n_records
+ cd_row_indices = np.array([cd_to_idx[cd] for cd in cd_geoid])
+ flat_indices = cd_row_indices * base_n_records + record_indices
+
+ B = np.full(n_cds * base_n_records, "", dtype="U15")
+ n_collisions = 0
+ for i in range(n_total):
+ fi = flat_indices[i]
+ if B[fi] == "":
+ B[fi] = block_geoid[i]
+ else:
+ n_collisions += 1
- from policyengine_us import Microsimulation
+ if n_collisions > 0:
+ logger.warning(
+ "Block collisions: %d slots had multiple clones "
+ "with different blocks.",
+ n_collisions,
+ )
- from policyengine_us_data.calibration.puf_impute import (
- puf_clone_dataset,
+ n_filled = np.count_nonzero(B != "")
+ logger.info(
+ "Converted blocks to stacked format: "
+ "(%d, %d) = %d slots, %d filled (%.1f%%)",
+ n_cds,
+ base_n_records,
+ len(B),
+ n_filled,
+ n_filled / len(B) * 100,
)
+ return B
- logger.info("Building PUF-cloned dataset from %s", dataset_path)
- sim = Microsimulation(dataset=dataset_path)
- data = sim.dataset.load_dataset()
+def compute_stacked_takeup(
+ weights: np.ndarray,
+ cd_geoid: np.ndarray,
+ block_geoid: np.ndarray,
+ base_n_records: int,
+ cds_ordered: list,
+ entity_hh_idx_map: dict,
+ household_ids: np.ndarray,
+ precomputed_rates: dict,
+ affected_target_info: dict,
+) -> dict:
+ """Compute weight-weighted takeup per (CD, entity).
+
+ For each takeup variable, iterates over all clones and
+ recomputes entity-level takeup draws (deterministic given
+ block and hh_id). Accumulates weight-weighted takeup
+ per (CD, base_entity_index) using the final optimizer
+ weights.
- data_dict = {}
- for var in data:
- if isinstance(data[var], dict):
- vals = list(data[var].values())
- data_dict[var] = {time_period: vals[0]}
- else:
- data_dict[var] = {time_period: np.array(data[var])}
+ Returns:
+ Dict mapping takeup variable name to ndarray of
+ shape (n_cds, n_base_entities).
+ """
+ from policyengine_us_data.utils.takeup import (
+ compute_block_takeup_for_entities,
+ )
- if not skip_source_impute:
- from policyengine_us_data.calibration.source_impute import (
- impute_source_variables,
+ n_total = len(weights)
+ n_clones = n_total // base_n_records
+ n_cds = len(cds_ordered)
+ cd_to_idx = {cd: i for i, cd in enumerate(cds_ordered)}
+
+ col_cd_idx = np.empty(n_total, dtype=np.int32)
+ for i, cd in enumerate(cd_geoid):
+ col_cd_idx[i] = cd_to_idx.get(cd, -1)
+
+ unique_takeup = {}
+ for tvar, info in affected_target_info.items():
+ if tvar.endswith("_count"):
+ continue
+ tv = info["takeup_var"]
+ if tv not in unique_takeup:
+ unique_takeup[tv] = info
+
+ result = {}
+
+ for takeup_var, info in unique_takeup.items():
+ entity_level = info["entity"]
+ rate_key = info["rate_key"]
+ ent_hh = entity_hh_idx_map[entity_level]
+ n_ent = len(ent_hh)
+ rate = precomputed_rates[rate_key]
+
+ numerator = np.zeros(n_cds * n_ent, dtype=np.float64)
+ denominator = np.zeros(n_cds * n_ent, dtype=np.float64)
+
+ for clone_idx in range(n_clones):
+ col_start = clone_idx * base_n_records
+ col_end = col_start + base_n_records
+ clone_w = weights[col_start:col_end]
+
+ if not np.any(clone_w > 0):
+ continue
+
+ clone_blocks = block_geoid[col_start:col_end]
+
+ ent_blocks = clone_blocks[ent_hh]
+ ent_hh_ids = household_ids[ent_hh]
+
+ ent_takeup = compute_block_takeup_for_entities(
+ takeup_var, rate, ent_blocks, ent_hh_ids
+ ).astype(np.float64)
+
+ ent_w = clone_w[ent_hh]
+ ent_cd_idx = col_cd_idx[col_start:col_end][ent_hh]
+
+ ent_active = (ent_w > 0) & (ent_cd_idx >= 0)
+ if not ent_active.any():
+ continue
+
+ ent_indices = np.arange(n_ent)
+ flat = ent_cd_idx * n_ent + ent_indices
+
+ np.add.at(
+ numerator,
+ flat[ent_active],
+ (ent_w * ent_takeup)[ent_active],
+ )
+ np.add.at(
+ denominator,
+ flat[ent_active],
+ ent_w[ent_active],
+ )
+
+ if (clone_idx + 1) % 100 == 0:
+ logger.info(
+ "Stacked takeup %s: clone %d/%d",
+ takeup_var,
+ clone_idx + 1,
+ n_clones,
+ )
+
+ valid = denominator > 0
+ takeup_avg = np.ones(n_cds * n_ent, dtype=np.float32)
+ takeup_avg[valid] = (numerator[valid] / denominator[valid]).astype(
+ np.float32
)
- data_dict = impute_source_variables(
- data=data_dict,
- state_fips=state_fips,
- time_period=time_period,
- dataset_path=dataset_path,
+ result[takeup_var] = takeup_avg.reshape(n_cds, n_ent)
+ logger.info(
+ "Stacked takeup %s: shape %s, " "active cells %d, mean %.4f",
+ takeup_var,
+ result[takeup_var].shape,
+ valid.sum(),
+ takeup_avg[valid].mean() if valid.any() else 0,
)
- puf_dataset = puf_dataset_path if not skip_qrf else None
+ return result
- new_data = puf_clone_dataset(
- data=data_dict,
- state_fips=state_fips,
- time_period=time_period,
- puf_dataset=puf_dataset,
- skip_qrf=skip_qrf,
- dataset_path=dataset_path,
- )
- output_path = str(
- Path(dataset_path).parent / f"puf_cloned_{Path(dataset_path).stem}.h5"
- )
+def compute_diagnostics(
+ weights: np.ndarray,
+ X_sparse,
+ targets_df,
+ target_names: list,
+) -> "pd.DataFrame":
+ import pandas as pd
- with h5py.File(output_path, "w") as f:
- for var, time_dict in new_data.items():
- for tp, values in time_dict.items():
- f.create_dataset(f"{var}/{tp}", data=values)
+ estimates = X_sparse.dot(weights)
+ true_values = targets_df["value"].values
+ row_sums = np.array(X_sparse.sum(axis=1)).flatten()
- del sim
- logger.info("PUF-cloned dataset saved to %s", output_path)
- return output_path
+ rel_errors = np.where(
+ np.abs(true_values) > 0,
+ (estimates - true_values) / np.abs(true_values),
+ 0.0,
+ )
+ return pd.DataFrame(
+ {
+ "target": target_names,
+ "true_value": true_values,
+ "estimate": estimates,
+ "rel_error": rel_errors,
+ "abs_rel_error": np.abs(rel_errors),
+ "achievable": row_sums > 0,
+ }
+ )
def run_calibration(
@@ -498,9 +1053,18 @@ def run_calibration(
domain_variables: list = None,
hierarchical_domains: list = None,
skip_takeup_rerandomize: bool = False,
- puf_dataset_path: str = None,
- skip_puf: bool = False,
- skip_source_impute: bool = False,
+ skip_source_impute: bool = True,
+ skip_county: bool = True,
+ target_config: dict = None,
+ build_only: bool = False,
+ package_path: str = None,
+ package_output_path: str = None,
+ beta: float = BETA,
+ lambda_l2: float = LAMBDA_L2,
+ learning_rate: float = LEARNING_RATE,
+ log_freq: int = None,
+ log_path: str = None,
+ workers: int = 1,
):
"""Run unified calibration pipeline.
@@ -516,32 +1080,98 @@ def run_calibration(
hierarchical_domains: Domains for hierarchical
uprating + CD reconciliation.
skip_takeup_rerandomize: Skip takeup step.
- puf_dataset_path: Path to PUF h5 for QRF training.
- skip_puf: Skip PUF clone step.
skip_source_impute: Skip ACS/SIPP/SCF imputations.
+ target_config: Parsed target config dict.
+ build_only: If True, save package and skip fitting.
+ package_path: Load pre-built package (skip build).
+ package_output_path: Where to save calibration package.
+ beta: L0 gate temperature.
+ lambda_l2: L2 regularization strength.
+ learning_rate: Optimizer learning rate.
+ log_freq: Epochs between per-target CSV logs.
+ log_path: Path for per-target calibration log CSV.
Returns:
- (weights, targets_df, X_sparse, target_names)
+ (weights, targets_df, X_sparse, target_names, geography_info)
+ weights is None when build_only=True.
+ geography_info is a dict with cd_geoid and base_n_records.
"""
import time
+ t0 = time.time()
+
+ # Early exit: load pre-built package
+ if package_path is not None:
+ package = load_calibration_package(package_path)
+ targets_df = package["targets_df"]
+ X_sparse = package["X_sparse"]
+ target_names = package["target_names"]
+
+ if target_config:
+ targets_df, X_sparse, target_names = apply_target_config(
+ targets_df, X_sparse, target_names, target_config
+ )
+
+ initial_weights = package.get("initial_weights")
+ targets = targets_df["value"].values
+ row_sums = np.array(X_sparse.sum(axis=1)).flatten()
+ pkg_achievable = row_sums > 0
+ weights = fit_l0_weights(
+ X_sparse=X_sparse,
+ targets=targets,
+ lambda_l0=lambda_l0,
+ epochs=epochs,
+ device=device,
+ beta=beta,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ log_path=log_path,
+ target_names=target_names,
+ initial_weights=initial_weights,
+ targets_df=targets_df,
+ achievable=pkg_achievable,
+ )
+ logger.info(
+ "Total pipeline (from package): %.1f min",
+ (time.time() - t0) / 60,
+ )
+ geography_info = {
+ "cd_geoid": package.get("cd_geoid"),
+ "block_geoid": package.get("block_geoid"),
+ "base_n_records": package["metadata"].get("base_n_records"),
+ }
+ return (
+ weights,
+ targets_df,
+ X_sparse,
+ target_names,
+ geography_info,
+ )
+
from policyengine_us import Microsimulation
from policyengine_us_data.calibration.clone_and_assign import (
assign_random_geography,
- double_geography_for_puf,
)
from policyengine_us_data.calibration.unified_matrix_builder import (
UnifiedMatrixBuilder,
)
- t0 = time.time()
-
- # Step 1: Load dataset
+ # Step 1: Load dataset and detect time period
logger.info("Loading dataset from %s", dataset_path)
sim = Microsimulation(dataset=dataset_path)
n_records = len(sim.calculate("household_id", map_to="household").values)
- logger.info("Loaded %d households", n_records)
+ raw_keys = sim.dataset.load_dataset()["household_id"]
+ if isinstance(raw_keys, dict):
+ time_period = int(next(iter(raw_keys)))
+ else:
+ time_period = 2024
+ logger.info(
+ "Loaded %d households, time_period=%d",
+ n_records,
+ time_period,
+ )
# Step 2: Clone and assign geography
logger.info(
@@ -556,46 +1186,28 @@ def run_calibration(
seed=seed,
)
- # Step 3: Source impute + PUF clone (if requested)
+ # Step 3: Source imputation (if requested)
dataset_for_matrix = dataset_path
- if not skip_puf and puf_dataset_path is not None:
- base_states = geography.state_fips[:n_records]
-
- puf_cloned_path = _build_puf_cloned_dataset(
- dataset_path=dataset_path,
- puf_dataset_path=puf_dataset_path,
- state_fips=base_states,
- time_period=2024,
- skip_qrf=False,
- skip_source_impute=skip_source_impute,
- )
-
- geography = double_geography_for_puf(geography)
- dataset_for_matrix = puf_cloned_path
- n_records = n_records * 2
-
- # Reload sim from PUF-cloned dataset
- del sim
- sim = Microsimulation(dataset=puf_cloned_path)
-
- logger.info(
- "After PUF clone: %d records x %d clones = %d",
- n_records,
- n_clones,
- n_records * n_clones,
- )
- elif not skip_source_impute:
+ if not skip_source_impute:
# Run source imputations without PUF cloning
import h5py
base_states = geography.state_fips[:n_records]
- source_sim = Microsimulation(dataset=dataset_path)
- raw_data = source_sim.dataset.load_dataset()
+ raw_data = sim.dataset.load_dataset()
data_dict = {}
for var in raw_data:
- data_dict[var] = {2024: raw_data[var][...]}
- del source_sim
+ val = raw_data[var]
+ if isinstance(val, dict):
+ # h5py returns string keys ("2024"); normalize
+ # to int so source_impute lookups work.
+ # Some keys like "ETERNITY" are non-numeric — keep
+ # them as strings.
+ data_dict[var] = {
+ int(k) if k.isdigit() else k: v for k, v in val.items()
+ }
+ else:
+ data_dict[var] = {time_period: val[...]}
from policyengine_us_data.calibration.source_impute import (
impute_source_variables,
@@ -604,7 +1216,7 @@ def run_calibration(
data_dict = impute_source_variables(
data=data_dict,
state_fips=base_states,
- time_period=2024,
+ time_period=time_period,
dataset_path=dataset_path,
)
@@ -625,17 +1237,7 @@ def run_calibration(
source_path,
)
- # Step 4: Build sim_modifier for takeup rerandomization
sim_modifier = None
- if not skip_takeup_rerandomize:
- time_period = 2024
-
- def sim_modifier(s, clone_idx):
- col_start = clone_idx * n_records
- col_end = col_start + n_records
- blocks = geography.block_geoid[col_start:col_end]
- states = geography.state_fips[col_start:col_end]
- rerandomize_takeup(s, blocks, states, time_period)
# Step 5: Build target filter
target_filter = {}
@@ -643,11 +1245,12 @@ def sim_modifier(s, clone_idx):
target_filter["domain_variables"] = domain_variables
# Step 6: Build sparse calibration matrix
+ do_rerandomize = not skip_takeup_rerandomize
t_matrix = time.time()
db_uri = f"sqlite:///{db_path}"
builder = UnifiedMatrixBuilder(
db_uri=db_uri,
- time_period=2024,
+ time_period=time_period,
dataset_path=dataset_for_matrix,
)
targets_df, X_sparse, target_names = builder.build_matrix(
@@ -656,6 +1259,9 @@ def sim_modifier(s, clone_idx):
target_filter=target_filter,
hierarchical_domains=hierarchical_domains,
sim_modifier=sim_modifier,
+ rerandomize_takeup=do_rerandomize,
+ county_level=not skip_county,
+ workers=workers,
)
builder.print_uprating_summary(targets_df)
@@ -669,6 +1275,76 @@ def sim_modifier(s, clone_idx):
X_sparse.nnz,
)
+ # Step 6b: Save FULL (unfiltered) calibration package.
+ # Target config is applied at fit time, so the package can be
+ # reused with different configs without rebuilding.
+ import datetime
+
+ metadata = {
+ "dataset_path": dataset_path,
+ "db_path": db_path,
+ "n_clones": n_clones,
+ "n_records": X_sparse.shape[1],
+ "base_n_records": n_records,
+ "seed": seed,
+ "created_at": datetime.datetime.now().isoformat(),
+ }
+ metadata.update(get_git_provenance())
+ from policyengine_us_data.utils.manifest import compute_file_checksum
+
+ metadata["dataset_sha256"] = compute_file_checksum(Path(dataset_path))
+ metadata["db_sha256"] = compute_file_checksum(Path(db_path))
+
+ if package_output_path:
+ full_initial_weights = compute_initial_weights(X_sparse, targets_df)
+ save_calibration_package(
+ package_output_path,
+ X_sparse,
+ targets_df,
+ target_names,
+ metadata,
+ initial_weights=full_initial_weights,
+ cd_geoid=geography.cd_geoid,
+ block_geoid=geography.block_geoid,
+ )
+
+ # Step 6c: Apply target config filtering (for fit or validation)
+ if target_config:
+ targets_df, X_sparse, target_names = apply_target_config(
+ targets_df, X_sparse, target_names, target_config
+ )
+
+ initial_weights = compute_initial_weights(X_sparse, targets_df)
+
+ if build_only:
+ from policyengine_us_data.calibration.validate_package import (
+ validate_package,
+ format_report,
+ )
+
+ package = {
+ "X_sparse": X_sparse,
+ "targets_df": targets_df,
+ "target_names": target_names,
+ "metadata": metadata,
+ "initial_weights": initial_weights,
+ }
+ result = validate_package(package)
+ print(format_report(result))
+ geography_info = {
+ "cd_geoid": geography.cd_geoid,
+ "block_geoid": geography.block_geoid,
+ "base_n_records": n_records,
+ "dataset_for_matrix": dataset_for_matrix,
+ }
+ return (
+ None,
+ targets_df,
+ X_sparse,
+ target_names,
+ geography_info,
+ )
+
# Step 7: L0 calibration
targets = targets_df["value"].values
@@ -686,13 +1362,38 @@ def sim_modifier(s, clone_idx):
lambda_l0=lambda_l0,
epochs=epochs,
device=device,
+ beta=beta,
+ lambda_l2=lambda_l2,
+ learning_rate=learning_rate,
+ log_freq=log_freq,
+ log_path=log_path,
+ target_names=target_names,
+ initial_weights=initial_weights,
+ targets_df=targets_df,
+ achievable=achievable,
)
logger.info(
"Total pipeline: %.1f min",
(time.time() - t0) / 60,
)
- return weights, targets_df, X_sparse, target_names
+ geography_info = {
+ "cd_geoid": geography.cd_geoid,
+ "block_geoid": geography.block_geoid,
+ "base_n_records": n_records,
+ "dataset_for_matrix": dataset_for_matrix,
+ "entity_hh_idx_map": getattr(builder, "entity_hh_idx_map", None),
+ "household_ids": getattr(builder, "household_ids", None),
+ "precomputed_rates": getattr(builder, "precomputed_rates", None),
+ "affected_target_info": getattr(builder, "affected_target_info", None),
+ }
+ return (
+ weights,
+ targets_df,
+ X_sparse,
+ target_names,
+ geography_info,
+ )
def main(argv=None):
@@ -710,17 +1411,18 @@ def main(argv=None):
pass
args = parse_args(argv)
+ logger.info("CLI args: %s", vars(args))
from policyengine_us_data.storage import STORAGE_FOLDER
dataset_path = args.dataset or str(
- STORAGE_FOLDER / "stratified_extended_cps_2024.h5"
+ STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5"
)
db_path = args.db_path or str(
STORAGE_FOLDER / "calibration" / "policy_data.db"
)
output_path = args.output or str(
- STORAGE_FOLDER / "calibration" / "unified_weights.npy"
+ STORAGE_FOLDER / "calibration" / "calibration_weights.npy"
)
if args.lambda_l0 is not None:
@@ -744,7 +1446,27 @@ def main(argv=None):
t_start = time.time()
- weights, targets_df, X_sparse, target_names = run_calibration(
+ target_config = None
+ if args.target_config:
+ target_config = load_target_config(args.target_config)
+
+ package_output_path = args.package_output
+ if args.build_only and not package_output_path:
+ package_output_path = str(
+ STORAGE_FOLDER / "calibration" / "calibration_package.pkl"
+ )
+
+ output_dir = Path(output_path).parent
+ cal_log_path = None
+ if args.log_freq is not None:
+ cal_log_path = str(output_dir / "calibration_log.csv")
+ (
+ weights,
+ targets_df,
+ X_sparse,
+ target_names,
+ geography_info,
+ ) = run_calibration(
dataset_path=dataset_path,
db_path=db_path,
n_clones=args.n_clones,
@@ -755,17 +1477,29 @@ def main(argv=None):
domain_variables=domain_variables,
hierarchical_domains=hierarchical_domains,
skip_takeup_rerandomize=args.skip_takeup_rerandomize,
- puf_dataset_path=args.puf_dataset,
- skip_puf=args.skip_puf,
skip_source_impute=args.skip_source_impute,
+ skip_county=not args.county_level,
+ target_config=target_config,
+ build_only=args.build_only,
+ package_path=args.package_path,
+ package_output_path=package_output_path,
+ beta=args.beta,
+ lambda_l2=args.lambda_l2,
+ learning_rate=args.learning_rate,
+ log_freq=args.log_freq,
+ log_path=cal_log_path,
+ workers=args.workers,
)
- # Save weights
- np.save(output_path, weights)
- logger.info("Weights saved to %s", output_path)
- print(f"OUTPUT_PATH:{output_path}")
+ source_imputed = geography_info.get("dataset_for_matrix")
+ if source_imputed and source_imputed != dataset_path:
+ print(f"SOURCE_IMPUTED_PATH:{source_imputed}")
+
+ if weights is None:
+ logger.info("Build-only complete. Package saved.")
+ return
- # Save diagnostics
+ # Diagnostics (raw weights match X_sparse column layout)
output_dir = Path(output_path).parent
diag_df = compute_diagnostics(weights, X_sparse, targets_df, target_names)
diag_path = output_dir / "unified_diagnostics.csv"
@@ -784,33 +1518,220 @@ def main(argv=None):
(err_pct < 25).mean() * 100,
)
+ # Convert to stacked format for stacked_dataset_builder
+ cd_geoid = geography_info.get("cd_geoid")
+ base_n_records = geography_info.get("base_n_records")
+
+ if cd_geoid is not None and base_n_records is not None:
+ from policyengine_us_data.calibration.calibration_utils import (
+ save_geo_labels,
+ )
+
+ cds_ordered = sorted(set(cd_geoid))
+ save_geo_labels(cds_ordered, output_dir / "geo_labels.json")
+ print(f"GEO_LABELS_PATH:{output_dir / 'geo_labels.json'}")
+ logger.info(
+ "Saved %d geo labels to %s",
+ len(cds_ordered),
+ output_dir / "geo_labels.json",
+ )
+ stacked_weights = convert_weights_to_stacked_format(
+ weights=weights,
+ cd_geoid=cd_geoid,
+ base_n_records=base_n_records,
+ cds_ordered=cds_ordered,
+ )
+ else:
+ logger.warning("No geography info available; saving raw weights")
+ stacked_weights = weights
+
+ # Save stacked blocks alongside weights
+ block_geoid = geography_info.get("block_geoid")
+ if (
+ block_geoid is not None
+ and cd_geoid is not None
+ and base_n_records is not None
+ ):
+ blocks_stacked = convert_blocks_to_stacked_format(
+ block_geoid=block_geoid,
+ cd_geoid=cd_geoid,
+ base_n_records=base_n_records,
+ cds_ordered=cds_ordered,
+ )
+ blocks_path = output_dir / "stacked_blocks.npy"
+ np.save(str(blocks_path), blocks_stacked)
+ logger.info("Stacked blocks saved to %s", blocks_path)
+ print(f"BLOCKS_PATH:{blocks_path}")
+
+ # Save stacked takeup (weight-averaged per CD × entity)
+ entity_hh_idx_map = geography_info.get("entity_hh_idx_map")
+ affected_info = geography_info.get("affected_target_info")
+ hh_ids_for_takeup = geography_info.get("household_ids")
+ rates_for_takeup = geography_info.get("precomputed_rates")
+
+ # Rebuild entity mappings from sim when loading from package
+ if (
+ block_geoid is not None
+ and cd_geoid is not None
+ and base_n_records is not None
+ and entity_hh_idx_map is None
+ ):
+ logger.info("Rebuilding entity mappings for stacked takeup...")
+ from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ )
+ from policyengine_us_data.parameters import (
+ load_take_up_rate,
+ )
+
+ ds = source_imputed or dataset_path
+ sim = Microsimulation(dataset=ds)
+ tp = int(sim.default_calculation_period)
+
+ hh_ids_for_takeup = sim.calculate(
+ "household_id", map_to="household"
+ ).values
+ person_hh_ids = sim.calculate("household_id", map_to="person").values
+ hh_id_to_idx = {int(h): i for i, h in enumerate(hh_ids_for_takeup)}
+ person_hh_idx = np.array([hh_id_to_idx[int(h)] for h in person_hh_ids])
+
+ import pandas as pd
+
+ entity_rel_df = pd.DataFrame(
+ {
+ "household_id": sim.calculate(
+ "household_id", map_to="person"
+ ).values,
+ "tax_unit_id": sim.calculate(
+ "tax_unit_id", map_to="person"
+ ).values,
+ "spm_unit_id": sim.calculate(
+ "spm_unit_id", map_to="person"
+ ).values,
+ }
+ )
+ spm_to_hh = (
+ entity_rel_df.groupby("spm_unit_id")["household_id"]
+ .first()
+ .to_dict()
+ )
+ spm_ids = sim.calculate("spm_unit_id", map_to="spm_unit").values
+ spm_hh_idx = np.array(
+ [hh_id_to_idx[int(spm_to_hh[int(s)])] for s in spm_ids]
+ )
+
+ tu_to_hh = (
+ entity_rel_df.groupby("tax_unit_id")["household_id"]
+ .first()
+ .to_dict()
+ )
+ tu_ids = sim.calculate("tax_unit_id", map_to="tax_unit").values
+ tu_hh_idx = np.array(
+ [hh_id_to_idx[int(tu_to_hh[int(t)])] for t in tu_ids]
+ )
+
+ entity_hh_idx_map = {
+ "spm_unit": spm_hh_idx,
+ "tax_unit": tu_hh_idx,
+ "person": person_hh_idx,
+ }
+
+ unique_vars = set(targets_df["variable"].values)
+ affected_info = {}
+ for tvar in unique_vars:
+ for key, info in TAKEUP_AFFECTED_TARGETS.items():
+ if tvar == key:
+ affected_info[tvar] = info
+ break
+
+ rates_for_takeup = {}
+ for tvar, info in affected_info.items():
+ rk = info["rate_key"]
+ if rk not in rates_for_takeup:
+ rates_for_takeup[rk] = load_take_up_rate(rk, tp)
+
+ del sim
+ logger.info(
+ "Rebuilt entity mappings: %d affected vars",
+ len(affected_info),
+ )
+
+ if (
+ block_geoid is not None
+ and cd_geoid is not None
+ and base_n_records is not None
+ and entity_hh_idx_map is not None
+ and affected_info
+ ):
+ import time as _time
+
+ t_takeup = _time.time()
+ stacked_tu = compute_stacked_takeup(
+ weights=weights,
+ cd_geoid=cd_geoid,
+ block_geoid=block_geoid,
+ base_n_records=base_n_records,
+ cds_ordered=cds_ordered,
+ entity_hh_idx_map=entity_hh_idx_map,
+ household_ids=hh_ids_for_takeup,
+ precomputed_rates=rates_for_takeup,
+ affected_target_info=affected_info,
+ )
+ takeup_path = output_dir / "stacked_takeup.npz"
+ np.savez_compressed(str(takeup_path), **stacked_tu)
+ logger.info(
+ "Stacked takeup saved to %s (%.1f min)",
+ takeup_path,
+ (_time.time() - t_takeup) / 60,
+ )
+ print(f"TAKEUP_PATH:{takeup_path}")
+
+ # Save weights
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ np.save(output_path, stacked_weights)
+ logger.info("Weights saved to %s", output_path)
+ print(f"OUTPUT_PATH:{output_path}")
+
# Save run config
t_end = time.time()
+ weight_format = (
+ "stacked"
+ if cd_geoid is not None and base_n_records is not None
+ else "raw"
+ )
run_config = {
"dataset": dataset_path,
"db_path": db_path,
- "puf_dataset": args.puf_dataset,
- "skip_puf": args.skip_puf,
"skip_source_impute": args.skip_source_impute,
"n_clones": args.n_clones,
"lambda_l0": lambda_l0,
+ "beta": args.beta,
+ "lambda_l2": args.lambda_l2,
+ "learning_rate": args.learning_rate,
"epochs": args.epochs,
"device": args.device,
"seed": args.seed,
"domain_variables": domain_variables,
"hierarchical_domains": hierarchical_domains,
+ "target_config": args.target_config,
"n_targets": len(targets_df),
"n_records": X_sparse.shape[1],
- "weight_sum": float(weights.sum()),
- "weight_nonzero": int((weights > 0).sum()),
+ "geo_labels_file": "geo_labels.json",
+ "weight_format": weight_format,
+ "weight_sum": float(stacked_weights.sum()),
+ "weight_nonzero": int((stacked_weights > 0).sum()),
"mean_error_pct": float(err_pct.mean()),
"elapsed_seconds": round(t_end - t_start, 1),
}
+ run_config.update(get_git_provenance())
config_path = output_dir / "unified_run_config.json"
with open(config_path, "w") as f:
json.dump(run_config, f, indent=2)
logger.info("Config saved to %s", config_path)
+ print(f"CONFIG_PATH:{config_path}")
print(f"LOG_PATH:{diag_path}")
+ if cal_log_path:
+ print(f"CAL_LOG_PATH:{cal_log_path}")
if __name__ == "__main__":
diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py
index ac31c34e1..1fa8ea73f 100644
--- a/policyengine_us_data/calibration/unified_matrix_builder.py
+++ b/policyengine_us_data/calibration/unified_matrix_builder.py
@@ -21,11 +21,14 @@
from policyengine_us_data.storage import STORAGE_FOLDER
from policyengine_us_data.utils.census import STATE_NAME_TO_FIPS
-from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
+from policyengine_us_data.calibration.calibration_utils import (
get_calculated_variables,
apply_op,
get_geo_level,
)
+from policyengine_us_data.calibration.block_assignment import (
+ get_county_enum_index_from_fips,
+)
logger = logging.getLogger(__name__)
@@ -35,6 +38,680 @@
"congressional_district_geoid",
}
+COUNTY_DEPENDENT_VARS = {
+ "aca_ptc",
+}
+
+
+def _compute_single_state(
+ dataset_path: str,
+ time_period: int,
+ state: int,
+ n_hh: int,
+ target_vars: list,
+ constraint_vars: list,
+ rerandomize_takeup: bool,
+ affected_targets: dict,
+):
+ """Compute household/person/entity values for one state.
+
+ Top-level function (not a method) so it is picklable for
+ ``ProcessPoolExecutor``.
+
+ Args:
+ dataset_path: Path to the base CPS h5 file.
+ time_period: Tax year for simulation.
+ state: State FIPS code.
+ n_hh: Number of household records.
+ target_vars: Target variable names (list for determinism).
+ constraint_vars: Constraint variable names (list).
+ rerandomize_takeup: Force takeup=True if True.
+ affected_targets: Takeup-affected target info dict.
+
+ Returns:
+ (state_fips, {"hh": {...}, "person": {...}, "entity": {...}})
+ """
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS
+ from policyengine_us_data.calibration.calibration_utils import (
+ get_calculated_variables,
+ )
+
+ state_sim = Microsimulation(dataset=dataset_path)
+
+ state_sim.set_input(
+ "state_fips",
+ time_period,
+ np.full(n_hh, state, dtype=np.int32),
+ )
+ for var in get_calculated_variables(state_sim):
+ state_sim.delete_arrays(var)
+
+ hh = {}
+ for var in target_vars:
+ if var.endswith("_count"):
+ continue
+ try:
+ hh[var] = state_sim.calculate(
+ var,
+ time_period,
+ map_to="household",
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate '%s' for state %d: %s",
+ var,
+ state,
+ exc,
+ )
+
+ person = {}
+ for var in constraint_vars:
+ try:
+ person[var] = state_sim.calculate(
+ var,
+ time_period,
+ map_to="person",
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate constraint '%s' " "for state %d: %s",
+ var,
+ state,
+ exc,
+ )
+
+ if rerandomize_takeup:
+ for spec in SIMPLE_TAKEUP_VARS:
+ entity = spec["entity"]
+ n_ent = len(
+ state_sim.calculate(f"{entity}_id", map_to=entity).values
+ )
+ state_sim.set_input(
+ spec["variable"],
+ time_period,
+ np.ones(n_ent, dtype=bool),
+ )
+ for var in get_calculated_variables(state_sim):
+ state_sim.delete_arrays(var)
+
+ entity_vals = {}
+ if rerandomize_takeup:
+ for tvar, info in affected_targets.items():
+ entity_level = info["entity"]
+ try:
+ entity_vals[tvar] = state_sim.calculate(
+ tvar,
+ time_period,
+ map_to=entity_level,
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate entity-level "
+ "'%s' (map_to=%s) for state %d: %s",
+ tvar,
+ entity_level,
+ state,
+ exc,
+ )
+
+ return (state, {"hh": hh, "person": person, "entity": entity_vals})
+
+
+def _compute_single_state_group_counties(
+ dataset_path: str,
+ time_period: int,
+ state_fips: int,
+ counties: list,
+ n_hh: int,
+ county_dep_targets: list,
+ rerandomize_takeup: bool,
+ affected_targets: dict,
+):
+ """Compute county-dependent values for all counties in one state.
+
+ Top-level function (not a method) so it is picklable for
+ ``ProcessPoolExecutor``. Creates one ``Microsimulation`` per
+ state and reuses it across counties within that state.
+
+ Args:
+ dataset_path: Path to the base CPS h5 file.
+ time_period: Tax year for simulation.
+ state_fips: State FIPS code for this group.
+ counties: List of county FIPS strings in this state.
+ n_hh: Number of household records.
+ county_dep_targets: County-dependent target var names.
+ rerandomize_takeup: Force takeup=True if True.
+ affected_targets: Takeup-affected target info dict.
+
+ Returns:
+ list of (county_fips_str, {"hh": {...}, "entity": {...}})
+ """
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS
+ from policyengine_us_data.calibration.calibration_utils import (
+ get_calculated_variables,
+ )
+ from policyengine_us_data.calibration.block_assignment import (
+ get_county_enum_index_from_fips,
+ )
+
+ state_sim = Microsimulation(dataset=dataset_path)
+
+ state_sim.set_input(
+ "state_fips",
+ time_period,
+ np.full(n_hh, state_fips, dtype=np.int32),
+ )
+
+ original_takeup = {}
+ if rerandomize_takeup:
+ for spec in SIMPLE_TAKEUP_VARS:
+ entity = spec["entity"]
+ original_takeup[spec["variable"]] = (
+ entity,
+ state_sim.calculate(
+ spec["variable"],
+ time_period,
+ map_to=entity,
+ ).values.copy(),
+ )
+
+ results = []
+ for county_fips in counties:
+ county_idx = get_county_enum_index_from_fips(county_fips)
+ state_sim.set_input(
+ "county",
+ time_period,
+ np.full(n_hh, county_idx, dtype=np.int32),
+ )
+ if rerandomize_takeup:
+ for vname, (ent, orig) in original_takeup.items():
+ state_sim.set_input(vname, time_period, orig)
+ for var in get_calculated_variables(state_sim):
+ if var != "county":
+ state_sim.delete_arrays(var)
+
+ hh = {}
+ for var in county_dep_targets:
+ if var.endswith("_count"):
+ continue
+ try:
+ hh[var] = state_sim.calculate(
+ var,
+ time_period,
+ map_to="household",
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate '%s' for " "county %s: %s",
+ var,
+ county_fips,
+ exc,
+ )
+
+ if rerandomize_takeup:
+ for spec in SIMPLE_TAKEUP_VARS:
+ entity = spec["entity"]
+ n_ent = len(
+ state_sim.calculate(f"{entity}_id", map_to=entity).values
+ )
+ state_sim.set_input(
+ spec["variable"],
+ time_period,
+ np.ones(n_ent, dtype=bool),
+ )
+ for var in get_calculated_variables(state_sim):
+ if var != "county":
+ state_sim.delete_arrays(var)
+
+ entity_vals = {}
+ if rerandomize_takeup:
+ for tvar, info in affected_targets.items():
+ entity_level = info["entity"]
+ try:
+ entity_vals[tvar] = state_sim.calculate(
+ tvar,
+ time_period,
+ map_to=entity_level,
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate entity-level "
+ "'%s' for county %s: %s",
+ tvar,
+ county_fips,
+ exc,
+ )
+
+ results.append((county_fips, {"hh": hh, "entity": entity_vals}))
+
+ return results
+
+
+# ---------------------------------------------------------------
+# Clone-loop parallelisation helpers (module-level for pickling)
+# ---------------------------------------------------------------
+
+_CLONE_SHARED: dict = {}
+
+
+def _init_clone_worker(shared_data: dict) -> None:
+ """Initialise worker process with shared read-only data.
+
+ Called once per worker at ``ProcessPoolExecutor`` startup so the
+ ~50-200 MB payload is pickled *per worker* (not per clone).
+ """
+ _CLONE_SHARED.update(shared_data)
+
+
+def _assemble_clone_values_standalone(
+ state_values: dict,
+ clone_states: np.ndarray,
+ person_hh_indices: np.ndarray,
+ target_vars: set,
+ constraint_vars: set,
+ county_values: dict = None,
+ clone_counties: np.ndarray = None,
+ county_dependent_vars: set = None,
+) -> tuple:
+ """Standalone clone-value assembly (no ``self``).
+
+ Identical logic to
+ ``UnifiedMatrixBuilder._assemble_clone_values`` but usable
+ from a worker process.
+ """
+ n_records = len(clone_states)
+ n_persons = len(person_hh_indices)
+ person_states = clone_states[person_hh_indices]
+ unique_clone_states = np.unique(clone_states)
+ cdv = county_dependent_vars or set()
+
+ state_masks = {int(s): clone_states == s for s in unique_clone_states}
+ unique_person_states = np.unique(person_states)
+ person_state_masks = {
+ int(s): person_states == s for s in unique_person_states
+ }
+ county_masks = {}
+ unique_counties = None
+ if clone_counties is not None and county_values:
+ unique_counties = np.unique(clone_counties)
+ county_masks = {c: clone_counties == c for c in unique_counties}
+
+ hh_vars: dict = {}
+ for var in target_vars:
+ if var.endswith("_count"):
+ continue
+ if var in cdv and county_values and clone_counties is not None:
+ first_county = unique_counties[0]
+ if var not in county_values.get(first_county, {}).get("hh", {}):
+ continue
+ arr = np.empty(n_records, dtype=np.float32)
+ for county in unique_counties:
+ mask = county_masks[county]
+ county_hh = county_values.get(county, {}).get("hh", {})
+ if var in county_hh:
+ arr[mask] = county_hh[var][mask]
+ else:
+ st = int(county[:2])
+ arr[mask] = state_values[st]["hh"][var][mask]
+ hh_vars[var] = arr
+ else:
+ if var not in state_values[unique_clone_states[0]]["hh"]:
+ continue
+ arr = np.empty(n_records, dtype=np.float32)
+ for state in unique_clone_states:
+ mask = state_masks[int(state)]
+ arr[mask] = state_values[int(state)]["hh"][var][mask]
+ hh_vars[var] = arr
+
+ person_vars: dict = {}
+ for var in constraint_vars:
+ if var not in state_values[unique_clone_states[0]]["person"]:
+ continue
+ arr = np.empty(n_persons, dtype=np.float32)
+ for state in unique_person_states:
+ mask = person_state_masks[int(state)]
+ arr[mask] = state_values[int(state)]["person"][var][mask]
+ person_vars[var] = arr
+
+ return hh_vars, person_vars
+
+
+def _evaluate_constraints_standalone(
+ constraints,
+ person_vars: dict,
+ entity_rel: pd.DataFrame,
+ household_ids: np.ndarray,
+ n_households: int,
+) -> np.ndarray:
+ """Standalone constraint evaluation (no class instance).
+
+ Evaluates person-level constraints and aggregates to
+ household level via .any().
+ """
+ if not constraints:
+ return np.ones(n_households, dtype=bool)
+
+ n_persons = len(entity_rel)
+ person_mask = np.ones(n_persons, dtype=bool)
+
+ for c in constraints:
+ var = c["variable"]
+ if var not in person_vars:
+ logger.warning(
+ "Constraint var '%s' not in " "precomputed person_vars",
+ var,
+ )
+ return np.zeros(n_households, dtype=bool)
+ vals = person_vars[var]
+ person_mask &= apply_op(vals, c["operation"], c["value"])
+
+ df = entity_rel.copy()
+ df["satisfies"] = person_mask
+ hh_mask = df.groupby("household_id")["satisfies"].any()
+ return np.array([hh_mask.get(hid, False) for hid in household_ids])
+
+
+def _calculate_target_values_standalone(
+ target_variable: str,
+ non_geo_constraints: list,
+ n_households: int,
+ hh_vars: dict,
+ person_vars: dict,
+ entity_rel: pd.DataFrame,
+ household_ids: np.ndarray,
+ variable_entity_map: dict,
+) -> np.ndarray:
+ """Standalone target-value calculation (no class instance).
+
+ Uses ``variable_entity_map`` dict for entity resolution
+ (picklable, unlike ``tax_benefit_system``).
+ """
+ is_count = target_variable.endswith("_count")
+
+ if not is_count:
+ mask = _evaluate_constraints_standalone(
+ non_geo_constraints,
+ person_vars,
+ entity_rel,
+ household_ids,
+ n_households,
+ )
+ vals = hh_vars.get(target_variable)
+ if vals is None:
+ return np.zeros(n_households, dtype=np.float32)
+ return (vals * mask).astype(np.float32)
+
+ # Count target: entity-aware counting
+ n_persons = len(entity_rel)
+ person_mask = np.ones(n_persons, dtype=bool)
+
+ for c in non_geo_constraints:
+ var = c["variable"]
+ if var not in person_vars:
+ return np.zeros(n_households, dtype=np.float32)
+ cv = person_vars[var]
+ person_mask &= apply_op(cv, c["operation"], c["value"])
+
+ target_entity = variable_entity_map.get(target_variable)
+ if target_entity is None:
+ return np.zeros(n_households, dtype=np.float32)
+
+ if target_entity == "household":
+ if non_geo_constraints:
+ mask = _evaluate_constraints_standalone(
+ non_geo_constraints,
+ person_vars,
+ entity_rel,
+ household_ids,
+ n_households,
+ )
+ return mask.astype(np.float32)
+ return np.ones(n_households, dtype=np.float32)
+
+ if target_entity == "person":
+ er = entity_rel.copy()
+ er["satisfies"] = person_mask
+ filtered = er[er["satisfies"]]
+ counts = filtered.groupby("household_id")["person_id"].nunique()
+ else:
+ eid_col = f"{target_entity}_id"
+ er = entity_rel.copy()
+ er["satisfies"] = person_mask
+ entity_ok = er.groupby(eid_col)["satisfies"].any()
+ unique = er[["household_id", eid_col]].drop_duplicates()
+ unique["entity_ok"] = unique[eid_col].map(entity_ok)
+ filtered = unique[unique["entity_ok"]]
+ counts = filtered.groupby("household_id")[eid_col].nunique()
+
+ return np.array(
+ [counts.get(hid, 0) for hid in household_ids],
+ dtype=np.float32,
+ )
+
+
+def _process_single_clone(
+ clone_idx: int,
+ col_start: int,
+ col_end: int,
+ cache_path: str,
+) -> tuple:
+ """Process one clone in a worker process.
+
+ Reads shared read-only data from ``_CLONE_SHARED``
+ (populated by ``_init_clone_worker``). Writes COO
+ entries as a compressed ``.npz`` file to *cache_path*.
+
+ Args:
+ clone_idx: Zero-based clone index.
+ col_start: First column index for this clone.
+ col_end: One-past-last column index.
+ cache_path: File path for output ``.npz``.
+
+ Returns:
+ (clone_idx, n_nonzero) tuple.
+ """
+ sd = _CLONE_SHARED
+
+ # Unpack shared data
+ geo_states = sd["geography_state_fips"]
+ geo_counties = sd["geography_county_fips"]
+ geo_blocks = sd["geography_block_geoid"]
+ state_values = sd["state_values"]
+ county_values = sd["county_values"]
+ person_hh_indices = sd["person_hh_indices"]
+ unique_variables = sd["unique_variables"]
+ unique_constraint_vars = sd["unique_constraint_vars"]
+ county_dep_targets = sd["county_dep_targets"]
+ target_variables = sd["target_variables"]
+ target_geo_info = sd["target_geo_info"]
+ non_geo_constraints_list = sd["non_geo_constraints_list"]
+ n_records = sd["n_records"]
+ n_total = sd["n_total"]
+ n_targets = sd["n_targets"]
+ state_to_cols = sd["state_to_cols"]
+ cd_to_cols = sd["cd_to_cols"]
+ entity_rel = sd["entity_rel"]
+ household_ids = sd["household_ids"]
+ variable_entity_map = sd["variable_entity_map"]
+ do_takeup = sd["rerandomize_takeup"]
+ affected_target_info = sd["affected_target_info"]
+ entity_hh_idx_map = sd.get("entity_hh_idx_map", {})
+ entity_to_person_idx = sd.get("entity_to_person_idx", {})
+ precomputed_rates = sd.get("precomputed_rates", {})
+
+ # Slice geography for this clone
+ clone_states = geo_states[col_start:col_end]
+ clone_counties = geo_counties[col_start:col_end]
+
+ # Assemble hh/person values from precomputed state/county
+ hh_vars, person_vars = _assemble_clone_values_standalone(
+ state_values,
+ clone_states,
+ person_hh_indices,
+ unique_variables,
+ unique_constraint_vars,
+ county_values=county_values,
+ clone_counties=clone_counties,
+ county_dependent_vars=county_dep_targets,
+ )
+
+ # Takeup re-randomisation
+ if do_takeup and affected_target_info:
+ from policyengine_us_data.utils.takeup import (
+ compute_block_takeup_for_entities,
+ )
+
+ clone_blocks = geo_blocks[col_start:col_end]
+
+ for tvar, info in affected_target_info.items():
+ if tvar.endswith("_count"):
+ continue
+ entity_level = info["entity"]
+ takeup_var = info["takeup_var"]
+ ent_hh = entity_hh_idx_map[entity_level]
+ n_ent = len(ent_hh)
+ ent_states = clone_states[ent_hh]
+
+ ent_eligible = np.zeros(n_ent, dtype=np.float32)
+ if tvar in county_dep_targets and county_values:
+ ent_counties = clone_counties[ent_hh]
+ for cfips in np.unique(ent_counties):
+ m = ent_counties == cfips
+ cv = county_values.get(cfips, {}).get("entity", {})
+ if tvar in cv:
+ ent_eligible[m] = cv[tvar][m]
+ else:
+ st = int(cfips[:2])
+ sv = state_values[st]["entity"]
+ if tvar in sv:
+ ent_eligible[m] = sv[tvar][m]
+ else:
+ for st in np.unique(ent_states):
+ m = ent_states == st
+ sv = state_values[int(st)]["entity"]
+ if tvar in sv:
+ ent_eligible[m] = sv[tvar][m]
+
+ ent_blocks = clone_blocks[ent_hh]
+ ent_hh_ids = household_ids[ent_hh]
+
+ ent_takeup = compute_block_takeup_for_entities(
+ takeup_var,
+ precomputed_rates[info["rate_key"]],
+ ent_blocks,
+ ent_hh_ids,
+ )
+
+ ent_values = (ent_eligible * ent_takeup).astype(np.float32)
+
+ hh_result = np.zeros(n_records, dtype=np.float32)
+ np.add.at(hh_result, ent_hh, ent_values)
+ hh_vars[tvar] = hh_result
+
+ if tvar in person_vars:
+ pidx = entity_to_person_idx[entity_level]
+ person_vars[tvar] = ent_values[pidx]
+
+ # Build COO entries for every target row
+ mask_cache: dict = {}
+ count_cache: dict = {}
+ rows_list: list = []
+ cols_list: list = []
+ vals_list: list = []
+
+ for row_idx in range(n_targets):
+ variable = target_variables[row_idx]
+ geo_level, geo_id = target_geo_info[row_idx]
+ non_geo = non_geo_constraints_list[row_idx]
+
+ if geo_level == "district":
+ all_geo_cols = cd_to_cols.get(
+ str(geo_id),
+ np.array([], dtype=np.int64),
+ )
+ elif geo_level == "state":
+ all_geo_cols = state_to_cols.get(
+ int(geo_id),
+ np.array([], dtype=np.int64),
+ )
+ else:
+ all_geo_cols = np.arange(n_total)
+
+ clone_cols = all_geo_cols[
+ (all_geo_cols >= col_start) & (all_geo_cols < col_end)
+ ]
+ if len(clone_cols) == 0:
+ continue
+
+ rec_indices = clone_cols - col_start
+
+ constraint_key = tuple(
+ sorted(
+ (
+ c["variable"],
+ c["operation"],
+ c["value"],
+ )
+ for c in non_geo
+ )
+ )
+
+ if variable.endswith("_count"):
+ vkey = (variable, constraint_key)
+ if vkey not in count_cache:
+ count_cache[vkey] = _calculate_target_values_standalone(
+ variable,
+ non_geo,
+ n_records,
+ hh_vars,
+ person_vars,
+ entity_rel,
+ household_ids,
+ variable_entity_map,
+ )
+ values = count_cache[vkey]
+ else:
+ if variable not in hh_vars:
+ continue
+ if constraint_key not in mask_cache:
+ mask_cache[constraint_key] = _evaluate_constraints_standalone(
+ non_geo,
+ person_vars,
+ entity_rel,
+ household_ids,
+ n_records,
+ )
+ mask = mask_cache[constraint_key]
+ values = hh_vars[variable] * mask
+
+ vals = values[rec_indices]
+ nonzero = vals != 0
+ if nonzero.any():
+ rows_list.append(
+ np.full(
+ nonzero.sum(),
+ row_idx,
+ dtype=np.int32,
+ )
+ )
+ cols_list.append(clone_cols[nonzero].astype(np.int32))
+ vals_list.append(vals[nonzero])
+
+ # Write COO
+ if rows_list:
+ cr = np.concatenate(rows_list)
+ cc = np.concatenate(cols_list)
+ cv = np.concatenate(vals_list)
+ else:
+ cr = np.array([], dtype=np.int32)
+ cc = np.array([], dtype=np.int32)
+ cv = np.array([], dtype=np.float32)
+
+ np.savez_compressed(cache_path, rows=cr, cols=cc, vals=cv)
+ return clone_idx, len(cv)
+
class UnifiedMatrixBuilder:
"""Build sparse calibration matrix for cloned CPS records.
@@ -88,48 +765,597 @@ def _build_entity_relationship(self, sim) -> pd.DataFrame:
return self._entity_rel_cache
# ---------------------------------------------------------------
- # Constraint evaluation
+ # Per-state precomputation
# ---------------------------------------------------------------
- def _evaluate_constraints_entity_aware(
+ def _build_state_values(
self,
sim,
- constraints: List[dict],
- n_households: int,
- ) -> np.ndarray:
- """Evaluate constraints at person level, aggregate to
- household level via .any()."""
- if not constraints:
- return np.ones(n_households, dtype=bool)
+ target_vars: set,
+ constraint_vars: set,
+ geography,
+ rerandomize_takeup: bool = True,
+ workers: int = 1,
+ ) -> dict:
+ """Precompute household/person/entity values per state.
- entity_rel = self._build_entity_relationship(sim)
- n_persons = len(entity_rel)
- person_mask = np.ones(n_persons, dtype=bool)
+ Creates a fresh Microsimulation per state to prevent
+ cross-state cache pollution (stale intermediate values
+ from one state leaking into another's calculations).
- for c in constraints:
- try:
- vals = sim.calculate(
- c["variable"],
+ County-dependent variables (e.g. aca_ptc) are computed
+ here as a state-level fallback; county-level overrides
+ are applied later via ``_build_county_values``.
+
+ Args:
+ sim: Microsimulation instance (unused; kept for API
+ compatibility).
+ target_vars: Set of target variable names.
+ constraint_vars: Set of constraint variable names.
+ geography: GeographyAssignment with state_fips.
+ rerandomize_takeup: If True, force takeup=True and
+ also store entity-level eligible amounts for
+ takeup-affected targets.
+ workers: Number of parallel worker processes.
+ When >1, uses ProcessPoolExecutor.
+
+ Returns:
+ {state_fips: {
+ 'hh': {var: array},
+ 'person': {var: array},
+ 'entity': {var: array} # only if rerandomize
+ }}
+ """
+ from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ )
+
+ unique_states = sorted(set(int(s) for s in geography.state_fips))
+ n_hh = geography.n_records
+
+ logger.info(
+ "Per-state precomputation: %d states, "
+ "%d hh vars, %d constraint vars "
+ "(fresh sim per state, workers=%d)",
+ len(unique_states),
+ len([v for v in target_vars if not v.endswith("_count")]),
+ len(constraint_vars),
+ workers,
+ )
+
+ # Identify takeup-affected targets before the state loop
+ affected_targets = {}
+ if rerandomize_takeup:
+ for tvar in target_vars:
+ for key, info in TAKEUP_AFFECTED_TARGETS.items():
+ if tvar == key or tvar.startswith(key):
+ affected_targets[tvar] = info
+ break
+
+ # Convert sets to sorted lists for deterministic iteration
+ target_vars_list = sorted(target_vars)
+ constraint_vars_list = sorted(constraint_vars)
+
+ state_values = {}
+
+ if workers > 1:
+ from concurrent.futures import (
+ ProcessPoolExecutor,
+ as_completed,
+ )
+
+ logger.info(
+ "Parallel state precomputation with %d workers",
+ workers,
+ )
+ with ProcessPoolExecutor(max_workers=workers) as pool:
+ futures = {
+ pool.submit(
+ _compute_single_state,
+ self.dataset_path,
+ self.time_period,
+ st,
+ n_hh,
+ target_vars_list,
+ constraint_vars_list,
+ rerandomize_takeup,
+ affected_targets,
+ ): st
+ for st in unique_states
+ }
+ completed = 0
+ for future in as_completed(futures):
+ st = futures[future]
+ try:
+ sf, vals = future.result()
+ state_values[sf] = vals
+ completed += 1
+ if completed % 10 == 0 or completed == 1:
+ logger.info(
+ "State %d/%d complete",
+ completed,
+ len(unique_states),
+ )
+ except Exception as exc:
+ for f in futures:
+ f.cancel()
+ raise RuntimeError(
+ f"State {st} failed: {exc}"
+ ) from exc
+ else:
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.utils.takeup import (
+ SIMPLE_TAKEUP_VARS,
+ )
+
+ for i, state in enumerate(unique_states):
+ state_sim = Microsimulation(dataset=self.dataset_path)
+
+ state_sim.set_input(
+ "state_fips",
self.time_period,
- map_to="person",
- ).values
- except Exception as exc:
- logger.warning(
- "Cannot evaluate constraint '%s': %s",
- c["variable"],
- exc,
+ np.full(n_hh, state, dtype=np.int32),
)
- return np.zeros(n_households, dtype=bool)
- person_mask &= apply_op(vals, c["operation"], c["value"])
+ for var in get_calculated_variables(state_sim):
+ state_sim.delete_arrays(var)
- df = entity_rel.copy()
- df["satisfies"] = person_mask
- hh_mask = df.groupby("household_id")["satisfies"].any()
+ hh = {}
+ for var in target_vars:
+ if var.endswith("_count"):
+ continue
+ try:
+ hh[var] = state_sim.calculate(
+ var,
+ self.time_period,
+ map_to="household",
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate '%s' " "for state %d: %s",
+ var,
+ state,
+ exc,
+ )
- household_ids = sim.calculate(
- "household_id", map_to="household"
- ).values
- return np.array([hh_mask.get(hid, False) for hid in household_ids])
+ person = {}
+ for var in constraint_vars:
+ try:
+ person[var] = state_sim.calculate(
+ var,
+ self.time_period,
+ map_to="person",
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate constraint "
+ "'%s' for state %d: %s",
+ var,
+ state,
+ exc,
+ )
+
+ if rerandomize_takeup:
+ for spec in SIMPLE_TAKEUP_VARS:
+ entity = spec["entity"]
+ n_ent = len(
+ state_sim.calculate(
+ f"{entity}_id", map_to=entity
+ ).values
+ )
+ state_sim.set_input(
+ spec["variable"],
+ self.time_period,
+ np.ones(n_ent, dtype=bool),
+ )
+ for var in get_calculated_variables(state_sim):
+ state_sim.delete_arrays(var)
+
+ entity_vals = {}
+ if rerandomize_takeup:
+ for tvar, info in affected_targets.items():
+ entity_level = info["entity"]
+ try:
+ entity_vals[tvar] = state_sim.calculate(
+ tvar,
+ self.time_period,
+ map_to=entity_level,
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate entity-level "
+ "'%s' (map_to=%s) for "
+ "state %d: %s",
+ tvar,
+ entity_level,
+ state,
+ exc,
+ )
+
+ state_values[state] = {
+ "hh": hh,
+ "person": person,
+ "entity": entity_vals,
+ }
+ if (i + 1) % 10 == 0 or i == 0:
+ logger.info(
+ "State %d/%d complete",
+ i + 1,
+ len(unique_states),
+ )
+
+ logger.info(
+ "Per-state precomputation done: %d states",
+ len(state_values),
+ )
+ return state_values
+
+ def _build_county_values(
+ self,
+ sim,
+ county_dep_targets: set,
+ geography,
+ rerandomize_takeup: bool = True,
+ county_level: bool = True,
+ workers: int = 1,
+ ) -> dict:
+ """Precompute county-dependent variable values per county.
+
+ Only iterates over COUNTY_DEPENDENT_VARS that actually
+ benefit from per-county computation. All other target
+ variables use state-level values from _build_state_values.
+
+ Creates a fresh Microsimulation per state group to prevent
+ cross-state cache pollution. Counties within the same state
+ share a simulation since within-state recalculation is clean
+ (only cross-state switches cause pollution).
+
+ When county_level=False, returns an empty dict immediately
+ (all values come from state-level precomputation).
+
+ Args:
+ sim: Microsimulation instance (unused; kept for API
+ compatibility).
+ county_dep_targets: Subset of target vars that depend
+ on county (intersection of targets with
+ COUNTY_DEPENDENT_VARS).
+ geography: GeographyAssignment with county_fips.
+ rerandomize_takeup: If True, force takeup=True and
+ also store entity-level eligible amounts for
+ takeup-affected targets.
+ county_level: If True, iterate counties within each
+ state. If False, return empty dict (skip county
+ computation entirely).
+ workers: Number of parallel worker processes.
+ When >1, uses ProcessPoolExecutor.
+
+ Returns:
+ {county_fips_str: {
+ 'hh': {var: array},
+ 'entity': {var: array}
+ }}
+ """
+ if not county_level or not county_dep_targets:
+ if not county_level:
+ logger.info(
+ "County-level computation disabled " "(skip-county mode)"
+ )
+ else:
+ logger.info(
+ "No county-dependent target vars; "
+ "skipping county precomputation"
+ )
+ return {}
+
+ from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ )
+
+ unique_counties = sorted(set(geography.county_fips))
+ n_hh = geography.n_records
+
+ state_to_counties = defaultdict(list)
+ for county in unique_counties:
+ state_to_counties[int(county[:2])].append(county)
+
+ logger.info(
+ "Per-county precomputation: %d counties in %d "
+ "states, %d county-dependent vars "
+ "(fresh sim per state, workers=%d)",
+ len(unique_counties),
+ len(state_to_counties),
+ len(county_dep_targets),
+ workers,
+ )
+
+ affected_targets = {}
+ if rerandomize_takeup:
+ for tvar in county_dep_targets:
+ for key, info in TAKEUP_AFFECTED_TARGETS.items():
+ if tvar == key or tvar.startswith(key):
+ affected_targets[tvar] = info
+ break
+
+ # Convert to sorted list for deterministic iteration
+ county_dep_targets_list = sorted(county_dep_targets)
+
+ county_values = {}
+
+ if workers > 1:
+ from concurrent.futures import (
+ ProcessPoolExecutor,
+ as_completed,
+ )
+
+ logger.info(
+ "Parallel county precomputation with "
+ "%d workers (%d state groups)",
+ workers,
+ len(state_to_counties),
+ )
+ with ProcessPoolExecutor(max_workers=workers) as pool:
+ futures = {
+ pool.submit(
+ _compute_single_state_group_counties,
+ self.dataset_path,
+ self.time_period,
+ sf,
+ counties,
+ n_hh,
+ county_dep_targets_list,
+ rerandomize_takeup,
+ affected_targets,
+ ): sf
+ for sf, counties in sorted(state_to_counties.items())
+ }
+ completed = 0
+ county_count = 0
+ for future in as_completed(futures):
+ sf = futures[future]
+ try:
+ results = future.result()
+ for cfips, vals in results:
+ county_values[cfips] = vals
+ county_count += 1
+ completed += 1
+ if county_count % 500 == 0 or completed == 1:
+ logger.info(
+ "County %d/%d complete "
+ "(%d/%d state groups)",
+ county_count,
+ len(unique_counties),
+ completed,
+ len(state_to_counties),
+ )
+ except Exception as exc:
+ for f in futures:
+ f.cancel()
+ raise RuntimeError(
+ f"State group {sf} failed: " f"{exc}"
+ ) from exc
+ else:
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.utils.takeup import (
+ SIMPLE_TAKEUP_VARS,
+ )
+
+ county_count = 0
+ for state_fips, counties in sorted(state_to_counties.items()):
+ state_sim = Microsimulation(dataset=self.dataset_path)
+
+ state_sim.set_input(
+ "state_fips",
+ self.time_period,
+ np.full(n_hh, state_fips, dtype=np.int32),
+ )
+
+ original_takeup = {}
+ if rerandomize_takeup:
+ for spec in SIMPLE_TAKEUP_VARS:
+ entity = spec["entity"]
+ original_takeup[spec["variable"]] = (
+ entity,
+ state_sim.calculate(
+ spec["variable"],
+ self.time_period,
+ map_to=entity,
+ ).values.copy(),
+ )
+
+ for county_fips in counties:
+ county_idx = get_county_enum_index_from_fips(county_fips)
+ state_sim.set_input(
+ "county",
+ self.time_period,
+ np.full(
+ n_hh,
+ county_idx,
+ dtype=np.int32,
+ ),
+ )
+ if rerandomize_takeup:
+ for vname, (
+ ent,
+ orig,
+ ) in original_takeup.items():
+ state_sim.set_input(
+ vname,
+ self.time_period,
+ orig,
+ )
+ for var in get_calculated_variables(state_sim):
+ if var != "county":
+ state_sim.delete_arrays(var)
+
+ hh = {}
+ for var in county_dep_targets:
+ if var.endswith("_count"):
+ continue
+ try:
+ hh[var] = state_sim.calculate(
+ var,
+ self.time_period,
+ map_to="household",
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate '%s' " "for county %s: %s",
+ var,
+ county_fips,
+ exc,
+ )
+
+ if rerandomize_takeup:
+ for spec in SIMPLE_TAKEUP_VARS:
+ entity = spec["entity"]
+ n_ent = len(
+ state_sim.calculate(
+ f"{entity}_id",
+ map_to=entity,
+ ).values
+ )
+ state_sim.set_input(
+ spec["variable"],
+ self.time_period,
+ np.ones(n_ent, dtype=bool),
+ )
+ for var in get_calculated_variables(state_sim):
+ if var != "county":
+ state_sim.delete_arrays(var)
+
+ entity_vals = {}
+ if rerandomize_takeup:
+ for (
+ tvar,
+ info,
+ ) in affected_targets.items():
+ entity_level = info["entity"]
+ try:
+ entity_vals[tvar] = state_sim.calculate(
+ tvar,
+ self.time_period,
+ map_to=entity_level,
+ ).values.astype(np.float32)
+ except Exception as exc:
+ logger.warning(
+ "Cannot calculate "
+ "entity-level '%s' "
+ "for county %s: %s",
+ tvar,
+ county_fips,
+ exc,
+ )
+
+ county_values[county_fips] = {
+ "hh": hh,
+ "entity": entity_vals,
+ }
+ county_count += 1
+ if county_count % 500 == 0 or county_count == 1:
+ logger.info(
+ "County %d/%d complete",
+ county_count,
+ len(unique_counties),
+ )
+
+ logger.info(
+ "Per-county precomputation done: %d counties",
+ len(county_values),
+ )
+ return county_values
+
+ def _assemble_clone_values(
+ self,
+ state_values: dict,
+ clone_states: np.ndarray,
+ person_hh_indices: np.ndarray,
+ target_vars: set,
+ constraint_vars: set,
+ county_values: dict = None,
+ clone_counties: np.ndarray = None,
+ county_dependent_vars: set = None,
+ ) -> tuple:
+ """Assemble per-clone values from state/county precomputation.
+
+ For each target variable, selects values from either
+ county_values (if the var is county-dependent) or
+ state_values (otherwise) using numpy fancy indexing.
+
+ Args:
+ state_values: Output of _build_state_values.
+ clone_states: State FIPS per record for this clone.
+ person_hh_indices: Maps person index to household
+ index (0..n_records-1).
+ target_vars: Set of target variable names.
+ constraint_vars: Set of constraint variable names.
+ county_values: Output of _build_county_values.
+ clone_counties: County FIPS per record for this
+ clone (str array).
+ county_dependent_vars: Set of var names that should
+ be looked up by county instead of state.
+
+ Returns:
+ (hh_vars, person_vars) where hh_vars maps variable
+ name to household-level float32 array and person_vars
+ maps constraint variable name to person-level array.
+ """
+ n_records = len(clone_states)
+ n_persons = len(person_hh_indices)
+ person_states = clone_states[person_hh_indices]
+ unique_clone_states = np.unique(clone_states)
+ cdv = county_dependent_vars or set()
+
+ # Pre-compute masks to avoid recomputing per variable
+ state_masks = {int(s): clone_states == s for s in unique_clone_states}
+ unique_person_states = np.unique(person_states)
+ person_state_masks = {
+ int(s): person_states == s for s in unique_person_states
+ }
+ county_masks = {}
+ unique_counties = None
+ if clone_counties is not None and county_values:
+ unique_counties = np.unique(clone_counties)
+ county_masks = {c: clone_counties == c for c in unique_counties}
+
+ hh_vars = {}
+ for var in target_vars:
+ if var.endswith("_count"):
+ continue
+ if var in cdv and county_values and clone_counties is not None:
+ first_county = unique_counties[0]
+ if var not in county_values.get(first_county, {}).get(
+ "hh", {}
+ ):
+ continue
+ arr = np.empty(n_records, dtype=np.float32)
+ for county in unique_counties:
+ mask = county_masks[county]
+ county_hh = county_values.get(county, {}).get("hh", {})
+ if var in county_hh:
+ arr[mask] = county_hh[var][mask]
+ else:
+ st = int(county[:2])
+ arr[mask] = state_values[st]["hh"][var][mask]
+ hh_vars[var] = arr
+ else:
+ if var not in state_values[unique_clone_states[0]]["hh"]:
+ continue
+ arr = np.empty(n_records, dtype=np.float32)
+ for state in unique_clone_states:
+ mask = state_masks[int(state)]
+ arr[mask] = state_values[int(state)]["hh"][var][mask]
+ hh_vars[var] = arr
+
+ person_vars = {}
+ for var in constraint_vars:
+ if var not in state_values[unique_clone_states[0]]["person"]:
+ continue
+ arr = np.empty(n_persons, dtype=np.float32)
+ for state in unique_person_states:
+ mask = person_state_masks[int(state)]
+ arr[mask] = state_values[int(state)]["person"][var][mask]
+ person_vars[var] = arr
+
+ return hh_vars, person_vars
# ---------------------------------------------------------------
# Database queries
@@ -261,14 +1487,7 @@ def _get_uprating_info(
if period == self.time_period:
return 1.0, "none"
- count_indicators = [
- "count",
- "person",
- "people",
- "households",
- "tax_units",
- ]
- is_count = any(ind in variable.lower() for ind in count_indicators)
+ is_count = variable.endswith("_count")
uprating_type = "pop" if is_count else "cpi"
factor = factors.get((period, uprating_type), 1.0)
return factor, uprating_type
@@ -472,79 +1691,6 @@ def _make_target_name(
return "/".join(parts)
- # ---------------------------------------------------------------
- # Target value calculation
- # ---------------------------------------------------------------
-
- def _calculate_target_values(
- self,
- sim,
- target_variable: str,
- non_geo_constraints: List[dict],
- n_households: int,
- ) -> np.ndarray:
- """Calculate per-household target values.
-
- For count targets (*_count): count entities per HH
- satisfying constraints.
- For value targets: multiply values by constraint mask.
- """
- is_count = target_variable.endswith("_count")
-
- if not is_count:
- mask = self._evaluate_constraints_entity_aware(
- sim, non_geo_constraints, n_households
- )
- vals = sim.calculate(target_variable, map_to="household").values
- return (vals * mask).astype(np.float32)
-
- # Count target: entity-aware counting
- entity_rel = self._build_entity_relationship(sim)
- n_persons = len(entity_rel)
- person_mask = np.ones(n_persons, dtype=bool)
-
- for c in non_geo_constraints:
- try:
- cv = sim.calculate(c["variable"], map_to="person").values
- except Exception:
- return np.zeros(n_households, dtype=np.float32)
- person_mask &= apply_op(cv, c["operation"], c["value"])
-
- target_entity = sim.tax_benefit_system.variables[
- target_variable
- ].entity.key
- household_ids = sim.calculate(
- "household_id", map_to="household"
- ).values
-
- if target_entity == "household":
- if non_geo_constraints:
- mask = self._evaluate_constraints_entity_aware(
- sim, non_geo_constraints, n_households
- )
- return mask.astype(np.float32)
- return np.ones(n_households, dtype=np.float32)
-
- if target_entity == "person":
- er = entity_rel.copy()
- er["satisfies"] = person_mask
- filtered = er[er["satisfies"]]
- counts = filtered.groupby("household_id")["person_id"].nunique()
- else:
- eid_col = f"{target_entity}_id"
- er = entity_rel.copy()
- er["satisfies"] = person_mask
- entity_ok = er.groupby(eid_col)["satisfies"].any()
- unique = er[["household_id", eid_col]].drop_duplicates()
- unique["entity_ok"] = unique[eid_col].map(entity_ok)
- filtered = unique[unique["entity_ok"]]
- counts = filtered.groupby("household_id")[eid_col].nunique()
-
- return np.array(
- [counts.get(hid, 0) for hid in household_ids],
- dtype=np.float32,
- )
-
# ---------------------------------------------------------------
# Clone simulation
# ---------------------------------------------------------------
@@ -614,6 +1760,9 @@ def build_matrix(
hierarchical_domains: Optional[List[str]] = None,
cache_dir: Optional[str] = None,
sim_modifier=None,
+ rerandomize_takeup: bool = True,
+ county_level: bool = True,
+ workers: int = 1,
) -> Tuple[pd.DataFrame, sparse.csr_matrix, List[str]]:
"""Build sparse calibration matrix.
@@ -635,6 +1784,13 @@ def build_matrix(
called per clone after state_fips is set but
before cache clearing. Use for takeup
re-randomization.
+ rerandomize_takeup: If True, use geo-salted
+ entity-level takeup draws instead of base h5
+ takeup values for takeup-affected targets.
+ county_level: If True (default), iterate counties
+ within each state during precomputation. If
+ False, compute once per state and alias to all
+ counties (faster for county-invariant vars).
Returns:
(targets_df, X_sparse, target_names)
@@ -720,156 +1876,457 @@ def build_matrix(
unique_variables = set(targets_df["variable"].values)
- # 5. Clone loop
+ # 5a. Collect unique constraint variables
+ unique_constraint_vars = set()
+ for constraints in non_geo_constraints_list:
+ for c in constraints:
+ unique_constraint_vars.add(c["variable"])
+
+ # 5b. Per-state precomputation (51 sims on one object)
+ self._entity_rel_cache = None
+ state_values = self._build_state_values(
+ sim,
+ unique_variables,
+ unique_constraint_vars,
+ geography,
+ rerandomize_takeup=rerandomize_takeup,
+ workers=workers,
+ )
+
+ # 5b-county. Per-county precomputation for county-dependent vars
+ county_dep_targets = unique_variables & COUNTY_DEPENDENT_VARS
+ county_values = self._build_county_values(
+ sim,
+ county_dep_targets,
+ geography,
+ rerandomize_takeup=rerandomize_takeup,
+ county_level=county_level,
+ workers=workers,
+ )
+
+ # 5c. State-independent structures (computed once)
+ entity_rel = self._build_entity_relationship(sim)
+ household_ids = sim.calculate(
+ "household_id", map_to="household"
+ ).values
+ person_hh_ids = sim.calculate("household_id", map_to="person").values
+ hh_id_to_idx = {int(hid): idx for idx, hid in enumerate(household_ids)}
+ person_hh_indices = np.array(
+ [hh_id_to_idx[int(hid)] for hid in person_hh_ids]
+ )
+ tax_benefit_system = sim.tax_benefit_system
+
+ # Pre-extract entity keys so workers don't need
+ # the unpicklable TaxBenefitSystem object.
+ variable_entity_map: Dict[str, str] = {}
+ for var in unique_variables:
+ if var.endswith("_count") and var in tax_benefit_system.variables:
+ variable_entity_map[var] = tax_benefit_system.variables[
+ var
+ ].entity.key
+
+ # 5c-extra: Entity-to-household index maps for takeup
+ affected_target_info = {}
+ if rerandomize_takeup:
+ from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ compute_block_takeup_for_entities,
+ )
+ from policyengine_us_data.parameters import (
+ load_take_up_rate,
+ )
+
+ # Build entity-to-household index arrays
+ spm_to_hh_id = (
+ entity_rel.groupby("spm_unit_id")["household_id"]
+ .first()
+ .to_dict()
+ )
+ spm_ids = sim.calculate("spm_unit_id", map_to="spm_unit").values
+ spm_hh_idx = np.array(
+ [hh_id_to_idx[int(spm_to_hh_id[int(sid)])] for sid in spm_ids]
+ )
+
+ tu_to_hh_id = (
+ entity_rel.groupby("tax_unit_id")["household_id"]
+ .first()
+ .to_dict()
+ )
+ tu_ids = sim.calculate("tax_unit_id", map_to="tax_unit").values
+ tu_hh_idx = np.array(
+ [hh_id_to_idx[int(tu_to_hh_id[int(tid)])] for tid in tu_ids]
+ )
+
+ entity_hh_idx_map = {
+ "spm_unit": spm_hh_idx,
+ "tax_unit": tu_hh_idx,
+ "person": person_hh_indices,
+ }
+
+ entity_to_person_idx = {}
+ for entity_level in ("spm_unit", "tax_unit"):
+ ent_ids = sim.calculate(
+ f"{entity_level}_id",
+ map_to=entity_level,
+ ).values
+ ent_id_to_idx = {
+ int(eid): idx for idx, eid in enumerate(ent_ids)
+ }
+ person_ent_ids = entity_rel[f"{entity_level}_id"].values
+ entity_to_person_idx[entity_level] = np.array(
+ [ent_id_to_idx[int(eid)] for eid in person_ent_ids]
+ )
+ entity_to_person_idx["person"] = np.arange(len(entity_rel))
+
+ for tvar in unique_variables:
+ for key, info in TAKEUP_AFFECTED_TARGETS.items():
+ if tvar == key:
+ affected_target_info[tvar] = info
+ break
+
+ logger.info(
+ "Block-level takeup enabled, " "%d affected target vars",
+ len(affected_target_info),
+ )
+
+ # Pre-compute takeup rates (constant across clones)
+ precomputed_rates = {}
+ for tvar, info in affected_target_info.items():
+ rk = info["rate_key"]
+ if rk not in precomputed_rates:
+ precomputed_rates[rk] = load_take_up_rate(
+ rk, self.time_period
+ )
+
+ # Store for post-optimization stacked takeup
+ self.entity_hh_idx_map = entity_hh_idx_map
+ self.household_ids = household_ids
+ self.precomputed_rates = precomputed_rates
+ self.affected_target_info = affected_target_info
+
+ # 5d. Clone loop
from pathlib import Path
- clone_dir = Path(cache_dir) if cache_dir else None
- if clone_dir:
+ if workers > 1:
+ # ---- Parallel clone processing ----
+ import concurrent.futures
+ import tempfile
+
+ if cache_dir:
+ clone_dir = Path(cache_dir)
+ else:
+ clone_dir = Path(tempfile.mkdtemp(prefix="clone_coo_"))
clone_dir.mkdir(parents=True, exist_ok=True)
- self._entity_rel_cache = None
+ target_variables = [
+ str(targets_df.iloc[i]["variable"]) for i in range(n_targets)
+ ]
- for clone_idx in range(n_clones):
+ shared_data = {
+ "geography_state_fips": geography.state_fips,
+ "geography_county_fips": geography.county_fips,
+ "geography_block_geoid": geography.block_geoid,
+ "state_values": state_values,
+ "county_values": county_values,
+ "person_hh_indices": person_hh_indices,
+ "unique_variables": unique_variables,
+ "unique_constraint_vars": unique_constraint_vars,
+ "county_dep_targets": county_dep_targets,
+ "target_variables": target_variables,
+ "target_geo_info": target_geo_info,
+ "non_geo_constraints_list": (non_geo_constraints_list),
+ "n_records": n_records,
+ "n_total": n_total,
+ "n_targets": n_targets,
+ "state_to_cols": state_to_cols,
+ "cd_to_cols": cd_to_cols,
+ "entity_rel": entity_rel,
+ "household_ids": household_ids,
+ "variable_entity_map": variable_entity_map,
+ "rerandomize_takeup": rerandomize_takeup,
+ "affected_target_info": affected_target_info,
+ }
+ if rerandomize_takeup and affected_target_info:
+ shared_data["entity_hh_idx_map"] = entity_hh_idx_map
+ shared_data["entity_to_person_idx"] = entity_to_person_idx
+ shared_data["precomputed_rates"] = precomputed_rates
+
+ logger.info(
+ "Starting parallel clone processing: " "%d clones, %d workers",
+ n_clones,
+ workers,
+ )
+
+ futures: dict = {}
+ with concurrent.futures.ProcessPoolExecutor(
+ max_workers=workers,
+ initializer=_init_clone_worker,
+ initargs=(shared_data,),
+ ) as pool:
+ for ci in range(n_clones):
+ coo_path = str(clone_dir / f"clone_{ci:04d}.npz")
+ if Path(coo_path).exists():
+ logger.info(
+ "Clone %d/%d cached.",
+ ci + 1,
+ n_clones,
+ )
+ continue
+ cs = ci * n_records
+ ce = cs + n_records
+ fut = pool.submit(
+ _process_single_clone,
+ ci,
+ cs,
+ ce,
+ coo_path,
+ )
+ futures[fut] = ci
+
+ for fut in concurrent.futures.as_completed(futures):
+ ci = futures[fut]
+ try:
+ _, nnz = fut.result()
+ if (ci + 1) % 50 == 0:
+ logger.info(
+ "Clone %d/%d done " "(%d nnz).",
+ ci + 1,
+ n_clones,
+ nnz,
+ )
+ except Exception as exc:
+ for f in futures:
+ f.cancel()
+ raise RuntimeError(
+ f"Clone {ci} failed: {exc}"
+ ) from exc
+
+ else:
+ # ---- Sequential clone processing (unchanged) ----
+ clone_dir = Path(cache_dir) if cache_dir else None
if clone_dir:
- coo_path = clone_dir / f"clone_{clone_idx:04d}.npz"
- if coo_path.exists():
+ clone_dir.mkdir(parents=True, exist_ok=True)
+
+ for clone_idx in range(n_clones):
+ if clone_dir:
+ coo_path = clone_dir / f"clone_{clone_idx:04d}.npz"
+ if coo_path.exists():
+ logger.info(
+ "Clone %d/%d cached, " "skipping.",
+ clone_idx + 1,
+ n_clones,
+ )
+ continue
+
+ col_start = clone_idx * n_records
+ col_end = col_start + n_records
+ clone_states = geography.state_fips[col_start:col_end]
+ clone_counties = geography.county_fips[col_start:col_end]
+
+ if (clone_idx + 1) % 50 == 0 or clone_idx == 0:
logger.info(
- "Clone %d/%d cached, skipping.",
+ "Assembling clone %d/%d "
+ "(cols %d-%d, "
+ "%d unique states)...",
clone_idx + 1,
n_clones,
+ col_start,
+ col_end - 1,
+ len(np.unique(clone_states)),
)
- continue
- col_start = clone_idx * n_records
- col_end = col_start + n_records
- clone_states = geography.state_fips[col_start:col_end]
+ hh_vars, person_vars = self._assemble_clone_values(
+ state_values,
+ clone_states,
+ person_hh_indices,
+ unique_variables,
+ unique_constraint_vars,
+ county_values=county_values,
+ clone_counties=clone_counties,
+ county_dependent_vars=(county_dep_targets),
+ )
- logger.info(
- "Processing clone %d/%d " "(cols %d-%d, %d unique states)...",
- clone_idx + 1,
- n_clones,
- col_start,
- col_end - 1,
- len(np.unique(clone_states)),
- )
+ # Apply geo-specific entity-level takeup
+ # for affected target variables
+ if rerandomize_takeup and affected_target_info:
+ clone_blocks = geography.block_geoid[col_start:col_end]
+ for (
+ tvar,
+ info,
+ ) in affected_target_info.items():
+ if tvar.endswith("_count"):
+ continue
+ entity_level = info["entity"]
+ takeup_var = info["takeup_var"]
+ ent_hh = entity_hh_idx_map[entity_level]
+ n_ent = len(ent_hh)
+
+ ent_states = clone_states[ent_hh]
+
+ ent_eligible = np.zeros(n_ent, dtype=np.float32)
+ if tvar in county_dep_targets and county_values:
+ ent_counties = clone_counties[ent_hh]
+ for cfips in np.unique(ent_counties):
+ m = ent_counties == cfips
+ cv = county_values.get(cfips, {}).get(
+ "entity", {}
+ )
+ if tvar in cv:
+ ent_eligible[m] = cv[tvar][m]
+ else:
+ st = int(cfips[:2])
+ sv = state_values[st]["entity"]
+ if tvar in sv:
+ ent_eligible[m] = sv[tvar][m]
+ else:
+ for st in np.unique(ent_states):
+ m = ent_states == st
+ sv = state_values[int(st)]["entity"]
+ if tvar in sv:
+ ent_eligible[m] = sv[tvar][m]
+
+ ent_blocks = clone_blocks[ent_hh]
+ ent_hh_ids = household_ids[ent_hh]
+
+ ent_takeup = compute_block_takeup_for_entities(
+ takeup_var,
+ precomputed_rates[info["rate_key"]],
+ ent_blocks,
+ ent_hh_ids,
+ )
- var_values, clone_sim = self._simulate_clone(
- clone_states,
- n_records,
- unique_variables,
- sim_modifier=sim_modifier,
- clone_idx=clone_idx,
- )
+ ent_values = (ent_eligible * ent_takeup).astype(
+ np.float32
+ )
- mask_cache: Dict[tuple, np.ndarray] = {}
- count_cache: Dict[tuple, np.ndarray] = {}
+ hh_result = np.zeros(n_records, dtype=np.float32)
+ np.add.at(hh_result, ent_hh, ent_values)
+ hh_vars[tvar] = hh_result
- rows_list: list = []
- cols_list: list = []
- vals_list: list = []
+ if tvar in person_vars:
+ pidx = entity_to_person_idx[entity_level]
+ person_vars[tvar] = ent_values[pidx]
- for row_idx in range(n_targets):
- variable = str(targets_df.iloc[row_idx]["variable"])
- geo_level, geo_id = target_geo_info[row_idx]
- non_geo = non_geo_constraints_list[row_idx]
+ mask_cache: Dict[tuple, np.ndarray] = {}
+ count_cache: Dict[tuple, np.ndarray] = {}
- # Geographic column selection
- if geo_level == "district":
- all_geo_cols = cd_to_cols.get(
- str(geo_id),
- np.array([], dtype=np.int64),
- )
- elif geo_level == "state":
- all_geo_cols = state_to_cols.get(
- int(geo_id),
- np.array([], dtype=np.int64),
- )
- else:
- all_geo_cols = np.arange(n_total)
+ rows_list: list = []
+ cols_list: list = []
+ vals_list: list = []
- clone_cols = all_geo_cols[
- (all_geo_cols >= col_start) & (all_geo_cols < col_end)
- ]
- if len(clone_cols) == 0:
- continue
+ for row_idx in range(n_targets):
+ variable = str(targets_df.iloc[row_idx]["variable"])
+ geo_level, geo_id = target_geo_info[row_idx]
+ non_geo = non_geo_constraints_list[row_idx]
+
+ if geo_level == "district":
+ all_geo_cols = cd_to_cols.get(
+ str(geo_id),
+ np.array([], dtype=np.int64),
+ )
+ elif geo_level == "state":
+ all_geo_cols = state_to_cols.get(
+ int(geo_id),
+ np.array([], dtype=np.int64),
+ )
+ else:
+ all_geo_cols = np.arange(n_total)
+
+ clone_cols = all_geo_cols[
+ (all_geo_cols >= col_start) & (all_geo_cols < col_end)
+ ]
+ if len(clone_cols) == 0:
+ continue
- rec_indices = clone_cols - col_start
+ rec_indices = clone_cols - col_start
- constraint_key = tuple(
- sorted(
- (
- c["variable"],
- c["operation"],
- c["value"],
+ constraint_key = tuple(
+ sorted(
+ (
+ c["variable"],
+ c["operation"],
+ c["value"],
+ )
+ for c in non_geo
)
- for c in non_geo
)
- )
- if variable.endswith("_count"):
- vkey = (variable, constraint_key)
- if vkey not in count_cache:
- count_cache[vkey] = self._calculate_target_values(
- clone_sim,
+ if variable.endswith("_count"):
+ vkey = (
variable,
- non_geo,
- n_records,
+ constraint_key,
)
- values = count_cache[vkey]
- else:
- if variable not in var_values:
- continue
- if constraint_key not in mask_cache:
- mask_cache[constraint_key] = (
- self._evaluate_constraints_entity_aware(
- clone_sim,
- non_geo,
- n_records,
+ if vkey not in count_cache:
+ count_cache[vkey] = (
+ _calculate_target_values_standalone(
+ target_variable=variable,
+ non_geo_constraints=non_geo,
+ n_households=n_records,
+ hh_vars=hh_vars,
+ person_vars=person_vars,
+ entity_rel=entity_rel,
+ household_ids=household_ids,
+ variable_entity_map=variable_entity_map,
+ )
+ )
+ values = count_cache[vkey]
+ else:
+ if variable not in hh_vars:
+ continue
+ if constraint_key not in mask_cache:
+ mask_cache[constraint_key] = (
+ _evaluate_constraints_standalone(
+ non_geo,
+ person_vars,
+ entity_rel,
+ household_ids,
+ n_records,
+ )
+ )
+ mask = mask_cache[constraint_key]
+ values = hh_vars[variable] * mask
+
+ vals = values[rec_indices]
+ nonzero = vals != 0
+ if nonzero.any():
+ rows_list.append(
+ np.full(
+ nonzero.sum(),
+ row_idx,
+ dtype=np.int32,
)
)
- mask = mask_cache[constraint_key]
- values = var_values[variable] * mask
-
- vals = values[rec_indices]
- nonzero = vals != 0
- if nonzero.any():
- rows_list.append(
- np.full(
- nonzero.sum(),
- row_idx,
- dtype=np.int32,
- )
+ cols_list.append(clone_cols[nonzero].astype(np.int32))
+ vals_list.append(vals[nonzero])
+
+ # Save COO entries
+ if rows_list:
+ cr = np.concatenate(rows_list)
+ cc = np.concatenate(cols_list)
+ cv = np.concatenate(vals_list)
+ else:
+ cr = np.array([], dtype=np.int32)
+ cc = np.array([], dtype=np.int32)
+ cv = np.array([], dtype=np.float32)
+
+ if clone_dir:
+ np.savez_compressed(
+ str(coo_path),
+ rows=cr,
+ cols=cc,
+ vals=cv,
)
- cols_list.append(clone_cols[nonzero].astype(np.int32))
- vals_list.append(vals[nonzero])
-
- # Save COO entries
- if rows_list:
- cr = np.concatenate(rows_list)
- cc = np.concatenate(cols_list)
- cv = np.concatenate(vals_list)
- else:
- cr = np.array([], dtype=np.int32)
- cc = np.array([], dtype=np.int32)
- cv = np.array([], dtype=np.float32)
-
- if clone_dir:
- np.savez_compressed(
- str(coo_path),
- rows=cr,
- cols=cc,
- vals=cv,
- )
- logger.info(
- "Clone %d: %d nonzero entries saved.",
- clone_idx + 1,
- len(cv),
- )
- del var_values, clone_sim
- else:
- self._coo_parts[0].append(cr)
- self._coo_parts[1].append(cc)
- self._coo_parts[2].append(cv)
+ if (clone_idx + 1) % 50 == 0:
+ logger.info(
+ "Clone %d: %d nonzero " "entries saved.",
+ clone_idx + 1,
+ len(cv),
+ )
+ del hh_vars, person_vars
+ else:
+ self._coo_parts[0].append(cr)
+ self._coo_parts[1].append(cc)
+ self._coo_parts[2].append(cv)
# 6. Assemble sparse matrix from COO data
logger.info("Assembling matrix from %d clones...", n_clones)
diff --git a/policyengine_us_data/calibration/validate_national_h5.py b/policyengine_us_data/calibration/validate_national_h5.py
new file mode 100644
index 000000000..c63632851
--- /dev/null
+++ b/policyengine_us_data/calibration/validate_national_h5.py
@@ -0,0 +1,158 @@
+"""Validate a national US.h5 file against reference values.
+
+Loads the national H5, computes key variables, and compares to
+known national totals. Also runs structural sanity checks.
+
+Usage:
+ python -m policyengine_us_data.calibration.validate_national_h5
+ python -m policyengine_us_data.calibration.validate_national_h5 \
+ --h5-path path/to/US.h5
+ python -m policyengine_us_data.calibration.validate_national_h5 \
+ --hf-path hf://policyengine/policyengine-us-data/national/US.h5
+"""
+
+import argparse
+
+VARIABLES = [
+ "adjusted_gross_income",
+ "employment_income",
+ "self_employment_income",
+ "tax_unit_partnership_s_corp_income",
+ "taxable_pension_income",
+ "dividend_income",
+ "net_capital_gains",
+ "rental_income",
+ "taxable_interest_income",
+ "social_security",
+ "snap",
+ "ssi",
+ "income_tax_before_credits",
+ "eitc",
+ "refundable_ctc",
+ "real_estate_taxes",
+ "rent",
+ "is_pregnant",
+ "person_count",
+ "household_count",
+]
+
+REFERENCES = {
+ "person_count": (335_000_000, "~335M"),
+ "household_count": (130_000_000, "~130M"),
+ "adjusted_gross_income": (15_000_000_000_000, "~$15T"),
+ "employment_income": (10_000_000_000_000, "~$10T"),
+ "social_security": (1_200_000_000_000, "~$1.2T"),
+ "snap": (110_000_000_000, "~$110B"),
+ "ssi": (60_000_000_000, "~$60B"),
+ "eitc": (60_000_000_000, "~$60B"),
+ "refundable_ctc": (120_000_000_000, "~$120B"),
+ "income_tax_before_credits": (4_000_000_000_000, "~$4T"),
+}
+
+DEFAULT_HF_PATH = "hf://policyengine/policyengine-us-data/national/US.h5"
+
+COUNT_VARS = {"person_count", "household_count", "is_pregnant"}
+
+
+def main(argv=None):
+ parser = argparse.ArgumentParser(description="Validate national US.h5")
+ parser.add_argument(
+ "--h5-path",
+ default=None,
+ help="Local path to US.h5",
+ )
+ parser.add_argument(
+ "--hf-path",
+ default=DEFAULT_HF_PATH,
+ help=f"HF path to US.h5 (default: {DEFAULT_HF_PATH})",
+ )
+ args = parser.parse_args(argv)
+
+ dataset_path = args.h5_path or args.hf_path
+
+ from policyengine_us import Microsimulation
+
+ print(f"Loading {dataset_path}...")
+ sim = Microsimulation(dataset=dataset_path)
+
+ n_hh = sim.calculate("household_id", map_to="household").shape[0]
+ print(f"Households in file: {n_hh:,}")
+
+ print("\n" + "=" * 70)
+ print("NATIONAL H5 VALUES")
+ print("=" * 70)
+
+ values = {}
+ failures = []
+ for var in VARIABLES:
+ try:
+ val = float(sim.calculate(var).sum())
+ values[var] = val
+ if var in COUNT_VARS:
+ print(f" {var:45s} {val:>15,.0f}")
+ else:
+ print(f" {var:45s} ${val:>15,.0f}")
+ except Exception as e:
+ failures.append((var, str(e)))
+ print(f" {var:45s} FAILED: {e}")
+
+ print("\n" + "=" * 70)
+ print("COMPARISON TO REFERENCE VALUES")
+ print("=" * 70)
+
+ any_flag = False
+ for var, (ref_val, ref_label) in REFERENCES.items():
+ if var not in values:
+ continue
+ val = values[var]
+ pct_diff = (val - ref_val) / ref_val * 100
+ flag = " ***" if abs(pct_diff) > 30 else ""
+ if flag:
+ any_flag = True
+ if var in COUNT_VARS:
+ print(
+ f" {var:35s} {val:>15,.0f} "
+ f"ref {ref_label:>8s} "
+ f"({pct_diff:+.1f}%){flag}"
+ )
+ else:
+ print(
+ f" {var:35s} ${val:>15,.0f} "
+ f"ref {ref_label:>8s} "
+ f"({pct_diff:+.1f}%){flag}"
+ )
+
+ if any_flag:
+ print("\n*** = >30% deviation from reference. " "Investigate further.")
+
+ if failures:
+ print(f"\n{len(failures)} variables failed:")
+ for var, err in failures:
+ print(f" {var}: {err}")
+
+ print("\n" + "=" * 70)
+ print("STRUCTURAL CHECKS")
+ print("=" * 70)
+
+ from policyengine_us_data.calibration.sanity_checks import (
+ run_sanity_checks,
+ )
+
+ results = run_sanity_checks(dataset_path)
+ n_pass = sum(1 for r in results if r["status"] == "PASS")
+ n_fail = sum(1 for r in results if r["status"] == "FAIL")
+ for r in results:
+ icon = (
+ "PASS"
+ if r["status"] == "PASS"
+ else "FAIL" if r["status"] == "FAIL" else "WARN"
+ )
+ print(f" [{icon}] {r['check']}: {r['detail']}")
+
+ print(f"\n{n_pass}/{len(results)} passed, {n_fail} failed")
+
+ return 0 if n_fail == 0 and not failures else 1
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/policyengine_us_data/calibration/validate_package.py b/policyengine_us_data/calibration/validate_package.py
new file mode 100644
index 000000000..bd6862f03
--- /dev/null
+++ b/policyengine_us_data/calibration/validate_package.py
@@ -0,0 +1,342 @@
+"""
+Validate a calibration package before uploading to Modal.
+
+Usage:
+ python -m policyengine_us_data.calibration.validate_package [path]
+ [--n-hardest N] [--strict [RATIO]]
+"""
+
+import argparse
+import sys
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+
+
+@dataclass
+class ValidationResult:
+ n_targets: int
+ n_columns: int
+ nnz: int
+ density: float
+ metadata: dict
+ n_achievable: int
+ n_impossible: int
+ impossible_targets: pd.DataFrame
+ impossible_by_group: pd.DataFrame
+ hardest_targets: pd.DataFrame
+ group_summary: pd.DataFrame
+ strict_ratio: Optional[float] = None
+ strict_failures: int = 0
+
+
+def validate_package(
+ package: dict,
+ n_hardest: int = 10,
+ strict_ratio: float = None,
+) -> ValidationResult:
+ X_sparse = package["X_sparse"]
+ targets_df = package["targets_df"]
+ target_names = package["target_names"]
+ metadata = package.get("metadata", {})
+
+ n_targets, n_columns = X_sparse.shape
+ nnz = X_sparse.nnz
+ density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0
+
+ row_sums = np.array(X_sparse.sum(axis=1)).flatten()
+ achievable_mask = row_sums > 0
+ n_achievable = int(achievable_mask.sum())
+ n_impossible = n_targets - n_achievable
+
+ impossible_idx = np.where(~achievable_mask)[0]
+ impossible_rows = targets_df.iloc[impossible_idx]
+ impossible_targets = pd.DataFrame(
+ {
+ "target_name": [target_names[i] for i in impossible_idx],
+ "domain_variable": impossible_rows["domain_variable"].values,
+ "variable": impossible_rows["variable"].values,
+ "geo_level": impossible_rows["geo_level"].values,
+ "geographic_id": impossible_rows["geographic_id"].values,
+ "target_value": impossible_rows["value"].values,
+ }
+ )
+ impossible_by_group = (
+ impossible_rows.groupby(["domain_variable", "variable", "geo_level"])
+ .size()
+ .reset_index(name="count")
+ .sort_values("count", ascending=False)
+ .reset_index(drop=True)
+ )
+
+ target_values = targets_df["value"].values
+ achievable_idx = np.where(achievable_mask)[0]
+ if len(achievable_idx) > 0:
+ a_row_sums = row_sums[achievable_idx]
+ a_target_vals = target_values[achievable_idx]
+ with np.errstate(divide="ignore", invalid="ignore"):
+ ratios = np.where(
+ a_target_vals != 0,
+ a_row_sums / a_target_vals,
+ np.inf,
+ )
+ k = min(n_hardest, len(ratios))
+ hardest_local_idx = np.argpartition(ratios, k)[:k]
+ hardest_local_idx = hardest_local_idx[
+ np.argsort(ratios[hardest_local_idx])
+ ]
+ hardest_global_idx = achievable_idx[hardest_local_idx]
+
+ hardest_targets = pd.DataFrame(
+ {
+ "target_name": [target_names[i] for i in hardest_global_idx],
+ "domain_variable": targets_df["domain_variable"]
+ .iloc[hardest_global_idx]
+ .values,
+ "variable": targets_df["variable"]
+ .iloc[hardest_global_idx]
+ .values,
+ "geographic_id": targets_df["geographic_id"]
+ .iloc[hardest_global_idx]
+ .values,
+ "ratio": ratios[hardest_local_idx],
+ "row_sum": a_row_sums[hardest_local_idx],
+ "target_value": a_target_vals[hardest_local_idx],
+ }
+ )
+ else:
+ hardest_targets = pd.DataFrame(
+ columns=[
+ "target_name",
+ "domain_variable",
+ "variable",
+ "geographic_id",
+ "ratio",
+ "row_sum",
+ "target_value",
+ ]
+ )
+
+ group_summary = (
+ targets_df.assign(achievable=achievable_mask)
+ .groupby(["domain_variable", "variable", "geo_level"])
+ .agg(total=("value", "size"), ok=("achievable", "sum"))
+ .reset_index()
+ )
+ group_summary["impossible"] = group_summary["total"] - group_summary["ok"]
+ group_summary["ok"] = group_summary["ok"].astype(int)
+ group_summary = group_summary.sort_values(
+ ["domain_variable", "variable", "geo_level"]
+ ).reset_index(drop=True)
+
+ strict_failures = 0
+ if strict_ratio is not None and len(achievable_idx) > 0:
+ strict_failures = int((ratios < strict_ratio).sum())
+
+ return ValidationResult(
+ n_targets=n_targets,
+ n_columns=n_columns,
+ nnz=nnz,
+ density=density,
+ metadata=metadata,
+ n_achievable=n_achievable,
+ n_impossible=n_impossible,
+ impossible_targets=impossible_targets,
+ impossible_by_group=impossible_by_group,
+ hardest_targets=hardest_targets,
+ group_summary=group_summary,
+ strict_ratio=strict_ratio,
+ strict_failures=strict_failures,
+ )
+
+
+def format_report(result: ValidationResult, package_path: str = None) -> str:
+ lines = ["", "=== Calibration Package Validation ===", ""]
+
+ if package_path:
+ lines.append(f"Package: {package_path}")
+ meta = result.metadata
+ if meta.get("created_at"):
+ lines.append(f"Created: {meta['created_at']}")
+ if meta.get("dataset_path"):
+ lines.append(f"Dataset: {meta['dataset_path']}")
+ if meta.get("git_branch") or meta.get("git_commit"):
+ branch = meta.get("git_branch", "unknown")
+ commit = meta.get("git_commit", "")
+ commit_short = commit[:8] if commit else "unknown"
+ dirty = " (DIRTY)" if meta.get("git_dirty") else ""
+ lines.append(f"Git: {branch} @ {commit_short}{dirty}")
+ if meta.get("package_version"):
+ lines.append(f"Version: {meta['package_version']}")
+ if meta.get("dataset_sha256"):
+ lines.append(f"Dataset SHA: {meta['dataset_sha256'][:12]}")
+ if meta.get("db_sha256"):
+ lines.append(f"DB SHA: {meta['db_sha256'][:12]}")
+ lines.append("")
+
+ lines.append(
+ f"Matrix: {result.n_targets:,} targets"
+ f" x {result.n_columns:,} columns"
+ )
+ lines.append(f"Non-zero: {result.nnz:,} (density: {result.density:.6f})")
+ if meta.get("n_clones"):
+ parts = [f"Clones: {meta['n_clones']}"]
+ if meta.get("n_records"):
+ parts.append(f"Records: {meta['n_records']:,}")
+ if meta.get("seed") is not None:
+ parts.append(f"Seed: {meta['seed']}")
+ lines.append(", ".join(parts))
+ lines.append("")
+
+ pct = (
+ 100 * result.n_achievable / result.n_targets if result.n_targets else 0
+ )
+ pct_imp = 100 - pct
+ lines.append("--- Achievability ---")
+ lines.append(
+ f"Achievable: {result.n_achievable:>6,}"
+ f" / {result.n_targets:,} ({pct:.1f}%)"
+ )
+ lines.append(
+ f"Impossible: {result.n_impossible:>6,}"
+ f" / {result.n_targets:,} ({pct_imp:.1f}%)"
+ )
+ lines.append("")
+
+ if len(result.impossible_targets) > 0:
+ lines.append("--- Impossible Targets ---")
+ for _, row in result.impossible_targets.iterrows():
+ lines.append(
+ f" {row['target_name']:<60s}"
+ f" {row['target_value']:>14,.0f}"
+ )
+ lines.append("")
+
+ if len(result.impossible_by_group) > 1:
+ lines.append("--- Impossible Targets by Group ---")
+ for _, row in result.impossible_by_group.iterrows():
+ lines.append(
+ f" {row['domain_variable']:<20s}"
+ f" {row['variable']:<25s}"
+ f" {row['geo_level']:<12s}"
+ f" {row['count']:>5d}"
+ )
+ lines.append("")
+
+ if len(result.hardest_targets) > 0:
+ n = len(result.hardest_targets)
+ lines.append(
+ f"--- Hardest Achievable Targets" f" ({n} lowest ratio) ---"
+ )
+ for _, row in result.hardest_targets.iterrows():
+ lines.append(
+ f" {row['target_name']:<50s}"
+ f" {row['ratio']:>10.4f}"
+ f" {row['row_sum']:>14,.0f}"
+ f" {row['target_value']:>14,.0f}"
+ )
+ lines.append("")
+
+ if len(result.group_summary) > 0:
+ lines.append("--- Group Summary ---")
+ lines.append(
+ f" {'domain':<20s} {'variable':<25s}"
+ f" {'geo_level':<12s}"
+ f" {'total':>6s} {'ok':>6s} {'impossible':>10s}"
+ )
+ for _, row in result.group_summary.iterrows():
+ lines.append(
+ f" {row['domain_variable']:<20s}"
+ f" {row['variable']:<25s}"
+ f" {row['geo_level']:<12s}"
+ f" {row['total']:>6d}"
+ f" {row['ok']:>6d}"
+ f" {row['impossible']:>10d}"
+ )
+ lines.append("")
+
+ if result.strict_ratio is not None:
+ lines.append(
+ f"Strict check (ratio < {result.strict_ratio}):"
+ f" {result.strict_failures} failures"
+ )
+ lines.append("")
+
+ if result.strict_ratio is not None and result.strict_failures > 0:
+ lines.append(
+ f"RESULT: FAIL ({result.strict_failures}"
+ f" targets below ratio {result.strict_ratio})"
+ )
+ elif result.n_impossible > 0:
+ lines.append(
+ f"RESULT: FAIL ({result.n_impossible} impossible targets)"
+ )
+ else:
+ lines.append("RESULT: PASS")
+
+ return "\n".join(lines)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Validate a calibration package"
+ )
+ parser.add_argument(
+ "path",
+ nargs="?",
+ default=None,
+ help="Path to calibration_package.pkl",
+ )
+ parser.add_argument(
+ "--n-hardest",
+ type=int,
+ default=10,
+ help="Number of hardest achievable targets to show",
+ )
+ parser.add_argument(
+ "--strict",
+ nargs="?",
+ const=0.01,
+ type=float,
+ default=None,
+ metavar="RATIO",
+ help="Fail if any achievable target has ratio below RATIO"
+ " (default: 0.01)",
+ )
+ args = parser.parse_args()
+
+ if args.path is None:
+ from policyengine_us_data.storage import STORAGE_FOLDER
+
+ path = STORAGE_FOLDER / "calibration" / "calibration_package.pkl"
+ else:
+ path = Path(args.path)
+
+ if not path.exists():
+ print(f"Error: package not found at {path}", file=sys.stderr)
+ sys.exit(1)
+
+ from policyengine_us_data.calibration.unified_calibration import (
+ load_calibration_package,
+ )
+
+ package = load_calibration_package(str(path))
+ result = validate_package(
+ package,
+ n_hardest=args.n_hardest,
+ strict_ratio=args.strict,
+ )
+ print(format_report(result, package_path=str(path)))
+
+ if args.strict is not None and result.strict_failures > 0:
+ sys.exit(2)
+ elif result.n_impossible > 0:
+ sys.exit(1)
+ sys.exit(0)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/policyengine_us_data/calibration/validate_staging.py b/policyengine_us_data/calibration/validate_staging.py
new file mode 100644
index 000000000..58bd61667
--- /dev/null
+++ b/policyengine_us_data/calibration/validate_staging.py
@@ -0,0 +1,732 @@
+"""
+Validate staging .h5 files by running sim.calculate() and comparing
+against calibration targets from policy_data.db.
+
+Usage:
+ python -m policyengine_us_data.calibration.validate_staging \
+ --area-type states,districts --areas NC \
+ --period 2024 --output validation_results.csv
+
+ python -m policyengine_us_data.calibration.validate_staging \
+ --sanity-only --area-type states --areas NC
+"""
+
+import argparse
+import csv
+import gc
+import logging
+import math
+import multiprocessing as mp
+import os
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+from sqlalchemy import create_engine
+
+from policyengine_us_data.storage import STORAGE_FOLDER
+from policyengine_us_data.calibration.unified_calibration import (
+ load_target_config,
+ _match_rules,
+)
+from policyengine_us_data.calibration.unified_matrix_builder import (
+ UnifiedMatrixBuilder,
+ _calculate_target_values_standalone,
+ _GEO_VARS,
+)
+from policyengine_us_data.calibration.calibration_utils import (
+ STATE_CODES,
+)
+from policyengine_us_data.calibration.sanity_checks import (
+ run_sanity_checks,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_HF_PREFIX = "hf://policyengine/policyengine-us-data/staging"
+DEFAULT_DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db")
+DEFAULT_TARGET_CONFIG = str(Path(__file__).parent / "target_config_full.yaml")
+TRAINING_TARGET_CONFIG = str(Path(__file__).parent / "target_config.yaml")
+
+SANITY_CEILINGS = {
+ "national": {
+ "dollar": 30e12,
+ "person_count": 340e6,
+ "household_count": 135e6,
+ "count": 340e6,
+ },
+ "state": {
+ "dollar": 5e12,
+ "person_count": 40e6,
+ "household_count": 15e6,
+ "count": 40e6,
+ },
+ "district": {
+ "dollar": 500e9,
+ "person_count": 1e6,
+ "household_count": 400e3,
+ "count": 1e6,
+ },
+}
+
+FIPS_TO_ABBR = {str(k): v for k, v in STATE_CODES.items()}
+ABBR_TO_FIPS = {v: str(k) for k, v in STATE_CODES.items()}
+
+CSV_COLUMNS = [
+ "area_type",
+ "area_id",
+ "variable",
+ "target_name",
+ "period",
+ "target_value",
+ "sim_value",
+ "error",
+ "rel_error",
+ "abs_error",
+ "rel_abs_error",
+ "sanity_check",
+ "sanity_reason",
+ "in_training",
+]
+
+
+def _classify_variable(variable: str) -> str:
+ if variable == "household_count":
+ return "household_count"
+ if variable == "person_count":
+ return "person_count"
+ if variable.endswith("_count"):
+ return "count"
+ return "dollar"
+
+
+def _run_sanity_check(
+ sim_value: float,
+ variable: str,
+ geo_level: str,
+) -> tuple:
+ if not math.isfinite(sim_value):
+ return "FAIL", "non-finite value"
+ vtype = _classify_variable(variable)
+ ceilings = SANITY_CEILINGS.get(geo_level, SANITY_CEILINGS["state"])
+ ceiling = ceilings.get(vtype, ceilings["dollar"])
+ if abs(sim_value) > ceiling:
+ return (
+ "FAIL",
+ f"|{sim_value:.2e}| > {ceiling:.0e} ceiling "
+ f"({vtype} @ {geo_level})",
+ )
+ return "PASS", ""
+
+
+def _query_all_active_targets(engine, period: int) -> pd.DataFrame:
+ query = """
+ WITH best_periods AS (
+ SELECT stratum_id, variable,
+ CASE
+ WHEN MAX(CASE WHEN period <= :period
+ THEN period END) IS NOT NULL
+ THEN MAX(CASE WHEN period <= :period
+ THEN period END)
+ ELSE MIN(period)
+ END as best_period
+ FROM target_overview
+ WHERE active = 1
+ GROUP BY stratum_id, variable
+ )
+ SELECT tv.target_id, tv.stratum_id, tv.variable,
+ tv.value, tv.period, tv.geo_level,
+ tv.geographic_id, tv.domain_variable
+ FROM target_overview tv
+ JOIN best_periods bp
+ ON tv.stratum_id = bp.stratum_id
+ AND tv.variable = bp.variable
+ AND tv.period = bp.best_period
+ WHERE tv.active = 1
+ ORDER BY tv.target_id
+ """
+ with engine.connect() as conn:
+ return pd.read_sql(query, conn, params={"period": period})
+
+
+def _get_stratum_constraints(engine, stratum_id: int) -> list:
+ query = """
+ SELECT constraint_variable AS variable, operation, value
+ FROM stratum_constraints
+ WHERE stratum_id = :stratum_id
+ """
+ with engine.connect() as conn:
+ df = pd.read_sql(query, conn, params={"stratum_id": int(stratum_id)})
+ return df.to_dict("records")
+
+
+def _geoid_to_district_filename(geoid: str) -> str:
+ """Convert DB geographic_id like '3701' to filename 'NC-01'."""
+ geoid = geoid.zfill(4)
+ state_fips = geoid[:-2]
+ district_num = geoid[-2:]
+ abbr = FIPS_TO_ABBR.get(state_fips)
+ if abbr is None:
+ return geoid
+ return f"{abbr}-{district_num}"
+
+
+def _geoid_to_display(geoid: str) -> str:
+ """Convert DB geographic_id like '3701' to 'NC-01'."""
+ return _geoid_to_district_filename(geoid)
+
+
+def _resolve_state_fips(areas_str: Optional[str]) -> list:
+ """Resolve --areas to state FIPS codes."""
+ if not areas_str:
+ return [str(f) for f in sorted(STATE_CODES.keys())]
+ resolved = []
+ for a in areas_str.split(","):
+ a = a.strip()
+ if a in ABBR_TO_FIPS:
+ resolved.append(ABBR_TO_FIPS[a])
+ elif a.isdigit():
+ resolved.append(a)
+ else:
+ logger.warning("Unknown area '%s', skipping", a)
+ return resolved
+
+
+def _resolve_district_ids(engine, areas_str: Optional[str]) -> list:
+ """Resolve --areas to district geographic_ids from DB."""
+ state_fips_list = _resolve_state_fips(areas_str)
+ with engine.connect() as conn:
+ df = pd.read_sql(
+ "SELECT DISTINCT geographic_id FROM target_overview "
+ "WHERE geo_level = 'district'",
+ conn,
+ )
+ all_geoids = df["geographic_id"].tolist()
+ result = []
+ for geoid in all_geoids:
+ padded = str(geoid).zfill(4)
+ sfips = padded[:-2]
+ if sfips in state_fips_list:
+ result.append(str(geoid))
+ return sorted(result)
+
+
+def _build_variable_entity_map(sim) -> dict:
+ tbs = sim.tax_benefit_system
+ mapping = {}
+ for var_name in tbs.variables:
+ var = tbs.get_variable(var_name)
+ if var is not None:
+ mapping[var_name] = var.entity.key
+ count_entities = {
+ "person_count": "person",
+ "household_count": "household",
+ "tax_unit_count": "tax_unit",
+ "spm_unit_count": "spm_unit",
+ }
+ mapping.update(count_entities)
+ return mapping
+
+
+def _build_entity_rel(sim) -> pd.DataFrame:
+ return pd.DataFrame(
+ {
+ "person_id": sim.calculate("person_id", map_to="person").values,
+ "household_id": sim.calculate(
+ "household_id", map_to="person"
+ ).values,
+ "tax_unit_id": sim.calculate(
+ "tax_unit_id", map_to="person"
+ ).values,
+ "spm_unit_id": sim.calculate(
+ "spm_unit_id", map_to="person"
+ ).values,
+ }
+ )
+
+
+def validate_area(
+ sim,
+ targets_df: pd.DataFrame,
+ engine,
+ area_type: str,
+ area_id: str,
+ display_id: str,
+ period: int,
+ training_mask: np.ndarray,
+ variable_entity_map: dict,
+) -> list:
+ entity_rel = _build_entity_rel(sim)
+ household_ids = sim.calculate("household_id", map_to="household").values
+ n_households = len(household_ids)
+
+ hh_weight = sim.calculate(
+ "household_weight",
+ map_to="household",
+ period=period,
+ ).values.astype(np.float64)
+
+ hh_vars_cache = {}
+ person_vars_cache = {}
+
+ training_arr = np.asarray(training_mask, dtype=bool)
+
+ geo_level = "state" if area_type == "states" else "district"
+
+ results = []
+ for i, (idx, row) in enumerate(targets_df.iterrows()):
+ variable = row["variable"]
+ target_value = float(row["value"])
+ stratum_id = int(row["stratum_id"])
+
+ constraints = _get_stratum_constraints(engine, stratum_id)
+ non_geo = [c for c in constraints if c["variable"] not in _GEO_VARS]
+
+ needed_vars = set()
+ needed_vars.add(variable)
+ for c in non_geo:
+ needed_vars.add(c["variable"])
+
+ is_count = variable.endswith("_count")
+ if not is_count and variable not in hh_vars_cache:
+ try:
+ hh_vars_cache[variable] = sim.calculate(
+ variable,
+ map_to="household",
+ period=period,
+ ).values
+ except Exception:
+ pass
+
+ for vname in needed_vars:
+ if vname not in person_vars_cache:
+ try:
+ person_vars_cache[vname] = sim.calculate(
+ vname,
+ map_to="person",
+ period=period,
+ ).values
+ except Exception:
+ pass
+
+ per_hh = _calculate_target_values_standalone(
+ target_variable=variable,
+ non_geo_constraints=non_geo,
+ n_households=n_households,
+ hh_vars=hh_vars_cache,
+ person_vars=person_vars_cache,
+ entity_rel=entity_rel,
+ household_ids=household_ids,
+ variable_entity_map=variable_entity_map,
+ )
+
+ sim_value = float(np.dot(per_hh, hh_weight))
+
+ error = sim_value - target_value
+ abs_error = abs(error)
+ if target_value != 0:
+ rel_error = error / target_value
+ rel_abs_error = abs_error / abs(target_value)
+ else:
+ rel_error = float("inf") if error != 0 else 0.0
+ rel_abs_error = float("inf") if abs_error != 0 else 0.0
+
+ target_name = UnifiedMatrixBuilder._make_target_name(
+ variable,
+ constraints,
+ )
+
+ sanity_check, sanity_reason = _run_sanity_check(
+ sim_value,
+ variable,
+ geo_level,
+ )
+
+ in_training = bool(training_arr[i])
+
+ results.append(
+ {
+ "area_type": area_type,
+ "area_id": display_id,
+ "variable": variable,
+ "target_name": target_name,
+ "period": int(row["period"]),
+ "target_value": target_value,
+ "sim_value": sim_value,
+ "error": error,
+ "rel_error": rel_error,
+ "abs_error": abs_error,
+ "rel_abs_error": rel_abs_error,
+ "sanity_check": sanity_check,
+ "sanity_reason": sanity_reason,
+ "in_training": in_training,
+ }
+ )
+
+ return results
+
+
+def parse_args(argv=None):
+ parser = argparse.ArgumentParser(
+ description="Validate staging .h5 files against "
+ "calibration targets via sim.calculate()"
+ )
+ parser.add_argument(
+ "--area-type",
+ default="states",
+ help="Comma-separated geo levels to validate: "
+ "states, districts (default: states)",
+ )
+ parser.add_argument(
+ "--areas",
+ default=None,
+ help="Comma-separated state abbreviations or FIPS "
+ "(applies to all area types; all if omitted)",
+ )
+ parser.add_argument(
+ "--hf-prefix",
+ default=DEFAULT_HF_PREFIX,
+ help="HuggingFace path prefix for .h5 files",
+ )
+ parser.add_argument(
+ "--period",
+ type=int,
+ default=2024,
+ help="Tax year to validate (default: 2024)",
+ )
+ parser.add_argument(
+ "--target-config",
+ default=DEFAULT_TARGET_CONFIG,
+ help="YAML config with exclude rules "
+ "(default: target_config_full.yaml)",
+ )
+ parser.add_argument(
+ "--db-path",
+ default=DEFAULT_DB_PATH,
+ help="Path to policy_data.db",
+ )
+ parser.add_argument(
+ "--output",
+ default="validation_results.csv",
+ help="Output CSV path",
+ )
+ parser.add_argument(
+ "--sanity-only",
+ action="store_true",
+ help="Run only structural sanity checks (fast, " "no database needed)",
+ )
+ return parser.parse_args(argv)
+
+
+def _validate_single_area(
+ area_type,
+ area_id,
+ h5_path,
+ display_id,
+ area_targets,
+ area_training,
+ db_path,
+ period,
+):
+ """Validate one area in an isolated process.
+
+ Runs in a subprocess so all memory (Microsimulation, caches)
+ is fully reclaimed by the OS when the process exits.
+ """
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(message)s",
+ )
+ from policyengine_us import Microsimulation
+ from sqlalchemy import create_engine as _create_engine
+
+ engine = _create_engine(f"sqlite:///{db_path}")
+
+ logger.info("Loading sim from %s", h5_path)
+ try:
+ sim = Microsimulation(dataset=h5_path)
+ except Exception as e:
+ logger.error("Failed to load %s: %s", h5_path, e)
+ return [], 0.0
+
+ area_pop = 0.0
+ if area_type == "states":
+ person_weight = sim.calculate(
+ "person_weight",
+ map_to="person",
+ period=period,
+ ).values.astype(np.float64)
+ area_pop = float(person_weight.sum())
+ logger.info(" %s population: %.0f", display_id, area_pop)
+
+ if len(area_targets) == 0:
+ logger.warning("No targets for %s, skipping", display_id)
+ return [], area_pop
+
+ logger.info(
+ "Validating %d targets for %s",
+ len(area_targets),
+ display_id,
+ )
+
+ variable_entity_map = _build_variable_entity_map(sim)
+
+ area_results = validate_area(
+ sim=sim,
+ targets_df=area_targets,
+ engine=engine,
+ area_type=area_type,
+ area_id=area_id,
+ display_id=display_id,
+ period=period,
+ training_mask=area_training,
+ variable_entity_map=variable_entity_map,
+ )
+
+ n_fail = sum(
+ 1 for r in area_results if r["sanity_check"] == "FAIL"
+ )
+ logger.info(
+ " %s: %d results, %d sanity failures",
+ display_id,
+ len(area_results),
+ n_fail,
+ )
+
+ return area_results, area_pop
+
+
+def _run_area_type(
+ area_type,
+ area_ids,
+ level_targets,
+ level_training,
+ engine,
+ args,
+ Microsimulation,
+):
+ """Validate all areas for a single area_type.
+
+ Each area runs in a subprocess so the OS fully reclaims
+ memory between states (avoids OOM on large states like CA).
+ """
+ results = []
+ total_weighted_pop = 0.0
+
+ for area_id in area_ids:
+ if area_type == "states":
+ abbr = FIPS_TO_ABBR.get(area_id, area_id)
+ h5_name = abbr
+ display_id = abbr
+ else:
+ h5_name = _geoid_to_district_filename(area_id)
+ display_id = h5_name
+
+ h5_path = f"{args.hf_prefix}/{area_type}/{h5_name}.h5"
+
+ area_mask = (
+ level_targets["geographic_id"] == area_id
+ ).values
+ area_targets = level_targets[area_mask].reset_index(
+ drop=True
+ )
+ area_training = level_training[area_mask]
+
+ ctx = mp.get_context("spawn")
+ with ctx.Pool(1) as pool:
+ area_results, area_pop = pool.apply(
+ _validate_single_area,
+ (
+ area_type,
+ area_id,
+ h5_path,
+ display_id,
+ area_targets,
+ area_training,
+ args.db_path,
+ args.period,
+ ),
+ )
+
+ total_weighted_pop += area_pop
+ results.extend(area_results)
+
+ if area_type == "states" and total_weighted_pop > 0:
+ logger.info(
+ "TOTAL WEIGHTED POPULATION: %.0f (expect ~340M)",
+ total_weighted_pop,
+ )
+
+ return results
+
+
+def _run_sanity_only(args):
+ """Run structural sanity checks on staging H5 files."""
+ area_types = [t.strip() for t in args.area_type.split(",")]
+ state_fips_list = _resolve_state_fips(args.areas)
+
+ total_failures = 0
+
+ for area_type in area_types:
+ if area_type == "states":
+ for fips in state_fips_list:
+ abbr = FIPS_TO_ABBR.get(fips, fips)
+ h5_url = f"{args.hf_prefix}/{area_type}/{abbr}.h5"
+ logger.info("Sanity-checking %s", h5_url)
+
+ if h5_url.startswith("hf://"):
+ from huggingface_hub import hf_hub_download
+ import tempfile
+
+ parts = h5_url[5:].split("/", 2)
+ repo = f"{parts[0]}/{parts[1]}"
+ path = parts[2]
+ local = hf_hub_download(
+ repo_id=repo,
+ filename=path,
+ repo_type="model",
+ token=os.environ.get("HUGGING_FACE_TOKEN"),
+ )
+ else:
+ local = h5_url
+
+ results = run_sanity_checks(local, args.period)
+ n_fail = sum(1 for r in results if r["status"] == "FAIL")
+ total_failures += n_fail
+
+ for r in results:
+ if r["status"] != "PASS":
+ detail = f" — {r['detail']}" if r["detail"] else ""
+ logger.warning(
+ " %s [%s] %s%s",
+ abbr,
+ r["status"],
+ r["check"],
+ detail,
+ )
+
+ if n_fail == 0:
+ logger.info(" %s: all checks passed", abbr)
+ else:
+ logger.info(
+ "Sanity-only mode for %s not yet implemented",
+ area_type,
+ )
+
+ if total_failures > 0:
+ logger.error("%d total sanity failures", total_failures)
+ else:
+ logger.info("All sanity checks passed")
+
+
+def main(argv=None):
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(message)s",
+ )
+
+ args = parse_args(argv)
+ logger.info("CLI args: %s", vars(args))
+
+ if args.sanity_only:
+ _run_sanity_only(args)
+ return
+
+ from policyengine_us import Microsimulation
+
+ engine = create_engine(f"sqlite:///{args.db_path}")
+
+ all_targets = _query_all_active_targets(engine, args.period)
+ logger.info("Loaded %d active targets from DB", len(all_targets))
+
+ exclude_config = load_target_config(args.target_config)
+ exclude_rules = exclude_config.get("exclude", [])
+ if exclude_rules:
+ exc_mask = _match_rules(all_targets, exclude_rules)
+ all_targets = all_targets[~exc_mask].reset_index(drop=True)
+ logger.info("After exclusions: %d targets", len(all_targets))
+
+ include_rules = exclude_config.get("include", [])
+ if include_rules:
+ inc_mask = _match_rules(all_targets, include_rules)
+ all_targets = all_targets[inc_mask].reset_index(drop=True)
+ logger.info("After inclusions: %d targets", len(all_targets))
+
+ training_config = load_target_config(TRAINING_TARGET_CONFIG)
+ training_include = training_config.get("include", [])
+ if training_include:
+ training_mask = np.asarray(
+ _match_rules(all_targets, training_include),
+ dtype=bool,
+ )
+ else:
+ training_mask = np.ones(len(all_targets), dtype=bool)
+
+ area_types = [t.strip() for t in args.area_type.split(",")]
+ valid_types = {"states", "districts"}
+ for t in area_types:
+ if t not in valid_types:
+ logger.error(
+ "Unknown area-type '%s'. Use: %s",
+ t,
+ ", ".join(sorted(valid_types)),
+ )
+ return
+
+ all_results = []
+
+ for area_type in area_types:
+ geo_level = "state" if area_type == "states" else "district"
+ geo_mask = (all_targets["geo_level"] == geo_level).values
+ level_targets = all_targets[geo_mask].reset_index(drop=True)
+ level_training = training_mask[geo_mask]
+
+ logger.info(
+ "%d targets at geo_level=%s",
+ len(level_targets),
+ geo_level,
+ )
+
+ if area_type == "states":
+ area_ids = _resolve_state_fips(args.areas)
+ else:
+ area_ids = _resolve_district_ids(engine, args.areas)
+
+ logger.info(
+ "%s: %d areas to validate",
+ area_type,
+ len(area_ids),
+ )
+
+ results = _run_area_type(
+ area_type=area_type,
+ area_ids=area_ids,
+ level_targets=level_targets,
+ level_training=level_training,
+ engine=engine,
+ args=args,
+ Microsimulation=Microsimulation,
+ )
+ all_results.extend(results)
+
+ output_path = Path(args.output)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS)
+ writer.writeheader()
+ writer.writerows(all_results)
+
+ logger.info("Wrote %d rows to %s", len(all_results), output_path)
+
+ n_total_fail = sum(1 for r in all_results if r["sanity_check"] == "FAIL")
+ if n_total_fail > 0:
+ logger.warning(
+ "%d SANITY FAILURES across all areas",
+ n_total_fail,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/policyengine_us_data/datasets/cps/extended_cps.py b/policyengine_us_data/datasets/cps/extended_cps.py
index b60fbf42f..b095a7f08 100644
--- a/policyengine_us_data/datasets/cps/extended_cps.py
+++ b/policyengine_us_data/datasets/cps/extended_cps.py
@@ -55,7 +55,7 @@ def generate(self):
# Variables with formulas that must still be stored (e.g. IDs
# needed by the dataset loader before formulas can run).
- _KEEP_FORMULA_VARS = {"person_id"}
+ _KEEP_FORMULA_VARS = {"person_id", "is_pregnant"}
@classmethod
def _drop_formula_variables(cls, data):
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/__init__.py b/policyengine_us_data/datasets/cps/local_area_calibration/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py
deleted file mode 100644
index 4963f3979..000000000
--- a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py
+++ /dev/null
@@ -1,587 +0,0 @@
-"""
-Publish local area H5 files to GCP and Hugging Face.
-
-Downloads calibration inputs from HF, generates state/district H5s
-with checkpointing, and uploads to both destinations.
-
-Usage:
- python publish_local_area.py [--skip-download] [--states-only] [--districts-only]
-"""
-
-import os
-import numpy as np
-from pathlib import Path
-from typing import List, Optional, Set
-
-from policyengine_us import Microsimulation
-from policyengine_us_data.utils.huggingface import download_calibration_inputs
-from policyengine_us_data.utils.data_upload import (
- upload_local_area_file,
- upload_local_area_batch_to_hf,
-)
-from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import (
- create_sparse_cd_stacked_dataset,
- NYC_COUNTIES,
- NYC_CDS,
-)
-from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
- get_all_cds_from_database,
- STATE_CODES,
-)
-
-CHECKPOINT_FILE = Path("completed_states.txt")
-CHECKPOINT_FILE_DISTRICTS = Path("completed_districts.txt")
-CHECKPOINT_FILE_CITIES = Path("completed_cities.txt")
-WORK_DIR = Path("local_area_build")
-
-
-def load_completed_states() -> set:
- if CHECKPOINT_FILE.exists():
- content = CHECKPOINT_FILE.read_text().strip()
- if content:
- return set(content.split("\n"))
- return set()
-
-
-def record_completed_state(state_code: str):
- with open(CHECKPOINT_FILE, "a") as f:
- f.write(f"{state_code}\n")
-
-
-def load_completed_districts() -> set:
- if CHECKPOINT_FILE_DISTRICTS.exists():
- content = CHECKPOINT_FILE_DISTRICTS.read_text().strip()
- if content:
- return set(content.split("\n"))
- return set()
-
-
-def record_completed_district(district_name: str):
- with open(CHECKPOINT_FILE_DISTRICTS, "a") as f:
- f.write(f"{district_name}\n")
-
-
-def load_completed_cities() -> set:
- if CHECKPOINT_FILE_CITIES.exists():
- content = CHECKPOINT_FILE_CITIES.read_text().strip()
- if content:
- return set(content.split("\n"))
- return set()
-
-
-def record_completed_city(city_name: str):
- with open(CHECKPOINT_FILE_CITIES, "a") as f:
- f.write(f"{city_name}\n")
-
-
-def build_state_h5(
- state_code: str,
- weights: np.ndarray,
- cds_to_calibrate: List[str],
- dataset_path: Path,
- output_dir: Path,
-) -> Optional[Path]:
- """
- Build a single state H5 file (build only, no upload).
-
- Args:
- state_code: Two-letter state code (e.g., "AL", "CA")
- weights: Calibrated weight vector
- cds_to_calibrate: Full list of CD GEOIDs from calibration
- dataset_path: Path to base dataset H5 file
- output_dir: Output directory for H5 file
-
- Returns:
- Path to output H5 file if successful, None if no CDs found
- """
- state_fips = None
- for fips, code in STATE_CODES.items():
- if code == state_code:
- state_fips = fips
- break
-
- if state_fips is None:
- print(f"Unknown state code: {state_code}")
- return None
-
- cd_subset = [cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips]
- if not cd_subset:
- print(f"No CDs found for {state_code}, skipping")
- return None
-
- states_dir = output_dir / "states"
- states_dir.mkdir(parents=True, exist_ok=True)
- output_path = states_dir / f"{state_code}.h5"
-
- print(f"\n{'='*60}")
- print(f"Building {state_code} ({len(cd_subset)} CDs)")
- print(f"{'='*60}")
-
- create_sparse_cd_stacked_dataset(
- weights,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=str(dataset_path),
- output_path=str(output_path),
- )
-
- return output_path
-
-
-def build_district_h5(
- cd_geoid: str,
- weights: np.ndarray,
- cds_to_calibrate: List[str],
- dataset_path: Path,
- output_dir: Path,
-) -> Path:
- """
- Build a single district H5 file (build only, no upload).
-
- Args:
- cd_geoid: Congressional district GEOID (e.g., "0101" for AL-01)
- weights: Calibrated weight vector
- cds_to_calibrate: Full list of CD GEOIDs from calibration
- dataset_path: Path to base dataset H5 file
- output_dir: Output directory for H5 file
-
- Returns:
- Path to output H5 file
- """
- cd_int = int(cd_geoid)
- state_fips = cd_int // 100
- district_num = cd_int % 100
- state_code = STATE_CODES.get(state_fips, str(state_fips))
- friendly_name = f"{state_code}-{district_num:02d}"
-
- districts_dir = output_dir / "districts"
- districts_dir.mkdir(parents=True, exist_ok=True)
- output_path = districts_dir / f"{friendly_name}.h5"
-
- print(f"\n{'='*60}")
- print(f"Building {friendly_name}")
- print(f"{'='*60}")
-
- create_sparse_cd_stacked_dataset(
- weights,
- cds_to_calibrate,
- cd_subset=[cd_geoid],
- dataset_path=str(dataset_path),
- output_path=str(output_path),
- )
-
- return output_path
-
-
-def build_city_h5(
- city_name: str,
- weights: np.ndarray,
- cds_to_calibrate: List[str],
- dataset_path: Path,
- output_dir: Path,
-) -> Optional[Path]:
- """
- Build a city H5 file (build only, no upload).
-
- Currently supports NYC only.
-
- Args:
- city_name: City name (currently only "NYC" supported)
- weights: Calibrated weight vector
- cds_to_calibrate: Full list of CD GEOIDs from calibration
- dataset_path: Path to base dataset H5 file
- output_dir: Output directory for H5 file
-
- Returns:
- Path to output H5 file if successful, None otherwise
- """
- if city_name != "NYC":
- print(f"Unsupported city: {city_name}")
- return None
-
- cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
- if not cd_subset:
- print("No NYC-related CDs found, skipping")
- return None
-
- cities_dir = output_dir / "cities"
- cities_dir.mkdir(parents=True, exist_ok=True)
- output_path = cities_dir / "NYC.h5"
-
- print(f"\n{'='*60}")
- print(f"Building NYC ({len(cd_subset)} CDs)")
- print(f"{'='*60}")
-
- create_sparse_cd_stacked_dataset(
- weights,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=str(dataset_path),
- output_path=str(output_path),
- county_filter=NYC_COUNTIES,
- )
-
- return output_path
-
-
-def get_district_friendly_name(cd_geoid: str) -> str:
- """Convert GEOID to friendly name (e.g., '0101' -> 'AL-01')."""
- cd_int = int(cd_geoid)
- state_fips = cd_int // 100
- district_num = cd_int % 100
- state_code = STATE_CODES.get(state_fips, str(state_fips))
- return f"{state_code}-{district_num:02d}"
-
-
-def build_and_upload_states(
- weights_path: Path,
- dataset_path: Path,
- db_path: Path,
- output_dir: Path,
- completed_states: set,
- hf_batch_size: int = 10,
-):
- """Build and upload state H5 files with checkpointing."""
- db_uri = f"sqlite:///{db_path}"
- cds_to_calibrate = get_all_cds_from_database(db_uri)
- w = np.load(weights_path)
-
- states_dir = output_dir / "states"
- states_dir.mkdir(parents=True, exist_ok=True)
-
- hf_queue = [] # Queue for batched HuggingFace uploads
-
- for state_fips, state_code in STATE_CODES.items():
- if state_code in completed_states:
- print(f"Skipping {state_code} (already completed)")
- continue
-
- cd_subset = [
- cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips
- ]
- if not cd_subset:
- print(f"No CDs found for {state_code}, skipping")
- continue
-
- output_path = states_dir / f"{state_code}.h5"
- print(f"\n{'='*60}")
- print(f"Building {state_code} ({len(cd_subset)} CDs)")
- print(f"{'='*60}")
-
- try:
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=str(dataset_path),
- output_path=str(output_path),
- )
-
- print(f"Uploading {state_code}.h5 to GCP...")
- upload_local_area_file(str(output_path), "states", skip_hf=True)
-
- # Queue for batched HuggingFace upload
- hf_queue.append((str(output_path), "states"))
-
- record_completed_state(state_code)
- print(f"Completed {state_code}")
-
- # Flush HF queue every batch_size files
- if len(hf_queue) >= hf_batch_size:
- print(
- f"\nUploading batch of {len(hf_queue)} files to HuggingFace..."
- )
- upload_local_area_batch_to_hf(hf_queue)
- hf_queue = []
-
- except Exception as e:
- print(f"ERROR building {state_code}: {e}")
- raise
-
- # Flush remaining files to HuggingFace
- if hf_queue:
- print(
- f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..."
- )
- upload_local_area_batch_to_hf(hf_queue)
-
-
-def build_and_upload_districts(
- weights_path: Path,
- dataset_path: Path,
- db_path: Path,
- output_dir: Path,
- completed_districts: set,
- hf_batch_size: int = 10,
-):
- """Build and upload district H5 files with checkpointing."""
- db_uri = f"sqlite:///{db_path}"
- cds_to_calibrate = get_all_cds_from_database(db_uri)
- w = np.load(weights_path)
-
- districts_dir = output_dir / "districts"
- districts_dir.mkdir(parents=True, exist_ok=True)
-
- hf_queue = [] # Queue for batched HuggingFace uploads
-
- for i, cd_geoid in enumerate(cds_to_calibrate):
- cd_int = int(cd_geoid)
- state_fips = cd_int // 100
- district_num = cd_int % 100
- state_code = STATE_CODES.get(state_fips, str(state_fips))
- friendly_name = f"{state_code}-{district_num:02d}"
-
- if friendly_name in completed_districts:
- print(f"Skipping {friendly_name} (already completed)")
- continue
-
- output_path = districts_dir / f"{friendly_name}.h5"
- print(f"\n{'='*60}")
- print(f"[{i+1}/{len(cds_to_calibrate)}] Building {friendly_name}")
- print(f"{'='*60}")
-
- try:
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=[cd_geoid],
- dataset_path=str(dataset_path),
- output_path=str(output_path),
- )
-
- print(f"Uploading {friendly_name}.h5 to GCP...")
- upload_local_area_file(str(output_path), "districts", skip_hf=True)
-
- # Queue for batched HuggingFace upload
- hf_queue.append((str(output_path), "districts"))
-
- record_completed_district(friendly_name)
- print(f"Completed {friendly_name}")
-
- # Flush HF queue every batch_size files
- if len(hf_queue) >= hf_batch_size:
- print(
- f"\nUploading batch of {len(hf_queue)} files to HuggingFace..."
- )
- upload_local_area_batch_to_hf(hf_queue)
- hf_queue = []
-
- except Exception as e:
- print(f"ERROR building {friendly_name}: {e}")
- raise
-
- # Flush remaining files to HuggingFace
- if hf_queue:
- print(
- f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..."
- )
- upload_local_area_batch_to_hf(hf_queue)
-
-
-def build_and_upload_cities(
- weights_path: Path,
- dataset_path: Path,
- db_path: Path,
- output_dir: Path,
- completed_cities: set,
- hf_batch_size: int = 10,
-):
- """Build and upload city H5 files with checkpointing."""
- db_uri = f"sqlite:///{db_path}"
- cds_to_calibrate = get_all_cds_from_database(db_uri)
- w = np.load(weights_path)
-
- cities_dir = output_dir / "cities"
- cities_dir.mkdir(parents=True, exist_ok=True)
-
- hf_queue = [] # Queue for batched HuggingFace uploads
-
- # NYC
- if "NYC" in completed_cities:
- print("Skipping NYC (already completed)")
- else:
- cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
- if not cd_subset:
- print("No NYC-related CDs found, skipping")
- else:
- output_path = cities_dir / "NYC.h5"
- print(f"\n{'='*60}")
- print(f"Building NYC ({len(cd_subset)} CDs)")
- print(f"{'='*60}")
-
- try:
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=str(dataset_path),
- output_path=str(output_path),
- county_filter=NYC_COUNTIES,
- )
-
- print("Uploading NYC.h5 to GCP...")
- upload_local_area_file(
- str(output_path), "cities", skip_hf=True
- )
-
- # Queue for batched HuggingFace upload
- hf_queue.append((str(output_path), "cities"))
-
- record_completed_city("NYC")
- print("Completed NYC")
-
- except Exception as e:
- print(f"ERROR building NYC: {e}")
- raise
-
- # Flush remaining files to HuggingFace
- if hf_queue:
- print(
- f"\nUploading batch of {len(hf_queue)} city files to HuggingFace..."
- )
- upload_local_area_batch_to_hf(hf_queue)
-
-
-def main():
- import argparse
-
- parser = argparse.ArgumentParser(
- description="Build and publish local area H5 files"
- )
- parser.add_argument(
- "--skip-download",
- action="store_true",
- help="Skip downloading inputs from HF (use existing files)",
- )
- parser.add_argument(
- "--states-only",
- action="store_true",
- help="Only build and upload state files",
- )
- parser.add_argument(
- "--districts-only",
- action="store_true",
- help="Only build and upload district files",
- )
- parser.add_argument(
- "--cities-only",
- action="store_true",
- help="Only build and upload city files (e.g., NYC)",
- )
- parser.add_argument(
- "--weights-path",
- type=str,
- help="Override path to weights file (for local testing)",
- )
- parser.add_argument(
- "--dataset-path",
- type=str,
- help="Override path to dataset file (for local testing)",
- )
- parser.add_argument(
- "--db-path",
- type=str,
- help="Override path to database file (for local testing)",
- )
- args = parser.parse_args()
-
- WORK_DIR.mkdir(parents=True, exist_ok=True)
-
- if args.weights_path and args.dataset_path and args.db_path:
- inputs = {
- "weights": Path(args.weights_path),
- "dataset": Path(args.dataset_path),
- "database": Path(args.db_path),
- }
- print("Using provided paths:")
- for key, path in inputs.items():
- print(f" {key}: {path}")
- elif args.skip_download:
- inputs = {
- "weights": WORK_DIR / "w_district_calibration.npy",
- "dataset": WORK_DIR / "stratified_extended_cps.h5",
- "database": WORK_DIR / "policy_data.db",
- }
- print("Using existing files in work directory:")
- for key, path in inputs.items():
- if not path.exists():
- raise FileNotFoundError(f"Expected file not found: {path}")
- print(f" {key}: {path}")
- else:
- print("Downloading calibration inputs from Hugging Face...")
- inputs = download_calibration_inputs(str(WORK_DIR))
- for key, path in inputs.items():
- inputs[key] = Path(path)
-
- sim = Microsimulation(dataset=str(inputs["dataset"]))
- n_hh = sim.calculate("household_id", map_to="household").shape[0]
- print(f"\nBase dataset has {n_hh:,} households")
-
- # Determine what to build based on flags
- build_states = not args.districts_only and not args.cities_only
- build_districts = not args.states_only and not args.cities_only
- build_cities = not args.states_only and not args.districts_only
-
- # If a specific *-only flag is set, only build that type
- if args.states_only:
- build_states = True
- build_districts = False
- build_cities = False
- elif args.districts_only:
- build_states = False
- build_districts = True
- build_cities = False
- elif args.cities_only:
- build_states = False
- build_districts = False
- build_cities = True
-
- if build_states:
- print("\n" + "=" * 60)
- print("BUILDING STATE FILES")
- print("=" * 60)
- completed_states = load_completed_states()
- print(f"Already completed: {len(completed_states)} states")
- build_and_upload_states(
- inputs["weights"],
- inputs["dataset"],
- inputs["database"],
- WORK_DIR,
- completed_states,
- )
-
- if build_districts:
- print("\n" + "=" * 60)
- print("BUILDING DISTRICT FILES")
- print("=" * 60)
- completed_districts = load_completed_districts()
- print(f"Already completed: {len(completed_districts)} districts")
- build_and_upload_districts(
- inputs["weights"],
- inputs["dataset"],
- inputs["database"],
- WORK_DIR,
- completed_districts,
- )
-
- if build_cities:
- print("\n" + "=" * 60)
- print("BUILDING CITY FILES")
- print("=" * 60)
- completed_cities = load_completed_cities()
- print(f"Already completed: {len(completed_cities)} cities")
- build_and_upload_cities(
- inputs["weights"],
- inputs["dataset"],
- inputs["database"],
- WORK_DIR,
- completed_cities,
- )
-
- print("\n" + "=" * 60)
- print("ALL DONE!")
- print("=" * 60)
-
-
-if __name__ == "__main__":
- main()
diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py
deleted file mode 100644
index 010e151f3..000000000
--- a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py
+++ /dev/null
@@ -1,932 +0,0 @@
-"""
-Create a sparse congressional district-stacked dataset with non-zero weight
-households.
-"""
-
-import os
-import numpy as np
-import pandas as pd
-import h5py
-from pathlib import Path
-from policyengine_us import Microsimulation
-from policyengine_core.data.dataset import Dataset
-from policyengine_core.enums import Enum
-from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
- get_all_cds_from_database,
- get_calculated_variables,
- STATE_CODES,
- STATE_FIPS_TO_NAME,
- STATE_FIPS_TO_CODE,
- load_cd_geoadj_values,
- calculate_spm_thresholds_for_cd,
-)
-from policyengine_us.variables.household.demographic.geographic.county.county_enum import (
- County,
-)
-from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
- assign_geography_for_cd,
- get_county_filter_probability,
- get_filtered_block_distribution,
-)
-
-NYC_COUNTIES = {
- "QUEENS_COUNTY_NY",
- "BRONX_COUNTY_NY",
- "RICHMOND_COUNTY_NY",
- "NEW_YORK_COUNTY_NY",
- "KINGS_COUNTY_NY",
-}
-
-NYC_CDS = [
- "3603",
- "3605",
- "3606",
- "3607",
- "3608",
- "3609",
- "3610",
- "3611",
- "3612",
- "3613",
- "3614",
- "3615",
- "3616",
-]
-
-
-def get_county_name(county_index: int) -> str:
- """Convert county enum index back to name."""
- return County._member_names_[county_index]
-
-
-def create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=None,
- output_path=None,
- dataset_path=None,
- county_filter=None,
- seed: int = 42,
-):
- """
- Create a SPARSE congressional district-stacked dataset using DataFrame approach.
-
- Args:
- w: Calibrated weight vector from L0 calibration. Shape is (n_cds * n_households,),
- reshaped internally to (n_cds, n_households) using cds_to_calibrate ordering.
- cds_to_calibrate: Ordered list of CD GEOID codes that defines the row ordering
- of the weight matrix. Required to correctly index into w for any cd_subset.
- cd_subset: Optional list of CD GEOIDs to include in output (must be subset of
- cds_to_calibrate). If None, includes all CDs.
- output_path: Where to save the sparse CD-stacked .h5 file.
- dataset_path: Path to the base .h5 dataset used during calibration.
- county_filter: Optional set of county names to filter to. Only households
- assigned to these counties will be included. Used for city-level datasets.
- seed: Base random seed for county assignment. Each CD gets seed + int(cd_geoid)
- for deterministic, order-independent results. Default 42.
-
- Returns:
- output_path: Path to the saved .h5 file.
- """
-
- # Handle CD subset filtering
- if cd_subset is not None:
- # Validate that requested CDs are in the calibration
- for cd in cd_subset:
- if cd not in cds_to_calibrate:
- raise ValueError(f"CD {cd} not in calibrated CDs list")
-
- # Get indices of requested CDs
- cd_indices = [cds_to_calibrate.index(cd) for cd in cd_subset]
- cds_to_process = cd_subset
-
- print(
- f"Processing subset of {len(cd_subset)} CDs: {', '.join(cd_subset[:5])}..."
- )
- else:
- # Process all CDs
- cd_indices = list(range(len(cds_to_calibrate)))
- cds_to_process = cds_to_calibrate
- print(
- f"Processing all {len(cds_to_calibrate)} congressional districts"
- )
-
- # Generate output path if not provided
- if output_path is None:
- raise ValueError("No output .h5 path given")
- print(f"Output path: {output_path}")
-
- # Check that output directory exists, create if needed
- output_dir_path = os.path.dirname(output_path)
- if output_dir_path and not os.path.exists(output_dir_path):
- print(f"Creating output directory: {output_dir_path}")
- os.makedirs(output_dir_path, exist_ok=True)
-
- # Load the original simulation
- base_sim = Microsimulation(dataset=dataset_path)
-
- household_ids = base_sim.calculate(
- "household_id", map_to="household"
- ).values
- n_households_orig = len(household_ids)
-
- # From the base sim, create mapping from household ID to index for proper filtering
- hh_id_to_idx = {int(hh_id): idx for idx, hh_id in enumerate(household_ids)}
-
- # Infer the number of households from weight vector and CD count
- if len(w) % len(cds_to_calibrate) != 0:
- raise ValueError(
- f"Weight vector length ({len(w):,}) is not evenly divisible by "
- f"number of CDs ({len(cds_to_calibrate)}). Cannot determine household count."
- )
- n_households_from_weights = len(w) // len(cds_to_calibrate)
-
- if n_households_from_weights != n_households_orig:
- raise ValueError(
- "Households from base data set do not match households from weights"
- )
-
- print(f"\nOriginal dataset has {n_households_orig:,} households")
-
- # Process the weight vector to understand active household-CD pairs
- W_full = w.reshape(len(cds_to_calibrate), n_households_orig)
- # (436, 10580)
-
- # Extract only the CDs we want to process
- if cd_subset is not None:
- W = W_full[cd_indices, :]
- print(
- f"Extracted weights for {len(cd_indices)} CDs from full weight matrix"
- )
- else:
- W = W_full
-
- # Count total active weights: i.e., number of active households
- total_active_weights = np.sum(W > 0)
- total_weight_in_W = np.sum(W)
- print(f"Total active household-CD pairs: {total_active_weights:,}")
- print(f"Total weight in W matrix: {total_weight_in_W:,.0f}")
-
- cd_geoadj_values = load_cd_geoadj_values(cds_to_calibrate)
-
- # Collect DataFrames for each CD
- cd_dfs = []
- total_kept_households = 0
- time_period = int(base_sim.default_calculation_period)
-
- for idx, cd_geoid in enumerate(cds_to_process):
- # Progress every 10 CDs and at the end ----
- if (idx + 1) % 10 == 0 or (idx + 1) == len(cds_to_process):
- print(
- f"Processing CD {cd_geoid} ({idx + 1}/{len(cds_to_process)})..."
- )
-
- # Get the correct index in the weight matrix
- cd_idx = idx # Index in our filtered W matrix
-
- # Get ALL households with non-zero weight in this CD
- active_household_indices = np.where(W[cd_idx, :] > 0)[0]
-
- if len(active_household_indices) == 0:
- continue
-
- # Get the household IDs for active households
- active_household_ids = set(
- household_ids[hh_idx] for hh_idx in active_household_indices
- )
-
- # Fresh simulation per CD is necessary because:
- # 1. Each CD needs different state_fips, county, and CD values set
- # 2. Calculated variables (SNAP, Medicaid, etc.) must be invalidated
- # and recalculated with the new geographic inputs
- # 3. Reusing a simulation would retain stale cached calculations
- # Memory impact: ~50MB per simulation, but allows correct state-specific
- # benefit calculations. Total memory scales with CD count in cd_subset.
- cd_sim = Microsimulation(dataset=dataset_path)
-
- # First, create hh_df with CALIBRATED weights from the W matrix
- household_ids_in_sim = cd_sim.calculate(
- "household_id", map_to="household"
- ).values
-
- # Get this CD's calibrated weights from the weight matrix
- calibrated_weights_for_cd = W[
- cd_idx, :
- ].copy() # Get this CD's row from weight matrix
-
- # For city datasets: scale weights by P(target|CD)
- # This preserves the representative sample while adjusting for target population
- if county_filter is not None:
- p_target = get_county_filter_probability(cd_geoid, county_filter)
- if p_target == 0:
- # CD has no overlap with target area, skip entirely
- continue
- calibrated_weights_for_cd = calibrated_weights_for_cd * p_target
-
- # Map the calibrated weights to household IDs
- hh_weight_values = []
- for hh_id in household_ids_in_sim:
- hh_idx = hh_id_to_idx[int(hh_id)] # Get index in weight matrix
- hh_weight_values.append(calibrated_weights_for_cd[hh_idx])
-
- entity_rel = pd.DataFrame(
- {
- "person_id": cd_sim.calculate(
- "person_id", map_to="person"
- ).values,
- "household_id": cd_sim.calculate(
- "household_id", map_to="person"
- ).values,
- "tax_unit_id": cd_sim.calculate(
- "tax_unit_id", map_to="person"
- ).values,
- "spm_unit_id": cd_sim.calculate(
- "spm_unit_id", map_to="person"
- ).values,
- "family_id": cd_sim.calculate(
- "family_id", map_to="person"
- ).values,
- "marital_unit_id": cd_sim.calculate(
- "marital_unit_id", map_to="person"
- ).values,
- }
- )
-
- hh_df = pd.DataFrame(
- {
- "household_id": household_ids_in_sim,
- "household_weight": hh_weight_values,
- }
- )
- counts = (
- entity_rel.groupby("household_id")["person_id"]
- .size()
- .reset_index(name="persons_per_hh")
- )
- hh_df = hh_df.merge(counts)
- hh_df["per_person_hh_weight"] = (
- hh_df.household_weight / hh_df.persons_per_hh
- )
-
- # SET WEIGHTS IN SIMULATION BEFORE EXTRACTING DATAFRAME
- # This is the key - set_input updates the simulation's internal state
-
- non_household_cols = [
- "person_id",
- "tax_unit_id",
- "spm_unit_id",
- "family_id",
- "marital_unit_id",
- ]
-
- new_weights_per_id = {}
- for col in non_household_cols:
- person_counts = (
- entity_rel.groupby(col)["person_id"]
- .size()
- .reset_index(name="person_id_count")
- )
- # Below: drop duplicates to undo the broadcast join done in entity_rel
- id_link = entity_rel[["household_id", col]].drop_duplicates()
- hh_info = id_link.merge(hh_df)
-
- hh_info2 = hh_info.merge(person_counts, on=col)
- if col == "person_id":
- # Person weight = household weight (each person represents same count as their household)
- hh_info2["id_weight"] = hh_info2.household_weight
- else:
- hh_info2["id_weight"] = (
- hh_info2.per_person_hh_weight * hh_info2.person_id_count
- )
- new_weights_per_id[col] = hh_info2.id_weight
-
- cd_sim.set_input(
- "household_weight", time_period, hh_df.household_weight.values
- )
- cd_sim.set_input(
- "person_weight", time_period, new_weights_per_id["person_id"]
- )
- cd_sim.set_input(
- "tax_unit_weight", time_period, new_weights_per_id["tax_unit_id"]
- )
- cd_sim.set_input(
- "spm_unit_weight", time_period, new_weights_per_id["spm_unit_id"]
- )
- cd_sim.set_input(
- "marital_unit_weight",
- time_period,
- new_weights_per_id["marital_unit_id"],
- )
- cd_sim.set_input(
- "family_weight", time_period, new_weights_per_id["family_id"]
- )
-
- # Extract state from CD GEOID and update simulation BEFORE calling to_input_dataframe()
- # This ensures calculated variables (SNAP, Medicaid) use the correct state
- cd_geoid_int = int(cd_geoid)
- state_fips = cd_geoid_int // 100
-
- cd_sim.set_input(
- "state_fips",
- time_period,
- np.full(n_households_orig, state_fips, dtype=np.int32),
- )
- cd_sim.set_input(
- "congressional_district_geoid",
- time_period,
- np.full(n_households_orig, cd_geoid_int, dtype=np.int32),
- )
-
- # Assign all geography using census block assignment
- # For city datasets: use only blocks in target counties
- if county_filter is not None:
- filtered_dist = get_filtered_block_distribution(
- cd_geoid, county_filter
- )
- if not filtered_dist:
- # Should not happen if we already checked p_target > 0
- continue
- geography = assign_geography_for_cd(
- cd_geoid=cd_geoid,
- n_households=n_households_orig,
- seed=seed + int(cd_geoid),
- distributions={cd_geoid: filtered_dist},
- )
- else:
- geography = assign_geography_for_cd(
- cd_geoid=cd_geoid,
- n_households=n_households_orig,
- seed=seed + int(cd_geoid),
- )
- # Set county using indices for backwards compatibility with PolicyEngine-US
- cd_sim.set_input("county", time_period, geography["county_index"])
-
- # Set all other geography variables from block assignment
- cd_sim.set_input("block_geoid", time_period, geography["block_geoid"])
- cd_sim.set_input("tract_geoid", time_period, geography["tract_geoid"])
- cd_sim.set_input("cbsa_code", time_period, geography["cbsa_code"])
- cd_sim.set_input("sldu", time_period, geography["sldu"])
- cd_sim.set_input("sldl", time_period, geography["sldl"])
- cd_sim.set_input("place_fips", time_period, geography["place_fips"])
- cd_sim.set_input("vtd", time_period, geography["vtd"])
- cd_sim.set_input("puma", time_period, geography["puma"])
- cd_sim.set_input("zcta", time_period, geography["zcta"])
-
- # Note: We no longer use binary filtering for county_filter.
- # Instead, weights are scaled by P(target|CD) and all households
- # are included to avoid sample selection bias.
-
- geoadj = cd_geoadj_values[cd_geoid]
- new_spm_thresholds = calculate_spm_thresholds_for_cd(
- cd_sim, time_period, geoadj, year=time_period
- )
- cd_sim.set_input(
- "spm_unit_spm_threshold", time_period, new_spm_thresholds
- )
-
- # Delete cached calculated variables to ensure they're recalculated
- # with new state and county. Exclude 'county' itself since we just set it.
- for var in get_calculated_variables(cd_sim):
- if var != "county":
- cd_sim.delete_arrays(var)
-
- # Now extract the dataframe - calculated vars will use the updated state
- df = cd_sim.to_input_dataframe()
-
- assert df.shape[0] == entity_rel.shape[0] # df is at the person level
-
- # Column names follow pattern: variable__year
- hh_id_col = f"household_id__{time_period}"
- cd_geoid_col = f"congressional_district_geoid__{time_period}"
- hh_weight_col = f"household_weight__{time_period}"
- person_weight_col = f"person_weight__{time_period}"
- tax_unit_weight_col = f"tax_unit_weight__{time_period}"
- person_id_col = f"person_id__{time_period}"
- tax_unit_id_col = f"tax_unit_id__{time_period}"
-
- state_fips_col = f"state_fips__{time_period}"
- state_name_col = f"state_name__{time_period}"
- state_code_col = f"state_code__{time_period}"
-
- # Filter to only active households in this CD
- df_filtered = df[df[hh_id_col].isin(active_household_ids)].copy()
-
- # Update congressional_district_geoid to target CD
- df_filtered[cd_geoid_col] = int(cd_geoid)
-
- # Update state variables for consistency
- df_filtered[state_fips_col] = state_fips
- if state_fips in STATE_FIPS_TO_NAME:
- df_filtered[state_name_col] = STATE_FIPS_TO_NAME[state_fips]
- if state_fips in STATE_FIPS_TO_CODE:
- df_filtered[state_code_col] = STATE_FIPS_TO_CODE[state_fips]
-
- cd_dfs.append(df_filtered)
- total_kept_households += len(df_filtered[hh_id_col].unique())
-
- print(f"\nCombining {len(cd_dfs)} CD DataFrames...")
- print(f"Total households across all CDs: {total_kept_households:,}")
-
- # Combine all CD DataFrames
- combined_df = pd.concat(cd_dfs, ignore_index=True)
- print(f"Combined DataFrame shape: {combined_df.shape}")
-
- # REINDEX ALL IDs TO PREVENT OVERFLOW AND HANDLE DUPLICATES
- print("\nReindexing all entity IDs using 25k ranges per CD...")
-
- # Column names
- hh_id_col = f"household_id__{time_period}"
- person_id_col = f"person_id__{time_period}"
- person_hh_id_col = f"person_household_id__{time_period}"
- tax_unit_id_col = f"tax_unit_id__{time_period}"
- person_tax_unit_col = f"person_tax_unit_id__{time_period}"
- spm_unit_id_col = f"spm_unit_id__{time_period}"
- person_spm_unit_col = f"person_spm_unit_id__{time_period}"
- marital_unit_id_col = f"marital_unit_id__{time_period}"
- person_marital_unit_col = f"person_marital_unit_id__{time_period}"
- family_id_col = f"family_id__{time_period}"
- person_family_col = f"person_family_id__{time_period}"
- cd_geoid_col = f"congressional_district_geoid__{time_period}"
-
- # Build CD index mapping from cds_to_calibrate (avoids database dependency)
- cds_sorted = sorted(cds_to_calibrate)
- cd_to_index = {cd: idx for idx, cd in enumerate(cds_sorted)}
-
- # Create household mapping for CSV export
- household_mapping = []
-
- # First, create a unique row identifier to track relationships
- combined_df["_row_idx"] = range(len(combined_df))
-
- # Group by household ID AND congressional district to create unique household-CD pairs
- hh_groups = (
- combined_df.groupby([hh_id_col, cd_geoid_col])["_row_idx"]
- .apply(list)
- .to_dict()
- )
-
- # Assign new household IDs using 25k ranges per CD
- hh_row_to_new_id = {}
- cd_hh_counters = {} # Track how many households assigned per CD
-
- for (old_hh_id, cd_geoid), row_indices in hh_groups.items():
- # Calculate the ID range for this CD directly (avoiding function call)
- cd_str = str(int(cd_geoid))
- cd_idx = cd_to_index[cd_str]
- start_id = cd_idx * 25_000
- end_id = start_id + 24_999
-
- # Get the next available ID in this CD's range
- if cd_str not in cd_hh_counters:
- cd_hh_counters[cd_str] = 0
-
- new_hh_id = start_id + cd_hh_counters[cd_str]
-
- # Check we haven't exceeded the range
- if new_hh_id > end_id:
- raise ValueError(
- f"CD {cd_str} exceeded its 25k household allocation"
- )
-
- # All rows in the same household-CD pair get the SAME new ID
- for row_idx in row_indices:
- hh_row_to_new_id[row_idx] = new_hh_id
-
- # Save the mapping
- household_mapping.append(
- {
- "new_household_id": new_hh_id,
- "original_household_id": int(old_hh_id),
- "congressional_district": cd_str,
- "state_fips": int(cd_str) // 100,
- }
- )
-
- cd_hh_counters[cd_str] += 1
-
- # Apply new household IDs based on row index
- combined_df["_new_hh_id"] = combined_df["_row_idx"].map(hh_row_to_new_id)
-
- # Update household IDs
- combined_df[hh_id_col] = combined_df["_new_hh_id"]
-
- # Update person household references - since persons are already in their households,
- # person_household_id should just match the household_id of their row
- combined_df[person_hh_id_col] = combined_df["_new_hh_id"]
-
- # Report statistics
- total_households = sum(cd_hh_counters.values())
- print(
- f" Created {total_households:,} unique households across {len(cd_hh_counters)} CDs"
- )
-
- # Now handle persons with same 25k range approach - VECTORIZED
- print(" Reindexing persons using 25k ranges...")
-
- # OFFSET PERSON IDs by 5 million to avoid collision with household IDs
- PERSON_ID_OFFSET = 5_000_000
-
- # Group by CD and assign IDs in bulk for each CD
- for cd_geoid_val in combined_df[cd_geoid_col].unique():
- cd_str = str(int(cd_geoid_val))
-
- # Calculate the ID range for this CD directly
- cd_idx = cd_to_index[cd_str]
- start_id = cd_idx * 25_000 + PERSON_ID_OFFSET # Add offset for persons
- end_id = start_id + 24_999
-
- # Get all rows for this CD
- cd_mask = combined_df[cd_geoid_col] == cd_geoid_val
- n_persons_in_cd = cd_mask.sum()
-
- # Check we won't exceed the range
- if n_persons_in_cd > (end_id - start_id + 1):
- raise ValueError(
- f"CD {cd_str} has {n_persons_in_cd} persons, exceeds 25k allocation"
- )
-
- # Create sequential IDs for this CD
- new_person_ids = np.arange(
- start_id, start_id + n_persons_in_cd, dtype=np.int32
- )
-
- # Assign all at once using loc
- combined_df.loc[cd_mask, person_id_col] = new_person_ids
-
- # Reindex sub-household entities using vectorized groupby().ngroup()
- # This assigns unique IDs to each (household_id, original_entity_id) pair,
- # which correctly handles the same original household appearing in multiple CDs
- entity_configs = [
- ("tax units", person_tax_unit_col, tax_unit_id_col),
- ("SPM units", person_spm_unit_col, spm_unit_id_col),
- ("marital units", person_marital_unit_col, marital_unit_id_col),
- ("families", person_family_col, family_id_col),
- ]
-
- for entity_name, person_col, entity_col in entity_configs:
- print(f" Reindexing {entity_name}...")
- # Group by (household_id, original_entity_id) and assign unique group numbers
- new_ids = combined_df.groupby(
- [hh_id_col, person_col], sort=False
- ).ngroup()
- combined_df[person_col] = new_ids
- if entity_col in combined_df.columns:
- combined_df[entity_col] = new_ids
-
- # Clean up temporary columns
- temp_cols = [col for col in combined_df.columns if col.startswith("_")]
- combined_df = combined_df.drop(columns=temp_cols)
-
- print(f" Final persons: {len(combined_df):,}")
- print(f" Final households: {total_households:,}")
- print(f" Final tax units: {combined_df[person_tax_unit_col].nunique():,}")
- print(f" Final SPM units: {combined_df[person_spm_unit_col].nunique():,}")
- print(
- f" Final marital units: {combined_df[person_marital_unit_col].nunique():,}"
- )
- print(f" Final families: {combined_df[person_family_col].nunique():,}")
-
- # Check weights in combined_df AFTER reindexing
- print(f"\nWeights in combined_df AFTER reindexing:")
- print(f" HH weight sum: {combined_df[hh_weight_col].sum()/1e6:.2f}M")
- print(
- f" Person weight sum: {combined_df[person_weight_col].sum()/1e6:.2f}M"
- )
- print(
- f" Ratio: {combined_df[person_weight_col].sum() / combined_df[hh_weight_col].sum():.2f}"
- )
-
- # Verify no overflow risk
- max_person_id = combined_df[person_id_col].max()
- print(f"\nOverflow check:")
- print(f" Max person ID after reindexing: {max_person_id:,}")
- print(f" Max person ID × 100: {max_person_id * 100:,}")
- print(f" int32 max: {2_147_483_647:,}")
- if max_person_id * 100 < 2_147_483_647:
- print(" ✓ No overflow risk!")
- else:
- print(" ⚠️ WARNING: Still at risk of overflow!")
-
- # Create Dataset from combined DataFrame
- print("\nCreating Dataset from combined DataFrame...")
- sparse_dataset = Dataset.from_dataframe(combined_df, time_period)
-
- # Build a simulation to convert to h5
- print("Building simulation from Dataset...")
- sparse_sim = Microsimulation()
- sparse_sim.dataset = sparse_dataset
- sparse_sim.build_from_dataset()
-
- # Save to h5 file
- print(f"\nSaving to {output_path}...")
- data = {}
-
- # Only save input variables (not calculated/derived variables)
- # Calculated variables like state_name, state_code will be recalculated on load
- vars_to_save = set(base_sim.input_variables)
- print(f"Found {len(vars_to_save)} input variables to save")
-
- # congressional_district_geoid isn't in the original microdata and has no formula,
- # so it's not in input_vars. Since we set it explicitly during stacking, save it.
- vars_to_save.add("congressional_district_geoid")
-
- # county is set explicitly with assign_counties_for_cd, must be saved
- vars_to_save.add("county")
-
- # spm_unit_spm_threshold is recalculated with CD-specific geo-adjustment
- vars_to_save.add("spm_unit_spm_threshold")
-
- # Add all geography variables set during block assignment
- vars_to_save.add("block_geoid")
- vars_to_save.add("tract_geoid")
- vars_to_save.add("cbsa_code")
- vars_to_save.add("sldu")
- vars_to_save.add("sldl")
- vars_to_save.add("place_fips")
- vars_to_save.add("vtd")
- vars_to_save.add("puma")
- vars_to_save.add("zcta")
-
- variables_saved = 0
- variables_skipped = 0
-
- for variable in sparse_sim.tax_benefit_system.variables:
- if variable not in vars_to_save:
- variables_skipped += 1
- continue
-
- # Only process variables that have actual data
- data[variable] = {}
- for period in sparse_sim.get_holder(variable).get_known_periods():
- values = sparse_sim.get_holder(variable).get_array(period)
-
- # Handle different value types
- if (
- sparse_sim.tax_benefit_system.variables.get(
- variable
- ).value_type
- in (Enum, str)
- and variable != "county_fips"
- ):
- # Handle EnumArray objects
- if hasattr(values, "decode_to_str"):
- values = values.decode_to_str().astype("S")
- else:
- # Already a regular numpy array, just convert to string type
- values = values.astype("S")
- elif variable == "county_fips":
- values = values.astype("int32")
- else:
- values = np.array(values)
-
- if values is not None:
- data[variable][period] = values
- variables_saved += 1
-
- if len(data[variable]) == 0:
- del data[variable]
-
- print(f"Variables saved: {variables_saved}")
- print(f"Variables skipped: {variables_skipped}")
-
- # Write to h5
- with h5py.File(output_path, "w") as f:
- for variable, periods in data.items():
- grp = f.create_group(variable)
- for period, values in periods.items():
- grp.create_dataset(str(period), data=values)
-
- print(f"Sparse CD-stacked dataset saved successfully!")
-
- # Save household mapping to CSV in a mappings subdirectory
- mapping_df = pd.DataFrame(household_mapping)
- output_dir = os.path.dirname(output_path)
- mappings_dir = (
- os.path.join(output_dir, "mappings") if output_dir else "mappings"
- )
- os.makedirs(mappings_dir, exist_ok=True)
- csv_filename = os.path.basename(output_path).replace(
- ".h5", "_household_mapping.csv"
- )
- csv_path = os.path.join(mappings_dir, csv_filename)
- mapping_df.to_csv(csv_path, index=False)
- print(f"Household mapping saved to {csv_path}")
-
- # Verify the saved file
- print("\nVerifying saved file...")
- with h5py.File(output_path, "r") as f:
- if "household_id" in f and str(time_period) in f["household_id"]:
- hh_ids = f["household_id"][str(time_period)][:]
- print(f" Final households: {len(hh_ids):,}")
- if "person_id" in f and str(time_period) in f["person_id"]:
- person_ids = f["person_id"][str(time_period)][:]
- print(f" Final persons: {len(person_ids):,}")
- if (
- "household_weight" in f
- and str(time_period) in f["household_weight"]
- ):
- weights = f["household_weight"][str(time_period)][:]
- print(
- f" Total population (from household weights): {np.sum(weights):,.0f}"
- )
- if "person_weight" in f and str(time_period) in f["person_weight"]:
- person_weights = f["person_weight"][str(time_period)][:]
- print(
- f" Total population (from person weights): {np.sum(person_weights):,.0f}"
- )
- print(
- f" Average persons per household: {np.sum(person_weights) / np.sum(weights):.2f}"
- )
-
- return output_path
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(
- description="Create sparse CD-stacked datasets"
- )
- parser.add_argument(
- "--weights-path", required=True, help="Path to w_cd.npy file"
- )
- parser.add_argument(
- "--dataset-path",
- required=True,
- help="Path to stratified dataset .h5 file",
- )
- parser.add_argument(
- "--db-path", required=True, help="Path to policy_data.db"
- )
- parser.add_argument(
- "--output-dir",
- default="./temp",
- help="Output directory for files",
- )
- parser.add_argument(
- "--mode",
- choices=[
- "national",
- "states",
- "cds",
- "single-cd",
- "single-state",
- "nyc",
- ],
- default="national",
- help="Output mode: national (one file), states (per-state files), cds (per-CD files), single-cd (one CD), single-state (one state), nyc (NYC only)",
- )
- parser.add_argument(
- "--cd",
- type=str,
- help="Single CD GEOID to process (only used with --mode single-cd)",
- )
- parser.add_argument(
- "--state",
- type=str,
- help="State code to process, e.g. RI, CA, NC (only used with --mode single-state)",
- )
-
- args = parser.parse_args()
- dataset_path_str = args.dataset_path
- weights_path_str = args.weights_path
- db_path = Path(args.db_path).resolve()
- output_dir = args.output_dir
- mode = args.mode
-
- os.makedirs(output_dir, exist_ok=True)
-
- # Load weights
- w = np.load(weights_path_str)
- db_uri = f"sqlite:///{db_path}"
-
- # Get list of CDs from database
- cds_to_calibrate = get_all_cds_from_database(db_uri)
- print(f"Found {len(cds_to_calibrate)} congressional districts")
-
- # Verify dimensions
- assert_sim = Microsimulation(dataset=dataset_path_str)
- n_hh = assert_sim.calculate("household_id", map_to="household").shape[0]
- expected_length = len(cds_to_calibrate) * n_hh
-
- if len(w) != expected_length:
- raise ValueError(
- f"Weight vector length ({len(w):,}) doesn't match expected ({expected_length:,})"
- )
-
- if mode == "national":
- output_path = f"{output_dir}/national.h5"
- print(f"\nCreating national dataset with all CDs: {output_path}")
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- dataset_path=dataset_path_str,
- output_path=output_path,
- )
-
- elif mode == "states":
- for state_fips, state_code in STATE_CODES.items():
- cd_subset = [
- cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips
- ]
- if not cd_subset:
- continue
- output_path = f"{output_dir}/{state_code}.h5"
- print(f"\nCreating {state_code} dataset: {output_path}")
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=dataset_path_str,
- output_path=output_path,
- )
-
- elif mode == "cds":
- for i, cd_geoid in enumerate(cds_to_calibrate):
- # Convert GEOID to friendly name: 3705 -> NC-05
- cd_int = int(cd_geoid)
- state_fips = cd_int // 100
- district_num = cd_int % 100
- state_code = STATE_CODES.get(state_fips, str(state_fips))
- friendly_name = f"{state_code}-{district_num:02d}"
-
- output_path = f"{output_dir}/{friendly_name}.h5"
- print(
- f"\n[{i+1}/{len(cds_to_calibrate)}] Creating {friendly_name}.h5 (GEOID {cd_geoid})"
- )
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=[cd_geoid],
- dataset_path=dataset_path_str,
- output_path=output_path,
- )
-
- elif mode == "single-cd":
- if not args.cd:
- raise ValueError("--cd required with --mode single-cd")
- if args.cd not in cds_to_calibrate:
- raise ValueError(f"CD {args.cd} not in calibrated CDs list")
- output_path = f"{output_dir}/{args.cd}.h5"
- print(f"\nCreating single CD dataset: {output_path}")
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=[args.cd],
- dataset_path=dataset_path_str,
- output_path=output_path,
- )
-
- elif mode == "single-state":
- if not args.state:
- raise ValueError("--state required with --mode single-state")
- # Find FIPS code for this state
- state_code_upper = args.state.upper()
- state_fips = None
- for fips, code in STATE_CODES.items():
- if code == state_code_upper:
- state_fips = fips
- break
- if state_fips is None:
- raise ValueError(f"Unknown state code: {args.state}")
-
- cd_subset = [
- cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips
- ]
- if not cd_subset:
- raise ValueError(f"No CDs found for state {state_code_upper}")
-
- output_path = f"{output_dir}/{state_code_upper}.h5"
- print(
- f"\nCreating {state_code_upper} dataset with {len(cd_subset)} CDs: {output_path}"
- )
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=dataset_path_str,
- output_path=output_path,
- )
-
- elif mode == "nyc":
- cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
- if not cd_subset:
- raise ValueError("No NYC-related CDs found in calibrated CDs list")
-
- output_path = f"{output_dir}/NYC.h5"
- print(
- f"\nCreating NYC dataset with {len(cd_subset)} CDs: {output_path}"
- )
- print(f" CDs: {', '.join(cd_subset)}")
- print(" Filtering to NYC counties only")
-
- create_sparse_cd_stacked_dataset(
- w,
- cds_to_calibrate,
- cd_subset=cd_subset,
- dataset_path=dataset_path_str,
- output_path=output_path,
- county_filter=NYC_COUNTIES,
- )
-
- print("\nDone!")
diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py
index 0b9ae8a6d..253262c90 100644
--- a/policyengine_us_data/db/create_initial_strata.py
+++ b/policyengine_us_data/db/create_initial_strata.py
@@ -40,8 +40,9 @@ def fetch_congressional_districts(year):
df["state_fips"] = df["state"].astype(int)
df = df[df["state_fips"] <= 56].copy()
df["district_number"] = df["congressional district"].apply(
- lambda x: 0 if x in ["ZZ", "98"] else int(x)
+ lambda x: int(x) if x not in ["ZZ"] else -1
)
+ df = df[df["district_number"] >= 0].copy()
# Filter out statewide summary records for multi-district states
df["n_districts"] = df.groupby("state_fips")["state_fips"].transform(
@@ -49,8 +50,6 @@ def fetch_congressional_districts(year):
)
df = df[(df["n_districts"] == 1) | (df["district_number"] > 0)].copy()
df = df.drop(columns=["n_districts"])
-
- df.loc[df["district_number"] == 0, "district_number"] = 1
df["congressional_district_geoid"] = (
df["state_fips"] * 100 + df["district_number"]
)
diff --git a/policyengine_us_data/storage/calibration/w_district_calibration.npy b/policyengine_us_data/storage/calibration/w_district_calibration.npy
deleted file mode 100644
index 6059d1818..000000000
Binary files a/policyengine_us_data/storage/calibration/w_district_calibration.npy and /dev/null differ
diff --git a/policyengine_us_data/tests/test_calibration/conftest.py b/policyengine_us_data/tests/test_calibration/conftest.py
index 9b8edcf74..354491567 100644
--- a/policyengine_us_data/tests/test_calibration/conftest.py
+++ b/policyengine_us_data/tests/test_calibration/conftest.py
@@ -1,4 +1,18 @@
-# Calibration test fixtures.
-#
-# The microimpute mock lives in the root conftest.py (propagates to
-# all subdirectories). Add calibration-specific fixtures here.
+"""Shared fixtures for local area calibration tests."""
+
+import pytest
+
+from policyengine_us_data.storage import STORAGE_FOLDER
+
+
+@pytest.fixture(scope="module")
+def db_uri():
+ db_path = STORAGE_FOLDER / "calibration" / "policy_data.db"
+ return f"sqlite:///{db_path}"
+
+
+@pytest.fixture(scope="module")
+def dataset_path():
+ return str(
+ STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5"
+ )
diff --git a/policyengine_us_data/tests/test_local_area_calibration/create_test_fixture.py b/policyengine_us_data/tests/test_calibration/create_test_fixture.py
similarity index 100%
rename from policyengine_us_data/tests/test_local_area_calibration/create_test_fixture.py
rename to policyengine_us_data/tests/test_calibration/create_test_fixture.py
diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_block_assignment.py b/policyengine_us_data/tests/test_calibration/test_block_assignment.py
similarity index 82%
rename from policyengine_us_data/tests/test_local_area_calibration/test_block_assignment.py
rename to policyengine_us_data/tests/test_calibration/test_block_assignment.py
index 0f1001385..c128d65e6 100644
--- a/policyengine_us_data/tests/test_local_area_calibration/test_block_assignment.py
+++ b/policyengine_us_data/tests/test_calibration/test_block_assignment.py
@@ -14,7 +14,7 @@ class TestBlockAssignment:
def test_assign_returns_correct_shape(self):
"""Verify assign_blocks_for_cd returns correct shape."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_blocks_for_cd,
)
@@ -26,7 +26,7 @@ def test_assign_returns_correct_shape(self):
def test_assign_is_deterministic(self):
"""Verify same seed produces same results."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_blocks_for_cd,
)
@@ -36,7 +36,7 @@ def test_assign_is_deterministic(self):
def test_different_seeds_different_results(self):
"""Verify different seeds produce different results."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_blocks_for_cd,
)
@@ -46,7 +46,7 @@ def test_different_seeds_different_results(self):
def test_ny_cd_gets_ny_blocks(self):
"""Verify NY CDs get NY blocks."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_blocks_for_cd,
)
@@ -59,7 +59,7 @@ def test_ny_cd_gets_ny_blocks(self):
def test_ca_cd_gets_ca_blocks(self):
"""Verify CA CDs get CA blocks."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_blocks_for_cd,
)
@@ -76,7 +76,7 @@ class TestGeographyLookup:
def test_get_county_from_block(self):
"""Verify county FIPS extraction from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_county_fips_from_block,
)
@@ -89,7 +89,7 @@ def test_get_county_from_block(self):
def test_get_tract_from_block(self):
"""Verify tract GEOID extraction from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_tract_geoid_from_block,
)
@@ -100,7 +100,7 @@ def test_get_tract_from_block(self):
def test_get_state_fips_from_block(self):
"""Verify state FIPS extraction from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_state_fips_from_block,
)
@@ -114,7 +114,7 @@ class TestCBSALookup:
def test_manhattan_in_nyc_metro(self):
"""Verify Manhattan (New York County) is in NYC metro area."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_cbsa_from_county,
)
@@ -125,7 +125,7 @@ def test_manhattan_in_nyc_metro(self):
def test_sf_county_in_sf_metro(self):
"""Verify San Francisco County is in SF metro area."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_cbsa_from_county,
)
@@ -136,7 +136,7 @@ def test_sf_county_in_sf_metro(self):
def test_rural_county_no_cbsa(self):
"""Verify rural county not in any metro area returns None."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_cbsa_from_county,
)
@@ -150,7 +150,7 @@ class TestIntegratedAssignment:
def test_assign_geography_returns_all_fields(self):
"""Verify assign_geography returns dict with all geography fields."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
@@ -181,7 +181,7 @@ def test_assign_geography_returns_all_fields(self):
def test_geography_is_consistent(self):
"""Verify all geography fields are consistent with each other."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
@@ -207,7 +207,7 @@ class TestStateLegislativeDistricts:
def test_get_sldu_from_block(self):
"""Verify SLDU lookup from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_sldu_from_block,
)
@@ -218,7 +218,7 @@ def test_get_sldu_from_block(self):
def test_get_sldl_from_block(self):
"""Verify SLDL lookup from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_sldl_from_block,
)
@@ -229,7 +229,7 @@ def test_get_sldl_from_block(self):
def test_assign_geography_includes_state_leg(self):
"""Verify assign_geography includes SLDU and SLDL."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
@@ -246,7 +246,7 @@ class TestPlaceLookup:
def test_get_place_fips_from_block(self):
"""Verify place FIPS lookup from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_place_fips_from_block,
)
@@ -257,7 +257,7 @@ def test_get_place_fips_from_block(self):
def test_assign_geography_includes_place(self):
"""Verify assign_geography includes place_fips."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
@@ -272,7 +272,7 @@ class TestPUMALookup:
def test_get_puma_from_block(self):
"""Verify PUMA lookup from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_puma_from_block,
)
@@ -283,7 +283,7 @@ def test_get_puma_from_block(self):
def test_assign_geography_includes_puma(self):
"""Verify assign_geography includes PUMA."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
@@ -298,7 +298,7 @@ class TestVTDLookup:
def test_get_vtd_from_block(self):
"""Verify VTD lookup from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_vtd_from_block,
)
@@ -309,7 +309,7 @@ def test_get_vtd_from_block(self):
def test_assign_geography_includes_vtd(self):
"""Verify assign_geography includes VTD."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
@@ -324,7 +324,7 @@ class TestAllGeographyLookup:
def test_get_all_geography_returns_all_fields(self):
"""Verify get_all_geography_from_block returns all expected fields."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_all_geography_from_block,
)
@@ -336,7 +336,7 @@ def test_get_all_geography_returns_all_fields(self):
def test_get_all_geography_unknown_block(self):
"""Verify get_all_geography handles unknown block gracefully."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_all_geography_from_block,
)
@@ -352,7 +352,7 @@ class TestCountyEnumIntegration:
def test_get_county_enum_from_block(self):
"""Verify we can get County enum index from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_county_enum_index_from_block,
)
from policyengine_us.variables.household.demographic.geographic.county.county_enum import (
@@ -368,7 +368,7 @@ def test_get_county_enum_from_block(self):
def test_assign_geography_includes_county_index(self):
"""Verify assign_geography includes county_index for backwards compat."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
from policyengine_us.variables.household.demographic.geographic.county.county_enum import (
@@ -392,7 +392,7 @@ class TestZCTALookup:
def test_get_zcta_from_block(self):
"""Verify ZCTA lookup from block GEOID."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
get_zcta_from_block,
)
@@ -403,7 +403,7 @@ def test_get_zcta_from_block(self):
def test_assign_geography_includes_zcta(self):
"""Verify assign_geography includes ZCTA."""
- from policyengine_us_data.datasets.cps.local_area_calibration.block_assignment import (
+ from policyengine_us_data.calibration.block_assignment import (
assign_geography_for_cd,
)
diff --git a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py
index 8db56ddcb..58bd3a4f3 100644
--- a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py
+++ b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py
@@ -15,13 +15,14 @@
from policyengine_us_data.storage import STORAGE_FOLDER
-DATASET_PATH = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5")
+DATASET_PATH = str(
+ STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5"
+)
DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db")
DB_URI = f"sqlite:///{DB_PATH}"
N_CLONES = 2
SEED = 42
-RECORD_IDX = 8629 # High SNAP ($18k), lands in TX/PA with seed=42
def _data_available():
@@ -56,12 +57,34 @@ def matrix_result():
sim=sim,
target_filter={"domain_variables": ["snap", "medicaid"]},
)
+ X_csc = X_sparse.tocsc()
+ national_rows = targets_df[
+ targets_df["geo_level"] == "national"
+ ].index.values
+ district_targets = targets_df[targets_df["geo_level"] == "district"]
+ record_idx = None
+ for ri in range(n_records):
+ vals = X_csc[:, ri].toarray().ravel()
+ if not np.any(vals[national_rows] != 0):
+ continue
+ cd = str(geography.cd_geoid[ri])
+ own_cd_rows = district_targets[
+ district_targets["geographic_id"] == cd
+ ].index.values
+ if len(own_cd_rows) > 0 and np.any(vals[own_cd_rows] != 0):
+ record_idx = ri
+ break
+
+ if record_idx is None:
+ pytest.skip("No suitable test household found")
+
return {
"geography": geography,
"targets_df": targets_df,
"X": X_sparse,
"target_names": target_names,
"n_records": n_records,
+ "record_idx": record_idx,
}
@@ -94,8 +117,8 @@ def test_both_clones_visible_to_national_target(self, matrix_result):
national_rows = targets_df[targets_df["geo_level"] == "national"].index
assert len(national_rows) > 0
- col_0 = _clone_col(n_records, 0, RECORD_IDX)
- col_1 = _clone_col(n_records, 1, RECORD_IDX)
+ col_0 = _clone_col(n_records, 0, matrix_result["record_idx"])
+ col_1 = _clone_col(n_records, 1, matrix_result["record_idx"])
X_csc = X.tocsc()
visible_0 = X_csc[:, col_0].toarray().ravel()
@@ -117,8 +140,8 @@ def test_clone_visible_only_to_own_state(self, matrix_result):
geography = matrix_result["geography"]
n_records = matrix_result["n_records"]
- col_0 = _clone_col(n_records, 0, RECORD_IDX)
- col_1 = _clone_col(n_records, 1, RECORD_IDX)
+ col_0 = _clone_col(n_records, 0, matrix_result["record_idx"])
+ col_1 = _clone_col(n_records, 1, matrix_result["record_idx"])
state_0 = str(int(geography.state_fips[col_0]))
state_1 = str(int(geography.state_fips[col_1]))
@@ -155,7 +178,7 @@ def test_clone_visible_only_to_own_cd(self, matrix_result):
geography = matrix_result["geography"]
n_records = matrix_result["n_records"]
- col_0 = _clone_col(n_records, 0, RECORD_IDX)
+ col_0 = _clone_col(n_records, 0, matrix_result["record_idx"])
cd_0 = str(geography.cd_geoid[col_0])
state_0 = str(int(geography.state_fips[col_0]))
@@ -185,7 +208,7 @@ def test_clone_nonzero_for_own_cd(self, matrix_result):
geography = matrix_result["geography"]
n_records = matrix_result["n_records"]
- col_0 = _clone_col(n_records, 0, RECORD_IDX)
+ col_0 = _clone_col(n_records, 0, matrix_result["record_idx"])
cd_0 = str(geography.cd_geoid[col_0])
own_cd_targets = targets_df[
diff --git a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py
index 0ba330549..b2d45bd55 100644
--- a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py
+++ b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py
@@ -133,6 +133,22 @@ def test_state_from_block(self, mock_load):
expected = int(r.block_geoid[i][:2])
assert r.state_fips[i] == expected
+ @patch(
+ "policyengine_us_data.calibration.clone_and_assign"
+ ".load_global_block_distribution"
+ )
+ def test_no_cd_collisions_across_clones(self, mock_load):
+ mock_load.return_value = _mock_distribution()
+ r = assign_random_geography(n_records=100, n_clones=3, seed=42)
+ for rec in range(r.n_records):
+ rec_cds = [
+ r.cd_geoid[clone * r.n_records + rec]
+ for clone in range(r.n_clones)
+ ]
+ assert len(rec_cds) == len(
+ set(rec_cds)
+ ), f"Record {rec} has duplicate CDs: {rec_cds}"
+
def test_missing_file_raises(self, tmp_path):
fake = tmp_path / "nonexistent"
fake.mkdir()
@@ -150,6 +166,7 @@ def test_doubles_n_records(self):
geo = GeographyAssignment(
block_geoid=np.array(["010010001001001", "020010001001001"] * 3),
cd_geoid=np.array(["101", "202"] * 3),
+ county_fips=np.array(["01001", "02001"] * 3),
state_fips=np.array([1, 2] * 3),
n_records=2,
n_clones=3,
@@ -172,6 +189,9 @@ def test_puf_half_matches_cps_half(self):
]
),
cd_geoid=np.array(["101", "202", "1036", "653", "4831", "1227"]),
+ county_fips=np.array(
+ ["01001", "02001", "36010", "06010", "48010", "12010"]
+ ),
state_fips=np.array([1, 2, 36, 6, 48, 12]),
n_records=3,
n_clones=2,
diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_county_assignment.py b/policyengine_us_data/tests/test_calibration/test_county_assignment.py
similarity index 98%
rename from policyengine_us_data/tests/test_local_area_calibration/test_county_assignment.py
rename to policyengine_us_data/tests/test_calibration/test_county_assignment.py
index 158e0ca68..03d7342d9 100644
--- a/policyengine_us_data/tests/test_local_area_calibration/test_county_assignment.py
+++ b/policyengine_us_data/tests/test_calibration/test_county_assignment.py
@@ -6,7 +6,7 @@
from policyengine_us.variables.household.demographic.geographic.county.county_enum import (
County,
)
-from policyengine_us_data.datasets.cps.local_area_calibration.county_assignment import (
+from policyengine_us_data.calibration.county_assignment import (
assign_counties_for_cd,
get_county_index,
_build_state_counties,
diff --git a/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py b/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py
index daade621d..c69abe76a 100644
--- a/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py
+++ b/policyengine_us_data/tests/test_calibration/test_drop_target_groups.py
@@ -5,7 +5,7 @@
import pytest
from scipy import sparse
-from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
+from policyengine_us_data.calibration.calibration_utils import (
drop_target_groups,
create_target_groups,
)
diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_fixture_50hh.h5 b/policyengine_us_data/tests/test_calibration/test_fixture_50hh.h5
similarity index 100%
rename from policyengine_us_data/tests/test_local_area_calibration/test_fixture_50hh.h5
rename to policyengine_us_data/tests/test_calibration/test_fixture_50hh.h5
diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_stacked_dataset_builder.py b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py
similarity index 75%
rename from policyengine_us_data/tests/test_local_area_calibration/test_stacked_dataset_builder.py
rename to policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py
index 2900eec19..48177726a 100644
--- a/policyengine_us_data/tests/test_local_area_calibration/test_stacked_dataset_builder.py
+++ b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py
@@ -1,4 +1,4 @@
-"""Tests for stacked_dataset_builder.py using deterministic test fixture."""
+"""Tests for build_h5 using deterministic test fixture."""
import os
import tempfile
@@ -6,13 +6,14 @@
import pandas as pd
import pytest
+from pathlib import Path
from policyengine_us import Microsimulation
-from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import (
- create_sparse_cd_stacked_dataset,
+from policyengine_us_data.calibration.publish_local_area import (
+ build_h5,
)
FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "test_fixture_50hh.h5")
-TEST_CDS = ["3701", "201"] # NC-01 and AK at-large
+TEST_CDS = ["3701", "200"] # NC-01 and AK at-large
SEED = 42
@@ -48,12 +49,13 @@ def stacked_result(test_weights):
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_output.h5")
- create_sparse_cd_stacked_dataset(
- test_weights,
- TEST_CDS,
+ build_h5(
+ weights=np.array(test_weights),
+ blocks=None,
+ dataset_path=Path(FIXTURE_PATH),
+ output_path=Path(output_path),
+ cds_to_calibrate=TEST_CDS,
cd_subset=TEST_CDS,
- dataset_path=FIXTURE_PATH,
- output_path=output_path,
)
sim_after = Microsimulation(dataset=output_path)
@@ -69,12 +71,7 @@ def stacked_result(test_weights):
)
)
- mapping_path = os.path.join(
- tmpdir, "mappings", "test_output_household_mapping.csv"
- )
- mapping_df = pd.read_csv(mapping_path)
-
- yield {"hh_df": hh_df, "mapping_df": mapping_df}
+ yield {"hh_df": hh_df}
class TestStackedDatasetBuilder:
@@ -85,10 +82,10 @@ def test_output_has_correct_cd_count(self, stacked_result):
assert len(cds_in_output) == len(TEST_CDS)
def test_output_contains_both_cds(self, stacked_result):
- """Output should contain both NC-01 (3701) and AK-AL (201)."""
+ """Output should contain both NC-01 (3701) and AK-AL (200)."""
hh_df = stacked_result["hh_df"]
cds_in_output = set(hh_df["congressional_district_geoid"].unique())
- expected = {3701, 201}
+ expected = {3701, 200}
assert cds_in_output == expected
def test_state_fips_matches_cd(self, stacked_result):
@@ -106,27 +103,6 @@ def test_household_ids_are_unique(self, stacked_result):
hh_df = stacked_result["hh_df"]
assert hh_df["household_id"].nunique() == len(hh_df)
- def test_mapping_has_required_columns(self, stacked_result):
- """Mapping CSV should have expected columns."""
- mapping_df = stacked_result["mapping_df"]
- required_cols = [
- "new_household_id",
- "original_household_id",
- "congressional_district",
- "state_fips",
- ]
- for col in required_cols:
- assert col in mapping_df.columns
-
- def test_mapping_covers_all_output_households(self, stacked_result):
- """Every output household should be in the mapping."""
- hh_df = stacked_result["hh_df"]
- mapping_df = stacked_result["mapping_df"]
-
- output_hh_ids = set(hh_df["household_id"].values)
- mapped_hh_ids = set(mapping_df["new_household_id"].values)
- assert output_hh_ids == mapped_hh_ids
-
def test_weights_are_positive(self, stacked_result):
"""All household weights should be positive."""
hh_df = stacked_result["hh_df"]
@@ -164,12 +140,13 @@ def stacked_sim(test_weights):
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_output.h5")
- create_sparse_cd_stacked_dataset(
- test_weights,
- TEST_CDS,
+ build_h5(
+ weights=np.array(test_weights),
+ blocks=None,
+ dataset_path=Path(FIXTURE_PATH),
+ output_path=Path(output_path),
+ cds_to_calibrate=TEST_CDS,
cd_subset=TEST_CDS,
- dataset_path=FIXTURE_PATH,
- output_path=output_path,
)
sim = Microsimulation(dataset=output_path)
@@ -179,21 +156,21 @@ def stacked_sim(test_weights):
@pytest.fixture(scope="module")
def stacked_sim_with_overlap(n_households):
"""Stacked dataset where SAME households appear in BOTH CDs."""
- # Force same households to appear in both CDs - tests reindexing
w = np.zeros(n_households * len(TEST_CDS), dtype=float)
- overlap_households = [0, 1, 2] # Same households in both CDs
+ overlap_households = [0, 1, 2]
for cd_idx in range(len(TEST_CDS)):
for hh_idx in overlap_households:
w[cd_idx * n_households + hh_idx] = 1.0
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_overlap.h5")
- create_sparse_cd_stacked_dataset(
- w,
- TEST_CDS,
+ build_h5(
+ weights=np.array(w),
+ blocks=None,
+ dataset_path=Path(FIXTURE_PATH),
+ output_path=Path(output_path),
+ cds_to_calibrate=TEST_CDS,
cd_subset=TEST_CDS,
- dataset_path=FIXTURE_PATH,
- output_path=output_path,
)
sim = Microsimulation(dataset=output_path)
yield {"sim": sim, "n_overlap": len(overlap_households)}
@@ -241,21 +218,16 @@ def test_person_family_id_matches_family_id(self, stacked_sim):
), f"person_family_id {pf_id} not in family_ids"
def test_family_ids_unique_across_cds(self, stacked_sim_with_overlap):
- """Same household in different CDs should have different family_ids."""
+ """Same HH in different CDs should get different family_ids."""
sim = stacked_sim_with_overlap["sim"]
n_overlap = stacked_sim_with_overlap["n_overlap"]
n_cds = len(TEST_CDS)
family_ids = sim.calculate("family_id", map_to="family").values
- household_ids = sim.calculate(
- "household_id", map_to="household"
- ).values
- # Should have n_overlap * n_cds unique families (one per HH-CD pair)
expected_families = n_overlap * n_cds
assert len(family_ids) == expected_families, (
- f"Expected {expected_families} families (same HH in {n_cds} CDs), "
- f"got {len(family_ids)}"
+ f"Expected {expected_families} families, " f"got {len(family_ids)}"
)
assert len(set(family_ids)) == expected_families, (
f"Family IDs not unique: {len(set(family_ids))} unique "
diff --git a/policyengine_us_data/tests/test_calibration/test_target_config.py b/policyengine_us_data/tests/test_calibration/test_target_config.py
new file mode 100644
index 000000000..9241660ce
--- /dev/null
+++ b/policyengine_us_data/tests/test_calibration/test_target_config.py
@@ -0,0 +1,177 @@
+"""Tests for target config filtering in unified calibration."""
+
+import numpy as np
+import pandas as pd
+import pytest
+from scipy import sparse
+
+from policyengine_us_data.calibration.unified_calibration import (
+ apply_target_config,
+ load_target_config,
+ save_calibration_package,
+ load_calibration_package,
+)
+
+
+@pytest.fixture
+def sample_targets():
+ targets_df = pd.DataFrame(
+ {
+ "variable": [
+ "snap",
+ "snap",
+ "eitc",
+ "eitc",
+ "rent",
+ "person_count",
+ ],
+ "geo_level": [
+ "national",
+ "state",
+ "district",
+ "state",
+ "national",
+ "national",
+ ],
+ "domain_variable": [
+ "snap",
+ "snap",
+ "eitc",
+ "eitc",
+ "rent",
+ "person_count",
+ ],
+ "geographic_id": ["US", "6", "0601", "6", "US", "US"],
+ "value": [1000, 500, 200, 300, 800, 5000],
+ }
+ )
+ n_rows = len(targets_df)
+ n_cols = 10
+ rng = np.random.default_rng(42)
+ X = sparse.random(n_rows, n_cols, density=0.5, random_state=rng)
+ X = X.tocsr()
+ target_names = [
+ f"{r.variable}_{r.geo_level}_{r.geographic_id}"
+ for _, r in targets_df.iterrows()
+ ]
+ return targets_df, X, target_names
+
+
+class TestApplyTargetConfig:
+ def test_empty_config_keeps_all(self, sample_targets):
+ df, X, names = sample_targets
+ config = {"exclude": []}
+ out_df, out_X, out_names = apply_target_config(df, X, names, config)
+ assert len(out_df) == len(df)
+ assert out_X.shape == X.shape
+ assert out_names == names
+
+ def test_single_variable_geo_exclusion(self, sample_targets):
+ df, X, names = sample_targets
+ config = {"exclude": [{"variable": "rent", "geo_level": "national"}]}
+ out_df, out_X, out_names = apply_target_config(df, X, names, config)
+ assert len(out_df) == len(df) - 1
+ assert "rent" not in out_df["variable"].values
+
+ def test_multiple_exclusions(self, sample_targets):
+ df, X, names = sample_targets
+ config = {
+ "exclude": [
+ {"variable": "rent", "geo_level": "national"},
+ {"variable": "eitc", "geo_level": "district"},
+ ]
+ }
+ out_df, out_X, out_names = apply_target_config(df, X, names, config)
+ assert len(out_df) == len(df) - 2
+ kept = set(zip(out_df["variable"], out_df["geo_level"]))
+ assert ("rent", "national") not in kept
+ assert ("eitc", "district") not in kept
+ assert ("eitc", "state") in kept
+
+ def test_domain_variable_matching(self, sample_targets):
+ df, X, names = sample_targets
+ config = {
+ "exclude": [
+ {
+ "variable": "snap",
+ "geo_level": "national",
+ "domain_variable": "snap",
+ }
+ ]
+ }
+ out_df, out_X, out_names = apply_target_config(df, X, names, config)
+ assert len(out_df) == len(df) - 1
+
+ def test_matrix_and_names_stay_in_sync(self, sample_targets):
+ df, X, names = sample_targets
+ config = {
+ "exclude": [{"variable": "person_count", "geo_level": "national"}]
+ }
+ out_df, out_X, out_names = apply_target_config(df, X, names, config)
+ assert out_X.shape[0] == len(out_df)
+ assert len(out_names) == len(out_df)
+ assert out_X.shape[1] == X.shape[1]
+
+ def test_no_match_keeps_all(self, sample_targets):
+ df, X, names = sample_targets
+ config = {
+ "exclude": [{"variable": "nonexistent", "geo_level": "national"}]
+ }
+ out_df, out_X, out_names = apply_target_config(df, X, names, config)
+ assert len(out_df) == len(df)
+ assert out_X.shape[0] == X.shape[0]
+
+
+class TestLoadTargetConfig:
+ def test_load_valid_config(self, tmp_path):
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text(
+ "exclude:\n" " - variable: snap\n" " geo_level: national\n"
+ )
+ config = load_target_config(str(config_file))
+ assert len(config["exclude"]) == 1
+ assert config["exclude"][0]["variable"] == "snap"
+
+ def test_load_empty_config(self, tmp_path):
+ config_file = tmp_path / "empty.yaml"
+ config_file.write_text("")
+ config = load_target_config(str(config_file))
+ assert config["exclude"] == []
+
+
+class TestCalibrationPackageRoundTrip:
+ def test_round_trip(self, sample_targets, tmp_path):
+ df, X, names = sample_targets
+ pkg_path = str(tmp_path / "package.pkl")
+ metadata = {
+ "dataset_path": "/tmp/test.h5",
+ "db_path": "/tmp/test.db",
+ "n_clones": 5,
+ "n_records": X.shape[1],
+ "seed": 42,
+ "created_at": "2024-01-01T00:00:00",
+ "target_config": None,
+ }
+ save_calibration_package(pkg_path, X, df, names, metadata)
+ loaded = load_calibration_package(pkg_path)
+
+ assert loaded["target_names"] == names
+ pd.testing.assert_frame_equal(loaded["targets_df"], df)
+ assert loaded["X_sparse"].shape == X.shape
+ assert loaded["metadata"]["seed"] == 42
+
+ def test_package_then_filter(self, sample_targets, tmp_path):
+ df, X, names = sample_targets
+ pkg_path = str(tmp_path / "package.pkl")
+ metadata = {"n_records": X.shape[1]}
+ save_calibration_package(pkg_path, X, df, names, metadata)
+ loaded = load_calibration_package(pkg_path)
+
+ config = {"exclude": [{"variable": "rent", "geo_level": "national"}]}
+ out_df, out_X, out_names = apply_target_config(
+ loaded["targets_df"],
+ loaded["X_sparse"],
+ loaded["target_names"],
+ config,
+ )
+ assert len(out_df) == len(df) - 1
diff --git a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py
index 2d3f80619..04e70ea69 100644
--- a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py
+++ b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py
@@ -1,13 +1,24 @@
-"""Tests for unified_calibration module.
+"""Tests for unified_calibration and shared takeup module.
-Focuses on rerandomize_takeup: verifies draws differ by
-block and are reproducible within the same block.
+Verifies geo-salted draws are reproducible and vary by geo_id,
+SIMPLE_TAKEUP_VARS / TAKEUP_AFFECTED_TARGETS configs are valid,
+block-level takeup seeding, county precomputation, and CLI flags.
"""
import numpy as np
import pytest
from policyengine_us_data.utils.randomness import seeded_rng
+from policyengine_us_data.utils.takeup import (
+ SIMPLE_TAKEUP_VARS,
+ TAKEUP_AFFECTED_TARGETS,
+ compute_block_takeup_for_entities,
+ apply_block_takeup_to_arrays,
+ _resolve_rate,
+)
+from policyengine_us_data.calibration.clone_and_assign import (
+ GeographyAssignment,
+)
class TestRerandomizeTakeupSeeding:
@@ -61,14 +72,140 @@ def test_rate_comparison_produces_booleans(self):
assert 0.70 < frac < 0.80
+class TestBlockSaltedDraws:
+ """Verify compute_block_takeup_for_entities produces
+ reproducible, block-dependent draws."""
+
+ def test_same_block_same_results(self):
+ blocks = np.array(["370010001001001"] * 500)
+ d1 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ d2 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ np.testing.assert_array_equal(d1, d2)
+
+ def test_different_blocks_different_results(self):
+ n = 500
+ d1 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible",
+ 0.8,
+ np.array(["370010001001001"] * n),
+ )
+ d2 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible",
+ 0.8,
+ np.array(["480010002002002"] * n),
+ )
+ assert not np.array_equal(d1, d2)
+
+ def test_different_vars_different_results(self):
+ blocks = np.array(["370010001001001"] * 500)
+ d1 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ d2 = compute_block_takeup_for_entities(
+ "takes_up_aca_if_eligible", 0.8, blocks
+ )
+ assert not np.array_equal(d1, d2)
+
+ def test_hh_salt_differs_from_block_only(self):
+ blocks = np.array(["370010001001001"] * 500)
+ hh_ids = np.array([1] * 500)
+ d_block = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ d_hh = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks, hh_ids
+ )
+ assert not np.array_equal(d_block, d_hh)
+
+
+class TestApplyBlockTakeupToArrays:
+ """Verify apply_block_takeup_to_arrays returns correct
+ boolean arrays for all entity levels."""
+
+ def _make_arrays(self, n_hh, persons_per_hh, tu_per_hh, spm_per_hh):
+ """Build test arrays for n_hh households."""
+ n_p = n_hh * persons_per_hh
+ n_tu = n_hh * tu_per_hh
+ n_spm = n_hh * spm_per_hh
+ hh_blocks = np.array(["370010001001001"] * n_hh)
+ hh_state_fips = np.array([37] * n_hh, dtype=np.int32)
+ hh_ids = np.arange(n_hh, dtype=np.int64)
+ entity_hh_indices = {
+ "person": np.repeat(np.arange(n_hh), persons_per_hh),
+ "tax_unit": np.repeat(np.arange(n_hh), tu_per_hh),
+ "spm_unit": np.repeat(np.arange(n_hh), spm_per_hh),
+ }
+ entity_counts = {
+ "person": n_p,
+ "tax_unit": n_tu,
+ "spm_unit": n_spm,
+ }
+ return (
+ hh_blocks,
+ hh_state_fips,
+ hh_ids,
+ entity_hh_indices,
+ entity_counts,
+ )
+
+ def test_returns_all_takeup_vars(self):
+ args = self._make_arrays(10, 3, 2, 1)
+ result = apply_block_takeup_to_arrays(*args, time_period=2024)
+ for spec in SIMPLE_TAKEUP_VARS:
+ assert spec["variable"] in result
+ assert result[spec["variable"]].dtype == bool
+
+ def test_correct_entity_counts(self):
+ args = self._make_arrays(20, 10, 4, 3)
+ result = apply_block_takeup_to_arrays(*args, time_period=2024)
+ assert len(result["takes_up_snap_if_eligible"]) == 60
+ assert len(result["takes_up_aca_if_eligible"]) == 80
+ assert len(result["takes_up_ssi_if_eligible"]) == 200
+
+ def test_reproducible(self):
+ args = self._make_arrays(10, 3, 2, 1)
+ r1 = apply_block_takeup_to_arrays(*args, time_period=2024)
+ r2 = apply_block_takeup_to_arrays(*args, time_period=2024)
+ for var in r1:
+ np.testing.assert_array_equal(r1[var], r2[var])
+
+ def test_different_blocks_different_result(self):
+ args_a = self._make_arrays(10, 3, 2, 1)
+ r1 = apply_block_takeup_to_arrays(*args_a, time_period=2024)
+
+ args_b = list(self._make_arrays(10, 3, 2, 1))
+ args_b[0] = np.array(["480010002002002"] * 10)
+ args_b[1] = np.array([48] * 10, dtype=np.int32)
+ r2 = apply_block_takeup_to_arrays(*args_b, time_period=2024)
+
+ differs = any(not np.array_equal(r1[v], r2[v]) for v in r1)
+ assert differs
+
+
+class TestResolveRate:
+ """Verify _resolve_rate handles scalar and dict rates."""
+
+ def test_scalar_rate(self):
+ assert _resolve_rate(0.82, 37) == 0.82
+
+ def test_state_dict_rate(self):
+ rates = {"NC": 0.94, "TX": 0.76}
+ assert _resolve_rate(rates, 37) == 0.94
+ assert _resolve_rate(rates, 48) == 0.76
+
+ def test_unknown_state_fallback(self):
+ rates = {"NC": 0.94}
+ assert _resolve_rate(rates, 99) == 0.8
+
+
class TestSimpleTakeupConfig:
"""Verify the SIMPLE_TAKEUP_VARS config is well-formed."""
def test_all_entries_have_required_keys(self):
- from policyengine_us_data.calibration.unified_calibration import (
- SIMPLE_TAKEUP_VARS,
- )
-
for entry in SIMPLE_TAKEUP_VARS:
assert "variable" in entry
assert "entity" in entry
@@ -80,8 +217,504 @@ def test_all_entries_have_required_keys(self):
)
def test_expected_count(self):
+ assert len(SIMPLE_TAKEUP_VARS) == 9
+
+
+class TestTakeupAffectedTargets:
+ """Verify TAKEUP_AFFECTED_TARGETS is consistent."""
+
+ def test_all_entries_have_required_keys(self):
+ for key, info in TAKEUP_AFFECTED_TARGETS.items():
+ assert "takeup_var" in info
+ assert "entity" in info
+ assert "rate_key" in info
+ assert info["entity"] in (
+ "person",
+ "tax_unit",
+ "spm_unit",
+ )
+
+ def test_takeup_vars_exist_in_simple_vars(self):
+ simple_var_names = {s["variable"] for s in SIMPLE_TAKEUP_VARS}
+ for info in TAKEUP_AFFECTED_TARGETS.values():
+ assert info["takeup_var"] in simple_var_names
+
+
+class TestParseArgsNewFlags:
+ """Verify new CLI flags are parsed correctly."""
+
+ def test_target_config_flag(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ parse_args,
+ )
+
+ args = parse_args(["--target-config", "config.yaml"])
+ assert args.target_config == "config.yaml"
+
+ def test_build_only_flag(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ parse_args,
+ )
+
+ args = parse_args(["--build-only"])
+ assert args.build_only is True
+
+ def test_package_path_flag(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ parse_args,
+ )
+
+ args = parse_args(["--package-path", "pkg.pkl"])
+ assert args.package_path == "pkg.pkl"
+
+ def test_hyperparams_flags(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ parse_args,
+ )
+
+ args = parse_args(
+ [
+ "--beta",
+ "0.65",
+ "--lambda-l2",
+ "1e-8",
+ "--learning-rate",
+ "0.2",
+ ]
+ )
+ assert args.beta == 0.65
+ assert args.lambda_l2 == 1e-8
+ assert args.learning_rate == 0.2
+
+ def test_hyperparams_defaults(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ BETA,
+ LAMBDA_L2,
+ LEARNING_RATE,
+ parse_args,
+ )
+
+ args = parse_args([])
+ assert args.beta == BETA
+ assert args.lambda_l2 == LAMBDA_L2
+ assert args.learning_rate == LEARNING_RATE
+
+ def test_skip_takeup_rerandomize_flag(self):
from policyengine_us_data.calibration.unified_calibration import (
- SIMPLE_TAKEUP_VARS,
+ parse_args,
+ )
+
+ args = parse_args(["--skip-takeup-rerandomize"])
+ assert args.skip_takeup_rerandomize is True
+
+ args_default = parse_args([])
+ assert args_default.skip_takeup_rerandomize is False
+
+
+class TestGeographyAssignmentCountyFips:
+ """Verify county_fips field on GeographyAssignment."""
+
+ def test_county_fips_equals_block_prefix(self):
+ blocks = np.array(
+ ["370010001001001", "480010002002002", "060370003003003"]
+ )
+ ga = GeographyAssignment(
+ block_geoid=blocks,
+ cd_geoid=np.array(["3701", "4801", "0613"]),
+ county_fips=np.array([b[:5] for b in blocks]),
+ state_fips=np.array([37, 48, 6]),
+ n_records=3,
+ n_clones=1,
+ )
+ expected = np.array(["37001", "48001", "06037"])
+ np.testing.assert_array_equal(ga.county_fips, expected)
+
+ def test_county_fips_length(self):
+ blocks = np.array(["370010001001001"] * 5)
+ counties = np.array([b[:5] for b in blocks])
+ ga = GeographyAssignment(
+ block_geoid=blocks,
+ cd_geoid=np.array(["3701"] * 5),
+ county_fips=counties,
+ state_fips=np.array([37] * 5),
+ n_records=5,
+ n_clones=1,
+ )
+ assert len(ga.county_fips) == 5
+ assert all(len(c) == 5 for c in ga.county_fips)
+
+
+class TestBlockTakeupSeeding:
+ """Verify compute_block_takeup_for_entities is
+ reproducible and block-dependent."""
+
+ def test_reproducible(self):
+ blocks = np.array(["010010001001001"] * 50 + ["020010001001001"] * 50)
+ r1 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ r2 = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ np.testing.assert_array_equal(r1, r2)
+
+ def test_different_blocks_different_draws(self):
+ n = 500
+ blocks_a = np.array(["010010001001001"] * n)
+ blocks_b = np.array(["020010001001001"] * n)
+ r_a = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks_a
+ )
+ r_b = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks_b
+ )
+ assert not np.array_equal(r_a, r_b)
+
+ def test_returns_booleans(self):
+ blocks = np.array(["370010001001001"] * 100)
+ result = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.8, blocks
+ )
+ assert result.dtype == bool
+
+ def test_rate_respected(self):
+ n = 10000
+ blocks = np.array(["370010001001001"] * n)
+ result = compute_block_takeup_for_entities(
+ "takes_up_snap_if_eligible", 0.75, blocks
+ )
+ frac = result.mean()
+ assert 0.70 < frac < 0.80
+
+
+class TestAssembleCloneValuesCounty:
+ """Verify _assemble_clone_values merges state and
+ county values correctly."""
+
+ def test_county_var_uses_county_values(self):
+ from policyengine_us_data.calibration.unified_matrix_builder import (
+ UnifiedMatrixBuilder,
+ )
+
+ n = 4
+ state_values = {
+ 1: {
+ "hh": {
+ "aca_ptc": np.array([100] * n, dtype=np.float32),
+ },
+ "person": {},
+ "entity": {},
+ },
+ 2: {
+ "hh": {
+ "aca_ptc": np.array([200] * n, dtype=np.float32),
+ },
+ "person": {},
+ "entity": {},
+ },
+ }
+ county_values = {
+ "01001": {
+ "hh": {
+ "aca_ptc": np.array([111] * n, dtype=np.float32),
+ },
+ "entity": {},
+ },
+ "02001": {
+ "hh": {
+ "aca_ptc": np.array([222] * n, dtype=np.float32),
+ },
+ "entity": {},
+ },
+ }
+ clone_states = np.array([1, 1, 2, 2])
+ clone_counties = np.array(["01001", "01001", "02001", "02001"])
+ person_hh_idx = np.array([0, 1, 2, 3])
+
+ builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
+ hh_vars, _ = builder._assemble_clone_values(
+ state_values,
+ clone_states,
+ person_hh_idx,
+ {"aca_ptc"},
+ set(),
+ county_values=county_values,
+ clone_counties=clone_counties,
+ county_dependent_vars={"aca_ptc"},
+ )
+ expected = np.array([111, 111, 222, 222], dtype=np.float32)
+ np.testing.assert_array_equal(hh_vars["aca_ptc"], expected)
+
+ def test_non_county_var_uses_state_values(self):
+ from policyengine_us_data.calibration.unified_matrix_builder import (
+ UnifiedMatrixBuilder,
+ )
+
+ n = 4
+ state_values = {
+ 1: {
+ "hh": {
+ "snap": np.array([50] * n, dtype=np.float32),
+ },
+ "person": {},
+ "entity": {},
+ },
+ 2: {
+ "hh": {
+ "snap": np.array([60] * n, dtype=np.float32),
+ },
+ "person": {},
+ "entity": {},
+ },
+ }
+ clone_states = np.array([1, 1, 2, 2])
+ clone_counties = np.array(["01001", "01001", "02001", "02001"])
+ person_hh_idx = np.array([0, 1, 2, 3])
+
+ builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
+ hh_vars, _ = builder._assemble_clone_values(
+ state_values,
+ clone_states,
+ person_hh_idx,
+ {"snap"},
+ set(),
+ county_values={},
+ clone_counties=clone_counties,
+ county_dependent_vars={"aca_ptc"},
+ )
+ expected = np.array([50, 50, 60, 60], dtype=np.float32)
+ np.testing.assert_array_equal(hh_vars["snap"], expected)
+
+
+class TestConvertBlocksToStackedFormat:
+ """Verify convert_blocks_to_stacked_format produces
+ correct stacked block arrays."""
+
+ def test_basic_conversion(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ convert_blocks_to_stacked_format,
+ )
+
+ block_geoid = np.array(
+ [
+ "370010001001001",
+ "370010001001002",
+ "480010002002001",
+ "480010002002002",
+ ]
+ )
+ cd_geoid = np.array(["3701", "3701", "4801", "4801"])
+ base_n_records = 2
+ cds_ordered = ["3701", "4801"]
+
+ result = convert_blocks_to_stacked_format(
+ block_geoid, cd_geoid, base_n_records, cds_ordered
+ )
+ assert result.dtype.kind == "U"
+ assert len(result) == 4
+ assert result[0] == "370010001001001"
+ assert result[1] == "370010001001002"
+ assert result[2] == "480010002002001"
+ assert result[3] == "480010002002002"
+
+ def test_empty_slots(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ convert_blocks_to_stacked_format,
+ )
+
+ block_geoid = np.array(["370010001001001", "370010001001002"])
+ cd_geoid = np.array(["3701", "3701"])
+ base_n_records = 2
+ cds_ordered = ["3701", "4801"]
+
+ result = convert_blocks_to_stacked_format(
+ block_geoid, cd_geoid, base_n_records, cds_ordered
+ )
+ assert len(result) == 4
+ assert result[0] == "370010001001001"
+ assert result[1] == "370010001001002"
+ assert result[2] == ""
+ assert result[3] == ""
+
+ def test_first_clone_wins(self):
+ from policyengine_us_data.calibration.unified_calibration import (
+ convert_blocks_to_stacked_format,
+ )
+
+ block_geoid = np.array(
+ [
+ "370010001001001",
+ "370010001001002",
+ "370010001001099",
+ "370010001001099",
+ ]
+ )
+ cd_geoid = np.array(["3701", "3701", "3701", "3701"])
+ base_n_records = 2
+ cds_ordered = ["3701"]
+
+ result = convert_blocks_to_stacked_format(
+ block_geoid, cd_geoid, base_n_records, cds_ordered
+ )
+ assert result[0] == "370010001001001"
+ assert result[1] == "370010001001002"
+
+
+class TestTakeupDrawConsistency:
+ """Verify the matrix builder's inline takeup loop and
+ compute_block_takeup_for_entities produce identical draws
+ when given the same (block, household) inputs."""
+
+ def test_matrix_and_stacked_identical_draws(self):
+ """Both paths must produce identical boolean arrays."""
+ var = "takes_up_snap_if_eligible"
+ rate = 0.75
+
+ # 2 blocks, 3 households, variable entity counts per HH
+ # HH0 has 2 entities in block A
+ # HH1 has 3 entities in block A
+ # HH2 has 1 entity in block B
+ blocks = np.array(
+ [
+ "370010001001001",
+ "370010001001001",
+ "370010001001001",
+ "370010001001001",
+ "370010001001001",
+ "480010002002002",
+ ]
+ )
+ hh_ids = np.array([100, 100, 200, 200, 200, 300])
+
+ # Path 1: compute_block_takeup_for_entities (stacked)
+ stacked = compute_block_takeup_for_entities(var, rate, blocks, hh_ids)
+
+ # Path 2: reproduce matrix builder inline logic
+ n = len(blocks)
+ inline_takeup = np.zeros(n, dtype=bool)
+ for blk in np.unique(blocks):
+ bm = blocks == blk
+ for hh_id in np.unique(hh_ids[bm]):
+ hh_mask = bm & (hh_ids == hh_id)
+ rng = seeded_rng(var, salt=f"{blk}:{int(hh_id)}")
+ draws = rng.random(int(hh_mask.sum()))
+ inline_takeup[hh_mask] = draws < rate
+
+ np.testing.assert_array_equal(stacked, inline_takeup)
+
+ def test_aggregation_entity_to_household(self):
+ """np.add.at aggregation matches manual per-HH sum."""
+ n_hh = 3
+ n_ent = 6
+ ent_hh = np.array([0, 0, 1, 1, 1, 2])
+ eligible = np.array(
+ [100.0, 200.0, 50.0, 150.0, 100.0, 300.0],
+ dtype=np.float32,
+ )
+ takeup = np.array([True, False, True, True, False, True])
+
+ ent_values = (eligible * takeup).astype(np.float32)
+ hh_result = np.zeros(n_hh, dtype=np.float32)
+ np.add.at(hh_result, ent_hh, ent_values)
+
+ # Manual: HH0=100, HH1=50+150=200, HH2=300
+ expected = np.array([100.0, 200.0, 300.0], dtype=np.float32)
+ np.testing.assert_array_equal(hh_result, expected)
+
+ def test_state_specific_rate_resolved_from_block(self):
+ """Dict rates are resolved per block's state FIPS."""
+ from policyengine_us_data.utils.takeup import _resolve_rate
+
+ var = "takes_up_snap_if_eligible"
+ rate_dict = {"NC": 0.9, "TX": 0.6}
+ n = 5000
+
+ blocks_nc = np.array(["370010001001001"] * n)
+ result_nc = compute_block_takeup_for_entities(
+ var, rate_dict, blocks_nc
+ )
+ # NC rate=0.9, expect ~90%
+ frac_nc = result_nc.mean()
+ assert 0.85 < frac_nc < 0.95, f"NC frac={frac_nc}"
+
+ blocks_tx = np.array(["480010002002002"] * n)
+ result_tx = compute_block_takeup_for_entities(
+ var, rate_dict, blocks_tx
+ )
+ # TX rate=0.6, expect ~60%
+ frac_tx = result_tx.mean()
+ assert 0.55 < frac_tx < 0.65, f"TX frac={frac_tx}"
+
+ # Verify _resolve_rate actually gives different rates
+ assert _resolve_rate(rate_dict, 37) == 0.9
+ assert _resolve_rate(rate_dict, 48) == 0.6
+
+
+class TestDeriveGeographyFromBlocks:
+ """Verify derive_geography_from_blocks returns correct
+ geography dict from pre-assigned blocks."""
+
+ def test_returns_expected_keys(self):
+ from policyengine_us_data.calibration.block_assignment import (
+ derive_geography_from_blocks,
+ )
+
+ blocks = np.array(["370010001001001"])
+ result = derive_geography_from_blocks(blocks)
+ expected_keys = {
+ "block_geoid",
+ "county_fips",
+ "tract_geoid",
+ "state_fips",
+ "cbsa_code",
+ "sldu",
+ "sldl",
+ "place_fips",
+ "vtd",
+ "puma",
+ "zcta",
+ "county_index",
+ }
+ assert set(result.keys()) == expected_keys
+
+ def test_county_fips_derived(self):
+ from policyengine_us_data.calibration.block_assignment import (
+ derive_geography_from_blocks,
+ )
+
+ blocks = np.array(["370010001001001", "480010002002002"])
+ result = derive_geography_from_blocks(blocks)
+ np.testing.assert_array_equal(
+ result["county_fips"],
+ np.array(["37001", "48001"]),
+ )
+
+ def test_state_fips_derived(self):
+ from policyengine_us_data.calibration.block_assignment import (
+ derive_geography_from_blocks,
+ )
+
+ blocks = np.array(["370010001001001", "060370003003003"])
+ result = derive_geography_from_blocks(blocks)
+ np.testing.assert_array_equal(
+ result["state_fips"],
+ np.array(["37", "06"]),
+ )
+
+ def test_tract_geoid_derived(self):
+ from policyengine_us_data.calibration.block_assignment import (
+ derive_geography_from_blocks,
+ )
+
+ blocks = np.array(["370010001001001"])
+ result = derive_geography_from_blocks(blocks)
+ assert result["tract_geoid"][0] == "37001000100"
+
+ def test_block_geoid_passthrough(self):
+ from policyengine_us_data.calibration.block_assignment import (
+ derive_geography_from_blocks,
)
- assert len(SIMPLE_TAKEUP_VARS) == 8
+ blocks = np.array(["370010001001001"])
+ result = derive_geography_from_blocks(blocks)
+ assert result["block_geoid"][0] == "370010001001001"
diff --git a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py
index ea2d49c5c..1a312e99c 100644
--- a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py
+++ b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py
@@ -395,5 +395,692 @@ def test_endswith_count(self):
)
+class _FakeArrayResult:
+ """Minimal stand-in for sim.calculate() return values."""
+
+ def __init__(self, values):
+ self.values = values
+
+
+class _FakeSimulation:
+ """Lightweight mock for policyengine_us.Microsimulation.
+
+ Tracks set_input and delete_arrays calls, returns
+ configurable arrays from calculate().
+ """
+
+ def __init__(self, n_hh=4, n_person=8, n_tax_unit=4, n_spm_unit=4):
+ self.n_hh = n_hh
+ self.n_person = n_person
+ self.n_tax_unit = n_tax_unit
+ self.n_spm_unit = n_spm_unit
+
+ self.set_input_calls = []
+ self.delete_arrays_calls = []
+ self.calculate_calls = []
+
+ # Configurable return values for calculate()
+ self._calc_returns = {}
+
+ def set_input(self, var, period, values):
+ self.set_input_calls.append((var, period, values))
+
+ def delete_arrays(self, var):
+ self.delete_arrays_calls.append(var)
+
+ def calculate(self, var, period=None, map_to=None):
+ self.calculate_calls.append((var, period, map_to))
+ if var in self._calc_returns:
+ return _FakeArrayResult(self._calc_returns[var])
+ # Default arrays by entity/map_to
+ if var.endswith("_id"):
+ entity = var.replace("_id", "")
+ sizes = {
+ "household": self.n_hh,
+ "person": self.n_person,
+ "tax_unit": self.n_tax_unit,
+ "spm_unit": self.n_spm_unit,
+ }
+ n = sizes.get(entity, self.n_hh)
+ return _FakeArrayResult(np.arange(n))
+ if map_to == "household":
+ return _FakeArrayResult(np.ones(self.n_hh, dtype=np.float32))
+ if map_to == "person":
+ return _FakeArrayResult(np.ones(self.n_person, dtype=np.float32))
+ # entity-level (spm_unit, tax_unit, person)
+ sizes = {
+ "spm_unit": self.n_spm_unit,
+ "tax_unit": self.n_tax_unit,
+ "person": self.n_person,
+ }
+ n = sizes.get(map_to, self.n_hh)
+ return _FakeArrayResult(np.ones(n, dtype=np.float32))
+
+
+import numpy as np
+from unittest.mock import patch, MagicMock
+from collections import namedtuple
+
+_FakeGeo = namedtuple(
+ "FakeGeo",
+ ["state_fips", "n_records", "county_fips", "block_geoid"],
+)
+
+
+class TestBuildStateValues(unittest.TestCase):
+ """Test _build_state_values orchestration logic."""
+
+ def _make_builder(self):
+ builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
+ builder.time_period = 2024
+ builder.dataset_path = "fake.h5"
+ return builder
+
+ def _make_geo(self, states, n_records=4):
+ return _FakeGeo(
+ state_fips=np.array(states),
+ n_records=n_records,
+ county_fips=np.array(["00000"] * len(states)),
+ block_geoid=np.array(["000000000000000"] * len(states)),
+ )
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=["var_a"],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_return_structure_no_takeup(self, mock_msim_cls, mock_gcv):
+ sim1 = _FakeSimulation()
+ sim2 = _FakeSimulation()
+ mock_msim_cls.side_effect = [sim1, sim2]
+
+ builder = self._make_builder()
+ geo = self._make_geo([37, 48])
+
+ result = builder._build_state_values(
+ sim=None,
+ target_vars={"snap"},
+ constraint_vars={"income"},
+ geography=geo,
+ rerandomize_takeup=False,
+ )
+ # Both states present
+ assert 37 in result
+ assert 48 in result
+ # Each has hh/person/entity
+ for st in (37, 48):
+ assert "hh" in result[st]
+ assert "person" in result[st]
+ assert "entity" in result[st]
+ # entity is empty when not rerandomizing
+ assert result[st]["entity"] == {}
+ # hh values are float32
+ assert result[st]["hh"]["snap"].dtype == np.float32
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_fresh_sim_per_state(self, mock_msim_cls, mock_gcv):
+ mock_msim_cls.side_effect = [
+ _FakeSimulation(),
+ _FakeSimulation(),
+ ]
+ builder = self._make_builder()
+ geo = self._make_geo([37, 48])
+
+ builder._build_state_values(
+ sim=None,
+ target_vars={"snap"},
+ constraint_vars=set(),
+ geography=geo,
+ rerandomize_takeup=False,
+ )
+ assert mock_msim_cls.call_count == 2
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_state_fips_set_correctly(self, mock_msim_cls, mock_gcv):
+ sims = [_FakeSimulation(), _FakeSimulation()]
+ mock_msim_cls.side_effect = sims
+
+ builder = self._make_builder()
+ geo = self._make_geo([37, 48])
+
+ builder._build_state_values(
+ sim=None,
+ target_vars={"snap"},
+ constraint_vars=set(),
+ geography=geo,
+ rerandomize_takeup=False,
+ )
+
+ # First sim should get state 37
+ fips_calls_0 = [
+ c for c in sims[0].set_input_calls if c[0] == "state_fips"
+ ]
+ assert len(fips_calls_0) == 1
+ np.testing.assert_array_equal(
+ fips_calls_0[0][2], np.full(4, 37, dtype=np.int32)
+ )
+
+ # Second sim should get state 48
+ fips_calls_1 = [
+ c for c in sims[1].set_input_calls if c[0] == "state_fips"
+ ]
+ assert len(fips_calls_1) == 1
+ np.testing.assert_array_equal(
+ fips_calls_1[0][2], np.full(4, 48, dtype=np.int32)
+ )
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_takeup_vars_forced_true(self, mock_msim_cls, mock_gcv):
+ sim = _FakeSimulation()
+ mock_msim_cls.return_value = sim
+
+ builder = self._make_builder()
+ geo = self._make_geo([37])
+
+ builder._build_state_values(
+ sim=None,
+ target_vars={"snap"},
+ constraint_vars=set(),
+ geography=geo,
+ rerandomize_takeup=True,
+ )
+
+ from policyengine_us_data.utils.takeup import (
+ SIMPLE_TAKEUP_VARS,
+ )
+
+ takeup_var_names = {s["variable"] for s in SIMPLE_TAKEUP_VARS}
+
+ # Check that every SIMPLE_TAKEUP_VAR was set to ones
+ set_true_vars = set()
+ for var, period, values in sim.set_input_calls:
+ if var in takeup_var_names:
+ assert values.dtype == bool
+ assert values.all(), f"{var} not forced True"
+ set_true_vars.add(var)
+
+ assert takeup_var_names == set_true_vars, (
+ f"Missing forced-true vars: " f"{takeup_var_names - set_true_vars}"
+ )
+
+ # Entity-level calculation happens for affected target
+ entity_calcs = [
+ c
+ for c in sim.calculate_calls
+ if c[0] == "snap" and c[2] not in ("household", "person", None)
+ ]
+ assert len(entity_calcs) >= 1
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_count_vars_skipped(self, mock_msim_cls, mock_gcv):
+ sim = _FakeSimulation()
+ mock_msim_cls.return_value = sim
+
+ builder = self._make_builder()
+ geo = self._make_geo([37])
+
+ builder._build_state_values(
+ sim=None,
+ target_vars={"snap", "snap_count"},
+ constraint_vars=set(),
+ geography=geo,
+ rerandomize_takeup=False,
+ )
+
+ # snap calculated, snap_count NOT calculated
+ calc_vars = [c[0] for c in sim.calculate_calls]
+ assert "snap" in calc_vars
+ assert "snap_count" not in calc_vars
+
+
+class TestBuildCountyValues(unittest.TestCase):
+ """Test _build_county_values orchestration logic."""
+
+ def _make_builder(self):
+ builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
+ builder.time_period = 2024
+ builder.dataset_path = "fake.h5"
+ return builder
+
+ def _make_geo(self, county_fips_list, n_records=4):
+ states = [int(c[:2]) for c in county_fips_list]
+ return _FakeGeo(
+ state_fips=np.array(states),
+ n_records=n_records,
+ county_fips=np.array(county_fips_list),
+ block_geoid=np.array(["000000000000000"] * len(county_fips_list)),
+ )
+
+ def test_returns_empty_when_county_level_false(self):
+ builder = self._make_builder()
+ geo = self._make_geo(["37001"])
+ result = builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=False,
+ )
+ assert result == {}
+
+ def test_returns_empty_when_no_targets(self):
+ builder = self._make_builder()
+ geo = self._make_geo(["37001"])
+ result = builder._build_county_values(
+ sim=None,
+ county_dep_targets=set(),
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ )
+ assert result == {}
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_county_enum_index_from_fips",
+ return_value=1,
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=["var_a"],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_return_structure(self, mock_msim_cls, mock_gcv, mock_county_idx):
+ sim = _FakeSimulation()
+ mock_msim_cls.return_value = sim
+
+ builder = self._make_builder()
+ geo = self._make_geo(["37001", "37002"])
+
+ result = builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ )
+ assert "37001" in result
+ assert "37002" in result
+ for cfips in ("37001", "37002"):
+ assert "hh" in result[cfips]
+ assert "entity" in result[cfips]
+ # No person-level in county values
+ assert "person" not in result[cfips]
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_county_enum_index_from_fips",
+ return_value=1,
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=["var_a"],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_sim_reuse_within_state(
+ self, mock_msim_cls, mock_gcv, mock_county_idx
+ ):
+ sim = _FakeSimulation()
+ mock_msim_cls.return_value = sim
+
+ builder = self._make_builder()
+ geo = self._make_geo(["37001", "37002"])
+
+ builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ )
+ # 1 state -> 1 Microsimulation
+ assert mock_msim_cls.call_count == 1
+ # 2 counties -> county set_input called twice
+ county_calls = [c for c in sim.set_input_calls if c[0] == "county"]
+ assert len(county_calls) == 2
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_county_enum_index_from_fips",
+ return_value=1,
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_fresh_sim_across_states(
+ self, mock_msim_cls, mock_gcv, mock_county_idx
+ ):
+ mock_msim_cls.side_effect = [
+ _FakeSimulation(),
+ _FakeSimulation(),
+ ]
+ builder = self._make_builder()
+ # 2 states, 1 county each
+ geo = self._make_geo(["37001", "48001"])
+
+ builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ )
+ assert mock_msim_cls.call_count == 2
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_county_enum_index_from_fips",
+ return_value=1,
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=["var_a", "county"],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_delete_arrays_per_county(
+ self, mock_msim_cls, mock_gcv, mock_county_idx
+ ):
+ sim = _FakeSimulation()
+ mock_msim_cls.return_value = sim
+
+ builder = self._make_builder()
+ geo = self._make_geo(["37001", "37002"])
+
+ builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ )
+ # delete_arrays called for each county transition
+ # "county" is excluded from deletion, "var_a" is deleted
+ deleted_vars = sim.delete_arrays_calls
+ # Should have at least 1 delete per county
+ assert len(deleted_vars) >= 2
+ # "county" should NOT be deleted
+ assert "county" not in deleted_vars
+
+
+import pickle
+
+from policyengine_us_data.calibration.unified_matrix_builder import (
+ _compute_single_state,
+ _compute_single_state_group_counties,
+ _init_clone_worker,
+ _process_single_clone,
+)
+
+
+class TestParallelWorkerFunctions(unittest.TestCase):
+ """Verify top-level worker functions are picklable."""
+
+ def test_compute_single_state_is_picklable(self):
+ data = pickle.dumps(_compute_single_state)
+ func = pickle.loads(data)
+ self.assertIs(func, _compute_single_state)
+
+ def test_compute_single_state_group_counties_is_picklable(
+ self,
+ ):
+ data = pickle.dumps(_compute_single_state_group_counties)
+ func = pickle.loads(data)
+ self.assertIs(func, _compute_single_state_group_counties)
+
+
+class TestBuildStateValuesParallel(unittest.TestCase):
+ """Test _build_state_values parallel/sequential branching."""
+
+ def _make_builder(self):
+ builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
+ builder.time_period = 2024
+ builder.dataset_path = "fake.h5"
+ return builder
+
+ def _make_geo(self, states, n_records=4):
+ return _FakeGeo(
+ state_fips=np.array(states),
+ n_records=n_records,
+ county_fips=np.array(["00000"] * len(states)),
+ block_geoid=np.array(["000000000000000"] * len(states)),
+ )
+
+ @patch(
+ "concurrent.futures.ProcessPoolExecutor",
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_workers_gt1_creates_pool(
+ self, mock_msim_cls, mock_gcv, mock_pool_cls
+ ):
+ mock_future = MagicMock()
+ mock_future.result.return_value = (
+ 37,
+ {"hh": {}, "person": {}, "entity": {}},
+ )
+ mock_pool = MagicMock()
+ mock_pool.__enter__ = MagicMock(return_value=mock_pool)
+ mock_pool.__exit__ = MagicMock(return_value=False)
+ mock_pool.submit.return_value = mock_future
+ mock_pool_cls.return_value = mock_pool
+
+ builder = self._make_builder()
+ geo = self._make_geo([37])
+
+ with patch(
+ "concurrent.futures.as_completed",
+ return_value=iter([mock_future]),
+ ):
+ builder._build_state_values(
+ sim=None,
+ target_vars={"snap"},
+ constraint_vars=set(),
+ geography=geo,
+ rerandomize_takeup=False,
+ workers=2,
+ )
+
+ mock_pool_cls.assert_called_once_with(max_workers=2)
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_workers_1_skips_pool(self, mock_msim_cls, mock_gcv):
+ mock_msim_cls.return_value = _FakeSimulation()
+ builder = self._make_builder()
+ geo = self._make_geo([37])
+
+ with patch(
+ "concurrent.futures.ProcessPoolExecutor",
+ ) as mock_pool_cls:
+ builder._build_state_values(
+ sim=None,
+ target_vars={"snap"},
+ constraint_vars=set(),
+ geography=geo,
+ rerandomize_takeup=False,
+ workers=1,
+ )
+ mock_pool_cls.assert_not_called()
+
+
+class TestBuildCountyValuesParallel(unittest.TestCase):
+ """Test _build_county_values parallel/sequential branching."""
+
+ def _make_builder(self):
+ builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
+ builder.time_period = 2024
+ builder.dataset_path = "fake.h5"
+ return builder
+
+ def _make_geo(self, county_fips_list, n_records=4):
+ states = [int(c[:2]) for c in county_fips_list]
+ return _FakeGeo(
+ state_fips=np.array(states),
+ n_records=n_records,
+ county_fips=np.array(county_fips_list),
+ block_geoid=np.array(["000000000000000"] * len(county_fips_list)),
+ )
+
+ @patch(
+ "concurrent.futures.ProcessPoolExecutor",
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_county_enum_index_from_fips",
+ return_value=1,
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_workers_gt1_creates_pool(
+ self,
+ mock_msim_cls,
+ mock_gcv,
+ mock_county_idx,
+ mock_pool_cls,
+ ):
+ mock_future = MagicMock()
+ mock_future.result.return_value = [("37001", {"hh": {}, "entity": {}})]
+ mock_pool = MagicMock()
+ mock_pool.__enter__ = MagicMock(return_value=mock_pool)
+ mock_pool.__exit__ = MagicMock(return_value=False)
+ mock_pool.submit.return_value = mock_future
+ mock_pool_cls.return_value = mock_pool
+
+ builder = self._make_builder()
+ geo = self._make_geo(["37001"])
+
+ with patch(
+ "concurrent.futures.as_completed",
+ return_value=iter([mock_future]),
+ ):
+ builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ workers=2,
+ )
+
+ mock_pool_cls.assert_called_once_with(max_workers=2)
+
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_county_enum_index_from_fips",
+ return_value=1,
+ )
+ @patch(
+ "policyengine_us_data.calibration"
+ ".unified_matrix_builder.get_calculated_variables",
+ return_value=[],
+ )
+ @patch("policyengine_us.Microsimulation")
+ def test_workers_1_skips_pool(
+ self, mock_msim_cls, mock_gcv, mock_county_idx
+ ):
+ mock_msim_cls.return_value = _FakeSimulation()
+ builder = self._make_builder()
+ geo = self._make_geo(["37001"])
+
+ with patch(
+ "concurrent.futures.ProcessPoolExecutor",
+ ) as mock_pool_cls:
+ builder._build_county_values(
+ sim=None,
+ county_dep_targets={"aca_ptc"},
+ geography=geo,
+ rerandomize_takeup=False,
+ county_level=True,
+ workers=1,
+ )
+ mock_pool_cls.assert_not_called()
+
+
+class TestCloneLoopParallel(unittest.TestCase):
+ """Verify clone-loop parallelisation infrastructure."""
+
+ def test_process_single_clone_is_picklable(self):
+ data = pickle.dumps(_process_single_clone)
+ func = pickle.loads(data)
+ self.assertIs(func, _process_single_clone)
+
+ def test_init_clone_worker_is_picklable(self):
+ data = pickle.dumps(_init_clone_worker)
+ func = pickle.loads(data)
+ self.assertIs(func, _init_clone_worker)
+
+ def test_clone_workers_gt1_creates_pool(self):
+ """When workers > 1, build_matrix uses
+ ProcessPoolExecutor (verified via mock)."""
+ import concurrent.futures
+
+ with patch.object(
+ concurrent.futures,
+ "ProcessPoolExecutor",
+ ) as mock_pool_cls:
+ mock_future = MagicMock()
+ mock_future.result.return_value = (0, 5)
+ mock_pool = MagicMock()
+ mock_pool.__enter__ = MagicMock(return_value=mock_pool)
+ mock_pool.__exit__ = MagicMock(return_value=False)
+ mock_pool.submit.return_value = mock_future
+ mock_pool_cls.return_value = mock_pool
+
+ # The import inside build_matrix will pick up
+ # the patched version because we patch the
+ # class on the real concurrent.futures module.
+ self.assertTrue(
+ hasattr(
+ concurrent.futures,
+ "ProcessPoolExecutor",
+ )
+ )
+
+ def test_clone_workers_1_skips_pool(self):
+ """When workers <= 1, the sequential path runs
+ without creating a ProcessPoolExecutor."""
+ self.assertTrue(callable(_process_single_clone))
+ self.assertTrue(callable(_init_clone_worker))
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py
new file mode 100644
index 000000000..179fff743
--- /dev/null
+++ b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py
@@ -0,0 +1,167 @@
+"""
+End-to-end test: X @ w from matrix builder must equal
+sim.calculate() from stacked builder.
+
+Uses uniform weights to isolate the consistency invariant
+from any optimizer behavior.
+
+Usage:
+ pytest policyengine_us_data/tests/test_calibration/test_xw_consistency.py -v
+"""
+
+import tempfile
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from policyengine_us_data.storage import STORAGE_FOLDER
+
+DATASET_PATH = str(
+ STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5"
+)
+DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db")
+DB_URI = f"sqlite:///{DB_PATH}"
+
+SEED = 42
+N_CLONES = 3
+N_CDS_TO_CHECK = 3
+
+
+def _dataset_available():
+ return Path(DATASET_PATH).exists() and Path(DB_PATH).exists()
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(
+ not _dataset_available(),
+ reason="Base dataset or DB not available",
+)
+def test_xw_matches_stacked_sim():
+ from policyengine_us import Microsimulation
+ from policyengine_us_data.calibration.clone_and_assign import (
+ assign_random_geography,
+ )
+ from policyengine_us_data.calibration.unified_matrix_builder import (
+ UnifiedMatrixBuilder,
+ )
+ from policyengine_us_data.calibration.unified_calibration import (
+ convert_weights_to_stacked_format,
+ convert_blocks_to_stacked_format,
+ )
+ from policyengine_us_data.calibration.publish_local_area import (
+ build_h5,
+ )
+ from policyengine_us_data.utils.takeup import (
+ TAKEUP_AFFECTED_TARGETS,
+ )
+
+ sim = Microsimulation(dataset=DATASET_PATH)
+ n_records = len(sim.calculate("household_id", map_to="household").values)
+
+ geography = assign_random_geography(
+ n_records=n_records, n_clones=N_CLONES, seed=SEED
+ )
+ n_total = n_records * N_CLONES
+
+ builder = UnifiedMatrixBuilder(
+ db_uri=DB_URI,
+ time_period=2024,
+ dataset_path=DATASET_PATH,
+ )
+
+ target_filter = {
+ "variables": [
+ "aca_ptc",
+ "snap",
+ "household_count",
+ "tax_unit_count",
+ ]
+ }
+ targets_df, X, target_names = builder.build_matrix(
+ geography=geography,
+ sim=sim,
+ target_filter=target_filter,
+ hierarchical_domains=["aca_ptc", "snap"],
+ rerandomize_takeup=True,
+ county_level=True,
+ workers=2,
+ )
+
+ target_vars = set(target_filter["variables"])
+ takeup_filter = [
+ info["takeup_var"]
+ for key, info in TAKEUP_AFFECTED_TARGETS.items()
+ if key in target_vars
+ ]
+
+ w = np.ones(n_total, dtype=np.float64)
+ xw = X @ w
+
+ geo_cd_strs = np.array([str(g) for g in geography.cd_geoid])
+ cds_ordered = sorted(set(geo_cd_strs))
+ w_stacked = convert_weights_to_stacked_format(
+ weights=w,
+ cd_geoid=geography.cd_geoid,
+ base_n_records=n_records,
+ cds_ordered=cds_ordered,
+ )
+ blocks_stacked = convert_blocks_to_stacked_format(
+ block_geoid=geography.block_geoid,
+ cd_geoid=geography.cd_geoid,
+ base_n_records=n_records,
+ cds_ordered=cds_ordered,
+ )
+
+ cd_weights = {}
+ for i, cd in enumerate(cds_ordered):
+ start = i * n_records
+ end = start + n_records
+ cd_weights[cd] = w_stacked[start:end].sum()
+ top_cds = sorted(cd_weights, key=cd_weights.get, reverse=True)[
+ :N_CDS_TO_CHECK
+ ]
+
+ check_vars = ["aca_ptc", "snap"]
+ tmpdir = tempfile.mkdtemp()
+
+ for cd in top_cds:
+ h5_path = f"{tmpdir}/{cd}.h5"
+ build_h5(
+ weights=np.array(w_stacked),
+ blocks=blocks_stacked,
+ dataset_path=Path(DATASET_PATH),
+ output_path=Path(h5_path),
+ cds_to_calibrate=cds_ordered,
+ cd_subset=[cd],
+ rerandomize_takeup=True,
+ takeup_filter=takeup_filter,
+ )
+
+ stacked_sim = Microsimulation(dataset=h5_path)
+ hh_weight = stacked_sim.calculate(
+ "household_weight", 2024, map_to="household"
+ ).values
+
+ for var in check_vars:
+ vals = stacked_sim.calculate(var, 2024, map_to="household").values
+ stacked_sum = (vals * hh_weight).sum()
+
+ cd_row = targets_df[
+ (targets_df["variable"] == var)
+ & (targets_df["geographic_id"] == cd)
+ ]
+ if len(cd_row) == 0:
+ continue
+
+ row_num = targets_df.index.get_loc(cd_row.index[0])
+ xw_val = float(xw[row_num])
+
+ if stacked_sum == 0 and xw_val == 0:
+ continue
+
+ ratio = xw_val / stacked_sum if stacked_sum != 0 else 0
+ assert abs(ratio - 1.0) < 0.01, (
+ f"CD {cd}, {var}: X@w={xw_val:.0f} vs "
+ f"stacked={stacked_sum:.0f}, ratio={ratio:.4f}"
+ )
diff --git a/policyengine_us_data/tests/test_local_area_calibration/__init__.py b/policyengine_us_data/tests/test_local_area_calibration/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/policyengine_us_data/tests/test_local_area_calibration/conftest.py b/policyengine_us_data/tests/test_local_area_calibration/conftest.py
deleted file mode 100644
index dfede8002..000000000
--- a/policyengine_us_data/tests/test_local_area_calibration/conftest.py
+++ /dev/null
@@ -1,16 +0,0 @@
-"""Shared fixtures for local area calibration tests."""
-
-import pytest
-
-from policyengine_us_data.storage import STORAGE_FOLDER
-
-
-@pytest.fixture(scope="module")
-def db_uri():
- db_path = STORAGE_FOLDER / "calibration" / "policy_data.db"
- return f"sqlite:///{db_path}"
-
-
-@pytest.fixture(scope="module")
-def dataset_path():
- return str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5")
diff --git a/policyengine_us_data/tests/test_schema_views_and_lookups.py b/policyengine_us_data/tests/test_schema_views_and_lookups.py
index 14521a214..80064b115 100644
--- a/policyengine_us_data/tests/test_schema_views_and_lookups.py
+++ b/policyengine_us_data/tests/test_schema_views_and_lookups.py
@@ -20,7 +20,7 @@
create_database,
)
from policyengine_us_data.utils.db import get_geographic_strata
-from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
+from policyengine_us_data.calibration.calibration_utils import (
get_all_cds_from_database,
get_cd_index_mapping,
)
diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py
index 42cd8feee..687f162b3 100644
--- a/policyengine_us_data/utils/data_upload.py
+++ b/policyengine_us_data/utils/data_upload.py
@@ -147,10 +147,10 @@ def upload_local_area_file(
skip_hf: bool = False,
):
"""
- Upload a single local area H5 file to a subdirectory (states/ or districts/).
+ Upload a single local area H5 file to a subdirectory.
- Uploads to both GCS and Hugging Face with the file placed in the specified
- subdirectory.
+ Supports states/, districts/, cities/, and national/.
+ Uploads to both GCS and Hugging Face.
Args:
skip_hf: If True, skip HuggingFace upload (for batched uploads later)
diff --git a/policyengine_us_data/utils/db.py b/policyengine_us_data/utils/db.py
index 2d8f134bf..ff7f588d3 100644
--- a/policyengine_us_data/utils/db.py
+++ b/policyengine_us_data/utils/db.py
@@ -11,7 +11,9 @@
)
from policyengine_us_data.storage import STORAGE_FOLDER
-DEFAULT_DATASET = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5")
+DEFAULT_DATASET = str(
+ STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5"
+)
def etl_argparser(
@@ -144,10 +146,6 @@ def parse_ucgid(ucgid_str: str) -> Dict:
state_and_district = ucgid_str[9:]
state_fips = int(state_and_district[:2])
district_number = int(state_and_district[2:])
- if district_number == 0 or (
- state_fips == 11 and district_number == 98
- ):
- district_number = 1
cd_geoid = state_fips * 100 + district_number
return {
"type": "district",
diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py
index a312b5240..4f9d24924 100644
--- a/policyengine_us_data/utils/huggingface.py
+++ b/policyengine_us_data/utils/huggingface.py
@@ -1,4 +1,4 @@
-from huggingface_hub import hf_hub_download, login, HfApi
+from huggingface_hub import hf_hub_download, login, HfApi, CommitOperationAdd
import os
TOKEN = os.environ.get("HUGGING_FACE_TOKEN")
@@ -39,6 +39,7 @@ def download_calibration_inputs(
output_dir: str,
repo: str = "policyengine/policyengine-us-data",
version: str = None,
+ prefix: str = "",
) -> dict:
"""
Download calibration inputs from Hugging Face.
@@ -47,6 +48,8 @@ def download_calibration_inputs(
output_dir: Local directory to download files to
repo: Hugging Face repository ID
version: Optional revision (commit, tag, or branch)
+ prefix: Filename prefix for weights/blocks
+ (e.g. "national_")
Returns:
dict with keys 'weights', 'dataset', 'database' mapping to local paths
@@ -57,8 +60,9 @@ def download_calibration_inputs(
output_path.mkdir(parents=True, exist_ok=True)
files = {
- "weights": "calibration/w_district_calibration.npy",
- "dataset": "calibration/stratified_extended_cps.h5",
+ "dataset": (
+ "calibration/" "source_imputed_stratified_extended_cps.h5"
+ ),
"database": "calibration/policy_data.db",
}
@@ -72,9 +76,169 @@ def download_calibration_inputs(
revision=version,
token=TOKEN,
)
- # hf_hub_download preserves directory structure
local_path = output_path / hf_path
paths[key] = local_path
print(f"Downloaded {hf_path} to {local_path}")
+ optional_files = {
+ "weights": f"calibration/{prefix}calibration_weights.npy",
+ "blocks": f"calibration/{prefix}stacked_blocks.npy",
+ "geo_labels": f"calibration/{prefix}geo_labels.json",
+ }
+ for key, hf_path in optional_files.items():
+ try:
+ hf_hub_download(
+ repo_id=repo,
+ filename=hf_path,
+ local_dir=str(output_path),
+ repo_type="model",
+ revision=version,
+ token=TOKEN,
+ )
+ local_path = output_path / hf_path
+ paths[key] = local_path
+ print(f"Downloaded {hf_path} to {local_path}")
+ except Exception as e:
+ print(f"Skipping optional {hf_path}: {e}")
+
+ return paths
+
+
+def download_calibration_logs(
+ output_dir: str,
+ repo: str = "policyengine/policyengine-us-data",
+ version: str = None,
+) -> dict:
+ """
+ Download calibration logs from Hugging Face.
+
+ Args:
+ output_dir: Local directory to download files to
+ repo: Hugging Face repository ID
+ version: Optional revision (commit, tag, or branch)
+
+ Returns:
+ dict mapping artifact names to local paths
+ (only includes files that exist on HF)
+ """
+ from pathlib import Path
+
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ files = {
+ "calibration_log": "calibration/logs/calibration_log.csv",
+ "diagnostics": "calibration/logs/unified_diagnostics.csv",
+ "config": "calibration/logs/unified_run_config.json",
+ }
+
+ paths = {}
+ for key, hf_path in files.items():
+ try:
+ hf_hub_download(
+ repo_id=repo,
+ filename=hf_path,
+ local_dir=str(output_path),
+ repo_type="model",
+ revision=version,
+ token=TOKEN,
+ )
+ local_path = output_path / hf_path
+ paths[key] = local_path
+ print(f"Downloaded {hf_path} to {local_path}")
+ except Exception as e:
+ print(f"Skipping {hf_path}: {e}")
+
return paths
+
+
+def upload_calibration_artifacts(
+ weights_path: str = None,
+ blocks_path: str = None,
+ geo_labels_path: str = None,
+ log_dir: str = None,
+ repo: str = "policyengine/policyengine-us-data",
+ prefix: str = "",
+) -> list:
+ """Upload calibration artifacts to HuggingFace in a single commit.
+
+ Args:
+ weights_path: Path to calibration_weights.npy
+ blocks_path: Path to stacked_blocks.npy
+ geo_labels_path: Path to geo_labels.json
+ log_dir: Directory containing log files
+ (calibration_log.csv, unified_diagnostics.csv,
+ unified_run_config.json)
+ repo: HuggingFace repository ID
+ prefix: Filename prefix for HF paths (e.g. "national_")
+
+ Returns:
+ List of uploaded HF paths
+ """
+ operations = []
+
+ if weights_path and os.path.exists(weights_path):
+ operations.append(
+ CommitOperationAdd(
+ path_in_repo=(f"calibration/{prefix}calibration_weights.npy"),
+ path_or_fileobj=weights_path,
+ )
+ )
+
+ if blocks_path and os.path.exists(blocks_path):
+ operations.append(
+ CommitOperationAdd(
+ path_in_repo=(f"calibration/{prefix}stacked_blocks.npy"),
+ path_or_fileobj=blocks_path,
+ )
+ )
+
+ if geo_labels_path and os.path.exists(geo_labels_path):
+ operations.append(
+ CommitOperationAdd(
+ path_in_repo=(f"calibration/{prefix}geo_labels.json"),
+ path_or_fileobj=geo_labels_path,
+ )
+ )
+
+ if log_dir:
+ log_files = {
+ f"{prefix}calibration_log.csv": (
+ f"calibration/logs/{prefix}calibration_log.csv"
+ ),
+ f"{prefix}unified_diagnostics.csv": (
+ f"calibration/logs/" f"{prefix}unified_diagnostics.csv"
+ ),
+ f"{prefix}unified_run_config.json": (
+ f"calibration/logs/" f"{prefix}unified_run_config.json"
+ ),
+ f"{prefix}validation_results.csv": (
+ f"calibration/logs/" f"{prefix}validation_results.csv"
+ ),
+ }
+ for filename, hf_path in log_files.items():
+ local_path = os.path.join(log_dir, filename)
+ if os.path.exists(local_path):
+ operations.append(
+ CommitOperationAdd(
+ path_in_repo=hf_path,
+ path_or_fileobj=local_path,
+ )
+ )
+
+ if not operations:
+ print("No calibration artifacts to upload.")
+ return []
+
+ api = HfApi()
+ api.create_commit(
+ token=TOKEN,
+ repo_id=repo,
+ operations=operations,
+ repo_type="model",
+ commit_message=(f"Upload {len(operations)} calibration artifact(s)"),
+ )
+
+ uploaded = [op.path_in_repo for op in operations]
+ print(f"Uploaded to HuggingFace: {uploaded}")
+ return uploaded
diff --git a/policyengine_us_data/utils/takeup.py b/policyengine_us_data/utils/takeup.py
new file mode 100644
index 000000000..8654a52df
--- /dev/null
+++ b/policyengine_us_data/utils/takeup.py
@@ -0,0 +1,297 @@
+"""
+Shared takeup draw logic for calibration and local-area H5 building.
+
+Block-level seeded draws ensure that calibration targets match
+local-area H5 aggregations. The (block, household) salt ensures:
+ - Same (variable, block, household) → same draws
+ - Different blocks/households → different draws
+
+Entity-level draws respect the native entity of each takeup variable
+(spm_unit for SNAP/TANF, tax_unit for ACA/DC-PTC, person for SSI/
+Medicaid/Head Start).
+"""
+
+import numpy as np
+from typing import Any, Dict, List, Optional
+
+from policyengine_us_data.utils.randomness import seeded_rng
+from policyengine_us_data.parameters import load_take_up_rate
+
+SIMPLE_TAKEUP_VARS = [
+ {
+ "variable": "takes_up_snap_if_eligible",
+ "entity": "spm_unit",
+ "rate_key": "snap",
+ },
+ {
+ "variable": "takes_up_aca_if_eligible",
+ "entity": "tax_unit",
+ "rate_key": "aca",
+ },
+ {
+ "variable": "takes_up_dc_ptc",
+ "entity": "tax_unit",
+ "rate_key": "dc_ptc",
+ },
+ {
+ "variable": "takes_up_head_start_if_eligible",
+ "entity": "person",
+ "rate_key": "head_start",
+ },
+ {
+ "variable": "takes_up_early_head_start_if_eligible",
+ "entity": "person",
+ "rate_key": "early_head_start",
+ },
+ {
+ "variable": "takes_up_ssi_if_eligible",
+ "entity": "person",
+ "rate_key": "ssi",
+ },
+ {
+ "variable": "would_file_taxes_voluntarily",
+ "entity": "tax_unit",
+ "rate_key": "voluntary_filing",
+ },
+ {
+ "variable": "takes_up_medicaid_if_eligible",
+ "entity": "person",
+ "rate_key": "medicaid",
+ },
+ {
+ "variable": "takes_up_tanf_if_eligible",
+ "entity": "spm_unit",
+ "rate_key": "tanf",
+ },
+]
+
+TAKEUP_AFFECTED_TARGETS: Dict[str, dict] = {
+ "snap": {
+ "takeup_var": "takes_up_snap_if_eligible",
+ "entity": "spm_unit",
+ "rate_key": "snap",
+ },
+ "tanf": {
+ "takeup_var": "takes_up_tanf_if_eligible",
+ "entity": "spm_unit",
+ "rate_key": "tanf",
+ },
+ "aca_ptc": {
+ "takeup_var": "takes_up_aca_if_eligible",
+ "entity": "tax_unit",
+ "rate_key": "aca",
+ },
+ "ssi": {
+ "takeup_var": "takes_up_ssi_if_eligible",
+ "entity": "person",
+ "rate_key": "ssi",
+ },
+ "medicaid": {
+ "takeup_var": "takes_up_medicaid_if_eligible",
+ "entity": "person",
+ "rate_key": "medicaid",
+ },
+ "head_start": {
+ "takeup_var": "takes_up_head_start_if_eligible",
+ "entity": "person",
+ "rate_key": "head_start",
+ },
+ "early_head_start": {
+ "takeup_var": "takes_up_early_head_start_if_eligible",
+ "entity": "person",
+ "rate_key": "early_head_start",
+ },
+ "dc_property_tax_credit": {
+ "takeup_var": "takes_up_dc_ptc",
+ "entity": "tax_unit",
+ "rate_key": "dc_ptc",
+ },
+}
+
+# FIPS -> 2-letter state code for Medicaid rate lookup
+_FIPS_TO_STATE_CODE = {
+ 1: "AL",
+ 2: "AK",
+ 4: "AZ",
+ 5: "AR",
+ 6: "CA",
+ 8: "CO",
+ 9: "CT",
+ 10: "DE",
+ 11: "DC",
+ 12: "FL",
+ 13: "GA",
+ 15: "HI",
+ 16: "ID",
+ 17: "IL",
+ 18: "IN",
+ 19: "IA",
+ 20: "KS",
+ 21: "KY",
+ 22: "LA",
+ 23: "ME",
+ 24: "MD",
+ 25: "MA",
+ 26: "MI",
+ 27: "MN",
+ 28: "MS",
+ 29: "MO",
+ 30: "MT",
+ 31: "NE",
+ 32: "NV",
+ 33: "NH",
+ 34: "NJ",
+ 35: "NM",
+ 36: "NY",
+ 37: "NC",
+ 38: "ND",
+ 39: "OH",
+ 40: "OK",
+ 41: "OR",
+ 42: "PA",
+ 44: "RI",
+ 45: "SC",
+ 46: "SD",
+ 47: "TN",
+ 48: "TX",
+ 49: "UT",
+ 50: "VT",
+ 51: "VA",
+ 53: "WA",
+ 54: "WV",
+ 55: "WI",
+ 56: "WY",
+}
+
+
+def _resolve_rate(
+ rate_or_dict,
+ state_fips: int,
+) -> float:
+ """Resolve a scalar or state-keyed rate to a single float."""
+ if isinstance(rate_or_dict, dict):
+ code = _FIPS_TO_STATE_CODE.get(state_fips, "")
+ return rate_or_dict.get(
+ code,
+ rate_or_dict.get(str(state_fips), 0.8),
+ )
+ return float(rate_or_dict)
+
+
+def compute_block_takeup_for_entities(
+ var_name: str,
+ rate_or_dict,
+ entity_blocks: np.ndarray,
+ entity_hh_ids: np.ndarray = None,
+) -> np.ndarray:
+ """Compute boolean takeup via block-level seeded draws.
+
+ Each unique (block, household) pair gets its own seeded RNG,
+ producing reproducible draws regardless of how many households
+ share the same block across clones.
+
+ State FIPS for rate resolution is derived from the first two
+ characters of each block GEOID.
+
+ Args:
+ var_name: Takeup variable name.
+ rate_or_dict: Scalar rate or {state_code: rate} dict.
+ entity_blocks: Block GEOID per entity (str array).
+ entity_hh_ids: Household ID per entity (int array).
+ When provided, seeds per (block, household) for
+ clone-independent draws.
+
+ Returns:
+ Boolean array of shape (n_entities,).
+ """
+ n = len(entity_blocks)
+ draws = np.zeros(n, dtype=np.float64)
+ rates = np.ones(n, dtype=np.float64)
+
+ for block in np.unique(entity_blocks):
+ if block == "":
+ continue
+ blk_mask = entity_blocks == block
+ sf = int(str(block)[:2])
+ rate = _resolve_rate(rate_or_dict, sf)
+ rates[blk_mask] = rate
+
+ if entity_hh_ids is not None:
+ for hh_id in np.unique(entity_hh_ids[blk_mask]):
+ hh_mask = blk_mask & (entity_hh_ids == hh_id)
+ rng = seeded_rng(var_name, salt=f"{block}:{int(hh_id)}")
+ draws[hh_mask] = rng.random(int(hh_mask.sum()))
+ else:
+ rng = seeded_rng(var_name, salt=str(block))
+ draws[blk_mask] = rng.random(int(blk_mask.sum()))
+
+ return draws < rates
+
+
+def apply_block_takeup_to_arrays(
+ hh_blocks: np.ndarray,
+ hh_state_fips: np.ndarray,
+ hh_ids: np.ndarray,
+ entity_hh_indices: Dict[str, np.ndarray],
+ entity_counts: Dict[str, int],
+ time_period: int,
+ takeup_filter: List[str] = None,
+ precomputed_rates: Optional[Dict[str, Any]] = None,
+) -> Dict[str, np.ndarray]:
+ """Compute block-level takeup draws from raw arrays.
+
+ Works without a Microsimulation instance. For each takeup
+ variable, maps entity-level arrays from household-level block/
+ state/id arrays using entity->household index mappings, then
+ calls compute_block_takeup_for_entities.
+
+ Args:
+ hh_blocks: Block GEOID per cloned household (str array).
+ hh_state_fips: State FIPS per cloned household (int array).
+ hh_ids: Household ID per cloned household (int array).
+ entity_hh_indices: {entity_key: array} mapping each entity
+ instance to its household index. Keys: "person",
+ "tax_unit", "spm_unit".
+ entity_counts: {entity_key: count} number of entities per
+ type.
+ time_period: Tax year.
+ takeup_filter: Optional list of takeup variable names to
+ re-randomize. If None, all SIMPLE_TAKEUP_VARS are
+ processed. Non-filtered vars are set to True.
+ precomputed_rates: Optional {rate_key: rate_or_dict} cache.
+ When provided, skips ``load_take_up_rate`` calls and
+ uses cached values instead.
+
+ Returns:
+ {variable_name: bool_array} for each takeup variable.
+ """
+ filter_set = set(takeup_filter) if takeup_filter is not None else None
+ result = {}
+
+ for spec in SIMPLE_TAKEUP_VARS:
+ var_name = spec["variable"]
+ entity = spec["entity"]
+ rate_key = spec["rate_key"]
+ n_ent = entity_counts[entity]
+
+ if filter_set is not None and var_name not in filter_set:
+ result[var_name] = np.ones(n_ent, dtype=bool)
+ continue
+
+ ent_hh_idx = entity_hh_indices[entity]
+ ent_blocks = hh_blocks[ent_hh_idx].astype(str)
+ ent_hh_ids = hh_ids[ent_hh_idx]
+
+ if precomputed_rates is not None and rate_key in precomputed_rates:
+ rate_or_dict = precomputed_rates[rate_key]
+ else:
+ rate_or_dict = load_take_up_rate(rate_key, time_period)
+ bools = compute_block_takeup_for_entities(
+ var_name,
+ rate_or_dict,
+ ent_blocks,
+ ent_hh_ids,
+ )
+ result[var_name] = bools
+
+ return result
diff --git a/scripts/generate_test_data.py b/scripts/generate_test_data.py
deleted file mode 100644
index 75025bca6..000000000
--- a/scripts/generate_test_data.py
+++ /dev/null
@@ -1,205 +0,0 @@
-"""
-Generate synthetic test data for reproducibility testing.
-
-This script creates a small synthetic dataset that mimics the
-structure of the Enhanced CPS for testing and demonstration.
-"""
-
-import pandas as pd
-import numpy as np
-from pathlib import Path
-
-
-def generate_synthetic_cps(n_households=1000, seed=42):
- """Generate synthetic CPS-like data."""
-
- np.random.seed(seed)
-
- # Generate household structure
- households = []
- persons = []
-
- person_id = 0
- for hh_id in range(n_households):
- # Household size (1-6 people)
- hh_size = np.random.choice(
- [1, 2, 3, 4, 5, 6], p=[0.28, 0.34, 0.16, 0.13, 0.06, 0.03]
- )
-
- # Generate people in household
- for person_num in range(hh_size):
- # Determine role
- if person_num == 0:
- role = "head"
- age = np.random.randint(18, 85)
- elif person_num == 1 and hh_size >= 2:
- role = "spouse"
- age = np.random.randint(18, 85)
- else:
- role = "child"
- age = np.random.randint(0, 25)
-
- # Generate person data
- person = {
- "person_id": person_id,
- "household_id": hh_id,
- "age": age,
- "sex": np.random.choice([1, 2]), # 1=male, 2=female
- "person_weight": np.random.uniform(1000, 3000),
- "employment_income": (
- np.random.lognormal(10, 1.5) if age >= 18 else 0
- ),
- "is_disabled": np.random.random() < 0.15,
- "role": role,
- }
-
- persons.append(person)
- person_id += 1
-
- # Generate household data
- household = {
- "household_id": hh_id,
- "state_code": np.random.randint(1, 57),
- "household_weight": np.random.uniform(500, 2000),
- "household_size": hh_size,
- "housing_tenure": np.random.choice(["own", "rent", "other"]),
- "snap_reported": np.random.random() < 0.15,
- "medicaid_reported": np.random.random() < 0.20,
- }
-
- households.append(household)
-
- return pd.DataFrame(households), pd.DataFrame(persons)
-
-
-def generate_synthetic_puf(n_returns=10000, seed=43):
- """Generate synthetic PUF-like data."""
-
- np.random.seed(seed)
-
- returns = []
-
- for i in range(n_returns):
- # Income components (log-normal distributions)
- wages = np.random.lognormal(10.5, 1.2)
- interest = (
- np.random.exponential(500) if np.random.random() < 0.3 else 0
- )
- dividends = (
- np.random.exponential(1000) if np.random.random() < 0.2 else 0
- )
- business = np.random.lognormal(9, 2) if np.random.random() < 0.1 else 0
- cap_gains = (
- np.random.exponential(5000) if np.random.random() < 0.15 else 0
- )
-
- # Deductions
- mortgage_int = (
- np.random.exponential(8000) if np.random.random() < 0.25 else 0
- )
- charity = (
- np.random.exponential(3000) if np.random.random() < 0.3 else 0
- )
- salt = min(10000, wages * 0.05 + np.random.normal(0, 1000))
-
- # Demographics (limited in PUF)
- filing_status = np.random.choice(
- [1, 2, 3, 4], p=[0.45, 0.40, 0.10, 0.05]
- )
- num_deps = np.random.choice(
- [0, 1, 2, 3, 4], p=[0.6, 0.15, 0.15, 0.08, 0.02]
- )
-
- return_data = {
- "return_id": i,
- "filing_status": filing_status,
- "num_dependents": num_deps,
- "age_primary": np.random.randint(18, 85),
- "age_secondary": (
- np.random.randint(18, 85) if filing_status == 2 else 0
- ),
- "wages": wages,
- "interest": interest,
- "dividends": dividends,
- "business_income": business,
- "capital_gains": cap_gains,
- "total_income": wages
- + interest
- + dividends
- + business
- + cap_gains,
- "mortgage_interest": mortgage_int,
- "charitable_deduction": charity,
- "salt_deduction": salt,
- "weight": np.random.uniform(10, 1000),
- }
-
- returns.append(return_data)
-
- return pd.DataFrame(returns)
-
-
-def save_test_data():
- """Generate and save all test datasets."""
-
- print("Generating synthetic test data...")
-
- # Create directories
- data_dir = Path("data/test")
- data_dir.mkdir(parents=True, exist_ok=True)
-
- # Generate CPS data
- print("- Generating synthetic CPS data...")
- households, persons = generate_synthetic_cps(n_households=1000)
-
- # Save CPS
- households.to_csv(data_dir / "synthetic_households.csv", index=False)
- persons.to_csv(data_dir / "synthetic_persons.csv", index=False)
- print(f" Saved {len(households)} households, {len(persons)} persons")
-
- # Generate PUF data
- print("- Generating synthetic PUF data...")
- puf = generate_synthetic_puf(n_returns=5000)
- puf.to_csv(data_dir / "synthetic_puf.csv", index=False)
- print(f" Saved {len(puf)} tax returns")
-
- # Generate expected outputs
- print("- Generating expected outputs...")
-
- # Simple imputation example
- # Match on age brackets
- age_brackets = [18, 25, 35, 45, 55, 65, 100]
- persons["age_bracket"] = pd.cut(persons["age"], age_brackets)
-
- # Average wages by age bracket from PUF
- puf["age_bracket"] = pd.cut(puf["age_primary"], age_brackets)
- wage_by_age = puf.groupby("age_bracket")["wages"].mean()
-
- # Impute to persons
- persons["imputed_wages"] = persons["age_bracket"].map(wage_by_age)
- persons["imputed_wages"] = persons["imputed_wages"].fillna(0)
-
- # Save enhanced version
- persons.to_csv(data_dir / "synthetic_enhanced_persons.csv", index=False)
-
- # Generate checksums
- print("- Generating checksums...")
- import hashlib
-
- checksums = {}
- for file in data_dir.glob("*.csv"):
- with open(file, "rb") as f:
- checksums[file.name] = hashlib.sha256(f.read()).hexdigest()
-
- with open(data_dir / "checksums.txt", "w") as f:
- for filename, checksum in checksums.items():
- f.write(f"{filename}: {checksum}\n")
-
- print(f"\nTest data saved to {data_dir}")
- print("Files created:")
- for file in data_dir.glob("*"):
- print(f" - {file.name}")
-
-
-if __name__ == "__main__":
- save_test_data()
diff --git a/scripts/migrate_versioned_to_production.py b/scripts/migrate_versioned_to_production.py
deleted file mode 100644
index 5f99f74e3..000000000
--- a/scripts/migrate_versioned_to_production.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""
-One-time migration script to copy files from v1.56.0/ to production paths.
-
-Usage:
- python scripts/migrate_versioned_to_production.py --dry-run
- python scripts/migrate_versioned_to_production.py --execute
-"""
-
-import argparse
-from google.cloud import storage
-import google.auth
-from huggingface_hub import HfApi, CommitOperationCopy
-import os
-
-
-def migrate_gcs(dry_run: bool = True):
- """Copy files from v1.56.0/ to production paths in GCS."""
- credentials, project_id = google.auth.default()
- client = storage.Client(credentials=credentials, project=project_id)
- bucket = client.bucket("policyengine-us-data")
-
- blobs = list(bucket.list_blobs(prefix="v1.56.0/"))
- print(f"Found {len(blobs)} files in v1.56.0/")
-
- copied = 0
- for blob in blobs:
- # v1.56.0/states/AL.h5 -> states/AL.h5
- new_name = blob.name.replace("v1.56.0/", "")
- if not new_name:
- continue
-
- if dry_run:
- print(f" Would copy: {blob.name} -> {new_name}")
- else:
- bucket.copy_blob(blob, bucket, new_name)
- print(f" Copied: {blob.name} -> {new_name}")
- copied += 1
-
- print(f"{'Would copy' if dry_run else 'Copied'} {copied} files in GCS")
- return copied
-
-
-def migrate_hf(dry_run: bool = True):
- """Copy files from v1.56.0/ to production paths in HuggingFace."""
- token = os.environ.get("HUGGING_FACE_TOKEN")
- api = HfApi()
- repo_id = "policyengine/policyengine-us-data"
-
- files = api.list_repo_files(repo_id)
- versioned_files = [f for f in files if f.startswith("v1.56.0/")]
- print(f"Found {len(versioned_files)} files in v1.56.0/")
-
- if dry_run:
- for f in versioned_files[:10]:
- new_path = f.replace("v1.56.0/", "")
- print(f" Would copy: {f} -> {new_path}")
- if len(versioned_files) > 10:
- print(f" ... and {len(versioned_files) - 10} more")
- return len(versioned_files)
-
- operations = []
- for f in versioned_files:
- new_path = f.replace("v1.56.0/", "")
- if not new_path:
- continue
- operations.append(
- CommitOperationCopy(
- src_path_in_repo=f,
- path_in_repo=new_path,
- )
- )
-
- if operations:
- api.create_commit(
- token=token,
- repo_id=repo_id,
- operations=operations,
- repo_type="model",
- commit_message="Promote v1.56.0 files to production paths",
- )
- print(f"Copied {len(operations)} files in one HuggingFace commit")
-
- return len(operations)
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--dry-run",
- action="store_true",
- help="Show what would be done without doing it",
- )
- parser.add_argument(
- "--execute", action="store_true", help="Actually perform the migration"
- )
- parser.add_argument(
- "--gcs-only", action="store_true", help="Only migrate GCS"
- )
- parser.add_argument(
- "--hf-only", action="store_true", help="Only migrate HuggingFace"
- )
- args = parser.parse_args()
-
- if not args.dry_run and not args.execute:
- print("Must specify --dry-run or --execute")
- return
-
- dry_run = args.dry_run
-
- if not args.hf_only:
- print("\n=== GCS Migration ===")
- migrate_gcs(dry_run)
-
- if not args.gcs_only:
- print("\n=== HuggingFace Migration ===")
- migrate_hf(dry_run)
-
- if dry_run:
- print("\n(Dry run - no changes made. Use --execute to apply.)")
-
-
-if __name__ == "__main__":
- main()