From dd75514341d7e181d6c36c33c73ef0b77ef8318d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 1 Jun 2026 22:09:07 +0200 Subject: [PATCH 1/7] Add live runtime stage manifests --- docs/api.md | 8 + docs/stage-contracts.md | 17 + src/microplex_us/pipelines/__init__.py | 12 + src/microplex_us/pipelines/artifacts.py | 227 ++++++-- src/microplex_us/pipelines/stage_contracts.py | 5 + src/microplex_us/pipelines/stage_manifest.py | 6 + .../pipelines/stage_manifest_builder.py | 121 ++++- .../pipelines/stage_manifest_types.py | 39 +- src/microplex_us/pipelines/stage_run.py | 51 +- src/microplex_us/pipelines/stage_runtime.py | 512 ++++++++++++++++++ src/microplex_us/pipelines/us.py | 80 ++- tests/pipelines/test_artifacts.py | 158 +++++- tests/pipelines/test_stage_manifest.py | 23 +- tests/pipelines/test_stage_run.py | 8 + tests/pipelines/test_stage_runtime.py | 105 ++++ 15 files changed, 1296 insertions(+), 76 deletions(-) create mode 100644 src/microplex_us/pipelines/stage_runtime.py create mode 100644 tests/pipelines/test_stage_runtime.py diff --git a/docs/api.md b/docs/api.md index 0fb6cfa..6f0b66f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -76,6 +76,14 @@ :undoc-members: ``` +## Stage runtime writer + +```{eval-rst} +.. automodule:: microplex_us.pipelines.stage_runtime + :members: + :undoc-members: +``` + ## Artifact helpers ```{eval-rst} diff --git a/docs/stage-contracts.md b/docs/stage-contracts.md index 33a6e31..0259fb5 100644 --- a/docs/stage-contracts.md +++ b/docs/stage-contracts.md @@ -10,6 +10,13 @@ file is the machine-readable saved-run overlay for the stage taxonomy. It record canonical stages, status for the current run, artifact paths, diagnostics owned by each stage, and the current resume posture. +`status` is the saved-artifact readiness view: it reports whether the artifacts +for that stage are ready, incomplete, missing, metadata-only, or deferred. +`lifecycleStatus` is the runtime view: it reports whether the stage is pending, +running, complete, failed, or deferred in the current run. Keeping these fields +separate lets a failed run say both "Stage 5 failed" and "Stage 4's saved +artifact is ready for manual replay." + Each saved bundle also includes typed per-stage output manifests at `stage_artifacts/manifests/.json`. These manifests are written through `USStageRunWriter`, which validates each stage as a whole instead of updating @@ -17,6 +24,16 @@ individual manifest keys directly. The manifest files live outside each stage's payload directory so they do not change the content hash of reloadable stage artifacts. +Live runs can use `USStageRuntimeWriter` to write those same per-stage manifests +incrementally. The writer exposes `start_stage`, `update`, `record_output`, +`record_diagnostic`, `complete_stage`, `fail_stage`, `defer_stage`, and +`finalize_from_artifact_manifest`. A stage can start only after the immediately +previous stage is complete unless explicit stage-input overrides are enabled. +The canonical multi-source versioned build path reserves the versioned artifact +directory before loading sources, writes Stage 1 immediately, writes Stage 2 as +source frames load, then finalizes all stage manifests against the completed +artifact manifest during save. + The registry exposes two seam layers: - `inputs` and `outputs` are structured stage resources. They identify artifact, diff --git a/src/microplex_us/pipelines/__init__.py b/src/microplex_us/pipelines/__init__.py index 353faa9..5ddd3c0 100644 --- a/src/microplex_us/pipelines/__init__.py +++ b/src/microplex_us/pipelines/__init__.py @@ -294,9 +294,12 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: ( "USDataFlowStageSummary", "USStageArtifactRecord", + "USStageFailureRecord", + "USStageLifecycleStatus", "USStageManifest", "USStageMetric", "USStageRecord", + "USStageRuntimeEventRecord", "USStageStatus", "USValidationEvidenceManifest", "USValidationEvidenceRecord", @@ -337,6 +340,8 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: "USSourceLoadingOutputs", "USSourcePlanningOutputs", "USStageInputOverride", + "USStageInputValidationSettings", + "USStageInputValidator", "USStageOutputManifest", "USStageRunWriter", "USValidationBenchmarkingOutputs", @@ -346,6 +351,13 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: "write_us_stage_run_manifests_from_artifact_manifest", ), ), + **_exports( + "microplex_us.pipelines.stage_runtime", + ( + "RuntimeUpdateSection", + "USStageRuntimeWriter", + ), + ), **_exports( "microplex_us.pipelines.summarize_pe_native_family_drilldown", ( diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 39c2871..94c9ea8 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -3,7 +3,8 @@ from __future__ import annotations import json -from dataclasses import dataclass, replace +from collections.abc import Mapping +from dataclasses import asdict, dataclass, replace from datetime import UTC, datetime from importlib.metadata import PackageNotFoundError, version from pathlib import Path @@ -44,9 +45,13 @@ write_us_policyengine_entity_stage_artifact, ) from microplex_us.pipelines.stage_run import ( + USArtifactRef, + USDiagnosticOutput, + USRunProfileOutputs, USStageInputOverride, write_us_stage_run_manifests_from_artifact_manifest, ) +from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter from microplex_us.pipelines.summarize_child_tax_unit_agi_drift import ( DEFAULT_VARIABLES as DEFAULT_CHILD_TAX_UNIT_AGI_DRIFT_VARIABLES, ) @@ -816,6 +821,7 @@ def save_us_microplex_artifacts( child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, allow_stage_input_overrides: bool = False, stage_input_overrides: tuple[USStageInputOverride, ...] = (), + stage_runtime_writer: USStageRuntimeWriter | None = None, ) -> USMicroplexArtifactPaths: """Persist a build result as a reproducible artifact bundle.""" output_dir = Path(output_dir) @@ -1227,12 +1233,15 @@ def save_us_microplex_artifacts( "path": str(resolved_run_index_path), "artifact_id": recorded_entry.artifact_id, } - manifest = write_us_stage_run_manifests_from_artifact_manifest( - output_dir, - manifest, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) + if stage_runtime_writer is not None: + manifest = stage_runtime_writer.finalize_from_artifact_manifest(manifest) + else: + manifest = write_us_stage_run_manifests_from_artifact_manifest( + output_dir, + manifest, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) assert_valid_benchmark_artifact_manifest( manifest, artifact_dir=output_dir, @@ -1552,12 +1561,28 @@ def build_and_save_versioned_us_microplex_from_source_providers( stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: """Build from multiple source providers, save a versioned bundle, and report frontier gap.""" - pipeline = USMicroplexPipeline(config) + resolved_config = config or USMicroplexBuildConfig() + _resolved_version_id, preallocated_output_dir, stage_runtime_writer = ( + _initialize_versioned_stage_runtime_writer( + output_root, + version_id=version_id, + config=resolved_config, + providers=providers, + queries=queries, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + ) + pipeline = USMicroplexPipeline( + resolved_config, + stage_runtime_writer=stage_runtime_writer, + ) build_result = pipeline.build_from_source_providers(providers, queries=queries) return _finalize_versioned_build_artifacts( build_result, output_root=output_root, version_id=version_id, + preallocated_output_dir=preallocated_output_dir, frontier_metric=frontier_metric, policyengine_comparison_cache=policyengine_comparison_cache, policyengine_target_provider=policyengine_target_provider, @@ -1577,6 +1602,7 @@ def build_and_save_versioned_us_microplex_from_source_providers( child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, allow_stage_input_overrides=allow_stage_input_overrides, stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, ) @@ -1642,6 +1668,7 @@ def _finalize_versioned_build_artifacts( *, output_root: str | Path, version_id: str | None, + preallocated_output_dir: str | Path | None = None, frontier_metric: FrontierMetric, policyengine_comparison_cache: PolicyEngineUSComparisonCache | None, policyengine_target_provider: TargetProvider | None, @@ -1663,30 +1690,61 @@ def _finalize_versioned_build_artifacts( child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, allow_stage_input_overrides: bool = False, stage_input_overrides: tuple[USStageInputOverride, ...] = (), + stage_runtime_writer: USStageRuntimeWriter | None = None, ) -> USMicroplexVersionedBuildArtifacts: - artifact_paths = save_versioned_us_microplex_artifacts( - build_result, - output_root, - version_id=version_id, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) + if preallocated_output_dir is not None: + output_root_path = Path(output_root) + output_dir = Path(preallocated_output_dir) + artifact_paths = save_us_microplex_artifacts( + build_result, + output_dir, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path + or output_root_path / "run_registry.jsonl", + run_index_path=run_index_path or output_root_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, + ) + artifact_paths = replace(artifact_paths, version_id=output_dir.name) + else: + artifact_paths = save_versioned_us_microplex_artifacts( + build_result, + output_root, + version_id=version_id, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, + ) current_entry = None frontier_entry = None frontier_delta = None @@ -1827,6 +1885,19 @@ def _allocate_versioned_output_dir( *, version_id: str | None, result: USMicroplexBuildResult, +) -> tuple[str, Path]: + return _allocate_versioned_output_dir_for_config( + output_root, + version_id=version_id, + config=result.config.to_dict(), + ) + + +def _allocate_versioned_output_dir_for_config( + output_root: Path, + *, + version_id: str | None, + config: dict[str, Any], ) -> tuple[str, Path]: if version_id is not None: output_dir = output_root / version_id @@ -1836,7 +1907,7 @@ def _allocate_versioned_output_dir( ) return version_id, output_dir - config_hash = _short_config_hash(result.config.to_dict()) + config_hash = _short_config_hash(config) timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") base_version_id = f"{timestamp}-{config_hash}" candidate_version_id = base_version_id @@ -1857,6 +1928,98 @@ def _short_config_hash(config: dict[str, Any]) -> str: return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:8] +def _initialize_versioned_stage_runtime_writer( + output_root: str | Path, + *, + version_id: str | None, + config: USMicroplexBuildConfig, + providers: list[SourceProvider], + queries: dict[str, SourceQuery] | None, + allow_stage_input_overrides: bool, + stage_input_overrides: tuple[USStageInputOverride, ...], +) -> tuple[str, Path, USStageRuntimeWriter]: + root = Path(output_root) + root.mkdir(parents=True, exist_ok=True) + resolved_version_id, output_dir = _allocate_versioned_output_dir_for_config( + root, + version_id=version_id, + config=config.to_dict(), + ) + provider_query_plan = _provider_query_plan(providers, queries) + writer = USStageRuntimeWriter( + output_dir, + manifest_payload={ + "created_at": datetime.now(UTC).isoformat(), + "config": config.to_dict(), + "artifacts": {"manifest": "manifest.json"}, + }, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + writer.start_stage( + "01_run_profile", + metadata={"version_id": resolved_version_id}, + ) + writer.complete_stage( + USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config=config.to_dict(), + provider_query_plan=provider_query_plan, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description="Runtime run-profile summary.", + summary={ + "provider_names": provider_query_plan["provider_names"], + "version_id": resolved_version_id, + }, + ) + }, + ) + ) + return resolved_version_id, output_dir, writer + + +def _provider_query_plan( + providers: list[SourceProvider], + queries: dict[str, SourceQuery] | None, +) -> dict[str, Any]: + return { + "provider_names": [provider.descriptor.name for provider in providers], + "queries": { + key: _json_ready_query(query) for key, query in dict(queries or {}).items() + }, + } + + +def _json_ready_query(query: SourceQuery) -> dict[str, Any]: + if hasattr(query, "to_dict"): + payload = query.to_dict() + if isinstance(payload, dict): + return payload + if hasattr(query, "__dataclass_fields__"): + return _json_ready(asdict(query)) + return _json_ready(vars(query)) + + +def _json_ready(value: Any) -> Any: + if isinstance(value, Mapping): + return {str(key): _json_ready(item) for key, item in value.items()} + if isinstance(value, (tuple, list, set, frozenset)): + return [_json_ready(item) for item in value] + if isinstance(value, Path): + return str(value) + if hasattr(value, "value"): + return value.value + return value + + def _registry_metric_value(entry: Any | None, metric: FrontierMetric) -> float | None: if entry is None: return None diff --git a/src/microplex_us/pipelines/stage_contracts.py b/src/microplex_us/pipelines/stage_contracts.py index d9fdc5d..e3a337c 100644 --- a/src/microplex_us/pipelines/stage_contracts.py +++ b/src/microplex_us/pipelines/stage_contracts.py @@ -830,6 +830,11 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] "Stage-local calibration summary.", stage_id="07_calibration", ), + _stage_output_resource( + "target_ledger", + "Structured target-resolution and calibration target ledger.", + stage_id="07_calibration", + ), ), artifacts=( USStageArtifactContract( diff --git a/src/microplex_us/pipelines/stage_manifest.py b/src/microplex_us/pipelines/stage_manifest.py index a7d4c4f..14aa7a0 100644 --- a/src/microplex_us/pipelines/stage_manifest.py +++ b/src/microplex_us/pipelines/stage_manifest.py @@ -19,12 +19,15 @@ US_VALIDATION_STAGE_ID, USDataFlowStageSummary, USStageArtifactRecord, + USStageFailureRecord, + USStageLifecycleStatus, USStageManifest, USStageMetric, USStageMetricValue, USStageRecord, USStageResourceRecord, USStageResumeRecord, + USStageRuntimeEventRecord, USStageStatus, USStageValidationRecord, USStageValidationStatus, @@ -47,12 +50,15 @@ "US_STAGE_ARTIFACT_ROOT", "US_STAGE_MANIFEST_SCHEMA_VERSION", "USStageArtifactRecord", + "USStageFailureRecord", + "USStageLifecycleStatus", "USStageManifest", "USStageMetric", "USStageMetricValue", "USStageRecord", "USStageResourceRecord", "USStageResumeRecord", + "USStageRuntimeEventRecord", "USStageStatus", "USStageValidationRecord", "USStageValidationStatus", diff --git a/src/microplex_us/pipelines/stage_manifest_builder.py b/src/microplex_us/pipelines/stage_manifest_builder.py index 5befd74..eed79a8 100644 --- a/src/microplex_us/pipelines/stage_manifest_builder.py +++ b/src/microplex_us/pipelines/stage_manifest_builder.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from collections.abc import Iterable from pathlib import Path from typing import Any, cast @@ -16,6 +17,7 @@ from microplex_us.pipelines.stage_manifest_types import ( US_STAGE_MANIFEST_SCHEMA_VERSION, USStageArtifactRecord, + USStageLifecycleStatus, USStageManifest, USStageRecord, USStageResourceRecord, @@ -37,11 +39,16 @@ def build_us_stage_manifest( manifest = dict(manifest_payload) artifact_map = dict(manifest.get("artifacts", {})) assumed_existing = set(assume_existing_artifact_keys) + stage_output_manifests = _load_stage_output_manifests( + artifact_root, + manifest, + ) stages = [ _stage_record( contract, artifact_root=artifact_root, manifest=manifest, + stage_output_manifest=stage_output_manifests.get(contract.id), assume_existing_artifact_keys=assumed_existing, ) for contract in default_us_pipeline_stage_contracts() @@ -86,6 +93,7 @@ def _stage_record( *, artifact_root: Path, manifest: dict[str, Any], + stage_output_manifest: dict[str, Any] | None, assume_existing_artifact_keys: set[str], ) -> USStageRecord: artifacts = [ @@ -97,18 +105,34 @@ def _stage_record( ) for artifact in contract.artifacts ] + status = stage_status( + contract.id, + artifact_root=artifact_root, + manifest=manifest, + artifacts=artifacts, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ) return { "id": contract.id, "step": contract.step, "title": contract.title, "purpose": contract.purpose, - "status": stage_status( - contract.id, - artifact_root=artifact_root, - manifest=manifest, - artifacts=artifacts, - assume_existing_artifact_keys=assume_existing_artifact_keys, + "status": status, + "lifecycleStatus": _stage_lifecycle_status( + stage_output_manifest, + saved_status=status, + ), + "outputManifest": _stage_output_manifest_ref(manifest, contract.id), + "startedAt": _runtime_optional_str(stage_output_manifest, "startedAt"), + "updatedAt": _runtime_optional_str(stage_output_manifest, "updatedAt"), + "completedAt": _runtime_optional_str(stage_output_manifest, "completedAt"), + "failedAt": _runtime_optional_str(stage_output_manifest, "failedAt"), + "deferredReason": _runtime_optional_str( + stage_output_manifest, + "deferredReason", ), + "failure": _runtime_mapping_or_none(stage_output_manifest, "failure"), + "events": _runtime_events(stage_output_manifest), "consumes": list(contract.consumes), "produces": list(contract.produces), "inputs": _resource_records(contract.inputs), @@ -127,6 +151,91 @@ def _stage_record( } +def _load_stage_output_manifests( + artifact_root: Path, + manifest: dict[str, Any], +) -> dict[str, dict[str, Any]]: + stage_manifest_paths = manifest.get("stage_output_manifests") + if not isinstance(stage_manifest_paths, dict): + return {} + payloads: dict[str, dict[str, Any]] = {} + for stage_id, value in stage_manifest_paths.items(): + if not isinstance(stage_id, str) or value is None: + continue + path = Path(str(value)) + if not path.is_absolute(): + path = artifact_root / path + try: + payload = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + continue + if isinstance(payload, dict): + payloads[stage_id] = payload + return payloads + + +def _stage_output_manifest_ref( + manifest: dict[str, Any], + stage_id: str, +) -> str | None: + stage_manifest_paths = manifest.get("stage_output_manifests") + if not isinstance(stage_manifest_paths, dict): + return None + value = stage_manifest_paths.get(stage_id) + return str(value) if value is not None else None + + +def _stage_lifecycle_status( + stage_output_manifest: dict[str, Any] | None, + *, + saved_status: str, +) -> USStageLifecycleStatus: + if stage_output_manifest is not None: + value = stage_output_manifest.get("lifecycleStatus") + if value in {"pending", "running", "complete", "failed", "deferred"}: + return cast(USStageLifecycleStatus, value) + if stage_output_manifest.get("complete") is True: + return "complete" + if stage_output_manifest.get("complete") is False: + return "pending" + if saved_status == "ready": + return "complete" + if saved_status == "deferred": + return "deferred" + return "pending" + + +def _runtime_optional_str( + stage_output_manifest: dict[str, Any] | None, + key: str, +) -> str | None: + if stage_output_manifest is None: + return None + value = stage_output_manifest.get(key) + return str(value) if value is not None else None + + +def _runtime_mapping_or_none( + stage_output_manifest: dict[str, Any] | None, + key: str, +) -> dict[str, Any] | None: + if stage_output_manifest is None: + return None + value = stage_output_manifest.get(key) + return dict(value) if isinstance(value, dict) else None + + +def _runtime_events( + stage_output_manifest: dict[str, Any] | None, +) -> list[dict[str, Any]]: + if stage_output_manifest is None: + return [] + events = stage_output_manifest.get("events") + if not isinstance(events, list): + return [] + return [dict(event) for event in events if isinstance(event, dict)] + + def _artifact_record( artifact: USStageArtifactContract, *, diff --git a/src/microplex_us/pipelines/stage_manifest_types.py b/src/microplex_us/pipelines/stage_manifest_types.py index 7b5a417..12897a9 100644 --- a/src/microplex_us/pipelines/stage_manifest_types.py +++ b/src/microplex_us/pipelines/stage_manifest_types.py @@ -10,8 +10,8 @@ StageResumeMode, ) -US_STAGE_MANIFEST_SCHEMA_VERSION = 2 -SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS = frozenset({1, 2}) +US_STAGE_MANIFEST_SCHEMA_VERSION = 3 +SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS = frozenset({1, 2, 3}) US_STAGE_ARTIFACT_ROOT = "stage_artifacts" US_POLICYENGINE_ENTITY_STAGE_ID = "06_policyengine_entities" US_VALIDATION_STAGE_ID = "09_validation_benchmarking" @@ -28,6 +28,13 @@ ] USStageValidationStatus = Literal["planned", "manual", "implemented"] +USStageLifecycleStatus = Literal[ + "pending", + "running", + "complete", + "failed", + "deferred", +] class USStageMetric(TypedDict): @@ -67,6 +74,22 @@ class USStageValidationRecord(TypedDict): status: USStageValidationStatus +class USStageFailureRecord(TypedDict, total=False): + """Runtime failure details for one stage.""" + + errorType: str + message: str + traceback: str | None + + +class USStageRuntimeEventRecord(TypedDict, total=False): + """Compact runtime event included in a stage output manifest.""" + + event: str + timestamp: str + details: dict[str, Any] + + class USStageResourceRecord(TypedDict): """Saved-run view of one structured stage input or output.""" @@ -88,6 +111,15 @@ class USStageRecord(TypedDict): title: str purpose: str status: USStageStatus + lifecycleStatus: USStageLifecycleStatus + outputManifest: str | None + startedAt: str | None + updatedAt: str | None + completedAt: str | None + failedAt: str | None + deferredReason: str | None + failure: USStageFailureRecord | None + events: list[USStageRuntimeEventRecord] consumes: list[str] produces: list[str] inputs: list[USStageResourceRecord] @@ -149,12 +181,15 @@ class USValidationEvidenceManifest(TypedDict): "US_STAGE_MANIFEST_SCHEMA_VERSION", "US_VALIDATION_STAGE_ID", "USStageArtifactRecord", + "USStageFailureRecord", + "USStageLifecycleStatus", "USStageManifest", "USStageMetric", "USStageMetricValue", "USStageRecord", "USStageResourceRecord", "USStageResumeRecord", + "USStageRuntimeEventRecord", "USStageStatus", "USStageValidationRecord", "USStageValidationStatus", diff --git a/src/microplex_us/pipelines/stage_run.py b/src/microplex_us/pipelines/stage_run.py index 5f2ee44..240b2e8 100644 --- a/src/microplex_us/pipelines/stage_run.py +++ b/src/microplex_us/pipelines/stage_run.py @@ -31,11 +31,16 @@ write_us_stage_manifest, write_us_validation_evidence_manifest, ) +from microplex_us.pipelines.stage_manifest_types import ( + USStageFailureRecord, + USStageLifecycleStatus, + USStageRuntimeEventRecord, +) from microplex_us.pipelines.stage_readiness import ( write_us_conditional_readiness_report, ) -US_STAGE_OUTPUT_MANIFEST_SCHEMA_VERSION = 1 +US_STAGE_OUTPUT_MANIFEST_SCHEMA_VERSION = 2 USArtifactCategory = Literal[ "required_output", @@ -188,6 +193,14 @@ class USStageOutputManifest: auxiliary_artifacts: Mapping[str, USAuxiliaryArtifact] = field(default_factory=dict) metadata: Mapping[str, Any] = field(default_factory=dict) complete: bool = True + lifecycle_status: USStageLifecycleStatus | None = None + started_at: str | None = None + updated_at: str | None = None + completed_at: str | None = None + failed_at: str | None = None + deferred_reason: str | None = None + failure: USStageFailureRecord | None = None + events: tuple[USStageRuntimeEventRecord, ...] = () stage_id: str = field(default="", init=False) def required_output_keys(self) -> tuple[str, ...]: @@ -247,6 +260,14 @@ def to_dict( "auxiliary_artifacts", "metadata", "complete", + "lifecycle_status", + "started_at", + "updated_at", + "completed_at", + "failed_at", + "deferred_reason", + "failure", + "events", "stage_id", } } @@ -255,6 +276,14 @@ def to_dict( "contractVersion": self.contract_version, "stageId": self.stage_id, "complete": self.complete, + "lifecycleStatus": self.resolved_lifecycle_status(), + "startedAt": self.started_at, + "updatedAt": self.updated_at, + "completedAt": self.completed_at, + "failedAt": self.failed_at, + "deferredReason": self.deferred_reason, + "failure": self.failure, + "events": [_serialize_value(event, artifact_root) for event in self.events], "inputStageManifest": input_stage_manifest or _optional_str(self.input_stage_manifest), "inputOverrides": [ @@ -272,6 +301,13 @@ def to_dict( "metadata": dict(self.metadata), } + def resolved_lifecycle_status(self) -> USStageLifecycleStatus: + """Return explicit lifecycle state or the legacy completion default.""" + + if self.lifecycle_status is not None: + return self.lifecycle_status + return "complete" if self.complete else "pending" + @dataclass(frozen=True) class USRunProfileOutputs(USStageOutputManifest): @@ -1019,6 +1055,16 @@ def build_us_stage_output_manifests_from_artifact_manifest( stage_summary=benchmark_summary, ), complete=bool(has_benchmark), + lifecycle_status=( + "complete" if has_benchmark else "deferred" if has_dataset else None + ), + deferred_reason=( + None + if has_benchmark + else "Stage 8 dataset exists, but validation or benchmark evidence is not attached." + if has_dataset + else None + ), ), ) @@ -1455,7 +1501,10 @@ def _write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: "USStageInputOverride", "USStageInputValidationSettings", "USStageInputValidator", + "USStageFailureRecord", + "USStageLifecycleStatus", "USStageOutputManifest", + "USStageRuntimeEventRecord", "USStageRunWriter", "USValidationBenchmarkingOutputs", "build_us_stage_output_manifests_from_artifact_manifest", diff --git a/src/microplex_us/pipelines/stage_runtime.py b/src/microplex_us/pipelines/stage_runtime.py new file mode 100644 index 0000000..1620abd --- /dev/null +++ b/src/microplex_us/pipelines/stage_runtime.py @@ -0,0 +1,512 @@ +"""Live runtime writer for canonical US pipeline stage manifests.""" + +from __future__ import annotations + +import json +import traceback +from collections.abc import Mapping +from dataclasses import replace +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +from microplex_us.pipelines.stage_contracts import ( + US_CANONICAL_STAGE_IDS, + US_STAGE_CONTRACT_VERSION, + get_us_pipeline_stage_contract, + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_manifest import write_us_stage_manifest +from microplex_us.pipelines.stage_manifest_types import ( + USStageFailureRecord, + USStageLifecycleStatus, + USStageRuntimeEventRecord, +) +from microplex_us.pipelines.stage_run import ( + USArtifactRef, + USDiagnosticOutput, + USStageInputOverride, + USStageOutputManifest, + USStageRunWriter, + _serialize_value, + build_us_stage_output_manifests_from_artifact_manifest, +) + +RuntimeUpdateSection = Literal["outputs", "diagnostics", "metadata"] + + +class USStageRuntimeWriter: + """Write stage manifests incrementally during a canonical US build.""" + + def __init__( + self, + artifact_root: str | Path, + *, + manifest_payload: Mapping[str, Any] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), + ) -> None: + self.artifact_root = Path(artifact_root) + self.manifest_payload: dict[str, Any] = dict(manifest_payload or {}) + self.allow_stage_input_overrides = allow_stage_input_overrides + self.stage_input_overrides = tuple(stage_input_overrides) + self._run_writer = USStageRunWriter( + self.artifact_root, + manifest_payload=self.manifest_payload, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + + @property + def recorded_stages(self) -> tuple[USStageOutputManifest, ...]: + """Return completed typed stage manifests recorded by this writer.""" + + return self._run_writer.recorded_stages + + def start_stage( + self, + stage_id: str, + *, + metadata: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: + """Mark one stage as running after validating its previous stage seam.""" + + self._validate_stage_id(stage_id) + self._validate_start_transition(stage_id) + now = _now() + payload = self._stage_payload(stage_id) + payload["complete"] = False + payload["lifecycleStatus"] = "running" + payload["startedAt"] = payload.get("startedAt") or now + payload["updatedAt"] = now + payload["completedAt"] = None + payload["failedAt"] = None + payload["deferredReason"] = None + payload["failure"] = None + payload["metadata"] = { + **dict(payload.get("metadata", {})), + **dict(metadata or {}), + } + payload["events"] = [ + *list(payload.get("events", [])), + _event("stage_started", now, dict(metadata or {})), + ] + self._write_stage_payload(stage_id, payload) + self._refresh_aggregate() + return payload + + def update( + self, + stage_id: str, + key: str, + value: Any, + *, + section: RuntimeUpdateSection = "outputs", + path: str | Path | None = None, + ) -> dict[str, Any]: + """Update one manifest entry, optionally writing a JSON artifact first.""" + + self._validate_stage_id(stage_id) + if section == "outputs": + self._validate_output_key(stage_id, key) + payload = self._stage_payload(stage_id) + written_value = value + if path is not None: + written_value = self._write_update_artifact(stage_id, key, value, path) + bucket = dict(payload.get(section, {})) + bucket[key] = _runtime_serialize(written_value, self.artifact_root) + payload[section] = bucket + now = _now() + payload["updatedAt"] = now + payload["events"] = [ + *list(payload.get("events", [])), + _event("stage_updated", now, {"section": section, "key": key}), + ] + self._write_stage_payload(stage_id, payload) + self._refresh_aggregate() + return payload + + def record_output( + self, + stage_id: str, + key: str, + value: Any, + *, + path: str | Path | None = None, + ) -> dict[str, Any]: + """Record one stage output entry.""" + + return self.update(stage_id, key, value, section="outputs", path=path) + + def record_diagnostic( + self, + stage_id: str, + diagnostic: USDiagnosticOutput, + ) -> dict[str, Any]: + """Record one diagnostic output for a running stage.""" + + return self.update( + stage_id, + diagnostic.key, + diagnostic, + section="diagnostics", + ) + + def complete_stage(self, outputs: USStageOutputManifest) -> dict[str, Any]: + """Validate, record, and write a complete typed stage output manifest.""" + + self._validate_stage_id(outputs.stage_id) + now = _now() + existing = self._stage_payload(outputs.stage_id) + stage_started_at = _optional_str(existing.get("startedAt")) or now + existing_events = tuple( + dict(event) + for event in existing.get("events", ()) + if isinstance(event, dict) + ) + input_stage_manifest = outputs.input_stage_manifest + if input_stage_manifest is None: + input_stage_manifest = self._previous_stage_manifest_ref(outputs.stage_id) + lifecycle_outputs = replace( + outputs, + input_stage_manifest=input_stage_manifest, + lifecycle_status="complete", + started_at=stage_started_at, + updated_at=now, + completed_at=now, + failed_at=None, + deferred_reason=None, + failure=None, + events=( + *existing_events, + *tuple(outputs.events), + _event("stage_completed", now), + ), + ) + self._run_writer.manifest_payload = self.manifest_payload + self._run_writer.record_stage(lifecycle_outputs) + self.manifest_payload = self._run_writer.write_manifest_files() + return self._stage_payload(outputs.stage_id) + + def fail_stage( + self, + stage_id: str, + error: BaseException, + *, + metadata: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: + """Mark one stage as failed and persist the failure details.""" + + self._validate_stage_id(stage_id) + now = _now() + payload = self._stage_payload(stage_id) + failure: USStageFailureRecord = { + "errorType": type(error).__name__, + "message": str(error), + "traceback": "".join( + traceback.format_exception(type(error), error, error.__traceback__) + ), + } + payload["complete"] = False + payload["lifecycleStatus"] = "failed" + payload["updatedAt"] = now + payload["failedAt"] = now + payload["failure"] = failure + payload["metadata"] = { + **dict(payload.get("metadata", {})), + **dict(metadata or {}), + } + payload["events"] = [ + *list(payload.get("events", [])), + _event("stage_failed", now, {"errorType": type(error).__name__}), + ] + self._write_stage_payload(stage_id, payload) + self._refresh_aggregate() + return payload + + def defer_stage( + self, + stage_id: str, + reason: str, + *, + metadata: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: + """Mark one stage as intentionally deferred.""" + + self._validate_stage_id(stage_id) + now = _now() + payload = self._stage_payload(stage_id) + payload["complete"] = False + payload["lifecycleStatus"] = "deferred" + payload["updatedAt"] = now + payload["deferredReason"] = reason + payload["metadata"] = { + **dict(payload.get("metadata", {})), + **dict(metadata or {}), + } + payload["events"] = [ + *list(payload.get("events", [])), + _event("stage_deferred", now, {"reason": reason}), + ] + self._write_stage_payload(stage_id, payload) + self._refresh_aggregate() + return payload + + def finalize_from_artifact_manifest( + self, + manifest_payload: Mapping[str, Any], + ) -> dict[str, Any]: + """Finalize typed manifests from a completed saved artifact manifest.""" + + self.manifest_payload = dict(manifest_payload) + self._run_writer = USStageRunWriter( + self.artifact_root, + manifest_payload=self.manifest_payload, + allow_stage_input_overrides=self.allow_stage_input_overrides, + stage_input_overrides=self.stage_input_overrides, + ) + for outputs in build_us_stage_output_manifests_from_artifact_manifest( + self.artifact_root, + self.manifest_payload, + ): + existing = self._stage_payload(outputs.stage_id) + now = _now() + lifecycle_status = _final_lifecycle_status(outputs) + existing_events = tuple( + dict(event) + for event in existing.get("events", ()) + if isinstance(event, dict) + ) + lifecycle_outputs = replace( + outputs, + input_stage_manifest=outputs.input_stage_manifest + or self._previous_stage_manifest_ref(outputs.stage_id), + lifecycle_status=lifecycle_status, + started_at=_optional_str(existing.get("startedAt")) or now, + updated_at=now, + completed_at=now if lifecycle_status == "complete" else None, + deferred_reason=( + outputs.deferred_reason if lifecycle_status == "deferred" else None + ), + events=( + *existing_events, + *tuple(outputs.events), + _event(f"stage_{lifecycle_status}", now), + ), + ) + self._run_writer.record_stage(lifecycle_outputs) + self.manifest_payload = self._run_writer.write_manifest_files() + return self.manifest_payload + + def _stage_payload(self, stage_id: str) -> dict[str, Any]: + path = self._stage_output_manifest_path(stage_id) + if path.exists(): + try: + payload = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + payload = {} + if isinstance(payload, dict): + return _ensure_stage_payload_defaults(stage_id, payload) + return _empty_stage_payload(stage_id) + + def _write_stage_payload(self, stage_id: str, payload: Mapping[str, Any]) -> None: + path = self._stage_output_manifest_path(stage_id) + _write_json_atomically(path, payload) + self._register_stage_output_manifest(stage_id, path) + + def _register_stage_output_manifest(self, stage_id: str, path: Path) -> None: + stage_paths = dict(self.manifest_payload.get("stage_output_manifests", {})) + stage_paths[stage_id] = str(path.relative_to(self.artifact_root)) + self.manifest_payload["stage_output_manifests"] = stage_paths + + def _refresh_aggregate(self) -> None: + stage_manifest_path = resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "stage_manifest", + ) + artifacts = dict(self.manifest_payload.get("artifacts", {})) + artifacts.setdefault("stage_manifest", stage_manifest_path.name) + artifacts.setdefault("manifest", "manifest.json") + self.manifest_payload["artifacts"] = artifacts + write_us_stage_manifest( + self.artifact_root, + stage_manifest_path, + manifest_payload=self.manifest_payload, + ) + + def _stage_output_manifest_path(self, stage_id: str) -> Path: + return self.artifact_root / "stage_artifacts" / "manifests" / f"{stage_id}.json" + + def _previous_stage_manifest_ref(self, stage_id: str) -> str | None: + stage_index = US_CANONICAL_STAGE_IDS.index(stage_id) + if stage_index == 0: + return None + previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] + path = self._stage_output_manifest_path(previous_stage_id) + return str(path.relative_to(self.artifact_root)) if path.exists() else None + + def _validate_start_transition(self, stage_id: str) -> None: + stage_index = US_CANONICAL_STAGE_IDS.index(stage_id) + if stage_index == 0: + return + previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] + previous_payload = self._stage_payload(previous_stage_id) + if previous_payload.get("lifecycleStatus") == "complete": + return + contract = get_us_pipeline_stage_contract(stage_id) + required_previous_inputs = tuple( + resource + for resource in contract.inputs + if resource.required and resource.stage_id == previous_stage_id + ) + if required_previous_inputs and all( + self._override_satisfies(stage_id, resource.key) + for resource in required_previous_inputs + ): + return + raise ValueError( + f"{stage_id} requires {previous_stage_id} to be complete before start, " + "unless explicit stage input overrides are enabled" + ) + + def _override_satisfies(self, stage_id: str, key: str) -> bool: + if not self.allow_stage_input_overrides: + return False + return any( + override.stage_id == stage_id and override.key == key + for override in self.stage_input_overrides + ) + + def _validate_output_key(self, stage_id: str, key: str) -> None: + contract = get_us_pipeline_stage_contract(stage_id) + valid_keys = {resource.key for resource in contract.outputs} + valid_keys.update(artifact.key for artifact in contract.artifacts) + if key not in valid_keys: + valid = ", ".join(sorted(valid_keys)) or "none" + raise KeyError(f"Unknown output key {stage_id}.{key}; valid keys: {valid}") + + def _write_update_artifact( + self, + stage_id: str, + key: str, + value: Any, + path: str | Path, + ) -> USArtifactRef: + resolved_path = Path(path) + if not resolved_path.is_absolute(): + resolved_path = self.artifact_root / resolved_path + _write_json_atomically( + resolved_path, _runtime_serialize(value, self.artifact_root) + ) + artifact_contract = get_us_stage_artifact_contract(stage_id, key) + return USArtifactRef( + key=key, + path=resolved_path, + format=artifact_contract.format, + required=artifact_contract.required, + resume_role=artifact_contract.resume_role, + exists=True, + ) + + @staticmethod + def _validate_stage_id(stage_id: str) -> None: + if stage_id not in US_CANONICAL_STAGE_IDS: + raise KeyError(f"Unknown US pipeline stage: {stage_id}") + + +def _empty_stage_payload(stage_id: str) -> dict[str, Any]: + contract = get_us_pipeline_stage_contract(stage_id) + return { + "schemaVersion": 2, + "contractVersion": US_STAGE_CONTRACT_VERSION, + "stageId": stage_id, + "complete": False, + "lifecycleStatus": "pending", + "startedAt": None, + "updatedAt": None, + "completedAt": None, + "failedAt": None, + "deferredReason": None, + "failure": None, + "inputStageManifest": None, + "inputOverrides": [], + "requiredOutputs": [ + resource.key for resource in contract.outputs if resource.required + ], + "missingRequiredOutputs": [ + resource.key for resource in contract.outputs if resource.required + ], + "outputs": {}, + "diagnostics": { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description=f"Runtime diagnostic summary for {stage_id}.", + ).to_dict(), + }, + "auxiliaryArtifacts": {}, + "metadata": {}, + "events": [], + } + + +def _ensure_stage_payload_defaults( + stage_id: str, + payload: dict[str, Any], +) -> dict[str, Any]: + defaults = _empty_stage_payload(stage_id) + merged = {**defaults, **payload} + for key in ("outputs", "diagnostics", "auxiliaryArtifacts", "metadata"): + if not isinstance(merged.get(key), dict): + merged[key] = {} + if not isinstance(merged.get("events"), list): + merged["events"] = [] + return merged + + +def _final_lifecycle_status( + outputs: USStageOutputManifest, +) -> USStageLifecycleStatus: + if outputs.resolved_lifecycle_status() == "deferred": + return "deferred" + return "complete" if outputs.complete else "pending" + + +def _runtime_serialize(value: Any, artifact_root: str | Path | None) -> Any: + if isinstance(value, USDiagnosticOutput): + return value.to_dict(artifact_root) + return _serialize_value(value, artifact_root) + + +def _event( + event: str, + timestamp: str, + details: Mapping[str, Any] | None = None, +) -> USStageRuntimeEventRecord: + return { + "event": event, + "timestamp": timestamp, + "details": dict(details or {}), + } + + +def _now() -> str: + return datetime.now(UTC).isoformat() + + +def _optional_str(value: Any) -> str | None: + return str(value) if value is not None else None + + +def _write_json_atomically(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + temporary = path.with_suffix(path.suffix + ".tmp") + temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) + temporary.replace(path) + + +__all__ = [ + "RuntimeUpdateSection", + "USStageRuntimeWriter", +] diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index 2f965d9..08352f4 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -73,6 +73,11 @@ from microplex_us.pipelines.pe_native_optimization import ( optimize_policyengine_us_native_loss_dataset, ) +from microplex_us.pipelines.stage_run import ( + USDiagnosticOutput, + USSourceLoadingOutputs, +) +from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter from microplex_us.policyengine.aotc import ( qualifying_expenses_from_american_opportunity_credit, ) @@ -1951,11 +1956,65 @@ def total_weighted_population(self) -> float: return float(self.calibrated_data["weight"].sum()) +def _source_loading_stage_outputs( + frames: list[ObservationFrame], +) -> USSourceLoadingOutputs: + frame_summaries: list[dict[str, Any]] = [] + relationship_summaries: dict[str, list[dict[str, Any]]] = {} + source_names: list[str] = [] + for frame in frames: + source_names.append(frame.source.name) + table_rows = { + entity.value: int(len(table)) for entity, table in frame.tables.items() + } + frame_summaries.append( + { + "source": frame.source.name, + "tables": table_rows, + "relationship_count": len(frame.relationships), + } + ) + relationship_summaries[frame.source.name] = [ + { + "parentEntity": relationship.parent_entity.value, + "childEntity": relationship.child_entity.value, + "parentKey": relationship.parent_key, + "childKey": relationship.child_key, + "cardinality": relationship.cardinality, + } + for relationship in frame.relationships + ] + return USSourceLoadingOutputs( + observation_frame_summary={ + "source_count": len(frames), + "frames": frame_summaries, + }, + source_descriptors=tuple(dict.fromkeys(source_names)), + source_relationships=relationship_summaries, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description="Runtime source-loading summary.", + summary={ + "source_names": source_names, + "source_count": len(frames), + }, + ) + }, + ) + + class USMicroplexPipeline: """End-to-end build orchestration for a US microplex dataset.""" - def __init__(self, config: USMicroplexBuildConfig | None = None): + def __init__( + self, + config: USMicroplexBuildConfig | None = None, + *, + stage_runtime_writer: USStageRuntimeWriter | None = None, + ): self.config = config or USMicroplexBuildConfig() + self.stage_runtime_writer = stage_runtime_writer def build_from_data_dir(self, data_dir: str | Path) -> USMicroplexBuildResult: from microplex_us.data_sources.cps import ( @@ -2005,12 +2064,23 @@ def build_from_source_providers( "USMicroplexPipeline requires at least one source provider" ) + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.start_stage("02_source_loading") frames: list[ObservationFrame] = [] - for provider in providers: - frame = provider.load_frame( - self._resolve_source_query(provider, queries or {}) + try: + for provider in providers: + frame = provider.load_frame( + self._resolve_source_query(provider, queries or {}) + ) + frames.append(frame) + except Exception as exc: + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.fail_stage("02_source_loading", exc) + raise + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.complete_stage( + _source_loading_stage_outputs(frames) ) - frames.append(frame) return self.build_from_frames(frames) def build_from_frame(self, frame: ObservationFrame) -> USMicroplexBuildResult: diff --git a/tests/pipelines/test_artifacts.py b/tests/pipelines/test_artifacts.py index b5241ba..0bf1c63 100644 --- a/tests/pipelines/test_artifacts.py +++ b/tests/pipelines/test_artifacts.py @@ -3,6 +3,7 @@ import json import sqlite3 from pathlib import Path +from types import SimpleNamespace import h5py import pandas as pd @@ -10,6 +11,7 @@ from microplex.targets import StaticTargetProvider, TargetQuery, TargetSet, TargetSpec from microplex_us.pipelines.artifacts import ( + build_and_save_versioned_us_microplex_from_source_providers, replay_us_microplex_policyengine_stage_from_artifact, save_us_microplex_artifacts, ) @@ -28,6 +30,65 @@ ) +def test_source_provider_versioned_build_initializes_live_stage_writer( + tmp_path, + monkeypatch, +) -> None: + captured: dict[str, object] = {} + + class FakePipeline: + def __init__(self, config, *, stage_runtime_writer=None): + captured["config"] = config + captured["pipeline_stage_runtime_writer"] = stage_runtime_writer + + def build_from_source_providers(self, providers, queries=None): + captured["providers"] = providers + captured["queries"] = queries + return "build-result" + + def fake_finalize(build_result, **kwargs): + captured["build_result"] = build_result + captured["finalize_stage_runtime_writer"] = kwargs.get("stage_runtime_writer") + captured.update(kwargs) + return "finalized" + + monkeypatch.setattr( + "microplex_us.pipelines.artifacts.USMicroplexPipeline", + FakePipeline, + ) + monkeypatch.setattr( + "microplex_us.pipelines.artifacts._finalize_versioned_build_artifacts", + fake_finalize, + ) + provider = SimpleNamespace(descriptor=SimpleNamespace(name="unit_source")) + + result = build_and_save_versioned_us_microplex_from_source_providers( + [provider], + tmp_path, + config=USMicroplexBuildConfig(calibration_backend="none"), + version_id="runtime-test", + ) + + output_dir = tmp_path / "runtime-test" + stage1_manifest = json.loads( + ( + output_dir / "stage_artifacts" / "manifests" / "01_run_profile.json" + ).read_text() + ) + + assert result == "finalized" + assert captured["build_result"] == "build-result" + assert captured["preallocated_output_dir"] == output_dir + assert ( + captured["pipeline_stage_runtime_writer"] + is captured["finalize_stage_runtime_writer"] + ) + assert stage1_manifest["lifecycleStatus"] == "complete" + assert stage1_manifest["outputs"]["provider_query_plan"]["provider_names"] == [ + "unit_source" + ] + + def test_replay_policyengine_stage_from_artifact_uses_saved_synthetic( tmp_path, monkeypatch, @@ -115,7 +176,11 @@ def calibrate_policyengine_tables(self, tables): captured["tables"] = tables calibrated = captured["table_input"].copy() calibrated["weight"] = calibrated["weight"] * 2.0 - return "policyengine_tables", calibrated, {"backend": "policyengine_db_none"} + return ( + "policyengine_tables", + calibrated, + {"backend": "policyengine_db_none"}, + ) monkeypatch.setattr( "microplex_us.pipelines.artifacts.USMicroplexPipeline", @@ -244,7 +309,9 @@ def test_writes_expected_files(self, tmp_path): seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), scaffold_seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}), - calibrated_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [0.5, 1.5]}), + calibrated_data=pd.DataFrame( + {"income": [10.0, 20.0], "weight": [0.5, 1.5]} + ), targets=USMicroplexTargets( marginal={"state": {"CA": 2.0}}, continuous={"income": 30.0}, @@ -351,12 +418,14 @@ def test_writes_expected_files(self, tmp_path): for record in artifact_inventory["artifacts"] } assert inventory_records[("01_run_profile", "manifest")]["exists"] is True - assert inventory_records[ - ("08_dataset_assembly", "policyengine_dataset") - ]["classification"] == "post_artifact_evidence" + assert ( + inventory_records[("08_dataset_assembly", "policyengine_dataset")][ + "classification" + ] + == "post_artifact_evidence" + ) readiness = { - stage["stageId"]: stage - for stage in conditional_readiness["stages"] + stage["stageId"]: stage for stage in conditional_readiness["stages"] } assert readiness["09_validation_benchmarking"]["readiness"] == ( "post_artifact_evidence" @@ -367,8 +436,7 @@ def test_writes_expected_files(self, tmp_path): ) assert source_diagnostics["summary"]["support_household_weight_share"] == 0.0 assert ( - source_diagnostics["summary"]["puf_support_household_weight_share"] - == 0.0 + source_diagnostics["summary"]["puf_support_household_weight_share"] == 0.0 ) assert source_diagnostics["summary"]["total_household_weight"] == 2.0 assert source_diagnostics["summary"]["total_person_weight"] == 2.0 @@ -539,7 +607,9 @@ def test_writes_child_tax_unit_agi_drift_summary(self, tmp_path): ), spm_units=pd.DataFrame({"spm_unit_id": [20], "household_id": [1]}), families=pd.DataFrame({"family_id": [30], "household_id": [1]}), - marital_units=pd.DataFrame({"marital_unit_id": [40], "household_id": [1]}), + marital_units=pd.DataFrame( + {"marital_unit_id": [40], "household_id": [1]} + ), ), ) baseline_dataset = _write_baseline_dataset( @@ -602,7 +672,9 @@ def test_writes_policyengine_harness_when_baseline_and_targets_are_provided( ), seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}), - calibrated_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [0.5, 1.5]}), + calibrated_data=pd.DataFrame( + {"income": [10.0, 20.0], "weight": [0.5, 1.5]} + ), targets=USMicroplexTargets( marginal={"state": {"CA": 2.0}}, continuous={"income": 30.0}, @@ -723,14 +795,21 @@ def test_writes_policyengine_harness_when_baseline_and_targets_are_provided( assert paths.policyengine_harness.exists() manifest = json.loads(paths.manifest.read_text()) - assert manifest["artifacts"]["policyengine_harness"] == "policyengine_harness.json" + assert ( + manifest["artifacts"]["policyengine_harness"] == "policyengine_harness.json" + ) assert manifest["policyengine_harness"]["slice_win_rate"] == 1.0 assert manifest["policyengine_harness"]["target_win_rate"] == 1.0 - assert manifest["policyengine_harness"]["candidate_composite_parity_loss"] is not None + assert ( + manifest["policyengine_harness"]["candidate_composite_parity_loss"] + is not None + ) harness_payload = json.loads(paths.policyengine_harness.read_text()) assert harness_payload["metadata"]["baseline_dataset"] == "baseline.h5" - assert harness_payload["metadata"]["policyengine_us_runtime_version"] is not None + assert ( + harness_payload["metadata"]["policyengine_us_runtime_version"] is not None + ) assert harness_payload["summary"]["slice_win_rate"] == 1.0 assert harness_payload["summary"]["candidate_composite_parity_loss"] is not None @@ -785,7 +864,9 @@ def test_can_defer_policyengine_harness_generation(self, monkeypatch, tmp_path): ), seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}), - calibrated_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [0.5, 1.5]}), + calibrated_data=pd.DataFrame( + {"income": [10.0, 20.0], "weight": [0.5, 1.5]} + ), targets=USMicroplexTargets( marginal={"state": {"CA": 2.0}}, continuous={"income": 30.0}, @@ -931,7 +1012,9 @@ def test_writes_policyengine_harness_from_build_config_defaults(self, tmp_path): ), seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}), - calibrated_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [0.5, 1.5]}), + calibrated_data=pd.DataFrame( + {"income": [10.0, 20.0], "weight": [0.5, 1.5]} + ), targets=USMicroplexTargets( marginal={"state": {"CA": 2.0}}, continuous={"income": 30.0}, @@ -993,10 +1076,16 @@ def test_writes_policyengine_harness_from_build_config_defaults(self, tmp_path): manifest = json.loads(paths.manifest.read_text()) assert manifest["policyengine_harness"]["slice_win_rate"] == 1.0 assert manifest["policyengine_harness"]["target_win_rate"] == 1.0 - assert manifest["policyengine_harness"]["candidate_composite_parity_loss"] is not None - assert manifest["policyengine_harness"]["parity_scorecard"]["overall"][ - "candidate_beats_baseline" - ] is True + assert ( + manifest["policyengine_harness"]["candidate_composite_parity_loss"] + is not None + ) + assert ( + manifest["policyengine_harness"]["parity_scorecard"]["overall"][ + "candidate_beats_baseline" + ] + is True + ) assert manifest["run_registry"]["artifact_id"] == "bundle" assert manifest["run_registry"]["improved_candidate_frontier"] is True assert manifest["run_registry"]["improved_composite_frontier"] is True @@ -1008,11 +1097,18 @@ def test_writes_policyengine_harness_from_build_config_defaults(self, tmp_path): harness_payload = json.loads(paths.policyengine_harness.read_text()) assert harness_payload["metadata"]["baseline_dataset"] == "baseline.h5" assert harness_payload["metadata"]["targets_db"] == "policyengine_targets.db" - assert harness_payload["metadata"]["harness_suite"] == "policyengine_us_all_targets" + assert ( + harness_payload["metadata"]["harness_suite"] + == "policyengine_us_all_targets" + ) assert harness_payload["metadata"]["harness_slice_names"] == ["all_targets"] assert harness_payload["metadata"]["target_variables"] == ["household_count"] - assert harness_payload["metadata"]["policyengine_us_runtime_version"] is not None - assert [slice_payload["name"] for slice_payload in harness_payload["slices"]] == [ + assert ( + harness_payload["metadata"]["policyengine_us_runtime_version"] is not None + ) + assert [ + slice_payload["name"] for slice_payload in harness_payload["slices"] + ] == [ "all_targets", ] registry_entries = load_us_microplex_run_registry(paths.run_registry) @@ -1021,7 +1117,9 @@ def test_writes_policyengine_harness_from_build_config_defaults(self, tmp_path): assert registry_entries[0].policyengine_us_runtime_version is not None assert registry_entries[0].supported_target_rate == 1.0 assert registry_entries[0].candidate_composite_parity_loss is not None - assert registry_entries[0].tag_summaries["all_targets"]["target_win_rate"] == 1.0 + assert ( + registry_entries[0].tag_summaries["all_targets"]["target_win_rate"] == 1.0 + ) def test_writes_policyengine_native_scores_when_available( self, monkeypatch, tmp_path @@ -1057,7 +1155,9 @@ def test_writes_policyengine_native_scores_when_available( ), seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}), - calibrated_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [0.5, 1.5]}), + calibrated_data=pd.DataFrame( + {"income": [10.0, 20.0], "weight": [0.5, 1.5]} + ), targets=USMicroplexTargets(marginal={}, continuous={"income": 30.0}), calibration_summary={"max_error": 0.01, "mean_error": 0.005}, synthesis_metadata={"backend": "bootstrap"}, @@ -1124,7 +1224,9 @@ def test_writes_policyengine_native_scores_when_available( manifest["policyengine_native_scores"]["candidate_enhanced_cps_native_loss"] == 0.25 ) - assert manifest["policyengine_native_scores"]["candidate_beats_baseline"] is True + assert ( + manifest["policyengine_native_scores"]["candidate_beats_baseline"] is True + ) assert ( manifest["run_registry"]["default_frontier_metric"] == "enhanced_cps_native_loss_delta" @@ -1154,7 +1256,9 @@ def _boom(**_kwargs): ), seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}), - calibrated_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [0.5, 1.5]}), + calibrated_data=pd.DataFrame( + {"income": [10.0, 20.0], "weight": [0.5, 1.5]} + ), targets=USMicroplexTargets(marginal={}, continuous={"income": 30.0}), calibration_summary={"max_error": 0.01, "mean_error": 0.005}, synthesis_metadata={"backend": "bootstrap"}, diff --git a/tests/pipelines/test_stage_manifest.py b/tests/pipelines/test_stage_manifest.py index 954e9ac..bf0968d 100644 --- a/tests/pipelines/test_stage_manifest.py +++ b/tests/pipelines/test_stage_manifest.py @@ -78,7 +78,7 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): payload = build_us_stage_manifest(tmp_path, manifest_payload=manifest) - assert payload["schemaVersion"] == 2 + assert payload["schemaVersion"] == 3 assert payload["generatedAt"] == "2026-05-28T00:00:00+00:00" assert [stage["id"] for stage in payload["stages"]] == [ "01_run_profile", @@ -111,7 +111,7 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): assert stage5_artifacts["synthetic_data"]["hash_mode"] == "file_sha256" -def test_load_us_stage_manifest_accepts_v1_and_v2(tmp_path): +def test_load_us_stage_manifest_accepts_v1_v2_and_v3(tmp_path): v1_path = tmp_path / "stage_manifest_v1.json" v1_path.write_text( json.dumps( @@ -140,9 +140,24 @@ def test_load_us_stage_manifest_accepts_v1_and_v2(tmp_path): } ) ) + v3_path = tmp_path / "stage_manifest_v3.json" + v3_path.write_text( + json.dumps( + { + "schemaVersion": 3, + "contractVersion": "us-runtime-stages-v2", + "generatedAt": None, + "pipeline": "us_microplex", + "artifactRoot": ".", + "manifest": "manifest.json", + "stages": [], + } + ) + ) assert load_us_stage_manifest(v1_path)["schemaVersion"] == 1 assert load_us_stage_manifest(v2_path)["schemaVersion"] == 2 + assert load_us_stage_manifest(v3_path)["schemaVersion"] == 3 def test_build_us_stage_manifest_keeps_empty_validation_index_deferred(tmp_path): @@ -283,7 +298,9 @@ def test_write_us_stage_manifest_and_resolve_artifact_path(tmp_path): ) assert dataset_path == tmp_path / "policyengine_us.h5" - assert stage_summary_for_data_flow_snapshot(loaded)[7]["id"] == "08_dataset_assembly" + assert ( + stage_summary_for_data_flow_snapshot(loaded)[7]["id"] == "08_dataset_assembly" + ) def test_policyengine_entity_stage_artifact_round_trips_partial_bundle(tmp_path): diff --git a/tests/pipelines/test_stage_run.py b/tests/pipelines/test_stage_run.py index c4d0eba..b2a6b58 100644 --- a/tests/pipelines/test_stage_run.py +++ b/tests/pipelines/test_stage_run.py @@ -32,6 +32,14 @@ "auxiliary_artifacts", "metadata", "complete", + "lifecycle_status", + "started_at", + "updated_at", + "completed_at", + "failed_at", + "deferred_reason", + "failure", + "events", "stage_id", } diff --git a/tests/pipelines/test_stage_runtime.py b/tests/pipelines/test_stage_runtime.py new file mode 100644 index 0000000..d54a0cf --- /dev/null +++ b/tests/pipelines/test_stage_runtime.py @@ -0,0 +1,105 @@ +"""Tests for live US stage runtime manifest updates.""" + +import json + +import pytest + +from microplex_us.pipelines.stage_run import ( + USArtifactRef, + USDiagnosticOutput, + USRunProfileOutputs, + USSourceLoadingOutputs, +) +from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter + + +def _diagnostics(stage_id: str) -> dict[str, USDiagnosticOutput]: + return { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description=f"Summary for {stage_id}.", + summary={"stage": stage_id}, + ) + } + + +def test_runtime_writer_requires_previous_stage_completion_to_start(tmp_path): + writer = USStageRuntimeWriter(tmp_path) + + with pytest.raises(ValueError, match="01_run_profile to be complete"): + writer.start_stage("02_source_loading") + + writer.start_stage("01_run_profile") + + with pytest.raises(ValueError, match="01_run_profile to be complete"): + writer.start_stage("02_source_loading") + + +def test_runtime_writer_completes_stage_and_exposes_lifecycle(tmp_path): + writer = USStageRuntimeWriter(tmp_path) + writer.start_stage("01_run_profile", metadata={"profile": "test"}) + writer.complete_stage( + USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config={"calibration_backend": "none"}, + provider_query_plan={"source_names": ["unit"]}, + diagnostics=_diagnostics("01_run_profile"), + ) + ) + + writer.start_stage("02_source_loading") + writer.complete_stage( + USSourceLoadingOutputs( + observation_frame_summary={"source_count": 1}, + source_descriptors=("unit",), + source_relationships={"household_person": "ok"}, + diagnostics=_diagnostics("02_source_loading"), + ) + ) + + stage2_path = tmp_path / "stage_artifacts" / "manifests" / "02_source_loading.json" + stage2 = json.loads(stage2_path.read_text()) + aggregate = json.loads((tmp_path / "stage_manifest.json").read_text()) + aggregate_stage2 = {stage["id"]: stage for stage in aggregate["stages"]}[ + "02_source_loading" + ] + + assert stage2["lifecycleStatus"] == "complete" + assert stage2["inputStageManifest"] == ( + "stage_artifacts/manifests/01_run_profile.json" + ) + assert aggregate_stage2["lifecycleStatus"] == "complete" + assert aggregate_stage2["outputManifest"] == ( + "stage_artifacts/manifests/02_source_loading.json" + ) + assert aggregate_stage2["completedAt"] is not None + assert [event["event"] for event in stage2["events"]] == [ + "stage_started", + "stage_completed", + ] + + +def test_runtime_writer_update_writes_json_artifact_reference(tmp_path): + writer = USStageRuntimeWriter(tmp_path) + payload = writer.record_output( + "03_source_planning", + "source_plan", + {"scaffoldSource": "cps"}, + path="stage_artifacts/03_source_planning/source_plan.json", + ) + + source_plan_path = ( + tmp_path / "stage_artifacts" / "03_source_planning" / "source_plan.json" + ) + + assert json.loads(source_plan_path.read_text()) == {"scaffoldSource": "cps"} + assert payload["outputs"]["source_plan"]["path"] == ( + "stage_artifacts/03_source_planning/source_plan.json" + ) + assert payload["outputs"]["source_plan"]["exists"] is True From 6be1cbb0731b532bbcb8ea4ce7967b09cabf7024 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 1 Jun 2026 23:01:54 +0200 Subject: [PATCH 2/7] Complete live runtime stage manifests --- pyproject.toml | 1 + src/microplex_us/pipelines/__init__.py | 19 +- src/microplex_us/pipelines/artifacts.py | 232 +++++- src/microplex_us/pipelines/stage9_replay.py | 264 +++++++ .../pipelines/stage_manifest_io.py | 1 + src/microplex_us/pipelines/stage_run.py | 3 + src/microplex_us/pipelines/stage_runtime.py | 73 +- src/microplex_us/pipelines/us.py | 660 ++++++++++++++---- tests/pipelines/test_stage9_replay.py | 101 +++ tests/pipelines/test_stage_runtime.py | 101 +++ tests/pipelines/test_versioned_artifacts.py | 45 +- 11 files changed, 1327 insertions(+), 173 deletions(-) create mode 100644 src/microplex_us/pipelines/stage9_replay.py create mode 100644 tests/pipelines/test_stage9_replay.py diff --git a/pyproject.toml b/pyproject.toml index 6eff96e..3ddaa4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ microplex-us-pe-native-target-diagnostics = "microplex_us.pipelines.pe_native_sc microplex-us-r2-archive-artifact = "microplex_us.pipelines.r2_artifacts:main" microplex-us-reweight-cd-age-targets = "microplex_us.pipelines.cd_age_reweighting:main" microplex-us-score-pe-native-loss = "microplex_us.pipelines.pe_native_scores:main" +microplex-us-stage9-replay = "microplex_us.pipelines.stage9_replay:main" microplex-us-write-transparency-sidecars = "microplex_us.pipelines.transparency_sidecars:main" microplex-us-version-bump-benchmark = "microplex_us.pipelines.version_benchmark:main" diff --git a/src/microplex_us/pipelines/__init__.py b/src/microplex_us/pipelines/__init__.py index 5ddd3c0..59546e3 100644 --- a/src/microplex_us/pipelines/__init__.py +++ b/src/microplex_us/pipelines/__init__.py @@ -358,6 +358,13 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: "USStageRuntimeWriter", ), ), + **_exports( + "microplex_us.pipelines.stage9_replay", + ( + "USStage9ReplayResult", + "replay_us_stage9_validation_benchmarking", + ), + ), **_exports( "microplex_us.pipelines.summarize_pe_native_family_drilldown", ( @@ -367,21 +374,15 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: ), **_exports( "microplex_us.pipelines.summarize_pe_native_regressions", - ( - "summarize_us_pe_native_regressions", - ), + ("summarize_us_pe_native_regressions",), ), **_exports( "microplex_us.pipelines.summarize_policyengine_oracle_regressions", - ( - "summarize_us_policyengine_oracle_regressions", - ), + ("summarize_us_policyengine_oracle_regressions",), ), **_exports( "microplex_us.pipelines.summarize_policyengine_oracle_target_drilldown", - ( - "summarize_us_policyengine_oracle_target_drilldown", - ), + ("summarize_us_policyengine_oracle_target_drilldown",), ), **_exports( "microplex_us.pipelines.source_stage_parity", diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 94c9ea8..974182a 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -39,6 +39,7 @@ select_us_microplex_frontier_entry, ) from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, resolve_us_stage_artifact_contract_path, ) from microplex_us.pipelines.stage_manifest import ( @@ -46,9 +47,11 @@ ) from microplex_us.pipelines.stage_run import ( USArtifactRef, + USDatasetAssemblyOutputs, USDiagnosticOutput, USRunProfileOutputs, USStageInputOverride, + USValidationBenchmarkingOutputs, write_us_stage_run_manifests_from_artifact_manifest, ) from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter @@ -121,6 +124,85 @@ class USMicroplexVersionedBuildArtifacts: frontier_delta: float | None = None +def _stage_artifact_ref( + artifact_root: str | Path, + stage_id: str, + artifact_key: str, + *, + assume_exists: bool = False, +) -> USArtifactRef: + contract = get_us_stage_artifact_contract(stage_id, artifact_key) + return USArtifactRef( + key=artifact_key, + path=resolve_us_stage_artifact_contract_path( + artifact_root, + stage_id, + artifact_key, + ), + format=contract.format, + required=contract.required, + resume_role=contract.resume_role, + assume_exists=assume_exists, + ) + + +def _stage_diagnostics( + stage_id: str, + summary: Mapping[str, Any], +) -> dict[str, USDiagnosticOutput]: + return { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description=f"Runtime diagnostic summary for {stage_id}.", + summary=dict(summary), + ) + } + + +def _write_parquet_unless_live_artifact_exists( + path: Path, + frame: pd.DataFrame, + *, + live_artifact: bool, +) -> None: + if live_artifact and path.exists(): + return + path.parent.mkdir(parents=True, exist_ok=True) + frame.to_parquet(path, index=False) + + +def _write_json_unless_live_artifact_exists( + path: Path, + payload: Mapping[str, Any], + *, + live_artifact: bool, +) -> None: + if live_artifact and path.exists(): + return + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + + +def _stage9_benchmark_summary(manifest: Mapping[str, Any]) -> dict[str, Any]: + summary: dict[str, Any] = {} + for key in ( + "policyengine_harness", + "policyengine_native_scores", + "policyengine_native_audit", + "imputation_ablation", + ): + value = manifest.get(key) + if isinstance(value, Mapping): + summary[key] = dict(value) + diagnostics = manifest.get("diagnostics") + if isinstance(diagnostics, Mapping): + for key in ("child_tax_unit_agi_drift", "capital_gains_lots"): + value = diagnostics.get(key) + if isinstance(value, Mapping): + summary[key] = dict(value) + return summary + + def replay_us_microplex_policyengine_stage_from_artifact( artifact_dir: str | Path, *, @@ -937,29 +1019,48 @@ def save_us_microplex_artifacts( resolved_run_registry_path = None resolved_run_index_path = None harness_payload = None + live_artifacts = stage_runtime_writer is not None + + if stage_runtime_writer is not None: + stage_runtime_writer.start_stage("08_dataset_assembly") if result.scaffold_seed_data is not None and scaffold_seed_data_path is not None: - scaffold_seed_data_path.parent.mkdir(parents=True, exist_ok=True) - result.scaffold_seed_data.to_parquet(scaffold_seed_data_path, index=False) - result.seed_data.to_parquet(seed_data_path, index=False) - result.synthetic_data.to_parquet(synthetic_data_path, index=False) - result.calibrated_data.to_parquet(calibrated_data_path, index=False) - targets_path.write_text( - json.dumps( - { - "marginal": result.targets.marginal, - "continuous": result.targets.continuous, - }, - indent=2, - sort_keys=True, + _write_parquet_unless_live_artifact_exists( + scaffold_seed_data_path, + result.scaffold_seed_data, + live_artifact=live_artifacts, ) + _write_parquet_unless_live_artifact_exists( + seed_data_path, + result.seed_data, + live_artifact=live_artifacts, + ) + _write_parquet_unless_live_artifact_exists( + synthetic_data_path, + result.synthetic_data, + live_artifact=live_artifacts, + ) + _write_parquet_unless_live_artifact_exists( + calibrated_data_path, + result.calibrated_data, + live_artifact=live_artifacts, + ) + _write_json_unless_live_artifact_exists( + targets_path, + { + "marginal": result.targets.marginal, + "continuous": result.targets.continuous, + }, + live_artifact=live_artifacts, ) if result.synthesizer is not None and synthesizer_path is not None: result.synthesizer.save(synthesizer_path) - _write_us_source_plan_artifact(result, source_plan_path) - _write_json_atomically(calibration_summary_path, result.calibration_summary) + if not (live_artifacts and source_plan_path.exists()): + _write_us_source_plan_artifact(result, source_plan_path) + if not (live_artifacts and calibration_summary_path.exists()): + _write_json_atomically(calibration_summary_path, result.calibration_summary) source_weight_diagnostics_payload = _build_source_weight_diagnostics(result) _write_json_atomically( source_weight_diagnostics_path, @@ -968,10 +1069,11 @@ def save_us_microplex_artifacts( if result.policyengine_tables is not None and policyengine_dataset_path is not None: if policyengine_entity_tables_path is not None: - write_us_policyengine_entity_stage_artifact( - result.policyengine_tables, - output_dir, - ) + if not (live_artifacts and policyengine_entity_tables_path.exists()): + write_us_policyengine_entity_stage_artifact( + result.policyengine_tables, + output_dir, + ) period = result.config.policyengine_dataset_year or 2024 USMicroplexPipeline(result.config).export_policyengine_dataset( result, @@ -982,6 +1084,57 @@ def save_us_microplex_artifacts( _maybe_write_capital_gains_lot_artifact(result, output_dir) ) + if stage_runtime_writer is not None: + stage_runtime_writer.complete_stage( + USDatasetAssemblyOutputs( + policyengine_dataset=( + _stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "policyengine_dataset", + ) + if policyengine_dataset_path is not None + else None + ), + stage_manifest=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "stage_manifest", + assume_exists=True, + ), + data_flow_snapshot=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "data_flow_snapshot", + assume_exists=True, + ), + artifact_inventory=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "artifact_inventory", + assume_exists=True, + ), + conditional_readiness=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "conditional_readiness", + assume_exists=True, + ), + diagnostics=_stage_diagnostics( + "08_dataset_assembly", + { + "policyengine_dataset": ( + str(policyengine_dataset_path.relative_to(output_dir)) + if policyengine_dataset_path is not None + else None + ), + "has_capital_gains_lots": capital_gains_lots_path is not None, + }, + ), + ) + ) + stage_runtime_writer.start_stage("09_validation_benchmarking") + ( resolved_target_provider, resolved_baseline_dataset, @@ -1234,6 +1387,47 @@ def save_us_microplex_artifacts( "artifact_id": recorded_entry.artifact_id, } if stage_runtime_writer is not None: + stage_runtime_writer.manifest_payload = manifest + stage9_summary = _stage9_benchmark_summary(manifest) + if stage9_summary: + stage_runtime_writer.complete_stage( + USValidationBenchmarkingOutputs( + validation_evidence=_stage_artifact_ref( + output_dir, + "09_validation_benchmarking", + "validation_evidence", + assume_exists=True, + ), + benchmark_summary=stage9_summary, + policyengine_harness=( + _stage_artifact_ref( + output_dir, + "09_validation_benchmarking", + "policyengine_harness", + ) + if policyengine_harness_path is not None + else None + ), + policyengine_native_scores=( + _stage_artifact_ref( + output_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) + if policyengine_native_scores_path is not None + else None + ), + diagnostics=_stage_diagnostics( + "09_validation_benchmarking", + stage9_summary, + ), + ) + ) + else: + stage_runtime_writer.defer_stage( + "09_validation_benchmarking", + "No validation or benchmark evidence was configured for this run.", + ) manifest = stage_runtime_writer.finalize_from_artifact_manifest(manifest) else: manifest = write_us_stage_run_manifests_from_artifact_manifest( diff --git a/src/microplex_us/pipelines/stage9_replay.py b/src/microplex_us/pipelines/stage9_replay.py new file mode 100644 index 0000000..a45a71c --- /dev/null +++ b/src/microplex_us/pipelines/stage9_replay.py @@ -0,0 +1,264 @@ +"""Safe Stage 9 validation and benchmarking replay helpers.""" + +from __future__ import annotations + +import argparse +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from microplex_us.pipelines.pe_native_scores import compute_us_pe_native_scores +from microplex_us.pipelines.stage_manifest_io import write_json_atomically +from microplex_us.pipelines.stage_validation_evidence import ( + build_us_validation_evidence_manifest, +) + + +@dataclass(frozen=True) +class USStage9ReplayResult: + """Artifacts written by a Stage 9 replay.""" + + output_dir: Path + replay_manifest: Path + validation_evidence: Path + policyengine_harness: Path | None = None + policyengine_native_scores: Path | None = None + + +def replay_us_stage9_validation_benchmarking( + artifact_dir: str | Path, + *, + output_dir: str | Path | None = None, + baseline_dataset: str | Path | None = None, + policyengine_us_data_repo: str | Path | None = None, + period: int | None = None, + precomputed_policyengine_harness: str | Path | dict[str, Any] | None = None, + precomputed_policyengine_native_scores: str | Path | dict[str, Any] | None = None, + run_id: str | None = None, + allow_overwrite: bool = False, +) -> USStage9ReplayResult: + """Rerun safe Stage 9 evidence against an existing Stage 8 dataset. + + The original artifact bundle is left untouched. New evidence is written under + a replay directory and indexed by a replay-local evidence manifest. + """ + + artifact_root = Path(artifact_dir).expanduser().resolve() + manifest_path = artifact_root / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"Saved artifact manifest not found: {manifest_path}") + manifest = json.loads(manifest_path.read_text()) + dataset_path = _validated_stage8_dataset_path(artifact_root, manifest) + + resolved_output_dir = _resolve_replay_output_dir( + artifact_root, + output_dir=output_dir, + run_id=run_id, + ) + if resolved_output_dir.exists() and any(resolved_output_dir.iterdir()): + if not allow_overwrite: + raise FileExistsError( + f"Stage 9 replay output directory already exists and is not empty: " + f"{resolved_output_dir}" + ) + resolved_output_dir.mkdir(parents=True, exist_ok=True) + + replay_manifest_payload = dict(manifest) + replay_artifacts = dict(manifest.get("artifacts", {})) + summaries: dict[str, Any] = {} + + harness_path = None + harness_payload = _load_optional_payload(precomputed_policyengine_harness) + if harness_payload is not None: + harness_path = resolved_output_dir / "policyengine_harness.json" + write_json_atomically(harness_path, harness_payload) + replay_artifacts["policyengine_harness"] = _relative_to_root( + harness_path, + artifact_root, + ) + if isinstance(harness_payload.get("summary"), dict): + summaries["policyengine_harness"] = dict(harness_payload["summary"]) + + native_scores_path = None + native_scores_payload = _load_optional_payload( + precomputed_policyengine_native_scores + ) + if native_scores_payload is None and baseline_dataset is not None: + native_scores_payload = compute_us_pe_native_scores( + candidate_dataset_path=dataset_path, + baseline_dataset_path=baseline_dataset, + period=period + or int( + dict(manifest.get("config", {})).get( + "policyengine_dataset_year", + 2024, + ) + ), + policyengine_us_data_repo=policyengine_us_data_repo, + ) + if native_scores_payload is not None: + native_scores_path = resolved_output_dir / "policyengine_native_scores.json" + write_json_atomically(native_scores_path, native_scores_payload) + replay_artifacts["policyengine_native_scores"] = _relative_to_root( + native_scores_path, + artifact_root, + ) + if isinstance(native_scores_payload.get("summary"), dict): + summaries["policyengine_native_scores"] = dict( + native_scores_payload["summary"] + ) + + if not summaries: + raise ValueError( + "Stage 9 replay did not produce evidence. Supply precomputed evidence " + "or a baseline dataset for native scoring." + ) + + evidence_path = resolved_output_dir / "evidence_manifest.json" + replay_artifacts["validation_evidence"] = _relative_to_root( + evidence_path, + artifact_root, + ) + replay_manifest_payload["artifacts"] = replay_artifacts + replay_manifest_payload.update(summaries) + replay_manifest_payload["stage9_replay"] = { + "created_at": datetime.now(UTC).isoformat(), + "source_artifact_dir": str(artifact_root), + "source_manifest": str(manifest_path), + "source_policyengine_dataset": _relative_to_root(dataset_path, artifact_root), + "output_dir": _relative_to_root(resolved_output_dir, artifact_root), + } + write_json_atomically( + evidence_path, + build_us_validation_evidence_manifest( + artifact_root, + manifest_payload=replay_manifest_payload, + ), + ) + replay_manifest_path = resolved_output_dir / "replay_manifest.json" + write_json_atomically(replay_manifest_path, replay_manifest_payload) + return USStage9ReplayResult( + output_dir=resolved_output_dir, + replay_manifest=replay_manifest_path, + validation_evidence=evidence_path, + policyengine_harness=harness_path, + policyengine_native_scores=native_scores_path, + ) + + +def _validated_stage8_dataset_path( + artifact_root: Path, + manifest: dict[str, Any], +) -> Path: + artifacts = dict(manifest.get("artifacts", {})) + dataset_value = artifacts.get("policyengine_dataset") + if not dataset_value: + raise ValueError("Stage 8 policyengine_dataset artifact is not declared") + dataset_path = Path(str(dataset_value)) + if not dataset_path.is_absolute(): + dataset_path = artifact_root / dataset_path + if not dataset_path.exists(): + raise FileNotFoundError(f"Stage 8 dataset artifact is missing: {dataset_path}") + + stage_manifest_paths = dict(manifest.get("stage_output_manifests", {})) + stage8_manifest_value = stage_manifest_paths.get("08_dataset_assembly") + if not stage8_manifest_value: + raise ValueError("Stage 8 output manifest is not declared") + stage8_manifest_path = Path(str(stage8_manifest_value)) + if not stage8_manifest_path.is_absolute(): + stage8_manifest_path = artifact_root / stage8_manifest_path + if not stage8_manifest_path.exists(): + raise FileNotFoundError( + f"Stage 8 output manifest is missing: {stage8_manifest_path}" + ) + stage8_manifest = json.loads(stage8_manifest_path.read_text()) + if stage8_manifest.get("lifecycleStatus") != "complete": + raise ValueError("Stage 8 must be complete before Stage 9 replay") + stage8_outputs = stage8_manifest.get("outputs") + if isinstance(stage8_outputs, dict): + serialized_dataset = stage8_outputs.get("policyengine_dataset") + if isinstance(serialized_dataset, dict): + output_path = serialized_dataset.get("path") + if output_path and Path(str(output_path)).name != dataset_path.name: + raise ValueError( + "Stage 8 dataset output does not match the root manifest " + "policyengine_dataset artifact" + ) + return dataset_path + + +def _resolve_replay_output_dir( + artifact_root: Path, + *, + output_dir: str | Path | None, + run_id: str | None, +) -> Path: + if output_dir is not None: + return Path(output_dir).expanduser().resolve() + resolved_run_id = run_id or datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + return ( + artifact_root + / "stage_artifacts" + / "09_validation_benchmarking" + / "replays" + / resolved_run_id + ) + + +def _load_optional_payload( + value: str | Path | dict[str, Any] | None, +) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, dict): + return dict(value) + return json.loads(Path(value).expanduser().read_text()) + + +def _relative_to_root(path: Path, artifact_root: Path) -> str: + try: + return str(path.relative_to(artifact_root)) + except ValueError: + return str(path) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Rerun Stage 9 validation evidence against a saved Stage 8 dataset." + ) + parser.add_argument("artifact_dir") + parser.add_argument("--output-dir") + parser.add_argument("--run-id") + parser.add_argument("--baseline-dataset") + parser.add_argument("--policyengine-us-data-repo") + parser.add_argument("--period", type=int) + parser.add_argument("--precomputed-policyengine-harness") + parser.add_argument("--precomputed-policyengine-native-scores") + parser.add_argument("--allow-overwrite", action="store_true") + args = parser.parse_args(argv) + result = replay_us_stage9_validation_benchmarking( + args.artifact_dir, + output_dir=args.output_dir, + baseline_dataset=args.baseline_dataset, + policyengine_us_data_repo=args.policyengine_us_data_repo, + period=args.period, + precomputed_policyengine_harness=args.precomputed_policyengine_harness, + precomputed_policyengine_native_scores=args.precomputed_policyengine_native_scores, + run_id=args.run_id, + allow_overwrite=args.allow_overwrite, + ) + print(result.validation_evidence) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + + +__all__ = [ + "USStage9ReplayResult", + "main", + "replay_us_stage9_validation_benchmarking", +] diff --git a/src/microplex_us/pipelines/stage_manifest_io.py b/src/microplex_us/pipelines/stage_manifest_io.py index 0fcf891..3ee870a 100644 --- a/src/microplex_us/pipelines/stage_manifest_io.py +++ b/src/microplex_us/pipelines/stage_manifest_io.py @@ -54,6 +54,7 @@ def load_us_stage_manifest(path: str | Path) -> USStageManifest: def write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: """Write JSON atomically through a sibling temporary file.""" + path.parent.mkdir(parents=True, exist_ok=True) temporary = path.with_suffix(path.suffix + ".tmp") temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) temporary.replace(path) diff --git a/src/microplex_us/pipelines/stage_run.py b/src/microplex_us/pipelines/stage_run.py index 240b2e8..508dd08 100644 --- a/src/microplex_us/pipelines/stage_run.py +++ b/src/microplex_us/pipelines/stage_run.py @@ -5,6 +5,7 @@ import json from collections.abc import Mapping from dataclasses import asdict, dataclass, field, fields, is_dataclass +from enum import Enum from pathlib import Path from typing import Any, Literal @@ -1419,6 +1420,8 @@ def _serialize_value(value: Any, artifact_root: str | Path | None) -> Any: return value.to_dict(artifact_root) if isinstance(value, Path): return str(value) + if isinstance(value, Enum): + return value.value if isinstance(value, Mapping): return { str(key): _serialize_value(item, artifact_root) diff --git a/src/microplex_us/pipelines/stage_runtime.py b/src/microplex_us/pipelines/stage_runtime.py index 1620abd..a1a20ab 100644 --- a/src/microplex_us/pipelines/stage_runtime.py +++ b/src/microplex_us/pipelines/stage_runtime.py @@ -84,6 +84,7 @@ def start_stage( payload["failedAt"] = None payload["deferredReason"] = None payload["failure"] = None + payload["inputOverrides"] = self._serialized_overrides_for_stage(stage_id) payload["metadata"] = { **dict(payload.get("metadata", {})), **dict(metadata or {}), @@ -186,8 +187,14 @@ def complete_stage(self, outputs: USStageOutputManifest) -> dict[str, Any]: ) self._run_writer.manifest_payload = self.manifest_payload self._run_writer.record_stage(lifecycle_outputs) - self.manifest_payload = self._run_writer.write_manifest_files() - return self._stage_payload(outputs.stage_id) + payload = lifecycle_outputs.to_dict( + self.artifact_root, + input_stage_manifest=input_stage_manifest, + input_overrides=self._input_overrides_for_stage(outputs.stage_id), + ) + self._write_stage_payload(outputs.stage_id, payload) + self._refresh_aggregate() + return payload def fail_stage( self, @@ -330,6 +337,9 @@ def _refresh_aggregate(self) -> None: artifacts.setdefault("stage_manifest", stage_manifest_path.name) artifacts.setdefault("manifest", "manifest.json") self.manifest_payload["artifacts"] = artifacts + _write_json_atomically( + self.artifact_root / "manifest.json", self.manifest_payload + ) write_us_stage_manifest( self.artifact_root, stage_manifest_path, @@ -354,6 +364,7 @@ def _validate_start_transition(self, stage_id: str) -> None: previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] previous_payload = self._stage_payload(previous_stage_id) if previous_payload.get("lifecycleStatus") == "complete": + self._validate_completed_previous_stage(previous_stage_id, previous_payload) return contract = get_us_pipeline_stage_contract(stage_id) required_previous_inputs = tuple( @@ -371,6 +382,33 @@ def _validate_start_transition(self, stage_id: str) -> None: "unless explicit stage input overrides are enabled" ) + def _validate_completed_previous_stage( + self, + stage_id: str, + payload: Mapping[str, Any], + ) -> None: + if payload.get("contractVersion") != US_STAGE_CONTRACT_VERSION: + raise ValueError( + f"{stage_id} uses stale contract version " + f"{payload.get('contractVersion')!r}; expected " + f"{US_STAGE_CONTRACT_VERSION!r}" + ) + missing = tuple(payload.get("missingRequiredOutputs") or ()) + if missing: + raise ValueError( + f"{stage_id} is complete but missing required outputs: " + f"{', '.join(str(item) for item in missing)}" + ) + outputs = payload.get("outputs") + if not isinstance(outputs, Mapping): + raise ValueError(f"{stage_id} has no serialized outputs") + required_outputs = tuple(payload.get("requiredOutputs") or ()) + for key in required_outputs: + if not _serialized_output_is_available(outputs.get(str(key))): + raise ValueError( + f"{stage_id} is complete but required output {key!r} is unavailable" + ) + def _override_satisfies(self, stage_id: str, key: str) -> bool: if not self.allow_stage_input_overrides: return False @@ -379,6 +417,22 @@ def _override_satisfies(self, stage_id: str, key: str) -> bool: for override in self.stage_input_overrides ) + def _serialized_overrides_for_stage(self, stage_id: str) -> list[dict[str, Any]]: + return [ + override.to_dict(self.artifact_root) + for override in self._input_overrides_for_stage(stage_id) + ] + + def _input_overrides_for_stage( + self, + stage_id: str, + ) -> tuple[USStageInputOverride, ...]: + return tuple( + override + for override in self.stage_input_overrides + if override.stage_id == stage_id + ) + def _validate_output_key(self, stage_id: str, key: str) -> None: contract = get_us_pipeline_stage_contract(stage_id) valid_keys = {resource.key for resource in contract.outputs} @@ -479,6 +533,21 @@ def _runtime_serialize(value: Any, artifact_root: str | Path | None) -> Any: return _serialize_value(value, artifact_root) +def _serialized_output_is_available(value: Any) -> bool: + if value is None: + return False + if isinstance(value, Mapping): + exists = value.get("exists") + if exists is not None: + return bool(exists) + return bool(value) + if isinstance(value, (list, tuple, set, frozenset)): + return bool(value) + if isinstance(value, str): + return bool(value) + return True + + def _event( event: str, timestamp: str, diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index 08352f4..ac28036 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -7,7 +7,7 @@ import time import warnings from collections import Counter -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from dataclasses import asdict, dataclass, field from functools import lru_cache from pathlib import Path @@ -73,9 +73,23 @@ from microplex_us.pipelines.pe_native_optimization import ( optimize_policyengine_us_native_loss_dataset, ) +from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_manifest_io import write_json_atomically +from microplex_us.pipelines.stage_policyengine_artifacts import ( + write_us_policyengine_entity_stage_artifact, +) from microplex_us.pipelines.stage_run import ( + USArtifactRef, + USCalibrationOutputs, USDiagnosticOutput, + USDonorSynthesisOutputs, + USPolicyEngineEntityOutputs, + USSeedScaffoldOutputs, USSourceLoadingOutputs, + USSourcePlanningOutputs, ) from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter from microplex_us.policyengine.aotc import ( @@ -1980,7 +1994,7 @@ def _source_loading_stage_outputs( "childEntity": relationship.child_entity.value, "parentKey": relationship.parent_key, "childKey": relationship.child_key, - "cardinality": relationship.cardinality, + "cardinality": relationship.cardinality.value, } for relationship in frame.relationships ] @@ -2004,6 +2018,142 @@ def _source_loading_stage_outputs( ) +def _runtime_stage_artifact_path( + writer: USStageRuntimeWriter, + stage_id: str, + artifact_key: str, +) -> Path: + return resolve_us_stage_artifact_contract_path( + writer.artifact_root, + stage_id, + artifact_key, + ) + + +def _runtime_stage_artifact_ref( + writer: USStageRuntimeWriter, + stage_id: str, + artifact_key: str, + *, + assume_exists: bool = False, +) -> USArtifactRef: + contract = get_us_stage_artifact_contract(stage_id, artifact_key) + return USArtifactRef( + key=artifact_key, + path=_runtime_stage_artifact_path(writer, stage_id, artifact_key), + format=contract.format, + required=contract.required, + resume_role=contract.resume_role, + assume_exists=assume_exists, + ) + + +def _runtime_stage_diagnostics( + stage_id: str, + summary: Mapping[str, Any], +) -> dict[str, USDiagnosticOutput]: + return { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description=f"Runtime diagnostic summary for {stage_id}.", + summary=dict(summary), + ) + } + + +def _write_runtime_dataframe_artifact(path: Path, frame: pd.DataFrame) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + frame.to_parquet(path, index=False) + + +def _runtime_source_plan_payload( + source_inputs: list[USMicroplexSourceInput], + fusion_plan: FusionPlan, + scaffold_input: USMicroplexSourceInput, +) -> dict[str, Any]: + source_names = tuple(input.frame.source.name for input in source_inputs) + return { + "formatVersion": 1, + "stageId": "03_source_planning", + "sourceNames": list(source_names), + "scaffoldSource": scaffold_input.frame.source.name, + "donorSourceNames": [ + source_name + for source_name in source_names + if source_name != scaffold_input.frame.source.name + ], + "fusionPlan": { + "sourceNames": list(fusion_plan.source_names), + }, + "scaffoldSelection": _runtime_scaffold_selection_summary( + source_inputs, + scaffold_input, + ), + } + + +def _runtime_scaffold_selection_summary( + source_inputs: list[USMicroplexSourceInput], + scaffold_input: USMicroplexSourceInput, +) -> dict[str, Any]: + return { + "scaffold_source": scaffold_input.frame.source.name, + "candidate_sources": [ + source_input.frame.source.name for source_input in source_inputs + ], + "household_rows": int(len(scaffold_input.households)), + "person_rows": int(len(scaffold_input.persons)), + } + + +def _runtime_seed_schema_metadata(seed_data: pd.DataFrame) -> dict[str, Any]: + identifier_columns = ( + "household_id", + "person_id", + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ) + return { + "rows": int(len(seed_data)), + "columns": int(len(seed_data.columns)), + "identifier_columns": { + column: column in seed_data.columns for column in identifier_columns + }, + "has_weight": "weight" in seed_data.columns, + } + + +def _runtime_targets_payload(targets: USMicroplexTargets) -> dict[str, Any]: + return { + "marginal": targets.marginal, + "continuous": targets.continuous, + } + + +def _runtime_target_ledger(targets: USMicroplexTargets) -> dict[str, Any]: + return { + "n_marginal_groups": len(targets.marginal), + "n_continuous": len(targets.continuous), + "marginal_keys": sorted(targets.marginal.keys()), + "continuous_keys": sorted(targets.continuous.keys()), + } + + +def _runtime_policyengine_table_summary( + tables: PolicyEngineUSEntityTableBundle, +) -> dict[str, Any]: + return { + "households": int(len(tables.households)), + "persons": int(len(tables.persons)), + "tax_units": int(len(tables.tax_units)), + "spm_units": int(len(tables.spm_units)), + "families": int(len(tables.families)), + "marital_units": int(len(tables.marital_units)), + } + + class USMicroplexPipeline: """End-to-end build orchestration for a US microplex dataset.""" @@ -2016,6 +2166,25 @@ def __init__( self.config = config or USMicroplexBuildConfig() self.stage_runtime_writer = stage_runtime_writer + def _runtime_start_stage( + self, + stage_id: str, + *, + metadata: Mapping[str, Any] | None = None, + ) -> None: + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.start_stage(stage_id, metadata=metadata) + + def _runtime_fail_stage( + self, + stage_id: str, + error: BaseException, + *, + metadata: Mapping[str, Any] | None = None, + ) -> None: + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.fail_stage(stage_id, error, metadata=metadata) + def build_from_data_dir(self, data_dir: str | Path) -> USMicroplexBuildResult: from microplex_us.data_sources.cps import ( DEFAULT_CACHE_DIR, @@ -2051,7 +2220,18 @@ def build_from_source_provider( provider: SourceProvider, query: SourceQuery | None = None, ) -> USMicroplexBuildResult: - frame = provider.load_frame(query) + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.start_stage("02_source_loading") + try: + frame = provider.load_frame(query) + except Exception as exc: + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.fail_stage("02_source_loading", exc) + raise + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.complete_stage( + _source_loading_stage_outputs([frame]) + ) return self.build_from_frames([frame]) def build_from_source_providers( @@ -2084,6 +2264,11 @@ def build_from_source_providers( return self.build_from_frames(frames) def build_from_frame(self, frame: ObservationFrame) -> USMicroplexBuildResult: + if self.stage_runtime_writer is not None: + self.stage_runtime_writer.start_stage("02_source_loading") + self.stage_runtime_writer.complete_stage( + _source_loading_stage_outputs([frame]) + ) return self.build_from_frames([frame]) def build_from_frames( @@ -2095,100 +2280,222 @@ def build_from_frames( "USMicroplexPipeline requires at least one observation frame" ) - source_inputs = [self.prepare_source_input(frame) for frame in frames] - fusion_plan = FusionPlan.from_sources([frame.source for frame in frames]) - scaffold_input = self._select_scaffold_source(source_inputs) - seed_data = self.prepare_seed_data_from_source(scaffold_input) - seed_data = self._strip_generated_entity_ids( - seed_data, - scaffold_input=scaffold_input, - ) - scaffold_seed_data = seed_data.copy() - donor_integration = self._integrate_donor_sources( - seed_data, - scaffold_input=scaffold_input, - donor_inputs=[ - source for source in source_inputs if source is not scaffold_input - ], - ) - seed_data = donor_integration["seed_data"] - seed_data = self._apply_dependent_tax_leaf_soft_caps(seed_data) - _emit_us_pipeline_progress( - "US microplex build: seed ready", - scaffold_source=scaffold_input.frame.source.name, - sources=_format_progress_values(fusion_plan.source_names), - rows=int(len(seed_data)), - columns=int(len(seed_data.columns)), - donor_integrated_variables=int( - len(donor_integration["integrated_variables"]) - ), - ) - _emit_us_pipeline_progress( - "US microplex build: targets start", - rows=int(len(seed_data)), - ) - targets = self.build_targets(seed_data) - _emit_us_pipeline_progress( - "US microplex build: targets complete", - marginal_targets=int(len(targets.marginal)), - continuous_targets=int(len(targets.continuous)), - ) - synthesis_variables = self._resolve_synthesis_variables( - scaffold_input, - fusion_plan=fusion_plan, - include_all_observed_targets=len(source_inputs) > 1, - available_columns=set(seed_data.columns), - observed_frame=seed_data, - ) - _emit_us_pipeline_progress( - "US microplex build: synthesis variables ready", - condition_vars=int(len(synthesis_variables.condition_vars)), - target_vars=int(len(synthesis_variables.target_vars)), - ) - _emit_us_pipeline_progress( - "US microplex build: synthesis start", - rows=int(len(seed_data)), - ) - synthetic_data, synthesizer, synthesis_metadata = self.synthesize( - seed_data, - synthesis_variables=synthesis_variables, - ) - _emit_us_pipeline_progress( - "US microplex build: synthesis complete", - rows=int(len(synthetic_data)), - columns=int(len(synthetic_data.columns)), - ) - synthesis_metadata = { - **synthesis_metadata, - "source_names": fusion_plan.source_names, - "condition_vars": list(synthesis_variables.condition_vars), - "target_vars": list(synthesis_variables.target_vars), - "scaffold_source": scaffold_input.frame.source.name, - "donor_integrated_variables": donor_integration["integrated_variables"], - "donor_conditioning_diagnostics": donor_integration.get( - "conditioning_diagnostics", [] - ), - "donor_excluded_variables": list( - self.config.donor_imputer_excluded_variables - ), - "donor_authoritative_override_variables": list( - self.config.donor_imputer_authoritative_override_variables - ), - "state_program_support_proxies": _state_program_support_proxy_summary( - set(seed_data.columns) - ), - } - _emit_us_pipeline_progress( - "US microplex build: support enforcement start", - rows=int(len(synthetic_data)), - ) - synthetic_data = self.ensure_target_support(synthetic_data, seed_data, targets) - _emit_us_pipeline_progress( - "US microplex build: support enforcement complete", - rows=int(len(synthetic_data)), - columns=int(len(synthetic_data.columns)), - ) - if self._has_policyengine_calibration_targets(): + self._runtime_start_stage("03_source_planning") + try: + source_inputs = [self.prepare_source_input(frame) for frame in frames] + fusion_plan = FusionPlan.from_sources([frame.source for frame in frames]) + scaffold_input = self._select_scaffold_source(source_inputs) + if self.stage_runtime_writer is not None: + source_plan_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "03_source_planning", + "source_plan", + ) + write_json_atomically( + source_plan_path, + _runtime_source_plan_payload( + source_inputs, + fusion_plan, + scaffold_input, + ), + ) + scaffold_selection = _runtime_scaffold_selection_summary( + source_inputs, + scaffold_input, + ) + self.stage_runtime_writer.complete_stage( + USSourcePlanningOutputs( + source_plan=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "03_source_planning", + "source_plan", + ), + scaffold_selection=scaffold_selection, + diagnostics=_runtime_stage_diagnostics( + "03_source_planning", + scaffold_selection, + ), + ) + ) + except Exception as exc: + self._runtime_fail_stage("03_source_planning", exc) + raise + + self._runtime_start_stage("04_seed_scaffold") + try: + seed_data = self.prepare_seed_data_from_source(scaffold_input) + seed_data = self._strip_generated_entity_ids( + seed_data, + scaffold_input=scaffold_input, + ) + scaffold_seed_data = seed_data.copy() + if self.stage_runtime_writer is not None: + scaffold_seed_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "04_seed_scaffold", + "scaffold_seed_data", + ) + _write_runtime_dataframe_artifact( + scaffold_seed_path, scaffold_seed_data + ) + seed_schema_metadata = _runtime_seed_schema_metadata(scaffold_seed_data) + self.stage_runtime_writer.complete_stage( + USSeedScaffoldOutputs( + scaffold_seed_data=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "04_seed_scaffold", + "scaffold_seed_data", + ), + seed_schema_metadata=seed_schema_metadata, + diagnostics=_runtime_stage_diagnostics( + "04_seed_scaffold", + { + **seed_schema_metadata, + "scaffold_source": scaffold_input.frame.source.name, + }, + ), + ) + ) + except Exception as exc: + self._runtime_fail_stage("04_seed_scaffold", exc) + raise + + self._runtime_start_stage("05_donor_integration_synthesis") + try: + donor_integration = self._integrate_donor_sources( + seed_data, + scaffold_input=scaffold_input, + donor_inputs=[ + source for source in source_inputs if source is not scaffold_input + ], + ) + seed_data = donor_integration["seed_data"] + seed_data = self._apply_dependent_tax_leaf_soft_caps(seed_data) + _emit_us_pipeline_progress( + "US microplex build: seed ready", + scaffold_source=scaffold_input.frame.source.name, + sources=_format_progress_values(fusion_plan.source_names), + rows=int(len(seed_data)), + columns=int(len(seed_data.columns)), + donor_integrated_variables=int( + len(donor_integration["integrated_variables"]) + ), + ) + _emit_us_pipeline_progress( + "US microplex build: targets start", + rows=int(len(seed_data)), + ) + targets = self.build_targets(seed_data) + _emit_us_pipeline_progress( + "US microplex build: targets complete", + marginal_targets=int(len(targets.marginal)), + continuous_targets=int(len(targets.continuous)), + ) + synthesis_variables = self._resolve_synthesis_variables( + scaffold_input, + fusion_plan=fusion_plan, + include_all_observed_targets=len(source_inputs) > 1, + available_columns=set(seed_data.columns), + observed_frame=seed_data, + ) + _emit_us_pipeline_progress( + "US microplex build: synthesis variables ready", + condition_vars=int(len(synthesis_variables.condition_vars)), + target_vars=int(len(synthesis_variables.target_vars)), + ) + _emit_us_pipeline_progress( + "US microplex build: synthesis start", + rows=int(len(seed_data)), + ) + synthetic_data, synthesizer, synthesis_metadata = self.synthesize( + seed_data, + synthesis_variables=synthesis_variables, + ) + _emit_us_pipeline_progress( + "US microplex build: synthesis complete", + rows=int(len(synthetic_data)), + columns=int(len(synthetic_data.columns)), + ) + synthesis_metadata = { + **synthesis_metadata, + "source_names": fusion_plan.source_names, + "condition_vars": list(synthesis_variables.condition_vars), + "target_vars": list(synthesis_variables.target_vars), + "scaffold_source": scaffold_input.frame.source.name, + "donor_integrated_variables": donor_integration["integrated_variables"], + "donor_conditioning_diagnostics": donor_integration.get( + "conditioning_diagnostics", [] + ), + "donor_excluded_variables": list( + self.config.donor_imputer_excluded_variables + ), + "donor_authoritative_override_variables": list( + self.config.donor_imputer_authoritative_override_variables + ), + "state_program_support_proxies": _state_program_support_proxy_summary( + set(seed_data.columns) + ), + } + _emit_us_pipeline_progress( + "US microplex build: support enforcement start", + rows=int(len(synthetic_data)), + ) + synthetic_data = self.ensure_target_support( + synthetic_data, seed_data, targets + ) + _emit_us_pipeline_progress( + "US microplex build: support enforcement complete", + rows=int(len(synthetic_data)), + columns=int(len(synthetic_data.columns)), + ) + if self.stage_runtime_writer is not None: + seed_data_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "05_donor_integration_synthesis", + "seed_data", + ) + synthetic_data_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "05_donor_integration_synthesis", + "synthetic_data", + ) + _write_runtime_dataframe_artifact(seed_data_path, seed_data) + _write_runtime_dataframe_artifact(synthetic_data_path, synthetic_data) + self.stage_runtime_writer.complete_stage( + USDonorSynthesisOutputs( + seed_data=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "05_donor_integration_synthesis", + "seed_data", + ), + synthetic_data=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "05_donor_integration_synthesis", + "synthetic_data", + ), + synthesis_metadata=synthesis_metadata, + diagnostics=_runtime_stage_diagnostics( + "05_donor_integration_synthesis", + { + "seed_rows": int(len(seed_data)), + "synthetic_rows": int(len(synthetic_data)), + "donor_integrated_variables": len( + donor_integration["integrated_variables"] + ), + "condition_vars": len( + synthesis_variables.condition_vars + ), + "target_vars": len(synthesis_variables.target_vars), + }, + ), + ) + ) + except Exception as exc: + self._runtime_fail_stage("05_donor_integration_synthesis", exc) + raise + + self._runtime_start_stage("06_policyengine_entities") + try: _emit_us_pipeline_progress( "US microplex build: policyengine tables start", rows=int(len(synthetic_data)), @@ -2213,44 +2520,127 @@ def build_from_frames( synthetic_tables, stage="pre_calibration", ) - _emit_us_pipeline_progress( - "US microplex build: policyengine calibration start", - backend=self.config.calibration_backend, - ) - ( - policyengine_tables, - calibrated_data, - calibration_summary, - ) = self.calibrate_policyengine_tables(synthetic_tables) - _emit_us_pipeline_progress( - "US microplex build: policyengine calibration complete", - backend=self.config.calibration_backend, - calibrated_rows=int(len(calibrated_data)), - ) - else: - _emit_us_pipeline_progress( - "US microplex build: calibration start", - backend=self.config.calibration_backend, - rows=int(len(synthetic_data)), - ) - calibrated_data, calibration_summary = self.calibrate( - synthetic_data, targets - ) - _emit_us_pipeline_progress( - "US microplex build: calibration complete", - backend=self.config.calibration_backend, - calibrated_rows=int(len(calibrated_data)), - ) - _emit_us_pipeline_progress( - "US microplex build: policyengine tables start", - rows=int(len(calibrated_data)), - ) - policyengine_tables = self.build_policyengine_entity_tables(calibrated_data) - _emit_us_pipeline_progress( - "US microplex build: policyengine tables complete", - households=int(len(policyengine_tables.households)), - persons=int(len(policyengine_tables.persons)), - ) + if self.stage_runtime_writer is not None: + write_us_policyengine_entity_stage_artifact( + synthetic_tables, + self.stage_runtime_writer.artifact_root, + ) + entity_summary = _runtime_policyengine_table_summary(synthetic_tables) + self.stage_runtime_writer.complete_stage( + USPolicyEngineEntityOutputs( + policyengine_entity_tables=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "06_policyengine_entities", + "policyengine_entity_tables", + ), + materialized_policyengine_inputs=entity_summary, + diagnostics=_runtime_stage_diagnostics( + "06_policyengine_entities", + entity_summary, + ), + ) + ) + except Exception as exc: + self._runtime_fail_stage("06_policyengine_entities", exc) + raise + + self._runtime_start_stage("07_calibration") + try: + if self._has_policyengine_calibration_targets(): + _emit_us_pipeline_progress( + "US microplex build: policyengine calibration start", + backend=self.config.calibration_backend, + ) + ( + policyengine_tables, + calibrated_data, + calibration_summary, + ) = self.calibrate_policyengine_tables(synthetic_tables) + _emit_us_pipeline_progress( + "US microplex build: policyengine calibration complete", + backend=self.config.calibration_backend, + calibrated_rows=int(len(calibrated_data)), + ) + else: + _emit_us_pipeline_progress( + "US microplex build: calibration start", + backend=self.config.calibration_backend, + rows=int(len(synthetic_data)), + ) + calibrated_data, calibration_summary = self.calibrate( + synthetic_data, targets + ) + _emit_us_pipeline_progress( + "US microplex build: calibration complete", + backend=self.config.calibration_backend, + calibrated_rows=int(len(calibrated_data)), + ) + _emit_us_pipeline_progress( + "US microplex build: policyengine tables start", + rows=int(len(calibrated_data)), + ) + policyengine_tables = self.build_policyengine_entity_tables( + calibrated_data + ) + _emit_us_pipeline_progress( + "US microplex build: policyengine tables complete", + households=int(len(policyengine_tables.households)), + persons=int(len(policyengine_tables.persons)), + ) + if self.stage_runtime_writer is not None: + calibrated_data_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "07_calibration", + "calibrated_data", + ) + targets_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "07_calibration", + "targets", + ) + calibration_summary_path = _runtime_stage_artifact_path( + self.stage_runtime_writer, + "07_calibration", + "calibration_summary", + ) + _write_runtime_dataframe_artifact( + calibrated_data_path, + calibrated_data, + ) + write_json_atomically(targets_path, _runtime_targets_payload(targets)) + write_json_atomically(calibration_summary_path, calibration_summary) + target_ledger = _runtime_target_ledger(targets) + self.stage_runtime_writer.complete_stage( + USCalibrationOutputs( + calibrated_data=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "07_calibration", + "calibrated_data", + ), + targets=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "07_calibration", + "targets", + ), + calibration_summary=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "07_calibration", + "calibration_summary", + ), + target_ledger=target_ledger, + diagnostics=_runtime_stage_diagnostics( + "07_calibration", + { + "calibrated_rows": int(len(calibrated_data)), + "backend": self.config.calibration_backend, + **target_ledger, + }, + ), + ) + ) + except Exception as exc: + self._runtime_fail_stage("07_calibration", exc) + raise return USMicroplexBuildResult( config=self.config, diff --git a/tests/pipelines/test_stage9_replay.py b/tests/pipelines/test_stage9_replay.py new file mode 100644 index 0000000..2169937 --- /dev/null +++ b/tests/pipelines/test_stage9_replay.py @@ -0,0 +1,101 @@ +"""Tests for safe Stage 9 validation replay.""" + +import json + +import pytest + +from microplex_us.pipelines.stage9_replay import ( + main, + replay_us_stage9_validation_benchmarking, +) + + +def _write_stage8_bundle(tmp_path, *, stage8_status: str = "complete"): + artifact_dir = tmp_path / "bundle" + manifest_dir = artifact_dir / "stage_artifacts" / "manifests" + manifest_dir.mkdir(parents=True) + dataset_path = artifact_dir / "policyengine_us.h5" + dataset_path.write_bytes(b"h5-placeholder") + stage8_manifest = { + "stageId": "08_dataset_assembly", + "lifecycleStatus": stage8_status, + "outputs": { + "policyengine_dataset": { + "path": "policyengine_us.h5", + "exists": True, + } + }, + } + (manifest_dir / "08_dataset_assembly.json").write_text(json.dumps(stage8_manifest)) + manifest = { + "config": {"policyengine_dataset_year": 2024}, + "artifacts": {"policyengine_dataset": "policyengine_us.h5"}, + "stage_output_manifests": { + "08_dataset_assembly": ( + "stage_artifacts/manifests/08_dataset_assembly.json" + ) + }, + } + (artifact_dir / "manifest.json").write_text(json.dumps(manifest)) + return artifact_dir + + +def test_stage9_replay_writes_new_evidence_without_mutating_source_bundle(tmp_path): + artifact_dir = _write_stage8_bundle(tmp_path) + original_manifest = (artifact_dir / "manifest.json").read_text() + + result = replay_us_stage9_validation_benchmarking( + artifact_dir, + run_id="unit-replay", + precomputed_policyengine_native_scores={ + "summary": {"enhanced_cps_native_loss_delta": -0.1} + }, + ) + + assert result.output_dir == ( + artifact_dir + / "stage_artifacts" + / "09_validation_benchmarking" + / "replays" + / "unit-replay" + ) + assert result.validation_evidence.exists() + assert result.policyengine_native_scores is not None + assert result.policyengine_native_scores.exists() + assert (artifact_dir / "manifest.json").read_text() == original_manifest + + evidence = json.loads(result.validation_evidence.read_text()) + assert evidence["stageId"] == "09_validation_benchmarking" + assert evidence["evidence"][0]["key"] == "policyengine_native_scores" + + +def test_stage9_replay_rejects_incomplete_stage8(tmp_path): + artifact_dir = _write_stage8_bundle(tmp_path, stage8_status="running") + + with pytest.raises(ValueError, match="Stage 8 must be complete"): + replay_us_stage9_validation_benchmarking( + artifact_dir, + precomputed_policyengine_native_scores={"summary": {"loss": 1.0}}, + ) + + +def test_stage9_replay_cli_smoke(tmp_path, capsys): + artifact_dir = _write_stage8_bundle(tmp_path) + payload_path = tmp_path / "native_scores.json" + payload_path.write_text(json.dumps({"summary": {"loss": 1.0}})) + + assert ( + main( + [ + str(artifact_dir), + "--run-id", + "cli-replay", + "--precomputed-policyengine-native-scores", + str(payload_path), + ] + ) + == 0 + ) + + output = capsys.readouterr().out.strip() + assert output.endswith("evidence_manifest.json") diff --git a/tests/pipelines/test_stage_runtime.py b/tests/pipelines/test_stage_runtime.py index d54a0cf..6363850 100644 --- a/tests/pipelines/test_stage_runtime.py +++ b/tests/pipelines/test_stage_runtime.py @@ -3,12 +3,15 @@ import json import pytest +from microplex.core import RelationshipCardinality +from microplex_us.pipelines.stage_contracts import US_STAGE_CONTRACT_VERSION from microplex_us.pipelines.stage_run import ( USArtifactRef, USDiagnosticOutput, USRunProfileOutputs, USSourceLoadingOutputs, + USStageInputOverride, ) from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter @@ -85,6 +88,104 @@ def test_runtime_writer_completes_stage_and_exposes_lifecycle(tmp_path): ] +def test_runtime_writer_serializes_enum_outputs(tmp_path): + writer = USStageRuntimeWriter( + tmp_path, + allow_stage_input_overrides=True, + stage_input_overrides=( + USStageInputOverride( + stage_id="02_source_loading", + key="provider_query_plan", + path="overrides/provider_query_plan.json", + ), + ), + ) + writer.start_stage("02_source_loading") + writer.complete_stage( + USSourceLoadingOutputs( + observation_frame_summary={"source_count": 1}, + source_descriptors=("unit",), + source_relationships={ + "unit": [{"cardinality": RelationshipCardinality.ONE_TO_MANY}] + }, + diagnostics=_diagnostics("02_source_loading"), + ) + ) + + stage2 = json.loads( + ( + tmp_path / "stage_artifacts" / "manifests" / "02_source_loading.json" + ).read_text() + ) + + assert stage2["outputs"]["source_relationships"]["unit"][0]["cardinality"] == ( + "one_to_many" + ) + + +def test_runtime_writer_records_overrides_in_running_manifest(tmp_path): + writer = USStageRuntimeWriter( + tmp_path, + allow_stage_input_overrides=True, + stage_input_overrides=( + USStageInputOverride( + stage_id="02_source_loading", + key="provider_query_plan", + path="overrides/provider_query_plan.json", + reason="unit test", + ), + ), + ) + + writer.start_stage("02_source_loading") + stage2 = json.loads( + ( + tmp_path / "stage_artifacts" / "manifests" / "02_source_loading.json" + ).read_text() + ) + + assert stage2["inputOverrides"] == [ + { + "stageId": "02_source_loading", + "key": "provider_query_plan", + "path": "overrides/provider_query_plan.json", + "reason": "unit test", + } + ] + + +def test_runtime_writer_refreshes_root_manifest_on_stage_start(tmp_path): + writer = USStageRuntimeWriter(tmp_path) + + writer.start_stage("01_run_profile") + manifest = json.loads((tmp_path / "manifest.json").read_text()) + + assert manifest["stage_output_manifests"]["01_run_profile"] == ( + "stage_artifacts/manifests/01_run_profile.json" + ) + + +def test_runtime_writer_rejects_stale_complete_previous_manifest(tmp_path): + writer = USStageRuntimeWriter(tmp_path) + stage1_path = tmp_path / "stage_artifacts" / "manifests" / "01_run_profile.json" + stage1_path.parent.mkdir(parents=True) + stage1_path.write_text( + json.dumps( + { + "stageId": "01_run_profile", + "contractVersion": US_STAGE_CONTRACT_VERSION, + "lifecycleStatus": "complete", + "requiredOutputs": ["manifest"], + "missingRequiredOutputs": ["manifest"], + "outputs": {}, + } + ) + ) + + with pytest.raises(ValueError, match="missing required outputs"): + writer.start_stage("02_source_loading") + + def test_runtime_writer_update_writes_json_artifact_reference(tmp_path): writer = USStageRuntimeWriter(tmp_path) payload = writer.record_output( diff --git a/tests/pipelines/test_versioned_artifacts.py b/tests/pipelines/test_versioned_artifacts.py index a107929..ea383d4 100644 --- a/tests/pipelines/test_versioned_artifacts.py +++ b/tests/pipelines/test_versioned_artifacts.py @@ -212,11 +212,14 @@ def test_save_versioned_us_microplex_artifacts_accepts_path_config_values(tmp_pa targets_db = tmp_path / "policy_data.db" baseline_dataset = tmp_path / "baseline.h5" _create_policyengine_targets_db(targets_db) - _write_baseline_dataset(baseline_dataset, _make_result( - targets_db=targets_db, - baseline_dataset=baseline_dataset, - snap_values=(100.0, 50.0), - ).policyengine_tables) + _write_baseline_dataset( + baseline_dataset, + _make_result( + targets_db=targets_db, + baseline_dataset=baseline_dataset, + snap_values=(100.0, 50.0), + ).policyengine_tables, + ) result = _make_result( targets_db=targets_db, @@ -233,7 +236,9 @@ def test_save_versioned_us_microplex_artifacts_accepts_path_config_values(tmp_pa policyengine_target_variables=("snap", "household_count"), ) - artifact_paths = save_versioned_us_microplex_artifacts(result, tmp_path / "artifacts") + artifact_paths = save_versioned_us_microplex_artifacts( + result, tmp_path / "artifacts" + ) manifest = json.loads(artifact_paths.manifest.read_text()) assert manifest["config"]["policyengine_targets_db"] == str(targets_db) @@ -369,7 +374,10 @@ def test_save_versioned_us_microplex_artifacts_uses_explicit_version(tmp_path): paths.output_dir / "stage_artifacts" / "03_source_planning" / "source_plan.json" ) assert paths.policyengine_entity_tables == ( - paths.output_dir / "stage_artifacts" / "06_policyengine_entities" / "metadata.json" + paths.output_dir + / "stage_artifacts" + / "06_policyengine_entities" + / "metadata.json" ) assert paths.calibration_summary == ( paths.output_dir @@ -611,7 +619,9 @@ def test_build_and_save_versioned_us_microplex_from_source_provider(tmp_path): "income": [55_000.0, 0.0, 72_000.0, 40_000.0, 18_000.0, 65_000.0], } ) - provider = _make_source_provider(name="test_cps", households=households, persons=persons) + provider = _make_source_provider( + name="test_cps", households=households, persons=persons + ) saved = build_and_save_versioned_us_microplex_from_source_provider( provider, @@ -721,6 +731,25 @@ def test_build_and_save_versioned_us_microplex_from_source_providers(tmp_path): assert saved.build_result.fusion_plan.source_names == ("test_cps", "test_puf") assert saved.current_entry is not None assert saved.frontier_delta == 0.0 + manifest = json.loads(saved.artifact_paths.manifest.read_text()) + stage_output_manifests = manifest["stage_output_manifests"] + assert tuple(stage_output_manifests) == ( + "01_run_profile", + "02_source_loading", + "03_source_planning", + "04_seed_scaffold", + "05_donor_integration_synthesis", + "06_policyengine_entities", + "07_calibration", + "08_dataset_assembly", + "09_validation_benchmarking", + ) + for stage_id, manifest_path in stage_output_manifests.items(): + stage_manifest = json.loads( + (saved.artifact_paths.output_dir / manifest_path).read_text() + ) + assert stage_manifest["lifecycleStatus"] in {"complete", "deferred"} + assert stage_manifest["events"] def test_build_and_save_versioned_us_microplex_from_data_dir(tmp_path): From 39086407b04d46d4942932bdfd1e035356811035 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 2 Jun 2026 16:01:20 +0200 Subject: [PATCH 3/7] Tighten live stage runtime manifests --- docs/stage-contracts.md | 20 +- src/microplex_us/pipelines/artifacts.py | 804 ++++++++++-------- .../pe_us_data_rebuild_checkpoint.py | 2 +- src/microplex_us/pipelines/stage9_replay.py | 17 +- src/microplex_us/pipelines/stage_artifacts.py | 29 +- src/microplex_us/pipelines/stage_contracts.py | 30 +- src/microplex_us/pipelines/stage_metrics.py | 2 +- .../pipelines/stage_policyengine_artifacts.py | 35 +- src/microplex_us/pipelines/stage_run.py | 35 +- src/microplex_us/pipelines/stage_runtime.py | 38 +- src/microplex_us/pipelines/stage_status.py | 2 +- src/microplex_us/pipelines/us.py | 19 +- src/microplex_us/policyengine/us.py | 13 +- tests/pipelines/test_artifacts.py | 4 + .../test_pe_us_data_rebuild_checkpoint.py | 14 +- tests/pipelines/test_stage9_replay.py | 18 + tests/pipelines/test_stage_artifacts.py | 62 +- tests/pipelines/test_stage_manifest.py | 8 +- tests/pipelines/test_stage_run.py | 6 +- tests/pipelines/test_versioned_artifacts.py | 3 +- 20 files changed, 702 insertions(+), 459 deletions(-) diff --git a/docs/stage-contracts.md b/docs/stage-contracts.md index 0259fb5..dee1685 100644 --- a/docs/stage-contracts.md +++ b/docs/stage-contracts.md @@ -34,6 +34,11 @@ directory before loading sources, writes Stage 1 immediately, writes Stage 2 as source frames load, then finalizes all stage manifests against the completed artifact manifest during save. +Other versioned convenience entry points still reconstruct their stage manifests +from the completed saved artifact manifest. They expose the same saved-run +contract files, but they do not yet produce live per-stage lifecycle updates +while the build is running. + The registry exposes two seam layers: - `inputs` and `outputs` are structured stage resources. They identify artifact, @@ -67,9 +72,11 @@ boundary artifacts where the pipeline already has stable outputs: - Stage 4: `stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet` - Stage 5: `seed_data.parquet` and `synthetic_data.parquet` -- Stage 6: `stage_artifacts/06_policyengine_entities/` +- Stage 6: `stage_artifacts/06_policyengine_entities/` for the pre-calibration + PolicyEngine entity-table checkpoint - Stage 7: `calibrated_data.parquet`, `targets.json`, and - `stage_artifacts/07_calibration/calibration_summary.json` + `stage_artifacts/07_calibration/calibration_summary.json`, plus the calibrated + PolicyEngine entity-table bundle used by dataset export - Stage 8: `policyengine_us.h5` - Stage 9: validation and benchmark evidence artifacts @@ -77,9 +84,12 @@ The Stage 4 artifact is the scaffold-projected seed before donor integration. It is a diagnostic and manual replay boundary, not an automatic conditional resume point yet. -Conditional execution is intentionally not implemented yet. The stage manifest -and artifacts are designed to make that possible later without changing the -saved-run contract again. +Conditional execution is intentionally narrow in this implementation. Stage 9 +validation and benchmarking can be replayed against an existing complete Stage 8 +dataset through `microplex-us-stage9-replay`. Earlier-stage conditional source +loading, donor integration, synthesis, calibration, and automatic graph +scheduling remain future work. The stage manifest and artifacts are designed to +make those routes possible later without changing the saved-run contract again. ## Artifact inventory and readiness diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 974182a..30415a0 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -44,6 +44,7 @@ ) from microplex_us.pipelines.stage_manifest import ( write_us_policyengine_entity_stage_artifact, + write_us_validation_evidence_manifest, ) from microplex_us.pipelines.stage_run import ( USArtifactRef, @@ -994,7 +995,7 @@ def save_us_microplex_artifacts( policyengine_entity_tables_path = ( resolve_us_stage_artifact_contract_path( output_dir, - "06_policyengine_entities", + "07_calibration", "policyengine_entity_tables", ) if result.policyengine_tables is not None @@ -1005,6 +1006,15 @@ def save_us_microplex_artifacts( "07_calibration", "calibration_summary", ) + pre_calibration_policyengine_entity_tables_path = ( + resolve_us_stage_artifact_contract_path( + output_dir, + "06_policyengine_entities", + "pre_calibration_policyengine_entity_tables", + ) + if result.policyengine_tables is not None + else None + ) validation_evidence_path = ( resolve_us_stage_artifact_contract_path( output_dir, @@ -1024,198 +1034,195 @@ def save_us_microplex_artifacts( if stage_runtime_writer is not None: stage_runtime_writer.start_stage("08_dataset_assembly") - if result.scaffold_seed_data is not None and scaffold_seed_data_path is not None: + try: + if ( + result.scaffold_seed_data is not None + and scaffold_seed_data_path is not None + ): + _write_parquet_unless_live_artifact_exists( + scaffold_seed_data_path, + result.scaffold_seed_data, + live_artifact=live_artifacts, + ) _write_parquet_unless_live_artifact_exists( - scaffold_seed_data_path, - result.scaffold_seed_data, + seed_data_path, + result.seed_data, + live_artifact=live_artifacts, + ) + _write_parquet_unless_live_artifact_exists( + synthetic_data_path, + result.synthetic_data, + live_artifact=live_artifacts, + ) + _write_parquet_unless_live_artifact_exists( + calibrated_data_path, + result.calibrated_data, + live_artifact=live_artifacts, + ) + _write_json_unless_live_artifact_exists( + targets_path, + { + "marginal": result.targets.marginal, + "continuous": result.targets.continuous, + }, live_artifact=live_artifacts, ) - _write_parquet_unless_live_artifact_exists( - seed_data_path, - result.seed_data, - live_artifact=live_artifacts, - ) - _write_parquet_unless_live_artifact_exists( - synthetic_data_path, - result.synthetic_data, - live_artifact=live_artifacts, - ) - _write_parquet_unless_live_artifact_exists( - calibrated_data_path, - result.calibrated_data, - live_artifact=live_artifacts, - ) - _write_json_unless_live_artifact_exists( - targets_path, - { - "marginal": result.targets.marginal, - "continuous": result.targets.continuous, - }, - live_artifact=live_artifacts, - ) - - if result.synthesizer is not None and synthesizer_path is not None: - result.synthesizer.save(synthesizer_path) - - if not (live_artifacts and source_plan_path.exists()): - _write_us_source_plan_artifact(result, source_plan_path) - if not (live_artifacts and calibration_summary_path.exists()): - _write_json_atomically(calibration_summary_path, result.calibration_summary) - source_weight_diagnostics_payload = _build_source_weight_diagnostics(result) - _write_json_atomically( - source_weight_diagnostics_path, - source_weight_diagnostics_payload, - ) - if result.policyengine_tables is not None and policyengine_dataset_path is not None: - if policyengine_entity_tables_path is not None: - if not (live_artifacts and policyengine_entity_tables_path.exists()): - write_us_policyengine_entity_stage_artifact( - result.policyengine_tables, - output_dir, - ) - period = result.config.policyengine_dataset_year or 2024 - USMicroplexPipeline(result.config).export_policyengine_dataset( - result, - policyengine_dataset_path, - period=period, + if result.synthesizer is not None and synthesizer_path is not None: + result.synthesizer.save(synthesizer_path) + + if not (live_artifacts and source_plan_path.exists()): + _write_us_source_plan_artifact(result, source_plan_path) + if not (live_artifacts and calibration_summary_path.exists()): + _write_json_atomically(calibration_summary_path, result.calibration_summary) + source_weight_diagnostics_payload = _build_source_weight_diagnostics(result) + _write_json_atomically( + source_weight_diagnostics_path, + source_weight_diagnostics_payload, ) - capital_gains_lots_path, capital_gains_lots_summary = ( - _maybe_write_capital_gains_lot_artifact(result, output_dir) - ) - if stage_runtime_writer is not None: - stage_runtime_writer.complete_stage( - USDatasetAssemblyOutputs( - policyengine_dataset=( - _stage_artifact_ref( + if ( + result.policyengine_tables is not None + and policyengine_dataset_path is not None + ): + if policyengine_entity_tables_path is not None: + if not (live_artifacts and policyengine_entity_tables_path.exists()): + write_us_policyengine_entity_stage_artifact( + result.policyengine_tables, output_dir, - "08_dataset_assembly", - "policyengine_dataset", + stage_id="07_calibration", + artifact_key="policyengine_entity_tables", + checkpoint_stage="post_calibration", ) - if policyengine_dataset_path is not None - else None - ), - stage_manifest=_stage_artifact_ref( - output_dir, - "08_dataset_assembly", - "stage_manifest", - assume_exists=True, - ), - data_flow_snapshot=_stage_artifact_ref( - output_dir, - "08_dataset_assembly", - "data_flow_snapshot", - assume_exists=True, - ), - artifact_inventory=_stage_artifact_ref( - output_dir, - "08_dataset_assembly", - "artifact_inventory", - assume_exists=True, - ), - conditional_readiness=_stage_artifact_ref( - output_dir, - "08_dataset_assembly", - "conditional_readiness", - assume_exists=True, - ), - diagnostics=_stage_diagnostics( - "08_dataset_assembly", - { - "policyengine_dataset": ( - str(policyengine_dataset_path.relative_to(output_dir)) - if policyengine_dataset_path is not None - else None - ), - "has_capital_gains_lots": capital_gains_lots_path is not None, - }, - ), + period = result.config.policyengine_dataset_year or 2024 + USMicroplexPipeline(result.config).export_policyengine_dataset( + result, + policyengine_dataset_path, + period=period, ) + capital_gains_lots_path, capital_gains_lots_summary = ( + _maybe_write_capital_gains_lot_artifact(result, output_dir) ) - stage_runtime_writer.start_stage("09_validation_benchmarking") - ( - resolved_target_provider, - resolved_baseline_dataset, - resolved_harness_slices, - resolved_harness_metadata, - ) = _resolve_policyengine_harness_context( - result, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - ) + if stage_runtime_writer is not None: + stage_runtime_writer.complete_stage( + USDatasetAssemblyOutputs( + policyengine_dataset=( + _stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "policyengine_dataset", + ) + if policyengine_dataset_path is not None + else None + ), + stage_manifest=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "stage_manifest", + assume_exists=True, + ), + data_flow_snapshot=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "data_flow_snapshot", + assume_exists=True, + ), + artifact_inventory=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "artifact_inventory", + assume_exists=True, + ), + conditional_readiness=_stage_artifact_ref( + output_dir, + "08_dataset_assembly", + "conditional_readiness", + assume_exists=True, + ), + diagnostics=_stage_diagnostics( + "08_dataset_assembly", + { + "policyengine_dataset": ( + str(policyengine_dataset_path.relative_to(output_dir)) + if policyengine_dataset_path is not None + else None + ), + "has_capital_gains_lots": ( + capital_gains_lots_path is not None + ), + }, + ), + ) + ) + stage_runtime_writer.start_stage("09_validation_benchmarking") + except Exception as exc: + if stage_runtime_writer is not None: + stage_runtime_writer.fail_stage("08_dataset_assembly", exc) + raise - harness_summary = None - native_scores_payload = ( - dict(precomputed_policyengine_native_scores) - if precomputed_policyengine_native_scores is not None - else None - ) - if precomputed_policyengine_harness_payload is not None: - harness_payload = dict(precomputed_policyengine_harness_payload) - policyengine_harness_path = resolve_us_stage_artifact_contract_path( - output_dir, - "09_validation_benchmarking", - "policyengine_harness", - ) - policyengine_harness_path.write_text( - json.dumps(harness_payload, indent=2, sort_keys=True) - ) - harness_summary = harness_payload.get("summary") - elif ( - not defer_policyengine_harness - and result.policyengine_tables is not None - and resolved_target_provider is not None - and resolved_baseline_dataset is not None - and resolved_harness_slices - ): - harness_period = result.config.policyengine_dataset_year or 2024 - harness_run = evaluate_policyengine_us_harness( - result.policyengine_tables, + try: + ( resolved_target_provider, + resolved_baseline_dataset, resolved_harness_slices, - baseline_dataset=str(resolved_baseline_dataset), - dataset_year=harness_period, - simulation_cls=result.config.policyengine_simulation_cls, - candidate_label="microplex", - baseline_label="policyengine_us_data", - metadata=resolved_harness_metadata, - cache=policyengine_comparison_cache, - ) - policyengine_harness_path = resolve_us_stage_artifact_contract_path( - output_dir, - "09_validation_benchmarking", - "policyengine_harness", + resolved_harness_metadata, + ) = _resolve_policyengine_harness_context( + result, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, ) - harness_run.save(policyengine_harness_path) - harness_payload = harness_run.to_dict() - harness_summary = harness_payload["summary"] - if native_scores_payload is not None: - policyengine_native_scores_path = resolve_us_stage_artifact_contract_path( - output_dir, - "09_validation_benchmarking", - "policyengine_native_scores", - ) - policyengine_native_scores_path.write_text( - json.dumps(native_scores_payload, indent=2, sort_keys=True) + harness_summary = None + native_scores_payload = ( + dict(precomputed_policyengine_native_scores) + if precomputed_policyengine_native_scores is not None + else None ) - elif ( - not defer_policyengine_native_score - and policyengine_dataset_path is not None - and resolved_baseline_dataset is not None - ): - try: - native_scores_payload = compute_us_pe_native_scores( - candidate_dataset_path=policyengine_dataset_path, - baseline_dataset_path=resolved_baseline_dataset, - period=result.config.policyengine_dataset_year or 2024, - policyengine_us_data_repo=policyengine_us_data_repo, + if precomputed_policyengine_harness_payload is not None: + harness_payload = dict(precomputed_policyengine_harness_payload) + policyengine_harness_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_harness", + ) + policyengine_harness_path.write_text( + json.dumps(harness_payload, indent=2, sort_keys=True) ) + harness_summary = harness_payload.get("summary") + elif ( + not defer_policyengine_harness + and result.policyengine_tables is not None + and resolved_target_provider is not None + and resolved_baseline_dataset is not None + and resolved_harness_slices + ): + harness_period = result.config.policyengine_dataset_year or 2024 + harness_run = evaluate_policyengine_us_harness( + result.policyengine_tables, + resolved_target_provider, + resolved_harness_slices, + baseline_dataset=str(resolved_baseline_dataset), + dataset_year=harness_period, + simulation_cls=result.config.policyengine_simulation_cls, + candidate_label="microplex", + baseline_label="policyengine_us_data", + metadata=resolved_harness_metadata, + cache=policyengine_comparison_cache, + ) + policyengine_harness_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_harness", + ) + harness_run.save(policyengine_harness_path) + harness_payload = harness_run.to_dict() + harness_summary = harness_payload["summary"] + + if native_scores_payload is not None: policyengine_native_scores_path = resolve_us_stage_artifact_contract_path( output_dir, "09_validation_benchmarking", @@ -1224,218 +1231,265 @@ def save_us_microplex_artifacts( policyengine_native_scores_path.write_text( json.dumps(native_scores_payload, indent=2, sort_keys=True) ) - except Exception: - if require_policyengine_native_score: - raise - - child_tax_unit_agi_drift_path = None - child_tax_unit_agi_drift_summary: dict[str, Any] | None = None - if enable_child_tax_unit_agi_drift: - try: - drift_path = resolve_us_stage_artifact_contract_path( - output_dir, - "09_validation_benchmarking", - "child_tax_unit_agi_drift", + elif ( + not defer_policyengine_native_score + and policyengine_dataset_path is not None + and resolved_baseline_dataset is not None + ): + try: + native_scores_payload = compute_us_pe_native_scores( + candidate_dataset_path=policyengine_dataset_path, + baseline_dataset_path=resolved_baseline_dataset, + period=result.config.policyengine_dataset_year or 2024, + policyengine_us_data_repo=policyengine_us_data_repo, + ) + policyengine_native_scores_path = ( + resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) + ) + policyengine_native_scores_path.write_text( + json.dumps(native_scores_payload, indent=2, sort_keys=True) + ) + except Exception: + if require_policyengine_native_score: + raise + + child_tax_unit_agi_drift_path = None + child_tax_unit_agi_drift_summary: dict[str, Any] | None = None + if enable_child_tax_unit_agi_drift: + try: + drift_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "child_tax_unit_agi_drift", + ) + variables = ( + child_tax_unit_agi_drift_variables + or DEFAULT_CHILD_TAX_UNIT_AGI_DRIFT_VARIABLES + ) + payload = summarize_child_tax_unit_agi_drift( + output_dir, + variables=variables, + ) + drift_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + child_tax_unit_agi_drift_path = drift_path + child_tax_unit_agi_drift_summary = ( + _summarize_child_tax_unit_agi_drift_ratios( + payload, + stage="calibrated", + variables=variables, + ) + ) + except Exception as exc: # pragma: no cover - diagnostic best-effort + child_tax_unit_agi_drift_summary = { + "error": f"{type(exc).__name__}: {exc}", + } + + manifest = { + "created_at": datetime.now(UTC).isoformat(), + "config": result.config.to_dict(), + "rows": { + "seed": int(len(result.seed_data)), + "synthetic": int(len(result.synthetic_data)), + "calibrated": int(len(result.calibrated_data)), + }, + "weights": { + "nonzero": result.n_nonzero_weights, + "total": result.total_weighted_population, + }, + "targets": { + "n_marginal_groups": len(result.targets.marginal), + "n_continuous": len(result.targets.continuous), + }, + "synthesis": result.synthesis_metadata, + "calibration": result.calibration_summary, + "artifacts": { + "seed_data": seed_data_path.name, + "scaffold_seed_data": ( + str(scaffold_seed_data_path.relative_to(output_dir)) + if scaffold_seed_data_path is not None + else None + ), + "synthetic_data": synthetic_data_path.name, + "calibrated_data": calibrated_data_path.name, + "targets": targets_path.name, + "synthesizer": synthesizer_path.name if synthesizer_path else None, + "source_plan": str(source_plan_path.relative_to(output_dir)), + "source_weight_diagnostics": source_weight_diagnostics_path.name, + "calibration_summary": str( + calibration_summary_path.relative_to(output_dir) + ), + "pre_calibration_policyengine_entity_tables": ( + str( + pre_calibration_policyengine_entity_tables_path.relative_to( + output_dir + ) + ) + if pre_calibration_policyengine_entity_tables_path is not None + and pre_calibration_policyengine_entity_tables_path.exists() + else None + ), + "policyengine_entity_tables": ( + str(policyengine_entity_tables_path.relative_to(output_dir)) + if policyengine_entity_tables_path is not None + else None + ), + "policyengine_dataset": ( + policyengine_dataset_path.name + if policyengine_dataset_path + else None + ), + "data_flow_snapshot": data_flow_snapshot_path.name, + "stage_manifest": stage_manifest_path.name, + "artifact_inventory": str( + artifact_inventory_path.relative_to(output_dir) + ), + "conditional_readiness": str( + conditional_readiness_path.relative_to(output_dir) + ), + "validation_evidence": ( + str(validation_evidence_path.relative_to(output_dir)) + if validation_evidence_path is not None + else None + ), + "policyengine_harness": ( + policyengine_harness_path.name + if policyengine_harness_path + else None + ), + "policyengine_native_scores": ( + policyengine_native_scores_path.name + if policyengine_native_scores_path is not None + else None + ), + "capital_gains_lots": ( + capital_gains_lots_path.name + if capital_gains_lots_path is not None + else None + ), + }, + } + if harness_summary is not None: + manifest["policyengine_harness"] = harness_summary + if native_scores_payload is not None: + manifest["policyengine_native_scores"] = dict( + native_scores_payload.get("summary", {}) ) - variables = ( - child_tax_unit_agi_drift_variables - or DEFAULT_CHILD_TAX_UNIT_AGI_DRIFT_VARIABLES + if child_tax_unit_agi_drift_path is not None: + manifest["artifacts"]["child_tax_unit_agi_drift"] = ( + child_tax_unit_agi_drift_path.name ) - payload = summarize_child_tax_unit_agi_drift( - output_dir, - variables=variables, + if child_tax_unit_agi_drift_summary is not None: + manifest.setdefault("diagnostics", {})["child_tax_unit_agi_drift"] = ( + child_tax_unit_agi_drift_summary ) - drift_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) - child_tax_unit_agi_drift_path = drift_path - child_tax_unit_agi_drift_summary = ( - _summarize_child_tax_unit_agi_drift_ratios( - payload, - stage="calibrated", - variables=variables, - ) + if capital_gains_lots_summary is not None: + manifest.setdefault("diagnostics", {})["capital_gains_lots"] = ( + capital_gains_lots_summary ) - except Exception as exc: # pragma: no cover - diagnostic best-effort - child_tax_unit_agi_drift_summary = { - "error": f"{type(exc).__name__}: {exc}", - } - - manifest = { - "created_at": datetime.now(UTC).isoformat(), - "config": result.config.to_dict(), - "rows": { - "seed": int(len(result.seed_data)), - "synthetic": int(len(result.synthetic_data)), - "calibrated": int(len(result.calibrated_data)), - }, - "weights": { - "nonzero": result.n_nonzero_weights, - "total": result.total_weighted_population, - }, - "targets": { - "n_marginal_groups": len(result.targets.marginal), - "n_continuous": len(result.targets.continuous), - }, - "synthesis": result.synthesis_metadata, - "calibration": result.calibration_summary, - "artifacts": { - "seed_data": seed_data_path.name, - "scaffold_seed_data": ( - str(scaffold_seed_data_path.relative_to(output_dir)) - if scaffold_seed_data_path is not None - else None - ), - "synthetic_data": synthetic_data_path.name, - "calibrated_data": calibrated_data_path.name, - "targets": targets_path.name, - "synthesizer": synthesizer_path.name if synthesizer_path else None, - "source_plan": str(source_plan_path.relative_to(output_dir)), - "source_weight_diagnostics": source_weight_diagnostics_path.name, - "calibration_summary": str( - calibration_summary_path.relative_to(output_dir) - ), - "policyengine_entity_tables": ( - str(policyengine_entity_tables_path.relative_to(output_dir)) - if policyengine_entity_tables_path is not None - else None - ), - "policyengine_dataset": ( - policyengine_dataset_path.name if policyengine_dataset_path else None - ), - "data_flow_snapshot": data_flow_snapshot_path.name, - "stage_manifest": stage_manifest_path.name, - "artifact_inventory": str(artifact_inventory_path.relative_to(output_dir)), - "conditional_readiness": str( - conditional_readiness_path.relative_to(output_dir) - ), - "validation_evidence": ( - str(validation_evidence_path.relative_to(output_dir)) - if validation_evidence_path is not None - else None - ), - "policyengine_harness": ( - policyengine_harness_path.name if policyengine_harness_path else None - ), - "policyengine_native_scores": ( - policyengine_native_scores_path.name - if policyengine_native_scores_path is not None - else None - ), - "capital_gains_lots": ( - capital_gains_lots_path.name - if capital_gains_lots_path is not None - else None - ), - }, - } - if harness_summary is not None: - manifest["policyengine_harness"] = harness_summary - if native_scores_payload is not None: - manifest["policyengine_native_scores"] = dict( - native_scores_payload.get("summary", {}) - ) - if child_tax_unit_agi_drift_path is not None: - manifest["artifacts"]["child_tax_unit_agi_drift"] = ( - child_tax_unit_agi_drift_path.name - ) - if child_tax_unit_agi_drift_summary is not None: - manifest.setdefault("diagnostics", {})["child_tax_unit_agi_drift"] = ( - child_tax_unit_agi_drift_summary - ) - if capital_gains_lots_summary is not None: - manifest.setdefault("diagnostics", {})["capital_gains_lots"] = ( - capital_gains_lots_summary + manifest.setdefault("diagnostics", {})["source_weight_diagnostics"] = dict( + source_weight_diagnostics_payload.get("summary", {}) ) - manifest.setdefault("diagnostics", {})["source_weight_diagnostics"] = dict( - source_weight_diagnostics_payload.get("summary", {}) - ) - if harness_summary is not None or native_scores_payload is not None: - resolved_run_registry_path = Path( - run_registry_path or output_dir.parent / "run_registry.jsonl" - ) - run_entry = build_us_microplex_run_registry_entry( - artifact_dir=output_dir, - manifest_path=manifest_path, - manifest=manifest, - policyengine_harness_path=policyengine_harness_path, - policyengine_harness_payload=harness_payload, - metadata=dict(run_registry_metadata or {}), - ) - recorded_entry = append_us_microplex_run_registry_entry( - resolved_run_registry_path, - run_entry, - ) - resolved_run_index_path = append_us_microplex_run_index_entry( - run_index_path or output_dir.parent, - recorded_entry, - policyengine_harness_payload=harness_payload, - ) - manifest["run_registry"] = { - "path": str(resolved_run_registry_path), - "artifact_id": recorded_entry.artifact_id, - "improved_candidate_frontier": recorded_entry.improved_candidate_frontier, - "improved_delta_frontier": recorded_entry.improved_delta_frontier, - "improved_composite_frontier": recorded_entry.improved_composite_frontier, - "improved_native_frontier": recorded_entry.improved_native_frontier, - "default_frontier_metric": ( - "enhanced_cps_native_loss_delta" - if native_scores_payload is not None - else "candidate_composite_parity_loss" - ), - } - manifest["run_index"] = { - "path": str(resolved_run_index_path), - "artifact_id": recorded_entry.artifact_id, - } - if stage_runtime_writer is not None: - stage_runtime_writer.manifest_payload = manifest - stage9_summary = _stage9_benchmark_summary(manifest) - if stage9_summary: - stage_runtime_writer.complete_stage( - USValidationBenchmarkingOutputs( - validation_evidence=_stage_artifact_ref( + if harness_summary is not None or native_scores_payload is not None: + resolved_run_registry_path = Path( + run_registry_path or output_dir.parent / "run_registry.jsonl" + ) + run_entry = build_us_microplex_run_registry_entry( + artifact_dir=output_dir, + manifest_path=manifest_path, + manifest=manifest, + policyengine_harness_path=policyengine_harness_path, + policyengine_harness_payload=harness_payload, + metadata=dict(run_registry_metadata or {}), + ) + recorded_entry = append_us_microplex_run_registry_entry( + resolved_run_registry_path, + run_entry, + ) + resolved_run_index_path = append_us_microplex_run_index_entry( + run_index_path or output_dir.parent, + recorded_entry, + policyengine_harness_payload=harness_payload, + ) + manifest["run_registry"] = { + "path": str(resolved_run_registry_path), + "artifact_id": recorded_entry.artifact_id, + "improved_candidate_frontier": recorded_entry.improved_candidate_frontier, + "improved_delta_frontier": recorded_entry.improved_delta_frontier, + "improved_composite_frontier": recorded_entry.improved_composite_frontier, + "improved_native_frontier": recorded_entry.improved_native_frontier, + "default_frontier_metric": ( + "enhanced_cps_native_loss_delta" + if native_scores_payload is not None + else "candidate_composite_parity_loss" + ), + } + manifest["run_index"] = { + "path": str(resolved_run_index_path), + "artifact_id": recorded_entry.artifact_id, + } + if stage_runtime_writer is not None: + stage_runtime_writer.manifest_payload = manifest + stage9_summary = _stage9_benchmark_summary(manifest) + if stage9_summary: + if validation_evidence_path is not None: + write_us_validation_evidence_manifest( output_dir, - "09_validation_benchmarking", - "validation_evidence", - assume_exists=True, - ), - benchmark_summary=stage9_summary, - policyengine_harness=( - _stage_artifact_ref( + validation_evidence_path, + manifest_payload=manifest, + ) + stage_runtime_writer.complete_stage( + USValidationBenchmarkingOutputs( + validation_evidence=_stage_artifact_ref( output_dir, "09_validation_benchmarking", - "policyengine_harness", - ) - if policyengine_harness_path is not None - else None - ), - policyengine_native_scores=( - _stage_artifact_ref( - output_dir, + "validation_evidence", + ), + benchmark_summary=stage9_summary, + policyengine_harness=( + _stage_artifact_ref( + output_dir, + "09_validation_benchmarking", + "policyengine_harness", + ) + if policyengine_harness_path is not None + else None + ), + policyengine_native_scores=( + _stage_artifact_ref( + output_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) + if policyengine_native_scores_path is not None + else None + ), + diagnostics=_stage_diagnostics( "09_validation_benchmarking", - "policyengine_native_scores", - ) - if policyengine_native_scores_path is not None - else None - ), - diagnostics=_stage_diagnostics( - "09_validation_benchmarking", - stage9_summary, - ), + stage9_summary, + ), + ) ) - ) + else: + stage_runtime_writer.defer_stage( + "09_validation_benchmarking", + "No validation or benchmark evidence was configured for this run.", + ) + manifest = stage_runtime_writer.finalize_from_artifact_manifest(manifest) else: - stage_runtime_writer.defer_stage( - "09_validation_benchmarking", - "No validation or benchmark evidence was configured for this run.", + manifest = write_us_stage_run_manifests_from_artifact_manifest( + output_dir, + manifest, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) - manifest = stage_runtime_writer.finalize_from_artifact_manifest(manifest) - else: - manifest = write_us_stage_run_manifests_from_artifact_manifest( - output_dir, - manifest, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) + except Exception as exc: + if stage_runtime_writer is not None: + stage_runtime_writer.fail_stage("09_validation_benchmarking", exc) + raise assert_valid_benchmark_artifact_manifest( manifest, artifact_dir=output_dir, diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py index 8c7e681..ddefa52 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py @@ -1218,7 +1218,7 @@ def _load_checkpoint_versioned_artifacts( artifact_root, artifacts, "policyengine_entity_tables", - stage_id="06_policyengine_entities", + stage_id="07_calibration", ), calibration_summary=_resolve_saved_stage_artifact_path( artifact_root, diff --git a/src/microplex_us/pipelines/stage9_replay.py b/src/microplex_us/pipelines/stage9_replay.py index a45a71c..65ac675 100644 --- a/src/microplex_us/pipelines/stage9_replay.py +++ b/src/microplex_us/pipelines/stage9_replay.py @@ -159,6 +159,7 @@ def _validated_stage8_dataset_path( dataset_path = Path(str(dataset_value)) if not dataset_path.is_absolute(): dataset_path = artifact_root / dataset_path + dataset_path = dataset_path.expanduser().resolve() if not dataset_path.exists(): raise FileNotFoundError(f"Stage 8 dataset artifact is missing: {dataset_path}") @@ -181,7 +182,14 @@ def _validated_stage8_dataset_path( serialized_dataset = stage8_outputs.get("policyengine_dataset") if isinstance(serialized_dataset, dict): output_path = serialized_dataset.get("path") - if output_path and Path(str(output_path)).name != dataset_path.name: + if ( + output_path + and _resolve_artifact_path( + artifact_root, + output_path, + ) + != dataset_path + ): raise ValueError( "Stage 8 dataset output does not match the root manifest " "policyengine_dataset artifact" @@ -224,6 +232,13 @@ def _relative_to_root(path: Path, artifact_root: Path) -> str: return str(path) +def _resolve_artifact_path(artifact_root: Path, value: object) -> Path: + path = Path(str(value)) + if not path.is_absolute(): + path = artifact_root / path + return path.expanduser().resolve() + + def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser( description="Rerun Stage 9 validation evidence against a saved Stage 8 dataset." diff --git a/src/microplex_us/pipelines/stage_artifacts.py b/src/microplex_us/pipelines/stage_artifacts.py index 4480a54..0779eac 100644 --- a/src/microplex_us/pipelines/stage_artifacts.py +++ b/src/microplex_us/pipelines/stage_artifacts.py @@ -462,15 +462,28 @@ def load_us_policyengine_entity_stage_artifacts( ) -> USPolicyEngineEntityStageArtifacts: """Load the saved Stage 6 PolicyEngine entity-table bundle.""" - metadata_path = resolve_us_stage_artifact_path_checked( - artifact_dir, - "06_policyengine_entities", - "policyengine_entity_tables", - manifest_payload=manifest_payload, - stage_manifest=stage_manifest, - expected_format="policyengine_entity_bundle", + try: + metadata_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "06_policyengine_entities", + "pre_calibration_policyengine_entity_tables", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="policyengine_entity_bundle", + ) + except (KeyError, FileNotFoundError): + metadata_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "06_policyengine_entities", + "policyengine_entity_tables", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="policyengine_entity_bundle", + ) + bundle, metadata = load_us_policyengine_entity_stage_artifact( + metadata_path, + expected_stage="post_microsim", ) - bundle, metadata = load_us_policyengine_entity_stage_artifact(metadata_path) return USPolicyEngineEntityStageArtifacts( bundle=bundle, metadata=metadata, diff --git a/src/microplex_us/pipelines/stage_contracts.py b/src/microplex_us/pipelines/stage_contracts.py index e3a337c..535b6b0 100644 --- a/src/microplex_us/pipelines/stage_contracts.py +++ b/src/microplex_us/pipelines/stage_contracts.py @@ -712,8 +712,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] ), outputs=( _artifact_resource( - "policyengine_entity_tables", - "Reloadable PolicyEngine entity-table checkpoint.", + "pre_calibration_policyengine_entity_tables", + "Reloadable pre-calibration PolicyEngine entity-table checkpoint.", stage_id="06_policyengine_entities", ), _stage_output_resource( @@ -724,8 +724,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] ), artifacts=( USStageArtifactContract( - key="policyengine_entity_tables", - description="Reloadable PE entity-table bundle saved as parquet files plus metadata.", + key="pre_calibration_policyengine_entity_tables", + description="Reloadable pre-calibration PE entity-table bundle saved as parquet files plus metadata.", path_hint="stage_artifacts/06_policyengine_entities/metadata.json", required=True, resume_role="manual_resume", @@ -761,8 +761,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] produces=("calibrated tables", "calibration summary", "target ledger"), inputs=( _artifact_resource( - "policyengine_entity_tables", - "PolicyEngine entity-table checkpoint from Stage 6.", + "pre_calibration_policyengine_entity_tables", + "Pre-calibration PolicyEngine entity-table checkpoint from Stage 6.", stage_id="06_policyengine_entities", ), _external_resource( @@ -830,6 +830,11 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] "Stage-local calibration summary.", stage_id="07_calibration", ), + _artifact_resource( + "policyengine_entity_tables", + "Calibrated PolicyEngine entity-table bundle used for dataset export.", + stage_id="07_calibration", + ), _stage_output_resource( "target_ledger", "Structured target-resolution and calibration target ledger.", @@ -864,6 +869,15 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] format="json", hash_mode="file_sha256", ), + USStageArtifactContract( + key="policyengine_entity_tables", + description="Calibrated PE entity-table bundle used for dataset export.", + path_hint="stage_artifacts/07_calibration/policyengine_entity_tables/metadata.json", + required=True, + resume_role="post_artifact_evidence", + format="policyengine_entity_bundle", + hash_mode="directory_sha256", + ), ), diagnostics=( "supported and unsupported targets", @@ -905,8 +919,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] ), _artifact_resource( "policyengine_entity_tables", - "PolicyEngine entity-table checkpoint from Stage 6.", - stage_id="06_policyengine_entities", + "Calibrated PolicyEngine entity-table checkpoint from Stage 7.", + stage_id="07_calibration", ), _config_resource( "policyengine_dataset_year", diff --git a/src/microplex_us/pipelines/stage_metrics.py b/src/microplex_us/pipelines/stage_metrics.py index f09c4e8..6668faf 100644 --- a/src/microplex_us/pipelines/stage_metrics.py +++ b/src/microplex_us/pipelines/stage_metrics.py @@ -50,7 +50,7 @@ def stage_metrics(stage_id: str, *, manifest: dict[str, Any]) -> list[USStageMet return [ { "label": "Entity bundle", - "value": artifacts.get("policyengine_entity_tables"), + "value": artifacts.get("pre_calibration_policyengine_entity_tables"), } ] if stage_id == "07_calibration": diff --git a/src/microplex_us/pipelines/stage_policyengine_artifacts.py b/src/microplex_us/pipelines/stage_policyengine_artifacts.py index 5e6345c..ac0a3fa 100644 --- a/src/microplex_us/pipelines/stage_policyengine_artifacts.py +++ b/src/microplex_us/pipelines/stage_policyengine_artifacts.py @@ -6,13 +6,14 @@ from pathlib import Path from typing import Any -from microplex_us.pipelines.stage_manifest_io import write_json_atomically -from microplex_us.pipelines.stage_manifest_types import ( - US_POLICYENGINE_ENTITY_STAGE_ID, - US_STAGE_ARTIFACT_ROOT, +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, ) +from microplex_us.pipelines.stage_manifest_io import write_json_atomically +from microplex_us.pipelines.stage_manifest_types import US_POLICYENGINE_ENTITY_STAGE_ID from microplex_us.policyengine.us import ( PolicyEngineUSEntityTableBundle, + USPipelineCheckpointStage, load_us_pipeline_checkpoint, save_us_pipeline_checkpoint, ) @@ -21,31 +22,43 @@ def write_us_policyengine_entity_stage_artifact( bundle: PolicyEngineUSEntityTableBundle, artifact_root: str | Path, + *, + stage_id: str = US_POLICYENGINE_ENTITY_STAGE_ID, + artifact_key: str = "pre_calibration_policyengine_entity_tables", + checkpoint_stage: USPipelineCheckpointStage = "post_microsim", ) -> Path: - """Persist a Stage 6 PE entity-table checkpoint under a saved-run root.""" + """Persist a PE entity-table checkpoint under a saved-run root.""" + metadata_path = resolve_us_stage_artifact_contract_path( + artifact_root, + stage_id, + artifact_key, + ) stage_dir = save_us_pipeline_checkpoint( bundle, - Path(artifact_root) / US_STAGE_ARTIFACT_ROOT / US_POLICYENGINE_ENTITY_STAGE_ID, - stage="post_microsim", + metadata_path.parent, + stage=checkpoint_stage, ) - metadata_path = stage_dir / "metadata.json" + metadata_path = stage_dir / metadata_path.name metadata = json.loads(metadata_path.read_text()) - metadata["stageId"] = US_POLICYENGINE_ENTITY_STAGE_ID + metadata["stageId"] = stage_id + metadata["artifactKey"] = artifact_key write_json_atomically(metadata_path, metadata) return metadata_path def load_us_policyengine_entity_stage_artifact( path: str | Path, + *, + expected_stage: USPipelineCheckpointStage | None = "post_microsim", ) -> tuple[PolicyEngineUSEntityTableBundle, dict[str, Any]]: - """Load a Stage 6 PE entity-table bundle artifact.""" + """Load a PE entity-table bundle artifact.""" input_path = Path(path) checkpoint_dir = input_path if input_path.is_dir() else input_path.parent bundle, metadata = load_us_pipeline_checkpoint( checkpoint_dir, - expected_stage="post_microsim", + expected_stage=expected_stage, ) return bundle, metadata diff --git a/src/microplex_us/pipelines/stage_run.py b/src/microplex_us/pipelines/stage_run.py index 508dd08..064b7f0 100644 --- a/src/microplex_us/pipelines/stage_run.py +++ b/src/microplex_us/pipelines/stage_run.py @@ -352,7 +352,7 @@ class USDonorSynthesisOutputs(USStageOutputManifest): @dataclass(frozen=True) class USPolicyEngineEntityOutputs(USStageOutputManifest): stage_id: str = field(default="06_policyengine_entities", init=False) - policyengine_entity_tables: USArtifactRef | None = None + pre_calibration_policyengine_entity_tables: USArtifactRef | None = None materialized_policyengine_inputs: Mapping[str, Any] = field(default_factory=dict) @@ -362,6 +362,7 @@ class USCalibrationOutputs(USStageOutputManifest): calibrated_data: USArtifactRef | None = None targets: USArtifactRef | None = None calibration_summary: USArtifactRef | None = None + policyengine_entity_tables: USArtifactRef | None = None target_ledger: Mapping[str, Any] = field(default_factory=dict) @@ -944,18 +945,23 @@ def build_us_stage_output_manifests_from_artifact_manifest( ), ), USPolicyEngineEntityOutputs( - policyengine_entity_tables=_artifact_ref( + pre_calibration_policyengine_entity_tables=_artifact_ref( root, artifacts, - "policyengine_entity_tables", + "pre_calibration_policyengine_entity_tables", "06_policyengine_entities", ), materialized_policyengine_inputs=_policyengine_entity_metadata_summary( root, artifacts, + artifact_key="pre_calibration_policyengine_entity_tables", ), diagnostics=_diagnostics("06_policyengine_entities", manifest), - complete=_artifact_exists(root, artifacts, "policyengine_entity_tables"), + complete=_artifact_exists( + root, + artifacts, + "pre_calibration_policyengine_entity_tables", + ), ), USCalibrationOutputs( calibrated_data=_artifact_ref( @@ -969,11 +975,22 @@ def build_us_stage_output_manifests_from_artifact_manifest( "07_calibration", category="diagnostic", ), + policyengine_entity_tables=_artifact_ref( + root, + artifacts, + "policyengine_entity_tables", + "07_calibration", + ), target_ledger={"target_count": manifest.get("targets", {})}, diagnostics=_diagnostics("07_calibration", manifest), complete=all( _artifact_exists(root, artifacts, key) - for key in ("calibrated_data", "targets", "calibration_summary") + for key in ( + "calibrated_data", + "targets", + "calibration_summary", + "policyengine_entity_tables", + ) ), ), USDatasetAssemblyOutputs( @@ -1230,8 +1247,10 @@ def _path_for_manifest(path: Path, artifact_root: Path) -> str: def _policyengine_entity_metadata_summary( artifact_root: Path, artifacts: Mapping[str, Any], + *, + artifact_key: str = "policyengine_entity_tables", ) -> dict[str, Any]: - declared = artifacts.get("policyengine_entity_tables") + declared = artifacts.get(artifact_key) if declared is None: return {} path = Path(str(declared)) @@ -1325,7 +1344,9 @@ def _default_stage_diagnostic_summary( "backend": synthesis.get("backend"), } if stage_id == "06_policyengine_entities": - return {"entity_tables": artifacts.get("policyengine_entity_tables")} + return { + "entity_tables": artifacts.get("pre_calibration_policyengine_entity_tables") + } if stage_id == "07_calibration": return { "calibrated_rows": rows.get("calibrated"), diff --git a/src/microplex_us/pipelines/stage_runtime.py b/src/microplex_us/pipelines/stage_runtime.py index a1a20ab..59392de 100644 --- a/src/microplex_us/pipelines/stage_runtime.py +++ b/src/microplex_us/pipelines/stage_runtime.py @@ -193,7 +193,10 @@ def complete_stage(self, outputs: USStageOutputManifest) -> dict[str, Any]: input_overrides=self._input_overrides_for_stage(outputs.stage_id), ) self._write_stage_payload(outputs.stage_id, payload) - self._refresh_aggregate() + if outputs.stage_id == "08_dataset_assembly": + self.manifest_payload = self._run_writer.write_manifest_files() + else: + self._refresh_aggregate() return payload def fail_stage( @@ -364,7 +367,8 @@ def _validate_start_transition(self, stage_id: str) -> None: previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] previous_payload = self._stage_payload(previous_stage_id) if previous_payload.get("lifecycleStatus") == "complete": - self._validate_completed_previous_stage(previous_stage_id, previous_payload) + self._validate_completed_stage(previous_stage_id, previous_payload) + self._validate_required_start_inputs(stage_id) return contract = get_us_pipeline_stage_contract(stage_id) required_previous_inputs = tuple( @@ -376,13 +380,41 @@ def _validate_start_transition(self, stage_id: str) -> None: self._override_satisfies(stage_id, resource.key) for resource in required_previous_inputs ): + self._validate_required_start_inputs(stage_id) return raise ValueError( f"{stage_id} requires {previous_stage_id} to be complete before start, " "unless explicit stage input overrides are enabled" ) - def _validate_completed_previous_stage( + def _validate_required_start_inputs(self, stage_id: str) -> None: + contract = get_us_pipeline_stage_contract(stage_id) + missing_inputs: list[str] = [] + for resource in contract.inputs: + if ( + not resource.required + or resource.stage_id is None + or resource.kind not in {"artifact", "manifest", "stage_output"} + or self._override_satisfies(stage_id, resource.key) + ): + continue + payload = self._stage_payload(resource.stage_id) + if payload.get("lifecycleStatus") != "complete": + missing_inputs.append(f"{resource.stage_id}.{resource.key}") + continue + self._validate_completed_stage(resource.stage_id, payload) + outputs = payload.get("outputs") + if not isinstance(outputs, Mapping) or not _serialized_output_is_available( + outputs.get(resource.key) + ): + missing_inputs.append(f"{resource.stage_id}.{resource.key}") + if missing_inputs: + raise ValueError( + f"{stage_id} is missing required stage input(s) before start: " + f"{', '.join(missing_inputs)}" + ) + + def _validate_completed_stage( self, stage_id: str, payload: Mapping[str, Any], diff --git a/src/microplex_us/pipelines/stage_status.py b/src/microplex_us/pipelines/stage_status.py index 410f57d..2605ac4 100644 --- a/src/microplex_us/pipelines/stage_status.py +++ b/src/microplex_us/pipelines/stage_status.py @@ -61,7 +61,7 @@ def stage_status( if stage_id == "06_policyengine_entities": if artifact_missing(artifacts): return "incomplete" - if artifact_exists(artifacts, "policyengine_entity_tables"): + if artifact_exists(artifacts, "pre_calibration_policyengine_entity_tables"): return "ready" if manifest_artifact_exists( manifest, diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index ac28036..8a87cda 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -2524,14 +2524,17 @@ def build_from_frames( write_us_policyengine_entity_stage_artifact( synthetic_tables, self.stage_runtime_writer.artifact_root, + stage_id="06_policyengine_entities", + artifact_key="pre_calibration_policyengine_entity_tables", + checkpoint_stage="post_microsim", ) entity_summary = _runtime_policyengine_table_summary(synthetic_tables) self.stage_runtime_writer.complete_stage( USPolicyEngineEntityOutputs( - policyengine_entity_tables=_runtime_stage_artifact_ref( + pre_calibration_policyengine_entity_tables=_runtime_stage_artifact_ref( self.stage_runtime_writer, "06_policyengine_entities", - "policyengine_entity_tables", + "pre_calibration_policyengine_entity_tables", ), materialized_policyengine_inputs=entity_summary, diagnostics=_runtime_stage_diagnostics( @@ -2588,6 +2591,13 @@ def build_from_frames( persons=int(len(policyengine_tables.persons)), ) if self.stage_runtime_writer is not None: + write_us_policyengine_entity_stage_artifact( + policyengine_tables, + self.stage_runtime_writer.artifact_root, + stage_id="07_calibration", + artifact_key="policyengine_entity_tables", + checkpoint_stage="post_calibration", + ) calibrated_data_path = _runtime_stage_artifact_path( self.stage_runtime_writer, "07_calibration", @@ -2627,6 +2637,11 @@ def build_from_frames( "07_calibration", "calibration_summary", ), + policyengine_entity_tables=_runtime_stage_artifact_ref( + self.stage_runtime_writer, + "07_calibration", + "policyengine_entity_tables", + ), target_ledger=target_ledger, diagnostics=_runtime_stage_diagnostics( "07_calibration", diff --git a/src/microplex_us/policyengine/us.py b/src/microplex_us/policyengine/us.py index 92519b1..52961d0 100644 --- a/src/microplex_us/policyengine/us.py +++ b/src/microplex_us/policyengine/us.py @@ -161,8 +161,14 @@ def table_for(self, entity: EntityType) -> pd.DataFrame: "marital_units", ) +USPipelineCheckpointStage = Literal[ + "post_imputation", + "post_microsim", + "post_calibration", +] + _ALLOWED_CHECKPOINT_STAGES: frozenset[str] = frozenset( - {"post_imputation", "post_microsim"} + {"post_imputation", "post_microsim", "post_calibration"} ) @@ -170,7 +176,7 @@ def save_us_pipeline_checkpoint( bundle: PolicyEngineUSEntityTableBundle, path: str | Path, *, - stage: Literal["post_imputation", "post_microsim"], + stage: USPipelineCheckpointStage, ) -> Path: """Persist a pipeline-stage bundle to ``path`` as parquet + metadata. @@ -183,6 +189,7 @@ def save_us_pipeline_checkpoint( microsim + calibration. * ``"post_microsim"`` — after microsim materialization, before the calibration fit loop. Resuming from here reruns only calibration. + * ``"post_calibration"`` — after calibration, ready for dataset export. """ import json import shutil @@ -216,7 +223,7 @@ def save_us_pipeline_checkpoint( def load_us_pipeline_checkpoint( path: str | Path, *, - expected_stage: Literal["post_imputation", "post_microsim"] | None = None, + expected_stage: USPipelineCheckpointStage | None = None, ) -> tuple[PolicyEngineUSEntityTableBundle, dict[str, Any]]: """Load a pipeline-stage bundle previously saved by ``save_us_pipeline_checkpoint``. diff --git a/tests/pipelines/test_artifacts.py b/tests/pipelines/test_artifacts.py index 0bf1c63..74c881b 100644 --- a/tests/pipelines/test_artifacts.py +++ b/tests/pipelines/test_artifacts.py @@ -404,6 +404,10 @@ def test_writes_expected_files(self, tmp_path): ) assert ( manifest["artifacts"]["policyengine_entity_tables"] + == "stage_artifacts/07_calibration/policyengine_entity_tables/metadata.json" + ) + assert ( + manifest["artifacts"]["pre_calibration_policyengine_entity_tables"] == "stage_artifacts/06_policyengine_entities/metadata.json" ) assert ( diff --git a/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py b/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py index 6557317..43b8ad7 100644 --- a/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py +++ b/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py @@ -1177,6 +1177,10 @@ def test_load_checkpoint_versioned_artifacts_hydrates_stage_sidecar_paths( stage_artifacts / "04_seed_scaffold" / "scaffold_seed_data.parquet", stage_artifacts / "06_policyengine_entities" / "metadata.json", stage_artifacts / "07_calibration" / "calibration_summary.json", + stage_artifacts + / "07_calibration" + / "policyengine_entity_tables" + / "metadata.json", stage_artifacts / "09_validation_benchmarking" / "evidence_manifest.json", stage_artifacts / "artifact_inventory.json", stage_artifacts / "conditional_readiness.json", @@ -1198,9 +1202,12 @@ def test_load_checkpoint_versioned_artifacts_hydrates_stage_sidecar_paths( "scaffold_seed_data": ( "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet" ), - "policyengine_entity_tables": ( + "pre_calibration_policyengine_entity_tables": ( "stage_artifacts/06_policyengine_entities/metadata.json" ), + "policyengine_entity_tables": ( + "stage_artifacts/07_calibration/policyengine_entity_tables/metadata.json" + ), "calibration_summary": ( "stage_artifacts/07_calibration/calibration_summary.json" ), @@ -1233,7 +1240,10 @@ def test_load_checkpoint_versioned_artifacts_hydrates_stage_sidecar_paths( stage_artifacts / "04_seed_scaffold" / "scaffold_seed_data.parquet" ) assert paths.policyengine_entity_tables == ( - stage_artifacts / "06_policyengine_entities" / "metadata.json" + stage_artifacts + / "07_calibration" + / "policyengine_entity_tables" + / "metadata.json" ) assert paths.calibration_summary == ( stage_artifacts / "07_calibration" / "calibration_summary.json" diff --git a/tests/pipelines/test_stage9_replay.py b/tests/pipelines/test_stage9_replay.py index 2169937..a82cf36 100644 --- a/tests/pipelines/test_stage9_replay.py +++ b/tests/pipelines/test_stage9_replay.py @@ -79,6 +79,24 @@ def test_stage9_replay_rejects_incomplete_stage8(tmp_path): ) +def test_stage9_replay_rejects_stage8_dataset_path_mismatch(tmp_path): + artifact_dir = _write_stage8_bundle(tmp_path) + stage8_manifest_path = ( + artifact_dir / "stage_artifacts" / "manifests" / "08_dataset_assembly.json" + ) + stage8_manifest = json.loads(stage8_manifest_path.read_text()) + stage8_manifest["outputs"]["policyengine_dataset"]["path"] = ( + "other/policyengine_us.h5" + ) + stage8_manifest_path.write_text(json.dumps(stage8_manifest)) + + with pytest.raises(ValueError, match="does not match"): + replay_us_stage9_validation_benchmarking( + artifact_dir, + precomputed_policyengine_native_scores={"summary": {"loss": 1.0}}, + ) + + def test_stage9_replay_cli_smoke(tmp_path, capsys): artifact_dir = _write_stage8_bundle(tmp_path) payload_path = tmp_path / "native_scores.json" diff --git a/tests/pipelines/test_stage_artifacts.py b/tests/pipelines/test_stage_artifacts.py index 567243b..fba96ed 100644 --- a/tests/pipelines/test_stage_artifacts.py +++ b/tests/pipelines/test_stage_artifacts.py @@ -28,7 +28,9 @@ def test_build_us_stage_artifact_inventory_hashes_files_and_directories(tmp_path): (tmp_path / "seed_data.parquet").write_text("seed") (tmp_path / "synthetic_data.parquet").write_text("synthetic") - source_plan = tmp_path / "stage_artifacts" / "03_source_planning" / "source_plan.json" + source_plan = ( + tmp_path / "stage_artifacts" / "03_source_planning" / "source_plan.json" + ) source_plan.parent.mkdir(parents=True) source_plan.write_text("{}") entity_dir = tmp_path / "stage_artifacts" / "06_policyengine_entities" @@ -44,7 +46,7 @@ def test_build_us_stage_artifact_inventory_hashes_files_and_directories(tmp_path "seed_data": "seed_data.parquet", "synthetic_data": "synthetic_data.parquet", "source_plan": "stage_artifacts/03_source_planning/source_plan.json", - "policyengine_entity_tables": ( + "pre_calibration_policyengine_entity_tables": ( "stage_artifacts/06_policyengine_entities/metadata.json" ), }, @@ -57,22 +59,23 @@ def test_build_us_stage_artifact_inventory_hashes_files_and_directories(tmp_path ) records = { - (record["stageId"], record["key"]): record - for record in inventory["artifacts"] + (record["stageId"], record["key"]): record for record in inventory["artifacts"] } - assert records[("05_donor_integration_synthesis", "synthetic_data")][ - "classification" - ] == "manual_replay" - assert records[("05_donor_integration_synthesis", "synthetic_data")][ - "hashStatus" - ] == "hashed" - assert records[("05_donor_integration_synthesis", "synthetic_data")][ - "contentHash" - ] + assert ( + records[("05_donor_integration_synthesis", "synthetic_data")]["classification"] + == "manual_replay" + ) + assert ( + records[("05_donor_integration_synthesis", "synthetic_data")]["hashStatus"] + == "hashed" + ) + assert records[("05_donor_integration_synthesis", "synthetic_data")]["contentHash"] assert records[("03_source_planning", "source_plan")]["classification"] == ( "diagnostic_only" ) - entity_record = records[("06_policyengine_entities", "policyengine_entity_tables")] + entity_record = records[ + ("06_policyengine_entities", "pre_calibration_policyengine_entity_tables") + ] assert entity_record["classification"] == "manual_resume" assert entity_record["fileCount"] == 2 assert entity_record["hashStatus"] == "hashed" @@ -95,15 +98,16 @@ def test_build_us_stage_artifact_inventory_classifies_missing_and_contract_only( inventory = build_us_stage_artifact_inventory(tmp_path, manifest_payload=manifest) records = { - (record["stageId"], record["key"]): record - for record in inventory["artifacts"] + (record["stageId"], record["key"]): record for record in inventory["artifacts"] } - assert records[("05_donor_integration_synthesis", "synthetic_data")][ - "classification" - ] == "missing_required" - assert records[("05_donor_integration_synthesis", "synthesizer")][ - "classification" - ] == "contract_only" + assert ( + records[("05_donor_integration_synthesis", "synthetic_data")]["classification"] + == "missing_required" + ) + assert ( + records[("05_donor_integration_synthesis", "synthesizer")]["classification"] + == "contract_only" + ) def test_build_us_stage_artifact_inventory_skips_large_file_hashes(tmp_path): @@ -123,9 +127,7 @@ def test_build_us_stage_artifact_inventory_skips_large_file_hashes(tmp_path): ) record = next( - record - for record in inventory["artifacts"] - if record["key"] == "synthetic_data" + record for record in inventory["artifacts"] if record["key"] == "synthetic_data" ) assert record["hashStatus"] == "too_large" assert record["contentHash"] is None @@ -186,7 +188,9 @@ def test_load_us_candidate_stage_artifacts_reads_stage5_boundary(tmp_path): pd.testing.assert_frame_equal(loaded.seed_data, seed) pd.testing.assert_frame_equal(loaded.synthetic_data, synthetic) - assert loaded.artifact_paths["synthetic_data"] == tmp_path / "synthetic_data.parquet" + assert ( + loaded.artifact_paths["synthetic_data"] == tmp_path / "synthetic_data.parquet" + ) def test_load_us_seed_scaffold_stage_artifacts_reads_stage4_boundary(tmp_path): @@ -274,7 +278,7 @@ def test_load_us_policyengine_entity_stage_artifacts_reads_checkpoint(tmp_path): "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, "calibration": {}, "artifacts": { - "policyengine_entity_tables": ( + "pre_calibration_policyengine_entity_tables": ( "stage_artifacts/06_policyengine_entities/metadata.json" ), }, @@ -350,7 +354,9 @@ def test_load_us_dataset_assembly_artifacts_resolves_stage8_paths(tmp_path): assert loaded.stage_manifest == tmp_path / "stage_manifest.json" assert loaded.data_flow_snapshot == tmp_path / "data_flow_snapshot.json" assert loaded.artifact_inventory == stage_artifacts / "artifact_inventory.json" - assert loaded.conditional_readiness == stage_artifacts / "conditional_readiness.json" + assert ( + loaded.conditional_readiness == stage_artifacts / "conditional_readiness.json" + ) def test_stage_artifact_checked_resolver_enforces_format_and_existence(tmp_path): diff --git a/tests/pipelines/test_stage_manifest.py b/tests/pipelines/test_stage_manifest.py index bf0968d..a8beea3 100644 --- a/tests/pipelines/test_stage_manifest.py +++ b/tests/pipelines/test_stage_manifest.py @@ -38,6 +38,9 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): calibration_path = tmp_path / "stage_artifacts" / "07_calibration" calibration_path.mkdir(parents=True) (calibration_path / "calibration_summary.json").write_text("{}") + final_entity_path = calibration_path / "policyengine_entity_tables" + final_entity_path.mkdir(parents=True) + (final_entity_path / "metadata.json").write_text("{}") (tmp_path / "stage_manifest.json").write_text("{}") (tmp_path / "data_flow_snapshot.json").write_text("{}") (tmp_path / "stage_artifacts" / "artifact_inventory.json").write_text("{}") @@ -62,12 +65,15 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): "calibrated_data": "calibrated_data.parquet", "targets": "targets.json", "source_plan": "stage_artifacts/03_source_planning/source_plan.json", - "policyengine_entity_tables": ( + "pre_calibration_policyengine_entity_tables": ( "stage_artifacts/06_policyengine_entities/metadata.json" ), "calibration_summary": ( "stage_artifacts/07_calibration/calibration_summary.json" ), + "policyengine_entity_tables": ( + "stage_artifacts/07_calibration/policyengine_entity_tables/metadata.json" + ), "policyengine_dataset": "policyengine_us.h5", "stage_manifest": "stage_manifest.json", "data_flow_snapshot": "data_flow_snapshot.json", diff --git a/tests/pipelines/test_stage_run.py b/tests/pipelines/test_stage_run.py index b2a6b58..14c69be 100644 --- a/tests/pipelines/test_stage_run.py +++ b/tests/pipelines/test_stage_run.py @@ -488,6 +488,7 @@ def _write_artifact_bundle_files(root): "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet", "stage_artifacts/06_policyengine_entities/metadata.json", "stage_artifacts/07_calibration/calibration_summary.json", + "stage_artifacts/07_calibration/policyengine_entity_tables/metadata.json", ): path = root / relative path.parent.mkdir(parents=True, exist_ok=True) @@ -561,9 +562,12 @@ def _artifact_manifest(): "scaffold_seed_data": ( "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet" ), - "policyengine_entity_tables": ( + "pre_calibration_policyengine_entity_tables": ( "stage_artifacts/06_policyengine_entities/metadata.json" ), + "policyengine_entity_tables": ( + "stage_artifacts/07_calibration/policyengine_entity_tables/metadata.json" + ), "calibration_summary": ( "stage_artifacts/07_calibration/calibration_summary.json" ), diff --git a/tests/pipelines/test_versioned_artifacts.py b/tests/pipelines/test_versioned_artifacts.py index ea383d4..c89cc9c 100644 --- a/tests/pipelines/test_versioned_artifacts.py +++ b/tests/pipelines/test_versioned_artifacts.py @@ -376,7 +376,8 @@ def test_save_versioned_us_microplex_artifacts_uses_explicit_version(tmp_path): assert paths.policyengine_entity_tables == ( paths.output_dir / "stage_artifacts" - / "06_policyengine_entities" + / "07_calibration" + / "policyengine_entity_tables" / "metadata.json" ) assert paths.calibration_summary == ( From 9e601a529bcd27ac44c1486de74fb12789eb92fa Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 2 Jun 2026 16:26:13 +0200 Subject: [PATCH 4/7] Attribute Stage 9 start failures correctly --- src/microplex_us/pipelines/artifacts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 30415a0..2b8fdf3 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -1155,13 +1155,15 @@ def save_us_microplex_artifacts( ), ) ) - stage_runtime_writer.start_stage("09_validation_benchmarking") except Exception as exc: if stage_runtime_writer is not None: stage_runtime_writer.fail_stage("08_dataset_assembly", exc) raise try: + if stage_runtime_writer is not None: + stage_runtime_writer.start_stage("09_validation_benchmarking") + ( resolved_target_provider, resolved_baseline_dataset, From 2ab88c14579b14c74c34ebfed543a167e5807508 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 2 Jun 2026 19:27:50 +0200 Subject: [PATCH 5/7] Add stage seam tests and publish entity artifacts --- src/microplex_us/pipelines/artifacts.py | 93 +++++---- src/microplex_us/pipelines/us.py | 2 + tests/pipelines/test_artifacts.py | 12 +- tests/pipelines/test_stage_run.py | 253 ++++++++++++++++++++++++ 4 files changed, 325 insertions(+), 35 deletions(-) diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 2b8fdf3..5860af4 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -101,6 +101,7 @@ class USMicroplexArtifactPaths: artifact_inventory: Path | None = None conditional_readiness: Path | None = None source_plan: Path | None = None + pre_calibration_policyengine_entity_tables: Path | None = None policyengine_entity_tables: Path | None = None calibration_summary: Path | None = None validation_evidence: Path | None = None @@ -254,10 +255,12 @@ def replay_us_microplex_policyengine_stage_from_artifact( ) pipeline = USMicroplexPipeline(config) + pre_calibration_policyengine_tables = pipeline.build_policyengine_entity_tables( + synthetic_data + ) if config.policyengine_targets_db is not None: - synthetic_tables = pipeline.build_policyengine_entity_tables(synthetic_data) policyengine_tables, calibrated_data, calibration_summary = ( - pipeline.calibrate_policyengine_tables(synthetic_tables) + pipeline.calibrate_policyengine_tables(pre_calibration_policyengine_tables) ) else: calibrated_data, calibration_summary = pipeline.calibrate( @@ -282,6 +285,7 @@ def replay_us_microplex_policyengine_stage_from_artifact( calibration_summary=calibration_summary, synthesis_metadata=synthesis_metadata, policyengine_tables=policyengine_tables, + pre_calibration_policyengine_tables=pre_calibration_policyengine_tables, scaffold_seed_data=scaffold_seed_data, ) @@ -909,6 +913,16 @@ def save_us_microplex_artifacts( """Persist a build result as a reproducible artifact bundle.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) + scaffold_seed_data = ( + result.scaffold_seed_data + if result.scaffold_seed_data is not None + else result.seed_data + ) + pre_calibration_policyengine_tables = ( + result.pre_calibration_policyengine_tables + if result.pre_calibration_policyengine_tables is not None + else result.policyengine_tables + ) seed_data_path = resolve_us_stage_artifact_contract_path( output_dir, @@ -983,14 +997,10 @@ def save_us_microplex_artifacts( "03_source_planning", "source_plan", ) - scaffold_seed_data_path = ( - resolve_us_stage_artifact_contract_path( - output_dir, - "04_seed_scaffold", - "scaffold_seed_data", - ) - if result.scaffold_seed_data is not None - else None + scaffold_seed_data_path = resolve_us_stage_artifact_contract_path( + output_dir, + "04_seed_scaffold", + "scaffold_seed_data", ) policyengine_entity_tables_path = ( resolve_us_stage_artifact_contract_path( @@ -1012,7 +1022,7 @@ def save_us_microplex_artifacts( "06_policyengine_entities", "pre_calibration_policyengine_entity_tables", ) - if result.policyengine_tables is not None + if pre_calibration_policyengine_tables is not None else None ) validation_evidence_path = ( @@ -1035,15 +1045,11 @@ def save_us_microplex_artifacts( stage_runtime_writer.start_stage("08_dataset_assembly") try: - if ( - result.scaffold_seed_data is not None - and scaffold_seed_data_path is not None - ): - _write_parquet_unless_live_artifact_exists( - scaffold_seed_data_path, - result.scaffold_seed_data, - live_artifact=live_artifacts, - ) + _write_parquet_unless_live_artifact_exists( + scaffold_seed_data_path, + scaffold_seed_data, + live_artifact=live_artifacts, + ) _write_parquet_unless_live_artifact_exists( seed_data_path, result.seed_data, @@ -1081,19 +1087,37 @@ def save_us_microplex_artifacts( source_weight_diagnostics_payload, ) + if ( + pre_calibration_policyengine_entity_tables_path is not None + and pre_calibration_policyengine_tables is not None + ): + if not ( + live_artifacts + and pre_calibration_policyengine_entity_tables_path.exists() + ): + write_us_policyengine_entity_stage_artifact( + pre_calibration_policyengine_tables, + output_dir, + stage_id="06_policyengine_entities", + artifact_key="pre_calibration_policyengine_entity_tables", + checkpoint_stage="post_microsim", + ) + if ( + policyengine_entity_tables_path is not None + and result.policyengine_tables is not None + ): + if not (live_artifacts and policyengine_entity_tables_path.exists()): + write_us_policyengine_entity_stage_artifact( + result.policyengine_tables, + output_dir, + stage_id="07_calibration", + artifact_key="policyengine_entity_tables", + checkpoint_stage="post_calibration", + ) if ( result.policyengine_tables is not None and policyengine_dataset_path is not None ): - if policyengine_entity_tables_path is not None: - if not (live_artifacts and policyengine_entity_tables_path.exists()): - write_us_policyengine_entity_stage_artifact( - result.policyengine_tables, - output_dir, - stage_id="07_calibration", - artifact_key="policyengine_entity_tables", - checkpoint_stage="post_calibration", - ) period = result.config.policyengine_dataset_year or 2024 USMicroplexPipeline(result.config).export_policyengine_dataset( result, @@ -1310,10 +1334,8 @@ def save_us_microplex_artifacts( "calibration": result.calibration_summary, "artifacts": { "seed_data": seed_data_path.name, - "scaffold_seed_data": ( - str(scaffold_seed_data_path.relative_to(output_dir)) - if scaffold_seed_data_path is not None - else None + "scaffold_seed_data": str( + scaffold_seed_data_path.relative_to(output_dir) ), "synthetic_data": synthetic_data_path.name, "calibrated_data": calibrated_data_path.name, @@ -1500,7 +1522,7 @@ def save_us_microplex_artifacts( "policyengine_harness" if harness_summary is not None else None ), required_artifact_keys=( - *(("scaffold_seed_data",) if scaffold_seed_data_path is not None else ()), + "scaffold_seed_data", "seed_data", "synthetic_data", "calibrated_data", @@ -1539,6 +1561,9 @@ def save_us_microplex_artifacts( artifact_inventory=artifact_inventory_path, conditional_readiness=conditional_readiness_path, source_plan=source_plan_path, + pre_calibration_policyengine_entity_tables=( + pre_calibration_policyengine_entity_tables_path + ), policyengine_entity_tables=policyengine_entity_tables_path, calibration_summary=calibration_summary_path, validation_evidence=validation_evidence_path, diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index 8a87cda..1dd87db 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -1952,6 +1952,7 @@ class USMicroplexBuildResult: synthesis_metadata: dict[str, Any] = field(default_factory=dict) synthesizer: Synthesizer | Any | None = None policyengine_tables: PolicyEngineUSEntityTableBundle | None = None + pre_calibration_policyengine_tables: PolicyEngineUSEntityTableBundle | None = None source_frame: ObservationFrame | None = None source_frames: tuple[ObservationFrame, ...] = () fusion_plan: FusionPlan | None = None @@ -2667,6 +2668,7 @@ def build_from_frames( synthesis_metadata=synthesis_metadata, synthesizer=synthesizer, policyengine_tables=policyengine_tables, + pre_calibration_policyengine_tables=synthetic_tables, source_frame=scaffold_input.frame, source_frames=tuple(frame for frame in frames), fusion_plan=fusion_plan, diff --git a/tests/pipelines/test_artifacts.py b/tests/pipelines/test_artifacts.py index 74c881b..1e299a0 100644 --- a/tests/pipelines/test_artifacts.py +++ b/tests/pipelines/test_artifacts.py @@ -16,6 +16,9 @@ save_us_microplex_artifacts, ) from microplex_us.pipelines.registry import load_us_microplex_run_registry +from microplex_us.pipelines.stage_policyengine_artifacts import ( + load_us_policyengine_entity_stage_artifact, +) from microplex_us.pipelines.us import ( USMicroplexBuildConfig, USMicroplexBuildResult, @@ -375,6 +378,8 @@ def test_writes_expected_files(self, tmp_path): assert paths.conditional_readiness.exists() assert paths.source_plan is not None assert paths.source_plan.exists() + assert paths.pre_calibration_policyengine_entity_tables is not None + assert paths.pre_calibration_policyengine_entity_tables.exists() assert paths.policyengine_entity_tables is not None assert paths.policyengine_entity_tables.exists() assert paths.calibration_summary is not None @@ -452,13 +457,18 @@ def test_writes_expected_files(self, tmp_path): assert source_diagnostics["sources"][0]["person_weight_share"] == 1.0 assert source_diagnostics["sources"][0]["tax_unit_count"] == 2 assert source_diagnostics["sources"][0]["tax_unit_weight_share"] == 1.0 + pre_calibration_tables, _ = load_us_policyengine_entity_stage_artifact( + paths.pre_calibration_policyengine_entity_tables + ) + assert pre_calibration_tables.tax_units is not None + assert "filing_status" in pre_calibration_tables.tax_units with h5py.File(paths.policyengine_dataset, "r") as handle: assert "household_id" in handle assert "person_household_id" in handle assert "tax_unit_id" in handle assert "taxable_interest_income" in handle - assert "filing_status" in handle + assert "filing_status" not in handle assert "source_weight_diagnostics" not in handle def test_writes_model_when_present(self, tmp_path): diff --git a/tests/pipelines/test_stage_run.py b/tests/pipelines/test_stage_run.py index 14c69be..5f8562c 100644 --- a/tests/pipelines/test_stage_run.py +++ b/tests/pipelines/test_stage_run.py @@ -8,17 +8,24 @@ from microplex_us.pipelines.stage_contracts import ( US_CANONICAL_STAGE_IDS, get_us_pipeline_stage_contract, + get_us_stage_artifact_contract, ) from microplex_us.pipelines.stage_run import ( US_STAGE_OUTPUT_MANIFEST_TYPES, USArtifactRef, USAuxiliaryArtifact, + USCalibrationOutputs, + USDatasetAssemblyOutputs, USDiagnosticOutput, + USDonorSynthesisOutputs, + USPolicyEngineEntityOutputs, USRunProfileOutputs, + USSeedScaffoldOutputs, USSourceLoadingOutputs, USSourcePlanningOutputs, USStageInputOverride, USStageRunWriter, + USValidationBenchmarkingOutputs, build_us_stage_output_manifests_from_artifact_manifest, parse_us_stage_input_override, write_us_stage_run_manifests_from_artifact_manifest, @@ -113,6 +120,58 @@ def test_stage_run_writer_records_typed_stage_manifests(tmp_path): ) +@pytest.mark.parametrize( + ("previous_stage_id", "stage_id"), + zip(US_CANONICAL_STAGE_IDS, US_CANONICAL_STAGE_IDS[1:]), +) +def test_adjacent_stage_serialized_outputs_satisfy_next_stage_inputs( + tmp_path, + previous_stage_id, + stage_id, +): + _write_mock_stage_prefix(tmp_path, previous_stage_id) + + current_output = _mock_stage_output( + stage_id, + input_stage_manifest=_stage_manifest_ref(previous_stage_id), + ) + current_writer = USStageRunWriter(tmp_path) + + current_writer.record_stage(current_output) + + assert current_writer.recorded_stages == (current_output,) + + +def test_adjacent_stage_serialized_output_schema_breaks_next_stage_input(tmp_path): + seams = [ + (previous_stage_id, stage_id, resource.key) + for previous_stage_id, stage_id in zip( + US_CANONICAL_STAGE_IDS, + US_CANONICAL_STAGE_IDS[1:], + ) + for resource in get_us_pipeline_stage_contract(stage_id).inputs + if resource.required and resource.stage_id == previous_stage_id + ] + assert seams + + for previous_stage_id, stage_id, missing_key in seams: + seam_root = tmp_path / f"{previous_stage_id}-to-{stage_id}-{missing_key}" + _write_mock_stage_prefix( + seam_root, + previous_stage_id, + missing_stage_id=previous_stage_id, + missing_output_key=missing_key, + ) + + current_output = _mock_stage_output( + stage_id, + input_stage_manifest=_stage_manifest_ref(previous_stage_id), + ) + + with pytest.raises(ValueError, match=missing_key): + USStageRunWriter(seam_root).record_stage(current_output) + + def test_stage_run_writer_rejects_missing_diagnostics(tmp_path): writer = USStageRunWriter(tmp_path) output = USRunProfileOutputs( @@ -475,6 +534,200 @@ def test_stage_run_writer_preserves_existing_validation_evidence_summary( ) +def _write_mock_stage_prefix( + root, + through_stage_id, + *, + missing_stage_id=None, + missing_output_key=None, +): + writer = USStageRunWriter(root) + for stage_id in US_CANONICAL_STAGE_IDS[ + : US_CANONICAL_STAGE_IDS.index(through_stage_id) + 1 + ]: + writer.record_stage( + _mock_stage_output( + stage_id, + missing_output_key=( + missing_output_key if stage_id == missing_stage_id else None + ), + complete=stage_id != missing_stage_id, + ) + ) + writer.write_manifest_files() + + +def _mock_stage_output( + stage_id, + *, + input_stage_manifest=None, + missing_output_key=None, + complete=True, +): + diagnostics = { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"stage_id": stage_id}, + ) + } + common = { + "input_stage_manifest": input_stage_manifest, + "diagnostics": diagnostics, + "complete": complete, + } + values = _mock_stage_output_values(stage_id) + if missing_output_key is not None: + values[missing_output_key] = _missing_output_value(values[missing_output_key]) + return _mock_stage_output_type(stage_id)(**common, **values) + + +def _mock_stage_output_type(stage_id): + return { + "01_run_profile": USRunProfileOutputs, + "02_source_loading": USSourceLoadingOutputs, + "03_source_planning": USSourcePlanningOutputs, + "04_seed_scaffold": USSeedScaffoldOutputs, + "05_donor_integration_synthesis": USDonorSynthesisOutputs, + "06_policyengine_entities": USPolicyEngineEntityOutputs, + "07_calibration": USCalibrationOutputs, + "08_dataset_assembly": USDatasetAssemblyOutputs, + "09_validation_benchmarking": USValidationBenchmarkingOutputs, + }[stage_id] + + +def _mock_stage_output_values(stage_id): + if stage_id == "01_run_profile": + return { + "manifest": _mock_artifact_ref("01_run_profile", "manifest"), + "resolved_config": {"n_synthetic": 10}, + "provider_query_plan": {"source_names": ["source"]}, + } + if stage_id == "02_source_loading": + return { + "observation_frame_summary": {"source_count": 1}, + "source_descriptors": ("source",), + "source_relationships": {"status": "valid"}, + } + if stage_id == "03_source_planning": + return { + "source_plan": _mock_artifact_ref("03_source_planning", "source_plan"), + "scaffold_selection": {"scaffold_source": "source"}, + } + if stage_id == "04_seed_scaffold": + return { + "scaffold_seed_data": _mock_artifact_ref( + "04_seed_scaffold", + "scaffold_seed_data", + ), + "seed_schema_metadata": {"required_columns": ["person_id"]}, + } + if stage_id == "05_donor_integration_synthesis": + return { + "seed_data": _mock_artifact_ref( + "05_donor_integration_synthesis", + "seed_data", + ), + "synthetic_data": _mock_artifact_ref( + "05_donor_integration_synthesis", + "synthetic_data", + ), + "synthesis_metadata": {"backend": "mock"}, + "source_weight_diagnostics": _mock_artifact_ref( + "05_donor_integration_synthesis", + "source_weight_diagnostics", + category="diagnostic", + ), + } + if stage_id == "06_policyengine_entities": + return { + "pre_calibration_policyengine_entity_tables": _mock_artifact_ref( + "06_policyengine_entities", + "pre_calibration_policyengine_entity_tables", + ), + "materialized_policyengine_inputs": {"tables": {"households": {"rows": 1}}}, + } + if stage_id == "07_calibration": + return { + "calibrated_data": _mock_artifact_ref( + "07_calibration", + "calibrated_data", + ), + "targets": _mock_artifact_ref("07_calibration", "targets"), + "calibration_summary": _mock_artifact_ref( + "07_calibration", + "calibration_summary", + category="diagnostic", + ), + "policyengine_entity_tables": _mock_artifact_ref( + "07_calibration", + "policyengine_entity_tables", + ), + "target_ledger": {"target_count": 1}, + } + if stage_id == "08_dataset_assembly": + return { + "policyengine_dataset": _mock_artifact_ref( + "08_dataset_assembly", + "policyengine_dataset", + ), + "stage_manifest": _mock_artifact_ref( + "08_dataset_assembly", + "stage_manifest", + category="derived", + ), + "data_flow_snapshot": _mock_artifact_ref( + "08_dataset_assembly", + "data_flow_snapshot", + category="derived", + ), + "artifact_inventory": _mock_artifact_ref( + "08_dataset_assembly", + "artifact_inventory", + category="derived", + ), + "conditional_readiness": _mock_artifact_ref( + "08_dataset_assembly", + "conditional_readiness", + category="derived", + ), + } + if stage_id == "09_validation_benchmarking": + return { + "validation_evidence": _mock_artifact_ref( + "09_validation_benchmarking", + "validation_evidence", + ), + "benchmark_summary": {"loss_delta": -0.1}, + "policyengine_native_scores": _mock_artifact_ref( + "09_validation_benchmarking", + "policyengine_native_scores", + category="diagnostic", + ), + } + raise KeyError(stage_id) + + +def _mock_artifact_ref(stage_id, artifact_key, *, category="required_output"): + contract = get_us_stage_artifact_contract(stage_id, artifact_key) + return USArtifactRef( + key=artifact_key, + path=contract.path_hint or f"stage_artifacts/{stage_id}/{artifact_key}", + format=contract.format, + required=contract.required, + category=category, + resume_role=contract.resume_role, + assume_exists=True, + ) + + +def _missing_output_value(value): + return None if isinstance(value, USArtifactRef) else type(value)() + + +def _stage_manifest_ref(stage_id): + return f"stage_artifacts/manifests/{stage_id}.json" + + def _write_artifact_bundle_files(root): for relative in ( "seed_data.parquet", From c4114605dd707dcca5a43168387648f4b8463031 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 2 Jun 2026 22:00:10 +0200 Subject: [PATCH 6/7] Preserve runtime lifecycle and tighten stage six artifacts --- src/microplex_us/pipelines/artifacts.py | 6 +- src/microplex_us/pipelines/stage_run.py | 1 + src/microplex_us/pipelines/stage_runtime.py | 56 +++++++++++---- tests/pipelines/test_artifacts.py | 75 +++++++++++++++++++++ tests/pipelines/test_stage_runtime.py | 42 ++++++++++++ 5 files changed, 163 insertions(+), 17 deletions(-) diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 5860af4..ed70fb3 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -918,11 +918,7 @@ def save_us_microplex_artifacts( if result.scaffold_seed_data is not None else result.seed_data ) - pre_calibration_policyengine_tables = ( - result.pre_calibration_policyengine_tables - if result.pre_calibration_policyengine_tables is not None - else result.policyengine_tables - ) + pre_calibration_policyengine_tables = result.pre_calibration_policyengine_tables seed_data_path = resolve_us_stage_artifact_contract_path( output_dir, diff --git a/src/microplex_us/pipelines/stage_run.py b/src/microplex_us/pipelines/stage_run.py index 064b7f0..e6ccedd 100644 --- a/src/microplex_us/pipelines/stage_run.py +++ b/src/microplex_us/pipelines/stage_run.py @@ -986,6 +986,7 @@ def build_us_stage_output_manifests_from_artifact_manifest( complete=all( _artifact_exists(root, artifacts, key) for key in ( + "pre_calibration_policyengine_entity_tables", "calibrated_data", "targets", "calibration_summary", diff --git a/src/microplex_us/pipelines/stage_runtime.py b/src/microplex_us/pipelines/stage_runtime.py index 59392de..bc18b5e 100644 --- a/src/microplex_us/pipelines/stage_runtime.py +++ b/src/microplex_us/pipelines/stage_runtime.py @@ -282,28 +282,51 @@ def finalize_from_artifact_manifest( ): existing = self._stage_payload(outputs.stage_id) now = _now() - lifecycle_status = _final_lifecycle_status(outputs) existing_events = tuple( dict(event) for event in existing.get("events", ()) if isinstance(event, dict) ) + existing_lifecycle = _terminal_lifecycle(existing) + if existing_lifecycle is not None: + lifecycle_status = existing_lifecycle + complete = bool(existing.get("complete")) + started_at = _optional_str(existing.get("startedAt")) + updated_at = _optional_str(existing.get("updatedAt")) + completed_at = _optional_str(existing.get("completedAt")) + failed_at = _optional_str(existing.get("failedAt")) + deferred_reason = _optional_str(existing.get("deferredReason")) + failure = existing.get("failure") + events = existing_events + else: + lifecycle_status = _final_lifecycle_status(outputs) + complete = outputs.complete + started_at = _optional_str(existing.get("startedAt")) or now + updated_at = now + completed_at = now if lifecycle_status == "complete" else None + failed_at = None + deferred_reason = ( + outputs.deferred_reason if lifecycle_status == "deferred" else None + ) + failure = None + events = ( + *existing_events, + *tuple(outputs.events), + _event(f"stage_{lifecycle_status}", now), + ) lifecycle_outputs = replace( outputs, + complete=complete, input_stage_manifest=outputs.input_stage_manifest or self._previous_stage_manifest_ref(outputs.stage_id), lifecycle_status=lifecycle_status, - started_at=_optional_str(existing.get("startedAt")) or now, - updated_at=now, - completed_at=now if lifecycle_status == "complete" else None, - deferred_reason=( - outputs.deferred_reason if lifecycle_status == "deferred" else None - ), - events=( - *existing_events, - *tuple(outputs.events), - _event(f"stage_{lifecycle_status}", now), - ), + started_at=started_at, + updated_at=updated_at, + completed_at=completed_at, + failed_at=failed_at, + deferred_reason=deferred_reason, + failure=failure, + events=events, ) self._run_writer.record_stage(lifecycle_outputs) self.manifest_payload = self._run_writer.write_manifest_files() @@ -559,6 +582,15 @@ def _final_lifecycle_status( return "complete" if outputs.complete else "pending" +def _terminal_lifecycle( + payload: Mapping[str, Any], +) -> USStageLifecycleStatus | None: + status = payload.get("lifecycleStatus") + if status in {"complete", "failed", "deferred"}: + return status + return None + + def _runtime_serialize(value: Any, artifact_root: str | Path | None) -> Any: if isinstance(value, USDiagnosticOutput): return value.to_dict(artifact_root) diff --git a/tests/pipelines/test_artifacts.py b/tests/pipelines/test_artifacts.py index 1e299a0..be64b67 100644 --- a/tests/pipelines/test_artifacts.py +++ b/tests/pipelines/test_artifacts.py @@ -322,6 +322,18 @@ def test_writes_expected_files(self, tmp_path): calibration_summary={"max_error": 0.01, "mean_error": 0.005}, synthesis_metadata={"backend": "bootstrap"}, synthesizer=None, + pre_calibration_policyengine_tables=PolicyEngineUSEntityTableBundle( + households=pd.DataFrame( + {"household_id": [1, 2], "household_weight": [1.0, 1.0]} + ), + tax_units=pd.DataFrame( + { + "tax_unit_id": [101, 102], + "household_id": [1, 2], + "filing_status": ["SINGLE", "JOINT"], + } + ), + ), policyengine_tables=PolicyEngineUSEntityTableBundle( households=pd.DataFrame( {"household_id": [1, 2], "household_weight": [0.5, 1.5]} @@ -471,6 +483,69 @@ def test_writes_expected_files(self, tmp_path): assert "filing_status" not in handle assert "source_weight_diagnostics" not in handle + def test_leaves_pre_calibration_entity_artifact_blank_when_tables_are_absent( + self, + tmp_path, + ): + result = USMicroplexBuildResult( + config=USMicroplexBuildConfig( + n_synthetic=1, + synthesis_backend="bootstrap", + calibration_backend="entropy", + ), + seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), + scaffold_seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}), + synthetic_data=pd.DataFrame({"income": [10.0], "weight": [1.0]}), + calibrated_data=pd.DataFrame({"income": [10.0], "weight": [1.0]}), + targets=USMicroplexTargets(marginal={}, continuous={}), + calibration_summary={"max_error": 0.0}, + synthesis_metadata={ + "backend": "bootstrap", + "source_names": ["source"], + "scaffold_source": "source", + }, + policyengine_tables=PolicyEngineUSEntityTableBundle( + households=pd.DataFrame( + {"household_id": [1], "household_weight": [1.0]} + ), + persons=pd.DataFrame( + { + "person_id": [10], + "household_id": [1], + "tax_unit_id": [101], + "spm_unit_id": [201], + "family_id": [301], + "marital_unit_id": [401], + "age": [35], + } + ), + tax_units=pd.DataFrame({"tax_unit_id": [101], "household_id": [1]}), + spm_units=pd.DataFrame({"spm_unit_id": [201], "household_id": [1]}), + families=pd.DataFrame({"family_id": [301], "household_id": [1]}), + marital_units=pd.DataFrame( + {"marital_unit_id": [401], "household_id": [1]} + ), + ), + ) + + paths = save_us_microplex_artifacts(result, tmp_path) + + manifest = json.loads(paths.manifest.read_text()) + assert paths.pre_calibration_policyengine_entity_tables is None + assert ( + manifest["artifacts"]["pre_calibration_policyengine_entity_tables"] is None + ) + assert not (tmp_path / "stage_artifacts" / "06_policyengine_entities").exists() + stage7_manifest = json.loads( + ( + tmp_path / "stage_artifacts" / "manifests" / "07_calibration.json" + ).read_text() + ) + assert stage7_manifest["complete"] is False + assert ( + stage7_manifest["outputs"]["policyengine_entity_tables"]["exists"] is True + ) + def test_writes_model_when_present(self, tmp_path): class FakeSynthesizer: def __init__(self): diff --git a/tests/pipelines/test_stage_runtime.py b/tests/pipelines/test_stage_runtime.py index 6363850..e6858ed 100644 --- a/tests/pipelines/test_stage_runtime.py +++ b/tests/pipelines/test_stage_runtime.py @@ -88,6 +88,48 @@ def test_runtime_writer_completes_stage_and_exposes_lifecycle(tmp_path): ] +def test_runtime_writer_finalize_preserves_completed_stage_lifecycle(tmp_path): + writer = USStageRuntimeWriter( + tmp_path, + manifest_payload={ + "config": {"calibration_backend": "none"}, + "artifacts": {"manifest": "manifest.json"}, + }, + ) + writer.start_stage("01_run_profile", metadata={"profile": "test"}) + writer.complete_stage( + USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config={"calibration_backend": "none"}, + provider_query_plan={"source_names": ["unit"]}, + diagnostics=_diagnostics("01_run_profile"), + ) + ) + stage1_path = tmp_path / "stage_artifacts" / "manifests" / "01_run_profile.json" + before = json.loads(stage1_path.read_text()) + + writer.finalize_from_artifact_manifest( + { + "config": {"calibration_backend": "none"}, + "artifacts": {"manifest": "manifest.json"}, + "synthesis": {"source_names": ["unit"]}, + } + ) + + after = json.loads(stage1_path.read_text()) + assert after["lifecycleStatus"] == "complete" + assert after["startedAt"] == before["startedAt"] + assert after["updatedAt"] == before["updatedAt"] + assert after["completedAt"] == before["completedAt"] + assert after["events"] == before["events"] + + def test_runtime_writer_serializes_enum_outputs(tmp_path): writer = USStageRuntimeWriter( tmp_path, From 07236fb90821c56112bd23229839d344fc09fe4a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 3 Jun 2026 16:29:59 +0200 Subject: [PATCH 7/7] Split artifact persistence modules --- .../pipelines/artifact_dataset_assembly.py | 87 + src/microplex_us/pipelines/artifact_io.py | 118 ++ src/microplex_us/pipelines/artifact_replay.py | 184 ++ .../pipelines/artifact_source_diagnostics.py | 416 +++++ src/microplex_us/pipelines/artifact_types.py | 54 + .../pipelines/artifact_validation.py | 160 ++ src/microplex_us/pipelines/artifacts.py | 1619 +---------------- .../pipelines/versioned_artifacts.py | 684 +++++++ 8 files changed, 1769 insertions(+), 1553 deletions(-) create mode 100644 src/microplex_us/pipelines/artifact_dataset_assembly.py create mode 100644 src/microplex_us/pipelines/artifact_io.py create mode 100644 src/microplex_us/pipelines/artifact_replay.py create mode 100644 src/microplex_us/pipelines/artifact_source_diagnostics.py create mode 100644 src/microplex_us/pipelines/artifact_types.py create mode 100644 src/microplex_us/pipelines/artifact_validation.py create mode 100644 src/microplex_us/pipelines/versioned_artifacts.py diff --git a/src/microplex_us/pipelines/artifact_dataset_assembly.py b/src/microplex_us/pipelines/artifact_dataset_assembly.py new file mode 100644 index 0000000..c812f4a --- /dev/null +++ b/src/microplex_us/pipelines/artifact_dataset_assembly.py @@ -0,0 +1,87 @@ +"""Dataset-assembly artifact helpers for saved US Microplex bundles.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pandas as pd + +from microplex_us.capital_gains_lots import ( + SyntheticCapitalGainsLotConfig, + generate_synthetic_capital_gains_lots, + synthetic_capital_gains_lot_metadata, + validate_capital_gains_lot_anchors, + write_capital_gains_lots_sqlite, +) +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.us import USMicroplexBuildResult + + +def _maybe_write_capital_gains_lot_artifact( + result: USMicroplexBuildResult, + output_dir: Path, +) -> tuple[Path | None, dict[str, Any] | None]: + if ( + not result.config.capital_gains_lots_enabled + or result.policyengine_tables is None + ): + return None, None + persons = result.policyengine_tables.persons + gain_column = "long_term_capital_gains_before_response" + if gain_column not in persons.columns: + return None, { + "enabled": True, + "written": False, + "reason": f"missing {gain_column}", + } + + period = result.config.policyengine_dataset_year or 2024 + lot_config = SyntheticCapitalGainsLotConfig( + random_seed=( + result.config.capital_gains_lots_random_seed + if result.config.capital_gains_lots_random_seed is not None + else result.config.random_seed + ), + max_lots_per_person=int(result.config.capital_gains_lots_max_lots_per_person), + ) + lots = generate_synthetic_capital_gains_lots( + persons, + period=period, + config=lot_config, + gain_column=gain_column, + ) + validate_capital_gains_lot_anchors(persons, lots, gain_column=gain_column) + metadata = synthetic_capital_gains_lot_metadata( + lot_config, + period=period, + source_gain_column=gain_column, + ) + nonzero_people = int( + pd.to_numeric(persons[gain_column], errors="coerce").fillna(0.0).ne(0.0).sum() + ) + metadata.update( + { + "person_rows": int(len(persons)), + "nonzero_person_rows": nonzero_people, + "lot_rows": int(len(lots)), + } + ) + path = resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "capital_gains_lots", + ) + write_capital_gains_lots_sqlite(lots, path, metadata=metadata) + return path, { + "enabled": True, + "written": True, + "path": path.name, + "person_rows": int(len(persons)), + "nonzero_person_rows": nonzero_people, + "lot_rows": int(len(lots)), + "source_gain_column": gain_column, + "max_lots_per_person": int(lot_config.max_lots_per_person), + } diff --git a/src/microplex_us/pipelines/artifact_io.py b/src/microplex_us/pipelines/artifact_io.py new file mode 100644 index 0000000..3f12543 --- /dev/null +++ b/src/microplex_us/pipelines/artifact_io.py @@ -0,0 +1,118 @@ +"""Low-level filesystem helpers for saved US Microplex artifacts.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import pandas as pd + +from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_run import USArtifactRef, USDiagnosticOutput + + +def _stage_artifact_ref( + artifact_root: str | Path, + stage_id: str, + artifact_key: str, + *, + assume_exists: bool = False, +) -> USArtifactRef: + contract = get_us_stage_artifact_contract(stage_id, artifact_key) + return USArtifactRef( + key=artifact_key, + path=resolve_us_stage_artifact_contract_path( + artifact_root, + stage_id, + artifact_key, + ), + format=contract.format, + required=contract.required, + resume_role=contract.resume_role, + assume_exists=assume_exists, + ) + + +def _stage_diagnostics( + stage_id: str, + summary: Mapping[str, Any], +) -> dict[str, USDiagnosticOutput]: + return { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description=f"Runtime diagnostic summary for {stage_id}.", + summary=dict(summary), + ) + } + + +def _write_parquet_unless_live_artifact_exists( + path: Path, + frame: pd.DataFrame, + *, + live_artifact: bool, +) -> None: + if live_artifact and path.exists(): + return + path.parent.mkdir(parents=True, exist_ok=True) + frame.to_parquet(path, index=False) + + +def _write_json_unless_live_artifact_exists( + path: Path, + payload: Mapping[str, Any], + *, + live_artifact: bool, +) -> None: + if live_artifact and path.exists(): + return + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + + +def _resolve_saved_artifact_file( + artifact_root: Path, + manifest: dict[str, Any], + artifact_key: str, +) -> Path: + artifacts = dict(manifest.get("artifacts", {})) + filename = artifacts.get(artifact_key) + if not filename: + filename = ( + "targets.json" if artifact_key == "targets" else f"{artifact_key}.parquet" + ) + path = Path(filename) + if not path.is_absolute(): + path = artifact_root / path + if not path.exists(): + raise FileNotFoundError(f"Saved artifact file not found: {path}") + return path + + +def _resolve_optional_saved_artifact_file( + artifact_root: Path, + manifest: dict[str, Any], + artifact_key: str, +) -> Path | None: + artifacts = dict(manifest.get("artifacts", {})) + filename = artifacts.get(artifact_key) + if not filename: + return None + path = Path(str(filename)) + if not path.is_absolute(): + path = artifact_root / path + if not path.exists(): + raise FileNotFoundError(f"Saved optional artifact file not found: {path}") + return path + + +def _write_json_atomically(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + temp_path = path.with_name(f".{path.name}.tmp") + temp_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + temp_path.replace(path) diff --git a/src/microplex_us/pipelines/artifact_replay.py b/src/microplex_us/pipelines/artifact_replay.py new file mode 100644 index 0000000..d4882f7 --- /dev/null +++ b/src/microplex_us/pipelines/artifact_replay.py @@ -0,0 +1,184 @@ +"""Replay helpers for saved US Microplex artifact bundles.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import pandas as pd +from microplex.targets import TargetProvider + +from microplex_us.pipelines.artifact_io import ( + _resolve_optional_saved_artifact_file, + _resolve_saved_artifact_file, +) +from microplex_us.pipelines.artifact_types import USMicroplexVersionedBuildArtifacts +from microplex_us.pipelines.registry import FrontierMetric +from microplex_us.pipelines.us import ( + USMicroplexBuildConfig, + USMicroplexBuildResult, + USMicroplexPipeline, + USMicroplexTargets, +) +from microplex_us.pipelines.versioned_artifacts import ( + _finalize_versioned_build_artifacts, +) +from microplex_us.policyengine.harness import ( + PolicyEngineUSComparisonCache, + PolicyEngineUSHarnessSlice, +) + + +def _facade_pipeline_cls() -> type[USMicroplexPipeline]: + from microplex_us.pipelines import artifacts + + return artifacts.USMicroplexPipeline + + +def replay_us_microplex_policyengine_stage_from_artifact( + artifact_dir: str | Path, + *, + config_overrides: dict[str, Any] | None = None, +) -> USMicroplexBuildResult: + """Replay calibration/export inputs from a saved artifact without raw ETL. + + This reloads saved seed and synthetic rows, applies optional runtime config + overrides, and reruns the downstream calibration stage from the saved + synthetic population. For PE-DB builds, this intentionally calls + ``calibrate_policyengine_tables`` even when ``calibration_backend="none"`` + so PE target materialization and export-only variables stay on the same + path as a full pipeline build. + """ + + artifact_root = Path(artifact_dir).expanduser().resolve() + manifest_path = artifact_root / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"Saved artifact manifest not found: {manifest_path}") + + manifest = json.loads(manifest_path.read_text()) + config_payload = dict(manifest.get("config", {})) + config_payload.update(dict(config_overrides or {})) + config = USMicroplexBuildConfig(**config_payload) + + seed_data = pd.read_parquet( + _resolve_saved_artifact_file(artifact_root, manifest, "seed_data") + ) + scaffold_seed_data_path = _resolve_optional_saved_artifact_file( + artifact_root, + manifest, + "scaffold_seed_data", + ) + scaffold_seed_data = ( + pd.read_parquet(scaffold_seed_data_path) + if scaffold_seed_data_path is not None + else None + ) + synthetic_data = pd.read_parquet( + _resolve_saved_artifact_file(artifact_root, manifest, "synthetic_data") + ) + targets_payload = json.loads( + _resolve_saved_artifact_file(artifact_root, manifest, "targets").read_text() + ) + targets = USMicroplexTargets( + marginal=dict(targets_payload.get("marginal", {})), + continuous=dict(targets_payload.get("continuous", {})), + ) + + pipeline = _facade_pipeline_cls()(config) + pre_calibration_policyengine_tables = pipeline.build_policyengine_entity_tables( + synthetic_data + ) + if config.policyengine_targets_db is not None: + policyengine_tables, calibrated_data, calibration_summary = ( + pipeline.calibrate_policyengine_tables(pre_calibration_policyengine_tables) + ) + else: + calibrated_data, calibration_summary = pipeline.calibrate( + synthetic_data, + targets, + ) + policyengine_tables = pipeline.build_policyengine_entity_tables(calibrated_data) + + synthesis_metadata = dict(manifest.get("synthesis", {})) + synthesis_metadata["policyengine_stage_replay"] = { + "source_artifact_dir": str(artifact_root), + "source_manifest": str(manifest_path), + "config_override_keys": sorted((config_overrides or {}).keys()), + } + + return USMicroplexBuildResult( + config=config, + seed_data=seed_data, + synthetic_data=synthetic_data, + calibrated_data=calibrated_data, + targets=targets, + calibration_summary=calibration_summary, + synthesis_metadata=synthesis_metadata, + policyengine_tables=policyengine_tables, + pre_calibration_policyengine_tables=pre_calibration_policyengine_tables, + scaffold_seed_data=scaffold_seed_data, + ) + + +def replay_and_save_versioned_us_microplex_policyengine_stage( + artifact_dir: str | Path, + output_root: str | Path | None = None, + *, + config_overrides: dict[str, Any] | None = None, + version_id: str | None = None, + frontier_metric: FrontierMetric = "candidate_composite_parity_loss", + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = True, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, +) -> USMicroplexVersionedBuildArtifacts: + """Replay a saved artifact's policy stage and persist a new versioned bundle.""" + + artifact_root = Path(artifact_dir).expanduser().resolve() + build_result = replay_us_microplex_policyengine_stage_from_artifact( + artifact_root, + config_overrides=config_overrides, + ) + resolved_output_root = ( + Path(output_root).expanduser().resolve() + if output_root is not None + else artifact_root.parent + ) + replay_metadata = { + "policyengine_stage_replay": True, + "source_artifact_dir": str(artifact_root), + **dict(run_registry_metadata or {}), + } + return _finalize_versioned_build_artifacts( + build_result, + output_root=resolved_output_root, + version_id=version_id, + frontier_metric=frontier_metric, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=replay_metadata, + ) diff --git a/src/microplex_us/pipelines/artifact_source_diagnostics.py b/src/microplex_us/pipelines/artifact_source_diagnostics.py new file mode 100644 index 0000000..0803980 --- /dev/null +++ b/src/microplex_us/pipelines/artifact_source_diagnostics.py @@ -0,0 +1,416 @@ +"""Source-plan and source-weight diagnostics for saved artifact bundles.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import pandas as pd + +from microplex_us.data_sources.forbes import ForbesFixedSpineConfig +from microplex_us.pipelines.artifact_io import _write_json_atomically +from microplex_us.pipelines.us import USMicroplexBuildResult + + +def _write_us_source_plan_artifact( + result: USMicroplexBuildResult, + output_path: Path, +) -> None: + synthesis = dict(result.synthesis_metadata) + source_names = tuple( + dict.fromkeys( + value + for value in ( + *list(synthesis.get("source_names", ())), + synthesis.get("scaffold_source"), + ) + if isinstance(value, str) and value + ) + ) + payload = { + "formatVersion": 1, + "stageId": "03_source_planning", + "sourceNames": list(source_names), + "scaffoldSource": synthesis.get("scaffold_source"), + "donorIntegratedVariables": list( + synthesis.get("donor_integrated_variables", ()) + ), + "conditionVars": list(synthesis.get("condition_vars", ())), + "targetVars": list(synthesis.get("target_vars", ())), + "donorAuthoritativeOverrideVariables": list( + synthesis.get("donor_authoritative_override_variables", ()) + ), + "donorExcludedVariables": list(synthesis.get("donor_excluded_variables", ())), + } + if result.fusion_plan is not None: + payload["fusionPlan"] = { + "sourceNames": list(result.fusion_plan.source_names), + } + _write_json_atomically(output_path, payload) + + +def _build_source_weight_diagnostics( + result: USMicroplexBuildResult, +) -> dict[str, Any]: + """Summarize source-weight provenance without exporting diagnostics to H5.""" + + entity_summaries = _entity_weight_summaries(result) + household_summary = entity_summaries["households"] + total_household_weight = household_summary["weight_sum"] + source_names = _source_names_for_diagnostics(result) + scaffold_source = _scaffold_source_for_diagnostics(result) + donor_sources = [ + source_name + for source_name in source_names + if scaffold_source is None or source_name != scaffold_source + ] + sources: list[dict[str, Any]] = [] + + fixed_spine_entry = _fixed_spine_source_entry( + result, + total_entity_summaries=entity_summaries, + ) + fixed_entity_summaries = ( + { + entity: { + "count": fixed_spine_entry.get(f"{prefix}_count", 0), + "weight_sum": fixed_spine_entry.get(f"{prefix}_weight_sum", 0.0), + "available": fixed_spine_entry.get(f"{prefix}_weight_sum") is not None, + } + for entity, prefix in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES.items() + } + if fixed_spine_entry is not None + else {} + ) + ordinary_entity_summaries = _subtract_entity_summaries( + entity_summaries, + fixed_entity_summaries, + ) + + sources.append( + { + "source_name": scaffold_source or "microplex_synthetic_population", + "source_class": "synthetic_population", + "source_role": "scaffold", + "source_names": source_names, + **_source_entity_fields(ordinary_entity_summaries, entity_summaries), + } + ) + + donor_integrated_variables = list( + result.synthesis_metadata.get("donor_integrated_variables", ()) + ) + for source_name in donor_sources: + sources.append( + { + "source_name": source_name, + "source_class": "donor_imputation", + "source_role": "donor", + "integrated_variable_count": len(donor_integrated_variables), + "row_contribution": "variables_imputed_into_synthetic_rows", + **_source_entity_fields( + _zero_entity_summaries(), + entity_summaries, + ), + } + ) + + if fixed_spine_entry is not None: + sources.append(fixed_spine_entry) + + numeric_shares = [ + float(source["household_weight_share"]) + for source in sources + if isinstance(source.get("household_weight_share"), int | float) + ] + summary = { + "diagnostic_scope": "saved_artifact_entity_weight_by_source_rows", + "household_count": household_summary["count"], + "total_household_weight": total_household_weight, + "person_count": entity_summaries["persons"]["count"], + "total_person_weight": entity_summaries["persons"]["weight_sum"], + "tax_unit_count": entity_summaries["tax_units"]["count"], + "total_tax_unit_weight": entity_summaries["tax_units"]["weight_sum"], + "source_entry_count": len(sources), + "donor_source_count": len(donor_sources), + "donor_integrated_variable_count": len(donor_integrated_variables), + "support_rows_appended": False, + "donor_rows_appended": False, + "support_household_weight_sum": 0.0, + "support_household_weight_share": 0.0, + "puf_support_household_weight_sum": 0.0, + "puf_support_household_weight_share": 0.0, + "max_source_household_weight_share": ( + max(numeric_shares) if numeric_shares else None + ), + "fixed_spine_enabled": bool( + isinstance(result.calibration_summary.get("fixed_spine"), dict) + and result.calibration_summary.get("fixed_spine", {}).get("enabled") + ), + "h5_exported": False, + } + + return { + "formatVersion": 1, + "created_at": datetime.now(UTC).isoformat(), + "summary": summary, + "sources": sources, + "notes": [ + "Donor sources contribute imputed variables to synthetic rows; they are not appended as weighted source rows.", + "Source diagnostics are written as a sidecar and are intentionally not exported into PolicyEngine H5 variables.", + ], + } + + +_SOURCE_DIAGNOSTIC_ENTITY_PREFIXES = { + "households": "household", + "persons": "person", + "tax_units": "tax_unit", +} + + +def _entity_weight_summaries( + result: USMicroplexBuildResult, +) -> dict[str, dict[str, Any]]: + summaries = _zero_entity_summaries() + if result.policyengine_tables is not None: + for entity in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES: + frame, weights = _policyengine_entity_weights(result, entity) + if frame is None or weights is None: + continue + summaries[entity] = { + "count": int(len(frame)), + "weight_sum": float(weights.sum()), + "available": True, + } + return summaries + + frame = result.calibrated_data + if frame.empty: + return summaries + weight_column = ( + "household_weight" if "household_weight" in frame.columns else "weight" + ) + if weight_column not in frame.columns: + summaries["persons"] = { + "count": int(len(frame)), + "weight_sum": 0.0, + "available": False, + } + return summaries + + weights = pd.to_numeric(frame[weight_column], errors="coerce").fillna(0.0) + summaries["persons"] = { + "count": int(len(frame)), + "weight_sum": float(weights.sum()), + "available": True, + } + if "household_id" in frame.columns: + household_weights = weights.groupby(frame["household_id"], sort=False).first() + summaries["households"] = { + "count": int(len(household_weights)), + "weight_sum": float(household_weights.sum()), + "available": True, + } + return summaries + + +def _zero_entity_summaries() -> dict[str, dict[str, Any]]: + return { + entity: {"count": 0, "weight_sum": 0.0, "available": False} + for entity in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES + } + + +def _subtract_entity_summaries( + total: dict[str, dict[str, Any]], + subtract: dict[str, dict[str, Any]], +) -> dict[str, dict[str, Any]]: + result: dict[str, dict[str, Any]] = {} + for entity in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES: + total_summary = total.get(entity, {}) + subtract_summary = subtract.get(entity, {}) + total_count = int(total_summary.get("count", 0) or 0) + subtract_count = int(subtract_summary.get("count", 0) or 0) + total_weight = float(total_summary.get("weight_sum", 0.0) or 0.0) + subtract_weight = float(subtract_summary.get("weight_sum", 0.0) or 0.0) + result[entity] = { + "count": max(total_count - subtract_count, 0), + "weight_sum": max(total_weight - subtract_weight, 0.0), + "available": bool(total_summary.get("available", False)), + } + return result + + +def _source_entity_fields( + source: dict[str, dict[str, Any]], + total: dict[str, dict[str, Any]], +) -> dict[str, Any]: + fields: dict[str, Any] = {} + for entity, prefix in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES.items(): + source_summary = source.get(entity, {}) + total_summary = total.get(entity, {}) + source_weight = source_summary.get("weight_sum") + fields[f"{prefix}_count"] = int(source_summary.get("count", 0) or 0) + fields[f"{prefix}_weight_sum"] = ( + float(source_weight) if source_weight is not None else None + ) + fields[f"{prefix}_weight_share"] = _weight_share( + float(source_weight or 0.0), + float(total_summary.get("weight_sum", 0.0) or 0.0), + ) + return fields + + +def _policyengine_entity_weights( + result: USMicroplexBuildResult, + entity: str, +) -> tuple[pd.DataFrame | None, pd.Series | None]: + tables = result.policyengine_tables + if tables is None: + return None, None + households = tables.households + if households is None or "household_weight" not in households.columns: + household_weight_by_id = None + else: + household_weights = pd.to_numeric( + households["household_weight"], + errors="coerce", + ).fillna(0.0) + household_weight_by_id = pd.Series( + household_weights.to_numpy(dtype=float), + index=households["household_id"], + ) + if entity == "households": + if households is None or household_weight_by_id is None: + return None, None + return households, household_weights + if entity == "persons": + return _frame_and_entity_weights( + tables.persons, + direct_weight_columns=("weight", "person_weight", "household_weight"), + household_weight_by_id=household_weight_by_id, + ) + if entity == "tax_units": + return _frame_and_entity_weights( + tables.tax_units, + direct_weight_columns=("tax_unit_weight", "household_weight"), + household_weight_by_id=household_weight_by_id, + ) + return None, None + + +def _frame_and_entity_weights( + frame: pd.DataFrame | None, + *, + direct_weight_columns: tuple[str, ...], + household_weight_by_id: pd.Series | None, +) -> tuple[pd.DataFrame | None, pd.Series | None]: + if frame is None: + return None, None + for column in direct_weight_columns: + if column in frame.columns: + return ( + frame, + pd.to_numeric(frame[column], errors="coerce").fillna(0.0), + ) + if household_weight_by_id is not None and "household_id" in frame.columns: + return ( + frame, + frame["household_id"].map(household_weight_by_id).fillna(0.0), + ) + return frame, pd.Series(0.0, index=frame.index, dtype=float) + + +def _source_names_for_diagnostics(result: USMicroplexBuildResult) -> list[str]: + synthesis = dict(result.synthesis_metadata) + names: list[str] = [] + if result.fusion_plan is not None: + names.extend(str(name) for name in result.fusion_plan.source_names) + names.extend(str(name) for name in synthesis.get("source_names", ()) if name) + scaffold_source = synthesis.get("scaffold_source") + if scaffold_source: + names.append(str(scaffold_source)) + for frame in result.source_frames: + source = getattr(frame, "source", None) + source_name = getattr(source, "name", None) + if source_name: + names.append(str(source_name)) + return list(dict.fromkeys(names)) + + +def _scaffold_source_for_diagnostics(result: USMicroplexBuildResult) -> str | None: + scaffold_source = result.synthesis_metadata.get("scaffold_source") + if scaffold_source: + return str(scaffold_source) + source_names = _source_names_for_diagnostics(result) + return source_names[0] if source_names else None + + +def _fixed_spine_source_entry( + result: USMicroplexBuildResult, + *, + total_entity_summaries: dict[str, dict[str, Any]], +) -> dict[str, Any] | None: + fixed_spine = result.calibration_summary.get("fixed_spine") + if not isinstance(fixed_spine, dict) or not fixed_spine.get("enabled"): + return None + + source_metadata = dict(fixed_spine.get("source_metadata", {})) + entry: dict[str, Any] = { + "source_name": source_metadata.get("source", "forbes_fixed_spine"), + "source_class": "fixed_spine", + "source_role": "post_calibration_append", + "source_metadata": source_metadata, + } + fixed_spine_config = ForbesFixedSpineConfig() + fixed_entity_summaries = _fixed_spine_entity_summaries( + result, + fixed_spine_config=fixed_spine_config, + ) + entry.update( + { + **_source_entity_fields( + fixed_entity_summaries, + total_entity_summaries, + ), + "household_id_detection": { + "method": "forbes_default_household_id_floor", + "minimum_household_id": fixed_spine_config.household_id_start, + }, + } + ) + return entry + + +def _fixed_spine_entity_summaries( + result: USMicroplexBuildResult, + *, + fixed_spine_config: ForbesFixedSpineConfig, +) -> dict[str, dict[str, Any]]: + summaries = _zero_entity_summaries() + id_floors = { + "households": ("household_id", fixed_spine_config.household_id_start), + "persons": ("person_id", fixed_spine_config.person_id_start), + "tax_units": ("tax_unit_id", fixed_spine_config.tax_unit_id_start), + } + for entity, (id_column, id_floor) in id_floors.items(): + frame, weights = _policyengine_entity_weights(result, entity) + if frame is None or weights is None or id_column not in frame.columns: + continue + ids = pd.to_numeric(frame[id_column], errors="coerce") + fixed_mask = ids >= id_floor + fixed_weights = weights.loc[fixed_mask] + summaries[entity] = { + "count": int(fixed_mask.sum()), + "weight_sum": float(fixed_weights.sum()), + "available": True, + } + return summaries + + +def _weight_share(value: float, denominator: float) -> float | None: + if denominator <= 0: + return None + return float(value) / float(denominator) diff --git a/src/microplex_us/pipelines/artifact_types.py b/src/microplex_us/pipelines/artifact_types.py new file mode 100644 index 0000000..d28fcf6 --- /dev/null +++ b/src/microplex_us/pipelines/artifact_types.py @@ -0,0 +1,54 @@ +"""Shared result types for saved US Microplex artifact bundles.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from microplex_us.pipelines.us import USMicroplexBuildResult + + +@dataclass(frozen=True) +class USMicroplexArtifactPaths: + """Filesystem locations for persisted pipeline artifacts.""" + + output_dir: Path + seed_data: Path + synthetic_data: Path + calibrated_data: Path + targets: Path + manifest: Path + version_id: str | None = None + scaffold_seed_data: Path | None = None + synthesizer: Path | None = None + policyengine_dataset: Path | None = None + data_flow_snapshot: Path | None = None + stage_manifest: Path | None = None + artifact_inventory: Path | None = None + conditional_readiness: Path | None = None + source_plan: Path | None = None + pre_calibration_policyengine_entity_tables: Path | None = None + policyengine_entity_tables: Path | None = None + calibration_summary: Path | None = None + validation_evidence: Path | None = None + policyengine_harness: Path | None = None + policyengine_native_scores: Path | None = None + policyengine_native_audit: Path | None = None + policyengine_native_target_diagnostics: Path | None = None + child_tax_unit_agi_drift: Path | None = None + capital_gains_lots: Path | None = None + source_weight_diagnostics: Path | None = None + run_registry: Path | None = None + run_index_db: Path | None = None + + +@dataclass(frozen=True) +class USMicroplexVersionedBuildArtifacts: + """End-to-end build, save, and frontier-tracking result.""" + + build_result: USMicroplexBuildResult + artifact_paths: USMicroplexArtifactPaths + current_entry: Any | None = None + frontier_entry: Any | None = None + frontier_delta: float | None = None diff --git a/src/microplex_us/pipelines/artifact_validation.py b/src/microplex_us/pipelines/artifact_validation.py new file mode 100644 index 0000000..ced295d --- /dev/null +++ b/src/microplex_us/pipelines/artifact_validation.py @@ -0,0 +1,160 @@ +"""Validation and benchmark artifact helpers for saved US Microplex bundles.""" + +from __future__ import annotations + +from collections.abc import Mapping +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any + +from microplex.targets import TargetProvider + +from microplex_us.pipelines.us import USMicroplexBuildResult +from microplex_us.policyengine.harness import ( + PolicyEngineUSComparisonCache, + PolicyEngineUSHarnessSlice, + default_policyengine_us_db_all_target_slices, + default_policyengine_us_harness_slices, + filter_nonempty_policyengine_us_harness_slices, +) +from microplex_us.policyengine.us import PolicyEngineUSDBTargetProvider + + +def _stage9_benchmark_summary(manifest: Mapping[str, Any]) -> dict[str, Any]: + summary: dict[str, Any] = {} + for key in ( + "policyengine_harness", + "policyengine_native_scores", + "policyengine_native_audit", + "imputation_ablation", + ): + value = manifest.get(key) + if isinstance(value, Mapping): + summary[key] = dict(value) + diagnostics = manifest.get("diagnostics") + if isinstance(diagnostics, Mapping): + for key in ("child_tax_unit_agi_drift", "capital_gains_lots"): + value = diagnostics.get(key) + if isinstance(value, Mapping): + summary[key] = dict(value) + return summary + + +def _summarize_child_tax_unit_agi_drift_ratios( + payload: dict[str, Any], + *, + stage: str, + variables: tuple[str, ...], +) -> dict[str, Any]: + stages = dict(payload.get("stages", {})) + stage_payload = dict(stages.get(stage, {})) + subsets = dict(stage_payload.get("subsets", {})) + adults = dict(subsets.get("adults", {})) + dependents = dict(subsets.get("dependents_under_20", {})) + ratios: dict[str, float | None] = {} + for variable in variables: + adult_sum = adults.get(variable, {}).get("sum") + child_sum = dependents.get(variable, {}).get("sum") + if adult_sum in (None, 0): + ratios[variable] = None + else: + ratios[variable] = float(child_sum or 0.0) / float(adult_sum) + return { + "stage": stage, + "dependents_under_20_sum_share": ratios, + } + + +def _resolve_policyengine_harness_context( + result: USMicroplexBuildResult, + *, + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None, + policyengine_target_provider: TargetProvider | None, + policyengine_baseline_dataset: str | Path | None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ), + policyengine_harness_metadata: dict[str, Any] | None, +) -> tuple[ + TargetProvider | None, + str | Path | None, + tuple[PolicyEngineUSHarnessSlice, ...], + dict[str, Any], +]: + resolved_target_provider = policyengine_target_provider + if ( + resolved_target_provider is None + and result.config.policyengine_targets_db is not None + ): + resolved_target_provider = PolicyEngineUSDBTargetProvider( + result.config.policyengine_targets_db + ) + + resolved_baseline_dataset = ( + policyengine_baseline_dataset or result.config.policyengine_baseline_dataset + ) + + harness_period = result.config.policyengine_dataset_year or 2024 + if policyengine_harness_slices is not None: + resolved_harness_slices = tuple(policyengine_harness_slices) + elif result.config.policyengine_targets_db is not None: + resolved_harness_slices = default_policyengine_us_db_all_target_slices( + period=harness_period, + reform_id=result.config.policyengine_target_reform_id, + ) + else: + resolved_harness_slices = default_policyengine_us_harness_slices( + period=harness_period + ) + if resolved_target_provider is not None and resolved_harness_slices: + resolved_harness_slices = filter_nonempty_policyengine_us_harness_slices( + resolved_target_provider, + resolved_harness_slices, + cache=policyengine_comparison_cache, + ) + + resolved_harness_metadata = { + "baseline_dataset": ( + Path(resolved_baseline_dataset).name + if resolved_baseline_dataset is not None + else None + ), + "targets_db": ( + Path(result.config.policyengine_targets_db).name + if result.config.policyengine_targets_db is not None + else None + ), + "target_period": result.config.policyengine_target_period, + "target_variables": list(result.config.policyengine_target_variables), + "target_domains": list(result.config.policyengine_target_domains), + "target_geo_levels": list(result.config.policyengine_target_geo_levels), + "target_profile": result.config.policyengine_target_profile, + "calibration_target_profile": ( + result.config.policyengine_calibration_target_profile + ), + "target_reform_id": result.config.policyengine_target_reform_id, + "harness_slice_names": [ + slice_spec.name for slice_spec in resolved_harness_slices + ], + "policyengine_us_runtime_version": _resolve_policyengine_us_runtime_version(), + "harness_suite": ( + "policyengine_us_all_targets" + if result.config.policyengine_targets_db is not None + and policyengine_harness_slices is None + else None + ), + **dict(policyengine_harness_metadata or {}), + } + return ( + resolved_target_provider, + resolved_baseline_dataset, + resolved_harness_slices, + resolved_harness_metadata, + ) + + +def _resolve_policyengine_us_runtime_version() -> str | None: + try: + return version("policyengine-us") + except PackageNotFoundError: + return None diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index baf9c14..3cb464e 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -3,28 +3,42 @@ from __future__ import annotations import json -from collections.abc import Mapping -from dataclasses import asdict, dataclass, replace from datetime import UTC, datetime -from importlib.metadata import PackageNotFoundError, version from pathlib import Path from typing import Any -import pandas as pd -from microplex.core import SourceProvider, SourceQuery from microplex.targets import ( TargetProvider, assert_valid_benchmark_artifact_manifest, ) -from microplex_us.capital_gains_lots import ( - SyntheticCapitalGainsLotConfig, - generate_synthetic_capital_gains_lots, - synthetic_capital_gains_lot_metadata, - validate_capital_gains_lot_anchors, - write_capital_gains_lots_sqlite, +from microplex_us.pipelines.artifact_dataset_assembly import ( + _maybe_write_capital_gains_lot_artifact, +) +from microplex_us.pipelines.artifact_io import ( + _stage_artifact_ref, + _stage_diagnostics, + _write_json_atomically, + _write_json_unless_live_artifact_exists, + _write_parquet_unless_live_artifact_exists, +) +from microplex_us.pipelines.artifact_replay import ( + replay_and_save_versioned_us_microplex_policyengine_stage, + replay_us_microplex_policyengine_stage_from_artifact, +) +from microplex_us.pipelines.artifact_source_diagnostics import ( + _build_source_weight_diagnostics, + _write_us_source_plan_artifact, +) +from microplex_us.pipelines.artifact_types import ( + USMicroplexArtifactPaths, + USMicroplexVersionedBuildArtifacts, +) +from microplex_us.pipelines.artifact_validation import ( + _resolve_policyengine_harness_context, + _stage9_benchmark_summary, + _summarize_child_tax_unit_agi_drift_ratios, ) -from microplex_us.data_sources.forbes import ForbesFixedSpineConfig from microplex_us.pipelines.index_db import ( append_us_microplex_run_index_entry, ) @@ -32,14 +46,10 @@ compute_us_pe_native_scores, ) from microplex_us.pipelines.registry import ( - FrontierMetric, append_us_microplex_run_registry_entry, build_us_microplex_run_registry_entry, - load_us_microplex_run_registry, - select_us_microplex_frontier_entry, ) from microplex_us.pipelines.stage_contracts import ( - get_us_stage_artifact_contract, resolve_us_stage_artifact_contract_path, ) from microplex_us.pipelines.stage_manifest import ( @@ -47,10 +57,7 @@ write_us_validation_evidence_manifest, ) from microplex_us.pipelines.stage_run import ( - USArtifactRef, USDatasetAssemblyOutputs, - USDiagnosticOutput, - USRunProfileOutputs, USStageInputOverride, USValidationBenchmarkingOutputs, write_us_stage_run_manifests_from_artifact_manifest, @@ -63,826 +70,55 @@ summarize_child_tax_unit_agi_drift, ) from microplex_us.pipelines.us import ( - USMicroplexBuildConfig, USMicroplexBuildResult, USMicroplexPipeline, - USMicroplexTargets, - build_us_microplex, +) +from microplex_us.pipelines.versioned_artifacts import ( + _allocate_versioned_output_dir, + _allocate_versioned_output_dir_for_config, + _finalize_versioned_build_artifacts, + _initialize_versioned_stage_runtime_writer, + _json_ready, + _json_ready_query, + _provider_query_plan, + _registry_metric_value, + _short_config_hash, + build_and_save_versioned_us_microplex, + build_and_save_versioned_us_microplex_from_data_dir, + build_and_save_versioned_us_microplex_from_source_provider, + build_and_save_versioned_us_microplex_from_source_providers, + save_versioned_us_microplex_artifacts, + save_versioned_us_microplex_build_result, ) from microplex_us.policyengine.harness import ( PolicyEngineUSComparisonCache, PolicyEngineUSHarnessSlice, - default_policyengine_us_db_all_target_slices, - default_policyengine_us_harness_slices, evaluate_policyengine_us_harness, - filter_nonempty_policyengine_us_harness_slices, -) -from microplex_us.policyengine.us import ( - PolicyEngineUSDBTargetProvider, ) - -@dataclass(frozen=True) -class USMicroplexArtifactPaths: - """Filesystem locations for persisted pipeline artifacts.""" - - output_dir: Path - seed_data: Path - synthetic_data: Path - calibrated_data: Path - targets: Path - manifest: Path - version_id: str | None = None - scaffold_seed_data: Path | None = None - synthesizer: Path | None = None - policyengine_dataset: Path | None = None - data_flow_snapshot: Path | None = None - stage_manifest: Path | None = None - artifact_inventory: Path | None = None - conditional_readiness: Path | None = None - source_plan: Path | None = None - pre_calibration_policyengine_entity_tables: Path | None = None - policyengine_entity_tables: Path | None = None - calibration_summary: Path | None = None - validation_evidence: Path | None = None - policyengine_harness: Path | None = None - policyengine_native_scores: Path | None = None - policyengine_native_audit: Path | None = None - policyengine_native_target_diagnostics: Path | None = None - child_tax_unit_agi_drift: Path | None = None - capital_gains_lots: Path | None = None - source_weight_diagnostics: Path | None = None - run_registry: Path | None = None - run_index_db: Path | None = None - - -@dataclass(frozen=True) -class USMicroplexVersionedBuildArtifacts: - """End-to-end build, save, and frontier-tracking result.""" - - build_result: USMicroplexBuildResult - artifact_paths: USMicroplexArtifactPaths - current_entry: Any | None = None - frontier_entry: Any | None = None - frontier_delta: float | None = None - - -def _stage_artifact_ref( - artifact_root: str | Path, - stage_id: str, - artifact_key: str, - *, - assume_exists: bool = False, -) -> USArtifactRef: - contract = get_us_stage_artifact_contract(stage_id, artifact_key) - return USArtifactRef( - key=artifact_key, - path=resolve_us_stage_artifact_contract_path( - artifact_root, - stage_id, - artifact_key, - ), - format=contract.format, - required=contract.required, - resume_role=contract.resume_role, - assume_exists=assume_exists, - ) - - -def _stage_diagnostics( - stage_id: str, - summary: Mapping[str, Any], -) -> dict[str, USDiagnosticOutput]: - return { - "stage_summary": USDiagnosticOutput( - key="stage_summary", - description=f"Runtime diagnostic summary for {stage_id}.", - summary=dict(summary), - ) - } - - -def _write_parquet_unless_live_artifact_exists( - path: Path, - frame: pd.DataFrame, - *, - live_artifact: bool, -) -> None: - if live_artifact and path.exists(): - return - path.parent.mkdir(parents=True, exist_ok=True) - frame.to_parquet(path, index=False) - - -def _write_json_unless_live_artifact_exists( - path: Path, - payload: Mapping[str, Any], - *, - live_artifact: bool, -) -> None: - if live_artifact and path.exists(): - return - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, sort_keys=True)) - - -def _stage9_benchmark_summary(manifest: Mapping[str, Any]) -> dict[str, Any]: - summary: dict[str, Any] = {} - for key in ( - "policyengine_harness", - "policyengine_native_scores", - "policyengine_native_audit", - "imputation_ablation", - ): - value = manifest.get(key) - if isinstance(value, Mapping): - summary[key] = dict(value) - diagnostics = manifest.get("diagnostics") - if isinstance(diagnostics, Mapping): - for key in ("child_tax_unit_agi_drift", "capital_gains_lots"): - value = diagnostics.get(key) - if isinstance(value, Mapping): - summary[key] = dict(value) - return summary - - -def replay_us_microplex_policyengine_stage_from_artifact( - artifact_dir: str | Path, - *, - config_overrides: dict[str, Any] | None = None, -) -> USMicroplexBuildResult: - """Replay calibration/export inputs from a saved artifact without raw ETL. - - This reloads saved seed and synthetic rows, applies optional runtime config - overrides, and reruns the downstream calibration stage from the saved - synthetic population. For PE-DB builds, this intentionally calls - ``calibrate_policyengine_tables`` even when ``calibration_backend="none"`` - so PE target materialization and export-only variables stay on the same - path as a full pipeline build. - """ - - artifact_root = Path(artifact_dir).expanduser().resolve() - manifest_path = artifact_root / "manifest.json" - if not manifest_path.exists(): - raise FileNotFoundError(f"Saved artifact manifest not found: {manifest_path}") - - manifest = json.loads(manifest_path.read_text()) - config_payload = dict(manifest.get("config", {})) - config_payload.update(dict(config_overrides or {})) - config = USMicroplexBuildConfig(**config_payload) - - seed_data = pd.read_parquet( - _resolve_saved_artifact_file(artifact_root, manifest, "seed_data") - ) - scaffold_seed_data_path = _resolve_optional_saved_artifact_file( - artifact_root, - manifest, - "scaffold_seed_data", - ) - scaffold_seed_data = ( - pd.read_parquet(scaffold_seed_data_path) - if scaffold_seed_data_path is not None - else None - ) - synthetic_data = pd.read_parquet( - _resolve_saved_artifact_file(artifact_root, manifest, "synthetic_data") - ) - targets_payload = json.loads( - _resolve_saved_artifact_file(artifact_root, manifest, "targets").read_text() - ) - targets = USMicroplexTargets( - marginal=dict(targets_payload.get("marginal", {})), - continuous=dict(targets_payload.get("continuous", {})), - ) - - pipeline = USMicroplexPipeline(config) - pre_calibration_policyengine_tables = pipeline.build_policyengine_entity_tables( - synthetic_data - ) - if config.policyengine_targets_db is not None: - policyengine_tables, calibrated_data, calibration_summary = ( - pipeline.calibrate_policyengine_tables(pre_calibration_policyengine_tables) - ) - else: - calibrated_data, calibration_summary = pipeline.calibrate( - synthetic_data, - targets, - ) - policyengine_tables = pipeline.build_policyengine_entity_tables(calibrated_data) - - synthesis_metadata = dict(manifest.get("synthesis", {})) - synthesis_metadata["policyengine_stage_replay"] = { - "source_artifact_dir": str(artifact_root), - "source_manifest": str(manifest_path), - "config_override_keys": sorted((config_overrides or {}).keys()), - } - - return USMicroplexBuildResult( - config=config, - seed_data=seed_data, - synthetic_data=synthetic_data, - calibrated_data=calibrated_data, - targets=targets, - calibration_summary=calibration_summary, - synthesis_metadata=synthesis_metadata, - policyengine_tables=policyengine_tables, - pre_calibration_policyengine_tables=pre_calibration_policyengine_tables, - scaffold_seed_data=scaffold_seed_data, - ) - - -def replay_and_save_versioned_us_microplex_policyengine_stage( - artifact_dir: str | Path, - output_root: str | Path | None = None, - *, - config_overrides: dict[str, Any] | None = None, - version_id: str | None = None, - frontier_metric: FrontierMetric = "candidate_composite_parity_loss", - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = True, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, -) -> USMicroplexVersionedBuildArtifacts: - """Replay a saved artifact's policy stage and persist a new versioned bundle.""" - - artifact_root = Path(artifact_dir).expanduser().resolve() - build_result = replay_us_microplex_policyengine_stage_from_artifact( - artifact_root, - config_overrides=config_overrides, - ) - resolved_output_root = ( - Path(output_root).expanduser().resolve() - if output_root is not None - else artifact_root.parent - ) - replay_metadata = { - "policyengine_stage_replay": True, - "source_artifact_dir": str(artifact_root), - **dict(run_registry_metadata or {}), - } - return _finalize_versioned_build_artifacts( - build_result, - output_root=resolved_output_root, - version_id=version_id, - frontier_metric=frontier_metric, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=replay_metadata, - ) - - -def _resolve_saved_artifact_file( - artifact_root: Path, - manifest: dict[str, Any], - artifact_key: str, -) -> Path: - artifacts = dict(manifest.get("artifacts", {})) - filename = artifacts.get(artifact_key) - if not filename: - filename = ( - "targets.json" if artifact_key == "targets" else f"{artifact_key}.parquet" - ) - path = Path(filename) - if not path.is_absolute(): - path = artifact_root / path - if not path.exists(): - raise FileNotFoundError(f"Saved artifact file not found: {path}") - return path - - -def _resolve_optional_saved_artifact_file( - artifact_root: Path, - manifest: dict[str, Any], - artifact_key: str, -) -> Path | None: - artifacts = dict(manifest.get("artifacts", {})) - filename = artifacts.get(artifact_key) - if not filename: - return None - path = Path(str(filename)) - if not path.is_absolute(): - path = artifact_root / path - if not path.exists(): - raise FileNotFoundError(f"Saved optional artifact file not found: {path}") - return path - - -def _write_us_source_plan_artifact( - result: USMicroplexBuildResult, - output_path: Path, -) -> None: - synthesis = dict(result.synthesis_metadata) - source_names = tuple( - dict.fromkeys( - value - for value in ( - *list(synthesis.get("source_names", ())), - synthesis.get("scaffold_source"), - ) - if isinstance(value, str) and value - ) - ) - payload = { - "formatVersion": 1, - "stageId": "03_source_planning", - "sourceNames": list(source_names), - "scaffoldSource": synthesis.get("scaffold_source"), - "donorIntegratedVariables": list( - synthesis.get("donor_integrated_variables", ()) - ), - "conditionVars": list(synthesis.get("condition_vars", ())), - "targetVars": list(synthesis.get("target_vars", ())), - "donorAuthoritativeOverrideVariables": list( - synthesis.get("donor_authoritative_override_variables", ()) - ), - "donorExcludedVariables": list(synthesis.get("donor_excluded_variables", ())), - } - if result.fusion_plan is not None: - payload["fusionPlan"] = { - "sourceNames": list(result.fusion_plan.source_names), - } - _write_json_atomically(output_path, payload) - - -def _build_source_weight_diagnostics( - result: USMicroplexBuildResult, -) -> dict[str, Any]: - """Summarize source-weight provenance without exporting diagnostics to H5.""" - - entity_summaries = _entity_weight_summaries(result) - household_summary = entity_summaries["households"] - total_household_weight = household_summary["weight_sum"] - source_names = _source_names_for_diagnostics(result) - scaffold_source = _scaffold_source_for_diagnostics(result) - donor_sources = [ - source_name - for source_name in source_names - if scaffold_source is None or source_name != scaffold_source - ] - sources: list[dict[str, Any]] = [] - - fixed_spine_entry = _fixed_spine_source_entry( - result, - total_entity_summaries=entity_summaries, - ) - fixed_entity_summaries = ( - { - entity: { - "count": fixed_spine_entry.get(f"{prefix}_count", 0), - "weight_sum": fixed_spine_entry.get(f"{prefix}_weight_sum", 0.0), - "available": fixed_spine_entry.get(f"{prefix}_weight_sum") is not None, - } - for entity, prefix in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES.items() - } - if fixed_spine_entry is not None - else {} - ) - ordinary_entity_summaries = _subtract_entity_summaries( - entity_summaries, - fixed_entity_summaries, - ) - - sources.append( - { - "source_name": scaffold_source or "microplex_synthetic_population", - "source_class": "synthetic_population", - "source_role": "scaffold", - "source_names": source_names, - **_source_entity_fields(ordinary_entity_summaries, entity_summaries), - } - ) - - donor_integrated_variables = list( - result.synthesis_metadata.get("donor_integrated_variables", ()) - ) - for source_name in donor_sources: - sources.append( - { - "source_name": source_name, - "source_class": "donor_imputation", - "source_role": "donor", - "integrated_variable_count": len(donor_integrated_variables), - "row_contribution": "variables_imputed_into_synthetic_rows", - **_source_entity_fields( - _zero_entity_summaries(), - entity_summaries, - ), - } - ) - - if fixed_spine_entry is not None: - sources.append(fixed_spine_entry) - - numeric_shares = [ - float(source["household_weight_share"]) - for source in sources - if isinstance(source.get("household_weight_share"), int | float) - ] - summary = { - "diagnostic_scope": "saved_artifact_entity_weight_by_source_rows", - "household_count": household_summary["count"], - "total_household_weight": total_household_weight, - "person_count": entity_summaries["persons"]["count"], - "total_person_weight": entity_summaries["persons"]["weight_sum"], - "tax_unit_count": entity_summaries["tax_units"]["count"], - "total_tax_unit_weight": entity_summaries["tax_units"]["weight_sum"], - "source_entry_count": len(sources), - "donor_source_count": len(donor_sources), - "donor_integrated_variable_count": len(donor_integrated_variables), - "support_rows_appended": False, - "donor_rows_appended": False, - "support_household_weight_sum": 0.0, - "support_household_weight_share": 0.0, - "puf_support_household_weight_sum": 0.0, - "puf_support_household_weight_share": 0.0, - "max_source_household_weight_share": ( - max(numeric_shares) if numeric_shares else None - ), - "fixed_spine_enabled": bool( - isinstance(result.calibration_summary.get("fixed_spine"), dict) - and result.calibration_summary.get("fixed_spine", {}).get("enabled") - ), - "h5_exported": False, - } - - return { - "formatVersion": 1, - "created_at": datetime.now(UTC).isoformat(), - "summary": summary, - "sources": sources, - "notes": [ - "Donor sources contribute imputed variables to synthetic rows; they are not appended as weighted source rows.", - "Source diagnostics are written as a sidecar and are intentionally not exported into PolicyEngine H5 variables.", - ], - } - - -_SOURCE_DIAGNOSTIC_ENTITY_PREFIXES = { - "households": "household", - "persons": "person", - "tax_units": "tax_unit", -} - - -def _entity_weight_summaries( - result: USMicroplexBuildResult, -) -> dict[str, dict[str, Any]]: - summaries = _zero_entity_summaries() - if result.policyengine_tables is not None: - for entity in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES: - frame, weights = _policyengine_entity_weights(result, entity) - if frame is None or weights is None: - continue - summaries[entity] = { - "count": int(len(frame)), - "weight_sum": float(weights.sum()), - "available": True, - } - return summaries - - frame = result.calibrated_data - if frame.empty: - return summaries - weight_column = ( - "household_weight" if "household_weight" in frame.columns else "weight" - ) - if weight_column not in frame.columns: - summaries["persons"] = { - "count": int(len(frame)), - "weight_sum": 0.0, - "available": False, - } - return summaries - - weights = pd.to_numeric(frame[weight_column], errors="coerce").fillna(0.0) - summaries["persons"] = { - "count": int(len(frame)), - "weight_sum": float(weights.sum()), - "available": True, - } - if "household_id" in frame.columns: - household_weights = weights.groupby(frame["household_id"], sort=False).first() - summaries["households"] = { - "count": int(len(household_weights)), - "weight_sum": float(household_weights.sum()), - "available": True, - } - return summaries - - -def _zero_entity_summaries() -> dict[str, dict[str, Any]]: - return { - entity: {"count": 0, "weight_sum": 0.0, "available": False} - for entity in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES - } - - -def _subtract_entity_summaries( - total: dict[str, dict[str, Any]], - subtract: dict[str, dict[str, Any]], -) -> dict[str, dict[str, Any]]: - result: dict[str, dict[str, Any]] = {} - for entity in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES: - total_summary = total.get(entity, {}) - subtract_summary = subtract.get(entity, {}) - total_count = int(total_summary.get("count", 0) or 0) - subtract_count = int(subtract_summary.get("count", 0) or 0) - total_weight = float(total_summary.get("weight_sum", 0.0) or 0.0) - subtract_weight = float(subtract_summary.get("weight_sum", 0.0) or 0.0) - result[entity] = { - "count": max(total_count - subtract_count, 0), - "weight_sum": max(total_weight - subtract_weight, 0.0), - "available": bool(total_summary.get("available", False)), - } - return result - - -def _source_entity_fields( - source: dict[str, dict[str, Any]], - total: dict[str, dict[str, Any]], -) -> dict[str, Any]: - fields: dict[str, Any] = {} - for entity, prefix in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES.items(): - source_summary = source.get(entity, {}) - total_summary = total.get(entity, {}) - source_weight = source_summary.get("weight_sum") - fields[f"{prefix}_count"] = int(source_summary.get("count", 0) or 0) - fields[f"{prefix}_weight_sum"] = ( - float(source_weight) if source_weight is not None else None - ) - fields[f"{prefix}_weight_share"] = _weight_share( - float(source_weight or 0.0), - float(total_summary.get("weight_sum", 0.0) or 0.0), - ) - return fields - - -def _policyengine_entity_weights( - result: USMicroplexBuildResult, - entity: str, -) -> tuple[pd.DataFrame | None, pd.Series | None]: - tables = result.policyengine_tables - if tables is None: - return None, None - households = tables.households - if households is None or "household_weight" not in households.columns: - household_weight_by_id = None - else: - household_weights = pd.to_numeric( - households["household_weight"], - errors="coerce", - ).fillna(0.0) - household_weight_by_id = pd.Series( - household_weights.to_numpy(dtype=float), - index=households["household_id"], - ) - if entity == "households": - if households is None or household_weight_by_id is None: - return None, None - return households, household_weights - if entity == "persons": - return _frame_and_entity_weights( - tables.persons, - direct_weight_columns=("weight", "person_weight", "household_weight"), - household_weight_by_id=household_weight_by_id, - ) - if entity == "tax_units": - return _frame_and_entity_weights( - tables.tax_units, - direct_weight_columns=("tax_unit_weight", "household_weight"), - household_weight_by_id=household_weight_by_id, - ) - return None, None - - -def _frame_and_entity_weights( - frame: pd.DataFrame | None, - *, - direct_weight_columns: tuple[str, ...], - household_weight_by_id: pd.Series | None, -) -> tuple[pd.DataFrame | None, pd.Series | None]: - if frame is None: - return None, None - for column in direct_weight_columns: - if column in frame.columns: - return ( - frame, - pd.to_numeric(frame[column], errors="coerce").fillna(0.0), - ) - if household_weight_by_id is not None and "household_id" in frame.columns: - return ( - frame, - frame["household_id"].map(household_weight_by_id).fillna(0.0), - ) - return frame, pd.Series(0.0, index=frame.index, dtype=float) - - -def _source_names_for_diagnostics(result: USMicroplexBuildResult) -> list[str]: - synthesis = dict(result.synthesis_metadata) - names: list[str] = [] - if result.fusion_plan is not None: - names.extend(str(name) for name in result.fusion_plan.source_names) - names.extend(str(name) for name in synthesis.get("source_names", ()) if name) - scaffold_source = synthesis.get("scaffold_source") - if scaffold_source: - names.append(str(scaffold_source)) - for frame in result.source_frames: - source = getattr(frame, "source", None) - source_name = getattr(source, "name", None) - if source_name: - names.append(str(source_name)) - return list(dict.fromkeys(names)) - - -def _scaffold_source_for_diagnostics(result: USMicroplexBuildResult) -> str | None: - scaffold_source = result.synthesis_metadata.get("scaffold_source") - if scaffold_source: - return str(scaffold_source) - source_names = _source_names_for_diagnostics(result) - return source_names[0] if source_names else None - - -def _fixed_spine_source_entry( - result: USMicroplexBuildResult, - *, - total_entity_summaries: dict[str, dict[str, Any]], -) -> dict[str, Any] | None: - fixed_spine = result.calibration_summary.get("fixed_spine") - if not isinstance(fixed_spine, dict) or not fixed_spine.get("enabled"): - return None - - source_metadata = dict(fixed_spine.get("source_metadata", {})) - entry: dict[str, Any] = { - "source_name": source_metadata.get("source", "forbes_fixed_spine"), - "source_class": "fixed_spine", - "source_role": "post_calibration_append", - "source_metadata": source_metadata, - } - fixed_spine_config = ForbesFixedSpineConfig() - fixed_entity_summaries = _fixed_spine_entity_summaries( - result, - fixed_spine_config=fixed_spine_config, - ) - entry.update( - { - **_source_entity_fields( - fixed_entity_summaries, - total_entity_summaries, - ), - "household_id_detection": { - "method": "forbes_default_household_id_floor", - "minimum_household_id": fixed_spine_config.household_id_start, - }, - } - ) - return entry - - -def _fixed_spine_entity_summaries( - result: USMicroplexBuildResult, - *, - fixed_spine_config: ForbesFixedSpineConfig, -) -> dict[str, dict[str, Any]]: - summaries = _zero_entity_summaries() - id_floors = { - "households": ("household_id", fixed_spine_config.household_id_start), - "persons": ("person_id", fixed_spine_config.person_id_start), - "tax_units": ("tax_unit_id", fixed_spine_config.tax_unit_id_start), - } - for entity, (id_column, id_floor) in id_floors.items(): - frame, weights = _policyengine_entity_weights(result, entity) - if frame is None or weights is None or id_column not in frame.columns: - continue - ids = pd.to_numeric(frame[id_column], errors="coerce") - fixed_mask = ids >= id_floor - fixed_weights = weights.loc[fixed_mask] - summaries[entity] = { - "count": int(fixed_mask.sum()), - "weight_sum": float(fixed_weights.sum()), - "available": True, - } - return summaries - - -def _weight_share(value: float, denominator: float) -> float | None: - if denominator <= 0: - return None - return float(value) / float(denominator) - - -def _summarize_child_tax_unit_agi_drift_ratios( - payload: dict[str, Any], - *, - stage: str, - variables: tuple[str, ...], -) -> dict[str, Any]: - stages = dict(payload.get("stages", {})) - stage_payload = dict(stages.get(stage, {})) - subsets = dict(stage_payload.get("subsets", {})) - adults = dict(subsets.get("adults", {})) - dependents = dict(subsets.get("dependents_under_20", {})) - ratios: dict[str, float | None] = {} - for variable in variables: - adult_sum = adults.get(variable, {}).get("sum") - child_sum = dependents.get(variable, {}).get("sum") - if adult_sum in (None, 0): - ratios[variable] = None - else: - ratios[variable] = float(child_sum or 0.0) / float(adult_sum) - return { - "stage": stage, - "dependents_under_20_sum_share": ratios, - } - - -def _maybe_write_capital_gains_lot_artifact( - result: USMicroplexBuildResult, - output_dir: Path, -) -> tuple[Path | None, dict[str, Any] | None]: - if ( - not result.config.capital_gains_lots_enabled - or result.policyengine_tables is None - ): - return None, None - persons = result.policyengine_tables.persons - gain_column = "long_term_capital_gains_before_response" - if gain_column not in persons.columns: - return None, { - "enabled": True, - "written": False, - "reason": f"missing {gain_column}", - } - - period = result.config.policyengine_dataset_year or 2024 - lot_config = SyntheticCapitalGainsLotConfig( - random_seed=( - result.config.capital_gains_lots_random_seed - if result.config.capital_gains_lots_random_seed is not None - else result.config.random_seed - ), - max_lots_per_person=int(result.config.capital_gains_lots_max_lots_per_person), - ) - lots = generate_synthetic_capital_gains_lots( - persons, - period=period, - config=lot_config, - gain_column=gain_column, - ) - validate_capital_gains_lot_anchors(persons, lots, gain_column=gain_column) - metadata = synthetic_capital_gains_lot_metadata( - lot_config, - period=period, - source_gain_column=gain_column, - ) - nonzero_people = int( - pd.to_numeric(persons[gain_column], errors="coerce").fillna(0.0).ne(0.0).sum() - ) - metadata.update( - { - "person_rows": int(len(persons)), - "nonzero_person_rows": nonzero_people, - "lot_rows": int(len(lots)), - } - ) - path = resolve_us_stage_artifact_contract_path( - output_dir, - "08_dataset_assembly", - "capital_gains_lots", - ) - write_capital_gains_lots_sqlite(lots, path, metadata=metadata) - return path, { - "enabled": True, - "written": True, - "path": path.name, - "person_rows": int(len(persons)), - "nonzero_person_rows": nonzero_people, - "lot_rows": int(len(lots)), - "source_gain_column": gain_column, - "max_lots_per_person": int(lot_config.max_lots_per_person), - } +__all__ = [ + "USMicroplexArtifactPaths", + "USMicroplexVersionedBuildArtifacts", + "_allocate_versioned_output_dir", + "_allocate_versioned_output_dir_for_config", + "_finalize_versioned_build_artifacts", + "_initialize_versioned_stage_runtime_writer", + "_json_ready", + "_json_ready_query", + "_maybe_write_capital_gains_lot_artifact", + "_provider_query_plan", + "_registry_metric_value", + "_short_config_hash", + "build_and_save_versioned_us_microplex", + "build_and_save_versioned_us_microplex_from_data_dir", + "build_and_save_versioned_us_microplex_from_source_provider", + "build_and_save_versioned_us_microplex_from_source_providers", + "replay_and_save_versioned_us_microplex_policyengine_stage", + "replay_us_microplex_policyengine_stage_from_artifact", + "save_us_microplex_artifacts", + "save_versioned_us_microplex_artifacts", + "save_versioned_us_microplex_build_result", +] def save_us_microplex_artifacts( @@ -1574,726 +810,3 @@ def save_us_microplex_artifacts( run_registry=resolved_run_registry_path, run_index_db=resolved_run_index_path, ) - - -def save_versioned_us_microplex_artifacts( - result: USMicroplexBuildResult, - output_root: str | Path, - *, - version_id: str | None = None, - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = False, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), -) -> USMicroplexArtifactPaths: - """Persist a build under a stable versioned directory beneath one output root.""" - output_root = Path(output_root) - output_root.mkdir(parents=True, exist_ok=True) - resolved_version_id, output_dir = _allocate_versioned_output_dir( - output_root, - version_id=version_id, - result=result, - ) - paths = save_us_microplex_artifacts( - result, - output_dir, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path or output_root / "run_registry.jsonl", - run_index_path=run_index_path or output_root, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - return replace(paths, version_id=resolved_version_id) - - -def build_and_save_versioned_us_microplex( - persons: Any, - households: Any, - output_root: str | Path, - *, - config: USMicroplexBuildConfig | None = None, - version_id: str | None = None, - frontier_metric: FrontierMetric = "candidate_composite_parity_loss", - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = False, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), -) -> USMicroplexVersionedBuildArtifacts: - """Build a US microplex dataset, save a versioned bundle, and report frontier gap.""" - build_result = build_us_microplex(persons, households, config=config) - return save_versioned_us_microplex_build_result( - build_result, - output_root, - version_id=version_id, - frontier_metric=frontier_metric, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - - -def save_versioned_us_microplex_build_result( - build_result: USMicroplexBuildResult, - output_root: str | Path, - *, - version_id: str | None = None, - frontier_metric: FrontierMetric = "candidate_composite_parity_loss", - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = False, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), -) -> USMicroplexVersionedBuildArtifacts: - """Save an already-built result as a versioned bundle and report frontier gap.""" - return _finalize_versioned_build_artifacts( - build_result, - output_root=output_root, - version_id=version_id, - frontier_metric=frontier_metric, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - - -def build_and_save_versioned_us_microplex_from_source_provider( - provider: SourceProvider, - output_root: str | Path, - *, - config: USMicroplexBuildConfig | None = None, - query: SourceQuery | None = None, - version_id: str | None = None, - frontier_metric: FrontierMetric = "candidate_composite_parity_loss", - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = False, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), -) -> USMicroplexVersionedBuildArtifacts: - """Build from one source provider, save a versioned bundle, and report frontier gap.""" - pipeline = USMicroplexPipeline(config) - build_result = pipeline.build_from_source_provider(provider, query=query) - return _finalize_versioned_build_artifacts( - build_result, - output_root=output_root, - version_id=version_id, - frontier_metric=frontier_metric, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - - -def build_and_save_versioned_us_microplex_from_source_providers( - providers: list[SourceProvider], - output_root: str | Path, - *, - config: USMicroplexBuildConfig | None = None, - queries: dict[str, SourceQuery] | None = None, - version_id: str | None = None, - frontier_metric: FrontierMetric = "candidate_composite_parity_loss", - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = False, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), -) -> USMicroplexVersionedBuildArtifacts: - """Build from multiple source providers, save a versioned bundle, and report frontier gap.""" - resolved_config = config or USMicroplexBuildConfig() - _resolved_version_id, preallocated_output_dir, stage_runtime_writer = ( - _initialize_versioned_stage_runtime_writer( - output_root, - version_id=version_id, - config=resolved_config, - providers=providers, - queries=queries, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - ) - pipeline = USMicroplexPipeline( - resolved_config, - stage_runtime_writer=stage_runtime_writer, - ) - build_result = pipeline.build_from_source_providers(providers, queries=queries) - return _finalize_versioned_build_artifacts( - build_result, - output_root=output_root, - version_id=version_id, - preallocated_output_dir=preallocated_output_dir, - frontier_metric=frontier_metric, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - stage_runtime_writer=stage_runtime_writer, - ) - - -def build_and_save_versioned_us_microplex_from_data_dir( - data_dir: str | Path, - output_root: str | Path, - *, - config: USMicroplexBuildConfig | None = None, - version_id: str | None = None, - frontier_metric: FrontierMetric = "candidate_composite_parity_loss", - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, - policyengine_target_provider: TargetProvider | None = None, - policyengine_baseline_dataset: str | Path | None = None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ) = None, - policyengine_harness_metadata: dict[str, Any] | None = None, - policyengine_us_data_repo: str | Path | None = None, - defer_policyengine_harness: bool = False, - require_policyengine_native_score: bool = False, - defer_policyengine_native_score: bool = False, - precomputed_policyengine_harness_payload: dict[str, Any] | None = None, - precomputed_policyengine_native_scores: dict[str, Any] | None = None, - run_registry_path: str | Path | None = None, - run_index_path: str | Path | None = None, - run_registry_metadata: dict[str, Any] | None = None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), -) -> USMicroplexVersionedBuildArtifacts: - """Build from a CPS-style parquet directory, save a versioned bundle, and report frontier gap.""" - pipeline = USMicroplexPipeline(config) - build_result = pipeline.build_from_data_dir(data_dir) - return _finalize_versioned_build_artifacts( - build_result, - output_root=output_root, - version_id=version_id, - frontier_metric=frontier_metric, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - - -def _finalize_versioned_build_artifacts( - build_result: USMicroplexBuildResult, - *, - output_root: str | Path, - version_id: str | None, - preallocated_output_dir: str | Path | None = None, - frontier_metric: FrontierMetric, - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None, - policyengine_target_provider: TargetProvider | None, - policyengine_baseline_dataset: str | Path | None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ), - policyengine_harness_metadata: dict[str, Any] | None, - policyengine_us_data_repo: str | Path | None, - defer_policyengine_harness: bool, - require_policyengine_native_score: bool, - defer_policyengine_native_score: bool, - precomputed_policyengine_harness_payload: dict[str, Any] | None, - precomputed_policyengine_native_scores: dict[str, Any] | None, - run_registry_path: str | Path | None, - run_index_path: str | Path | None, - run_registry_metadata: dict[str, Any] | None, - enable_child_tax_unit_agi_drift: bool = False, - child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, - allow_stage_input_overrides: bool = False, - stage_input_overrides: tuple[USStageInputOverride, ...] = (), - stage_runtime_writer: USStageRuntimeWriter | None = None, -) -> USMicroplexVersionedBuildArtifacts: - if preallocated_output_dir is not None: - output_root_path = Path(output_root) - output_dir = Path(preallocated_output_dir) - artifact_paths = save_us_microplex_artifacts( - build_result, - output_dir, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path - or output_root_path / "run_registry.jsonl", - run_index_path=run_index_path or output_root_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - stage_runtime_writer=stage_runtime_writer, - ) - artifact_paths = replace(artifact_paths, version_id=output_dir.name) - else: - artifact_paths = save_versioned_us_microplex_artifacts( - build_result, - output_root, - version_id=version_id, - policyengine_comparison_cache=policyengine_comparison_cache, - policyengine_target_provider=policyengine_target_provider, - policyengine_baseline_dataset=policyengine_baseline_dataset, - policyengine_harness_slices=policyengine_harness_slices, - policyengine_harness_metadata=policyengine_harness_metadata, - policyengine_us_data_repo=policyengine_us_data_repo, - defer_policyengine_harness=defer_policyengine_harness, - require_policyengine_native_score=require_policyengine_native_score, - defer_policyengine_native_score=defer_policyengine_native_score, - precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, - precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, - run_registry_path=run_registry_path, - run_index_path=run_index_path, - run_registry_metadata=run_registry_metadata, - enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, - child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - stage_runtime_writer=stage_runtime_writer, - ) - current_entry = None - frontier_entry = None - frontier_delta = None - if ( - artifact_paths.run_registry is not None - and artifact_paths.version_id is not None - ): - registry_entries = load_us_microplex_run_registry(artifact_paths.run_registry) - current_entry = next( - ( - entry - for entry in reversed(registry_entries) - if entry.artifact_id == artifact_paths.version_id - ), - None, - ) - frontier_entry = select_us_microplex_frontier_entry( - artifact_paths.run_registry, - metric=frontier_metric, - ) - current_value = _registry_metric_value(current_entry, frontier_metric) - frontier_value = _registry_metric_value(frontier_entry, frontier_metric) - if current_value is not None and frontier_value is not None: - frontier_delta = current_value - frontier_value - return USMicroplexVersionedBuildArtifacts( - build_result=build_result, - artifact_paths=artifact_paths, - current_entry=current_entry, - frontier_entry=frontier_entry, - frontier_delta=frontier_delta, - ) - - -def _write_json_atomically(path: Path, payload: dict[str, Any]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - temp_path = path.with_name(f".{path.name}.tmp") - temp_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) - temp_path.replace(path) - - -def _resolve_policyengine_harness_context( - result: USMicroplexBuildResult, - *, - policyengine_comparison_cache: PolicyEngineUSComparisonCache | None, - policyengine_target_provider: TargetProvider | None, - policyengine_baseline_dataset: str | Path | None, - policyengine_harness_slices: ( - tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None - ), - policyengine_harness_metadata: dict[str, Any] | None, -) -> tuple[ - TargetProvider | None, - str | Path | None, - tuple[PolicyEngineUSHarnessSlice, ...], - dict[str, Any], -]: - resolved_target_provider = policyengine_target_provider - if ( - resolved_target_provider is None - and result.config.policyengine_targets_db is not None - ): - resolved_target_provider = PolicyEngineUSDBTargetProvider( - result.config.policyengine_targets_db - ) - - resolved_baseline_dataset = ( - policyengine_baseline_dataset or result.config.policyengine_baseline_dataset - ) - - harness_period = result.config.policyengine_dataset_year or 2024 - if policyengine_harness_slices is not None: - resolved_harness_slices = tuple(policyengine_harness_slices) - elif result.config.policyengine_targets_db is not None: - resolved_harness_slices = default_policyengine_us_db_all_target_slices( - period=harness_period, - reform_id=result.config.policyengine_target_reform_id, - ) - else: - resolved_harness_slices = default_policyengine_us_harness_slices( - period=harness_period - ) - if resolved_target_provider is not None and resolved_harness_slices: - resolved_harness_slices = filter_nonempty_policyengine_us_harness_slices( - resolved_target_provider, - resolved_harness_slices, - cache=policyengine_comparison_cache, - ) - - resolved_harness_metadata = { - "baseline_dataset": ( - Path(resolved_baseline_dataset).name - if resolved_baseline_dataset is not None - else None - ), - "targets_db": ( - Path(result.config.policyengine_targets_db).name - if result.config.policyengine_targets_db is not None - else None - ), - "target_period": result.config.policyengine_target_period, - "target_variables": list(result.config.policyengine_target_variables), - "target_domains": list(result.config.policyengine_target_domains), - "target_geo_levels": list(result.config.policyengine_target_geo_levels), - "target_profile": result.config.policyengine_target_profile, - "calibration_target_profile": ( - result.config.policyengine_calibration_target_profile - ), - "target_reform_id": result.config.policyengine_target_reform_id, - "harness_slice_names": [ - slice_spec.name for slice_spec in resolved_harness_slices - ], - "policyengine_us_runtime_version": _resolve_policyengine_us_runtime_version(), - "harness_suite": ( - "policyengine_us_all_targets" - if result.config.policyengine_targets_db is not None - and policyengine_harness_slices is None - else None - ), - **dict(policyengine_harness_metadata or {}), - } - return ( - resolved_target_provider, - resolved_baseline_dataset, - resolved_harness_slices, - resolved_harness_metadata, - ) - - -def _resolve_policyengine_us_runtime_version() -> str | None: - try: - return version("policyengine-us") - except PackageNotFoundError: - return None - - -def _allocate_versioned_output_dir( - output_root: Path, - *, - version_id: str | None, - result: USMicroplexBuildResult, -) -> tuple[str, Path]: - return _allocate_versioned_output_dir_for_config( - output_root, - version_id=version_id, - config=result.config.to_dict(), - ) - - -def _allocate_versioned_output_dir_for_config( - output_root: Path, - *, - version_id: str | None, - config: dict[str, Any], -) -> tuple[str, Path]: - if version_id is not None: - output_dir = output_root / version_id - if output_dir.exists(): - raise FileExistsError( - f"Versioned artifact directory already exists: {output_dir}" - ) - return version_id, output_dir - - config_hash = _short_config_hash(config) - timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") - base_version_id = f"{timestamp}-{config_hash}" - candidate_version_id = base_version_id - suffix = 2 - output_dir = output_root / candidate_version_id - while output_dir.exists(): - candidate_version_id = f"{base_version_id}-{suffix}" - output_dir = output_root / candidate_version_id - suffix += 1 - return candidate_version_id, output_dir - - -def _short_config_hash(config: dict[str, Any]) -> str: - import hashlib - import json - - payload = json.dumps(config, sort_keys=True, separators=(",", ":")) - return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:8] - - -def _initialize_versioned_stage_runtime_writer( - output_root: str | Path, - *, - version_id: str | None, - config: USMicroplexBuildConfig, - providers: list[SourceProvider], - queries: dict[str, SourceQuery] | None, - allow_stage_input_overrides: bool, - stage_input_overrides: tuple[USStageInputOverride, ...], -) -> tuple[str, Path, USStageRuntimeWriter]: - root = Path(output_root) - root.mkdir(parents=True, exist_ok=True) - resolved_version_id, output_dir = _allocate_versioned_output_dir_for_config( - root, - version_id=version_id, - config=config.to_dict(), - ) - provider_query_plan = _provider_query_plan(providers, queries) - writer = USStageRuntimeWriter( - output_dir, - manifest_payload={ - "created_at": datetime.now(UTC).isoformat(), - "config": config.to_dict(), - "artifacts": {"manifest": "manifest.json"}, - }, - allow_stage_input_overrides=allow_stage_input_overrides, - stage_input_overrides=stage_input_overrides, - ) - writer.start_stage( - "01_run_profile", - metadata={"version_id": resolved_version_id}, - ) - writer.complete_stage( - USRunProfileOutputs( - manifest=USArtifactRef( - key="manifest", - path="manifest.json", - format="json", - required=True, - assume_exists=True, - ), - resolved_config=config.to_dict(), - provider_query_plan=provider_query_plan, - diagnostics={ - "stage_summary": USDiagnosticOutput( - key="stage_summary", - description="Runtime run-profile summary.", - summary={ - "provider_names": provider_query_plan["provider_names"], - "version_id": resolved_version_id, - }, - ) - }, - ) - ) - return resolved_version_id, output_dir, writer - - -def _provider_query_plan( - providers: list[SourceProvider], - queries: dict[str, SourceQuery] | None, -) -> dict[str, Any]: - return { - "provider_names": [provider.descriptor.name for provider in providers], - "queries": { - key: _json_ready_query(query) for key, query in dict(queries or {}).items() - }, - } - - -def _json_ready_query(query: SourceQuery) -> dict[str, Any]: - if hasattr(query, "to_dict"): - payload = query.to_dict() - if isinstance(payload, dict): - return payload - if hasattr(query, "__dataclass_fields__"): - return _json_ready(asdict(query)) - return _json_ready(vars(query)) - - -def _json_ready(value: Any) -> Any: - if isinstance(value, Mapping): - return {str(key): _json_ready(item) for key, item in value.items()} - if isinstance(value, (tuple, list, set, frozenset)): - return [_json_ready(item) for item in value] - if isinstance(value, Path): - return str(value) - if hasattr(value, "value"): - return value.value - return value - - -def _registry_metric_value(entry: Any | None, metric: FrontierMetric) -> float | None: - if entry is None: - return None - return getattr(entry, metric, None) diff --git a/src/microplex_us/pipelines/versioned_artifacts.py b/src/microplex_us/pipelines/versioned_artifacts.py new file mode 100644 index 0000000..b1ea51b --- /dev/null +++ b/src/microplex_us/pipelines/versioned_artifacts.py @@ -0,0 +1,684 @@ +"""Versioned build-and-save entrypoints for US Microplex artifacts.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import asdict, replace +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from microplex.core import SourceProvider, SourceQuery +from microplex.targets import TargetProvider + +from microplex_us.pipelines.artifact_types import ( + USMicroplexArtifactPaths, + USMicroplexVersionedBuildArtifacts, +) +from microplex_us.pipelines.registry import ( + FrontierMetric, + load_us_microplex_run_registry, + select_us_microplex_frontier_entry, +) +from microplex_us.pipelines.stage_run import ( + USArtifactRef, + USDiagnosticOutput, + USRunProfileOutputs, + USStageInputOverride, +) +from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter +from microplex_us.pipelines.us import ( + USMicroplexBuildConfig, + USMicroplexBuildResult, + USMicroplexPipeline, + build_us_microplex, +) +from microplex_us.policyengine.harness import ( + PolicyEngineUSComparisonCache, + PolicyEngineUSHarnessSlice, +) + + +def _save_us_microplex_artifacts(*args: Any, **kwargs: Any) -> USMicroplexArtifactPaths: + from microplex_us.pipelines.artifacts import save_us_microplex_artifacts + + return save_us_microplex_artifacts(*args, **kwargs) + + +def _facade_pipeline_cls() -> type[USMicroplexPipeline]: + from microplex_us.pipelines import artifacts + + return artifacts.USMicroplexPipeline + + +def _finalize_via_facade( + build_result: USMicroplexBuildResult, + **kwargs: Any, +) -> USMicroplexVersionedBuildArtifacts: + from microplex_us.pipelines import artifacts + + finalize = artifacts._finalize_versioned_build_artifacts + if finalize is _finalize_versioned_build_artifacts: + return _finalize_versioned_build_artifacts(build_result, **kwargs) + return finalize(build_result, **kwargs) + + +def save_versioned_us_microplex_artifacts( + result: USMicroplexBuildResult, + output_root: str | Path, + *, + version_id: str | None = None, + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = False, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> USMicroplexArtifactPaths: + """Persist a build under a stable versioned directory beneath one output root.""" + output_root = Path(output_root) + output_root.mkdir(parents=True, exist_ok=True) + resolved_version_id, output_dir = _allocate_versioned_output_dir( + output_root, + version_id=version_id, + result=result, + ) + paths = _save_us_microplex_artifacts( + result, + output_dir, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path or output_root / "run_registry.jsonl", + run_index_path=run_index_path or output_root, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + return replace(paths, version_id=resolved_version_id) + + +def build_and_save_versioned_us_microplex( + persons: Any, + households: Any, + output_root: str | Path, + *, + config: USMicroplexBuildConfig | None = None, + version_id: str | None = None, + frontier_metric: FrontierMetric = "candidate_composite_parity_loss", + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = False, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> USMicroplexVersionedBuildArtifacts: + """Build a US microplex dataset, save a versioned bundle, and report frontier gap.""" + build_result = build_us_microplex(persons, households, config=config) + return save_versioned_us_microplex_build_result( + build_result, + output_root, + version_id=version_id, + frontier_metric=frontier_metric, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + + +def save_versioned_us_microplex_build_result( + build_result: USMicroplexBuildResult, + output_root: str | Path, + *, + version_id: str | None = None, + frontier_metric: FrontierMetric = "candidate_composite_parity_loss", + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = False, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> USMicroplexVersionedBuildArtifacts: + """Save an already-built result as a versioned bundle and report frontier gap.""" + return _finalize_via_facade( + build_result, + output_root=output_root, + version_id=version_id, + frontier_metric=frontier_metric, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + + +def build_and_save_versioned_us_microplex_from_source_provider( + provider: SourceProvider, + output_root: str | Path, + *, + config: USMicroplexBuildConfig | None = None, + query: SourceQuery | None = None, + version_id: str | None = None, + frontier_metric: FrontierMetric = "candidate_composite_parity_loss", + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = False, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> USMicroplexVersionedBuildArtifacts: + """Build from one source provider, save a versioned bundle, and report frontier gap.""" + pipeline = _facade_pipeline_cls()(config) + build_result = pipeline.build_from_source_provider(provider, query=query) + return _finalize_via_facade( + build_result, + output_root=output_root, + version_id=version_id, + frontier_metric=frontier_metric, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + + +def build_and_save_versioned_us_microplex_from_source_providers( + providers: list[SourceProvider], + output_root: str | Path, + *, + config: USMicroplexBuildConfig | None = None, + queries: dict[str, SourceQuery] | None = None, + version_id: str | None = None, + frontier_metric: FrontierMetric = "candidate_composite_parity_loss", + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = False, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> USMicroplexVersionedBuildArtifacts: + """Build from multiple source providers, save a versioned bundle, and report frontier gap.""" + resolved_config = config or USMicroplexBuildConfig() + _resolved_version_id, preallocated_output_dir, stage_runtime_writer = ( + _initialize_versioned_stage_runtime_writer( + output_root, + version_id=version_id, + config=resolved_config, + providers=providers, + queries=queries, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + ) + pipeline = _facade_pipeline_cls()( + resolved_config, + stage_runtime_writer=stage_runtime_writer, + ) + build_result = pipeline.build_from_source_providers(providers, queries=queries) + return _finalize_via_facade( + build_result, + output_root=output_root, + version_id=version_id, + preallocated_output_dir=preallocated_output_dir, + frontier_metric=frontier_metric, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, + ) + + +def build_and_save_versioned_us_microplex_from_data_dir( + data_dir: str | Path, + output_root: str | Path, + *, + config: USMicroplexBuildConfig | None = None, + version_id: str | None = None, + frontier_metric: FrontierMetric = "candidate_composite_parity_loss", + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None = None, + policyengine_target_provider: TargetProvider | None = None, + policyengine_baseline_dataset: str | Path | None = None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ) = None, + policyengine_harness_metadata: dict[str, Any] | None = None, + policyengine_us_data_repo: str | Path | None = None, + defer_policyengine_harness: bool = False, + require_policyengine_native_score: bool = False, + defer_policyengine_native_score: bool = False, + precomputed_policyengine_harness_payload: dict[str, Any] | None = None, + precomputed_policyengine_native_scores: dict[str, Any] | None = None, + run_registry_path: str | Path | None = None, + run_index_path: str | Path | None = None, + run_registry_metadata: dict[str, Any] | None = None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> USMicroplexVersionedBuildArtifacts: + """Build from a CPS-style parquet directory, save a versioned bundle, and report frontier gap.""" + pipeline = _facade_pipeline_cls()(config) + build_result = pipeline.build_from_data_dir(data_dir) + return _finalize_via_facade( + build_result, + output_root=output_root, + version_id=version_id, + frontier_metric=frontier_metric, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + + +def _finalize_versioned_build_artifacts( + build_result: USMicroplexBuildResult, + *, + output_root: str | Path, + version_id: str | None, + preallocated_output_dir: str | Path | None = None, + frontier_metric: FrontierMetric, + policyengine_comparison_cache: PolicyEngineUSComparisonCache | None, + policyengine_target_provider: TargetProvider | None, + policyengine_baseline_dataset: str | Path | None, + policyengine_harness_slices: ( + tuple[PolicyEngineUSHarnessSlice, ...] | list[PolicyEngineUSHarnessSlice] | None + ), + policyengine_harness_metadata: dict[str, Any] | None, + policyengine_us_data_repo: str | Path | None, + defer_policyengine_harness: bool, + require_policyengine_native_score: bool, + defer_policyengine_native_score: bool, + precomputed_policyengine_harness_payload: dict[str, Any] | None, + precomputed_policyengine_native_scores: dict[str, Any] | None, + run_registry_path: str | Path | None, + run_index_path: str | Path | None, + run_registry_metadata: dict[str, Any] | None, + enable_child_tax_unit_agi_drift: bool = False, + child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), + stage_runtime_writer: USStageRuntimeWriter | None = None, +) -> USMicroplexVersionedBuildArtifacts: + if preallocated_output_dir is not None: + output_root_path = Path(output_root) + output_dir = Path(preallocated_output_dir) + artifact_paths = _save_us_microplex_artifacts( + build_result, + output_dir, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path + or output_root_path / "run_registry.jsonl", + run_index_path=run_index_path or output_root_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, + ) + artifact_paths = replace(artifact_paths, version_id=output_dir.name) + else: + artifact_paths = save_versioned_us_microplex_artifacts( + build_result, + output_root, + version_id=version_id, + policyengine_comparison_cache=policyengine_comparison_cache, + policyengine_target_provider=policyengine_target_provider, + policyengine_baseline_dataset=policyengine_baseline_dataset, + policyengine_harness_slices=policyengine_harness_slices, + policyengine_harness_metadata=policyengine_harness_metadata, + policyengine_us_data_repo=policyengine_us_data_repo, + defer_policyengine_harness=defer_policyengine_harness, + require_policyengine_native_score=require_policyengine_native_score, + defer_policyengine_native_score=defer_policyengine_native_score, + precomputed_policyengine_harness_payload=precomputed_policyengine_harness_payload, + precomputed_policyengine_native_scores=precomputed_policyengine_native_scores, + run_registry_path=run_registry_path, + run_index_path=run_index_path, + run_registry_metadata=run_registry_metadata, + enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, + child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, + ) + current_entry = None + frontier_entry = None + frontier_delta = None + if ( + artifact_paths.run_registry is not None + and artifact_paths.version_id is not None + ): + registry_entries = load_us_microplex_run_registry(artifact_paths.run_registry) + current_entry = next( + ( + entry + for entry in reversed(registry_entries) + if entry.artifact_id == artifact_paths.version_id + ), + None, + ) + frontier_entry = select_us_microplex_frontier_entry( + artifact_paths.run_registry, + metric=frontier_metric, + ) + current_value = _registry_metric_value(current_entry, frontier_metric) + frontier_value = _registry_metric_value(frontier_entry, frontier_metric) + if current_value is not None and frontier_value is not None: + frontier_delta = current_value - frontier_value + return USMicroplexVersionedBuildArtifacts( + build_result=build_result, + artifact_paths=artifact_paths, + current_entry=current_entry, + frontier_entry=frontier_entry, + frontier_delta=frontier_delta, + ) + + +def _allocate_versioned_output_dir( + output_root: Path, + *, + version_id: str | None, + result: USMicroplexBuildResult, +) -> tuple[str, Path]: + return _allocate_versioned_output_dir_for_config( + output_root, + version_id=version_id, + config=result.config.to_dict(), + ) + + +def _allocate_versioned_output_dir_for_config( + output_root: Path, + *, + version_id: str | None, + config: dict[str, Any], +) -> tuple[str, Path]: + if version_id is not None: + output_dir = output_root / version_id + if output_dir.exists(): + raise FileExistsError( + f"Versioned artifact directory already exists: {output_dir}" + ) + return version_id, output_dir + + config_hash = _short_config_hash(config) + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + base_version_id = f"{timestamp}-{config_hash}" + candidate_version_id = base_version_id + suffix = 2 + output_dir = output_root / candidate_version_id + while output_dir.exists(): + candidate_version_id = f"{base_version_id}-{suffix}" + output_dir = output_root / candidate_version_id + suffix += 1 + return candidate_version_id, output_dir + + +def _short_config_hash(config: dict[str, Any]) -> str: + import hashlib + import json + + payload = json.dumps(config, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:8] + + +def _initialize_versioned_stage_runtime_writer( + output_root: str | Path, + *, + version_id: str | None, + config: USMicroplexBuildConfig, + providers: list[SourceProvider], + queries: dict[str, SourceQuery] | None, + allow_stage_input_overrides: bool, + stage_input_overrides: tuple[USStageInputOverride, ...], +) -> tuple[str, Path, USStageRuntimeWriter]: + root = Path(output_root) + root.mkdir(parents=True, exist_ok=True) + resolved_version_id, output_dir = _allocate_versioned_output_dir_for_config( + root, + version_id=version_id, + config=config.to_dict(), + ) + provider_query_plan = _provider_query_plan(providers, queries) + writer = USStageRuntimeWriter( + output_dir, + manifest_payload={ + "created_at": datetime.now(UTC).isoformat(), + "config": config.to_dict(), + "artifacts": {"manifest": "manifest.json"}, + }, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + writer.start_stage( + "01_run_profile", + metadata={"version_id": resolved_version_id}, + ) + writer.complete_stage( + USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config=config.to_dict(), + provider_query_plan=provider_query_plan, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description="Runtime run-profile summary.", + summary={ + "provider_names": provider_query_plan["provider_names"], + "version_id": resolved_version_id, + }, + ) + }, + ) + ) + return resolved_version_id, output_dir, writer + + +def _provider_query_plan( + providers: list[SourceProvider], + queries: dict[str, SourceQuery] | None, +) -> dict[str, Any]: + return { + "provider_names": [provider.descriptor.name for provider in providers], + "queries": { + key: _json_ready_query(query) for key, query in dict(queries or {}).items() + }, + } + + +def _json_ready_query(query: SourceQuery) -> dict[str, Any]: + if hasattr(query, "to_dict"): + payload = query.to_dict() + if isinstance(payload, dict): + return payload + if hasattr(query, "__dataclass_fields__"): + return _json_ready(asdict(query)) + return _json_ready(vars(query)) + + +def _json_ready(value: Any) -> Any: + if isinstance(value, Mapping): + return {str(key): _json_ready(item) for key, item in value.items()} + if isinstance(value, (tuple, list, set, frozenset)): + return [_json_ready(item) for item in value] + if isinstance(value, Path): + return str(value) + if hasattr(value, "value"): + return value.value + return value + + +def _registry_metric_value(entry: Any | None, metric: FrontierMetric) -> float | None: + if entry is None: + return None + return getattr(entry, metric, None)