diff --git a/docs/api.md b/docs/api.md index 17153fd..0fb6cfa 100644 --- a/docs/api.md +++ b/docs/api.md @@ -16,6 +16,66 @@ :undoc-members: ``` +### Stage Manifest Internals + +```{eval-rst} +.. automodule:: microplex_us.pipelines.stage_manifest_types + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_manifest_builder + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_manifest_io + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_status + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_metrics + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_data_flow + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_policyengine_artifacts + :members: + :undoc-members: + +.. automodule:: microplex_us.pipelines.stage_validation_evidence + :members: + :undoc-members: +``` + +## Stage artifacts + +```{eval-rst} +.. automodule:: microplex_us.pipelines.stage_artifacts + :members: + :undoc-members: +``` + +## Conditional readiness + +```{eval-rst} +.. automodule:: microplex_us.pipelines.stage_readiness + :members: + :undoc-members: +``` + +## Stage run writer + +```{eval-rst} +.. automodule:: microplex_us.pipelines.stage_run + :members: + :undoc-members: +``` + ## Artifact helpers ```{eval-rst} diff --git a/docs/stage-contracts.md b/docs/stage-contracts.md index 3dea438..33a6e31 100644 --- a/docs/stage-contracts.md +++ b/docs/stage-contracts.md @@ -5,11 +5,31 @@ The canonical stage registry lives in expected inputs, outputs, artifacts, diagnostics, validation placeholders, and resume mode. -Saved artifact bundles now include a `stage_manifest.json` sidecar. This file is -the machine-readable saved-run overlay for the stage taxonomy. It records the +Saved artifact bundles now include a `stage_manifest.json` derived artifact. This +file is the machine-readable saved-run overlay for the stage taxonomy. It records the canonical stages, status for the current run, artifact paths, diagnostics owned by each stage, and the current resume posture. +Each saved bundle also includes typed per-stage output manifests at +`stage_artifacts/manifests/.json`. These manifests are written through +`USStageRunWriter`, which validates each stage as a whole instead of updating +individual manifest keys directly. The manifest files live outside each stage's +payload directory so they do not change the content hash of reloadable stage +artifacts. + +The registry exposes two seam layers: + +- `inputs` and `outputs` are structured stage resources. They identify artifact, + config, manifest, runtime, and external-data dependencies with explicit keys. +- `consumes` and `produces` remain short human-readable summaries for diagrams + and documentation. + +Artifact `required` means required for a complete canonical saved bundle. It is +separate from `resume_role`, which says whether an existing artifact is useful +for diagnostics, manual replay, manual resume, or post-artifact validation. +Partial bundles can therefore still expose a valid replay boundary while the +manifest honestly reports that the complete publication bundle is incomplete. + ## Legacy run-contract IDs Older run-contract summaries and dashboard payloads used operational labels @@ -34,7 +54,7 @@ boundary artifacts where the pipeline already has stable outputs: - Stage 7: `calibrated_data.parquet`, `targets.json`, and `stage_artifacts/07_calibration/calibration_summary.json` - Stage 8: `policyengine_us.h5` -- Stage 9: validation and benchmark evidence sidecars +- Stage 9: validation and benchmark evidence artifacts The Stage 4 artifact is the scaffold-projected seed before donor integration. It is a diagnostic and manual replay boundary, not an automatic conditional resume @@ -44,6 +64,21 @@ Conditional execution is intentionally not implemented yet. The stage manifest and artifacts are designed to make that possible later without changing the saved-run contract again. +## Artifact inventory and readiness + +Saved bundles also expose two Stage 8 diagnostic artifacts: + +- `stage_artifacts/artifact_inventory.json` lists canonical stage artifacts, + whether each path exists, whether it was referenced by the run manifest, its + resume role, size/file counts, and content hashes. +- `stage_artifacts/conditional_readiness.json` summarizes which stage outputs + are available for manual replay, manual resume, post-artifact evidence, or + diagnostics only. + +These reports are advisory. They do not skip or rerun stages, and they do not +silently accept stale artifacts. If a requested config is supplied to the +readiness builder, config mismatches are reported as `must_rerun`. + ## Validation hooks Each stage contract includes concise validation descriptors. These describe the diff --git a/src/microplex_us/pipelines/__init__.py b/src/microplex_us/pipelines/__init__.py index d9b64ba..353faa9 100644 --- a/src/microplex_us/pipelines/__init__.py +++ b/src/microplex_us/pipelines/__init__.py @@ -257,12 +257,38 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: ( "USPipelineStageContract", "USStageArtifactContract", + "USStageResourceContract", "USStageValidationContract", + "config_keys_for_us_pipeline_stage", "default_us_pipeline_stage_contracts", + "get_us_stage_artifact_contract", "get_us_pipeline_stage_contract", + "resolve_us_stage_artifact_contract_path", "serialize_us_pipeline_stage_contracts", ), ), + **_exports( + "microplex_us.pipelines.stage_artifacts", + ( + "USCalibratedStageArtifacts", + "USCandidateCalibrationReplayArtifacts", + "USCandidateStageArtifacts", + "USDatasetAssemblyArtifacts", + "USPolicyEngineEntityStageArtifacts", + "USSeedScaffoldStageArtifacts", + "USStageArtifactInventory", + "build_us_stage_artifact_inventory", + "load_us_calibrated_stage_artifacts", + "load_us_candidate_calibration_replay_artifacts", + "load_us_candidate_stage_artifacts", + "load_us_dataset_assembly_artifacts", + "load_us_policyengine_entity_stage_artifacts", + "load_us_seed_scaffold_stage_artifacts", + "load_us_stage_artifact_inventory", + "resolve_us_stage_artifact_path_checked", + "write_us_stage_artifact_inventory", + ), + ), **_exports( "microplex_us.pipelines.stage_manifest", ( @@ -285,6 +311,41 @@ def _exports(module: str, names: tuple[str, ...]) -> dict[str, str]: "write_us_validation_evidence_manifest", ), ), + **_exports( + "microplex_us.pipelines.stage_readiness", + ( + "USConditionalReadinessReport", + "USConditionalReadinessStageRecord", + "build_us_conditional_readiness_report", + "build_us_stage_reuse_key", + "load_us_conditional_readiness_report", + "write_us_conditional_readiness_report", + ), + ), + **_exports( + "microplex_us.pipelines.stage_run", + ( + "USAuxiliaryArtifact", + "USArtifactRef", + "USCalibrationOutputs", + "USDatasetAssemblyOutputs", + "USDiagnosticOutput", + "USDonorSynthesisOutputs", + "USPolicyEngineEntityOutputs", + "USRunProfileOutputs", + "USSeedScaffoldOutputs", + "USSourceLoadingOutputs", + "USSourcePlanningOutputs", + "USStageInputOverride", + "USStageOutputManifest", + "USStageRunWriter", + "USValidationBenchmarkingOutputs", + "build_us_stage_output_manifests_from_artifact_manifest", + "parse_us_stage_input_override", + "resolve_us_manifest_or_contract_artifact_path", + "write_us_stage_run_manifests_from_artifact_manifest", + ), + ), **_exports( "microplex_us.pipelines.summarize_pe_native_family_drilldown", ( diff --git a/src/microplex_us/pipelines/artifacts.py b/src/microplex_us/pipelines/artifacts.py index 11f3fb8..39c2871 100644 --- a/src/microplex_us/pipelines/artifacts.py +++ b/src/microplex_us/pipelines/artifacts.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from dataclasses import dataclass +from dataclasses import dataclass, replace from datetime import UTC, datetime from importlib.metadata import PackageNotFoundError, version from pathlib import Path @@ -24,9 +24,6 @@ write_capital_gains_lots_sqlite, ) from microplex_us.data_sources.forbes import ForbesFixedSpineConfig -from microplex_us.pipelines.data_flow_snapshot import ( - write_us_microplex_data_flow_snapshot, -) from microplex_us.pipelines.index_db import ( append_us_microplex_run_index_entry, ) @@ -40,11 +37,15 @@ load_us_microplex_run_registry, select_us_microplex_frontier_entry, ) +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) from microplex_us.pipelines.stage_manifest import ( - US_STAGE_ARTIFACT_ROOT, write_us_policyengine_entity_stage_artifact, - write_us_stage_manifest, - write_us_validation_evidence_manifest, +) +from microplex_us.pipelines.stage_run import ( + USStageInputOverride, + write_us_stage_run_manifests_from_artifact_manifest, ) from microplex_us.pipelines.summarize_child_tax_unit_agi_drift import ( DEFAULT_VARIABLES as DEFAULT_CHILD_TAX_UNIT_AGI_DRIFT_VARIABLES, @@ -88,6 +89,8 @@ class USMicroplexArtifactPaths: policyengine_dataset: Path | None = None data_flow_snapshot: Path | None = None stage_manifest: Path | None = None + artifact_inventory: Path | None = None + conditional_readiness: Path | None = None source_plan: Path | None = None policyengine_entity_tables: Path | None = None calibration_summary: Path | None = None @@ -266,7 +269,9 @@ def _resolve_saved_artifact_file( artifacts = dict(manifest.get("artifacts", {})) filename = artifacts.get(artifact_key) if not filename: - filename = "targets.json" if artifact_key == "targets" else f"{artifact_key}.parquet" + filename = ( + "targets.json" if artifact_key == "targets" else f"{artifact_key}.parquet" + ) path = Path(filename) if not path.is_absolute(): path = artifact_root / path @@ -320,9 +325,7 @@ def _write_us_source_plan_artifact( "donorAuthoritativeOverrideVariables": list( synthesis.get("donor_authoritative_override_variables", ()) ), - "donorExcludedVariables": list( - synthesis.get("donor_excluded_variables", ()) - ), + "donorExcludedVariables": list(synthesis.get("donor_excluded_variables", ())), } if result.fusion_plan is not None: payload["fusionPlan"] = { @@ -357,8 +360,7 @@ def _build_source_weight_diagnostics( entity: { "count": fixed_spine_entry.get(f"{prefix}_count", 0), "weight_sum": fixed_spine_entry.get(f"{prefix}_weight_sum", 0.0), - "available": fixed_spine_entry.get(f"{prefix}_weight_sum") - is not None, + "available": fixed_spine_entry.get(f"{prefix}_weight_sum") is not None, } for entity, prefix in _SOURCE_DIAGNOSTIC_ENTITY_PREFIXES.items() } @@ -772,7 +774,11 @@ def _maybe_write_capital_gains_lot_artifact( "lot_rows": int(len(lots)), } ) - path = output_dir / "capital_gains_lots.sqlite" + path = resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "capital_gains_lots", + ) write_capital_gains_lots_sqlite(lots, path, metadata=metadata) return path, { "enabled": True, @@ -808,40 +814,115 @@ def save_us_microplex_artifacts( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexArtifactPaths: """Persist a build result as a reproducible artifact bundle.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - seed_data_path = output_dir / "seed_data.parquet" - synthetic_data_path = output_dir / "synthetic_data.parquet" - calibrated_data_path = output_dir / "calibrated_data.parquet" - targets_path = output_dir / "targets.json" - manifest_path = output_dir / "manifest.json" - source_weight_diagnostics_path = output_dir / "source_weight_diagnostics.json" - synthesizer_path = output_dir / "synthesizer.pt" if result.synthesizer else None + seed_data_path = resolve_us_stage_artifact_contract_path( + output_dir, + "05_donor_integration_synthesis", + "seed_data", + ) + synthetic_data_path = resolve_us_stage_artifact_contract_path( + output_dir, + "05_donor_integration_synthesis", + "synthetic_data", + ) + calibrated_data_path = resolve_us_stage_artifact_contract_path( + output_dir, + "07_calibration", + "calibrated_data", + ) + targets_path = resolve_us_stage_artifact_contract_path( + output_dir, + "07_calibration", + "targets", + ) + manifest_path = resolve_us_stage_artifact_contract_path( + output_dir, + "01_run_profile", + "manifest", + ) + source_weight_diagnostics_path = resolve_us_stage_artifact_contract_path( + output_dir, + "05_donor_integration_synthesis", + "source_weight_diagnostics", + ) + synthesizer_path = ( + resolve_us_stage_artifact_contract_path( + output_dir, + "05_donor_integration_synthesis", + "synthesizer", + ) + if result.synthesizer + else None + ) policyengine_dataset_path = ( - output_dir / "policyengine_us.h5" if result.policyengine_tables is not None else None + resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "policyengine_dataset", + ) + if result.policyengine_tables is not None + else None + ) + data_flow_snapshot_path = resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "data_flow_snapshot", + ) + stage_manifest_path = resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "stage_manifest", + ) + artifact_inventory_path = resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "artifact_inventory", + ) + conditional_readiness_path = resolve_us_stage_artifact_contract_path( + output_dir, + "08_dataset_assembly", + "conditional_readiness", + ) + source_plan_path = resolve_us_stage_artifact_contract_path( + output_dir, + "03_source_planning", + "source_plan", ) - data_flow_snapshot_path = output_dir / "data_flow_snapshot.json" - stage_manifest_path = output_dir / "stage_manifest.json" - stage_artifact_root = output_dir / US_STAGE_ARTIFACT_ROOT - source_plan_path = stage_artifact_root / "03_source_planning" / "source_plan.json" scaffold_seed_data_path = ( - stage_artifact_root / "04_seed_scaffold" / "scaffold_seed_data.parquet" + resolve_us_stage_artifact_contract_path( + output_dir, + "04_seed_scaffold", + "scaffold_seed_data", + ) if result.scaffold_seed_data is not None else None ) policyengine_entity_tables_path = ( - stage_artifact_root / "06_policyengine_entities" / "metadata.json" + resolve_us_stage_artifact_contract_path( + output_dir, + "06_policyengine_entities", + "policyengine_entity_tables", + ) if result.policyengine_tables is not None else None ) - calibration_summary_path = ( - stage_artifact_root / "07_calibration" / "calibration_summary.json" + calibration_summary_path = resolve_us_stage_artifact_contract_path( + output_dir, + "07_calibration", + "calibration_summary", ) validation_evidence_path = ( - stage_artifact_root / "09_validation_benchmarking" / "evidence_manifest.json" + resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "validation_evidence", + ) if result.policyengine_tables is not None else None ) @@ -917,7 +998,11 @@ def save_us_microplex_artifacts( ) if precomputed_policyengine_harness_payload is not None: harness_payload = dict(precomputed_policyengine_harness_payload) - policyengine_harness_path = output_dir / "policyengine_harness.json" + policyengine_harness_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_harness", + ) policyengine_harness_path.write_text( json.dumps(harness_payload, indent=2, sort_keys=True) ) @@ -942,13 +1027,21 @@ def save_us_microplex_artifacts( metadata=resolved_harness_metadata, cache=policyengine_comparison_cache, ) - policyengine_harness_path = output_dir / "policyengine_harness.json" + policyengine_harness_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_harness", + ) harness_run.save(policyengine_harness_path) harness_payload = harness_run.to_dict() harness_summary = harness_payload["summary"] if native_scores_payload is not None: - policyengine_native_scores_path = output_dir / "policyengine_native_scores.json" + policyengine_native_scores_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) policyengine_native_scores_path.write_text( json.dumps(native_scores_payload, indent=2, sort_keys=True) ) @@ -964,8 +1057,10 @@ def save_us_microplex_artifacts( period=result.config.policyengine_dataset_year or 2024, policyengine_us_data_repo=policyengine_us_data_repo, ) - policyengine_native_scores_path = ( - output_dir / "policyengine_native_scores.json" + policyengine_native_scores_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "policyengine_native_scores", ) policyengine_native_scores_path.write_text( json.dumps(native_scores_payload, indent=2, sort_keys=True) @@ -978,7 +1073,11 @@ def save_us_microplex_artifacts( child_tax_unit_agi_drift_summary: dict[str, Any] | None = None if enable_child_tax_unit_agi_drift: try: - drift_path = output_dir / "child_tax_unit_agi_drift.json" + drift_path = resolve_us_stage_artifact_contract_path( + output_dir, + "09_validation_benchmarking", + "child_tax_unit_agi_drift", + ) variables = ( child_tax_unit_agi_drift_variables or DEFAULT_CHILD_TAX_UNIT_AGI_DRIFT_VARIABLES @@ -987,14 +1086,14 @@ def save_us_microplex_artifacts( output_dir, variables=variables, ) - drift_path.write_text( - json.dumps(payload, indent=2, sort_keys=True) - ) + drift_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) child_tax_unit_agi_drift_path = drift_path - child_tax_unit_agi_drift_summary = _summarize_child_tax_unit_agi_drift_ratios( - payload, - stage="calibrated", - variables=variables, + child_tax_unit_agi_drift_summary = ( + _summarize_child_tax_unit_agi_drift_ratios( + payload, + stage="calibrated", + variables=variables, + ) ) except Exception as exc: # pragma: no cover - diagnostic best-effort child_tax_unit_agi_drift_summary = { @@ -1045,6 +1144,10 @@ def save_us_microplex_artifacts( ), "data_flow_snapshot": data_flow_snapshot_path.name, "stage_manifest": stage_manifest_path.name, + "artifact_inventory": str(artifact_inventory_path.relative_to(output_dir)), + "conditional_readiness": str( + conditional_readiness_path.relative_to(output_dir) + ), "validation_evidence": ( str(validation_evidence_path.relative_to(output_dir)) if validation_evidence_path is not None @@ -1076,18 +1179,20 @@ def save_us_microplex_artifacts( child_tax_unit_agi_drift_path.name ) if child_tax_unit_agi_drift_summary is not None: - manifest.setdefault("diagnostics", {})[ - "child_tax_unit_agi_drift" - ] = child_tax_unit_agi_drift_summary + manifest.setdefault("diagnostics", {})["child_tax_unit_agi_drift"] = ( + child_tax_unit_agi_drift_summary + ) if capital_gains_lots_summary is not None: - manifest.setdefault("diagnostics", {})[ - "capital_gains_lots" - ] = capital_gains_lots_summary + manifest.setdefault("diagnostics", {})["capital_gains_lots"] = ( + capital_gains_lots_summary + ) manifest.setdefault("diagnostics", {})["source_weight_diagnostics"] = dict( source_weight_diagnostics_payload.get("summary", {}) ) if harness_summary is not None or native_scores_payload is not None: - resolved_run_registry_path = Path(run_registry_path or output_dir.parent / "run_registry.jsonl") + resolved_run_registry_path = Path( + run_registry_path or output_dir.parent / "run_registry.jsonl" + ) run_entry = build_us_microplex_run_registry_entry( artifact_dir=output_dir, manifest_path=manifest_path, @@ -1122,22 +1227,11 @@ def save_us_microplex_artifacts( "path": str(resolved_run_index_path), "artifact_id": recorded_entry.artifact_id, } - if validation_evidence_path is not None: - write_us_validation_evidence_manifest( - output_dir, - validation_evidence_path, - manifest_payload=manifest, - ) - write_us_microplex_data_flow_snapshot( - output_dir, - data_flow_snapshot_path, - manifest_payload=manifest, - assume_existing_stage_artifact_keys=("stage_manifest",), - ) - write_us_stage_manifest( + manifest = write_us_stage_run_manifests_from_artifact_manifest( output_dir, - stage_manifest_path, - manifest_payload=manifest, + manifest, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) assert_valid_benchmark_artifact_manifest( manifest, @@ -1169,7 +1263,6 @@ def save_us_microplex_artifacts( else () ), ) - _write_json_atomically(manifest_path, manifest) return USMicroplexArtifactPaths( output_dir=output_dir, @@ -1184,6 +1277,8 @@ def save_us_microplex_artifacts( policyengine_dataset=policyengine_dataset_path, data_flow_snapshot=data_flow_snapshot_path, stage_manifest=stage_manifest_path, + artifact_inventory=artifact_inventory_path, + conditional_readiness=conditional_readiness_path, source_plan=source_plan_path, policyengine_entity_tables=policyengine_entity_tables_path, calibration_summary=calibration_summary_path, @@ -1222,6 +1317,8 @@ def save_versioned_us_microplex_artifacts( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexArtifactPaths: """Persist a build under a stable versioned directory beneath one output root.""" output_root = Path(output_root) @@ -1250,27 +1347,10 @@ def save_versioned_us_microplex_artifacts( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) - return USMicroplexArtifactPaths( - output_dir=paths.output_dir, - version_id=resolved_version_id, - seed_data=paths.seed_data, - synthetic_data=paths.synthetic_data, - calibrated_data=paths.calibrated_data, - targets=paths.targets, - manifest=paths.manifest, - scaffold_seed_data=paths.scaffold_seed_data, - synthesizer=paths.synthesizer, - policyengine_dataset=paths.policyengine_dataset, - data_flow_snapshot=paths.data_flow_snapshot, - policyengine_harness=paths.policyengine_harness, - policyengine_native_scores=paths.policyengine_native_scores, - policyengine_native_audit=paths.policyengine_native_audit, - child_tax_unit_agi_drift=paths.child_tax_unit_agi_drift, - capital_gains_lots=paths.capital_gains_lots, - run_registry=paths.run_registry, - run_index_db=paths.run_index_db, - ) + return replace(paths, version_id=resolved_version_id) def build_and_save_versioned_us_microplex( @@ -1299,6 +1379,8 @@ def build_and_save_versioned_us_microplex( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: """Build a US microplex dataset, save a versioned bundle, and report frontier gap.""" build_result = build_us_microplex(persons, households, config=config) @@ -1323,6 +1405,8 @@ def build_and_save_versioned_us_microplex( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) @@ -1350,6 +1434,8 @@ def save_versioned_us_microplex_build_result( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: """Save an already-built result as a versioned bundle and report frontier gap.""" return _finalize_versioned_build_artifacts( @@ -1373,6 +1459,8 @@ def save_versioned_us_microplex_build_result( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) @@ -1402,6 +1490,8 @@ def build_and_save_versioned_us_microplex_from_source_provider( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: """Build from one source provider, save a versioned bundle, and report frontier gap.""" pipeline = USMicroplexPipeline(config) @@ -1427,6 +1517,8 @@ def build_and_save_versioned_us_microplex_from_source_provider( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) @@ -1456,6 +1548,8 @@ def build_and_save_versioned_us_microplex_from_source_providers( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: """Build from multiple source providers, save a versioned bundle, and report frontier gap.""" pipeline = USMicroplexPipeline(config) @@ -1481,6 +1575,8 @@ def build_and_save_versioned_us_microplex_from_source_providers( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) @@ -1509,6 +1605,8 @@ def build_and_save_versioned_us_microplex_from_data_dir( run_registry_metadata: dict[str, Any] | None = None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: """Build from a CPS-style parquet directory, save a versioned bundle, and report frontier gap.""" pipeline = USMicroplexPipeline(config) @@ -1534,6 +1632,8 @@ def build_and_save_versioned_us_microplex_from_data_dir( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) @@ -1561,6 +1661,8 @@ def _finalize_versioned_build_artifacts( run_registry_metadata: dict[str, Any] | None, enable_child_tax_unit_agi_drift: bool = False, child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> USMicroplexVersionedBuildArtifacts: artifact_paths = save_versioned_us_microplex_artifacts( build_result, @@ -1582,11 +1684,16 @@ def _finalize_versioned_build_artifacts( run_registry_metadata=run_registry_metadata, enable_child_tax_unit_agi_drift=enable_child_tax_unit_agi_drift, child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) current_entry = None frontier_entry = None frontier_delta = None - if artifact_paths.run_registry is not None and artifact_paths.version_id is not None: + if ( + artifact_paths.run_registry is not None + and artifact_paths.version_id is not None + ): registry_entries = load_us_microplex_run_registry(artifact_paths.run_registry) current_entry = next( ( @@ -1637,7 +1744,10 @@ def _resolve_policyengine_harness_context( dict[str, Any], ]: resolved_target_provider = policyengine_target_provider - if resolved_target_provider is None and result.config.policyengine_targets_db is not None: + if ( + resolved_target_provider is None + and result.config.policyengine_targets_db is not None + ): resolved_target_provider = PolicyEngineUSDBTargetProvider( result.config.policyengine_targets_db ) @@ -1685,7 +1795,9 @@ def _resolve_policyengine_harness_context( result.config.policyengine_calibration_target_profile ), "target_reform_id": result.config.policyengine_target_reform_id, - "harness_slice_names": [slice_spec.name for slice_spec in resolved_harness_slices], + "harness_slice_names": [ + slice_spec.name for slice_spec in resolved_harness_slices + ], "policyengine_us_runtime_version": _resolve_policyengine_us_runtime_version(), "harness_suite": ( "policyengine_us_all_targets" @@ -1719,7 +1831,9 @@ def _allocate_versioned_output_dir( if version_id is not None: output_dir = output_root / version_id if output_dir.exists(): - raise FileExistsError(f"Versioned artifact directory already exists: {output_dir}") + raise FileExistsError( + f"Versioned artifact directory already exists: {output_dir}" + ) return version_id, output_dir config_hash = _short_config_hash(result.config.to_dict()) diff --git a/src/microplex_us/pipelines/backfill_pe_native_audit.py b/src/microplex_us/pipelines/backfill_pe_native_audit.py index a2961a8..ba549c8 100644 --- a/src/microplex_us/pipelines/backfill_pe_native_audit.py +++ b/src/microplex_us/pipelines/backfill_pe_native_audit.py @@ -22,6 +22,10 @@ from microplex_us.pipelines.pe_us_data_rebuild_checkpoint import ( _refresh_checkpoint_data_flow_snapshot, ) +from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) def backfill_us_pe_native_audit_bundle( @@ -43,7 +47,11 @@ def backfill_us_pe_native_audit_bundle( native_scores_path = _resolve_required_native_scores_path(bundle_dir, artifacts) native_scores_payload = json.loads(native_scores_path.read_text()) - native_audit_path = bundle_dir / "pe_us_data_rebuild_native_audit.json" + native_audit_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_audit", + ) if native_audit_path.exists() and not force: payload = json.loads(native_audit_path.read_text()) else: @@ -96,7 +104,11 @@ def backfill_us_pe_native_audit_bundles( continue manifest_paths.append(manifest_path) native_scores_payload = json.loads(native_scores_path.read_text()) - native_audit_path = bundle_dir / "pe_us_data_rebuild_native_audit.json" + native_audit_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_audit", + ) if native_audit_path.exists() and not force: _write_native_audit_payload_to_bundle( bundle_dir=bundle_dir, @@ -222,11 +234,17 @@ def _write_native_audit_payload_to_bundle( manifest: dict, payload: dict, ) -> Path: - native_audit_path = bundle_dir / "pe_us_data_rebuild_native_audit.json" + native_audit_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_audit", + ) native_audit_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) artifacts = dict(manifest.get("artifacts", {})) - artifacts["policyengine_native_audit"] = native_audit_path.name + artifacts["policyengine_native_audit"] = str( + native_audit_path.relative_to(bundle_dir) + ) manifest["artifacts"] = artifacts manifest["policyengine_native_audit"] = dict(payload.get("verdictHints", {})) @@ -276,8 +294,14 @@ def _resolve_optional_native_scores_path( artifacts: dict, ) -> Path | None: artifact_name = ( - artifacts.get("policyengine_native_scores") or "policyengine_native_scores.json" + artifacts.get("policyengine_native_scores") + or get_us_stage_artifact_contract( + "09_validation_benchmarking", + "policyengine_native_scores", + ).path_hint ) + if artifact_name is None: + return None path = bundle_dir / str(artifact_name) if path.exists(): return path diff --git a/src/microplex_us/pipelines/backfill_pe_native_scores.py b/src/microplex_us/pipelines/backfill_pe_native_scores.py index bb8ed12..0e93030 100644 --- a/src/microplex_us/pipelines/backfill_pe_native_scores.py +++ b/src/microplex_us/pipelines/backfill_pe_native_scores.py @@ -13,20 +13,33 @@ compute_batch_us_pe_native_scores, compute_us_pe_native_scores, ) +from microplex_us.pipelines.pe_us_data_rebuild_checkpoint import ( + _refresh_checkpoint_data_flow_snapshot, +) from microplex_us.pipelines.registry import ( append_us_microplex_run_registry_entry, build_us_microplex_run_registry_entry, ) +from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) def discover_us_candidate_artifact_dirs(artifact_root: str | Path) -> tuple[Path, ...]: """Return saved US artifact bundle directories with a PE dataset and manifest.""" root = Path(artifact_root) + dataset_hint = get_us_stage_artifact_contract( + "08_dataset_assembly", + "policyengine_dataset", + ).path_hint + if dataset_hint is None: + raise RuntimeError("Stage 8 policyengine_dataset artifact has no path hint") return tuple( sorted( path.parent - for path in root.rglob("policyengine_us.h5") + for path in root.rglob(dataset_hint) if (path.parent / "manifest.json").exists() ) ) @@ -50,7 +63,11 @@ def backfill_us_pe_native_scores_bundle( if not dataset_name: raise ValueError(f"{bundle_dir} does not declare a policyengine_dataset artifact") - native_sidecar_path = bundle_dir / "policyengine_native_scores.json" + native_sidecar_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) if native_sidecar_path.exists() and not force: payload = json.loads(native_sidecar_path.read_text()) else: @@ -96,7 +113,11 @@ def backfill_us_pe_native_scores_bundles( manifest = json.loads(manifest_path.read_text()) manifest_paths.append(manifest_path) - native_sidecar_path = bundle_dir / "policyengine_native_scores.json" + native_sidecar_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) if native_sidecar_path.exists() and not force: _write_native_scores_payload_to_bundle( bundle_dir=bundle_dir, @@ -189,11 +210,17 @@ def _write_native_scores_payload_to_bundle( manifest: dict, payload: dict, ) -> Path: - native_sidecar_path = bundle_dir / "policyengine_native_scores.json" + native_sidecar_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ) native_sidecar_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) artifacts = dict(manifest.get("artifacts", {})) - artifacts["policyengine_native_scores"] = native_sidecar_path.name + artifacts["policyengine_native_scores"] = str( + native_sidecar_path.relative_to(bundle_dir) + ) manifest["artifacts"] = artifacts manifest["policyengine_native_scores"] = dict(payload.get("summary", {})) if "run_registry" in manifest: @@ -201,6 +228,7 @@ def _write_native_scores_payload_to_bundle( "enhanced_cps_native_loss_delta" ) + _refresh_checkpoint_data_flow_snapshot(bundle_dir, manifest) assert_valid_benchmark_artifact_manifest( manifest, artifact_dir=bundle_dir, diff --git a/src/microplex_us/pipelines/check_site_snapshot.py b/src/microplex_us/pipelines/check_site_snapshot.py index 4df6a01..03fada8 100644 --- a/src/microplex_us/pipelines/check_site_snapshot.py +++ b/src/microplex_us/pipelines/check_site_snapshot.py @@ -11,6 +11,9 @@ build_us_microplex_data_flow_snapshot, ) from microplex_us.pipelines.site_snapshot import build_us_microplex_site_snapshot +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) def check_us_microplex_site_snapshot( @@ -84,7 +87,11 @@ def _resolve_artifact_dir(snapshot_file: Path, source_artifact: dict) -> Path: def _check_data_flow_snapshot_current(artifact_dir: Path) -> None: - snapshot_path = artifact_dir / "data_flow_snapshot.json" + snapshot_path = resolve_us_stage_artifact_contract_path( + artifact_dir, + "08_dataset_assembly", + "data_flow_snapshot", + ) if not snapshot_path.exists(): raise SystemExit("US data-flow snapshot is missing from the artifact bundle.") frozen_snapshot = json.loads(snapshot_path.read_text()) diff --git a/src/microplex_us/pipelines/dashboard.py b/src/microplex_us/pipelines/dashboard.py index ed74384..4948272 100644 --- a/src/microplex_us/pipelines/dashboard.py +++ b/src/microplex_us/pipelines/dashboard.py @@ -14,6 +14,7 @@ from microplex_us.pipelines.stage_contracts import ( canonicalize_us_pipeline_stage_id, + get_us_stage_artifact_contract, ) _ROOT = Path(__file__).resolve().parents[3] @@ -1720,8 +1721,13 @@ def write_dashboard_payload( def _iter_score_paths(artifact_root: Path) -> list[Path]: + native_scores_hint = get_us_stage_artifact_contract( + "09_validation_benchmarking", + "policyengine_native_scores", + ).path_hint paths = list(artifact_root.rglob("scores.json")) - paths.extend(artifact_root.rglob("policyengine_native_scores.json")) + if native_scores_hint is not None: + paths.extend(artifact_root.rglob(native_scores_hint)) paths.extend(artifact_root.rglob("*_score.json")) return [path for path in paths if path.is_file()] @@ -1940,7 +1946,13 @@ def _score_label(path: Path, candidate_dataset: Any, index: int) -> str: artifact = path.parent.name if isinstance(candidate_dataset, str): dataset_name = Path(candidate_dataset).name - if dataset_name != "policyengine_us.h5": + policyengine_dataset_hint = get_us_stage_artifact_contract( + "08_dataset_assembly", + "policyengine_dataset", + ).path_hint + if policyengine_dataset_hint is None or dataset_name != Path( + policyengine_dataset_hint + ).name: return f"{artifact} / {dataset_name}" if index: return f"{artifact} / candidate {index + 1}" diff --git a/src/microplex_us/pipelines/data_flow_snapshot.py b/src/microplex_us/pipelines/data_flow_snapshot.py index 1de98f7..ca78374 100644 --- a/src/microplex_us/pipelines/data_flow_snapshot.py +++ b/src/microplex_us/pipelines/data_flow_snapshot.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Any +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) from microplex_us.pipelines.stage_manifest import ( build_us_stage_manifest, load_us_stage_manifest, @@ -45,7 +48,11 @@ def require_saved_us_microplex_data_flow_snapshot( ) -> dict[str, Any]: """Load the saved canonical US data-flow snapshot or raise.""" artifact_root = Path(artifact_dir) - snapshot_path = artifact_root / "data_flow_snapshot.json" + snapshot_path = resolve_us_stage_artifact_contract_path( + artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + ) if not snapshot_path.exists(): raise FileNotFoundError( f"US artifact bundle is missing data_flow_snapshot.json: {snapshot_path}" @@ -202,7 +209,11 @@ def _materialize_us_microplex_data_flow_snapshot( def _load_saved_data_flow_snapshot(artifact_root: Path) -> dict[str, Any] | None: - snapshot_path = artifact_root / "data_flow_snapshot.json" + snapshot_path = resolve_us_stage_artifact_contract_path( + artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + ) if not snapshot_path.exists(): return None snapshot = json.loads(snapshot_path.read_text()) diff --git a/src/microplex_us/pipelines/experiments.py b/src/microplex_us/pipelines/experiments.py index 7bf9cdd..4436475 100644 --- a/src/microplex_us/pipelines/experiments.py +++ b/src/microplex_us/pipelines/experiments.py @@ -31,6 +31,9 @@ load_us_microplex_run_registry, select_us_microplex_frontier_entry, ) +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) from microplex_us.pipelines.us import USMicroplexBuildConfig from microplex_us.policyengine.harness import ( PolicyEngineUSComparisonCache, @@ -264,6 +267,16 @@ def to_dict(self) -> dict[str, Any]: if self.artifact_paths.data_flow_snapshot is not None else None ), + "artifact_inventory": ( + str(self.artifact_paths.artifact_inventory) + if self.artifact_paths.artifact_inventory is not None + else None + ), + "conditional_readiness": ( + str(self.artifact_paths.conditional_readiness) + if self.artifact_paths.conditional_readiness is not None + else None + ), "policyengine_harness": ( str(self.artifact_paths.policyengine_harness) if self.artifact_paths.policyengine_harness is not None @@ -342,6 +355,16 @@ def from_dict(cls, payload: dict[str, Any]) -> USMicroplexExperimentResult: if artifact_paths.get("data_flow_snapshot") is not None else None ), + artifact_inventory=( + Path(artifact_paths["artifact_inventory"]) + if artifact_paths.get("artifact_inventory") is not None + else None + ), + conditional_readiness=( + Path(artifact_paths["conditional_readiness"]) + if artifact_paths.get("conditional_readiness") is not None + else None + ), policyengine_harness=( Path(artifact_paths["policyengine_harness"]) if artifact_paths.get("policyengine_harness") is not None @@ -768,7 +791,21 @@ def _refresh_experiment_artifact_paths( data_flow_snapshot=_resolve_optional_result_artifact_path( artifact_root, artifacts.get("data_flow_snapshot"), - fallback="data_flow_snapshot.json", + fallback=str( + resolve_us_stage_artifact_contract_path( + artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + ).relative_to(artifact_root) + ), + ), + artifact_inventory=_resolve_optional_result_artifact_path( + artifact_root, + artifacts.get("artifact_inventory"), + ), + conditional_readiness=_resolve_optional_result_artifact_path( + artifact_root, + artifacts.get("conditional_readiness"), ), policyengine_harness=_resolve_optional_result_artifact_path( artifact_root, diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_audit.py b/src/microplex_us/pipelines/pe_us_data_rebuild_audit.py index e94a678..9b09463 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_audit.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_audit.py @@ -11,6 +11,12 @@ compare_us_pe_native_target_deltas, compute_us_pe_native_support_audit, ) +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_run import ( + resolve_us_manifest_or_contract_artifact_path, +) def build_policyengine_us_data_rebuild_native_audit( @@ -33,18 +39,32 @@ def build_policyengine_us_data_rebuild_native_audit( if manifest_payload is not None else json.loads((artifact_root / "manifest.json").read_text()) ) + artifacts = dict(manifest.get("artifacts", {})) native_scores = ( dict(native_scores_payload) if native_scores_payload is not None - else json.loads((artifact_root / "policyengine_native_scores.json").read_text()) + else json.loads( + _resolve_stage_artifact_path( + artifact_root, + manifest, + "policyengine_native_scores", + stage_id="09_validation_benchmarking", + ).read_text() + ) ) imputation_ablation = ( dict(imputation_ablation_payload) if imputation_ablation_payload is not None - else _load_optional_json(artifact_root / "imputation_ablation.json") + else _load_optional_json( + _resolve_stage_artifact_path( + artifact_root, + manifest, + "imputation_ablation", + stage_id="09_validation_benchmarking", + ) + ) ) config = dict(manifest.get("config", {})) - artifacts = dict(manifest.get("artifacts", {})) candidate_dataset_path = _resolve_candidate_dataset_path(artifact_root, artifacts) baseline_dataset_path = _resolve_baseline_dataset_path(config) period = int( @@ -176,7 +196,11 @@ def write_policyengine_us_data_rebuild_native_audit( destination = ( Path(output_path) if output_path is not None - else artifact_root / "pe_us_data_rebuild_native_audit.json" + else resolve_us_stage_artifact_contract_path( + artifact_root, + "09_validation_benchmarking", + "policyengine_native_audit", + ) ) payload = build_policyengine_us_data_rebuild_native_audit( artifact_root, @@ -193,6 +217,21 @@ def write_policyengine_us_data_rebuild_native_audit( return destination +def _resolve_stage_artifact_path( + artifact_root: Path, + manifest: dict[str, Any], + artifact_key: str, + *, + stage_id: str, +) -> Path: + return resolve_us_manifest_or_contract_artifact_path( + artifact_root, + manifest, + artifact_key, + stage_id=stage_id, + ) + + def _resolve_candidate_dataset_path( artifact_root: Path, artifacts: dict[str, Any], @@ -202,7 +241,9 @@ def _resolve_candidate_dataset_path( raise FileNotFoundError( "Artifact bundle is missing artifacts.policyengine_dataset in manifest.json" ) - dataset_path = artifact_root / dataset_name + dataset_path = Path(dataset_name) + if not dataset_path.is_absolute(): + dataset_path = artifact_root / dataset_path if not dataset_path.exists(): raise FileNotFoundError( f"Artifact bundle is missing saved policyengine dataset: {dataset_path}" diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py index f8a345b..8c7e681 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py @@ -53,9 +53,13 @@ load_us_microplex_run_registry, select_us_microplex_frontier_entry, ) -from microplex_us.pipelines.stage_manifest import ( - write_us_stage_manifest, - write_us_validation_evidence_manifest, +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_run import ( + USStageInputOverride, + parse_us_stage_input_override, + write_us_stage_run_manifests_from_artifact_manifest, ) from microplex_us.variables import prune_redundant_variables @@ -237,6 +241,37 @@ def _resolve_saved_artifact_path( return candidate +def _resolve_required_saved_artifact_path( + artifact_root: Path, + artifacts: dict[str, Any], + artifact_key: str, +) -> Path: + path = _resolve_saved_artifact_path(artifact_root, artifacts.get(artifact_key)) + if path is None: + raise KeyError(f"Saved artifact manifest does not declare {artifact_key!r}") + return path + + +def _resolve_saved_stage_artifact_path( + artifact_root: Path, + artifacts: dict[str, Any], + artifact_key: str, + *, + stage_id: str, +) -> Path | None: + declared_path = _resolve_saved_artifact_path( + artifact_root, artifacts.get(artifact_key) + ) + if declared_path is not None: + return declared_path + contract_path = resolve_us_stage_artifact_contract_path( + artifact_root, + stage_id, + artifact_key, + ) + return contract_path if contract_path.exists() else None + + def _infer_policyengine_baseline_household_weight_sum( baseline_dataset: str | Path, *, @@ -970,134 +1005,29 @@ def _build_checkpoint_imputation_ablation_payload( } -def _build_checkpoint_benchmark_stage( - manifest: dict[str, Any], - *, - extra_outputs: tuple[str, ...] = (), -) -> dict[str, Any]: - artifacts = dict(manifest.get("artifacts", {})) - calibration_summary = dict(manifest.get("calibration", {})) - harness_summary = dict(manifest.get("policyengine_harness", {})) - native_scores_summary = dict(manifest.get("policyengine_native_scores", {})) - imputation_ablation_summary = dict(manifest.get("imputation_ablation", {})) - outputs = [ - value - for value in ( - artifacts.get("policyengine_harness"), - artifacts.get("policyengine_native_scores"), - artifacts.get("imputation_ablation"), - artifacts.get("policyengine_native_audit"), - *extra_outputs, - ) - if value - ] - return { - "id": "09_validation_benchmarking", - "legacyId": "benchmark", - "step": "09", - "title": "Validation and benchmarking", - "summary": ( - "Harness, native-loss, and donor-imputation diagnostics stay attached " - "to the same artifact bundle." - ), - "status": ( - "ready" - if harness_summary or native_scores_summary or imputation_ablation_summary - else "missing" - ), - "metrics": [ - { - "label": "Capped full oracle loss", - "value": calibration_summary.get( - "full_oracle_capped_mean_abs_relative_error" - ), - }, - { - "label": "Full oracle loss", - "value": calibration_summary.get("full_oracle_mean_abs_relative_error"), - }, - { - "label": "Harness delta", - "value": harness_summary.get("mean_abs_relative_error_delta"), - }, - { - "label": "Native delta", - "value": native_scores_summary.get("enhanced_cps_native_loss_delta"), - }, - { - "label": "Win rate", - "value": harness_summary.get("target_win_rate"), - }, - { - "label": "Imputation MAE", - "value": imputation_ablation_summary.get( - "production_mean_weighted_mae" - ), - }, - { - "label": "Imputation F1", - "value": imputation_ablation_summary.get("production_mean_support_f1"), - }, - ], - "outputs": list(dict.fromkeys(outputs)), - } - - def _refresh_checkpoint_data_flow_snapshot( artifact_root: Path, manifest: dict[str, Any], *, extra_outputs: tuple[str, ...] = (), ) -> Path | None: - snapshot_path = artifact_root / "data_flow_snapshot.json" - stage_manifest_path = artifact_root / "stage_manifest.json" - validation_evidence_path = ( - artifact_root - / "stage_artifacts" - / "09_validation_benchmarking" - / "evidence_manifest.json" - ) - artifacts = dict(manifest.get("artifacts", {})) - artifacts.setdefault("stage_manifest", stage_manifest_path.name) - artifacts.setdefault( - "validation_evidence", - str(validation_evidence_path.relative_to(artifact_root)), - ) - manifest["artifacts"] = artifacts - write_us_validation_evidence_manifest( + if extra_outputs: + manifest.setdefault("diagnostics", {}).setdefault( + "checkpoint_extra_outputs", + list(extra_outputs), + ) + updated_manifest = write_us_stage_run_manifests_from_artifact_manifest( artifact_root, - validation_evidence_path, - manifest_payload=manifest, + manifest, ) - write_us_stage_manifest( + manifest.clear() + manifest.update(updated_manifest) + snapshot_path = resolve_us_stage_artifact_contract_path( artifact_root, - stage_manifest_path, - manifest_payload=manifest, + "08_dataset_assembly", + "data_flow_snapshot", ) - if not snapshot_path.exists(): - return None - snapshot = json.loads(snapshot_path.read_text()) - if snapshot.get("schemaVersion") != 1: - return snapshot_path - stages = list(snapshot.get("stages", [])) - benchmark_stage = _build_checkpoint_benchmark_stage( - manifest, - extra_outputs=extra_outputs, - ) - replaced = False - for index, stage in enumerate(stages): - if isinstance(stage, dict) and stage.get("id") in { - "benchmark", - "09_validation_benchmarking", - }: - stages[index] = benchmark_stage - replaced = True - break - if not replaced: - stages.append(benchmark_stage) - snapshot["stages"] = stages - _write_json_atomically(snapshot_path, snapshot) - return snapshot_path + return snapshot_path if snapshot_path.exists() else None def _attach_checkpoint_registry_and_index( @@ -1215,35 +1145,128 @@ def _load_checkpoint_versioned_artifacts( artifact_paths = USMicroplexArtifactPaths( output_dir=artifact_root, version_id=artifact_root.name, - seed_data=artifact_root / str(artifacts["seed_data"]), - synthetic_data=artifact_root / str(artifacts["synthetic_data"]), - calibrated_data=artifact_root / str(artifacts["calibrated_data"]), - targets=artifact_root / str(artifacts["targets"]), + seed_data=_resolve_required_saved_artifact_path( + artifact_root, + artifacts, + "seed_data", + ), + synthetic_data=_resolve_required_saved_artifact_path( + artifact_root, + artifacts, + "synthetic_data", + ), + calibrated_data=_resolve_required_saved_artifact_path( + artifact_root, + artifacts, + "calibrated_data", + ), + targets=_resolve_required_saved_artifact_path( + artifact_root, + artifacts, + "targets", + ), manifest=manifest_path, - synthesizer=_resolve_saved_artifact_path( + scaffold_seed_data=_resolve_saved_stage_artifact_path( artifact_root, - artifacts.get("synthesizer"), + artifacts, + "scaffold_seed_data", + stage_id="04_seed_scaffold", ), - policyengine_dataset=_resolve_saved_artifact_path( + synthesizer=_resolve_saved_stage_artifact_path( artifact_root, - artifacts.get("policyengine_dataset"), + artifacts, + "synthesizer", + stage_id="05_donor_integration_synthesis", ), - data_flow_snapshot=( - artifact_root / "data_flow_snapshot.json" - if (artifact_root / "data_flow_snapshot.json").exists() - else None + policyengine_dataset=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "policyengine_dataset", + stage_id="08_dataset_assembly", + ), + data_flow_snapshot=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "data_flow_snapshot", + stage_id="08_dataset_assembly", + ), + stage_manifest=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "stage_manifest", + stage_id="08_dataset_assembly", + ), + artifact_inventory=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "artifact_inventory", + stage_id="08_dataset_assembly", + ), + conditional_readiness=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "conditional_readiness", + stage_id="08_dataset_assembly", + ), + source_plan=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "source_plan", + stage_id="03_source_planning", + ), + policyengine_entity_tables=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "policyengine_entity_tables", + stage_id="06_policyengine_entities", + ), + calibration_summary=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "calibration_summary", + stage_id="07_calibration", + ), + validation_evidence=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "validation_evidence", + stage_id="09_validation_benchmarking", + ), + policyengine_harness=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "policyengine_harness", + stage_id="09_validation_benchmarking", + ), + policyengine_native_scores=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "policyengine_native_scores", + stage_id="09_validation_benchmarking", + ), + policyengine_native_audit=_resolve_saved_stage_artifact_path( + artifact_root, + artifacts, + "policyengine_native_audit", + stage_id="09_validation_benchmarking", ), - policyengine_harness=_resolve_saved_artifact_path( + child_tax_unit_agi_drift=_resolve_saved_stage_artifact_path( artifact_root, - artifacts.get("policyengine_harness"), + artifacts, + "child_tax_unit_agi_drift", + stage_id="09_validation_benchmarking", ), - policyengine_native_scores=_resolve_saved_artifact_path( + capital_gains_lots=_resolve_saved_stage_artifact_path( artifact_root, - artifacts.get("policyengine_native_scores"), + artifacts, + "capital_gains_lots", + stage_id="08_dataset_assembly", ), - policyengine_native_audit=_resolve_saved_artifact_path( + source_weight_diagnostics=_resolve_saved_stage_artifact_path( artifact_root, - artifacts.get("policyengine_native_audit"), + artifacts, + "source_weight_diagnostics", + stage_id="05_donor_integration_synthesis", ), run_registry=_resolve_saved_artifact_path( artifact_root, @@ -1481,7 +1504,11 @@ def attach_policyengine_us_data_rebuild_checkpoint_evidence( ) harness_payload = harness_run.to_dict() if harness_payload is not None: - harness_path = artifact_root / "policyengine_harness.json" + harness_path = resolve_us_stage_artifact_contract_path( + artifact_root, + "09_validation_benchmarking", + "policyengine_harness", + ) _write_json_atomically(harness_path, harness_payload) artifacts["policyengine_harness"] = harness_path.name manifest["policyengine_harness"] = dict(harness_payload.get("summary", {})) @@ -1512,7 +1539,11 @@ def attach_policyengine_us_data_rebuild_checkpoint_evidence( policyengine_us_data_python=policyengine_us_data_python, ) if native_scores_payload is not None: - native_scores_path = artifact_root / "policyengine_native_scores.json" + native_scores_path = resolve_us_stage_artifact_contract_path( + artifact_root, + "09_validation_benchmarking", + "policyengine_native_scores", + ) _write_json_atomically(native_scores_path, native_scores_payload) artifacts["policyengine_native_scores"] = native_scores_path.name manifest["policyengine_native_scores"] = dict( @@ -1540,7 +1571,11 @@ def attach_policyengine_us_data_rebuild_checkpoint_evidence( manifest=manifest, ) if imputation_ablation_payload is not None: - imputation_ablation_path = artifact_root / "imputation_ablation.json" + imputation_ablation_path = resolve_us_stage_artifact_contract_path( + artifact_root, + "09_validation_benchmarking", + "imputation_ablation", + ) _write_json_atomically(imputation_ablation_path, imputation_ablation_payload) artifacts["imputation_ablation"] = imputation_ablation_path.name manifest["imputation_ablation"] = dict( @@ -1610,35 +1645,17 @@ def attach_policyengine_us_data_rebuild_checkpoint_evidence( policyengine_us_data_repo=policyengine_us_data_repo, policyengine_us_data_python=policyengine_us_data_python, ) - native_audit_path = artifact_root / "pe_us_data_rebuild_native_audit.json" + native_audit_path = resolve_us_stage_artifact_contract_path( + artifact_root, + "09_validation_benchmarking", + "policyengine_native_audit", + ) _write_json_atomically(native_audit_path, native_audit_payload) artifacts["policyengine_native_audit"] = native_audit_path.name manifest["policyengine_native_audit"] = dict( native_audit_payload.get("verdictHints", {}) ) - stage_manifest_path = artifact_root / "stage_manifest.json" - validation_evidence_path = ( - artifact_root - / "stage_artifacts" - / "09_validation_benchmarking" - / "evidence_manifest.json" - ) - artifacts.setdefault("stage_manifest", stage_manifest_path.name) - artifacts.setdefault( - "validation_evidence", - str(validation_evidence_path.relative_to(artifact_root)), - ) manifest["artifacts"] = artifacts - write_us_validation_evidence_manifest( - artifact_root, - validation_evidence_path, - manifest_payload=manifest, - ) - write_us_stage_manifest( - artifact_root, - stage_manifest_path, - manifest_payload=manifest, - ) _refresh_checkpoint_data_flow_snapshot( artifact_root, manifest, @@ -1845,6 +1862,8 @@ def run_policyengine_us_data_rebuild_checkpoint( run_registry_path: str | Path | None = None, run_index_path: str | Path | None = None, run_registry_metadata: dict[str, Any] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), ) -> PEUSDataRebuildCheckpointResult: """Run one saved rebuild checkpoint and write its PE comparison sidecars.""" @@ -1983,6 +2002,8 @@ def run_policyengine_us_data_rebuild_checkpoint( run_index_path=run_index_path, run_registry_metadata=resolved_registry_metadata, enable_child_tax_unit_agi_drift=True, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) _emit_checkpoint_progress( "PE-US-data rebuild checkpoint: build complete", @@ -2216,7 +2237,27 @@ def main(argv: list[str] | None = None) -> None: ) parser.add_argument("--capital-gains-lots-max-lots-per-person", type=int) parser.add_argument("--capital-gains-lots-random-seed", type=int) + parser.add_argument( + "--allow-stage-input-overrides", + action="store_true", + help=( + "Allow typed stage manifests to consume explicit CLI input overrides " + "instead of the immediately previous stage manifest." + ), + ) + parser.add_argument( + "--stage-input-override", + action="append", + default=[], + metavar="STAGE_ID.KEY=PATH", + help=("Explicit stage input override. Requires --allow-stage-input-overrides."), + ) args = parser.parse_args(argv) + stage_input_overrides = tuple( + parse_us_stage_input_override(value) for value in args.stage_input_override + ) + if stage_input_overrides and not args.allow_stage_input_overrides: + parser.error("--stage-input-override requires --allow-stage-input-overrides") config_overrides = { "n_synthetic": int(args.n_synthetic), @@ -2302,6 +2343,8 @@ def main(argv: list[str] | None = None) -> None: defer_policyengine_native_score=args.defer_policyengine_native_score, defer_native_audit=args.defer_native_audit, defer_imputation_ablation=args.defer_imputation_ablation, + allow_stage_input_overrides=args.allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, ) print(result.artifacts.artifact_paths.output_dir) diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_parity.py b/src/microplex_us/pipelines/pe_us_data_rebuild_parity.py index 13fb9f0..5d8aea1 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_parity.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_parity.py @@ -11,6 +11,9 @@ default_policyengine_us_data_rebuild_config, default_policyengine_us_data_rebuild_program, ) +from microplex_us.pipelines.stage_run import ( + resolve_us_manifest_or_contract_artifact_path, +) _HARNESS_SUMMARY_KEYS = ( "candidate_mean_abs_relative_error", @@ -85,32 +88,50 @@ def build_policyengine_us_data_rebuild_parity_artifact( if manifest_payload is not None else json.loads((artifact_root / "manifest.json").read_text()) ) + harness_path = _resolve_stage_artifact_path( + artifact_root, + manifest, + "policyengine_harness", + stage_id="09_validation_benchmarking", + ) harness_source = _resolve_payload_source( - artifact_root / "policyengine_harness.json", + harness_path, override_supplied=harness_payload is not None, ) harness = ( dict(harness_payload) if harness_payload is not None - else _load_optional_json(artifact_root / "policyengine_harness.json") + else _load_optional_json(harness_path) + ) + native_scores_path = _resolve_stage_artifact_path( + artifact_root, + manifest, + "policyengine_native_scores", + stage_id="09_validation_benchmarking", ) native_scores_source = _resolve_payload_source( - artifact_root / "policyengine_native_scores.json", + native_scores_path, override_supplied=native_scores_payload is not None, ) native_scores = ( dict(native_scores_payload) if native_scores_payload is not None - else _load_optional_json(artifact_root / "policyengine_native_scores.json") + else _load_optional_json(native_scores_path) + ) + imputation_ablation_path = _resolve_stage_artifact_path( + artifact_root, + manifest, + "imputation_ablation", + stage_id="09_validation_benchmarking", ) imputation_ablation_source = _resolve_payload_source( - artifact_root / "imputation_ablation.json", + imputation_ablation_path, override_supplied=imputation_ablation_payload is not None, ) imputation_ablation = ( dict(imputation_ablation_payload) if imputation_ablation_payload is not None - else _load_optional_json(artifact_root / "imputation_ablation.json") + else _load_optional_json(imputation_ablation_path) ) resolved_program = program or default_policyengine_us_data_rebuild_program() @@ -280,6 +301,21 @@ def _load_optional_json(path: Path) -> dict[str, Any] | None: return json.loads(path.read_text()) +def _resolve_stage_artifact_path( + artifact_root: Path, + manifest: dict[str, Any], + artifact_key: str, + *, + stage_id: str, +) -> Path: + return resolve_us_manifest_or_contract_artifact_path( + artifact_root, + manifest, + artifact_key, + stage_id=stage_id, + ) + + def _build_profile_conformance( *, observed_config: dict[str, Any], diff --git a/src/microplex_us/pipelines/pe_us_dataset_readiness.py b/src/microplex_us/pipelines/pe_us_dataset_readiness.py index 6f87445..6811846 100644 --- a/src/microplex_us/pipelines/pe_us_dataset_readiness.py +++ b/src/microplex_us/pipelines/pe_us_dataset_readiness.py @@ -10,6 +10,10 @@ import h5py import numpy as np +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) + DEFAULT_PERIOD = 2024 DEFAULT_REQUIRED_VARIABLES: dict[str, str] = { "household_id": "household", @@ -161,10 +165,16 @@ def _resolve_dataset_path(path: Path) -> Path: manifest = json.loads(manifest_path.read_text()) dataset_name = dict(manifest.get("artifacts", {})).get("policyengine_dataset") if isinstance(dataset_name, str) and dataset_name: - dataset_path = path / dataset_name + dataset_path = Path(dataset_name) + if not dataset_path.is_absolute(): + dataset_path = path / dataset_path if dataset_path.exists(): return dataset_path.resolve() - dataset_path = path / "policyengine_us.h5" + dataset_path = resolve_us_stage_artifact_contract_path( + path, + "08_dataset_assembly", + "policyengine_dataset", + ) if dataset_path.exists(): return dataset_path.resolve() raise FileNotFoundError(f"No policyengine_us.h5 export found under {path}") diff --git a/src/microplex_us/pipelines/site_snapshot.py b/src/microplex_us/pipelines/site_snapshot.py index 03495f1..5141053 100644 --- a/src/microplex_us/pipelines/site_snapshot.py +++ b/src/microplex_us/pipelines/site_snapshot.py @@ -13,6 +13,12 @@ require_saved_us_microplex_data_flow_snapshot, write_us_microplex_data_flow_snapshot, ) +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_run import ( + resolve_us_manifest_or_contract_artifact_path, +) FOCUS_TAG_PRIORITY: tuple[str, ...] = ( "state", @@ -51,7 +57,13 @@ def build_us_microplex_site_snapshot( "mean_abs_relative_error_delta", ), ) - harness = json.loads((artifact_root / "policyengine_harness.json").read_text()) + harness_path = resolve_us_manifest_or_contract_artifact_path( + artifact_root, + manifest, + "policyengine_harness", + stage_id="09_validation_benchmarking", + ) + harness = json.loads(harness_path.read_text()) summary = dict(harness.get("summary", {})) tag_summaries = { key: dict(value) @@ -62,14 +74,19 @@ def build_us_microplex_site_snapshot( synthesis = dict(manifest.get("synthesis", {})) calibration = dict(manifest.get("calibration", {})) config = dict(manifest.get("config", {})) - data_flow_path = artifact_root / "data_flow_snapshot.json" + data_flow_path = resolve_us_manifest_or_contract_artifact_path( + artifact_root, + manifest, + "data_flow_snapshot", + stage_id="08_dataset_assembly", + ) data_flow_snapshot = require_saved_us_microplex_data_flow_snapshot(artifact_root) source_artifact = { "artifactRef": _artifact_ref(artifact_root), "manifestFile": "manifest.json", - "harnessFile": "policyengine_harness.json", - "dataFlowFile": data_flow_path.name, + "harnessFile": _artifact_path_for_manifest(artifact_root, harness_path), + "dataFlowFile": _artifact_path_for_manifest(artifact_root, data_flow_path), "versionId": artifact_root.name, } if snapshot_path is not None: @@ -162,7 +179,11 @@ def write_us_microplex_site_snapshot( artifact_root = Path(artifact_dir) write_us_microplex_data_flow_snapshot( artifact_root, - artifact_root / "data_flow_snapshot.json", + resolve_us_stage_artifact_contract_path( + artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + ), ) snapshot = build_us_microplex_site_snapshot( artifact_root, @@ -181,6 +202,13 @@ def _artifact_ref(artifact_root: Path) -> str: return artifact_root.name +def _artifact_path_for_manifest(artifact_root: Path, path: Path) -> str: + try: + return str(path.relative_to(artifact_root)) + except ValueError: + return str(path) + + def _artifact_path_from_snapshot(artifact_root: Path, snapshot_path: Path) -> str: return os.path.relpath(artifact_root, snapshot_path.parent) diff --git a/src/microplex_us/pipelines/stage_artifacts.py b/src/microplex_us/pipelines/stage_artifacts.py new file mode 100644 index 0000000..4480a54 --- /dev/null +++ b/src/microplex_us/pipelines/stage_artifacts.py @@ -0,0 +1,829 @@ +"""Artifact inventory helpers for US Microplex saved runs.""" + +from __future__ import annotations + +import hashlib +import json +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast + +import pandas as pd + +from microplex_us.pipelines.stage_contracts import ( + US_STAGE_CONTRACT_VERSION, + StageArtifactFormat, + StageArtifactHashMode, + StageArtifactResumeRole, +) +from microplex_us.pipelines.stage_manifest import ( + USStageManifest, + build_us_stage_manifest, + load_us_policyengine_entity_stage_artifact, +) + +if TYPE_CHECKING: + from microplex_us.pipelines.us import USMicroplexTargets + from microplex_us.policyengine import PolicyEngineUSEntityTableBundle + +US_STAGE_ARTIFACT_INVENTORY_SCHEMA_VERSION = 1 +DEFAULT_US_STAGE_ARTIFACT_HASH_MAX_BYTES: int | None = None + +USStageArtifactClassification = Literal[ + "contract_only", + "diagnostic_only", + "manual_replay", + "manual_resume", + "post_artifact_evidence", + "missing_required", + "missing_optional", + "metadata_only", +] + +USStageArtifactHashStatus = Literal[ + "hashed", + "not_requested", + "missing", + "too_large", + "unsupported", + "error", +] + + +class USStageArtifactInventoryRecord(TypedDict): + """Inventory view of one canonical stage artifact.""" + + stageId: str + stageStep: str + stageTitle: str + key: str + description: str + path: str | None + exists: bool + referenced: bool + required: bool + resumeRole: StageArtifactResumeRole | None + format: StageArtifactFormat + hashMode: StageArtifactHashMode + classification: USStageArtifactClassification + sizeBytes: int | None + fileCount: int | None + contentHash: str | None + hashStatus: USStageArtifactHashStatus + + +class USStageArtifactInventory(TypedDict): + """Machine-readable artifact inventory for one saved run.""" + + schemaVersion: int + contractVersion: str + generatedAt: str | None + pipeline: str + artifactRoot: str + manifest: str + stageManifest: str | None + artifacts: list[USStageArtifactInventoryRecord] + + +@dataclass(frozen=True) +class USSeedScaffoldStageArtifacts: + """Reloaded Stage 4 seed/scaffold artifact.""" + + scaffold_seed_data: pd.DataFrame + artifact_paths: Mapping[str, Path] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USCandidateStageArtifacts: + """Reloaded Stage 5 candidate artifacts for manual downstream replay.""" + + seed_data: pd.DataFrame + synthetic_data: pd.DataFrame + artifact_paths: Mapping[str, Path] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USCandidateCalibrationReplayArtifacts: + """Cross-stage artifacts for manually replaying candidate calibration.""" + + candidate: USCandidateStageArtifacts + targets: USMicroplexTargets + seed_scaffold: USSeedScaffoldStageArtifacts | None = None + artifact_paths: Mapping[str, Path] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USPolicyEngineEntityStageArtifacts: + """Reloaded Stage 6 PolicyEngine entity-table checkpoint.""" + + bundle: PolicyEngineUSEntityTableBundle + metadata: dict[str, Any] + metadata_path: Path + + +@dataclass(frozen=True) +class USCalibratedStageArtifacts: + """Reloaded Stage 7 calibrated data and target metadata.""" + + calibrated_data: pd.DataFrame + targets: USMicroplexTargets + calibration_summary: dict[str, Any] + artifact_paths: Mapping[str, Path] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USDatasetAssemblyArtifacts: + """Resolved Stage 8 dataset assembly artifacts.""" + + policyengine_dataset: Path + manifest: Path + stage_manifest: Path + data_flow_snapshot: Path + artifact_inventory: Path + conditional_readiness: Path + + +def build_us_stage_artifact_inventory( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, + assume_existing_artifact_keys: Iterable[str] = (), + max_hash_bytes: int | None = DEFAULT_US_STAGE_ARTIFACT_HASH_MAX_BYTES, +) -> USStageArtifactInventory: + """Build an artifact inventory for one US Microplex saved-run directory.""" + + artifact_root = Path(artifact_dir) + manifest = ( + dict(manifest_payload) + if manifest_payload is not None + else json.loads((artifact_root / "manifest.json").read_text()) + ) + stages = ( + dict(stage_manifest) + if stage_manifest is not None + else build_us_stage_manifest( + artifact_root, + manifest_payload=manifest, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ) + ) + artifacts: list[USStageArtifactInventoryRecord] = [] + for stage in stages.get("stages", ()): + if not isinstance(stage, dict): + continue + stage_id = str(stage.get("id", "")) + stage_step = str(stage.get("step", "")) + stage_title = str(stage.get("title", "")) + for artifact in stage.get("artifacts", ()): + if isinstance(artifact, dict): + artifacts.append( + _inventory_record( + artifact, + stage_id=stage_id, + stage_step=stage_step, + stage_title=stage_title, + artifact_root=artifact_root, + max_hash_bytes=max_hash_bytes, + ) + ) + + manifest_artifacts = dict(manifest.get("artifacts", {})) + return { + "schemaVersion": US_STAGE_ARTIFACT_INVENTORY_SCHEMA_VERSION, + "contractVersion": US_STAGE_CONTRACT_VERSION, + "generatedAt": _optional_str(manifest.get("created_at")), + "pipeline": "us_microplex", + "artifactRoot": ".", + "manifest": str(manifest_artifacts.get("manifest", "manifest.json")), + "stageManifest": _optional_str(manifest_artifacts.get("stage_manifest")), + "artifacts": artifacts, + } + + +def write_us_stage_artifact_inventory( + artifact_dir: str | Path, + output_path: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, + assume_existing_artifact_keys: Iterable[str] = (), + max_hash_bytes: int | None = DEFAULT_US_STAGE_ARTIFACT_HASH_MAX_BYTES, +) -> Path: + """Write an artifact inventory sidecar for one saved run.""" + + destination = Path(output_path) + destination.parent.mkdir(parents=True, exist_ok=True) + _write_json_atomically( + destination, + build_us_stage_artifact_inventory( + artifact_dir, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + assume_existing_artifact_keys=assume_existing_artifact_keys, + max_hash_bytes=max_hash_bytes, + ), + ) + return destination + + +def load_us_stage_artifact_inventory(path: str | Path) -> USStageArtifactInventory: + """Load a saved artifact inventory and validate its schema version.""" + + inventory_path = Path(path) + payload = json.loads(inventory_path.read_text()) + if payload.get("schemaVersion") != US_STAGE_ARTIFACT_INVENTORY_SCHEMA_VERSION: + raise RuntimeError( + "Unsupported US stage artifact inventory schema: " + f"{payload.get('schemaVersion')!r}" + ) + return cast(USStageArtifactInventory, payload) + + +def resolve_us_stage_artifact_from_inventory( + artifact_dir: str | Path, + inventory: USStageArtifactInventory | dict[str, Any], + stage_id: str, + artifact_key: str, +) -> Path: + """Resolve one artifact path from a stage artifact inventory.""" + + for artifact in inventory.get("artifacts", ()): + if not isinstance(artifact, dict): + continue + if artifact.get("stageId") != stage_id or artifact.get("key") != artifact_key: + continue + path_text = artifact.get("path") + if not path_text: + raise KeyError(f"Stage artifact has no path: {stage_id}.{artifact_key}") + path = Path(str(path_text)) + if not path.is_absolute(): + path = Path(artifact_dir) / path + return path + raise KeyError(f"Stage artifact not found: {stage_id}.{artifact_key}") + + +def resolve_us_stage_artifact_path_checked( + artifact_dir: str | Path, + stage_id: str, + artifact_key: str, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, + expected_format: StageArtifactFormat | None = None, + require_exists: bool = True, +) -> Path: + """Resolve one stage artifact path and enforce format/existence checks.""" + + artifact_root = Path(artifact_dir) + record = _stage_artifact_record( + artifact_root, + stage_id, + artifact_key, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + ) + actual_format = cast(StageArtifactFormat, record.get("format") or "unknown") + if expected_format is not None and actual_format != expected_format: + raise ValueError( + f"Stage artifact {stage_id}.{artifact_key} has format " + f"{actual_format!r}, expected {expected_format!r}" + ) + path_text = record.get("path") + if not path_text: + raise KeyError(f"Stage artifact has no path: {stage_id}.{artifact_key}") + path = Path(str(path_text)) + if not path.is_absolute(): + path = artifact_root / path + if require_exists and not path.exists(): + raise FileNotFoundError(f"Stage artifact not found: {path}") + return path + + +def load_us_stage_parquet_artifact( + artifact_dir: str | Path, + stage_id: str, + artifact_key: str, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> pd.DataFrame: + """Load one stage-owned parquet dataframe artifact.""" + + path = resolve_us_stage_artifact_path_checked( + artifact_dir, + stage_id, + artifact_key, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="parquet_dataframe", + ) + return pd.read_parquet(path) + + +def load_us_stage_json_artifact( + artifact_dir: str | Path, + stage_id: str, + artifact_key: str, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> dict[str, Any]: + """Load one stage-owned JSON artifact.""" + + path = resolve_us_stage_artifact_path_checked( + artifact_dir, + stage_id, + artifact_key, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ) + payload = json.loads(path.read_text()) + if not isinstance(payload, dict): + raise ValueError(f"Expected JSON object in stage artifact: {path}") + return dict(payload) + + +def load_us_candidate_stage_artifacts( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> USCandidateStageArtifacts: + """Load the saved Stage 5 candidate population artifacts.""" + + seed_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "05_donor_integration_synthesis", + "seed_data", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="parquet_dataframe", + ) + synthetic_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "05_donor_integration_synthesis", + "synthetic_data", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="parquet_dataframe", + ) + return USCandidateStageArtifacts( + seed_data=pd.read_parquet(seed_path), + synthetic_data=pd.read_parquet(synthetic_path), + artifact_paths={ + "seed_data": seed_path, + "synthetic_data": synthetic_path, + }, + ) + + +def load_us_seed_scaffold_stage_artifacts( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> USSeedScaffoldStageArtifacts: + """Load the saved Stage 4 seed/scaffold artifact.""" + + scaffold_seed_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "04_seed_scaffold", + "scaffold_seed_data", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="parquet_dataframe", + ) + return USSeedScaffoldStageArtifacts( + scaffold_seed_data=pd.read_parquet(scaffold_seed_path), + artifact_paths={"scaffold_seed_data": scaffold_seed_path}, + ) + + +def load_us_candidate_calibration_replay_artifacts( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, + include_seed_scaffold: bool = True, +) -> USCandidateCalibrationReplayArtifacts: + """Load the cross-stage artifacts needed to manually replay calibration.""" + + from microplex_us.pipelines.us import USMicroplexTargets + + candidate = load_us_candidate_stage_artifacts( + artifact_dir, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + ) + targets_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "07_calibration", + "targets", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ) + seed_scaffold = None + if include_seed_scaffold: + try: + seed_scaffold = load_us_seed_scaffold_stage_artifacts( + artifact_dir, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + ) + except (KeyError, FileNotFoundError): + seed_scaffold = None + targets_payload = json.loads(targets_path.read_text()) + artifact_paths = { + **dict(candidate.artifact_paths), + "targets": targets_path, + } + if seed_scaffold is not None: + artifact_paths.update(seed_scaffold.artifact_paths) + return USCandidateCalibrationReplayArtifacts( + candidate=candidate, + targets=USMicroplexTargets( + marginal=dict(targets_payload.get("marginal", {})), + continuous=dict(targets_payload.get("continuous", {})), + ), + seed_scaffold=seed_scaffold, + artifact_paths=artifact_paths, + ) + + +def load_us_policyengine_entity_stage_artifacts( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> USPolicyEngineEntityStageArtifacts: + """Load the saved Stage 6 PolicyEngine entity-table bundle.""" + + metadata_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "06_policyengine_entities", + "policyengine_entity_tables", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="policyengine_entity_bundle", + ) + bundle, metadata = load_us_policyengine_entity_stage_artifact(metadata_path) + return USPolicyEngineEntityStageArtifacts( + bundle=bundle, + metadata=metadata, + metadata_path=metadata_path, + ) + + +def load_us_calibrated_stage_artifacts( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> USCalibratedStageArtifacts: + """Load saved Stage 7 calibrated outputs and calibration metadata.""" + + from microplex_us.pipelines.us import USMicroplexTargets + + calibrated_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "07_calibration", + "calibrated_data", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="parquet_dataframe", + ) + targets_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "07_calibration", + "targets", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ) + calibration_summary_path = resolve_us_stage_artifact_path_checked( + artifact_dir, + "07_calibration", + "calibration_summary", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ) + targets_payload = json.loads(targets_path.read_text()) + return USCalibratedStageArtifacts( + calibrated_data=pd.read_parquet(calibrated_path), + targets=USMicroplexTargets( + marginal=dict(targets_payload.get("marginal", {})), + continuous=dict(targets_payload.get("continuous", {})), + ), + calibration_summary=json.loads(calibration_summary_path.read_text()), + artifact_paths={ + "calibrated_data": calibrated_path, + "targets": targets_path, + "calibration_summary": calibration_summary_path, + }, + ) + + +def load_us_dataset_assembly_artifacts( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, +) -> USDatasetAssemblyArtifacts: + """Resolve saved Stage 8 dataset assembly artifacts.""" + + artifact_root = Path(artifact_dir) + return USDatasetAssemblyArtifacts( + policyengine_dataset=resolve_us_stage_artifact_path_checked( + artifact_root, + "08_dataset_assembly", + "policyengine_dataset", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="h5_dataset", + ), + manifest=artifact_root / "manifest.json", + stage_manifest=resolve_us_stage_artifact_path_checked( + artifact_root, + "08_dataset_assembly", + "stage_manifest", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ), + data_flow_snapshot=resolve_us_stage_artifact_path_checked( + artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ), + artifact_inventory=resolve_us_stage_artifact_path_checked( + artifact_root, + "08_dataset_assembly", + "artifact_inventory", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ), + conditional_readiness=resolve_us_stage_artifact_path_checked( + artifact_root, + "08_dataset_assembly", + "conditional_readiness", + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format="json", + ), + ) + + +def _stage_artifact_record( + artifact_root: Path, + stage_id: str, + artifact_key: str, + *, + manifest_payload: dict[str, Any] | None, + stage_manifest: USStageManifest | dict[str, Any] | None, +) -> dict[str, Any]: + manifest = ( + dict(manifest_payload) + if manifest_payload is not None + else json.loads((artifact_root / "manifest.json").read_text()) + ) + stages = ( + dict(stage_manifest) + if stage_manifest is not None + else build_us_stage_manifest(artifact_root, manifest_payload=manifest) + ) + for stage in stages.get("stages", ()): + if not isinstance(stage, dict) or stage.get("id") != stage_id: + continue + for artifact in stage.get("artifacts", ()): + if isinstance(artifact, dict) and artifact.get("key") == artifact_key: + return dict(artifact) + raise KeyError(f"Stage artifact not found: {stage_id}.{artifact_key}") + + +def _resolve_optional_stage_artifact_path( + artifact_dir: str | Path, + stage_id: str, + artifact_key: str, + *, + manifest_payload: dict[str, Any] | None, + stage_manifest: USStageManifest | dict[str, Any] | None, + expected_format: StageArtifactFormat, +) -> Path | None: + try: + return resolve_us_stage_artifact_path_checked( + artifact_dir, + stage_id, + artifact_key, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + expected_format=expected_format, + ) + except (KeyError, FileNotFoundError): + return None + + +def _inventory_record( + artifact: dict[str, Any], + *, + stage_id: str, + stage_step: str, + stage_title: str, + artifact_root: Path, + max_hash_bytes: int | None, +) -> USStageArtifactInventoryRecord: + path_text = _optional_str(artifact.get("path")) + resolved_path = _resolve_artifact_path(artifact_root, path_text) + artifact_format = cast( + StageArtifactFormat, + artifact.get("format") or "unknown", + ) + hash_mode = cast( + StageArtifactHashMode, + artifact.get("hash_mode") or "none", + ) + hash_target = _hash_target_path(resolved_path, artifact_format, hash_mode) + size_bytes, file_count = _artifact_size(hash_target) + content_hash, hash_status = _artifact_hash( + hash_target, + hash_mode=hash_mode, + max_hash_bytes=max_hash_bytes, + ) + return { + "stageId": stage_id, + "stageStep": stage_step, + "stageTitle": stage_title, + "key": str(artifact.get("key", "")), + "description": str(artifact.get("description", "")), + "path": path_text, + "exists": bool(artifact.get("exists")), + "referenced": bool(artifact.get("referenced")), + "required": bool(artifact.get("required")), + "resumeRole": cast(StageArtifactResumeRole | None, artifact.get("resume_role")), + "format": artifact_format, + "hashMode": hash_mode, + "classification": _artifact_classification(artifact), + "sizeBytes": size_bytes, + "fileCount": file_count, + "contentHash": content_hash, + "hashStatus": hash_status, + } + + +def _artifact_classification( + artifact: Mapping[str, Any], +) -> USStageArtifactClassification: + if not bool(artifact.get("exists")): + if bool(artifact.get("required")): + return "missing_required" + if bool(artifact.get("referenced")): + return "missing_optional" + return "contract_only" + resume_role = artifact.get("resume_role") + if resume_role == "diagnostic": + return "diagnostic_only" + if resume_role in {"manual_replay", "manual_resume", "post_artifact_evidence"}: + return cast(USStageArtifactClassification, resume_role) + return "metadata_only" + + +def _resolve_artifact_path(artifact_root: Path, path_text: str | None) -> Path | None: + if path_text is None: + return None + path = Path(path_text) + if not path.is_absolute(): + path = artifact_root / path + return path + + +def _hash_target_path( + path: Path | None, + artifact_format: StageArtifactFormat, + hash_mode: StageArtifactHashMode, +) -> Path | None: + if path is None or hash_mode != "directory_sha256": + return path + if artifact_format == "policyengine_entity_bundle" and path.name == "metadata.json": + return path.parent + return path + + +def _artifact_size(path: Path | None) -> tuple[int | None, int | None]: + if path is None or not path.exists(): + return None, None + if path.is_file(): + return path.stat().st_size, 1 + if path.is_dir(): + total = 0 + count = 0 + for child in _iter_directory_files(path): + total += child.stat().st_size + count += 1 + return total, count + return None, None + + +def _artifact_hash( + path: Path | None, + *, + hash_mode: StageArtifactHashMode, + max_hash_bytes: int | None, +) -> tuple[str | None, USStageArtifactHashStatus]: + if hash_mode == "none": + return None, "not_requested" + if path is None or not path.exists(): + return None, "missing" + try: + if hash_mode == "file_sha256": + if not path.is_file(): + return None, "unsupported" + size = path.stat().st_size + if max_hash_bytes is not None and size > max_hash_bytes: + return None, "too_large" + return _hash_file(path), "hashed" + if hash_mode == "directory_sha256": + if not path.is_dir(): + return None, "unsupported" + size, _ = _artifact_size(path) + if ( + max_hash_bytes is not None + and size is not None + and size > max_hash_bytes + ): + return None, "too_large" + return _hash_directory(path), "hashed" + except OSError: + return None, "error" + return None, "unsupported" + + +def _hash_file(path: Path) -> str: + hasher = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def _hash_directory(path: Path) -> str: + hasher = hashlib.sha256() + for child in _iter_directory_files(path): + relative = child.relative_to(path).as_posix() + hasher.update(relative.encode("utf-8")) + hasher.update(b"\0") + hasher.update(_hash_file(child).encode("ascii")) + hasher.update(b"\0") + return hasher.hexdigest() + + +def _iter_directory_files(path: Path) -> list[Path]: + return sorted(child for child in path.rglob("*") if child.is_file()) + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +def _write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: + temporary = path.with_suffix(path.suffix + ".tmp") + temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) + temporary.replace(path) + + +__all__ = [ + "DEFAULT_US_STAGE_ARTIFACT_HASH_MAX_BYTES", + "US_STAGE_ARTIFACT_INVENTORY_SCHEMA_VERSION", + "USCalibratedStageArtifacts", + "USCandidateStageArtifacts", + "USCandidateCalibrationReplayArtifacts", + "USDatasetAssemblyArtifacts", + "USPolicyEngineEntityStageArtifacts", + "USSeedScaffoldStageArtifacts", + "USStageArtifactClassification", + "USStageArtifactHashStatus", + "USStageArtifactInventory", + "USStageArtifactInventoryRecord", + "build_us_stage_artifact_inventory", + "load_us_calibrated_stage_artifacts", + "load_us_candidate_calibration_replay_artifacts", + "load_us_candidate_stage_artifacts", + "load_us_dataset_assembly_artifacts", + "load_us_policyengine_entity_stage_artifacts", + "load_us_seed_scaffold_stage_artifacts", + "load_us_stage_json_artifact", + "load_us_stage_parquet_artifact", + "load_us_stage_artifact_inventory", + "resolve_us_stage_artifact_path_checked", + "resolve_us_stage_artifact_from_inventory", + "write_us_stage_artifact_inventory", +] diff --git a/src/microplex_us/pipelines/stage_contracts.py b/src/microplex_us/pipelines/stage_contracts.py index cffa4fa..d9fdc5d 100644 --- a/src/microplex_us/pipelines/stage_contracts.py +++ b/src/microplex_us/pipelines/stage_contracts.py @@ -3,9 +3,10 @@ from __future__ import annotations from dataclasses import asdict, dataclass +from pathlib import Path from typing import Literal -US_STAGE_CONTRACT_VERSION = "us-runtime-stages-v1" +US_STAGE_CONTRACT_VERSION = "us-runtime-stages-v2" StageResumeMode = Literal[ "none", @@ -15,6 +16,38 @@ "post_artifact_evidence", ] +StageArtifactResumeRole = Literal[ + "diagnostic", + "manual_replay", + "manual_resume", + "post_artifact_evidence", +] + +StageArtifactFormat = Literal[ + "json", + "parquet_dataframe", + "policyengine_entity_bundle", + "h5_dataset", + "model_file", + "sqlite", + "unknown", +] + +StageArtifactHashMode = Literal[ + "none", + "file_sha256", + "directory_sha256", +] + +StageResourceKind = Literal[ + "artifact", + "config", + "external_data", + "manifest", + "runtime_object", + "stage_output", +] + US_CANONICAL_STAGE_IDS = ( "01_run_profile", "02_source_loading", @@ -74,7 +107,9 @@ class USStageArtifactContract: description: str path_hint: str | None = None required: bool = False - resume_role: str | None = None + resume_role: StageArtifactResumeRole | None = None + format: StageArtifactFormat = "unknown" + hash_mode: StageArtifactHashMode = "none" def to_dict(self) -> dict[str, object]: return asdict(self) @@ -92,6 +127,23 @@ def to_dict(self) -> dict[str, object]: return asdict(self) +@dataclass(frozen=True) +class USStageResourceContract: + """Structured input or output dependency for one canonical build stage.""" + + key: str + description: str + kind: StageResourceKind + required: bool = True + stage_id: str | None = None + artifact_key: str | None = None + config_key: str | None = None + manifest_key: str | None = None + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + @dataclass(frozen=True) class USPipelineStageContract: """Stable contract for one canonical US Microplex runtime stage.""" @@ -102,6 +154,8 @@ class USPipelineStageContract: purpose: str consumes: tuple[str, ...] produces: tuple[str, ...] + inputs: tuple[USStageResourceContract, ...] + outputs: tuple[USStageResourceContract, ...] artifacts: tuple[USStageArtifactContract, ...] diagnostics: tuple[str, ...] validations: tuple[USStageValidationContract, ...] @@ -110,6 +164,8 @@ class USPipelineStageContract: def to_dict(self) -> dict[str, object]: payload = asdict(self) + payload["inputs"] = [resource.to_dict() for resource in self.inputs] + payload["outputs"] = [resource.to_dict() for resource in self.outputs] payload["artifacts"] = [artifact.to_dict() for artifact in self.artifacts] payload["validations"] = [ validation.to_dict() for validation in self.validations @@ -117,6 +173,100 @@ def to_dict(self) -> dict[str, object]: return payload +def _artifact_resource( + key: str, + description: str, + *, + stage_id: str, + artifact_key: str | None = None, + required: bool = True, +) -> USStageResourceContract: + return USStageResourceContract( + key=key, + description=description, + kind="artifact", + required=required, + stage_id=stage_id, + artifact_key=artifact_key or key, + ) + + +def _config_resource( + key: str, + description: str, + *, + config_key: str | None = None, + required: bool = True, +) -> USStageResourceContract: + return USStageResourceContract( + key=key, + description=description, + kind="config", + required=required, + config_key=config_key or key, + ) + + +def _external_resource( + key: str, + description: str, + *, + required: bool = True, +) -> USStageResourceContract: + return USStageResourceContract( + key=key, + description=description, + kind="external_data", + required=required, + ) + + +def _manifest_resource( + key: str, + description: str, + *, + manifest_key: str | None = None, + required: bool = True, +) -> USStageResourceContract: + return USStageResourceContract( + key=key, + description=description, + kind="manifest", + required=required, + manifest_key=manifest_key or key, + ) + + +def _runtime_resource( + key: str, + description: str, + *, + required: bool = True, +) -> USStageResourceContract: + return USStageResourceContract( + key=key, + description=description, + kind="runtime_object", + required=required, + ) + + +def _stage_output_resource( + key: str, + description: str, + *, + stage_id: str, + required: bool = True, +) -> USStageResourceContract: + return USStageResourceContract( + key=key, + description=description, + kind="stage_output", + required=required, + stage_id=stage_id, + ) + + def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...]: """Return the canonical 9-stage US Microplex runtime taxonomy.""" @@ -128,12 +278,52 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] purpose="Resolve the build profile, runtime config, providers, queries, and run-level options.", consumes=("user configuration", "provider defaults", "runtime overrides"), produces=("resolved build config", "provider/query plan"), + inputs=( + _config_resource( + "build_profile", + "Selected build profile and runtime overrides.", + config_key="profile", + required=False, + ), + _config_resource( + "policyengine_target_period", + "Target period used by downstream PolicyEngine export and validation.", + ), + _config_resource( + "calibration_backend", + "Calibration backend selected for this run.", + ), + _config_resource( + "source_names", + "Requested source names or provider defaults.", + required=False, + ), + ), + outputs=( + _artifact_resource( + "manifest", + "Top-level manifest containing resolved configuration and artifact map.", + stage_id="01_run_profile", + ), + _stage_output_resource( + "resolved_config", + "Resolved build configuration recorded for downstream stages.", + stage_id="01_run_profile", + ), + _stage_output_resource( + "provider_query_plan", + "Resolved provider and source-query plan for source loading.", + stage_id="01_run_profile", + ), + ), artifacts=( USStageArtifactContract( key="manifest", description="Top-level artifact manifest with resolved config.", path_hint="manifest.json", required=True, + format="json", + hash_mode="file_sha256", ), ), diagnostics=( @@ -166,6 +356,34 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] "source descriptors", "entity relationships", ), + inputs=( + _stage_output_resource( + "provider_query_plan", + "Resolved provider and source-query plan from Stage 1.", + stage_id="01_run_profile", + ), + _external_resource( + "source_datasets", + "External source datasets requested by the provider/query plan.", + ), + ), + outputs=( + _stage_output_resource( + "observation_frame_summary", + "Saved summary of loaded Microplex observation frames with source metadata.", + stage_id="02_source_loading", + ), + _stage_output_resource( + "source_descriptors", + "Source descriptors attached to the loaded observation frames.", + stage_id="02_source_loading", + ), + _stage_output_resource( + "source_relationships", + "Validated entity relationships in loaded source frames.", + stage_id="02_source_loading", + ), + ), artifacts=(), diagnostics=( "source row counts", @@ -189,12 +407,37 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] purpose="Choose the scaffold source and map donor/source coverage before seed construction.", consumes=("observation frames", "source descriptors"), produces=("fusion plan", "scaffold selection", "donor/source plan"), + inputs=( + _runtime_resource( + "observation_frames", + "Loaded observation frames from Stage 2.", + ), + _runtime_resource( + "source_descriptors", + "Source descriptors attached to the loaded frames.", + ), + ), + outputs=( + _artifact_resource( + "source_plan", + "Saved scaffold and donor/source planning summary.", + stage_id="03_source_planning", + ), + _stage_output_resource( + "scaffold_selection", + "Selected scaffold/backbone source and donor plan.", + stage_id="03_source_planning", + ), + ), artifacts=( USStageArtifactContract( key="source_plan", description="Compact JSON summary of source names, scaffold, and donor variable plan.", path_hint="stage_artifacts/03_source_planning/source_plan.json", + required=True, resume_role="diagnostic", + format="json", + hash_mode="file_sha256", ), ), diagnostics=( @@ -219,6 +462,34 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] purpose="Project the selected scaffold source into the canonical seed structure.", consumes=("source plan", "scaffold frame", "identifier rules"), produces=("scaffold-derived seed frame", "seed schema metadata"), + inputs=( + _artifact_resource( + "source_plan", + "Saved scaffold and donor/source planning summary from Stage 3.", + stage_id="03_source_planning", + ), + _stage_output_resource( + "scaffold_selection", + "Selected scaffold/backbone source from Stage 3.", + stage_id="03_source_planning", + ), + _runtime_resource( + "scaffold_frame", + "Loaded source frame selected as the population scaffold.", + ), + ), + outputs=( + _artifact_resource( + "scaffold_seed_data", + "Scaffold-projected seed population before donor integration.", + stage_id="04_seed_scaffold", + ), + _stage_output_resource( + "seed_schema_metadata", + "Canonical identifier and required-column metadata for the seed.", + stage_id="04_seed_scaffold", + ), + ), artifacts=( USStageArtifactContract( key="scaffold_seed_data", @@ -226,6 +497,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] path_hint="stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet", required=True, resume_role="manual_replay", + format="parquet_dataframe", + hash_mode="file_sha256", ), ), diagnostics=( @@ -259,6 +532,112 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] "synthetic/candidate frame", "synthesis metadata", ), + inputs=( + _artifact_resource( + "scaffold_seed_data", + "Scaffold-projected seed population from Stage 4.", + stage_id="04_seed_scaffold", + ), + _runtime_resource( + "donor_frames", + "Loaded donor source frames used for variable integration.", + ), + _config_resource( + "synthesis_backend", + "Configured synthesis backend.", + ), + _config_resource( + "n_synthetic", + "Requested synthetic population size.", + required=False, + ), + _config_resource( + "random_seed", + "Random seed used by donor integration and synthesis.", + ), + _config_resource( + "synthesizer_condition_vars", + "Configured synthesis conditioning variables.", + required=False, + ), + _config_resource( + "synthesizer_target_vars", + "Configured synthesis target variables.", + required=False, + ), + _config_resource( + "synthesizer_epochs", + "Configured synthesizer training epochs.", + required=False, + ), + _config_resource( + "synthesizer_batch_size", + "Configured synthesizer batch size.", + required=False, + ), + _config_resource( + "synthesizer_learning_rate", + "Configured synthesizer learning rate.", + required=False, + ), + _config_resource( + "synthesizer_n_layers", + "Configured synthesizer network depth.", + required=False, + ), + _config_resource( + "synthesizer_hidden_dim", + "Configured synthesizer hidden dimension.", + required=False, + ), + _config_resource( + "donor_imputer_backend", + "Configured donor imputer backend.", + required=False, + ), + _config_resource( + "donor_imputer_condition_selection", + "Configured donor imputer condition selection strategy.", + required=False, + ), + _config_resource( + "donor_imputer_max_condition_vars", + "Configured donor imputer condition-variable cap.", + required=False, + ), + _config_resource( + "donor_imputer_excluded_variables", + "Variables excluded from donor imputation.", + required=False, + ), + _config_resource( + "donor_imputer_authoritative_override_variables", + "Variables treated as authoritative donor overrides.", + required=False, + ), + _config_resource( + "bootstrap_strata_columns", + "Bootstrap strata columns used by seed/bootstrap synthesis.", + required=False, + ), + ), + outputs=( + _artifact_resource( + "seed_data", + "Seed population after donor integration and semantic guards.", + stage_id="05_donor_integration_synthesis", + ), + _artifact_resource( + "synthetic_data", + "Candidate population before final calibration.", + stage_id="05_donor_integration_synthesis", + ), + _manifest_resource( + "synthesis_metadata", + "Synthesis metadata recorded in the saved manifest.", + manifest_key="synthesis", + ), + ), artifacts=( USStageArtifactContract( key="seed_data", @@ -266,6 +645,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] path_hint="seed_data.parquet", required=True, resume_role="diagnostic", + format="parquet_dataframe", + hash_mode="file_sha256", ), USStageArtifactContract( key="synthetic_data", @@ -273,12 +654,24 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] path_hint="synthetic_data.parquet", required=True, resume_role="manual_replay", + format="parquet_dataframe", + hash_mode="file_sha256", ), USStageArtifactContract( key="synthesizer", description="Optional fitted synthesis model.", path_hint="synthesizer.pt", resume_role="diagnostic", + format="model_file", + hash_mode="file_sha256", + ), + USStageArtifactContract( + key="source_weight_diagnostics", + description="Diagnostic summary of source-level contribution weights.", + path_hint="source_weight_diagnostics.json", + resume_role="diagnostic", + format="json", + hash_mode="file_sha256", ), ), diagnostics=( @@ -306,12 +699,38 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] purpose="Convert candidate rows into PE entity tables and materialize PE-facing inputs.", consumes=("synthetic/candidate frame", "PE input mapping rules"), produces=("PolicyEngine entity table bundle", "materialized PE variables"), + inputs=( + _artifact_resource( + "synthetic_data", + "Candidate population from Stage 5.", + stage_id="05_donor_integration_synthesis", + ), + _runtime_resource( + "policyengine_mapping_rules", + "Rules mapping Microplex candidate rows into PolicyEngine entities.", + ), + ), + outputs=( + _artifact_resource( + "policyengine_entity_tables", + "Reloadable PolicyEngine entity-table checkpoint.", + stage_id="06_policyengine_entities", + ), + _stage_output_resource( + "materialized_policyengine_inputs", + "PolicyEngine-facing variables materialized for calibration/export.", + stage_id="06_policyengine_entities", + ), + ), artifacts=( USStageArtifactContract( key="policyengine_entity_tables", description="Reloadable PE entity-table bundle saved as parquet files plus metadata.", path_hint="stage_artifacts/06_policyengine_entities/metadata.json", + required=True, resume_role="manual_resume", + format="policyengine_entity_bundle", + hash_mode="directory_sha256", ), ), diagnostics=( @@ -334,8 +753,84 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] step="07", title="Target resolution, selection, and calibration", purpose="Resolve target constraints, solve weights, and summarize fit quality.", - consumes=("PE entity table bundle", "target provider/query", "calibration config"), + consumes=( + "PE entity table bundle", + "target provider/query", + "calibration config", + ), produces=("calibrated tables", "calibration summary", "target ledger"), + inputs=( + _artifact_resource( + "policyengine_entity_tables", + "PolicyEngine entity-table checkpoint from Stage 6.", + stage_id="06_policyengine_entities", + ), + _external_resource( + "target_provider", + "Target provider or target database queried for calibration.", + ), + _config_resource( + "calibration_backend", + "Configured calibration backend.", + ), + _config_resource( + "calibration_tol", + "Configured calibration tolerance.", + required=False, + ), + _config_resource( + "calibration_max_iter", + "Configured maximum calibration iterations or epochs.", + required=False, + ), + _config_resource( + "target_sparsity", + "Configured sparse-target selection pressure.", + required=False, + ), + _config_resource( + "policyengine_quantity_targets", + "Configured PolicyEngine quantity targets.", + required=False, + ), + _config_resource( + "policyengine_targets_db", + "PolicyEngine target database used for calibration.", + required=False, + ), + _config_resource( + "policyengine_calibration_target_variables", + "Configured calibration target variables.", + required=False, + ), + _config_resource( + "policyengine_calibration_target_domains", + "Configured calibration target domains.", + required=False, + ), + _config_resource( + "policyengine_calibration_geo_levels", + "Configured calibration geography levels.", + required=False, + ), + ), + outputs=( + _artifact_resource( + "calibrated_data", + "Calibrated output frame.", + stage_id="07_calibration", + ), + _artifact_resource( + "targets", + "Target payload used by the build.", + stage_id="07_calibration", + ), + _artifact_resource( + "calibration_summary", + "Stage-local calibration summary.", + stage_id="07_calibration", + ), + ), artifacts=( USStageArtifactContract( key="calibrated_data", @@ -343,6 +838,8 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] path_hint="calibrated_data.parquet", required=True, resume_role="manual_replay", + format="parquet_dataframe", + hash_mode="file_sha256", ), USStageArtifactContract( key="targets", @@ -350,12 +847,17 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] path_hint="targets.json", required=True, resume_role="manual_replay", + format="json", + hash_mode="file_sha256", ), USStageArtifactContract( key="calibration_summary", description="Stage-local calibration summary JSON.", path_hint="stage_artifacts/07_calibration/calibration_summary.json", + required=True, resume_role="diagnostic", + format="json", + hash_mode="file_sha256", ), ), diagnostics=( @@ -380,26 +882,111 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] step="08", title="Dataset assembly and publication", purpose="Assemble the calibrated output into the distributable PE dataset artifact.", - consumes=("calibrated entity tables", "export variable maps", "period config"), - produces=("PolicyEngine H5 dataset", "artifact manifest", "data-flow snapshot"), + consumes=( + "calibrated entity tables", + "export variable maps", + "period config", + ), + produces=( + "PolicyEngine H5 dataset", + "artifact manifest", + "data-flow snapshot", + ), + inputs=( + _artifact_resource( + "calibrated_data", + "Calibrated output frame from Stage 7.", + stage_id="07_calibration", + ), + _artifact_resource( + "policyengine_entity_tables", + "PolicyEngine entity-table checkpoint from Stage 6.", + stage_id="06_policyengine_entities", + ), + _config_resource( + "policyengine_dataset_year", + "PolicyEngine dataset period used during H5 export.", + required=False, + ), + ), + outputs=( + _artifact_resource( + "policyengine_dataset", + "PolicyEngine-readable H5 dataset.", + stage_id="08_dataset_assembly", + ), + _artifact_resource( + "stage_manifest", + "Canonical saved-run stage manifest.", + stage_id="08_dataset_assembly", + ), + _artifact_resource( + "data_flow_snapshot", + "Site-facing saved-run pipeline snapshot.", + stage_id="08_dataset_assembly", + ), + _artifact_resource( + "artifact_inventory", + "Stage-owned artifact inventory.", + stage_id="08_dataset_assembly", + ), + _artifact_resource( + "conditional_readiness", + "Conditional-readiness report.", + stage_id="08_dataset_assembly", + ), + ), artifacts=( USStageArtifactContract( key="policyengine_dataset", description="PolicyEngine-readable H5 dataset.", path_hint="policyengine_us.h5", + required=True, resume_role="post_artifact_evidence", + format="h5_dataset", + hash_mode="file_sha256", + ), + USStageArtifactContract( + key="capital_gains_lots", + description="Optional synthetic capital-gains lot sidecar database.", + path_hint="capital_gains_lots.sqlite", + resume_role="diagnostic", + format="sqlite", + hash_mode="file_sha256", ), USStageArtifactContract( key="stage_manifest", description="Canonical stage manifest for the saved run.", path_hint="stage_manifest.json", required=True, + format="json", + hash_mode="file_sha256", ), USStageArtifactContract( key="data_flow_snapshot", description="Site-facing saved-run pipeline snapshot.", path_hint="data_flow_snapshot.json", required=True, + format="json", + hash_mode="file_sha256", + ), + USStageArtifactContract( + key="artifact_inventory", + description="Stage-owned artifact inventory with existence, role, and hash metadata.", + path_hint="stage_artifacts/artifact_inventory.json", + required=True, + resume_role="diagnostic", + format="json", + hash_mode="none", + ), + USStageArtifactContract( + key="conditional_readiness", + description="Conditional-readiness report for manual reuse decisions.", + path_hint="stage_artifacts/conditional_readiness.json", + required=True, + resume_role="diagnostic", + format="json", + hash_mode="none", ), ), diagnostics=( @@ -422,26 +1009,118 @@ def default_us_pipeline_stage_contracts() -> tuple[USPipelineStageContract, ...] step="09", title="Validation and benchmarking", purpose="Evaluate the assembled dataset and attach benchmark evidence.", - consumes=("PolicyEngine H5 dataset", "baseline dataset", "target provider/query"), - produces=("harness evidence", "native scores", "audits", "run registry/index evidence"), + consumes=( + "PolicyEngine H5 dataset", + "baseline dataset", + "target provider/query", + ), + produces=( + "harness evidence", + "native scores", + "audits", + "run registry/index evidence", + ), + inputs=( + _artifact_resource( + "policyengine_dataset", + "PolicyEngine-readable H5 dataset from Stage 8.", + stage_id="08_dataset_assembly", + ), + _external_resource( + "baseline_dataset", + "Baseline dataset used by validation or comparison harnesses.", + required=False, + ), + _external_resource( + "target_provider", + "Target provider or target database used for benchmark evidence.", + required=False, + ), + _config_resource( + "policyengine_dataset_year", + "PolicyEngine dataset period used during validation.", + required=False, + ), + ), + outputs=( + _artifact_resource( + "validation_evidence", + "Stage-local evidence manifest for validation sidecars.", + stage_id="09_validation_benchmarking", + ), + _stage_output_resource( + "benchmark_summary", + "Saved summary of validation and benchmark evidence attached to the run.", + stage_id="09_validation_benchmarking", + ), + _artifact_resource( + "policyengine_harness", + "PolicyEngine harness comparison payload.", + stage_id="09_validation_benchmarking", + required=False, + ), + _artifact_resource( + "policyengine_native_scores", + "PE-US-data native score comparison payload.", + stage_id="09_validation_benchmarking", + required=False, + ), + _artifact_resource( + "policyengine_native_audit", + "PE-US-data native score audit payload.", + stage_id="09_validation_benchmarking", + required=False, + ), + ), artifacts=( USStageArtifactContract( key="policyengine_harness", description="PolicyEngine harness comparison payload.", path_hint="policyengine_harness.json", resume_role="diagnostic", + format="json", + hash_mode="file_sha256", ), USStageArtifactContract( key="policyengine_native_scores", description="PE-US-data native score comparison payload.", path_hint="policyengine_native_scores.json", resume_role="diagnostic", + format="json", + hash_mode="file_sha256", + ), + USStageArtifactContract( + key="policyengine_native_audit", + description="PE-US-data native score audit payload.", + path_hint="pe_us_data_rebuild_native_audit.json", + resume_role="diagnostic", + format="json", + hash_mode="file_sha256", + ), + USStageArtifactContract( + key="imputation_ablation", + description="Imputation ablation benchmark payload.", + path_hint="imputation_ablation.json", + resume_role="diagnostic", + format="json", + hash_mode="file_sha256", + ), + USStageArtifactContract( + key="child_tax_unit_agi_drift", + description="Child tax-unit AGI drift diagnostic payload.", + path_hint="child_tax_unit_agi_drift.json", + resume_role="diagnostic", + format="json", + hash_mode="file_sha256", ), USStageArtifactContract( key="validation_evidence", description="Stage-local evidence manifest for validation sidecars.", path_hint="stage_artifacts/09_validation_benchmarking/evidence_manifest.json", + required=True, resume_role="diagnostic", + format="json", + hash_mode="file_sha256", ), ), diagnostics=( @@ -472,6 +1151,45 @@ def get_us_pipeline_stage_contract(stage_id: str) -> USPipelineStageContract: raise KeyError(f"Unknown US pipeline stage contract: {stage_id}") +def get_us_stage_artifact_contract( + stage_id: str, + artifact_key: str, +) -> USStageArtifactContract: + """Return one artifact contract from a canonical stage.""" + + contract = get_us_pipeline_stage_contract(stage_id) + for artifact in contract.artifacts: + if artifact.key == artifact_key: + return artifact + raise KeyError(f"Unknown US stage artifact contract: {stage_id}.{artifact_key}") + + +def resolve_us_stage_artifact_contract_path( + artifact_dir: str | Path, + stage_id: str, + artifact_key: str, +) -> Path: + """Resolve a stage artifact's canonical path from its contract path hint.""" + + artifact = get_us_stage_artifact_contract(stage_id, artifact_key) + if artifact.path_hint is None: + raise KeyError(f"US stage artifact has no path hint: {stage_id}.{artifact_key}") + return Path(artifact_dir) / artifact.path_hint + + +def config_keys_for_us_pipeline_stage(stage_id: str) -> tuple[str, ...]: + """Return config keys that affect one canonical stage's reuse checks.""" + + contract = get_us_pipeline_stage_contract(stage_id) + return tuple( + dict.fromkeys( + resource.config_key + for resource in contract.inputs + if resource.kind == "config" and resource.config_key is not None + ) + ) + + def serialize_us_pipeline_stage_contracts() -> dict[str, object]: """Serialize the canonical US stage contract registry.""" @@ -485,14 +1203,23 @@ def serialize_us_pipeline_stage_contracts() -> dict[str, object]: __all__ = [ + "StageArtifactFormat", + "StageArtifactHashMode", + "StageArtifactResumeRole", + "StageResourceKind", + "StageResumeMode", "US_CANONICAL_STAGE_IDS", "US_LEGACY_STAGE_ID_ALIASES", "US_STAGE_CONTRACT_VERSION", "USPipelineStageContract", "USStageArtifactContract", + "USStageResourceContract", "USStageValidationContract", "canonicalize_us_pipeline_stage_id", + "config_keys_for_us_pipeline_stage", "default_us_pipeline_stage_contracts", + "get_us_stage_artifact_contract", "get_us_pipeline_stage_contract", + "resolve_us_stage_artifact_contract_path", "serialize_us_pipeline_stage_contracts", ] diff --git a/src/microplex_us/pipelines/stage_data_flow.py b/src/microplex_us/pipelines/stage_data_flow.py new file mode 100644 index 0000000..fa91098 --- /dev/null +++ b/src/microplex_us/pipelines/stage_data_flow.py @@ -0,0 +1,59 @@ +"""Data-flow snapshot adapters for saved US stage manifests.""" + +from __future__ import annotations + +from typing import Any, cast + +from microplex_us.pipelines.stage_contracts import StageResumeMode +from microplex_us.pipelines.stage_manifest_types import ( + USDataFlowStageSummary, + USStageManifest, + USStageMetric, + USStageStatus, +) + + +def stage_summary_for_data_flow_snapshot( + stage_manifest: USStageManifest | dict[str, Any], +) -> list[USDataFlowStageSummary]: + """Return site-facing stage summaries from a canonical stage manifest.""" + + summaries: list[USDataFlowStageSummary] = [] + for stage in stage_manifest.get("stages", ()): + if not isinstance(stage, dict): + continue + resume = stage.get("resume", {}) + summaries.append( + { + "id": str(stage.get("id", "")), + "step": str(stage.get("step", "")), + "title": str(stage.get("title", "")), + "summary": str(stage.get("purpose", "")), + "status": cast(USStageStatus, stage.get("status", "missing")), + "metrics": cast(list[USStageMetric], list(stage.get("metrics", ()))), + "outputs": _stage_output_paths_for_data_flow(stage), + "resumeMode": cast( + StageResumeMode, + resume.get("mode", "none") if isinstance(resume, dict) else "none", + ), + } + ) + return summaries + + +def _stage_output_paths_for_data_flow(stage: dict[str, Any]) -> list[str]: + """Return artifact paths that a saved run actually referenced or produced.""" + + outputs: list[str] = [] + for artifact in stage.get("artifacts", ()): + if not isinstance(artifact, dict): + continue + path = artifact.get("path") + if not path: + continue + if bool(artifact.get("exists")) or bool(artifact.get("referenced")): + outputs.append(str(path)) + return outputs + + +__all__ = ["stage_summary_for_data_flow_snapshot"] diff --git a/src/microplex_us/pipelines/stage_manifest.py b/src/microplex_us/pipelines/stage_manifest.py index 5d5b043..a7d4c4f 100644 --- a/src/microplex_us/pipelines/stage_manifest.py +++ b/src/microplex_us/pipelines/stage_manifest.py @@ -1,740 +1,47 @@ -"""Stage manifest and reusable stage artifact helpers for US builds.""" +"""Compatibility facade for US saved-run stage manifest helpers.""" from __future__ import annotations -import json -from collections.abc import Iterable, Mapping -from pathlib import Path -from typing import Any, Literal, TypedDict, cast - -from microplex_us.pipelines.stage_contracts import ( - US_STAGE_CONTRACT_VERSION, - StageResumeMode, - USPipelineStageContract, - USStageArtifactContract, - default_us_pipeline_stage_contracts, +from microplex_us.pipelines.stage_data_flow import stage_summary_for_data_flow_snapshot +from microplex_us.pipelines.stage_manifest_builder import ( + build_us_stage_manifest, + resolve_us_stage_artifact_path, ) -from microplex_us.policyengine.us import ( - PolicyEngineUSEntityTableBundle, - load_us_pipeline_checkpoint, - save_us_pipeline_checkpoint, +from microplex_us.pipelines.stage_manifest_io import ( + load_us_stage_manifest, + write_us_stage_manifest, +) +from microplex_us.pipelines.stage_manifest_types import ( + SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS, + US_POLICYENGINE_ENTITY_STAGE_ID, + US_STAGE_ARTIFACT_ROOT, + US_STAGE_MANIFEST_SCHEMA_VERSION, + US_VALIDATION_STAGE_ID, + USDataFlowStageSummary, + USStageArtifactRecord, + USStageManifest, + USStageMetric, + USStageMetricValue, + USStageRecord, + USStageResourceRecord, + USStageResumeRecord, + USStageStatus, + USStageValidationRecord, + USStageValidationStatus, + USValidationEvidenceManifest, + USValidationEvidenceRecord, +) +from microplex_us.pipelines.stage_policyengine_artifacts import ( + load_us_policyengine_entity_stage_artifact, + write_us_policyengine_entity_stage_artifact, +) +from microplex_us.pipelines.stage_validation_evidence import ( + build_us_validation_evidence_manifest, + write_us_validation_evidence_manifest, ) - -US_STAGE_MANIFEST_SCHEMA_VERSION = 1 -US_STAGE_ARTIFACT_ROOT = "stage_artifacts" -US_POLICYENGINE_ENTITY_STAGE_ID = "06_policyengine_entities" -US_VALIDATION_STAGE_ID = "09_validation_benchmarking" - - -USStageMetricValue = str | int | float | bool | None - -USStageStatus = Literal[ - "ready", - "metadata_only", - "deferred", - "incomplete", - "missing", -] - -USStageValidationStatus = Literal["planned", "manual", "implemented"] - - -class USStageMetric(TypedDict): - """One compact metric shown for a saved stage.""" - - label: str - value: USStageMetricValue - - -class USStageArtifactRecord(TypedDict): - """Saved-run view of one stage artifact contract.""" - - key: str - description: str - path_hint: str | None - required: bool - resume_role: str | None - path: str | None - exists: bool - referenced: bool - - -class USStageResumeRecord(TypedDict): - """Saved-run resume metadata for one stage.""" - - mode: StageResumeMode - notes: str - - -class USStageValidationRecord(TypedDict): - """Saved-run view of one planned or implemented validation.""" - - key: str - description: str - status: USStageValidationStatus - - -class USStageRecord(TypedDict): - """One stage entry in a US stage manifest.""" - - id: str - step: str - title: str - purpose: str - status: USStageStatus - consumes: list[str] - produces: list[str] - artifacts: list[USStageArtifactRecord] - diagnostics: list[str] - validations: list[USStageValidationRecord] - resume: USStageResumeRecord - metrics: list[USStageMetric] - - -class USStageManifest(TypedDict): - """Canonical saved-run stage manifest.""" - - schemaVersion: int - contractVersion: str - generatedAt: str | None - pipeline: str - artifactRoot: str - manifest: str - stages: list[USStageRecord] - - -class USDataFlowStageSummary(TypedDict): - """Stage summary embedded in the site-facing data-flow snapshot.""" - - id: str - step: str - title: str - summary: str - status: USStageStatus - metrics: list[USStageMetric] - outputs: list[str] - resumeMode: StageResumeMode - - -class USValidationEvidenceRecord(TypedDict): - """One validation or benchmarking evidence sidecar.""" - - key: str - path: str - exists: bool - - -class USValidationEvidenceManifest(TypedDict): - """Stage 9 evidence index.""" - - formatVersion: int - stageId: str - evidence: list[USValidationEvidenceRecord] - summaries: dict[str, Any] - - -def write_us_stage_manifest( - artifact_dir: str | Path, - output_path: str | Path, - *, - manifest_payload: dict[str, Any], - assume_existing_artifact_keys: Iterable[str] = (), -) -> Path: - """Write the canonical stage manifest for a saved US artifact bundle.""" - - destination = Path(output_path) - destination.parent.mkdir(parents=True, exist_ok=True) - _write_json_atomically( - destination, - build_us_stage_manifest( - artifact_dir, - manifest_payload=manifest_payload, - assume_existing_artifact_keys=( - *tuple(assume_existing_artifact_keys), - "stage_manifest", - ), - ), - ) - return destination - - -def load_us_stage_manifest(path: str | Path) -> USStageManifest: - """Load a saved stage manifest and validate its schema version.""" - - manifest_path = Path(path) - payload = json.loads(manifest_path.read_text()) - if payload.get("schemaVersion") != US_STAGE_MANIFEST_SCHEMA_VERSION: - raise RuntimeError( - "Unsupported US stage manifest schema: " - f"{payload.get('schemaVersion')!r}" - ) - return cast(USStageManifest, payload) - - -def build_us_stage_manifest( - artifact_dir: str | Path, - *, - manifest_payload: dict[str, Any], - assume_existing_artifact_keys: Iterable[str] = (), -) -> USStageManifest: - """Build the canonical stage manifest from a saved artifact manifest.""" - - artifact_root = Path(artifact_dir) - manifest = dict(manifest_payload) - artifact_map = dict(manifest.get("artifacts", {})) - assumed_existing = set(assume_existing_artifact_keys) - stages = [ - _stage_record( - contract, - artifact_root=artifact_root, - manifest=manifest, - assume_existing_artifact_keys=assumed_existing, - ) - for contract in default_us_pipeline_stage_contracts() - ] - return { - "schemaVersion": US_STAGE_MANIFEST_SCHEMA_VERSION, - "contractVersion": US_STAGE_CONTRACT_VERSION, - "generatedAt": _optional_str(manifest.get("created_at")), - "pipeline": "us_microplex", - "artifactRoot": ".", - "manifest": str(artifact_map.get("manifest", "manifest.json")), - "stages": stages, - } - - -def stage_summary_for_data_flow_snapshot( - stage_manifest: USStageManifest | dict[str, Any], -) -> list[USDataFlowStageSummary]: - """Return site-facing stage summaries from a canonical stage manifest.""" - - summaries: list[USDataFlowStageSummary] = [] - for stage in stage_manifest.get("stages", ()): - if not isinstance(stage, dict): - continue - resume = stage.get("resume", {}) - summaries.append( - { - "id": str(stage.get("id", "")), - "step": str(stage.get("step", "")), - "title": str(stage.get("title", "")), - "summary": str(stage.get("purpose", "")), - "status": cast(USStageStatus, stage.get("status", "missing")), - "metrics": cast(list[USStageMetric], list(stage.get("metrics", ()))), - "outputs": _stage_output_paths_for_data_flow(stage), - "resumeMode": cast( - StageResumeMode, - resume.get("mode", "none") if isinstance(resume, dict) else "none", - ), - } - ) - return summaries - - -def _stage_output_paths_for_data_flow(stage: dict[str, Any]) -> list[str]: - """Return artifact paths that a saved run actually referenced or produced.""" - - outputs: list[str] = [] - for artifact in stage.get("artifacts", ()): - if not isinstance(artifact, dict): - continue - path = artifact.get("path") - if not path: - continue - if bool(artifact.get("exists")) or bool(artifact.get("referenced")): - outputs.append(str(path)) - return outputs - - -def write_us_policyengine_entity_stage_artifact( - bundle: PolicyEngineUSEntityTableBundle, - artifact_root: str | Path, -) -> Path: - """Persist a Stage 6 PE entity-table checkpoint under a saved-run root.""" - - stage_dir = save_us_pipeline_checkpoint( - bundle, - Path(artifact_root) / US_STAGE_ARTIFACT_ROOT / US_POLICYENGINE_ENTITY_STAGE_ID, - stage="post_microsim", - ) - metadata_path = stage_dir / "metadata.json" - metadata = json.loads(metadata_path.read_text()) - metadata["stageId"] = US_POLICYENGINE_ENTITY_STAGE_ID - _write_json_atomically(metadata_path, metadata) - return metadata_path - - -def load_us_policyengine_entity_stage_artifact( - path: str | Path, -) -> tuple[PolicyEngineUSEntityTableBundle, dict[str, Any]]: - """Load a Stage 6 PE entity-table bundle artifact.""" - - input_path = Path(path) - checkpoint_dir = input_path if input_path.is_dir() else input_path.parent - bundle, metadata = load_us_pipeline_checkpoint( - checkpoint_dir, - expected_stage="post_microsim", - ) - return bundle, metadata - - -def build_us_validation_evidence_manifest( - artifact_dir: str | Path, - *, - manifest_payload: dict[str, Any], -) -> USValidationEvidenceManifest: - """Build a compact Stage 9 evidence index from a saved artifact manifest.""" - - artifact_root = Path(artifact_dir) - artifacts = dict(manifest_payload.get("artifacts", {})) - evidence_keys = ( - "policyengine_harness", - "policyengine_native_scores", - "policyengine_native_audit", - "imputation_ablation", - "child_tax_unit_agi_drift", - ) - evidence: list[USValidationEvidenceRecord] = [] - for key in evidence_keys: - filename = artifacts.get(key) - if not filename: - continue - path_text = str(filename) - path = Path(path_text) - if not path.is_absolute(): - path = artifact_root / path - evidence.append( - { - "key": key, - "path": path_text, - "exists": path.exists(), - } - ) - return { - "formatVersion": 1, - "stageId": US_VALIDATION_STAGE_ID, - "evidence": evidence, - "summaries": { - key: manifest_payload[key] - for key in ( - "policyengine_harness", - "policyengine_native_scores", - "policyengine_native_audit", - "imputation_ablation", - ) - if isinstance(manifest_payload.get(key), dict) - }, - } - - -def write_us_validation_evidence_manifest( - artifact_dir: str | Path, - output_path: str | Path, - *, - manifest_payload: dict[str, Any], -) -> Path: - """Write a Stage 9 evidence manifest for validation/benchmark sidecars.""" - - destination = Path(output_path) - destination.parent.mkdir(parents=True, exist_ok=True) - _write_json_atomically( - destination, - build_us_validation_evidence_manifest( - artifact_dir, - manifest_payload=manifest_payload, - ), - ) - return destination - - -def resolve_us_stage_artifact_path( - artifact_dir: str | Path, - stage_manifest: dict[str, Any], - stage_id: str, - artifact_key: str, -) -> Path: - """Resolve one artifact path from a stage manifest.""" - - for stage in stage_manifest.get("stages", ()): - if not isinstance(stage, dict) or stage.get("id") != stage_id: - continue - for artifact in stage.get("artifacts", ()): - if ( - isinstance(artifact, dict) - and artifact.get("key") == artifact_key - and artifact.get("path") - ): - path = Path(str(artifact["path"])) - if not path.is_absolute(): - path = Path(artifact_dir) / path - return path - raise KeyError(f"Stage artifact not found: {stage_id}.{artifact_key}") - - -def _stage_record( - contract: USPipelineStageContract, - *, - artifact_root: Path, - manifest: dict[str, Any], - assume_existing_artifact_keys: set[str], -) -> USStageRecord: - artifacts = [ - _artifact_record( - artifact, - artifact_root=artifact_root, - manifest=manifest, - assume_existing_artifact_keys=assume_existing_artifact_keys, - ) - for artifact in contract.artifacts - ] - return { - "id": contract.id, - "step": contract.step, - "title": contract.title, - "purpose": contract.purpose, - "status": _stage_status( - contract.id, - artifact_root=artifact_root, - manifest=manifest, - artifacts=artifacts, - assume_existing_artifact_keys=assume_existing_artifact_keys, - ), - "consumes": list(contract.consumes), - "produces": list(contract.produces), - "artifacts": artifacts, - "diagnostics": list(contract.diagnostics), - "validations": cast( - list[USStageValidationRecord], - [validation.to_dict() for validation in contract.validations], - ), - "resume": { - "mode": contract.resume_mode, - "notes": contract.resume_notes, - }, - "metrics": _stage_metrics(contract.id, manifest=manifest), - } - - -def _artifact_record( - artifact: USStageArtifactContract, - *, - artifact_root: Path, - manifest: dict[str, Any], - assume_existing_artifact_keys: set[str], -) -> USStageArtifactRecord: - artifacts = dict(manifest.get("artifacts", {})) - manifest_path = artifacts.get(artifact.key) - path = str(manifest_path) if manifest_path else artifact.path_hint - exists = False - if path: - resolved = Path(str(path)) - if not resolved.is_absolute(): - resolved = artifact_root / resolved - exists = resolved.exists() or artifact.key in assume_existing_artifact_keys - return { - **artifact.to_dict(), - "path": path, - "exists": exists, - "referenced": manifest_path is not None, - } - - -def _stage_status( - stage_id: str, - *, - artifact_root: Path, - manifest: dict[str, Any], - artifacts: list[USStageArtifactRecord], - assume_existing_artifact_keys: set[str], -) -> USStageStatus: - artifact_map = dict(manifest.get("artifacts", {})) - synthesis = dict(manifest.get("synthesis", {})) - calibration = dict(manifest.get("calibration", {})) - rows = dict(manifest.get("rows", {})) - if stage_id == "01_run_profile": - if _referenced_artifact_missing(artifacts): - return "incomplete" - if _artifact_exists(artifacts, "manifest"): - return "ready" - return "metadata_only" if manifest.get("config") else "missing" - if stage_id == "02_source_loading": - return "metadata_only" if synthesis.get("source_names") else "missing" - if stage_id == "03_source_planning": - if _referenced_artifact_missing(artifacts): - return "incomplete" - if _artifact_exists(artifacts, "source_plan"): - return "ready" - return "metadata_only" if synthesis.get("scaffold_source") else "missing" - if stage_id == "04_seed_scaffold": - if _referenced_artifact_missing(artifacts, required_only=True): - return "incomplete" - if _required_artifacts_exist(artifacts): - return "ready" - return ( - "metadata_only" - if rows.get("seed") or synthesis.get("scaffold_source") - else "missing" - ) - if stage_id == "05_donor_integration_synthesis": - if _referenced_artifact_missing(artifacts, required_only=True): - return "incomplete" - if _required_artifacts_exist(artifacts): - return "ready" - return ( - "metadata_only" if rows.get("seed") or rows.get("synthetic") else "missing" - ) - if stage_id == "06_policyengine_entities": - if _referenced_artifact_missing(artifacts): - return "incomplete" - if _artifact_exists(artifacts, "policyengine_entity_tables"): - return "ready" - if _manifest_artifact_exists( - manifest, - artifact_root, - "policyengine_dataset", - assume_existing_artifact_keys=assume_existing_artifact_keys, - ): - return "metadata_only" - return "missing" - if stage_id == "07_calibration": - if _referenced_artifact_missing(artifacts, required_only=True): - return "incomplete" - if calibration and _required_artifacts_exist(artifacts): - return "ready" - return "metadata_only" if calibration and rows.get("calibrated") else "missing" - if stage_id == "08_dataset_assembly": - if _manifest_artifact_missing( - manifest, - artifact_root, - ("policyengine_dataset", "stage_manifest", "data_flow_snapshot"), - assume_existing_artifact_keys=assume_existing_artifact_keys, - ): - return "incomplete" - if _manifest_artifact_exists( - manifest, - artifact_root, - "policyengine_dataset", - assume_existing_artifact_keys=assume_existing_artifact_keys, - ): - return "ready" - return "metadata_only" if artifact_map.get("stage_manifest") else "missing" - if stage_id == "09_validation_benchmarking": - evidence_keys = ( - "policyengine_harness", - "policyengine_native_scores", - "policyengine_native_audit", - "imputation_ablation", - ) - evidence_index_keys = ("validation_evidence",) - if _manifest_artifact_missing( - manifest, - artifact_root, - (*evidence_keys, *evidence_index_keys), - assume_existing_artifact_keys=assume_existing_artifact_keys, - ): - return "incomplete" - has_evidence = any( - _manifest_artifact_exists( - manifest, - artifact_root, - key, - assume_existing_artifact_keys=assume_existing_artifact_keys, - ) - for key in evidence_keys - ) - if not has_evidence: - has_evidence = _validation_evidence_index_has_existing_evidence( - manifest, - artifact_root, - assume_existing_artifact_keys=assume_existing_artifact_keys, - ) - if has_evidence: - return "ready" - if _manifest_artifact_exists( - manifest, - artifact_root, - "policyengine_dataset", - assume_existing_artifact_keys=assume_existing_artifact_keys, - ): - return "deferred" - return "missing" - if any(artifact.get("exists") for artifact in artifacts): - return "ready" - return "missing" - - -def _required_artifacts_exist(artifacts: list[USStageArtifactRecord]) -> bool: - required = [artifact for artifact in artifacts if bool(artifact.get("required"))] - return bool(required) and all(bool(artifact.get("exists")) for artifact in required) - - -def _artifact_exists(artifacts: list[USStageArtifactRecord], key: str) -> bool: - return any( - artifact.get("key") == key and bool(artifact.get("exists")) - for artifact in artifacts - ) - - -def _referenced_artifact_missing( - artifacts: list[USStageArtifactRecord], - *, - required_only: bool = False, -) -> bool: - return any( - bool(artifact.get("referenced")) - and not bool(artifact.get("exists")) - and (not required_only or bool(artifact.get("required"))) - for artifact in artifacts - ) - - -def _manifest_artifact_exists( - manifest: dict[str, Any], - artifact_root: Path, - artifact_key: str, - *, - assume_existing_artifact_keys: set[str], -) -> bool: - path = _manifest_artifact_path(manifest, artifact_root, artifact_key) - if path is None: - return False - if artifact_key in assume_existing_artifact_keys: - return True - return path.exists() - - -def _manifest_artifact_missing( - manifest: dict[str, Any], - artifact_root: Path, - artifact_keys: tuple[str, ...], - *, - assume_existing_artifact_keys: set[str], -) -> bool: - artifacts = dict(manifest.get("artifacts", {})) - return any( - bool(artifacts.get(key)) - and not _manifest_artifact_exists( - manifest, - artifact_root, - key, - assume_existing_artifact_keys=assume_existing_artifact_keys, - ) - for key in artifact_keys - ) - - -def _validation_evidence_index_has_existing_evidence( - manifest: dict[str, Any], - artifact_root: Path, - *, - assume_existing_artifact_keys: set[str], -) -> bool: - path = _manifest_artifact_path(manifest, artifact_root, "validation_evidence") - if path is None: - return False - if "validation_evidence" in assume_existing_artifact_keys and not path.exists(): - return False - if not path.exists(): - return False - try: - payload = json.loads(path.read_text()) - except (OSError, json.JSONDecodeError): - return False - evidence = payload.get("evidence") - if not isinstance(evidence, list): - return False - for record in evidence: - if not isinstance(record, dict) or not record.get("path"): - continue - evidence_path = Path(str(record["path"])) - if not evidence_path.is_absolute(): - evidence_path = artifact_root / evidence_path - if evidence_path.exists(): - return True - return False - - -def _manifest_artifact_path( - manifest: dict[str, Any], - artifact_root: Path, - artifact_key: str, -) -> Path | None: - artifacts = dict(manifest.get("artifacts", {})) - filename = artifacts.get(artifact_key) - if not filename: - return None - path = Path(str(filename)) - if not path.is_absolute(): - path = artifact_root / path - return path - - -def _optional_str(value: Any) -> str | None: - if value is None: - return None - return str(value) - - -def _stage_metrics(stage_id: str, *, manifest: dict[str, Any]) -> list[USStageMetric]: - synthesis = dict(manifest.get("synthesis", {})) - calibration = dict(manifest.get("calibration", {})) - artifacts = dict(manifest.get("artifacts", {})) - harness = dict(manifest.get("policyengine_harness", {})) - native_scores = dict(manifest.get("policyengine_native_scores", {})) - rows = dict(manifest.get("rows", {})) - config = dict(manifest.get("config", {})) - if stage_id == "01_run_profile": - return [ - {"label": "Target period", "value": config.get("policyengine_target_period")}, - {"label": "Backend", "value": config.get("calibration_backend")}, - ] - if stage_id == "02_source_loading": - return [ - {"label": "Sources", "value": len(synthesis.get("source_names", ()))}, - ] - if stage_id == "03_source_planning": - return [{"label": "Scaffold", "value": synthesis.get("scaffold_source")}] - if stage_id == "04_seed_scaffold": - return [ - {"label": "Seed rows", "value": rows.get("seed")}, - {"label": "Scaffold", "value": synthesis.get("scaffold_source")}, - ] - if stage_id == "05_donor_integration_synthesis": - return [ - {"label": "Seed rows", "value": rows.get("seed")}, - { - "label": "Integrated vars", - "value": len(synthesis.get("donor_integrated_variables", ())), - }, - {"label": "Backend", "value": synthesis.get("backend")}, - {"label": "Synthetic rows", "value": rows.get("synthetic")}, - ] - if stage_id == "06_policyengine_entities": - return [{"label": "Entity bundle", "value": artifacts.get("policyengine_entity_tables")}] - if stage_id == "07_calibration": - return [ - {"label": "Backend", "value": calibration.get("backend")}, - {"label": "Supported", "value": calibration.get("n_supported_targets")}, - {"label": "Converged", "value": calibration.get("converged")}, - ] - if stage_id == "08_dataset_assembly": - return [{"label": "Dataset", "value": artifacts.get("policyengine_dataset")}] - if stage_id == "09_validation_benchmarking": - return [ - {"label": "Harness delta", "value": harness.get("mean_abs_relative_error_delta")}, - {"label": "Native delta", "value": native_scores.get("enhanced_cps_native_loss_delta")}, - {"label": "Win rate", "value": harness.get("target_win_rate")}, - ] - return [] - - -def _write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: - temporary = path.with_suffix(path.suffix + ".tmp") - temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) - temporary.replace(path) - __all__ = [ + "SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS", "USDataFlowStageSummary", "US_POLICYENGINE_ENTITY_STAGE_ID", "US_STAGE_ARTIFACT_ROOT", @@ -744,6 +51,7 @@ def _write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: "USStageMetric", "USStageMetricValue", "USStageRecord", + "USStageResourceRecord", "USStageResumeRecord", "USStageStatus", "USStageValidationRecord", diff --git a/src/microplex_us/pipelines/stage_manifest_builder.py b/src/microplex_us/pipelines/stage_manifest_builder.py new file mode 100644 index 0000000..5befd74 --- /dev/null +++ b/src/microplex_us/pipelines/stage_manifest_builder.py @@ -0,0 +1,172 @@ +"""Build aggregate saved-run stage manifests for US pipeline artifacts.""" + +from __future__ import annotations + +from collections.abc import Iterable +from pathlib import Path +from typing import Any, cast + +from microplex_us.pipelines.stage_contracts import ( + US_STAGE_CONTRACT_VERSION, + USPipelineStageContract, + USStageArtifactContract, + USStageResourceContract, + default_us_pipeline_stage_contracts, +) +from microplex_us.pipelines.stage_manifest_types import ( + US_STAGE_MANIFEST_SCHEMA_VERSION, + USStageArtifactRecord, + USStageManifest, + USStageRecord, + USStageResourceRecord, + USStageValidationRecord, +) +from microplex_us.pipelines.stage_metrics import stage_metrics +from microplex_us.pipelines.stage_status import stage_status + + +def build_us_stage_manifest( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any], + assume_existing_artifact_keys: Iterable[str] = (), +) -> USStageManifest: + """Build the canonical stage manifest from a saved artifact manifest.""" + + artifact_root = Path(artifact_dir) + manifest = dict(manifest_payload) + artifact_map = dict(manifest.get("artifacts", {})) + assumed_existing = set(assume_existing_artifact_keys) + stages = [ + _stage_record( + contract, + artifact_root=artifact_root, + manifest=manifest, + assume_existing_artifact_keys=assumed_existing, + ) + for contract in default_us_pipeline_stage_contracts() + ] + return { + "schemaVersion": US_STAGE_MANIFEST_SCHEMA_VERSION, + "contractVersion": US_STAGE_CONTRACT_VERSION, + "generatedAt": _optional_str(manifest.get("created_at")), + "pipeline": "us_microplex", + "artifactRoot": ".", + "manifest": str(artifact_map.get("manifest", "manifest.json")), + "stages": stages, + } + + +def resolve_us_stage_artifact_path( + artifact_dir: str | Path, + stage_manifest: dict[str, Any], + stage_id: str, + artifact_key: str, +) -> Path: + """Resolve one artifact path from a stage manifest.""" + + for stage in stage_manifest.get("stages", ()): + if not isinstance(stage, dict) or stage.get("id") != stage_id: + continue + for artifact in stage.get("artifacts", ()): + if ( + isinstance(artifact, dict) + and artifact.get("key") == artifact_key + and artifact.get("path") + ): + path = Path(str(artifact["path"])) + if not path.is_absolute(): + path = Path(artifact_dir) / path + return path + raise KeyError(f"Stage artifact not found: {stage_id}.{artifact_key}") + + +def _stage_record( + contract: USPipelineStageContract, + *, + artifact_root: Path, + manifest: dict[str, Any], + assume_existing_artifact_keys: set[str], +) -> USStageRecord: + artifacts = [ + _artifact_record( + artifact, + artifact_root=artifact_root, + manifest=manifest, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ) + for artifact in contract.artifacts + ] + return { + "id": contract.id, + "step": contract.step, + "title": contract.title, + "purpose": contract.purpose, + "status": stage_status( + contract.id, + artifact_root=artifact_root, + manifest=manifest, + artifacts=artifacts, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ), + "consumes": list(contract.consumes), + "produces": list(contract.produces), + "inputs": _resource_records(contract.inputs), + "outputs": _resource_records(contract.outputs), + "artifacts": artifacts, + "diagnostics": list(contract.diagnostics), + "validations": cast( + list[USStageValidationRecord], + [validation.to_dict() for validation in contract.validations], + ), + "resume": { + "mode": contract.resume_mode, + "notes": contract.resume_notes, + }, + "metrics": stage_metrics(contract.id, manifest=manifest), + } + + +def _artifact_record( + artifact: USStageArtifactContract, + *, + artifact_root: Path, + manifest: dict[str, Any], + assume_existing_artifact_keys: set[str], +) -> USStageArtifactRecord: + artifacts = dict(manifest.get("artifacts", {})) + manifest_path = artifacts.get(artifact.key) + path = str(manifest_path) if manifest_path else artifact.path_hint + exists = False + if path: + resolved = Path(str(path)) + if not resolved.is_absolute(): + resolved = artifact_root / resolved + exists = resolved.exists() or artifact.key in assume_existing_artifact_keys + return { + **artifact.to_dict(), + "path": path, + "exists": exists, + "referenced": manifest_path is not None, + } + + +def _resource_records( + resources: tuple[USStageResourceContract, ...], +) -> list[USStageResourceRecord]: + return cast( + list[USStageResourceRecord], + [resource.to_dict() for resource in resources], + ) + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +__all__ = [ + "build_us_stage_manifest", + "resolve_us_stage_artifact_path", +] diff --git a/src/microplex_us/pipelines/stage_manifest_io.py b/src/microplex_us/pipelines/stage_manifest_io.py new file mode 100644 index 0000000..0fcf891 --- /dev/null +++ b/src/microplex_us/pipelines/stage_manifest_io.py @@ -0,0 +1,66 @@ +"""I/O helpers for saved-run US stage manifests.""" + +from __future__ import annotations + +import json +from collections.abc import Iterable, Mapping +from pathlib import Path +from typing import Any, cast + +from microplex_us.pipelines.stage_manifest_builder import build_us_stage_manifest +from microplex_us.pipelines.stage_manifest_types import ( + SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS, + USStageManifest, +) + + +def write_us_stage_manifest( + artifact_dir: str | Path, + output_path: str | Path, + *, + manifest_payload: dict[str, Any], + assume_existing_artifact_keys: Iterable[str] = (), +) -> Path: + """Write the canonical stage manifest for a saved US artifact bundle.""" + + destination = Path(output_path) + destination.parent.mkdir(parents=True, exist_ok=True) + write_json_atomically( + destination, + build_us_stage_manifest( + artifact_dir, + manifest_payload=manifest_payload, + assume_existing_artifact_keys=( + *tuple(assume_existing_artifact_keys), + "stage_manifest", + ), + ), + ) + return destination + + +def load_us_stage_manifest(path: str | Path) -> USStageManifest: + """Load a saved stage manifest and validate its schema version.""" + + manifest_path = Path(path) + payload = json.loads(manifest_path.read_text()) + if payload.get("schemaVersion") not in SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS: + raise RuntimeError( + f"Unsupported US stage manifest schema: {payload.get('schemaVersion')!r}" + ) + return cast(USStageManifest, payload) + + +def write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: + """Write JSON atomically through a sibling temporary file.""" + + temporary = path.with_suffix(path.suffix + ".tmp") + temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) + temporary.replace(path) + + +__all__ = [ + "load_us_stage_manifest", + "write_json_atomically", + "write_us_stage_manifest", +] diff --git a/src/microplex_us/pipelines/stage_manifest_types.py b/src/microplex_us/pipelines/stage_manifest_types.py new file mode 100644 index 0000000..7b5a417 --- /dev/null +++ b/src/microplex_us/pipelines/stage_manifest_types.py @@ -0,0 +1,163 @@ +"""Shared saved-run stage manifest schemas for US pipeline artifacts.""" + +from __future__ import annotations + +from typing import Any, Literal, TypedDict + +from microplex_us.pipelines.stage_contracts import ( + StageArtifactFormat, + StageArtifactHashMode, + StageResumeMode, +) + +US_STAGE_MANIFEST_SCHEMA_VERSION = 2 +SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS = frozenset({1, 2}) +US_STAGE_ARTIFACT_ROOT = "stage_artifacts" +US_POLICYENGINE_ENTITY_STAGE_ID = "06_policyengine_entities" +US_VALIDATION_STAGE_ID = "09_validation_benchmarking" + + +USStageMetricValue = str | int | float | bool | None + +USStageStatus = Literal[ + "ready", + "metadata_only", + "deferred", + "incomplete", + "missing", +] + +USStageValidationStatus = Literal["planned", "manual", "implemented"] + + +class USStageMetric(TypedDict): + """One compact metric shown for a saved stage.""" + + label: str + value: USStageMetricValue + + +class USStageArtifactRecord(TypedDict): + """Saved-run view of one stage artifact contract.""" + + key: str + description: str + path_hint: str | None + required: bool + resume_role: str | None + format: StageArtifactFormat + hash_mode: StageArtifactHashMode + path: str | None + exists: bool + referenced: bool + + +class USStageResumeRecord(TypedDict): + """Saved-run resume metadata for one stage.""" + + mode: StageResumeMode + notes: str + + +class USStageValidationRecord(TypedDict): + """Saved-run view of one planned or implemented validation.""" + + key: str + description: str + status: USStageValidationStatus + + +class USStageResourceRecord(TypedDict): + """Saved-run view of one structured stage input or output.""" + + key: str + description: str + kind: str + required: bool + stage_id: str | None + artifact_key: str | None + config_key: str | None + manifest_key: str | None + + +class USStageRecord(TypedDict): + """One stage entry in a US stage manifest.""" + + id: str + step: str + title: str + purpose: str + status: USStageStatus + consumes: list[str] + produces: list[str] + inputs: list[USStageResourceRecord] + outputs: list[USStageResourceRecord] + artifacts: list[USStageArtifactRecord] + diagnostics: list[str] + validations: list[USStageValidationRecord] + resume: USStageResumeRecord + metrics: list[USStageMetric] + + +class USStageManifest(TypedDict): + """Canonical saved-run stage manifest.""" + + schemaVersion: int + contractVersion: str + generatedAt: str | None + pipeline: str + artifactRoot: str + manifest: str + stages: list[USStageRecord] + + +class USDataFlowStageSummary(TypedDict): + """Stage summary embedded in the site-facing data-flow snapshot.""" + + id: str + step: str + title: str + summary: str + status: USStageStatus + metrics: list[USStageMetric] + outputs: list[str] + resumeMode: StageResumeMode + + +class USValidationEvidenceRecord(TypedDict): + """One validation or benchmarking evidence sidecar.""" + + key: str + path: str + exists: bool + + +class USValidationEvidenceManifest(TypedDict): + """Stage 9 evidence index.""" + + formatVersion: int + stageId: str + evidence: list[USValidationEvidenceRecord] + summaries: dict[str, Any] + + +__all__ = [ + "SUPPORTED_US_STAGE_MANIFEST_SCHEMA_VERSIONS", + "USDataFlowStageSummary", + "US_POLICYENGINE_ENTITY_STAGE_ID", + "US_STAGE_ARTIFACT_ROOT", + "US_STAGE_MANIFEST_SCHEMA_VERSION", + "US_VALIDATION_STAGE_ID", + "USStageArtifactRecord", + "USStageManifest", + "USStageMetric", + "USStageMetricValue", + "USStageRecord", + "USStageResourceRecord", + "USStageResumeRecord", + "USStageStatus", + "USStageValidationRecord", + "USStageValidationStatus", + "USValidationEvidenceManifest", + "USValidationEvidenceRecord", +] diff --git a/src/microplex_us/pipelines/stage_metrics.py b/src/microplex_us/pipelines/stage_metrics.py new file mode 100644 index 0000000..f09c4e8 --- /dev/null +++ b/src/microplex_us/pipelines/stage_metrics.py @@ -0,0 +1,96 @@ +"""Display metrics for saved US pipeline stage manifests.""" + +from __future__ import annotations + +from typing import Any + +from microplex_us.pipelines.stage_manifest_types import USStageMetric + + +def stage_metrics(stage_id: str, *, manifest: dict[str, Any]) -> list[USStageMetric]: + """Return compact display metrics for one saved stage.""" + + synthesis = dict(manifest.get("synthesis", {})) + calibration = dict(manifest.get("calibration", {})) + artifacts = dict(manifest.get("artifacts", {})) + harness = dict(manifest.get("policyengine_harness", {})) + native_scores = dict(manifest.get("policyengine_native_scores", {})) + rows = dict(manifest.get("rows", {})) + config = dict(manifest.get("config", {})) + if stage_id == "01_run_profile": + return [ + { + "label": "Target period", + "value": config.get("policyengine_target_period"), + }, + {"label": "Backend", "value": config.get("calibration_backend")}, + ] + if stage_id == "02_source_loading": + return [ + {"label": "Sources", "value": len(synthesis.get("source_names", ()))}, + ] + if stage_id == "03_source_planning": + return [{"label": "Scaffold", "value": synthesis.get("scaffold_source")}] + if stage_id == "04_seed_scaffold": + return [ + {"label": "Seed rows", "value": rows.get("seed")}, + {"label": "Scaffold", "value": synthesis.get("scaffold_source")}, + ] + if stage_id == "05_donor_integration_synthesis": + return [ + {"label": "Seed rows", "value": rows.get("seed")}, + { + "label": "Integrated vars", + "value": len(synthesis.get("donor_integrated_variables", ())), + }, + {"label": "Backend", "value": synthesis.get("backend")}, + {"label": "Synthetic rows", "value": rows.get("synthetic")}, + ] + if stage_id == "06_policyengine_entities": + return [ + { + "label": "Entity bundle", + "value": artifacts.get("policyengine_entity_tables"), + } + ] + if stage_id == "07_calibration": + return [ + {"label": "Backend", "value": calibration.get("backend")}, + {"label": "Supported", "value": calibration.get("n_supported_targets")}, + {"label": "Converged", "value": calibration.get("converged")}, + ] + if stage_id == "08_dataset_assembly": + return [{"label": "Dataset", "value": artifacts.get("policyengine_dataset")}] + if stage_id == "09_validation_benchmarking": + imputation_ablation = dict(manifest.get("imputation_ablation", {})) + return [ + { + "label": "Capped full oracle loss", + "value": calibration.get("full_oracle_capped_mean_abs_relative_error"), + }, + { + "label": "Full oracle loss", + "value": calibration.get("full_oracle_mean_abs_relative_error"), + }, + { + "label": "Harness delta", + "value": harness.get("mean_abs_relative_error_delta"), + }, + { + "label": "Native delta", + "value": native_scores.get("enhanced_cps_native_loss_delta"), + }, + {"label": "Win rate", "value": harness.get("target_win_rate")}, + { + "label": "Imputation MAE", + "value": imputation_ablation.get("production_mean_weighted_mae"), + }, + { + "label": "Imputation F1", + "value": imputation_ablation.get("production_mean_support_f1"), + }, + ] + return [] + + +__all__ = ["stage_metrics"] diff --git a/src/microplex_us/pipelines/stage_policyengine_artifacts.py b/src/microplex_us/pipelines/stage_policyengine_artifacts.py new file mode 100644 index 0000000..5e6345c --- /dev/null +++ b/src/microplex_us/pipelines/stage_policyengine_artifacts.py @@ -0,0 +1,56 @@ +"""PolicyEngine entity stage artifact I/O for US saved runs.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from microplex_us.pipelines.stage_manifest_io import write_json_atomically +from microplex_us.pipelines.stage_manifest_types import ( + US_POLICYENGINE_ENTITY_STAGE_ID, + US_STAGE_ARTIFACT_ROOT, +) +from microplex_us.policyengine.us import ( + PolicyEngineUSEntityTableBundle, + load_us_pipeline_checkpoint, + save_us_pipeline_checkpoint, +) + + +def write_us_policyengine_entity_stage_artifact( + bundle: PolicyEngineUSEntityTableBundle, + artifact_root: str | Path, +) -> Path: + """Persist a Stage 6 PE entity-table checkpoint under a saved-run root.""" + + stage_dir = save_us_pipeline_checkpoint( + bundle, + Path(artifact_root) / US_STAGE_ARTIFACT_ROOT / US_POLICYENGINE_ENTITY_STAGE_ID, + stage="post_microsim", + ) + metadata_path = stage_dir / "metadata.json" + metadata = json.loads(metadata_path.read_text()) + metadata["stageId"] = US_POLICYENGINE_ENTITY_STAGE_ID + write_json_atomically(metadata_path, metadata) + return metadata_path + + +def load_us_policyengine_entity_stage_artifact( + path: str | Path, +) -> tuple[PolicyEngineUSEntityTableBundle, dict[str, Any]]: + """Load a Stage 6 PE entity-table bundle artifact.""" + + input_path = Path(path) + checkpoint_dir = input_path if input_path.is_dir() else input_path.parent + bundle, metadata = load_us_pipeline_checkpoint( + checkpoint_dir, + expected_stage="post_microsim", + ) + return bundle, metadata + + +__all__ = [ + "load_us_policyengine_entity_stage_artifact", + "write_us_policyengine_entity_stage_artifact", +] diff --git a/src/microplex_us/pipelines/stage_readiness.py b/src/microplex_us/pipelines/stage_readiness.py new file mode 100644 index 0000000..f24f5ea --- /dev/null +++ b/src/microplex_us/pipelines/stage_readiness.py @@ -0,0 +1,457 @@ +"""Conditional-readiness reports for US Microplex saved runs.""" + +from __future__ import annotations + +import hashlib +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Literal, TypedDict, cast + +from microplex_us.pipelines.stage_artifacts import ( + USStageArtifactInventory, + USStageArtifactInventoryRecord, + build_us_stage_artifact_inventory, + load_us_stage_artifact_inventory, +) +from microplex_us.pipelines.stage_contracts import ( + US_STAGE_CONTRACT_VERSION, + config_keys_for_us_pipeline_stage, +) +from microplex_us.pipelines.stage_manifest import ( + USStageManifest, + USStageStatus, + build_us_stage_manifest, +) + +US_CONDITIONAL_READINESS_SCHEMA_VERSION = 1 +US_CONFIG_REUSE_IGNORED_KEYS = frozenset( + { + "pipeline_checkpoint_save_post_imputation_path", + "pipeline_checkpoint_save_post_microsim_path", + } +) + +USStageReadiness = Literal[ + "manual_replay", + "manual_resume", + "post_artifact_evidence", + "diagnostic_only", + "metadata_only", + "must_rerun", + "not_applicable", +] + +USStageCompatibility = Literal[ + "match", + "mismatch", + "missing_saved_config", + "not_evaluated", +] + + +class USConditionalReadinessStageRecord(TypedDict): + """Conditional-readiness view of one canonical stage.""" + + stageId: str + stageStep: str + stageTitle: str + status: USStageStatus + readiness: USStageReadiness + reason: str + compatibility: USStageCompatibility + reuseKey: str | None + savedConfigHash: str | None + requestedConfigHash: str | None + availableArtifacts: list[str] + missingArtifacts: list[str] + diagnosticArtifacts: list[str] + reloadableArtifacts: list[str] + + +class USConditionalReadinessReport(TypedDict): + """Saved-run conditional-readiness report.""" + + schemaVersion: int + contractVersion: str + generatedAt: str | None + pipeline: str + artifactRoot: str + manifest: str + artifactInventory: str | None + savedConfigHash: str | None + requestedConfigHash: str | None + stages: list[USConditionalReadinessStageRecord] + + +def build_us_stage_reuse_key( + stage_id: str, + manifest_payload: Mapping[str, Any], + artifact_inventory: USStageArtifactInventory | Mapping[str, Any], +) -> str | None: + """Return a deterministic reuse key for one stage, if any evidence exists.""" + + stage_artifacts = [ + artifact + for artifact in artifact_inventory.get("artifacts", ()) + if isinstance(artifact, dict) and artifact.get("stageId") == stage_id + ] + if not stage_artifacts: + return None + evidence = [ + { + "key": str(artifact.get("key")), + "path": artifact.get("path"), + "classification": artifact.get("classification"), + "hashStatus": artifact.get("hashStatus"), + "contentHash": artifact.get("contentHash"), + "sizeBytes": artifact.get("sizeBytes"), + "fileCount": artifact.get("fileCount"), + } + for artifact in stage_artifacts + if artifact.get("exists") or artifact.get("referenced") + ] + if not evidence: + return None + payload = { + "stageId": stage_id, + "configHash": _stage_config_hash(stage_id, manifest_payload.get("config")), + "artifacts": sorted(evidence, key=lambda item: item["key"]), + } + return _hash_json(payload) + + +def build_us_conditional_readiness_report( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, + artifact_inventory: USStageArtifactInventory | dict[str, Any] | None = None, + requested_config: Mapping[str, Any] | None = None, +) -> USConditionalReadinessReport: + """Build a report describing which stage outputs could be reused manually.""" + + artifact_root = Path(artifact_dir) + manifest = ( + dict(manifest_payload) + if manifest_payload is not None + else json.loads((artifact_root / "manifest.json").read_text()) + ) + stages = ( + dict(stage_manifest) + if stage_manifest is not None + else build_us_stage_manifest(artifact_root, manifest_payload=manifest) + ) + inventory = ( + dict(artifact_inventory) + if artifact_inventory is not None + else _load_or_build_inventory(artifact_root, manifest_payload=manifest) + ) + saved_config_hash = _config_hash(manifest.get("config")) + requested_config_hash = ( + _config_hash(requested_config) if requested_config is not None else None + ) + return { + "schemaVersion": US_CONDITIONAL_READINESS_SCHEMA_VERSION, + "contractVersion": US_STAGE_CONTRACT_VERSION, + "generatedAt": _optional_str(manifest.get("created_at")), + "pipeline": "us_microplex", + "artifactRoot": ".", + "manifest": str( + dict(manifest.get("artifacts", {})).get("manifest", "manifest.json") + ), + "artifactInventory": _optional_str( + dict(manifest.get("artifacts", {})).get("artifact_inventory") + ), + "savedConfigHash": saved_config_hash, + "requestedConfigHash": requested_config_hash, + "stages": [ + _readiness_stage_record( + stage, + manifest=manifest, + inventory=inventory, + requested_config=requested_config, + ) + for stage in stages.get("stages", ()) + if isinstance(stage, dict) + ], + } + + +def write_us_conditional_readiness_report( + artifact_dir: str | Path, + output_path: str | Path, + *, + manifest_payload: dict[str, Any] | None = None, + stage_manifest: USStageManifest | dict[str, Any] | None = None, + artifact_inventory: USStageArtifactInventory | dict[str, Any] | None = None, + requested_config: Mapping[str, Any] | None = None, +) -> Path: + """Write a conditional-readiness report sidecar for one saved run.""" + + destination = Path(output_path) + destination.parent.mkdir(parents=True, exist_ok=True) + _write_json_atomically( + destination, + build_us_conditional_readiness_report( + artifact_dir, + manifest_payload=manifest_payload, + stage_manifest=stage_manifest, + artifact_inventory=artifact_inventory, + requested_config=requested_config, + ), + ) + return destination + + +def load_us_conditional_readiness_report( + path: str | Path, +) -> USConditionalReadinessReport: + """Load a saved conditional-readiness report.""" + + report_path = Path(path) + payload = json.loads(report_path.read_text()) + if payload.get("schemaVersion") != US_CONDITIONAL_READINESS_SCHEMA_VERSION: + raise RuntimeError( + "Unsupported US conditional-readiness report schema: " + f"{payload.get('schemaVersion')!r}" + ) + return cast(USConditionalReadinessReport, payload) + + +def _readiness_stage_record( + stage: Mapping[str, Any], + *, + manifest: Mapping[str, Any], + inventory: Mapping[str, Any], + requested_config: Mapping[str, Any] | None, +) -> USConditionalReadinessStageRecord: + stage_id = str(stage.get("id", "")) + saved_stage_config_hash = _stage_config_hash(stage_id, manifest.get("config")) + requested_stage_config_hash = ( + _stage_config_hash(stage_id, requested_config) + if requested_config is not None + else None + ) + compatibility = _config_compatibility( + saved_stage_config_hash, + requested_stage_config_hash, + requested_config_supplied=requested_config is not None, + ) + artifacts = _inventory_artifacts_for_stage(inventory, stage_id) + available = [ + _artifact_label(artifact) + for artifact in artifacts + if bool(artifact.get("exists")) + ] + missing = [ + _artifact_label(artifact) + for artifact in artifacts + if artifact.get("classification") in {"missing_required", "missing_optional"} + ] + diagnostic = [ + _artifact_label(artifact) + for artifact in artifacts + if artifact.get("classification") == "diagnostic_only" + ] + reloadable = [ + _artifact_label(artifact) + for artifact in artifacts + if artifact.get("classification") + in {"manual_replay", "manual_resume", "post_artifact_evidence"} + ] + readiness, reason = _stage_readiness( + stage, + artifacts, + compatibility=compatibility, + stage8_dataset_available=_stage8_dataset_available(inventory), + ) + return { + "stageId": stage_id, + "stageStep": str(stage.get("step", "")), + "stageTitle": str(stage.get("title", "")), + "status": cast(USStageStatus, stage.get("status", "missing")), + "readiness": readiness, + "reason": reason, + "compatibility": compatibility, + "reuseKey": build_us_stage_reuse_key(stage_id, manifest, inventory), + "savedConfigHash": saved_stage_config_hash, + "requestedConfigHash": requested_stage_config_hash, + "availableArtifacts": available, + "missingArtifacts": missing, + "diagnosticArtifacts": diagnostic, + "reloadableArtifacts": reloadable, + } + + +def _stage_readiness( + stage: Mapping[str, Any], + artifacts: list[USStageArtifactInventoryRecord], + *, + compatibility: USStageCompatibility, + stage8_dataset_available: bool, +) -> tuple[USStageReadiness, str]: + stage_id = str(stage.get("id", "")) + status = stage.get("status") + if stage_id == "09_validation_benchmarking" and status == "deferred": + if stage8_dataset_available: + return ( + "post_artifact_evidence", + "Stage 8 dataset is available for validation or benchmark evidence.", + ) + return ( + "must_rerun", + "Validation is deferred and no Stage 8 dataset is available.", + ) + if compatibility == "mismatch": + return ( + "must_rerun", + "Requested configuration does not match this stage's saved run inputs.", + ) + classifications = { + str(artifact.get("classification")) + for artifact in artifacts + if bool(artifact.get("exists")) + } + for readiness in ("manual_resume", "manual_replay", "post_artifact_evidence"): + if readiness in classifications: + return cast(USStageReadiness, readiness), ( + f"Stage has existing {readiness.replace('_', ' ')} artifacts." + ) + if "diagnostic_only" in classifications: + return ( + "diagnostic_only", + "Stage has diagnostic artifacts but no replay boundary.", + ) + if status in {"missing", "incomplete"}: + return "must_rerun", f"Stage status is {status}." + if status == "metadata_only": + return "metadata_only", "Stage has metadata but no reloadable artifact." + return "not_applicable", "No reusable artifact boundary is available." + + +def _inventory_artifacts_for_stage( + inventory: Mapping[str, Any], + stage_id: str, +) -> list[USStageArtifactInventoryRecord]: + return [ + cast(USStageArtifactInventoryRecord, artifact) + for artifact in inventory.get("artifacts", ()) + if isinstance(artifact, dict) and artifact.get("stageId") == stage_id + ] + + +def _stage8_dataset_available(inventory: Mapping[str, Any]) -> bool: + return any( + isinstance(artifact, dict) + and artifact.get("stageId") == "08_dataset_assembly" + and artifact.get("key") == "policyengine_dataset" + and bool(artifact.get("exists")) + for artifact in inventory.get("artifacts", ()) + ) + + +def _load_or_build_inventory( + artifact_root: Path, + *, + manifest_payload: dict[str, Any], +) -> USStageArtifactInventory: + inventory_name = dict(manifest_payload.get("artifacts", {})).get( + "artifact_inventory" + ) + if isinstance(inventory_name, str): + inventory_path = Path(inventory_name) + if not inventory_path.is_absolute(): + inventory_path = artifact_root / inventory_path + if inventory_path.exists(): + return load_us_stage_artifact_inventory(inventory_path) + return build_us_stage_artifact_inventory( + artifact_root, + manifest_payload=manifest_payload, + ) + + +def _config_compatibility( + saved_config_hash: str | None, + requested_config_hash: str | None, + *, + requested_config_supplied: bool, +) -> USStageCompatibility: + if not requested_config_supplied: + return "not_evaluated" + if saved_config_hash is None: + return "missing_saved_config" + return "match" if saved_config_hash == requested_config_hash else "mismatch" + + +def _config_hash(config: Any) -> str | None: + if not isinstance(config, Mapping): + return None + return _hash_json(_canonical_config(config)) + + +def _stage_config_hash(stage_id: str, config: Any) -> str | None: + keys = config_keys_for_us_pipeline_stage(stage_id) + if not keys: + return _hash_json({}) + if not isinstance(config, Mapping): + return None + scoped = {key: config.get(key) for key in keys if key in config} + return _hash_json(_canonical_config(scoped)) + + +def _canonical_config(config: Mapping[str, Any]) -> dict[str, Any]: + return { + str(key): _normalize_config_value(value) + for key, value in sorted(config.items()) + if key not in US_CONFIG_REUSE_IGNORED_KEYS + } + + +def _normalize_config_value(value: Any) -> Any: + if isinstance(value, Mapping): + return { + str(key): _normalize_config_value(item) + for key, item in sorted(value.items()) + } + if isinstance(value, (list, tuple)): + return [_normalize_config_value(item) for item in value] + if isinstance(value, Path): + return str(value) + return value + + +def _hash_json(payload: Any) -> str: + return hashlib.sha256( + json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + ).hexdigest() + + +def _artifact_label(artifact: Mapping[str, Any]) -> str: + return f"{artifact.get('stageId')}.{artifact.get('key')}" + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +def _write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: + temporary = path.with_suffix(path.suffix + ".tmp") + temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) + temporary.replace(path) + + +__all__ = [ + "US_CONDITIONAL_READINESS_SCHEMA_VERSION", + "US_CONFIG_REUSE_IGNORED_KEYS", + "USConditionalReadinessReport", + "USConditionalReadinessStageRecord", + "USStageCompatibility", + "USStageReadiness", + "build_us_conditional_readiness_report", + "build_us_stage_reuse_key", + "load_us_conditional_readiness_report", + "write_us_conditional_readiness_report", +] diff --git a/src/microplex_us/pipelines/stage_run.py b/src/microplex_us/pipelines/stage_run.py new file mode 100644 index 0000000..5f2ee44 --- /dev/null +++ b/src/microplex_us/pipelines/stage_run.py @@ -0,0 +1,1466 @@ +"""Shared stage-run writer for US Microplex saved-run manifests.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from pathlib import Path +from typing import Any, Literal + +from microplex_us.pipelines.data_flow_snapshot import ( + write_us_microplex_data_flow_snapshot, +) +from microplex_us.pipelines.stage_artifacts import ( + build_us_stage_artifact_inventory, + write_us_stage_artifact_inventory, +) +from microplex_us.pipelines.stage_contracts import ( + US_CANONICAL_STAGE_IDS, + US_STAGE_CONTRACT_VERSION, + StageArtifactFormat, + StageArtifactResumeRole, + StageResourceKind, + USStageResourceContract, + get_us_pipeline_stage_contract, + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) +from microplex_us.pipelines.stage_manifest import ( + build_us_validation_evidence_manifest, + write_us_stage_manifest, + write_us_validation_evidence_manifest, +) +from microplex_us.pipelines.stage_readiness import ( + write_us_conditional_readiness_report, +) + +US_STAGE_OUTPUT_MANIFEST_SCHEMA_VERSION = 1 + +USArtifactCategory = Literal[ + "required_output", + "diagnostic", + "auxiliary", + "derived", +] + + +@dataclass(frozen=True) +class USArtifactRef: + """Reference to one artifact owned by a stage output manifest.""" + + key: str + path: str | Path + format: StageArtifactFormat = "unknown" + required: bool = False + category: USArtifactCategory = "required_output" + resume_role: StageArtifactResumeRole | None = None + assume_exists: bool = False + exists: bool | None = None + + def resolved_path(self, artifact_root: str | Path) -> Path: + path = Path(self.path) + if not path.is_absolute(): + path = Path(artifact_root) / path + return path + + def exists_under(self, artifact_root: str | Path) -> bool: + if self.assume_exists: + return True + if self.exists is not None: + return self.exists + return self.resolved_path(artifact_root).exists() + + def relative_path(self, artifact_root: str | Path) -> str: + path = self.resolved_path(artifact_root) + try: + return str(path.relative_to(Path(artifact_root))) + except ValueError: + return str(path) + + def to_dict(self, artifact_root: str | Path | None = None) -> dict[str, Any]: + payload = asdict(self) + payload["path"] = ( + self.relative_path(artifact_root) + if artifact_root is not None + else str(self.path) + ) + if artifact_root is not None: + payload["exists"] = self.exists_under(artifact_root) + return payload + + +@dataclass(frozen=True) +class USAuxiliaryArtifact: + """Optional artifact declared by a stage contract.""" + + key: str + path: str | Path + format: StageArtifactFormat = "unknown" + description: str = "" + assume_exists: bool = False + + def as_artifact_ref(self) -> USArtifactRef: + return USArtifactRef( + key=self.key, + path=self.path, + format=self.format, + category="auxiliary", + assume_exists=self.assume_exists, + ) + + +@dataclass(frozen=True) +class USDiagnosticOutput: + """Diagnostic output exposed by a stage manifest.""" + + key: str + description: str = "" + path: str | Path | None = None + summary: Mapping[str, Any] = field(default_factory=dict) + + def to_dict(self, artifact_root: str | Path | None = None) -> dict[str, Any]: + path = None + if self.path is not None: + resolved = Path(self.path) + if artifact_root is not None and not resolved.is_absolute(): + resolved = Path(artifact_root) / resolved + if artifact_root is not None: + try: + path = str(resolved.relative_to(Path(artifact_root))) + except ValueError: + path = str(resolved) + else: + path = str(self.path) + return { + "key": self.key, + "description": self.description, + "path": path, + "summary": dict(self.summary), + } + + +@dataclass(frozen=True) +class USStageInputOverride: + """Explicit override for a stage input that is not provided by the prior stage.""" + + stage_id: str + key: str + path: str | Path + reason: str | None = None + + def to_dict(self, artifact_root: str | Path | None = None) -> dict[str, Any]: + path = Path(self.path) + path_text = str(path) + if artifact_root is not None and not path.is_absolute(): + path_text = str(path) + return { + "stageId": self.stage_id, + "key": self.key, + "path": path_text, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class USStageInputValidationSettings: + """Stage-specific settings for typed input boundary validation.""" + + stage_id: str + require_previous_stage_manifest: bool = True + enforce_required_stage_inputs: bool = True + enforce_only_when_stage_complete: bool = True + enforced_resource_kinds: tuple[StageResourceKind, ...] = ( + "artifact", + "manifest", + "stage_output", + ) + + +@dataclass(frozen=True) +class USStageOutputManifest: + """Base type for one typed stage output manifest.""" + + schema_version: int = US_STAGE_OUTPUT_MANIFEST_SCHEMA_VERSION + contract_version: str = US_STAGE_CONTRACT_VERSION + input_stage_manifest: str | Path | None = None + diagnostics: Mapping[str, USDiagnosticOutput] = field(default_factory=dict) + auxiliary_artifacts: Mapping[str, USAuxiliaryArtifact] = field(default_factory=dict) + metadata: Mapping[str, Any] = field(default_factory=dict) + complete: bool = True + stage_id: str = field(default="", init=False) + + def required_output_keys(self) -> tuple[str, ...]: + """Return required output keys from the canonical stage contract.""" + + contract = get_us_pipeline_stage_contract(self.stage_id) + return tuple(resource.key for resource in contract.outputs if resource.required) + + def artifact_refs(self) -> dict[str, USArtifactRef]: + """Return artifact references carried by this stage output manifest.""" + + refs: dict[str, USArtifactRef] = {} + for item in fields(self): + value = getattr(self, item.name) + if isinstance(value, USArtifactRef): + refs[value.key] = value + for artifact in self.auxiliary_artifacts.values(): + refs[artifact.key] = artifact.as_artifact_ref() + return refs + + def missing_required_outputs(self, artifact_root: str | Path) -> tuple[str, ...]: + """Return required output keys not provided or not present on disk.""" + + missing: list[str] = [] + for key in self.required_output_keys(): + value = getattr(self, key, None) + if _required_output_is_missing(value, artifact_root): + missing.append(key) + return tuple(missing) + + def to_dict( + self, + artifact_root: str | Path | None = None, + *, + input_stage_manifest: str | None = None, + input_overrides: tuple[USStageInputOverride, ...] = (), + ) -> dict[str, Any]: + """Serialize this typed output manifest.""" + + diagnostics = { + key: diagnostic.to_dict(artifact_root) + for key, diagnostic in self.diagnostics.items() + } + auxiliary = { + key: artifact.as_artifact_ref().to_dict(artifact_root) + for key, artifact in self.auxiliary_artifacts.items() + } + output_fields = { + item.name: _serialize_value(getattr(self, item.name), artifact_root) + for item in fields(self) + if item.name + not in { + "schema_version", + "contract_version", + "input_stage_manifest", + "diagnostics", + "auxiliary_artifacts", + "metadata", + "complete", + "stage_id", + } + } + return { + "schemaVersion": self.schema_version, + "contractVersion": self.contract_version, + "stageId": self.stage_id, + "complete": self.complete, + "inputStageManifest": input_stage_manifest + or _optional_str(self.input_stage_manifest), + "inputOverrides": [ + override.to_dict(artifact_root) for override in input_overrides + ], + "requiredOutputs": list(self.required_output_keys()), + "missingRequiredOutputs": ( + list(self.missing_required_outputs(artifact_root)) + if artifact_root is not None + else [] + ), + "outputs": output_fields, + "diagnostics": diagnostics, + "auxiliaryArtifacts": auxiliary, + "metadata": dict(self.metadata), + } + + +@dataclass(frozen=True) +class USRunProfileOutputs(USStageOutputManifest): + stage_id: str = field(default="01_run_profile", init=False) + manifest: USArtifactRef | None = None + resolved_config: Mapping[str, Any] = field(default_factory=dict) + provider_query_plan: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USSourceLoadingOutputs(USStageOutputManifest): + stage_id: str = field(default="02_source_loading", init=False) + observation_frame_summary: Mapping[str, Any] = field(default_factory=dict) + source_descriptors: tuple[str, ...] = () + source_relationships: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USSourcePlanningOutputs(USStageOutputManifest): + stage_id: str = field(default="03_source_planning", init=False) + source_plan: USArtifactRef | None = None + scaffold_selection: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USSeedScaffoldOutputs(USStageOutputManifest): + stage_id: str = field(default="04_seed_scaffold", init=False) + scaffold_seed_data: USArtifactRef | None = None + seed_schema_metadata: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USDonorSynthesisOutputs(USStageOutputManifest): + stage_id: str = field(default="05_donor_integration_synthesis", init=False) + seed_data: USArtifactRef | None = None + synthetic_data: USArtifactRef | None = None + synthesis_metadata: Mapping[str, Any] = field(default_factory=dict) + source_weight_diagnostics: USArtifactRef | None = None + + +@dataclass(frozen=True) +class USPolicyEngineEntityOutputs(USStageOutputManifest): + stage_id: str = field(default="06_policyengine_entities", init=False) + policyengine_entity_tables: USArtifactRef | None = None + materialized_policyengine_inputs: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USCalibrationOutputs(USStageOutputManifest): + stage_id: str = field(default="07_calibration", init=False) + calibrated_data: USArtifactRef | None = None + targets: USArtifactRef | None = None + calibration_summary: USArtifactRef | None = None + target_ledger: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class USDatasetAssemblyOutputs(USStageOutputManifest): + stage_id: str = field(default="08_dataset_assembly", init=False) + policyengine_dataset: USArtifactRef | None = None + stage_manifest: USArtifactRef | None = None + data_flow_snapshot: USArtifactRef | None = None + artifact_inventory: USArtifactRef | None = None + conditional_readiness: USArtifactRef | None = None + + +@dataclass(frozen=True) +class USValidationBenchmarkingOutputs(USStageOutputManifest): + stage_id: str = field(default="09_validation_benchmarking", init=False) + validation_evidence: USArtifactRef | None = None + benchmark_summary: Mapping[str, Any] = field(default_factory=dict) + policyengine_harness: USArtifactRef | None = None + policyengine_native_scores: USArtifactRef | None = None + policyengine_native_audit: USArtifactRef | None = None + imputation_ablation: USArtifactRef | None = None + child_tax_unit_agi_drift: USArtifactRef | None = None + + +US_STAGE_OUTPUT_MANIFEST_TYPES: dict[str, type[USStageOutputManifest]] = { + "01_run_profile": USRunProfileOutputs, + "02_source_loading": USSourceLoadingOutputs, + "03_source_planning": USSourcePlanningOutputs, + "04_seed_scaffold": USSeedScaffoldOutputs, + "05_donor_integration_synthesis": USDonorSynthesisOutputs, + "06_policyengine_entities": USPolicyEngineEntityOutputs, + "07_calibration": USCalibrationOutputs, + "08_dataset_assembly": USDatasetAssemblyOutputs, + "09_validation_benchmarking": USValidationBenchmarkingOutputs, +} + + +class USStageRunWriter: + """Validate and write typed US stage output manifests as one run.""" + + def __init__( + self, + artifact_root: str | Path, + *, + manifest_payload: Mapping[str, Any] | None = None, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), + ) -> None: + self.artifact_root = Path(artifact_root) + self.manifest_payload: dict[str, Any] = dict(manifest_payload or {}) + self.allow_stage_input_overrides = allow_stage_input_overrides + self.stage_input_overrides = tuple(stage_input_overrides) + if self.stage_input_overrides and not self.allow_stage_input_overrides: + raise ValueError( + "Stage input overrides require allow_stage_input_overrides=True" + ) + for override in self.stage_input_overrides: + _validate_us_stage_input_override(override) + self._recorded: dict[str, USStageOutputManifest] = {} + self._input_validator = USStageInputValidator( + self.artifact_root, + self._recorded, + allow_stage_input_overrides=self.allow_stage_input_overrides, + stage_input_overrides=self.stage_input_overrides, + ) + + @property + def recorded_stages(self) -> tuple[USStageOutputManifest, ...]: + """Return recorded stages in canonical order.""" + + return tuple( + self._recorded[stage_id] + for stage_id in US_CANONICAL_STAGE_IDS + if stage_id in self._recorded + ) + + def update(self, outputs: USStageOutputManifest) -> None: + """Record one whole typed stage output manifest.""" + + self.record_stage(outputs) + + def record_stage(self, outputs: USStageOutputManifest) -> None: + """Validate and record one whole typed stage output manifest.""" + + self.validate_stage(outputs) + self.validate_transition(outputs) + self._recorded[outputs.stage_id] = outputs + + def validate_stage(self, outputs: USStageOutputManifest) -> None: + """Validate one typed stage output manifest against its contract.""" + + expected_type = US_STAGE_OUTPUT_MANIFEST_TYPES.get(outputs.stage_id) + if expected_type is None: + raise KeyError(f"Unknown US stage output manifest: {outputs.stage_id}") + if not isinstance(outputs, expected_type): + raise TypeError( + f"{outputs.stage_id} must use {expected_type.__name__}, " + f"got {type(outputs).__name__}" + ) + get_us_pipeline_stage_contract(outputs.stage_id) + if not outputs.diagnostics: + raise ValueError(f"{outputs.stage_id} does not expose diagnostics") + missing = outputs.missing_required_outputs(self.artifact_root) + if outputs.complete and missing: + raise ValueError( + f"{outputs.stage_id} is marked complete but is missing required " + f"outputs: {', '.join(missing)}" + ) + contract_artifact_keys = { + artifact.key + for artifact in get_us_pipeline_stage_contract(outputs.stage_id).artifacts + } + for artifact in outputs.auxiliary_artifacts.values(): + if artifact.key not in contract_artifact_keys: + raise KeyError( + f"{outputs.stage_id} auxiliary artifact {artifact.key!r} " + "is not declared by the stage contract" + ) + for artifact in outputs.artifact_refs().values(): + if artifact.key not in contract_artifact_keys: + raise KeyError( + f"{outputs.stage_id} artifact {artifact.key!r} is not declared " + "by the stage contract" + ) + + def validate_transition(self, outputs: USStageOutputManifest) -> None: + """Validate that a stage consumes the previous stage output manifest.""" + + self._input_validator.validate(outputs) + + def write_manifest_files(self) -> dict[str, Any]: + """Write per-stage manifests and derived aggregate run manifests.""" + + self.artifact_root.mkdir(parents=True, exist_ok=True) + manifest = self._materialize_manifest_payload() + stage_manifest_path = resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "stage_manifest", + ) + data_flow_snapshot_path = resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + ) + artifact_inventory_path = resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "artifact_inventory", + ) + conditional_readiness_path = resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "conditional_readiness", + ) + manifest_path = resolve_us_stage_artifact_contract_path( + self.artifact_root, + "01_run_profile", + "manifest", + ) + validation_evidence_name = dict(manifest.get("artifacts", {})).get( + "validation_evidence" + ) + + _write_json_atomically(manifest_path, manifest) + if validation_evidence_name: + validation_evidence_path = self._resolve_path(validation_evidence_name) + write_us_validation_evidence_manifest( + self.artifact_root, + validation_evidence_path, + manifest_payload=manifest, + ) + write_us_microplex_data_flow_snapshot( + self.artifact_root, + data_flow_snapshot_path, + manifest_payload=manifest, + assume_existing_stage_artifact_keys=( + "stage_manifest", + "artifact_inventory", + "conditional_readiness", + ), + ) + write_us_stage_manifest( + self.artifact_root, + stage_manifest_path, + manifest_payload=manifest, + assume_existing_artifact_keys=( + "artifact_inventory", + "conditional_readiness", + ), + ) + readiness_inventory = build_us_stage_artifact_inventory( + self.artifact_root, + manifest_payload=manifest, + assume_existing_artifact_keys=( + "artifact_inventory", + "conditional_readiness", + ), + ) + write_us_conditional_readiness_report( + self.artifact_root, + conditional_readiness_path, + manifest_payload=manifest, + artifact_inventory=readiness_inventory, + ) + write_us_stage_artifact_inventory( + self.artifact_root, + artifact_inventory_path, + manifest_payload=manifest, + assume_existing_artifact_keys=("artifact_inventory",), + ) + return manifest + + def _materialize_manifest_payload(self) -> dict[str, Any]: + manifest = dict(self.manifest_payload) + artifacts = dict(manifest.get("artifacts", {})) + stage_manifest_paths: dict[str, str] = {} + + for stage_id in US_CANONICAL_STAGE_IDS: + outputs = self._recorded.get(stage_id) + if outputs is None: + continue + stage_manifest_path = self._stage_output_manifest_path(stage_id) + stage_manifest_path.parent.mkdir(parents=True, exist_ok=True) + stage_manifest_paths[stage_id] = str( + stage_manifest_path.relative_to(self.artifact_root) + ) + for artifact in outputs.artifact_refs().values(): + artifacts[artifact.key] = artifact.relative_path(self.artifact_root) + + self._ensure_aggregate_artifact_paths(artifacts) + manifest["artifacts"] = artifacts + manifest["stage_output_manifests"] = stage_manifest_paths + manifest.setdefault("diagnostics", {}) + for stage_id, outputs in self._recorded.items(): + manifest["diagnostics"].setdefault( + stage_id, + { + key: diagnostic.to_dict(self.artifact_root) + for key, diagnostic in outputs.diagnostics.items() + }, + ) + for stage_id, outputs in self._recorded.items(): + stage_manifest_path = self._stage_output_manifest_path(stage_id) + _write_json_atomically( + stage_manifest_path, + outputs.to_dict( + self.artifact_root, + input_stage_manifest=self._previous_stage_manifest_ref(stage_id), + input_overrides=self._overrides_for_stage(stage_id), + ), + ) + self.manifest_payload = manifest + return manifest + + def _ensure_aggregate_artifact_paths(self, artifacts: dict[str, Any]) -> None: + artifacts.setdefault( + "stage_manifest", + resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "stage_manifest", + ).name, + ) + artifacts.setdefault( + "data_flow_snapshot", + resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "data_flow_snapshot", + ).name, + ) + artifacts.setdefault( + "artifact_inventory", + str( + resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "artifact_inventory", + ).relative_to(self.artifact_root) + ), + ) + artifacts.setdefault( + "conditional_readiness", + str( + resolve_us_stage_artifact_contract_path( + self.artifact_root, + "08_dataset_assembly", + "conditional_readiness", + ).relative_to(self.artifact_root) + ), + ) + + def _stage_output_manifest_path(self, stage_id: str) -> Path: + return self.artifact_root / "stage_artifacts" / "manifests" / f"{stage_id}.json" + + def _previous_stage_manifest_ref(self, stage_id: str) -> str | None: + stage_index = US_CANONICAL_STAGE_IDS.index(stage_id) + if stage_index == 0: + return None + previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] + if previous_stage_id not in self._recorded: + return None + return str( + self._stage_output_manifest_path(previous_stage_id).relative_to( + self.artifact_root + ) + ) + + def _overrides_for_stage(self, stage_id: str) -> tuple[USStageInputOverride, ...]: + return tuple( + override + for override in self.stage_input_overrides + if override.stage_id == stage_id + ) + + def _resolve_path(self, value: Any) -> Path: + path = Path(str(value)) + if not path.is_absolute(): + path = self.artifact_root / path + return path + + +class USStageInputValidator: + """Validate stage input seams against typed stage manifests and overrides.""" + + def __init__( + self, + artifact_root: str | Path, + recorded_stages: Mapping[str, USStageOutputManifest], + *, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), + settings_by_stage: Mapping[str, USStageInputValidationSettings] | None = None, + ) -> None: + self.artifact_root = Path(artifact_root) + self.recorded_stages = recorded_stages + self.allow_stage_input_overrides = allow_stage_input_overrides + self.stage_input_overrides = tuple(stage_input_overrides) + self.settings_by_stage = dict( + settings_by_stage or default_us_stage_input_validation_settings() + ) + + def validate(self, outputs: USStageOutputManifest) -> None: + """Validate one stage's required input boundary.""" + + stage_index = US_CANONICAL_STAGE_IDS.index(outputs.stage_id) + if stage_index == 0: + return + settings = self.settings_by_stage[outputs.stage_id] + previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] + required_stage_inputs = tuple( + resource + for resource in get_us_pipeline_stage_contract(outputs.stage_id).inputs + if self._enforces_resource(resource, settings) + ) + missing_inputs = tuple( + self._resource_label(resource) + for resource in required_stage_inputs + if not self._resource_is_satisfied(resource, outputs) + ) + previous_inputs = tuple( + resource + for resource in required_stage_inputs + if resource.stage_id == previous_stage_id + ) + previous_stage_available = self._stage_manifest_available( + previous_stage_id, + outputs, + ) + previous_stage_overridden = bool(previous_inputs) and all( + self._override_satisfies(outputs.stage_id, resource) + for resource in previous_inputs + ) + if ( + settings.require_previous_stage_manifest + and not previous_stage_available + and not previous_stage_overridden + ): + detail = ( + f"; missing required inputs: {', '.join(missing_inputs)}" + if missing_inputs + else "" + ) + raise ValueError( + f"{outputs.stage_id} requires {previous_stage_id} output manifest " + "or explicit overrides for all required inputs from that stage" + f"{detail}" + ) + if ( + settings.enforce_required_stage_inputs + and missing_inputs + and (not settings.enforce_only_when_stage_complete or outputs.complete) + ): + raise ValueError( + f"{outputs.stage_id} is missing required stage input(s): " + f"{', '.join(missing_inputs)}" + ) + + def _enforces_resource( + self, + resource: USStageResourceContract, + settings: USStageInputValidationSettings, + ) -> bool: + return ( + resource.required + and resource.stage_id is not None + and resource.kind in settings.enforced_resource_kinds + ) + + def _resource_is_satisfied( + self, + resource: USStageResourceContract, + outputs: USStageOutputManifest, + ) -> bool: + if self._override_satisfies(outputs.stage_id, resource): + return True + source_stage_id = resource.stage_id + if source_stage_id is None: + return False + recorded_outputs = self.recorded_stages.get(source_stage_id) + if recorded_outputs is not None: + return not _required_output_is_missing( + getattr(recorded_outputs, resource.key, None), + self.artifact_root, + ) + serialized_stage = self._serialized_input_stage_manifest(outputs) + if ( + serialized_stage is not None + and serialized_stage.get("stageId") == source_stage_id + ): + return _serialized_output_key_is_available( + serialized_stage, + resource.key, + ) + return False + + def _stage_manifest_available( + self, + stage_id: str, + outputs: USStageOutputManifest, + ) -> bool: + if stage_id in self.recorded_stages: + return True + serialized_stage = self._serialized_input_stage_manifest(outputs) + return ( + serialized_stage is not None and serialized_stage.get("stageId") == stage_id + ) + + def _serialized_input_stage_manifest( + self, + outputs: USStageOutputManifest, + ) -> Mapping[str, Any] | None: + if outputs.input_stage_manifest is None: + return None + stage_index = US_CANONICAL_STAGE_IDS.index(outputs.stage_id) + if stage_index == 0: + return None + previous_stage_id = US_CANONICAL_STAGE_IDS[stage_index - 1] + path = Path(outputs.input_stage_manifest) + if not path.is_absolute(): + path = self.artifact_root / path + expected_path = ( + self.artifact_root + / "stage_artifacts" + / "manifests" + / f"{previous_stage_id}.json" + ) + if path != expected_path or not path.exists(): + return None + try: + payload = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return None + return payload if isinstance(payload, Mapping) else None + + def _override_satisfies( + self, + stage_id: str, + resource: USStageResourceContract, + ) -> bool: + if not self.allow_stage_input_overrides: + return False + return any( + override.stage_id == stage_id and override.key == resource.key + for override in self.stage_input_overrides + ) + + @staticmethod + def _resource_label(resource: USStageResourceContract) -> str: + return f"{resource.stage_id}.{resource.key}" + + +def build_us_stage_output_manifests_from_artifact_manifest( + artifact_root: str | Path, + manifest_payload: Mapping[str, Any], +) -> tuple[USStageOutputManifest, ...]: + """Build typed stage output manifests from an existing artifact manifest.""" + + root = Path(artifact_root) + manifest = dict(manifest_payload) + synthesis = dict(manifest.get("synthesis", {})) + rows = dict(manifest.get("rows", {})) + config = dict(manifest.get("config", {})) + artifacts = dict(manifest.get("artifacts", {})) + source_names = tuple( + str(source) + for source in synthesis.get("source_names", ()) + if isinstance(source, str) + ) + benchmark_summary, has_benchmark_evidence = _benchmark_summary(root, manifest) + has_benchmark = bool(benchmark_summary) and has_benchmark_evidence + has_dataset = _artifact_exists(root, artifacts, "policyengine_dataset") + return ( + USRunProfileOutputs( + manifest=_artifact_ref( + root, + {"manifest": artifacts.get("manifest", "manifest.json")}, + "manifest", + "01_run_profile", + assume_exists=True, + ), + resolved_config=config, + provider_query_plan={"source_names": list(source_names)}, + diagnostics=_diagnostics("01_run_profile", manifest), + complete=bool(config), + ), + USSourceLoadingOutputs( + observation_frame_summary={"source_count": len(source_names)}, + source_descriptors=source_names, + source_relationships={"status": "summarized"}, + diagnostics=_diagnostics("02_source_loading", manifest), + complete=bool(source_names), + ), + USSourcePlanningOutputs( + source_plan=_artifact_ref( + root, artifacts, "source_plan", "03_source_planning" + ), + scaffold_selection={"scaffold_source": synthesis.get("scaffold_source")}, + diagnostics=_diagnostics("03_source_planning", manifest), + complete=_artifact_exists(root, artifacts, "source_plan"), + ), + USSeedScaffoldOutputs( + scaffold_seed_data=_artifact_ref( + root, + artifacts, + "scaffold_seed_data", + "04_seed_scaffold", + ), + seed_schema_metadata={"seed_rows": rows.get("seed")}, + diagnostics=_diagnostics("04_seed_scaffold", manifest), + complete=_artifact_exists(root, artifacts, "scaffold_seed_data"), + ), + USDonorSynthesisOutputs( + seed_data=_artifact_ref( + root, + artifacts, + "seed_data", + "05_donor_integration_synthesis", + ), + synthetic_data=_artifact_ref( + root, + artifacts, + "synthetic_data", + "05_donor_integration_synthesis", + ), + synthesis_metadata=synthesis, + source_weight_diagnostics=_artifact_ref( + root, + artifacts, + "source_weight_diagnostics", + "05_donor_integration_synthesis", + category="diagnostic", + ), + diagnostics=_diagnostics("05_donor_integration_synthesis", manifest), + complete=all( + _artifact_exists(root, artifacts, key) + for key in ("seed_data", "synthetic_data") + ), + ), + USPolicyEngineEntityOutputs( + policyengine_entity_tables=_artifact_ref( + root, + artifacts, + "policyengine_entity_tables", + "06_policyengine_entities", + ), + materialized_policyengine_inputs=_policyengine_entity_metadata_summary( + root, + artifacts, + ), + diagnostics=_diagnostics("06_policyengine_entities", manifest), + complete=_artifact_exists(root, artifacts, "policyengine_entity_tables"), + ), + USCalibrationOutputs( + calibrated_data=_artifact_ref( + root, artifacts, "calibrated_data", "07_calibration" + ), + targets=_artifact_ref(root, artifacts, "targets", "07_calibration"), + calibration_summary=_artifact_ref( + root, + artifacts, + "calibration_summary", + "07_calibration", + category="diagnostic", + ), + target_ledger={"target_count": manifest.get("targets", {})}, + diagnostics=_diagnostics("07_calibration", manifest), + complete=all( + _artifact_exists(root, artifacts, key) + for key in ("calibrated_data", "targets", "calibration_summary") + ), + ), + USDatasetAssemblyOutputs( + policyengine_dataset=_artifact_ref( + root, + artifacts, + "policyengine_dataset", + "08_dataset_assembly", + ), + stage_manifest=_derived_artifact_ref( + root, "stage_manifest", "08_dataset_assembly" + ), + data_flow_snapshot=_derived_artifact_ref( + root, + "data_flow_snapshot", + "08_dataset_assembly", + ), + artifact_inventory=_derived_artifact_ref( + root, + "artifact_inventory", + "08_dataset_assembly", + ), + conditional_readiness=_derived_artifact_ref( + root, + "conditional_readiness", + "08_dataset_assembly", + ), + diagnostics=_diagnostics("08_dataset_assembly", manifest), + complete=bool(has_dataset), + ), + USValidationBenchmarkingOutputs( + validation_evidence=( + _derived_artifact_ref( + root, + "validation_evidence", + "09_validation_benchmarking", + ) + if has_dataset or has_benchmark + else None + ), + benchmark_summary=benchmark_summary, + policyengine_harness=_artifact_ref( + root, + artifacts, + "policyengine_harness", + "09_validation_benchmarking", + category="diagnostic", + ), + policyengine_native_scores=_artifact_ref( + root, + artifacts, + "policyengine_native_scores", + "09_validation_benchmarking", + category="diagnostic", + ), + policyengine_native_audit=_artifact_ref( + root, + artifacts, + "policyengine_native_audit", + "09_validation_benchmarking", + category="diagnostic", + ), + imputation_ablation=_artifact_ref( + root, + artifacts, + "imputation_ablation", + "09_validation_benchmarking", + category="diagnostic", + ), + child_tax_unit_agi_drift=_artifact_ref( + root, + artifacts, + "child_tax_unit_agi_drift", + "09_validation_benchmarking", + category="diagnostic", + ), + diagnostics=_diagnostics( + "09_validation_benchmarking", + manifest, + stage_summary=benchmark_summary, + ), + complete=bool(has_benchmark), + ), + ) + + +def write_us_stage_run_manifests_from_artifact_manifest( + artifact_root: str | Path, + manifest_payload: Mapping[str, Any], + *, + allow_stage_input_overrides: bool = False, + stage_input_overrides: tuple[USStageInputOverride, ...] = (), +) -> dict[str, Any]: + """Write typed stage manifests and aggregate outputs from an artifact manifest.""" + + writer = USStageRunWriter( + artifact_root, + manifest_payload=manifest_payload, + allow_stage_input_overrides=allow_stage_input_overrides, + stage_input_overrides=stage_input_overrides, + ) + for outputs in build_us_stage_output_manifests_from_artifact_manifest( + artifact_root, + manifest_payload, + ): + writer.record_stage(outputs) + return writer.write_manifest_files() + + +def resolve_us_manifest_or_contract_artifact_path( + artifact_root: str | Path, + manifest_payload: Mapping[str, Any], + artifact_key: str, + *, + stage_id: str, +) -> Path: + """Resolve an artifact from the manifest first, then the stage contract.""" + + artifacts = dict(manifest_payload.get("artifacts", {})) + declared = artifacts.get(artifact_key) + if declared is not None: + path = Path(str(declared)) + if not path.is_absolute(): + path = Path(artifact_root) / path + return path + return resolve_us_stage_artifact_contract_path( + artifact_root, stage_id, artifact_key + ) + + +def parse_us_stage_input_override(value: str) -> USStageInputOverride: + """Parse STAGE_ID.KEY=PATH into a stage input override.""" + + if "=" not in value: + raise ValueError("Stage input overrides must use STAGE_ID.KEY=PATH syntax") + left, path = value.split("=", 1) + if "." not in left: + raise ValueError("Stage input overrides must use STAGE_ID.KEY=PATH syntax") + stage_id, key = left.split(".", 1) + if not stage_id or not key or not path: + raise ValueError("Stage input overrides must use STAGE_ID.KEY=PATH syntax") + if stage_id not in US_CANONICAL_STAGE_IDS: + raise ValueError(f"Unknown US pipeline stage: {stage_id}") + override = USStageInputOverride(stage_id=stage_id, key=key, path=path) + _validate_us_stage_input_override(override) + return override + + +def default_us_stage_input_validation_settings() -> dict[ + str, USStageInputValidationSettings +]: + """Return stage-specific settings for typed input boundary validation.""" + + return { + stage_id: USStageInputValidationSettings( + stage_id=stage_id, + require_previous_stage_manifest=stage_id != "01_run_profile", + ) + for stage_id in US_CANONICAL_STAGE_IDS + } + + +def _validate_us_stage_input_override(override: USStageInputOverride) -> None: + if override.stage_id not in US_CANONICAL_STAGE_IDS: + raise ValueError(f"Unknown US pipeline stage: {override.stage_id}") + contract = get_us_pipeline_stage_contract(override.stage_id) + input_keys = {resource.key for resource in contract.inputs} + if override.key not in input_keys: + valid_keys = ", ".join(sorted(input_keys)) or "none" + raise ValueError( + f"Unknown input override key {override.stage_id}.{override.key}; " + f"valid keys: {valid_keys}" + ) + + +def _artifact_ref( + artifact_root: Path, + artifacts: Mapping[str, Any], + artifact_key: str, + stage_id: str, + *, + category: USArtifactCategory = "required_output", + assume_exists: bool = False, +) -> USArtifactRef | None: + declared = artifacts.get(artifact_key) + if declared is None: + return None + contract = get_us_stage_artifact_contract(stage_id, artifact_key) + return USArtifactRef( + key=artifact_key, + path=str(declared), + format=contract.format, + required=contract.required, + category=category, + resume_role=contract.resume_role, + assume_exists=assume_exists, + exists=_artifact_path_exists(artifact_root, declared), + ) + + +def _derived_artifact_ref( + artifact_root: Path, + artifact_key: str, + stage_id: str, +) -> USArtifactRef: + contract = get_us_stage_artifact_contract(stage_id, artifact_key) + path = resolve_us_stage_artifact_contract_path( + artifact_root, stage_id, artifact_key + ) + return USArtifactRef( + key=artifact_key, + path=str(path.relative_to(artifact_root)), + format=contract.format, + required=contract.required, + category="derived", + resume_role=contract.resume_role, + assume_exists=True, + ) + + +def _artifact_exists( + artifact_root: Path, + artifacts: Mapping[str, Any], + artifact_key: str, +) -> bool: + declared = artifacts.get(artifact_key) + return declared is not None and _artifact_path_exists(artifact_root, declared) + + +def _artifact_path_exists(artifact_root: Path, value: Any) -> bool: + path = Path(str(value)) + if not path.is_absolute(): + path = artifact_root / path + return path.exists() + + +def _path_for_manifest(path: Path, artifact_root: Path) -> str: + try: + return str(path.relative_to(artifact_root)) + except ValueError: + return str(path) + + +def _policyengine_entity_metadata_summary( + artifact_root: Path, + artifacts: Mapping[str, Any], +) -> dict[str, Any]: + declared = artifacts.get("policyengine_entity_tables") + if declared is None: + return {} + path = Path(str(declared)) + if not path.is_absolute(): + path = artifact_root / path + summary: dict[str, Any] = { + "metadata_path": _path_for_manifest(path, artifact_root), + } + if not path.exists() or not path.is_file(): + return summary + try: + metadata = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return summary + if not isinstance(metadata, Mapping): + return summary + stage = metadata.get("stage") + if stage is not None: + summary["stage"] = stage + tables: dict[str, dict[str, Any]] = {} + for key in ( + "households", + "persons", + "tax_units", + "spm_units", + "families", + "marital_units", + ): + table_metadata = metadata.get(key) + if not isinstance(table_metadata, Mapping): + continue + columns = table_metadata.get("columns", ()) + column_names = ( + [str(column) for column in columns] + if isinstance(columns, (list, tuple)) + else [] + ) + tables[key] = { + "rows": table_metadata.get("rows"), + "columns": column_names, + } + if tables: + summary["tables"] = tables + return summary + + +def _diagnostics( + stage_id: str, + manifest: Mapping[str, Any], + *, + stage_summary: Mapping[str, Any] | None = None, +) -> dict[str, USDiagnosticOutput]: + diagnostics = dict(manifest.get("diagnostics", {})) + stage_diagnostics = diagnostics.get(stage_id) + summary = ( + dict(stage_diagnostics) + if isinstance(stage_diagnostics, Mapping) + else dict(stage_summary) + if stage_summary is not None + else _default_stage_diagnostic_summary(stage_id, manifest) + ) + return { + "stage_summary": USDiagnosticOutput( + key="stage_summary", + description=f"Saved-run diagnostic summary for {stage_id}.", + summary=summary, + ) + } + + +def _default_stage_diagnostic_summary( + stage_id: str, + manifest: Mapping[str, Any], +) -> dict[str, Any]: + rows = dict(manifest.get("rows", {})) + synthesis = dict(manifest.get("synthesis", {})) + calibration = dict(manifest.get("calibration", {})) + artifacts = dict(manifest.get("artifacts", {})) + if stage_id == "01_run_profile": + return {"has_config": isinstance(manifest.get("config"), Mapping)} + if stage_id == "02_source_loading": + return {"source_names": list(synthesis.get("source_names", ()))} + if stage_id == "03_source_planning": + return {"scaffold_source": synthesis.get("scaffold_source")} + if stage_id == "04_seed_scaffold": + return {"seed_rows": rows.get("seed")} + if stage_id == "05_donor_integration_synthesis": + return { + "seed_rows": rows.get("seed"), + "synthetic_rows": rows.get("synthetic"), + "backend": synthesis.get("backend"), + } + if stage_id == "06_policyengine_entities": + return {"entity_tables": artifacts.get("policyengine_entity_tables")} + if stage_id == "07_calibration": + return { + "calibrated_rows": rows.get("calibrated"), + "backend": calibration.get("backend"), + "converged": calibration.get("converged"), + } + if stage_id == "08_dataset_assembly": + return {"dataset": artifacts.get("policyengine_dataset")} + if stage_id == "09_validation_benchmarking": + return _manifest_benchmark_summary(manifest) + return {} + + +def _benchmark_summary( + artifact_root: Path, + manifest: Mapping[str, Any], +) -> tuple[dict[str, Any], bool]: + try: + evidence = build_us_validation_evidence_manifest( + artifact_root, + manifest_payload=dict(manifest), + ) + except (OSError, ValueError, TypeError): + summary = _manifest_benchmark_summary_for_existing_artifacts( + artifact_root, + manifest, + ) + return summary, bool(summary) + summary = _validation_evidence_summary_for_existing_evidence(evidence) + if summary: + return summary, True + summary = _manifest_benchmark_summary_for_existing_artifacts( + artifact_root, + manifest, + ) + return summary, bool(summary) + + +def _manifest_benchmark_summary(manifest: Mapping[str, Any]) -> dict[str, Any]: + summary: dict[str, Any] = {} + for key in ( + "policyengine_harness", + "policyengine_native_scores", + "policyengine_native_audit", + "imputation_ablation", + ): + value = manifest.get(key) + if isinstance(value, Mapping): + summary[key] = dict(value) + return summary + + +def _validation_evidence_summary_for_existing_evidence( + evidence: Mapping[str, Any], +) -> dict[str, Any]: + records = evidence.get("evidence") + if not isinstance(records, list): + return {} + existing_keys = { + str(record["key"]) + for record in records + if isinstance(record, Mapping) + and record.get("key") + and record.get("exists") is True + } + summaries = evidence.get("summaries") + if not isinstance(summaries, Mapping): + return {} + return { + str(key): item for key, item in summaries.items() if str(key) in existing_keys + } + + +def _manifest_benchmark_summary_for_existing_artifacts( + artifact_root: Path, + manifest: Mapping[str, Any], +) -> dict[str, Any]: + artifacts = dict(manifest.get("artifacts", {})) + return { + key: value + for key, value in _manifest_benchmark_summary(manifest).items() + if _artifact_exists(artifact_root, artifacts, key) + } + + +def _serialize_value(value: Any, artifact_root: str | Path | None) -> Any: + if isinstance(value, USArtifactRef): + return value.to_dict(artifact_root) + if isinstance(value, USAuxiliaryArtifact): + return value.as_artifact_ref().to_dict(artifact_root) + if isinstance(value, USDiagnosticOutput): + return value.to_dict(artifact_root) + if isinstance(value, Path): + return str(value) + if isinstance(value, Mapping): + return { + str(key): _serialize_value(item, artifact_root) + for key, item in value.items() + } + if isinstance(value, tuple): + return [_serialize_value(item, artifact_root) for item in value] + if isinstance(value, list): + return [_serialize_value(item, artifact_root) for item in value] + if is_dataclass(value): + return { + str(key): _serialize_value(item, artifact_root) + for key, item in asdict(value).items() + } + return value + + +def _required_output_is_missing(value: Any, artifact_root: str | Path) -> bool: + if value is None: + return True + if isinstance(value, USArtifactRef): + return not value.exists_under(artifact_root) + if isinstance(value, Mapping): + return not bool(value) + if isinstance(value, (tuple, list, set, frozenset)): + return not bool(value) + if isinstance(value, str): + return not value + return False + + +def _serialized_output_key_is_available( + stage_manifest: Mapping[str, Any], + key: str, +) -> bool: + outputs = stage_manifest.get("outputs") + if not isinstance(outputs, Mapping) or key not in outputs: + return False + value = outputs[key] + if value is None: + return False + if isinstance(value, Mapping): + exists = value.get("exists") + if exists is not None: + return bool(exists) + return bool(value) + if isinstance(value, (tuple, list, set, frozenset)): + return bool(value) + if isinstance(value, str): + return bool(value) + return True + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +def _write_json_atomically(path: Path, payload: Mapping[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + temporary = path.with_suffix(path.suffix + ".tmp") + temporary.write_text(json.dumps(payload, indent=2, sort_keys=True)) + temporary.replace(path) + + +__all__ = [ + "USAuxiliaryArtifact", + "USArtifactCategory", + "USArtifactRef", + "USCalibrationOutputs", + "USDatasetAssemblyOutputs", + "USDiagnosticOutput", + "USDonorSynthesisOutputs", + "USPolicyEngineEntityOutputs", + "USRunProfileOutputs", + "USSeedScaffoldOutputs", + "USSourceLoadingOutputs", + "USSourcePlanningOutputs", + "USStageInputOverride", + "USStageInputValidationSettings", + "USStageInputValidator", + "USStageOutputManifest", + "USStageRunWriter", + "USValidationBenchmarkingOutputs", + "build_us_stage_output_manifests_from_artifact_manifest", + "default_us_stage_input_validation_settings", + "parse_us_stage_input_override", + "resolve_us_manifest_or_contract_artifact_path", + "write_us_stage_run_manifests_from_artifact_manifest", +] diff --git a/src/microplex_us/pipelines/stage_status.py b/src/microplex_us/pipelines/stage_status.py new file mode 100644 index 0000000..410f57d --- /dev/null +++ b/src/microplex_us/pipelines/stage_status.py @@ -0,0 +1,274 @@ +"""Saved-run status classification for US pipeline stage manifests.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from microplex_us.pipelines.stage_manifest_types import ( + USStageArtifactRecord, + USStageStatus, +) + + +def stage_status( + stage_id: str, + *, + artifact_root: Path, + manifest: dict[str, Any], + artifacts: list[USStageArtifactRecord], + assume_existing_artifact_keys: set[str], +) -> USStageStatus: + """Return the saved-run status for one canonical stage.""" + + artifact_map = dict(manifest.get("artifacts", {})) + synthesis = dict(manifest.get("synthesis", {})) + calibration = dict(manifest.get("calibration", {})) + rows = dict(manifest.get("rows", {})) + if stage_id == "01_run_profile": + if artifact_missing(artifacts, required_only=True): + return "incomplete" + if artifact_exists(artifacts, "manifest"): + return "ready" + return "metadata_only" if manifest.get("config") else "missing" + if stage_id == "02_source_loading": + return "metadata_only" if synthesis.get("source_names") else "missing" + if stage_id == "03_source_planning": + if artifact_missing(artifacts): + return "incomplete" + if artifact_exists(artifacts, "source_plan"): + return "ready" + return "metadata_only" if synthesis.get("scaffold_source") else "missing" + if stage_id == "04_seed_scaffold": + if artifact_missing(artifacts, required_only=True): + return "incomplete" + if required_artifacts_exist(artifacts): + return "ready" + return ( + "metadata_only" + if rows.get("seed") or synthesis.get("scaffold_source") + else "missing" + ) + if stage_id == "05_donor_integration_synthesis": + if artifact_missing(artifacts, required_only=True): + return "incomplete" + if required_artifacts_exist(artifacts): + return "ready" + return ( + "metadata_only" if rows.get("seed") or rows.get("synthetic") else "missing" + ) + if stage_id == "06_policyengine_entities": + if artifact_missing(artifacts): + return "incomplete" + if artifact_exists(artifacts, "policyengine_entity_tables"): + return "ready" + if manifest_artifact_exists( + manifest, + artifact_root, + "policyengine_dataset", + assume_existing_artifact_keys=assume_existing_artifact_keys, + ): + return "metadata_only" + return "missing" + if stage_id == "07_calibration": + if artifact_missing(artifacts, required_only=True): + return "incomplete" + if calibration and required_artifacts_exist(artifacts): + return "ready" + return "metadata_only" if calibration and rows.get("calibrated") else "missing" + if stage_id == "08_dataset_assembly": + if artifact_missing(artifacts, required_only=True): + return "incomplete" + if manifest_artifact_exists( + manifest, + artifact_root, + "policyengine_dataset", + assume_existing_artifact_keys=assume_existing_artifact_keys, + ): + return "ready" + return "metadata_only" if artifact_map.get("stage_manifest") else "missing" + if stage_id == "09_validation_benchmarking": + evidence_keys = ( + "policyengine_harness", + "policyengine_native_scores", + "policyengine_native_audit", + "imputation_ablation", + ) + evidence_index_keys = ("validation_evidence",) + if manifest_artifact_missing( + manifest, + artifact_root, + (*evidence_keys, *evidence_index_keys), + assume_existing_artifact_keys=assume_existing_artifact_keys, + ): + return "incomplete" + has_evidence = any( + manifest_artifact_exists( + manifest, + artifact_root, + key, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ) + for key in evidence_keys + ) + if not has_evidence: + has_evidence = validation_evidence_index_has_existing_evidence( + manifest, + artifact_root, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ) + if has_evidence: + if not manifest_artifact_exists( + manifest, + artifact_root, + "validation_evidence", + assume_existing_artifact_keys=assume_existing_artifact_keys, + ): + return "incomplete" + return "ready" + if manifest_artifact_exists( + manifest, + artifact_root, + "policyengine_dataset", + assume_existing_artifact_keys=assume_existing_artifact_keys, + ): + return "deferred" + return "missing" + if any(artifact.get("exists") for artifact in artifacts): + return "ready" + return "missing" + + +def required_artifacts_exist(artifacts: list[USStageArtifactRecord]) -> bool: + """Return whether all required artifacts exist.""" + + required = [artifact for artifact in artifacts if bool(artifact.get("required"))] + return bool(required) and all(bool(artifact.get("exists")) for artifact in required) + + +def artifact_exists(artifacts: list[USStageArtifactRecord], key: str) -> bool: + """Return whether a stage artifact record exists.""" + + return any( + artifact.get("key") == key and bool(artifact.get("exists")) + for artifact in artifacts + ) + + +def artifact_missing( + artifacts: list[USStageArtifactRecord], + *, + required_only: bool = False, +) -> bool: + """Return whether required or referenced stage artifacts are missing.""" + + return any( + not bool(artifact.get("exists")) + and ( + bool(artifact.get("required")) + or (not required_only and bool(artifact.get("referenced"))) + ) + for artifact in artifacts + ) + + +def manifest_artifact_exists( + manifest: dict[str, Any], + artifact_root: Path, + artifact_key: str, + *, + assume_existing_artifact_keys: set[str], +) -> bool: + """Return whether a top-level manifest artifact exists.""" + + path = manifest_artifact_path(manifest, artifact_root, artifact_key) + if path is None: + return False + if artifact_key in assume_existing_artifact_keys: + return True + return path.exists() + + +def manifest_artifact_missing( + manifest: dict[str, Any], + artifact_root: Path, + artifact_keys: tuple[str, ...], + *, + assume_existing_artifact_keys: set[str], +) -> bool: + """Return whether any referenced top-level manifest artifact is missing.""" + + artifacts = dict(manifest.get("artifacts", {})) + return any( + bool(artifacts.get(key)) + and not manifest_artifact_exists( + manifest, + artifact_root, + key, + assume_existing_artifact_keys=assume_existing_artifact_keys, + ) + for key in artifact_keys + ) + + +def validation_evidence_index_has_existing_evidence( + manifest: dict[str, Any], + artifact_root: Path, + *, + assume_existing_artifact_keys: set[str], +) -> bool: + """Return whether a validation evidence index points to existing evidence.""" + + path = manifest_artifact_path(manifest, artifact_root, "validation_evidence") + if path is None: + return False + if "validation_evidence" in assume_existing_artifact_keys and not path.exists(): + return False + if not path.exists(): + return False + try: + payload = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return False + evidence = payload.get("evidence") + if not isinstance(evidence, list): + return False + for record in evidence: + if not isinstance(record, dict) or not record.get("path"): + continue + evidence_path = Path(str(record["path"])) + if not evidence_path.is_absolute(): + evidence_path = artifact_root / evidence_path + if evidence_path.exists(): + return True + return False + + +def manifest_artifact_path( + manifest: dict[str, Any], + artifact_root: Path, + artifact_key: str, +) -> Path | None: + """Return the resolved path for a top-level manifest artifact.""" + + artifacts = dict(manifest.get("artifacts", {})) + filename = artifacts.get(artifact_key) + if not filename: + return None + path = Path(str(filename)) + if not path.is_absolute(): + path = artifact_root / path + return path + + +__all__ = [ + "artifact_exists", + "artifact_missing", + "manifest_artifact_exists", + "manifest_artifact_missing", + "manifest_artifact_path", + "required_artifacts_exist", + "stage_status", + "validation_evidence_index_has_existing_evidence", +] diff --git a/src/microplex_us/pipelines/stage_validation_evidence.py b/src/microplex_us/pipelines/stage_validation_evidence.py new file mode 100644 index 0000000..4b90e7b --- /dev/null +++ b/src/microplex_us/pipelines/stage_validation_evidence.py @@ -0,0 +1,136 @@ +"""Validation and benchmarking evidence manifests for US saved runs.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +from microplex_us.pipelines.stage_manifest_io import write_json_atomically +from microplex_us.pipelines.stage_manifest_types import ( + US_VALIDATION_STAGE_ID, + USValidationEvidenceManifest, + USValidationEvidenceRecord, +) + + +def build_us_validation_evidence_manifest( + artifact_dir: str | Path, + *, + manifest_payload: dict[str, Any], +) -> USValidationEvidenceManifest: + """Build a compact Stage 9 evidence index from a saved artifact manifest.""" + + artifact_root = Path(artifact_dir) + artifacts = dict(manifest_payload.get("artifacts", {})) + existing = _load_existing_validation_evidence_manifest(artifact_root, artifacts) + evidence_keys = ( + "policyengine_harness", + "policyengine_native_scores", + "policyengine_native_audit", + "imputation_ablation", + "child_tax_unit_agi_drift", + ) + evidence_by_key: dict[str, USValidationEvidenceRecord] = {} + if existing is not None: + for record in existing.get("evidence", ()): + if not isinstance(record, Mapping) or not record.get("key"): + continue + key = str(record["key"]) + evidence_by_key[key] = _validation_evidence_record( + artifact_root, + key, + record.get("path"), + ) + for key in evidence_keys: + filename = artifacts.get(key) + if not filename: + continue + evidence_by_key[key] = _validation_evidence_record( + artifact_root, + key, + filename, + ) + summaries: dict[str, Any] = {} + if existing is not None and isinstance(existing.get("summaries"), Mapping): + summaries.update(dict(existing["summaries"])) + summaries.update( + { + key: manifest_payload[key] + for key in ( + "policyengine_harness", + "policyengine_native_scores", + "policyengine_native_audit", + "imputation_ablation", + ) + if isinstance(manifest_payload.get(key), dict) + } + ) + return { + "formatVersion": 1, + "stageId": US_VALIDATION_STAGE_ID, + "evidence": list(evidence_by_key.values()), + "summaries": summaries, + } + + +def write_us_validation_evidence_manifest( + artifact_dir: str | Path, + output_path: str | Path, + *, + manifest_payload: dict[str, Any], +) -> Path: + """Write a Stage 9 evidence manifest for validation/benchmark sidecars.""" + + destination = Path(output_path) + destination.parent.mkdir(parents=True, exist_ok=True) + write_json_atomically( + destination, + build_us_validation_evidence_manifest( + artifact_dir, + manifest_payload=manifest_payload, + ), + ) + return destination + + +def _load_existing_validation_evidence_manifest( + artifact_root: Path, + artifacts: Mapping[str, Any], +) -> Mapping[str, Any] | None: + evidence_name = artifacts.get("validation_evidence") + if not evidence_name: + return None + path = Path(str(evidence_name)) + if not path.is_absolute(): + path = artifact_root / path + if not path.exists(): + return None + try: + payload = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return None + return payload if isinstance(payload, Mapping) else None + + +def _validation_evidence_record( + artifact_root: Path, + key: str, + path_value: Any, +) -> USValidationEvidenceRecord: + path_text = str(path_value) if path_value else "" + path = Path(path_text) + if path_text and not path.is_absolute(): + path = artifact_root / path + return { + "key": key, + "path": path_text, + "exists": bool(path_text) and path.exists(), + } + + +__all__ = [ + "build_us_validation_evidence_manifest", + "write_us_validation_evidence_manifest", +] diff --git a/src/microplex_us/pipelines/summarize_pe_native_family_drilldown.py b/src/microplex_us/pipelines/summarize_pe_native_family_drilldown.py index 906eae2..e398f9d 100644 --- a/src/microplex_us/pipelines/summarize_pe_native_family_drilldown.py +++ b/src/microplex_us/pipelines/summarize_pe_native_family_drilldown.py @@ -8,6 +8,11 @@ from pathlib import Path from typing import Any +from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) + def classify_pe_native_target_family(target_name: str) -> str: """Classify one PE target name into the broad-loss family buckets.""" @@ -77,7 +82,13 @@ def summarize_us_pe_native_family_drilldown( root_key = artifact_root.name for bundle_dir in _iter_native_audit_bundle_dirs(artifact_root): total_audits += 1 - payload = json.loads((bundle_dir / "pe_us_data_rebuild_native_audit.json").read_text()) + payload = json.loads( + resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_audit", + ).read_text() + ) verdict_hints = dict(payload.get("verdictHints", {})) support_summary = dict(payload.get("supportAuditSummary", {})) matching_targets = [ @@ -190,11 +201,21 @@ def summarize_us_pe_native_family_drilldown( def _iter_native_audit_bundle_dirs(artifact_root: Path) -> tuple[Path, ...]: + audit_hint = get_us_stage_artifact_contract( + "09_validation_benchmarking", + "policyengine_native_audit", + ).path_hint + dataset_hint = get_us_stage_artifact_contract( + "08_dataset_assembly", + "policyengine_dataset", + ).path_hint + if audit_hint is None or dataset_hint is None: + return () return tuple( sorted( path.parent - for path in artifact_root.rglob("pe_us_data_rebuild_native_audit.json") - if (path.parent / "policyengine_us.h5").exists() + for path in artifact_root.rglob(audit_hint) + if (path.parent / dataset_hint).exists() ) ) diff --git a/src/microplex_us/pipelines/summarize_pe_native_regressions.py b/src/microplex_us/pipelines/summarize_pe_native_regressions.py index 3577f40..9ce8117 100644 --- a/src/microplex_us/pipelines/summarize_pe_native_regressions.py +++ b/src/microplex_us/pipelines/summarize_pe_native_regressions.py @@ -8,6 +8,11 @@ from pathlib import Path from typing import Any +from microplex_us.pipelines.stage_contracts import ( + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, +) + def _sorted_counter_items(counter: Counter[str]) -> list[tuple[str, int]]: return sorted(counter.items(), key=lambda item: (-int(item[1]), item[0])) @@ -32,7 +37,13 @@ def summarize_us_pe_native_regressions( for artifact_root in normalized_roots: root_key = artifact_root.name for bundle_dir in _iter_scored_bundle_dirs(artifact_root): - scores_payload = json.loads((bundle_dir / "policyengine_native_scores.json").read_text()) + scores_payload = json.loads( + resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_scores", + ).read_text() + ) summary = dict(scores_payload.get("summary", {})) positive_families = [ row @@ -47,7 +58,11 @@ def summarize_us_pe_native_regressions( largest_family = positive_families[0] if positive_families else {} top3_families = [row.get("family") for row in positive_families[:3]] - audit_path = bundle_dir / "pe_us_data_rebuild_native_audit.json" + audit_path = resolve_us_stage_artifact_contract_path( + bundle_dir, + "09_validation_benchmarking", + "policyengine_native_audit", + ) audit_payload = json.loads(audit_path.read_text()) if audit_path.exists() else None verdict_hints = dict((audit_payload or {}).get("verdictHints", {})) support_summary = dict((audit_payload or {}).get("supportAuditSummary", {})) @@ -142,11 +157,21 @@ def summarize_us_pe_native_regressions( def _iter_scored_bundle_dirs(artifact_root: Path) -> tuple[Path, ...]: + scores_hint = get_us_stage_artifact_contract( + "09_validation_benchmarking", + "policyengine_native_scores", + ).path_hint + dataset_hint = get_us_stage_artifact_contract( + "08_dataset_assembly", + "policyengine_dataset", + ).path_hint + if scores_hint is None or dataset_hint is None: + return () return tuple( sorted( path.parent - for path in artifact_root.rglob("policyengine_native_scores.json") - if (path.parent / "policyengine_us.h5").exists() + for path in artifact_root.rglob(scores_hint) + if (path.parent / dataset_hint).exists() ) ) diff --git a/src/microplex_us/pipelines/summarize_policyengine_oracle_regressions.py b/src/microplex_us/pipelines/summarize_policyengine_oracle_regressions.py index 0a07ba2..9893be2 100644 --- a/src/microplex_us/pipelines/summarize_policyengine_oracle_regressions.py +++ b/src/microplex_us/pipelines/summarize_policyengine_oracle_regressions.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import Any +from microplex_us.pipelines.stage_contracts import get_us_stage_artifact_contract + def _sorted_counter_items(counter: Counter[str]) -> list[tuple[str, int]]: return sorted(counter.items(), key=lambda item: (-int(item[1]), item[0])) @@ -173,11 +175,17 @@ def summarize_us_policyengine_oracle_regressions( def _iter_oracle_bundle_dirs(artifact_root: Path) -> tuple[Path, ...]: + dataset_hint = get_us_stage_artifact_contract( + "08_dataset_assembly", + "policyengine_dataset", + ).path_hint + if dataset_hint is None: + return () return tuple( sorted( path.parent for path in artifact_root.rglob("manifest.json") - if (path.parent / "policyengine_us.h5").exists() + if (path.parent / dataset_hint).exists() ) ) diff --git a/src/microplex_us/pipelines/summarize_policyengine_oracle_target_drilldown.py b/src/microplex_us/pipelines/summarize_policyengine_oracle_target_drilldown.py index bf271d9..e554329 100644 --- a/src/microplex_us/pipelines/summarize_policyengine_oracle_target_drilldown.py +++ b/src/microplex_us/pipelines/summarize_policyengine_oracle_target_drilldown.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Any +from microplex_us.pipelines.stage_contracts import ( + resolve_us_stage_artifact_contract_path, +) from microplex_us.pipelines.us import ( USMicroplexBuildConfig, USMicroplexPipeline, @@ -40,11 +43,16 @@ def summarize_us_policyengine_oracle_target_drilldown( if config.policyengine_targets_db is None: raise ValueError("Artifact config does not define policyengine_targets_db") - dataset_name = dict(manifest.get("artifacts", {})).get( - "policyengine_dataset", - "policyengine_us.h5", + dataset_name = dict(manifest.get("artifacts", {})).get("policyengine_dataset") + dataset_path = ( + _resolve_manifest_artifact_path(bundle_dir, str(dataset_name)) + if dataset_name is not None + else resolve_us_stage_artifact_contract_path( + bundle_dir, + "08_dataset_assembly", + "policyengine_dataset", + ) ) - dataset_path = (bundle_dir / dataset_name).resolve() if not dataset_path.exists(): raise FileNotFoundError(f"PolicyEngine dataset not found: {dataset_path}") @@ -206,6 +214,13 @@ def summarize_us_policyengine_oracle_target_drilldown( } +def _resolve_manifest_artifact_path(bundle_dir: Path, artifact_name: str) -> Path: + artifact_path = Path(artifact_name) + if artifact_path.is_absolute(): + return artifact_path + return (bundle_dir / artifact_path).resolve() + + def _oracle_target_row( *, target: Any, diff --git a/tests/pipelines/test_artifacts.py b/tests/pipelines/test_artifacts.py index 9bfe193..b5241ba 100644 --- a/tests/pipelines/test_artifacts.py +++ b/tests/pipelines/test_artifacts.py @@ -302,6 +302,10 @@ def test_writes_expected_files(self, tmp_path): assert paths.policyengine_dataset.exists() assert paths.stage_manifest is not None assert paths.stage_manifest.exists() + assert paths.artifact_inventory is not None + assert paths.artifact_inventory.exists() + assert paths.conditional_readiness is not None + assert paths.conditional_readiness.exists() assert paths.source_plan is not None assert paths.source_plan.exists() assert paths.policyengine_entity_tables is not None @@ -323,6 +327,14 @@ def test_writes_expected_files(self, tmp_path): ) assert manifest["artifacts"]["policyengine_dataset"] == "policyengine_us.h5" assert manifest["artifacts"]["stage_manifest"] == "stage_manifest.json" + assert ( + manifest["artifacts"]["artifact_inventory"] + == "stage_artifacts/artifact_inventory.json" + ) + assert ( + manifest["artifacts"]["conditional_readiness"] + == "stage_artifacts/conditional_readiness.json" + ) assert ( manifest["artifacts"]["policyengine_entity_tables"] == "stage_artifacts/06_policyengine_entities/metadata.json" @@ -332,6 +344,23 @@ def test_writes_expected_files(self, tmp_path): == "source_weight_diagnostics.json" ) source_diagnostics = json.loads(paths.source_weight_diagnostics.read_text()) + artifact_inventory = json.loads(paths.artifact_inventory.read_text()) + conditional_readiness = json.loads(paths.conditional_readiness.read_text()) + inventory_records = { + (record["stageId"], record["key"]): record + for record in artifact_inventory["artifacts"] + } + assert inventory_records[("01_run_profile", "manifest")]["exists"] is True + assert inventory_records[ + ("08_dataset_assembly", "policyengine_dataset") + ]["classification"] == "post_artifact_evidence" + readiness = { + stage["stageId"]: stage + for stage in conditional_readiness["stages"] + } + assert readiness["09_validation_benchmarking"]["readiness"] == ( + "post_artifact_evidence" + ) assert ( source_diagnostics["summary"]["diagnostic_scope"] == "saved_artifact_entity_weight_by_source_rows" diff --git a/tests/pipelines/test_backfill_pe_native_scores.py b/tests/pipelines/test_backfill_pe_native_scores.py index 58163f7..a81c0a3 100644 --- a/tests/pipelines/test_backfill_pe_native_scores.py +++ b/tests/pipelines/test_backfill_pe_native_scores.py @@ -75,15 +75,37 @@ def test_backfill_us_pe_native_scores_root_updates_manifest_and_registry( assert sidecar_path.exists() updated_manifest = json.loads(manifest_path.read_text()) + stage_manifest = json.loads((bundle_dir / "stage_manifest.json").read_text()) + validation_evidence = json.loads( + ( + bundle_dir + / "stage_artifacts" + / "09_validation_benchmarking" + / "evidence_manifest.json" + ).read_text() + ) assert ( updated_manifest["artifacts"]["policyengine_native_scores"] == "policyengine_native_scores.json" ) + assert updated_manifest["artifacts"]["stage_manifest"] == "stage_manifest.json" + assert ( + updated_manifest["artifacts"]["validation_evidence"] + == "stage_artifacts/09_validation_benchmarking/evidence_manifest.json" + ) assert updated_manifest["policyengine_native_scores"]["candidate_beats_baseline"] is True assert ( updated_manifest["run_registry"]["default_frontier_metric"] == "enhanced_cps_native_loss_delta" ) + stage9 = next( + stage + for stage in stage_manifest["stages"] + if stage["id"] == "09_validation_benchmarking" + ) + assert stage9["status"] == "ready" + assert validation_evidence["evidence"][0]["key"] == "policyengine_native_scores" + assert validation_evidence["evidence"][0]["exists"] is True registry_path = artifact_root / "run_registry.jsonl" assert registry_path.exists() diff --git a/tests/pipelines/test_data_flow_snapshot.py b/tests/pipelines/test_data_flow_snapshot.py index 9b07541..79f7275 100644 --- a/tests/pipelines/test_data_flow_snapshot.py +++ b/tests/pipelines/test_data_flow_snapshot.py @@ -13,6 +13,27 @@ def test_build_us_microplex_data_flow_snapshot_reads_manifest_runtime_mix(tmp_pa artifact_dir.mkdir() (artifact_dir / "policyengine_us.h5").write_text("dataset") (artifact_dir / "policyengine_harness.json").write_text("{}") + evidence_path = ( + artifact_dir + / "stage_artifacts" + / "09_validation_benchmarking" + / "evidence_manifest.json" + ) + evidence_path.parent.mkdir(parents=True) + evidence_path.write_text( + json.dumps( + { + "schemaVersion": 1, + "evidence": [ + { + "key": "policyengine_harness", + "path": "policyengine_harness.json", + "exists": True, + } + ], + } + ) + ) (artifact_dir / "manifest.json").write_text( json.dumps( { @@ -52,6 +73,10 @@ def test_build_us_microplex_data_flow_snapshot_reads_manifest_runtime_mix(tmp_pa "artifacts": { "policyengine_dataset": "policyengine_us.h5", "policyengine_harness": "policyengine_harness.json", + "validation_evidence": ( + "stage_artifacts/09_validation_benchmarking/" + "evidence_manifest.json" + ), }, "policyengine_harness": { "mean_abs_relative_error_delta": -0.2, @@ -220,7 +245,11 @@ def test_write_us_microplex_data_flow_snapshot_ignores_stale_stage_manifest( artifact_dir, artifact_dir / "data_flow_snapshot.json", manifest_payload=manifest, - assume_existing_stage_artifact_keys=("stage_manifest",), + assume_existing_stage_artifact_keys=( + "stage_manifest", + "artifact_inventory", + "conditional_readiness", + ), ) snapshot = json.loads((artifact_dir / "data_flow_snapshot.json").read_text()) diff --git a/tests/pipelines/test_experiments.py b/tests/pipelines/test_experiments.py index ffa45c9..40e4f47 100644 --- a/tests/pipelines/test_experiments.py +++ b/tests/pipelines/test_experiments.py @@ -39,6 +39,10 @@ def _artifact_paths(root: Path, name: str) -> USMicroplexArtifactPaths: synthesizer=None, policyengine_dataset=output_dir / "policyengine.h5", data_flow_snapshot=output_dir / "data_flow_snapshot.json", + artifact_inventory=output_dir / "stage_artifacts" / "artifact_inventory.json", + conditional_readiness=( + output_dir / "stage_artifacts" / "conditional_readiness.json" + ), policyengine_harness=output_dir / "policyengine_harness.json", policyengine_native_scores=output_dir / "policyengine_native_scores.json", policyengine_native_audit=output_dir / "pe_us_data_rebuild_native_audit.json", @@ -153,6 +157,8 @@ def fake_build_and_save( assert loaded.leaderboard[0].current_entry is not None assert loaded.leaderboard[0].current_entry.candidate_composite_parity_loss == 0.35 assert loaded.leaderboard[0].artifact_paths.data_flow_snapshot is not None + assert loaded.leaderboard[0].artifact_paths.artifact_inventory is not None + assert loaded.leaderboard[0].artifact_paths.conditional_readiness is not None assert loaded.leaderboard[0].artifact_paths.policyengine_native_scores is not None assert loaded.leaderboard[0].artifact_paths.policyengine_native_audit is not None assert loaded.leaderboard[0].artifact_paths.run_index_db is not None @@ -507,6 +513,10 @@ def test_refresh_experiment_results_from_registry_refreshes_backfilled_artifact_ { "artifacts": { "data_flow_snapshot": "data_flow_snapshot.json", + "artifact_inventory": "stage_artifacts/artifact_inventory.json", + "conditional_readiness": ( + "stage_artifacts/conditional_readiness.json" + ), "policyengine_native_scores": "policyengine_native_scores.json", "policyengine_native_audit": "pe_us_data_rebuild_native_audit.json", } @@ -519,6 +529,9 @@ def test_refresh_experiment_results_from_registry_refreshes_backfilled_artifact_ "pe_us_data_rebuild_native_audit.json", ): (output_dir / name).write_text("{}") + (output_dir / "stage_artifacts").mkdir() + for name in ("artifact_inventory.json", "conditional_readiness.json"): + (output_dir / "stage_artifacts" / name).write_text("{}") registry_path = tmp_path / "run_registry.jsonl" result = USMicroplexExperimentResult( name="cps-only", @@ -555,6 +568,12 @@ def test_refresh_experiment_results_from_registry_refreshes_backfilled_artifact_ ) assert loaded[0].artifact_paths.data_flow_snapshot == output_dir / "data_flow_snapshot.json" + assert loaded[0].artifact_paths.artifact_inventory == ( + output_dir / "stage_artifacts" / "artifact_inventory.json" + ) + assert loaded[0].artifact_paths.conditional_readiness == ( + output_dir / "stage_artifacts" / "conditional_readiness.json" + ) assert ( loaded[0].artifact_paths.policyengine_native_scores == output_dir / "policyengine_native_scores.json" diff --git a/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py b/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py index 1c6a321..6557317 100644 --- a/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py +++ b/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py @@ -315,6 +315,8 @@ def fake_build_and_save_versioned_us_microplex_from_source_providers( run_index_path, run_registry_metadata, enable_child_tax_unit_agi_drift, + allow_stage_input_overrides, + stage_input_overrides, ): captured.update( { @@ -330,6 +332,8 @@ def fake_build_and_save_versioned_us_microplex_from_source_providers( "defer_policyengine_harness": defer_policyengine_harness, "defer_policyengine_native_score": defer_policyengine_native_score, "enable_child_tax_unit_agi_drift": enable_child_tax_unit_agi_drift, + "allow_stage_input_overrides": allow_stage_input_overrides, + "stage_input_overrides": stage_input_overrides, } ) manifest = { @@ -564,6 +568,8 @@ def fake_attach_policyengine_us_data_rebuild_checkpoint_evidence( assert captured["defer_policyengine_harness"] is True assert captured["defer_policyengine_native_score"] is True assert captured["enable_child_tax_unit_agi_drift"] is True + assert captured["allow_stage_input_overrides"] is False + assert captured["stage_input_overrides"] == () assert captured["policyengine_harness_metadata"]["rebuild_checkpoint"] is True assert captured["policyengine_harness_metadata"]["rebuild_program_id"] == ( "pe-us-data-rebuild-v1" @@ -597,7 +603,10 @@ def fake_attach_policyengine_us_data_rebuild_checkpoint_evidence( assert result.artifacts.frontier_entry is not None assert result.artifacts.frontier_entry.artifact_id == "run-1" assert result.artifacts.frontier_delta == 0.0 - assert result.native_audit_path == artifact_dir / "pe_us_data_rebuild_native_audit.json" + assert ( + result.native_audit_path + == artifact_dir / "pe_us_data_rebuild_native_audit.json" + ) assert result.native_audit_payload == { "verdictHints": {"largestRegressingFamily": None} } @@ -659,8 +668,7 @@ def info(self, message: str) -> None: "[version_id=run-1, providers=fake_source]" ] assert ( - stderr - == "PE-US-data rebuild checkpoint: starting build " + stderr == "PE-US-data rebuild checkpoint: starting build " "[version_id=run-1, providers=fake_source]\n" ) @@ -992,7 +1000,10 @@ def test_attach_policyengine_us_data_rebuild_checkpoint_evidence_updates_manifes registry_entries = load_us_microplex_run_registry(tmp_path / "run_registry.jsonl") assert result.harness_path == artifact_dir / "policyengine_harness.json" assert result.native_scores_path == artifact_dir / "policyengine_native_scores.json" - assert result.native_audit_path == artifact_dir / "pe_us_data_rebuild_native_audit.json" + assert ( + result.native_audit_path + == artifact_dir / "pe_us_data_rebuild_native_audit.json" + ) assert result.native_audit_payload == native_audit_payload assert result.imputation_ablation_path == artifact_dir / "imputation_ablation.json" written_native_audit = json.loads( @@ -1034,7 +1045,10 @@ def test_attach_policyengine_us_data_rebuild_checkpoint_evidence_updates_manifes ] is True ) - assert written_native_audit["verdictHints"]["productionImputationVariantIsMaeWinner"] is True + assert ( + written_native_audit["verdictHints"]["productionImputationVariantIsMaeWinner"] + is True + ) assert written_manifest["run_registry"]["artifact_id"] == "artifact" assert written_manifest["run_index"]["artifact_id"] == "artifact" assert (tmp_path / "run_index.duckdb").exists() @@ -1145,6 +1159,96 @@ def test_attach_policyengine_us_data_rebuild_checkpoint_evidence_registers_calib assert registry_entries[0].full_oracle_mean_abs_relative_error == 0.12 +def test_load_checkpoint_versioned_artifacts_hydrates_stage_sidecar_paths( + tmp_path, +) -> None: + artifact_dir = tmp_path / "artifact" + artifact_dir.mkdir() + stage_artifacts = artifact_dir / "stage_artifacts" + for path in ( + artifact_dir / "seed_data.parquet", + artifact_dir / "synthetic_data.parquet", + artifact_dir / "calibrated_data.parquet", + artifact_dir / "targets.json", + artifact_dir / "policyengine_us.h5", + artifact_dir / "stage_manifest.json", + artifact_dir / "data_flow_snapshot.json", + stage_artifacts / "03_source_planning" / "source_plan.json", + stage_artifacts / "04_seed_scaffold" / "scaffold_seed_data.parquet", + stage_artifacts / "06_policyengine_entities" / "metadata.json", + stage_artifacts / "07_calibration" / "calibration_summary.json", + stage_artifacts / "09_validation_benchmarking" / "evidence_manifest.json", + stage_artifacts / "artifact_inventory.json", + stage_artifacts / "conditional_readiness.json", + artifact_dir / "policyengine_native_scores.json", + artifact_dir / "source_weight_diagnostics.json", + ): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("{}") + manifest = { + "artifacts": { + "seed_data": "seed_data.parquet", + "synthetic_data": "synthetic_data.parquet", + "calibrated_data": "calibrated_data.parquet", + "targets": "targets.json", + "policyengine_dataset": "policyengine_us.h5", + "stage_manifest": "stage_manifest.json", + "data_flow_snapshot": "data_flow_snapshot.json", + "source_plan": "stage_artifacts/03_source_planning/source_plan.json", + "scaffold_seed_data": ( + "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet" + ), + "policyengine_entity_tables": ( + "stage_artifacts/06_policyengine_entities/metadata.json" + ), + "calibration_summary": ( + "stage_artifacts/07_calibration/calibration_summary.json" + ), + "validation_evidence": ( + "stage_artifacts/09_validation_benchmarking/evidence_manifest.json" + ), + "artifact_inventory": "stage_artifacts/artifact_inventory.json", + "conditional_readiness": "stage_artifacts/conditional_readiness.json", + "policyengine_native_scores": "policyengine_native_scores.json", + "source_weight_diagnostics": "source_weight_diagnostics.json", + } + } + (artifact_dir / "manifest.json").write_text(json.dumps(manifest)) + + loaded = checkpoint_module._load_checkpoint_versioned_artifacts( + build_result=SimpleNamespace(), + artifact_root=artifact_dir, + frontier_metric="full_oracle_mean_abs_relative_error", + ) + paths = loaded.artifact_paths + + assert paths.stage_manifest == artifact_dir / "stage_manifest.json" + assert paths.data_flow_snapshot == artifact_dir / "data_flow_snapshot.json" + assert paths.artifact_inventory == stage_artifacts / "artifact_inventory.json" + assert paths.conditional_readiness == stage_artifacts / "conditional_readiness.json" + assert ( + paths.source_plan == stage_artifacts / "03_source_planning" / "source_plan.json" + ) + assert paths.scaffold_seed_data == ( + stage_artifacts / "04_seed_scaffold" / "scaffold_seed_data.parquet" + ) + assert paths.policyengine_entity_tables == ( + stage_artifacts / "06_policyengine_entities" / "metadata.json" + ) + assert paths.calibration_summary == ( + stage_artifacts / "07_calibration" / "calibration_summary.json" + ) + assert paths.validation_evidence == ( + stage_artifacts / "09_validation_benchmarking" / "evidence_manifest.json" + ) + assert paths.policyengine_native_scores == ( + artifact_dir / "policyengine_native_scores.json" + ) + assert paths.source_weight_diagnostics == ( + artifact_dir / "source_weight_diagnostics.json" + ) + + def test_attach_policyengine_us_data_rebuild_checkpoint_evidence_computes_imputation_ablation_with_build_result( monkeypatch, tmp_path, @@ -1292,7 +1396,10 @@ def fake_build_checkpoint_imputation_ablation_payload( ) assert result.imputation_ablation_payload == imputation_ablation_payload assert result.native_audit_payload == native_audit_payload - assert result.native_audit_path == artifact_dir / "pe_us_data_rebuild_native_audit.json" + assert ( + result.native_audit_path + == artifact_dir / "pe_us_data_rebuild_native_audit.json" + ) assert result.imputation_ablation_path == artifact_dir / "imputation_ablation.json" assert ( written_manifest["artifacts"]["policyengine_native_audit"] diff --git a/tests/pipelines/test_stage_artifacts.py b/tests/pipelines/test_stage_artifacts.py new file mode 100644 index 0000000..567243b --- /dev/null +++ b/tests/pipelines/test_stage_artifacts.py @@ -0,0 +1,384 @@ +"""Tests for US stage artifact inventory helpers.""" + +import json + +import pandas as pd +import pytest + +from microplex_us.pipelines.stage_artifacts import ( + build_us_stage_artifact_inventory, + load_us_calibrated_stage_artifacts, + load_us_candidate_calibration_replay_artifacts, + load_us_candidate_stage_artifacts, + load_us_dataset_assembly_artifacts, + load_us_policyengine_entity_stage_artifacts, + load_us_seed_scaffold_stage_artifacts, + load_us_stage_artifact_inventory, + load_us_stage_json_artifact, + resolve_us_stage_artifact_from_inventory, + resolve_us_stage_artifact_path_checked, + write_us_stage_artifact_inventory, +) +from microplex_us.pipelines.stage_manifest import ( + write_us_policyengine_entity_stage_artifact, +) +from microplex_us.policyengine import PolicyEngineUSEntityTableBundle + + +def test_build_us_stage_artifact_inventory_hashes_files_and_directories(tmp_path): + (tmp_path / "seed_data.parquet").write_text("seed") + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + source_plan = tmp_path / "stage_artifacts" / "03_source_planning" / "source_plan.json" + source_plan.parent.mkdir(parents=True) + source_plan.write_text("{}") + entity_dir = tmp_path / "stage_artifacts" / "06_policyengine_entities" + entity_dir.mkdir(parents=True) + (entity_dir / "metadata.json").write_text("{}") + (entity_dir / "households.parquet").write_text("households") + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"seed": 1, "synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "seed_data": "seed_data.parquet", + "synthetic_data": "synthetic_data.parquet", + "source_plan": "stage_artifacts/03_source_planning/source_plan.json", + "policyengine_entity_tables": ( + "stage_artifacts/06_policyengine_entities/metadata.json" + ), + }, + } + + inventory = build_us_stage_artifact_inventory( + tmp_path, + manifest_payload=manifest, + max_hash_bytes=None, + ) + + records = { + (record["stageId"], record["key"]): record + for record in inventory["artifacts"] + } + assert records[("05_donor_integration_synthesis", "synthetic_data")][ + "classification" + ] == "manual_replay" + assert records[("05_donor_integration_synthesis", "synthetic_data")][ + "hashStatus" + ] == "hashed" + assert records[("05_donor_integration_synthesis", "synthetic_data")][ + "contentHash" + ] + assert records[("03_source_planning", "source_plan")]["classification"] == ( + "diagnostic_only" + ) + entity_record = records[("06_policyengine_entities", "policyengine_entity_tables")] + assert entity_record["classification"] == "manual_resume" + assert entity_record["fileCount"] == 2 + assert entity_record["hashStatus"] == "hashed" + + +def test_build_us_stage_artifact_inventory_classifies_missing_and_contract_only( + tmp_path, +): + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"seed": 1, "synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "seed_data": "seed_data.parquet", + "synthetic_data": "synthetic_data.parquet", + }, + } + + inventory = build_us_stage_artifact_inventory(tmp_path, manifest_payload=manifest) + + records = { + (record["stageId"], record["key"]): record + for record in inventory["artifacts"] + } + assert records[("05_donor_integration_synthesis", "synthetic_data")][ + "classification" + ] == "missing_required" + assert records[("05_donor_integration_synthesis", "synthesizer")][ + "classification" + ] == "contract_only" + + +def test_build_us_stage_artifact_inventory_skips_large_file_hashes(tmp_path): + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + + inventory = build_us_stage_artifact_inventory( + tmp_path, + manifest_payload=manifest, + max_hash_bytes=3, + ) + + record = next( + record + for record in inventory["artifacts"] + if record["key"] == "synthetic_data" + ) + assert record["hashStatus"] == "too_large" + assert record["contentHash"] is None + + +def test_write_load_and_resolve_us_stage_artifact_inventory(tmp_path): + (tmp_path / "policyengine_us.h5").write_text("dataset") + manifest = { + "config": {"calibration_backend": "none"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"policyengine_dataset": "policyengine_us.h5"}, + } + + path = write_us_stage_artifact_inventory( + tmp_path, + tmp_path / "stage_artifacts" / "artifact_inventory.json", + manifest_payload=manifest, + ) + loaded = load_us_stage_artifact_inventory(path) + dataset_path = resolve_us_stage_artifact_from_inventory( + tmp_path, + loaded, + "08_dataset_assembly", + "policyengine_dataset", + ) + + assert loaded["schemaVersion"] == 1 + assert dataset_path == tmp_path / "policyengine_us.h5" + + +def test_load_us_stage_artifact_inventory_rejects_unknown_schema(tmp_path): + path = tmp_path / "artifact_inventory.json" + path.write_text(json.dumps({"schemaVersion": 99})) + + with pytest.raises(RuntimeError, match="Unsupported US stage artifact inventory"): + load_us_stage_artifact_inventory(path) + + +def test_load_us_candidate_stage_artifacts_reads_stage5_boundary(tmp_path): + pytest.importorskip("pyarrow") + seed = pd.DataFrame({"person_id": [1], "income": [20]}) + synthetic = pd.DataFrame({"person_id": [1, 2], "income": [20, 30]}) + seed.to_parquet(tmp_path / "seed_data.parquet", index=False) + synthetic.to_parquet(tmp_path / "synthetic_data.parquet", index=False) + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"seed": 1, "synthetic": 2}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "seed_data": "seed_data.parquet", + "synthetic_data": "synthetic_data.parquet", + }, + } + + loaded = load_us_candidate_stage_artifacts(tmp_path, manifest_payload=manifest) + + pd.testing.assert_frame_equal(loaded.seed_data, seed) + pd.testing.assert_frame_equal(loaded.synthetic_data, synthetic) + assert loaded.artifact_paths["synthetic_data"] == tmp_path / "synthetic_data.parquet" + + +def test_load_us_seed_scaffold_stage_artifacts_reads_stage4_boundary(tmp_path): + pytest.importorskip("pyarrow") + scaffold = pd.DataFrame({"person_id": [1], "income": [10]}) + scaffold_path = ( + tmp_path / "stage_artifacts" / "04_seed_scaffold" / "scaffold_seed_data.parquet" + ) + scaffold_path.parent.mkdir(parents=True) + scaffold.to_parquet(scaffold_path, index=False) + manifest = { + "config": {"calibration_backend": "none"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "scaffold_seed_data": ( + "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet" + ), + }, + } + + loaded = load_us_seed_scaffold_stage_artifacts(tmp_path, manifest_payload=manifest) + + pd.testing.assert_frame_equal(loaded.scaffold_seed_data, scaffold) + assert loaded.artifact_paths["scaffold_seed_data"] == scaffold_path + + +def test_load_us_candidate_calibration_replay_artifacts_combines_boundaries( + tmp_path, +): + pytest.importorskip("pyarrow") + scaffold = pd.DataFrame({"person_id": [1], "income": [10]}) + seed = pd.DataFrame({"person_id": [1], "income": [20]}) + synthetic = pd.DataFrame({"person_id": [1, 2], "income": [20, 30]}) + scaffold_path = ( + tmp_path / "stage_artifacts" / "04_seed_scaffold" / "scaffold_seed_data.parquet" + ) + scaffold_path.parent.mkdir(parents=True) + scaffold.to_parquet(scaffold_path, index=False) + seed.to_parquet(tmp_path / "seed_data.parquet", index=False) + synthetic.to_parquet(tmp_path / "synthetic_data.parquet", index=False) + (tmp_path / "targets.json").write_text( + json.dumps({"marginal": {"age": {"20": 1.0}}, "continuous": {"income": 1.0}}) + ) + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"seed": 1, "synthetic": 2}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "scaffold_seed_data": ( + "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet" + ), + "seed_data": "seed_data.parquet", + "synthetic_data": "synthetic_data.parquet", + "targets": "targets.json", + }, + } + + loaded = load_us_candidate_calibration_replay_artifacts( + tmp_path, + manifest_payload=manifest, + ) + + pd.testing.assert_frame_equal(loaded.candidate.synthetic_data, synthetic) + assert loaded.seed_scaffold is not None + pd.testing.assert_frame_equal(loaded.seed_scaffold.scaffold_seed_data, scaffold) + assert loaded.targets.continuous == {"income": 1.0} + assert loaded.artifact_paths["targets"] == tmp_path / "targets.json" + + +def test_load_us_policyengine_entity_stage_artifacts_reads_checkpoint(tmp_path): + pytest.importorskip("pyarrow") + bundle = PolicyEngineUSEntityTableBundle( + households=pd.DataFrame({"household_id": [1], "household_weight": [1.0]}), + persons=pd.DataFrame({"person_id": [10], "household_id": [1]}), + tax_units=None, + spm_units=None, + families=None, + marital_units=None, + ) + write_us_policyengine_entity_stage_artifact(bundle, tmp_path) + manifest = { + "config": {"calibration_backend": "none"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "policyengine_entity_tables": ( + "stage_artifacts/06_policyengine_entities/metadata.json" + ), + }, + } + + loaded = load_us_policyengine_entity_stage_artifacts( + tmp_path, + manifest_payload=manifest, + ) + + assert loaded.metadata["stageId"] == "06_policyengine_entities" + pd.testing.assert_frame_equal(loaded.bundle.households, bundle.households) + + +def test_load_us_calibrated_stage_artifacts_reads_stage7_outputs(tmp_path): + pytest.importorskip("pyarrow") + calibrated = pd.DataFrame({"person_id": [1], "weight": [2.0]}) + calibrated.to_parquet(tmp_path / "calibrated_data.parquet", index=False) + (tmp_path / "targets.json").write_text( + json.dumps({"marginal": {}, "continuous": {"income": 1.0}}) + ) + summary_path = tmp_path / "stage_artifacts" / "07_calibration" + summary_path.mkdir(parents=True) + (summary_path / "calibration_summary.json").write_text( + json.dumps({"backend": "none", "converged": True}) + ) + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"calibrated": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {"backend": "none"}, + "artifacts": { + "calibrated_data": "calibrated_data.parquet", + "targets": "targets.json", + "calibration_summary": ( + "stage_artifacts/07_calibration/calibration_summary.json" + ), + }, + } + + loaded = load_us_calibrated_stage_artifacts(tmp_path, manifest_payload=manifest) + + pd.testing.assert_frame_equal(loaded.calibrated_data, calibrated) + assert loaded.targets.continuous == {"income": 1.0} + assert loaded.calibration_summary["converged"] is True + + +def test_load_us_dataset_assembly_artifacts_resolves_stage8_paths(tmp_path): + (tmp_path / "manifest.json").write_text("{}") + (tmp_path / "stage_manifest.json").write_text("{}") + (tmp_path / "data_flow_snapshot.json").write_text("{}") + (tmp_path / "policyengine_us.h5").write_text("dataset") + stage_artifacts = tmp_path / "stage_artifacts" + stage_artifacts.mkdir() + (stage_artifacts / "artifact_inventory.json").write_text("{}") + (stage_artifacts / "conditional_readiness.json").write_text("{}") + manifest = { + "config": {"calibration_backend": "none"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "policyengine_dataset": "policyengine_us.h5", + "stage_manifest": "stage_manifest.json", + "data_flow_snapshot": "data_flow_snapshot.json", + "artifact_inventory": "stage_artifacts/artifact_inventory.json", + "conditional_readiness": "stage_artifacts/conditional_readiness.json", + }, + } + + loaded = load_us_dataset_assembly_artifacts(tmp_path, manifest_payload=manifest) + + assert loaded.policyengine_dataset == tmp_path / "policyengine_us.h5" + assert loaded.stage_manifest == tmp_path / "stage_manifest.json" + assert loaded.data_flow_snapshot == tmp_path / "data_flow_snapshot.json" + assert loaded.artifact_inventory == stage_artifacts / "artifact_inventory.json" + assert loaded.conditional_readiness == stage_artifacts / "conditional_readiness.json" + + +def test_stage_artifact_checked_resolver_enforces_format_and_existence(tmp_path): + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + manifest = { + "config": {"calibration_backend": "none"}, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + + with pytest.raises(ValueError, match="expected 'json'"): + resolve_us_stage_artifact_path_checked( + tmp_path, + "05_donor_integration_synthesis", + "synthetic_data", + manifest_payload=manifest, + expected_format="json", + ) + + with pytest.raises(FileNotFoundError, match="Stage artifact not found"): + load_us_stage_json_artifact( + tmp_path, + "03_source_planning", + "source_plan", + manifest_payload={ + **manifest, + "artifacts": {"source_plan": "missing.json"}, + }, + ) diff --git a/tests/pipelines/test_stage_contracts.py b/tests/pipelines/test_stage_contracts.py index 933e807..a530c52 100644 --- a/tests/pipelines/test_stage_contracts.py +++ b/tests/pipelines/test_stage_contracts.py @@ -4,8 +4,11 @@ from microplex_us.pipelines.stage_contracts import ( canonicalize_us_pipeline_stage_id, + config_keys_for_us_pipeline_stage, default_us_pipeline_stage_contracts, get_us_pipeline_stage_contract, + get_us_stage_artifact_contract, + resolve_us_stage_artifact_contract_path, serialize_us_pipeline_stage_contracts, ) @@ -33,9 +36,21 @@ def test_default_us_pipeline_stage_contracts_are_stable_and_complete(): assert contract.purpose assert contract.consumes assert contract.produces + assert contract.inputs + assert contract.outputs assert contract.diagnostics assert contract.validations assert contract.resume_mode + for artifact in contract.artifacts: + assert artifact.format + assert artifact.hash_mode + if artifact.resume_role is not None: + assert artifact.resume_role in { + "diagnostic", + "manual_replay", + "manual_resume", + "post_artifact_evidence", + } def test_get_us_pipeline_stage_contract_returns_one_stage(): @@ -54,9 +69,12 @@ def test_serialize_us_pipeline_stage_contracts_is_json_ready(): payload = serialize_us_pipeline_stage_contracts() assert payload["schemaVersion"] == 1 - assert payload["contractVersion"] == "us-runtime-stages-v1" + assert payload["contractVersion"] == "us-runtime-stages-v2" assert len(payload["stages"]) == 9 assert payload["stages"][5]["id"] == "06_policyengine_entities" + assert payload["stages"][5]["inputs"][0]["artifact_key"] == "synthetic_data" + assert payload["stages"][7]["artifacts"][-1]["key"] == "conditional_readiness" + assert payload["stages"][7]["artifacts"][-1]["format"] == "json" def test_canonicalize_us_pipeline_stage_id_maps_legacy_runtime_ids(): @@ -69,3 +87,73 @@ def test_canonicalize_us_pipeline_stage_id_maps_legacy_runtime_ids(): assert canonicalize_us_pipeline_stage_id("benchmark") == "09_validation_benchmarking" assert canonicalize_us_pipeline_stage_id("08_dataset_assembly") == "08_dataset_assembly" assert canonicalize_us_pipeline_stage_id("custom-stage") == "custom-stage" + + +def test_stage_contracts_expose_config_scope_and_canonical_paths(tmp_path): + assert "n_synthetic" in config_keys_for_us_pipeline_stage( + "05_donor_integration_synthesis" + ) + assert resolve_us_stage_artifact_contract_path( + tmp_path, + "08_dataset_assembly", + "artifact_inventory", + ) == (tmp_path / "stage_artifacts" / "artifact_inventory.json") + + +def test_required_stage_inputs_reference_prior_outputs_and_artifacts(): + contracts = default_us_pipeline_stage_contracts() + contracts_by_id = {contract.id: contract for contract in contracts} + + for contract in contracts: + for resource in contract.inputs: + if not resource.required: + continue + if resource.kind == "stage_output": + assert resource.stage_id is not None + upstream = contracts_by_id[resource.stage_id] + assert any( + output.key == resource.key + and output.kind == "stage_output" + and output.stage_id == resource.stage_id + for output in upstream.outputs + ) + if resource.kind == "artifact": + assert resource.stage_id is not None + artifact = get_us_stage_artifact_contract( + resource.stage_id, + resource.artifact_key or resource.key, + ) + assert artifact.required + + +def test_source_planning_seam_exposes_descriptors_for_stage3(): + stage2 = get_us_pipeline_stage_contract("02_source_loading") + stage3 = get_us_pipeline_stage_contract("03_source_planning") + + stage2_outputs = {resource.key for resource in stage2.outputs} + stage3_inputs = {resource.key for resource in stage3.inputs} + + assert "source_descriptors" in stage2_outputs + assert "source_descriptors" in stage3_inputs + + +def test_stage_config_scopes_use_real_build_config_keys(): + stage5_keys = set(config_keys_for_us_pipeline_stage("05_donor_integration_synthesis")) + stage7_keys = set(config_keys_for_us_pipeline_stage("07_calibration")) + + assert { + "n_synthetic", + "random_seed", + "synthesis_backend", + "donor_imputer_backend", + "donor_imputer_condition_selection", + } <= stage5_keys + assert "calibration_epochs" not in stage7_keys + assert "calibration_l0_lambda" not in stage7_keys + assert { + "calibration_backend", + "calibration_tol", + "calibration_max_iter", + "target_sparsity", + "policyengine_targets_db", + } <= stage7_keys diff --git a/tests/pipelines/test_stage_manifest.py b/tests/pipelines/test_stage_manifest.py index ac8e68f..954e9ac 100644 --- a/tests/pipelines/test_stage_manifest.py +++ b/tests/pipelines/test_stage_manifest.py @@ -8,6 +8,7 @@ from microplex_us.pipelines.stage_manifest import ( build_us_stage_manifest, load_us_policyengine_entity_stage_artifact, + load_us_stage_manifest, resolve_us_stage_artifact_path, stage_summary_for_data_flow_snapshot, write_us_policyengine_entity_stage_artifact, @@ -28,6 +29,19 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): (tmp_path / "calibrated_data.parquet").write_text("calibrated") (tmp_path / "targets.json").write_text("{}") (tmp_path / "policyengine_us.h5").write_text("dataset") + source_plan_path = tmp_path / "stage_artifacts" / "03_source_planning" + source_plan_path.mkdir(parents=True) + (source_plan_path / "source_plan.json").write_text("{}") + entity_path = tmp_path / "stage_artifacts" / "06_policyengine_entities" + entity_path.mkdir(parents=True) + (entity_path / "metadata.json").write_text("{}") + calibration_path = tmp_path / "stage_artifacts" / "07_calibration" + calibration_path.mkdir(parents=True) + (calibration_path / "calibration_summary.json").write_text("{}") + (tmp_path / "stage_manifest.json").write_text("{}") + (tmp_path / "data_flow_snapshot.json").write_text("{}") + (tmp_path / "stage_artifacts" / "artifact_inventory.json").write_text("{}") + (tmp_path / "stage_artifacts" / "conditional_readiness.json").write_text("{}") manifest = { "created_at": "2026-05-28T00:00:00+00:00", "config": {"calibration_backend": "entropy"}, @@ -47,13 +61,24 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): "synthetic_data": "synthetic_data.parquet", "calibrated_data": "calibrated_data.parquet", "targets": "targets.json", + "source_plan": "stage_artifacts/03_source_planning/source_plan.json", + "policyengine_entity_tables": ( + "stage_artifacts/06_policyengine_entities/metadata.json" + ), + "calibration_summary": ( + "stage_artifacts/07_calibration/calibration_summary.json" + ), "policyengine_dataset": "policyengine_us.h5", + "stage_manifest": "stage_manifest.json", + "data_flow_snapshot": "data_flow_snapshot.json", + "artifact_inventory": "stage_artifacts/artifact_inventory.json", + "conditional_readiness": "stage_artifacts/conditional_readiness.json", }, } payload = build_us_stage_manifest(tmp_path, manifest_payload=manifest) - assert payload["schemaVersion"] == 1 + assert payload["schemaVersion"] == 2 assert payload["generatedAt"] == "2026-05-28T00:00:00+00:00" assert [stage["id"] for stage in payload["stages"]] == [ "01_run_profile", @@ -69,17 +94,67 @@ def test_build_us_stage_manifest_reports_nine_stage_statuses(tmp_path): statuses = {stage["id"]: stage["status"] for stage in payload["stages"]} assert statuses["01_run_profile"] == "ready" assert statuses["02_source_loading"] == "metadata_only" - assert statuses["03_source_planning"] == "metadata_only" + assert statuses["03_source_planning"] == "ready" assert statuses["04_seed_scaffold"] == "ready" assert statuses["05_donor_integration_synthesis"] == "ready" - assert statuses["06_policyengine_entities"] == "metadata_only" + assert statuses["06_policyengine_entities"] == "ready" assert statuses["07_calibration"] == "ready" assert statuses["08_dataset_assembly"] == "ready" assert statuses["09_validation_benchmarking"] == "deferred" + stage5_artifacts = { + artifact["key"]: artifact + for stage in payload["stages"] + if stage["id"] == "05_donor_integration_synthesis" + for artifact in stage["artifacts"] + } + assert stage5_artifacts["synthetic_data"]["format"] == "parquet_dataframe" + assert stage5_artifacts["synthetic_data"]["hash_mode"] == "file_sha256" + + +def test_load_us_stage_manifest_accepts_v1_and_v2(tmp_path): + v1_path = tmp_path / "stage_manifest_v1.json" + v1_path.write_text( + json.dumps( + { + "schemaVersion": 1, + "contractVersion": "us-runtime-stages-v1", + "generatedAt": None, + "pipeline": "us_microplex", + "artifactRoot": ".", + "manifest": "manifest.json", + "stages": [], + } + ) + ) + v2_path = tmp_path / "stage_manifest_v2.json" + v2_path.write_text( + json.dumps( + { + "schemaVersion": 2, + "contractVersion": "us-runtime-stages-v2", + "generatedAt": None, + "pipeline": "us_microplex", + "artifactRoot": ".", + "manifest": "manifest.json", + "stages": [], + } + ) + ) + + assert load_us_stage_manifest(v1_path)["schemaVersion"] == 1 + assert load_us_stage_manifest(v2_path)["schemaVersion"] == 2 def test_build_us_stage_manifest_keeps_empty_validation_index_deferred(tmp_path): (tmp_path / "policyengine_us.h5").write_text("dataset") + (tmp_path / "stage_manifest.json").write_text("{}") + (tmp_path / "data_flow_snapshot.json").write_text("{}") + (tmp_path / "stage_artifacts" / "artifact_inventory.json").parent.mkdir( + parents=True, + exist_ok=True, + ) + (tmp_path / "stage_artifacts" / "artifact_inventory.json").write_text("{}") + (tmp_path / "stage_artifacts" / "conditional_readiness.json").write_text("{}") evidence_path = ( tmp_path / "stage_artifacts" @@ -103,6 +178,10 @@ def test_build_us_stage_manifest_keeps_empty_validation_index_deferred(tmp_path) "calibration": {}, "artifacts": { "policyengine_dataset": "policyengine_us.h5", + "stage_manifest": "stage_manifest.json", + "data_flow_snapshot": "data_flow_snapshot.json", + "artifact_inventory": "stage_artifacts/artifact_inventory.json", + "conditional_readiness": "stage_artifacts/conditional_readiness.json", "validation_evidence": ( "stage_artifacts/09_validation_benchmarking/evidence_manifest.json" ), @@ -115,6 +194,27 @@ def test_build_us_stage_manifest_keeps_empty_validation_index_deferred(tmp_path) assert statuses["09_validation_benchmarking"] == "deferred" +def test_build_us_stage_manifest_requires_validation_evidence_for_stage9_ready( + tmp_path, +): + (tmp_path / "policyengine_us.h5").write_text("dataset") + (tmp_path / "policyengine_native_scores.json").write_text("{}") + manifest = { + "config": {"calibration_backend": "entropy"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": { + "policyengine_dataset": "policyengine_us.h5", + "policyengine_native_scores": "policyengine_native_scores.json", + }, + } + + payload = build_us_stage_manifest(tmp_path, manifest_payload=manifest) + + statuses = {stage["id"]: stage["status"] for stage in payload["stages"]} + assert statuses["09_validation_benchmarking"] == "incomplete" + + def test_stage_summary_omits_unreferenced_path_hints(tmp_path): manifest = { "config": {"calibration_backend": "entropy"}, diff --git a/tests/pipelines/test_stage_readiness.py b/tests/pipelines/test_stage_readiness.py new file mode 100644 index 0000000..2a72fa4 --- /dev/null +++ b/tests/pipelines/test_stage_readiness.py @@ -0,0 +1,222 @@ +"""Tests for US conditional-readiness reports.""" + +import json + +import pytest + +from microplex_us.pipelines.stage_artifacts import build_us_stage_artifact_inventory +from microplex_us.pipelines.stage_readiness import ( + build_us_conditional_readiness_report, + build_us_stage_reuse_key, + load_us_conditional_readiness_report, + write_us_conditional_readiness_report, +) + + +def test_build_us_stage_reuse_key_ignores_checkpoint_output_paths(tmp_path): + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + base_manifest = { + "config": { + "n_synthetic": 10, + "calibration_backend": "none", + "pipeline_checkpoint_save_post_microsim_path": "/tmp/a", + }, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + changed_output_path_manifest = { + **base_manifest, + "config": { + **base_manifest["config"], + "pipeline_checkpoint_save_post_microsim_path": "/tmp/b", + }, + } + + inventory = build_us_stage_artifact_inventory( + tmp_path, + manifest_payload=base_manifest, + max_hash_bytes=None, + ) + + assert build_us_stage_reuse_key( + "05_donor_integration_synthesis", + base_manifest, + inventory, + ) == build_us_stage_reuse_key( + "05_donor_integration_synthesis", + changed_output_path_manifest, + inventory, + ) + + +def test_build_us_stage_reuse_key_uses_stage_scoped_config(tmp_path): + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + base_manifest = { + "config": { + "n_synthetic": 10, + "synthesis_backend": "bootstrap", + "policyengine_dataset_year": 2024, + }, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + changed_stage8_config = { + **base_manifest, + "config": { + **base_manifest["config"], + "policyengine_dataset_year": 2025, + }, + } + changed_stage5_config = { + **base_manifest, + "config": { + **base_manifest["config"], + "n_synthetic": 20, + }, + } + inventory = build_us_stage_artifact_inventory( + tmp_path, + manifest_payload=base_manifest, + max_hash_bytes=None, + ) + + base_key = build_us_stage_reuse_key( + "05_donor_integration_synthesis", + base_manifest, + inventory, + ) + assert base_key == build_us_stage_reuse_key( + "05_donor_integration_synthesis", + changed_stage8_config, + inventory, + ) + assert base_key != build_us_stage_reuse_key( + "05_donor_integration_synthesis", + changed_stage5_config, + inventory, + ) + + +def test_conditional_readiness_reports_config_mismatch_as_rerun(tmp_path): + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + manifest = { + "config": {"n_synthetic": 10, "calibration_backend": "none"}, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + + report = build_us_conditional_readiness_report( + tmp_path, + manifest_payload=manifest, + requested_config={"n_synthetic": 20, "calibration_backend": "none"}, + ) + + stages = {stage["stageId"]: stage for stage in report["stages"]} + assert stages["05_donor_integration_synthesis"]["compatibility"] == "mismatch" + assert stages["05_donor_integration_synthesis"]["readiness"] == "must_rerun" + assert stages["05_donor_integration_synthesis"]["reason"] == ( + "Requested configuration does not match this stage's saved run inputs." + ) + assert stages["08_dataset_assembly"]["compatibility"] == "match" + + +def test_conditional_readiness_reports_manual_replay_without_requested_config(tmp_path): + (tmp_path / "synthetic_data.parquet").write_text("synthetic") + manifest = { + "config": {"n_synthetic": 10, "calibration_backend": "none"}, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + + report = build_us_conditional_readiness_report( + tmp_path, + manifest_payload=manifest, + ) + + stages = {stage["stageId"]: stage for stage in report["stages"]} + assert stages["05_donor_integration_synthesis"]["compatibility"] == ( + "not_evaluated" + ) + assert stages["05_donor_integration_synthesis"]["readiness"] == "manual_replay" + assert stages["05_donor_integration_synthesis"]["reloadableArtifacts"] == [ + "05_donor_integration_synthesis.synthetic_data" + ] + + +def test_conditional_readiness_reports_missing_required_artifacts_as_rerun(tmp_path): + manifest = { + "config": {"n_synthetic": 10, "calibration_backend": "none"}, + "rows": {"synthetic": 1}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"synthetic_data": "synthetic_data.parquet"}, + } + + report = build_us_conditional_readiness_report( + tmp_path, + manifest_payload=manifest, + ) + + stages = {stage["stageId"]: stage for stage in report["stages"]} + assert stages["05_donor_integration_synthesis"]["readiness"] == "must_rerun" + assert "05_donor_integration_synthesis.synthetic_data" in stages[ + "05_donor_integration_synthesis" + ]["missingArtifacts"] + + +def test_conditional_readiness_reports_stage9_from_stage8_dataset(tmp_path): + (tmp_path / "policyengine_us.h5").write_text("dataset") + manifest = { + "config": {"calibration_backend": "none"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"policyengine_dataset": "policyengine_us.h5"}, + } + + report = build_us_conditional_readiness_report( + tmp_path, + manifest_payload=manifest, + ) + + stages = {stage["stageId"]: stage for stage in report["stages"]} + assert stages["09_validation_benchmarking"]["status"] == "deferred" + assert stages["09_validation_benchmarking"]["readiness"] == ( + "post_artifact_evidence" + ) + + +def test_write_and_load_us_conditional_readiness_report(tmp_path): + (tmp_path / "policyengine_us.h5").write_text("dataset") + manifest = { + "config": {"calibration_backend": "none"}, + "synthesis": {"source_names": ["source"], "scaffold_source": "source"}, + "calibration": {}, + "artifacts": {"policyengine_dataset": "policyengine_us.h5"}, + } + + path = write_us_conditional_readiness_report( + tmp_path, + tmp_path / "stage_artifacts" / "conditional_readiness.json", + manifest_payload=manifest, + ) + loaded = load_us_conditional_readiness_report(path) + + assert loaded["schemaVersion"] == 1 + assert loaded["generatedAt"] is None + assert loaded["stages"][0]["stageId"] == "01_run_profile" + + +def test_load_us_conditional_readiness_report_rejects_unknown_schema(tmp_path): + path = tmp_path / "conditional_readiness.json" + path.write_text(json.dumps({"schemaVersion": 99})) + + with pytest.raises(RuntimeError, match="Unsupported US conditional-readiness"): + load_us_conditional_readiness_report(path) diff --git a/tests/pipelines/test_stage_run.py b/tests/pipelines/test_stage_run.py new file mode 100644 index 0000000..c4d0eba --- /dev/null +++ b/tests/pipelines/test_stage_run.py @@ -0,0 +1,563 @@ +"""Tests for typed US stage-run output manifests.""" + +import json +from dataclasses import fields + +import pytest + +from microplex_us.pipelines.stage_contracts import ( + US_CANONICAL_STAGE_IDS, + get_us_pipeline_stage_contract, +) +from microplex_us.pipelines.stage_run import ( + US_STAGE_OUTPUT_MANIFEST_TYPES, + USArtifactRef, + USAuxiliaryArtifact, + USDiagnosticOutput, + USRunProfileOutputs, + USSourceLoadingOutputs, + USSourcePlanningOutputs, + USStageInputOverride, + USStageRunWriter, + build_us_stage_output_manifests_from_artifact_manifest, + parse_us_stage_input_override, + write_us_stage_run_manifests_from_artifact_manifest, +) + +_BASE_STAGE_MANIFEST_FIELDS = { + "schema_version", + "contract_version", + "input_stage_manifest", + "diagnostics", + "auxiliary_artifacts", + "metadata", + "complete", + "stage_id", +} + + +def test_every_canonical_stage_has_typed_output_manifest(): + assert tuple(US_STAGE_OUTPUT_MANIFEST_TYPES) == US_CANONICAL_STAGE_IDS + + +def test_stage_output_manifests_use_contract_outputs_as_required_source(): + for stage_id, manifest_type in US_STAGE_OUTPUT_MANIFEST_TYPES.items(): + contract = get_us_pipeline_stage_contract(stage_id) + expected = tuple( + resource.key for resource in contract.outputs if resource.required + ) + output = manifest_type() + + assert output.required_output_keys() == expected + assert set(expected) <= {item.name for item in fields(manifest_type)} + + +def test_stage_output_manifest_fields_are_declared_by_contracts(): + for stage_id, manifest_type in US_STAGE_OUTPUT_MANIFEST_TYPES.items(): + contract = get_us_pipeline_stage_contract(stage_id) + contract_output_keys = {resource.key for resource in contract.outputs} + contract_artifact_keys = {artifact.key for artifact in contract.artifacts} + typed_output_fields = { + item.name + for item in fields(manifest_type) + if item.name not in _BASE_STAGE_MANIFEST_FIELDS + } + + assert contract_output_keys <= typed_output_fields + assert typed_output_fields <= contract_output_keys | contract_artifact_keys + + +def test_stage_run_writer_records_typed_stage_manifests(tmp_path): + _write_artifact_bundle_files(tmp_path) + manifest = _artifact_manifest() + + updated_manifest = write_us_stage_run_manifests_from_artifact_manifest( + tmp_path, + manifest, + ) + + assert (tmp_path / "manifest.json").exists() + assert ( + tmp_path + / "stage_artifacts" + / "manifests" + / "05_donor_integration_synthesis.json" + ).exists() + assert ( + tmp_path / "stage_artifacts" / "manifests" / "09_validation_benchmarking.json" + ).exists() + assert ( + updated_manifest["stage_output_manifests"]["07_calibration"] + == "stage_artifacts/manifests/07_calibration.json" + ) + stage5_manifest = json.loads( + ( + tmp_path + / "stage_artifacts" + / "manifests" + / "05_donor_integration_synthesis.json" + ).read_text() + ) + assert stage5_manifest["stageId"] == "05_donor_integration_synthesis" + assert stage5_manifest["diagnostics"] + assert stage5_manifest["inputStageManifest"] == ( + "stage_artifacts/manifests/04_seed_scaffold.json" + ) + + +def test_stage_run_writer_rejects_missing_diagnostics(tmp_path): + writer = USStageRunWriter(tmp_path) + output = USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config={"n_synthetic": 10}, + provider_query_plan={"source_names": ["source"]}, + ) + + with pytest.raises(ValueError, match="does not expose diagnostics"): + writer.record_stage(output) + + +def test_stage_run_writer_requires_prior_stage_or_override(tmp_path): + output = USSourceLoadingOutputs( + observation_frame_summary={"source_count": 1}, + source_descriptors=("source",), + source_relationships={"status": "summarized"}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"source_names": ["source"]}, + ) + }, + ) + + with pytest.raises(ValueError, match="requires 01_run_profile"): + USStageRunWriter(tmp_path).record_stage(output) + + with pytest.raises(ValueError, match="require allow_stage_input_overrides"): + USStageRunWriter( + tmp_path, + stage_input_overrides=( + USStageInputOverride( + stage_id="02_source_loading", + key="provider_query_plan", + path="overrides/provider_query_plan.json", + ), + ), + ) + + writer = USStageRunWriter( + tmp_path, + allow_stage_input_overrides=True, + stage_input_overrides=( + USStageInputOverride( + stage_id="02_source_loading", + key="provider_query_plan", + path="overrides/provider_query_plan.json", + reason="test override", + ), + ), + ) + writer.record_stage(output) + assert writer.recorded_stages == (output,) + + +def test_stage_run_writer_requires_specific_input_override(tmp_path): + output = USSourceLoadingOutputs( + observation_frame_summary={"source_count": 1}, + source_descriptors=("source",), + source_relationships={"status": "summarized"}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"source_names": ["source"]}, + ) + }, + ) + + writer = USStageRunWriter( + tmp_path, + allow_stage_input_overrides=True, + stage_input_overrides=( + USStageInputOverride( + stage_id="02_source_loading", + key="source_datasets", + path="overrides/source_datasets.json", + ), + ), + ) + + with pytest.raises(ValueError, match="provider_query_plan"): + writer.record_stage(output) + + +def test_stage_run_writer_validates_required_inputs_from_prior_manifest(tmp_path): + writer = USStageRunWriter(tmp_path) + writer.record_stage( + USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config={"n_synthetic": 10}, + provider_query_plan={}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"has_config": True}, + ) + }, + complete=False, + ) + ) + output = USSourceLoadingOutputs( + observation_frame_summary={"source_count": 1}, + source_descriptors=("source",), + source_relationships={"status": "summarized"}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"source_names": ["source"]}, + ) + }, + ) + + with pytest.raises(ValueError, match="01_run_profile.provider_query_plan"): + writer.record_stage(output) + + +def test_stage_run_writer_requires_prior_stage_even_without_stage_bound_inputs( + tmp_path, +): + output = USSourcePlanningOutputs( + scaffold_selection={"scaffold_source": "source"}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"scaffold_source": "source"}, + ) + }, + complete=False, + ) + + with pytest.raises(ValueError, match="requires 02_source_loading"): + USStageRunWriter(tmp_path).record_stage(output) + + +def test_stage_run_writer_rejects_arbitrary_input_manifest(tmp_path): + arbitrary_manifest = tmp_path / "arbitrary.json" + arbitrary_manifest.write_text("{}") + output = USSourceLoadingOutputs( + input_stage_manifest="arbitrary.json", + observation_frame_summary={"source_count": 1}, + source_descriptors=("source",), + source_relationships={"status": "summarized"}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"source_names": ["source"]}, + ) + }, + ) + + with pytest.raises(ValueError, match="requires 01_run_profile"): + USStageRunWriter(tmp_path).record_stage(output) + + +def test_stage_run_writer_rejects_empty_required_structured_outputs(tmp_path): + output = USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config={}, + provider_query_plan={"source_names": ["source"]}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"has_config": False}, + ) + }, + ) + + with pytest.raises(ValueError, match="resolved_config"): + USStageRunWriter(tmp_path).record_stage(output) + + +def test_stage_run_writer_rejects_undeclared_auxiliary_artifact(tmp_path): + writer = USStageRunWriter(tmp_path) + output = USRunProfileOutputs( + manifest=USArtifactRef( + key="manifest", + path="manifest.json", + format="json", + required=True, + assume_exists=True, + ), + resolved_config={"n_synthetic": 10}, + provider_query_plan={"source_names": ["source"]}, + diagnostics={ + "stage_summary": USDiagnosticOutput( + key="stage_summary", + summary={"has_config": True}, + ) + }, + auxiliary_artifacts={ + "not_declared": USAuxiliaryArtifact( + key="not_declared", + path="not_declared.json", + format="json", + ) + }, + ) + + with pytest.raises(KeyError, match="not declared"): + writer.update(output) + + +def test_parse_us_stage_input_override(): + override = parse_us_stage_input_override( + "02_source_loading.provider_query_plan=overrides/provider_query_plan.json" + ) + + assert override == USStageInputOverride( + stage_id="02_source_loading", + key="provider_query_plan", + path="overrides/provider_query_plan.json", + ) + + with pytest.raises(ValueError, match="STAGE_ID.KEY=PATH"): + parse_us_stage_input_override("02_source_loading=missing-key") + + with pytest.raises(ValueError, match="Unknown US pipeline stage"): + parse_us_stage_input_override("unknown_stage.provider_query_plan=override.json") + + with pytest.raises(ValueError, match="Unknown input override key"): + parse_us_stage_input_override("02_source_loading.not_an_input=override.json") + + +def test_build_stage_outputs_from_manifest_exposes_diagnostics(tmp_path): + _write_artifact_bundle_files(tmp_path) + outputs = build_us_stage_output_manifests_from_artifact_manifest( + tmp_path, + _artifact_manifest(), + ) + + assert len(outputs) == 9 + assert all(output.diagnostics for output in outputs) + stage6 = outputs[5] + assert "policyengine_dataset" not in stage6.materialized_policyengine_inputs + assert stage6.materialized_policyengine_inputs["tables"]["households"]["rows"] == 1 + + +def test_build_stage_outputs_treats_missing_declared_dataset_as_incomplete( + tmp_path, +): + _write_artifact_bundle_files(tmp_path) + (tmp_path / "policyengine_us.h5").unlink() + + outputs = build_us_stage_output_manifests_from_artifact_manifest( + tmp_path, + _artifact_manifest(), + ) + + stage8 = outputs[7] + assert stage8.complete is False + assert stage8.missing_required_outputs(tmp_path) == ("policyengine_dataset",) + + +def test_build_stage_outputs_hydrates_stage9_summary_from_validation_evidence( + tmp_path, +): + _write_artifact_bundle_files(tmp_path) + evidence_path = _write_validation_evidence_manifest(tmp_path) + manifest = _artifact_manifest() + manifest.pop("policyengine_native_scores") + manifest["artifacts"]["validation_evidence"] = str( + evidence_path.relative_to(tmp_path) + ) + + outputs = build_us_stage_output_manifests_from_artifact_manifest( + tmp_path, + manifest, + ) + + stage9 = outputs[8] + assert stage9.complete is True + assert stage9.benchmark_summary == { + "policyengine_native_scores": { + "enhanced_cps_native_loss_delta": -0.1, + } + } + assert stage9.diagnostics["stage_summary"].summary == stage9.benchmark_summary + + +def test_build_stage_outputs_does_not_complete_stage9_from_stale_evidence_summary( + tmp_path, +): + _write_artifact_bundle_files(tmp_path) + evidence_path = _write_validation_evidence_manifest(tmp_path) + (tmp_path / "policyengine_native_scores.json").unlink() + manifest = _artifact_manifest() + manifest.pop("policyengine_native_scores") + manifest["artifacts"]["validation_evidence"] = str( + evidence_path.relative_to(tmp_path) + ) + + outputs = build_us_stage_output_manifests_from_artifact_manifest( + tmp_path, + manifest, + ) + + stage9 = outputs[8] + assert stage9.complete is False + assert stage9.benchmark_summary == {} + + +def test_stage_run_writer_preserves_existing_validation_evidence_summary( + tmp_path, +): + _write_artifact_bundle_files(tmp_path) + evidence_path = _write_validation_evidence_manifest(tmp_path) + manifest = _artifact_manifest() + manifest.pop("policyengine_native_scores") + manifest["artifacts"]["validation_evidence"] = str( + evidence_path.relative_to(tmp_path) + ) + + write_us_stage_run_manifests_from_artifact_manifest(tmp_path, manifest) + + stage9_manifest = json.loads( + ( + tmp_path + / "stage_artifacts" + / "manifests" + / "09_validation_benchmarking.json" + ).read_text() + ) + rewritten_evidence = json.loads(evidence_path.read_text()) + + assert stage9_manifest["complete"] is True + assert stage9_manifest["outputs"]["benchmark_summary"] == { + "policyengine_native_scores": { + "enhanced_cps_native_loss_delta": -0.1, + } + } + assert rewritten_evidence["summaries"] == { + "policyengine_native_scores": { + "enhanced_cps_native_loss_delta": -0.1, + } + } + assert any( + record["key"] == "policyengine_native_scores" + and record["path"] == "policyengine_native_scores.json" + and record["exists"] is True + for record in rewritten_evidence["evidence"] + ) + + +def _write_artifact_bundle_files(root): + for relative in ( + "seed_data.parquet", + "synthetic_data.parquet", + "calibrated_data.parquet", + "targets.json", + "policyengine_us.h5", + "policyengine_native_scores.json", + "source_weight_diagnostics.json", + "stage_artifacts/03_source_planning/source_plan.json", + "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet", + "stage_artifacts/06_policyengine_entities/metadata.json", + "stage_artifacts/07_calibration/calibration_summary.json", + ): + path = root / relative + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("{}") + ( + root / "stage_artifacts" / "06_policyengine_entities" / "metadata.json" + ).write_text( + json.dumps( + { + "format_version": 1, + "stage": "post_microsim", + "households": {"rows": 1, "columns": ["household_id"]}, + "persons": {"rows": 1, "columns": ["person_id"]}, + } + ) + ) + + +def _write_validation_evidence_manifest(root): + evidence_path = ( + root + / "stage_artifacts" + / "09_validation_benchmarking" + / "evidence_manifest.json" + ) + evidence_path.parent.mkdir(parents=True, exist_ok=True) + evidence_path.write_text( + json.dumps( + { + "formatVersion": 1, + "stageId": "09_validation_benchmarking", + "evidence": [ + { + "key": "policyengine_native_scores", + "path": "policyengine_native_scores.json", + "exists": True, + } + ], + "summaries": { + "policyengine_native_scores": { + "enhanced_cps_native_loss_delta": -0.1, + } + }, + } + ) + ) + return evidence_path + + +def _artifact_manifest(): + return { + "created_at": "2026-05-30T00:00:00+00:00", + "config": {"n_synthetic": 10, "calibration_backend": "entropy"}, + "rows": {"seed": 1, "synthetic": 1, "calibrated": 1}, + "synthesis": { + "source_names": ["source"], + "scaffold_source": "source", + "backend": "seed", + }, + "calibration": {"backend": "entropy", "converged": True}, + "policyengine_native_scores": {"enhanced_cps_native_loss_delta": -0.1}, + "artifacts": { + "seed_data": "seed_data.parquet", + "synthetic_data": "synthetic_data.parquet", + "calibrated_data": "calibrated_data.parquet", + "targets": "targets.json", + "policyengine_dataset": "policyengine_us.h5", + "policyengine_native_scores": "policyengine_native_scores.json", + "source_weight_diagnostics": "source_weight_diagnostics.json", + "source_plan": "stage_artifacts/03_source_planning/source_plan.json", + "scaffold_seed_data": ( + "stage_artifacts/04_seed_scaffold/scaffold_seed_data.parquet" + ), + "policyengine_entity_tables": ( + "stage_artifacts/06_policyengine_entities/metadata.json" + ), + "calibration_summary": ( + "stage_artifacts/07_calibration/calibration_summary.json" + ), + }, + } diff --git a/tests/pipelines/test_versioned_artifacts.py b/tests/pipelines/test_versioned_artifacts.py index 0961913..a107929 100644 --- a/tests/pipelines/test_versioned_artifacts.py +++ b/tests/pipelines/test_versioned_artifacts.py @@ -358,6 +358,34 @@ def test_save_versioned_us_microplex_artifacts_uses_explicit_version(tmp_path): assert paths.output_dir == tmp_path / "builds" / "run-1" assert paths.run_registry == tmp_path / "builds" / "run_registry.jsonl" assert paths.run_index_db == tmp_path / "builds" / "run_index.duckdb" + assert paths.stage_manifest == paths.output_dir / "stage_manifest.json" + assert paths.artifact_inventory == ( + paths.output_dir / "stage_artifacts" / "artifact_inventory.json" + ) + assert paths.conditional_readiness == ( + paths.output_dir / "stage_artifacts" / "conditional_readiness.json" + ) + assert paths.source_plan == ( + paths.output_dir / "stage_artifacts" / "03_source_planning" / "source_plan.json" + ) + assert paths.policyengine_entity_tables == ( + paths.output_dir / "stage_artifacts" / "06_policyengine_entities" / "metadata.json" + ) + assert paths.calibration_summary == ( + paths.output_dir + / "stage_artifacts" + / "07_calibration" + / "calibration_summary.json" + ) + assert paths.validation_evidence == ( + paths.output_dir + / "stage_artifacts" + / "09_validation_benchmarking" + / "evidence_manifest.json" + ) + assert paths.source_weight_diagnostics == ( + paths.output_dir / "source_weight_diagnostics.json" + ) manifest = json.loads(paths.manifest.read_text()) assert manifest["run_registry"]["artifact_id"] == "run-1" assert manifest["run_index"]["artifact_id"] == "run-1"