From 0baab84b81f02efef692880e2f1decb94124db7a Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 18:15:15 +0200 Subject: [PATCH] feat(pipeline): HITL Pause + AuditLog with 4 backends (#147 phase 3c) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two enterprise features for state pipelines: human-in-the-loop pause/approve gates, and a structured audit trail of every node visit. HITL via Pause: - New Pause(reason=...) sentinel returned by a node halts the pipeline cleanly. - StatePipeline writes a paused checkpoint (new optional CheckpointRecord fields: paused=False, pause_reason=None — backward compatible). - StatePipelineResult gains paused/paused_node/pause_reason. - StatePipelineEventHandler gains an optional on_node_pause callback. - invoke(run_id=..., approve_pause=True) resumes from the SUCCESSOR of the paused node. Without approve_pause=True, resuming a paused run raises PipelineError — pauses are sticky until explicitly released. Audit log via four backends in new pipeline/audit.py: - AuditEntry pydantic model + split Protocol: AuditLog (write-only) + QueryableAuditLog (adds list_entries). - FileAuditLog — JSONL per (pipeline, run_id). Implements QueryableAuditLog. - PostgresAuditLog — single firefly_audit table, idempotent DDL, reuses the psycopg connection from Phase 3a's postgres extra. Implements QueryableAuditLog. - LoggingAuditLog — stdlib logging.Logger; pairs with any log-aggregation stack (Splunk-HEC, Loki, Datadog, OTel-LoggingHandler-bridge). Write-only. No new dep. - OtelAuditLog — OTel logs API directly; emits LogRecord with trace_id / span_id correlation. Requires opentelemetry-sdk. Write-only. - StatePipeline records one AuditEntry per node visit (success/error/pause) with inputs_snapshot, outputs_snapshot, latency_ms, started_at/completed_at, status, plus error_message or pause_reason as appropriate. - Audit-write failures are non-fatal — logged and swallowed. API additions exported from fireflyframework_agentic.pipeline: Pause, AuditEntry, AuditLog, QueryableAuditLog, FileAuditLog, PostgresAuditLog, LoggingAuditLog, OtelAuditLog No new required deps. Postgres extra (already added in 3a) covers PostgresAuditLog. opentelemetry-sdk is the same optional dep already used by Phase 3b for OTel spans. Tests: 20 new across two files. - test_state_pipeline_hitl.py (6): pause halts pipeline; resume without approval raises; resume with approve_pause continues from successor; on_node_pause fires; backward-compat for old checkpoints; pause→fail→retry. - test_audit_log.py (14): per-backend (File JSONL, Postgres mocked, Logging via caplog, OTel via mocked logger); pipeline writes one entry per visit; status reflects success/error/paused; audit write failures don't abort the pipeline. Full pipeline suite 142 passed (122 baseline + 20 new). Lints clean, pyright clean on touched modules. Example pipeline_state.py gains a 5th scenario showing HITL + audit end-to-end. --- docs/pipeline.md | 63 ++++ examples/pipeline_state.py | 69 ++++ fireflyframework_agentic/pipeline/__init__.py | 20 +- fireflyframework_agentic/pipeline/audit.py | 349 ++++++++++++++++++ fireflyframework_agentic/pipeline/builder.py | 4 + .../pipeline/checkpoint.py | 10 +- fireflyframework_agentic/pipeline/engine.py | 5 + .../pipeline/state_pipeline.py | 166 ++++++++- tests/unit/pipeline/test_audit_log.py | 325 ++++++++++++++++ .../unit/pipeline/test_state_pipeline_hitl.py | 208 +++++++++++ 10 files changed, 1216 insertions(+), 3 deletions(-) create mode 100644 fireflyframework_agentic/pipeline/audit.py create mode 100644 tests/unit/pipeline/test_audit_log.py create mode 100644 tests/unit/pipeline/test_state_pipeline_hitl.py diff --git a/docs/pipeline.md b/docs/pipeline.md index 60bd2097..88570307 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -320,6 +320,69 @@ In parallel, the pipeline emits OTel spans automatically when Handler exceptions are swallowed — observability never breaks business logic. +### Human-in-the-loop (Pause) + +Any node may return ``Pause(reason="...")`` instead of a state update to halt +the pipeline cleanly. The current state is checkpointed with a paused marker; +``invoke`` returns with ``result.paused=True`` and ``result.success=False``. + +```python +from fireflyframework_agentic.pipeline import Pause + +async def await_deploy_approval(state: DeployState) -> Pause: + return Pause(reason="awaiting human approval to deploy to production") +``` + +To resume after the external approval comes in, call ``invoke`` with the same +``run_id`` and ``approve_pause=True``. Without ``approve_pause=True``, the +resume raises a ``PipelineError`` — the pause is sticky until explicitly +released. The successor of the paused node runs next; the pause node itself +is not re-executed. + +```python +first = await pipeline.invoke(DeployState(...)) +assert first.paused +# ...later, after approval... +done = await pipeline.invoke(run_id=first.run_id, approve_pause=True) +assert done.success +``` + +The configured ``StatePipelineEventHandler`` receives an ``on_node_pause`` +callback when this happens (the callback is optional — partial handlers +without it continue to work). + +### Audit Log + +Distinct from the ``Checkpointer`` (which stores the *latest* state for +crash recovery), an ``AuditLog`` is an append-only record of *every* node +visit for compliance, debugging, and replay. Wire one in via the +``audit_log`` kwarg: + +```python +from fireflyframework_agentic.pipeline import ( + PipelineBuilder, FileAuditLog, PostgresAuditLog, LoggingAuditLog, OtelAuditLog, +) + +PipelineBuilder("agent", state=AgentState, audit_log=FileAuditLog("./audit")) +``` + +Four backends ship, each conforming to the ``AuditLog`` Protocol: + +| Backend | Use when | Read API | Trace-correlated | Install | +|---|---|---|---|---| +| ``FileAuditLog`` | Dev / single-host | yes | no | (default) | +| ``PostgresAuditLog`` | Compliance, retention, cross-run queries | yes | no | ``[postgres]`` | +| ``LoggingAuditLog`` | Generic log stacks (Splunk-HEC, Loki, JSON-logging) | no (write-only) | no | (default — stdlib) | +| ``OtelAuditLog`` | OTel-native stacks (Application Insights, Datadog APM, OTel Collector) | no (write-only) | **yes** | ``opentelemetry-sdk`` | + +``FileAuditLog`` and ``PostgresAuditLog`` also implement +``QueryableAuditLog`` with ``list_entries(pipeline_name, run_id)``. The +write-only backends delegate query/search to the user's existing +observability stack. + +Audit-log write failures are non-fatal — logged but never abort the +pipeline. + ### Mermaid Export `StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index b5fba799..ee9bb449 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -51,7 +51,9 @@ from pydantic import BaseModel from fireflyframework_agentic.pipeline import ( + FileAuditLog, FileCheckpointer, + Pause, PipelineBuilder, PostgresCheckpointer, Send, @@ -172,6 +174,9 @@ async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: print(f" ✗ {node_id}: {error}") + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + print(f" ⏸ {node_id} paused: {reason}") + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: status = "OK" if success else "FAILED" print(f" ═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") @@ -299,11 +304,75 @@ async def run_software_factory_postgres() -> None: print(f" eval: {second.state.evaluation}\n") +class HitlState(BaseModel): + """State threaded through a deploy pipeline gated by human approval.""" + + target_env: str + artifact: str | None = None + deployed_to: str | None = None + + +async def build_artifact(state: HitlState) -> dict: + return {"artifact": f"build-{state.target_env}.tar.gz"} + + +async def await_approval(state: HitlState) -> Pause: + return Pause(reason=f"awaiting human approval to deploy {state.artifact} to {state.target_env}") + + +async def deploy_artifact(state: HitlState) -> dict: + return {"deployed_to": f"https://{state.target_env}.example.com"} + + +async def run_hitl_with_audit() -> None: + print("=== 5. Human-in-the-loop deploy gate with audit log ===\n") + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + ckpt = FileCheckpointer(root / "ckpt") + audit = FileAuditLog(root / "audit") + pipeline = ( + PipelineBuilder( + "hitl-deploy", + state=HitlState, + checkpointer=ckpt, + audit_log=audit, + ) + .add_node(build_artifact) + .add_node(await_approval) + .add_node(deploy_artifact) + .chain(build_artifact, await_approval, deploy_artifact) + .build() + ) + + # First run halts at the approval gate. + first = await pipeline.invoke(HitlState(target_env="prod")) + print(f" first run: paused={first.paused}, paused_node={first.paused_node}") + print(f" reason: {first.pause_reason}") + print(f" run_id: {first.run_id}\n") + + # ...time passes; a human reviews and approves... + print(" (human reviews and approves)\n") + + # Resume with explicit approval. + done = await pipeline.invoke(run_id=first.run_id, approve_pause=True) + print(f" resumed: success={done.success}, deployed_to={done.state.deployed_to}") + print(f" completed: {done.completed_nodes}\n") + + # Audit log captures every node visit with its status. + entries = audit.list_entries("hitl-deploy", first.run_id) + print(" audit trail:") + for e in entries: + extra = f" reason={e.pause_reason!r}" if e.pause_reason else "" + print(f" seq={e.sequence} node={e.node_id} status={e.status}{extra}") + + async def main() -> None: await run_branching() await run_software_factory() await run_map_reduce() await run_software_factory_postgres() + await run_hitl_with_audit() if __name__ == "__main__": diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index af4778df..6f4dbe6c 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -28,6 +28,15 @@ via :class:`Checkpointer` enables resume after failure and mid-pipeline start. """ +from fireflyframework_agentic.pipeline.audit import ( + AuditEntry, + AuditLog, + FileAuditLog, + LoggingAuditLog, + OtelAuditLog, + PostgresAuditLog, + QueryableAuditLog, +) from fireflyframework_agentic.pipeline.builder import PipelineBuilder from fireflyframework_agentic.pipeline.checkpoint import ( Checkpointer, @@ -46,6 +55,7 @@ from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult from fireflyframework_agentic.pipeline.state_pipeline import ( + Pause, RecursionLimitError, Send, StatePipeline, @@ -67,6 +77,8 @@ __all__ = [ "DAG", "AgentStep", + "AuditEntry", + "AuditLog", "BatchLLMStep", "BranchStep", "CallableStep", @@ -79,14 +91,20 @@ "FailureStrategy", "FanInStep", "FanOutStep", + "FileAuditLog", "FileCheckpointer", + "LoggingAuditLog", "NodeResult", - "PostgresCheckpointer", + "OtelAuditLog", + "Pause", "PipelineBuilder", "PipelineContext", "PipelineEngine", "PipelineEventHandler", "PipelineResult", + "PostgresAuditLog", + "PostgresCheckpointer", + "QueryableAuditLog", "ReasoningStep", "RecursionLimitError", "RedisCheckpointer", diff --git a/fireflyframework_agentic/pipeline/audit.py b/fireflyframework_agentic/pipeline/audit.py new file mode 100644 index 00000000..481ea08d --- /dev/null +++ b/fireflyframework_agentic/pipeline/audit.py @@ -0,0 +1,349 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Append-only audit logs for state-pipeline node visits. + +Distinct from :mod:`fireflyframework_agentic.pipeline.checkpoint` — the +checkpointer stores the *latest* state for crash recovery; the audit log +stores *every* node visit for compliance, debugging, and replay. + +Four backends ship: + +* :class:`FileAuditLog` — one JSONL file per ``(pipeline_name, run_id)``. + Best for dev / single-host audit trails. +* :class:`PostgresAuditLog` — single ``firefly_audit`` table; reuses the + ``psycopg`` connection from the ``postgres`` optional extra (same dep + added in Phase 3a for ``PostgresCheckpointer``). +* :class:`LoggingAuditLog` — stdlib ``logging``; pairs with whatever log + aggregation pipeline (Splunk-HEC, Loki, Datadog, OTel-LoggingHandler-bridge) + the host application already runs. No new optional dep. +* :class:`OtelAuditLog` — direct OTel logs API; attaches trace correlation + (``trace_id``/``span_id``) automatically. Best for OTel-native stacks + (Application Insights, Datadog APM, OTel-Collector). + +File and Postgres also implement :class:`QueryableAuditLog` (``list_entries``); +Logging and OTel are write-only — query your observability stack instead. +""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import BaseModel + +try: + import psycopg as _psycopg # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _psycopg = None # type: ignore[assignment] + +try: + from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found] + from opentelemetry._logs import SeverityNumber as _OtelSeverityNumber # type: ignore[import-not-found] + from opentelemetry._logs import get_logger as _otel_get_logger # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _otel_get_logger = None # type: ignore[assignment] + _OtelLogRecord = None # type: ignore[assignment,misc] + _OtelSeverityNumber = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + + +AuditStatus = Literal["success", "error", "paused"] + + +class AuditEntry(BaseModel): + """A single audit record — one per node visit (success, error, or pause).""" + + pipeline_name: str + run_id: str + node_id: str + sequence: int + visit: int + started_at: datetime + completed_at: datetime + latency_ms: float + status: AuditStatus + inputs_snapshot: dict[str, Any] + outputs_snapshot: dict[str, Any] + error_message: str | None = None + pause_reason: str | None = None + + +@runtime_checkable +class AuditLog(Protocol): + """Write-only audit log. Every backend implements this method. + + Implementations must be safe to call from async code (called inside the + state pipeline's executor) but the method itself is sync. + """ + + def record(self, entry: AuditEntry) -> None: ... + + +@runtime_checkable +class QueryableAuditLog(AuditLog, Protocol): + """Audit log that also supports reading back recorded entries. + + File and Postgres backends implement this. Logging and OTel backends do + not — query your observability stack (Splunk / Datadog / Loki / etc.) + instead. + """ + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: ... + + +class FileAuditLog: + """Filesystem-backed audit log. Layout:: + + //.jsonl + + Each line is a JSON-serialized :class:`AuditEntry`. Appends are atomic + at the line level (single ``write`` call per entry); concurrent writers + to the same run_id may interleave at line boundaries but never within a + single entry. + """ + + def __init__(self, root: str | Path) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + def _path(self, pipeline_name: str, run_id: str) -> Path: + return self._root / pipeline_name / f"{run_id}.jsonl" + + def record(self, entry: AuditEntry) -> None: + path = self._path(entry.pipeline_name, entry.run_id) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as f: + f.write(entry.model_dump_json() + "\n") + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + path = self._path(pipeline_name, run_id) + if not path.exists(): + return [] + entries: list[AuditEntry] = [] + with path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + entries.append(AuditEntry.model_validate(json.loads(line))) + return entries + + +class PostgresAuditLog: + """Postgres-backed audit log. Single table created on first ``record`` call. + + Reuses ``psycopg`` from the ``postgres`` optional extra. ``dsn`` or a + pre-built ``connection`` is required, mirroring :class:`PostgresCheckpointer`'s + API for consistency. + """ + + _DDL = """ + CREATE TABLE IF NOT EXISTS {table} ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + visit INT NOT NULL, + node_id TEXT NOT NULL, + started_at TIMESTAMPTZ NOT NULL, + completed_at TIMESTAMPTZ NOT NULL, + latency_ms DOUBLE PRECISION NOT NULL, + status TEXT NOT NULL, + inputs_snapshot JSONB NOT NULL, + outputs_snapshot JSONB NOT NULL, + error_message TEXT, + pause_reason TEXT, + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS {table}_run_idx + ON {table} (pipeline_name, run_id); + """ + + def __init__( + self, + *, + dsn: str | None = None, + connection: Any = None, + table_name: str = "firefly_audit", + ) -> None: + if _psycopg is None: + raise ImportError( + "PostgresAuditLog requires the 'postgres' extra. " + "Install with: pip install fireflyframework-agentic[postgres]" + ) + if (dsn is None) == (connection is None): + raise ValueError("PostgresAuditLog needs exactly one of `dsn` or `connection`.") + if not table_name.replace("_", "").isalnum(): + raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") + self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) + self._table = table_name + self._ddl_applied = False + + def _ensure_table(self) -> None: + if self._ddl_applied: + return + with self._conn.cursor() as cur: + cur.execute(self._DDL.format(table=self._table)) + self._ddl_applied = True + + def record(self, entry: AuditEntry) -> None: + self._ensure_table() + sql = ( + f"INSERT INTO {self._table} " + "(pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " + "latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO NOTHING" + ) + with self._conn.cursor() as cur: + cur.execute( + sql, + ( + entry.pipeline_name, + entry.run_id, + entry.sequence, + entry.visit, + entry.node_id, + entry.started_at, + entry.completed_at, + entry.latency_ms, + entry.status, + json.dumps(entry.inputs_snapshot), + json.dumps(entry.outputs_snapshot), + entry.error_message, + entry.pause_reason, + ), + ) + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + self._ensure_table() + sql = ( + f"SELECT pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " + f"latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason " + f"FROM {self._table} WHERE pipeline_name = %s AND run_id = %s ORDER BY sequence" + ) + with self._conn.cursor() as cur: + cur.execute(sql, (pipeline_name, run_id)) + rows = cur.fetchall() + entries: list[AuditEntry] = [] + for row in rows: + entries.append( + AuditEntry( + pipeline_name=row[0], + run_id=row[1], + sequence=row[2], + visit=row[3], + node_id=row[4], + started_at=row[5], + completed_at=row[6], + latency_ms=row[7], + status=row[8], + inputs_snapshot=json.loads(row[9]) if isinstance(row[9], str) else row[9], + outputs_snapshot=json.loads(row[10]) if isinstance(row[10], str) else row[10], + error_message=row[11], + pause_reason=row[12], + ) + ) + return entries + + +class LoggingAuditLog: + """Audit log backed by Python's stdlib ``logging`` module. + + Each entry is emitted as a structured log record with the full + :class:`AuditEntry` available under ``record.firefly_audit``. Pairs + naturally with any log-aggregation pipeline the host already configures: + + * OTel collector via ``opentelemetry-sdk._logs.LoggingHandler`` + * Splunk via the official Splunk handler + * Loki via promtail tailing stdout + * Datadog via ``ddtrace`` log injection + * Plain JSON-logging via ``python-json-logger`` + + No new dependency. The user wires their log handler exactly once for the + whole application and audit entries flow there automatically. + """ + + def __init__(self, logger_name: str = "firefly.audit", level: int = logging.INFO) -> None: + self._logger = logging.getLogger(logger_name) + self._level = level + + def record(self, entry: AuditEntry) -> None: + self._logger.log( + self._level, + "firefly_audit: pipeline=%s run=%s node=%s status=%s latency_ms=%.1f", + entry.pipeline_name, + entry.run_id, + entry.node_id, + entry.status, + entry.latency_ms, + extra={"firefly_audit": entry.model_dump(mode="json")}, + ) + + +class OtelAuditLog: + """Audit log backed by the OTel logs API. + + Emits each entry as a structured OTel log record with attributes + matching :class:`AuditEntry` fields. When an OTel trace is active the + record is automatically correlated with the current ``trace_id`` and + ``span_id`` — useful for tying audit history to the spans Phase 3b + emits. + + Requires ``opentelemetry-sdk`` to be installed and a ``LoggerProvider`` + configured by the host application (the framework does not install a + provider). + """ + + def __init__(self, logger_name: str = "fireflyframework_agentic.audit") -> None: + if _otel_get_logger is None: + raise ImportError( + "OtelAuditLog requires the 'opentelemetry-sdk' package " + "(with the logs API). Install: pip install opentelemetry-sdk" + ) + self._otel_logger = _otel_get_logger(logger_name) + + def record(self, entry: AuditEntry) -> None: + # The constructor's guard on _otel_get_logger guarantees the OTel + # logs imports succeeded, so the LogRecord and SeverityNumber are + # not None here. + assert _OtelLogRecord is not None and _OtelSeverityNumber is not None + attrs: dict[str, Any] = { + "firefly.pipeline": entry.pipeline_name, + "firefly.run_id": entry.run_id, + "firefly.node": entry.node_id, + "firefly.sequence": entry.sequence, + "firefly.visit": entry.visit, + "firefly.latency_ms": entry.latency_ms, + "firefly.status": entry.status, + } + if entry.error_message: + attrs["firefly.error"] = entry.error_message + if entry.pause_reason: + attrs["firefly.pause_reason"] = entry.pause_reason + + body = f"[{entry.pipeline_name}] {entry.node_id} {entry.status} ({entry.latency_ms:.0f}ms)" + severity = _OtelSeverityNumber.ERROR if entry.status == "error" else _OtelSeverityNumber.INFO + log_record = _OtelLogRecord( + timestamp=int(entry.completed_at.timestamp() * 1_000_000_000), + severity_number=severity, + severity_text=severity.name, + body=body, + attributes=attrs, + ) + self._otel_logger.emit(log_record) diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 2975d59f..512bd06d 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -53,6 +53,7 @@ from pydantic import BaseModel from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import AuditLog from fireflyframework_agentic.pipeline.checkpoint import Checkpointer from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import PipelineEngine, StatePipelineEventHandler @@ -87,6 +88,7 @@ def __init__( checkpointer: Checkpointer | None = None, recursion_limit: int = 25, event_handler: StatePipelineEventHandler | None = None, + audit_log: AuditLog | None = None, ) -> None: # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). # The legacy port-based path keeps acyclicity as an invariant. @@ -96,6 +98,7 @@ def __init__( self._checkpointer = checkpointer self._recursion_limit = recursion_limit self._event_handler = event_handler + self._audit_log = audit_log self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] # State-based mode bookkeeping. Keyed by node id. @@ -270,6 +273,7 @@ def build(self) -> PipelineEngine | StatePipeline: checkpointer=self._checkpointer, recursion_limit=self._recursion_limit, event_handler=self._event_handler, + audit_log=self._audit_log, ) return PipelineEngine(self._dag) diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index 3a63d8c1..ecedb16c 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -51,7 +51,13 @@ class CheckpointRecord(BaseModel): - """One saved checkpoint.""" + """One saved checkpoint. + + ``paused`` and ``pause_reason`` are set when a node returns + :class:`fireflyframework_agentic.pipeline.state_pipeline.Pause`. Default + to ``False`` / ``None`` so existing records from earlier phases load + cleanly under the new schema. + """ pipeline_name: str run_id: str @@ -59,6 +65,8 @@ class CheckpointRecord(BaseModel): sequence: int state: dict[str, Any] completed_nodes: list[str] + paused: bool = False + pause_reason: str | None = None @runtime_checkable diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 9f7e3d0f..4f1e46e6 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -101,6 +101,11 @@ async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, err """Called when a node raises an exception.""" ... + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + """Called when a node returns :class:`Pause`, halting the pipeline + until an external ``invoke(run_id=..., approve_pause=True)`` resumes it.""" + ... + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: """Called once when ``invoke`` returns.""" ... diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index d33c37b7..823e4928 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -31,11 +31,13 @@ import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, get_type_hints from pydantic import BaseModel from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id from fireflyframework_agentic.pipeline.engine import start_otel_span @@ -71,6 +73,31 @@ class RecursionLimitError(Exception): """Raised when a node is visited more times than ``recursion_limit`` permits.""" +@dataclass +class Pause: + """Human-in-the-loop sentinel returned by a node to halt the pipeline. + + A node returns ``Pause(reason="...")`` when external approval (a human, + another system, a wall-clock event) is required before the pipeline may + continue. The pipeline then: + + 1. Writes a checkpoint with ``paused=True`` and the reason set. + 2. Emits ``on_node_pause`` on the configured event handler. + 3. Returns a :class:`StatePipelineResult` with ``paused=True`` and + ``success=False`` — the run is not finished, but it did not fail either. + + To resume after approval:: + + result = await pipeline.invoke(run_id=paused_run_id, approve_pause=True) + + Without ``approve_pause=True``, resuming a paused run raises + :class:`PipelineError`. The successor of the paused node runs next — + the pause node itself is not re-executed. + """ + + reason: str + + @dataclass class BranchSpec: """Internal: registered branch from one source node.""" @@ -91,6 +118,10 @@ class StatePipelineResult: success: True iff all attempted nodes completed without error. error: Last error message if ``success`` is False. failed_node: Node ID that failed, if any. + paused: True if the run halted on a :class:`Pause` sentinel; resume + via ``invoke(run_id=..., approve_pause=True)``. + paused_node: Node that returned ``Pause`` if ``paused`` is True. + pause_reason: Reason string the paused node passed to ``Pause(...)``. """ state: Any @@ -99,6 +130,9 @@ class StatePipelineResult: success: bool error: str | None = None failed_node: str | None = None + paused: bool = False + paused_node: str | None = None + pause_reason: str | None = None def discover_reducers(state_schema: type) -> dict[str, Reducer]: @@ -157,6 +191,7 @@ def __init__( checkpointer: Checkpointer | None = None, recursion_limit: int = 25, event_handler: StatePipelineEventHandler | None = None, + audit_log: AuditLog | None = None, ) -> None: self._name = name self._dag = dag @@ -166,9 +201,52 @@ def __init__( self._checkpointer = checkpointer self._recursion_limit = recursion_limit self._event_handler = event_handler + self._audit_log = audit_log self._reducers = discover_reducers(state_schema) self._validate() + def _audit( + self, + *, + run_id: str, + node_id: str, + sequence: int, + visit: int, + started_at: datetime, + completed_at: datetime, + latency_ms: float, + status: AuditStatus, + inputs_snapshot: dict[str, Any], + outputs_snapshot: dict[str, Any], + error_message: str | None = None, + pause_reason: str | None = None, + ) -> None: + """Construct and write an :class:`AuditEntry`. No-op if no audit log is configured. + + Audit-write failures are non-fatal — logged and swallowed. + """ + if self._audit_log is None: + return + entry = AuditEntry( + pipeline_name=self._name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit, + started_at=started_at, + completed_at=completed_at, + latency_ms=latency_ms, + status=status, + inputs_snapshot=inputs_snapshot, + outputs_snapshot=outputs_snapshot, + error_message=error_message, + pause_reason=pause_reason, + ) + try: + self._audit_log.record(entry) + except Exception: + logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) + async def _finalize_run( self, result: StatePipelineResult, @@ -335,6 +413,7 @@ async def invoke( *, run_id: str | None = None, start_at: str | Callable[..., Any] | None = None, + approve_pause: bool = False, ) -> StatePipelineResult: """Run the pipeline. @@ -353,9 +432,15 @@ async def invoke( record = self._checkpointer.load_latest(self._name, run_id) if record is None: raise PipelineError(f"No checkpoint found for run_id='{run_id}'") + # A paused run requires explicit approval before continuing. + if record.paused and not approve_pause: + raise PipelineError( + f"Run '{run_id}' is paused at node '{record.node_id}' " + f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." + ) state = self._state_schema.model_validate(record.state) resumed_completed = list(record.completed_nodes) - # Resume at the successor of the last completed node. + # Resume at the successor of the last completed (or paused) node. last = record.node_id next_node = self._next_step(last, state) # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. @@ -460,6 +545,8 @@ async def invoke( fn = self._node_fns[node_id] node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) await self._emit("on_node_start", self._name, run_id, node_id, visit_n) + inputs_snapshot = state.model_dump(mode="json") + started_at = datetime.now(UTC) t0 = time.perf_counter() try: update = await fn(state) @@ -474,6 +561,19 @@ async def invoke( if node_span is not None: with contextlib.suppress(Exception): node_span.end() + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence + 1, + visit=visit_n, + started_at=started_at, + completed_at=datetime.now(UTC), + latency_ms=(time.perf_counter() - t0) * 1000, + status="error", + inputs_snapshot=inputs_snapshot, + outputs_snapshot={}, + error_message=str(exc), + ) return await self._finalize_run( StatePipelineResult( state=state, @@ -488,9 +588,56 @@ async def invoke( run_id, ) elapsed = (time.perf_counter() - t0) * 1000 + completed_at = datetime.now(UTC) if node_span is not None: with contextlib.suppress(Exception): node_span.end() + + # HITL: a node returning Pause halts the pipeline and writes a + # paused checkpoint. Approval comes via invoke(approve_pause=True). + if isinstance(update, Pause): + pause_reason = update.reason + await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) + completed.append(node_id) + sequence += 1 + self._save_checkpoint( + run_id, + node_id, + sequence, + state, + completed, + paused=True, + pause_reason=pause_reason, + ) + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit_n, + started_at=started_at, + completed_at=completed_at, + latency_ms=elapsed, + status="paused", + inputs_snapshot=inputs_snapshot, + outputs_snapshot=state.model_dump(mode="json"), + pause_reason=pause_reason, + ) + logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + paused=True, + paused_node=node_id, + pause_reason=pause_reason, + ), + pipeline_span, + pipeline_start_time, + run_id, + ) + await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) @@ -500,6 +647,18 @@ async def invoke( completed.append(node_id) sequence += 1 self._save_checkpoint(run_id, node_id, sequence, state, completed) + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit_n, + started_at=started_at, + completed_at=completed_at, + latency_ms=elapsed, + status="success", + inputs_snapshot=inputs_snapshot, + outputs_snapshot=state.model_dump(mode="json"), + ) try: next_step = self._next_step(node_id, state) @@ -604,6 +763,9 @@ def _save_checkpoint( sequence: int, state: BaseModel, completed: list[str], + *, + paused: bool = False, + pause_reason: str | None = None, ) -> None: """Persist state via the configured checkpointer (no-op if absent).""" if self._checkpointer is None: @@ -617,6 +779,8 @@ def _save_checkpoint( sequence=sequence, state=state.model_dump(), completed_nodes=list(completed), + paused=paused, + pause_reason=pause_reason, ) ) except Exception: diff --git a/tests/unit/pipeline/test_audit_log.py b/tests/unit/pipeline/test_audit_log.py new file mode 100644 index 00000000..e6b426b2 --- /dev/null +++ b/tests/unit/pipeline/test_audit_log.py @@ -0,0 +1,325 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3c audit-log tests — File / Postgres / Logging / OTel backends + pipeline wiring.""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.audit as audit_module +from fireflyframework_agentic.pipeline import ( + AuditEntry, + FileAuditLog, + LoggingAuditLog, + OtelAuditLog, + Pause, + PipelineBuilder, + PostgresAuditLog, +) + + +def _entry(**overrides: Any) -> AuditEntry: + defaults = { + "pipeline_name": "p", + "run_id": "r", + "node_id": "n", + "sequence": 1, + "visit": 1, + "started_at": datetime(2026, 5, 27, tzinfo=UTC), + "completed_at": datetime(2026, 5, 27, 0, 0, 1, tzinfo=UTC), + "latency_ms": 100.0, + "status": "success", + "inputs_snapshot": {"x": 1}, + "outputs_snapshot": {"y": 2}, + } + defaults.update(overrides) + return AuditEntry(**defaults) # type: ignore[arg-type] + + +# ============================================================================= +# FileAuditLog +# ============================================================================= + + +def test_file_audit_log_writes_jsonl_per_run(tmp_path: Path) -> None: + log = FileAuditLog(tmp_path) + log.record(_entry(sequence=1, node_id="a")) + log.record(_entry(sequence=2, node_id="b")) + + path = tmp_path / "p" / "r.jsonl" + assert path.exists() + lines = path.read_text().strip().splitlines() + assert len(lines) == 2 + assert json.loads(lines[0])["node_id"] == "a" + assert json.loads(lines[1])["node_id"] == "b" + + +def test_file_audit_log_list_entries_round_trips(tmp_path: Path) -> None: + log = FileAuditLog(tmp_path) + for seq, node in [(1, "a"), (2, "b"), (3, "c")]: + log.record(_entry(sequence=seq, node_id=node)) + entries = log.list_entries("p", "r") + assert [e.node_id for e in entries] == ["a", "b", "c"] + + +def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None: + assert FileAuditLog(tmp_path).list_entries("p", "missing") == [] + + +# ============================================================================= +# PostgresAuditLog +# ============================================================================= + + +@pytest.fixture(autouse=True) +def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: + """Stub _psycopg and OTel symbols so backends can be constructed with mocks.""" + if audit_module._psycopg is None: + monkeypatch.setattr(audit_module, "_psycopg", MagicMock(name="psycopg_stub")) + if audit_module._otel_get_logger is None: + monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory")) + monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord")) + sev = MagicMock(name="SeverityNumber") + sev.ERROR = MagicMock(name="ERROR") + sev.ERROR.name = "ERROR" + sev.INFO = MagicMock(name="INFO") + sev.INFO.name = "INFO" + monkeypatch.setattr(audit_module, "_OtelSeverityNumber", sev) + + +def test_postgres_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(audit_module, "_psycopg", None) + with pytest.raises(ImportError, match=r"\[postgres\]"): + PostgresAuditLog(dsn="postgresql://x") + + +def _pg_conn_mock() -> tuple[MagicMock, dict]: + """MagicMock connection backed by an in-memory dict keyed by (pipeline,run,seq).""" + store: dict[tuple[str, str, int], dict[str, Any]] = {} + ddl_calls: list[str] = [] + conn = MagicMock(name="psycopg.Connection") + + def make_cursor() -> MagicMock: + cur = MagicMock() + cur.__enter__ = MagicMock(return_value=cur) + cur.__exit__ = MagicMock(return_value=None) + cur._last_one = None + cur._last_all = [] + + def fake_execute(sql: str, params: tuple | None = None) -> None: + s = sql.strip().lower() + if s.startswith("create table"): + ddl_calls.append(sql) + return + if s.startswith("insert into"): + assert params is not None + key = (params[0], params[1], params[2]) + store[key] = { + "pipeline_name": params[0], + "run_id": params[1], + "sequence": params[2], + "visit": params[3], + "node_id": params[4], + "started_at": params[5], + "completed_at": params[6], + "latency_ms": params[7], + "status": params[8], + "inputs_snapshot": json.loads(params[9]) if isinstance(params[9], str) else params[9], + "outputs_snapshot": json.loads(params[10]) if isinstance(params[10], str) else params[10], + "error_message": params[11], + "pause_reason": params[12], + } + return + if s.startswith("select"): + assert params is not None + rows = [v for k, v in store.items() if k[0] == params[0] and k[1] == params[1]] + rows.sort(key=lambda r: r["sequence"]) + cur._last_all = [ + ( + r["pipeline_name"], + r["run_id"], + r["sequence"], + r["visit"], + r["node_id"], + r["started_at"], + r["completed_at"], + r["latency_ms"], + r["status"], + r["inputs_snapshot"], + r["outputs_snapshot"], + r["error_message"], + r["pause_reason"], + ) + for r in rows + ] + return + raise AssertionError(f"unexpected SQL: {sql}") + + cur.execute.side_effect = fake_execute + cur.fetchone.side_effect = lambda: cur._last_one + cur.fetchall.side_effect = lambda: cur._last_all + return cur + + conn.cursor.side_effect = make_cursor + conn._ddl_calls = ddl_calls + return conn, store + + +def test_postgres_audit_ddl_once_then_inserts() -> None: + conn, store = _pg_conn_mock() + log = PostgresAuditLog(connection=conn) + for seq in (1, 2, 3): + log.record(_entry(sequence=seq, node_id=f"n{seq}")) + assert len(conn._ddl_calls) == 1 + assert len(store) == 3 + + +def test_postgres_audit_list_entries_orders_by_sequence() -> None: + conn, _ = _pg_conn_mock() + log = PostgresAuditLog(connection=conn) + for seq in (3, 1, 2): + log.record(_entry(sequence=seq, node_id=f"n{seq}")) + entries = log.list_entries("p", "r") + assert [e.sequence for e in entries] == [1, 2, 3] + + +def test_postgres_audit_rejects_bad_table_name() -> None: + with pytest.raises(ValueError, match="Invalid table_name"): + PostgresAuditLog(connection=MagicMock(), table_name="bad; DROP TABLE") + + +# ============================================================================= +# LoggingAuditLog +# ============================================================================= + + +def test_logging_audit_emits_record_with_firefly_audit_extra( + caplog: pytest.LogCaptureFixture, +) -> None: + log = LoggingAuditLog(logger_name="firefly.test_audit") + with caplog.at_level(logging.INFO, logger="firefly.test_audit"): + log.record(_entry(node_id="z", status="success")) + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert "firefly_audit" in rec.__dict__ + assert rec.__dict__["firefly_audit"]["node_id"] == "z" + assert rec.__dict__["firefly_audit"]["status"] == "success" + + +# ============================================================================= +# OtelAuditLog +# ============================================================================= + + +def test_otel_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(audit_module, "_otel_get_logger", None) + with pytest.raises(ImportError, match="opentelemetry-sdk"): + OtelAuditLog() + + +def test_otel_audit_emits_log_record_via_otel_logger(monkeypatch: pytest.MonkeyPatch) -> None: + mock_logger = MagicMock(name="otel_logger") + factory = MagicMock(name="get_logger", return_value=mock_logger) + monkeypatch.setattr(audit_module, "_otel_get_logger", factory) + + log = OtelAuditLog() + log.record(_entry(node_id="a", status="success")) + + factory.assert_called_once() + assert mock_logger.emit.called, "OtelAuditLog should call logger.emit() with a LogRecord" + + +# ============================================================================= +# Pipeline wiring — audit fires for every node visit +# ============================================================================= + + +class S(BaseModel): + log: str = "" + + +@pytest.mark.asyncio +async def test_pipeline_writes_one_audit_entry_per_node_visit(tmp_path: Path) -> None: + async def a(state: S) -> dict: + return {"log": "a"} + + async def b(state: S) -> dict: + return {"log": "b"} + + audit = FileAuditLog(tmp_path) + pipeline = PipelineBuilder("audit-test", state=S, audit_log=audit).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-test", result.run_id) + assert [e.node_id for e in entries] == ["a", "b"] + assert all(e.status == "success" for e in entries) + + +@pytest.mark.asyncio +async def test_pipeline_audit_captures_error_status(tmp_path: Path) -> None: + async def boom(state: S) -> dict: + raise RuntimeError("nope") + + audit = FileAuditLog(tmp_path) + pipeline = PipelineBuilder("audit-err", state=S, audit_log=audit).add_node(boom).build() + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-err", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "error" + assert "nope" in (entries[0].error_message or "") + + +@pytest.mark.asyncio +async def test_pipeline_audit_captures_paused_status(tmp_path: Path) -> None: + async def gate(state: S) -> Pause: + return Pause(reason="approval please") + + audit = FileAuditLog(tmp_path / "audit") + from fireflyframework_agentic.pipeline import FileCheckpointer + + pipeline = ( + PipelineBuilder( + "audit-pause", + state=S, + audit_log=audit, + checkpointer=FileCheckpointer(tmp_path / "ckpt"), + ) + .add_node(gate) + .build() + ) + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-pause", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "paused" + assert entries[0].pause_reason == "approval please" + + +@pytest.mark.asyncio +async def test_audit_write_failure_does_not_abort_pipeline(tmp_path: Path) -> None: + """A broken audit log shouldn't kill business logic.""" + + class CrashyAudit: + def record(self, entry: AuditEntry) -> None: + raise RuntimeError("audit storage offline") + + async def step(state: S) -> dict: + return {"log": "ran"} + + pipeline = ( + PipelineBuilder("crashy", state=S, audit_log=CrashyAudit()) # type: ignore[arg-type] + .add_node(step) + .build() + ) + result = await pipeline.invoke(S()) + assert result.success is True + assert result.state.log == "ran" diff --git a/tests/unit/pipeline/test_state_pipeline_hitl.py b/tests/unit/pipeline/test_state_pipeline_hitl.py new file mode 100644 index 00000000..aff1c687 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_hitl.py @@ -0,0 +1,208 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3c HITL tests: Pause + approve_pause resume + on_node_pause event.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + Pause, + PipelineBuilder, + StatePipeline, + extend, +) + + +class DeployState(BaseModel): + requirements: str = "" + spec: str | None = None + approved: Annotated[list[str], extend] = [] + deployed: bool = False + + +@dataclass +class PauseRecorder: + events: list[tuple] = field(default_factory=list) + + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + self.events.append(("pause", node_id, reason)) + + +# --- core pause/resume ------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_node_returning_pause_halts_pipeline(tmp_path: Path) -> None: + async def architect(state: DeployState) -> dict: + return {"spec": "v1"} + + async def gate(state: DeployState) -> Pause: + return Pause(reason="awaiting deploy approval") + + async def deploy(state: DeployState) -> dict: + return {"deployed": True} + + ckpt = FileCheckpointer(tmp_path) + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=ckpt) + .add_node(architect) + .add_node(gate) + .add_node(deploy) + .chain(architect, gate, deploy) + .build() + ) + assert isinstance(pipeline, StatePipeline) + + result = await pipeline.invoke(DeployState(requirements="user-mgmt")) + assert result.paused is True + assert result.paused_node == "gate" + assert result.pause_reason == "awaiting deploy approval" + assert result.success is False # paused != success + assert result.state.deployed is False # deploy did NOT run + assert result.completed_nodes == ["architect", "gate"] + + +@pytest.mark.asyncio +async def test_resume_without_approve_pause_raises(tmp_path: Path) -> None: + async def gate(state: DeployState) -> Pause: + return Pause(reason="block here") + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)).add_node(gate).build() + ) + first = await pipeline.invoke(DeployState()) + assert first.paused is True + + with pytest.raises(PipelineError, match="approve_pause=True"): + await pipeline.invoke(run_id=first.run_id) + + +@pytest.mark.asyncio +async def test_resume_with_approve_pause_continues_from_successor(tmp_path: Path) -> None: + fail_once = {"flag": False} + + async def architect(state: DeployState) -> dict: + if fail_once["flag"]: + raise AssertionError("architect should NOT re-run on resume") + return {"spec": "v1"} + + async def gate(state: DeployState) -> Pause: + if fail_once["flag"]: + raise AssertionError("gate should NOT re-run on resume") + return Pause(reason="approve please") + + async def deploy(state: DeployState) -> dict: + return {"deployed": True} + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)) + .add_node(architect) + .add_node(gate) + .add_node(deploy) + .chain(architect, gate, deploy) + .build() + ) + first = await pipeline.invoke(DeployState(requirements="x")) + assert first.paused is True + fail_once["flag"] = True # ensure neither architect nor gate re-runs + + second = await pipeline.invoke(run_id=first.run_id, approve_pause=True) + assert second.success is True + assert second.state.deployed is True + assert second.completed_nodes == ["architect", "gate", "deploy"] + + +@pytest.mark.asyncio +async def test_on_node_pause_event_fires(tmp_path: Path) -> None: + async def gate(state: DeployState) -> Pause: + return Pause(reason="hold") + + handler = PauseRecorder() + pipeline = ( + PipelineBuilder( + "hitl", + state=DeployState, + checkpointer=FileCheckpointer(tmp_path), + event_handler=handler, # type: ignore[arg-type] + ) + .add_node(gate) + .build() + ) + await pipeline.invoke(DeployState()) + assert handler.events == [("pause", "gate", "hold")] + + +# --- backward compat ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_paused_checkpoint_loads_when_pause_fields_missing(tmp_path: Path) -> None: + """An existing checkpoint without paused/pause_reason fields still loads.""" + from fireflyframework_agentic.pipeline.checkpoint import CheckpointRecord + + # Round-trip a record produced from a dict that omits the new fields — + # mirrors what an existing on-disk checkpoint from a pre-3c version looks like. + raw = { + "pipeline_name": "old", + "run_id": "legacy", + "node_id": "n", + "sequence": 1, + "state": {"requirements": "x"}, + "completed_nodes": ["n"], + } + record = CheckpointRecord.model_validate(raw) + assert record.paused is False + assert record.pause_reason is None + + +# --- a paused pipeline can still resume after non-pause failures ----------- + + +@pytest.mark.asyncio +async def test_pause_then_resume_then_error_then_resume(tmp_path: Path) -> None: + """End-to-end: pause, approve, then a subsequent failure, then retry.""" + counters = {"deploy_fail": False} + + async def gate(state: DeployState) -> Pause: + return Pause(reason="approve") + + async def deploy(state: DeployState) -> dict: + if not counters["deploy_fail"]: + counters["deploy_fail"] = True + raise RuntimeError("flaky") + return {"deployed": True} + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)) + .add_node(gate) + .add_node(deploy) + .chain(gate, deploy) + .build() + ) + paused = await pipeline.invoke(DeployState()) + assert paused.paused + + failed = await pipeline.invoke(run_id=paused.run_id, approve_pause=True) + assert not failed.success + assert failed.failed_node == "deploy" + + succeeded = await pipeline.invoke(run_id=paused.run_id, approve_pause=True) + # The deploy checkpoint at this point isn't marked paused, but the gate + # checkpoint still is. approve_pause is needed because load_latest may + # still return the older paused checkpoint OR the newer one depending on + # backend sort order — both backends sort by sequence so the latest is + # the failed-deploy record. The check passes either way: if the latest is + # paused, approve_pause is required; if not, approve_pause is ignored. + assert succeeded.success is True + assert succeeded.state.deployed is True