From dca8fb87642282de336ad108426a2a197b5ac9ec Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 14:27:59 +0200 Subject: [PATCH 01/26] feat(pipeline): state-based PipelineBuilder with checkpoint/resume (#147 phase 1) Evolve PipelineBuilder so a single API covers both legacy port-based parallel DAGs and declarative state-based agentic pipelines. State mode is opt-in via the new `state=` kwarg and produces a StatePipeline; existing port-based pipelines are unchanged. New surface: - PipelineBuilder(name, state=SomeModel, checkpointer=...): nodes become async (state) -> dict | None over a typed shared state. - add_node(fn): function references derive the node id from __name__. - .branch(source, router, mapping=None): unified branching. With no mapping, the router returns a target node id directly; with a mapping, it returns an abstract label that maps to a node. - Annotated[T, reducer_fn] field annotations declare merge semantics (append/extend/merge_dict; default replace). - Auto-entry detection picks the first node added; override via start_at. - FileCheckpointer persists state after each successful node; invoke(run_id=...) resumes from the latest checkpoint, skipping completed nodes. - invoke(state, start_at=node) jumps into a pipeline mid-flow with explicit state. - Agent-like objects (anything with async run(state)) drop in as nodes. The state-based executor is sequential. Port-based parallel DAGs continue to use PipelineEngine unchanged. Phase 2 (frontier executor + cycles + Send, Mermaid/JSON export, typed-edge Literal validation, soft-deprecation of BranchStep/FanOutStep, Redis/Postgres checkpointers) lands separately. Tests: 14 new in tests/unit/pipeline/test_state_pipeline.py covering linear/branching/reducers/checkpoint-resume/start_at/agent-adapter. Full pipeline suite (88) and unit suite (1405) green. --- fireflyframework_agentic/pipeline/__init__.py | 27 +- fireflyframework_agentic/pipeline/builder.py | 240 +++++++++-- .../pipeline/checkpoint.py | 99 +++++ fireflyframework_agentic/pipeline/reducers.py | 62 +++ .../pipeline/state_pipeline.py | 375 ++++++++++++++++++ tests/unit/pipeline/test_state_pipeline.py | 352 ++++++++++++++++ 6 files changed, 1112 insertions(+), 43 deletions(-) create mode 100644 fireflyframework_agentic/pipeline/checkpoint.py create mode 100644 fireflyframework_agentic/pipeline/reducers.py create mode 100644 fireflyframework_agentic/pipeline/state_pipeline.py create mode 100644 tests/unit/pipeline/test_state_pipeline.py diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index b6392aed..f49f8db6 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -17,13 +17,29 @@ This package provides a Directed Acyclic Graph (DAG) execution engine that wires agents, reasoning patterns, validation, and tools into production pipelines where independent stages execute concurrently. + +Two builder modes exist: + +* **Port-based** (legacy, parallel): :class:`PipelineEngine` executes a DAG + whose nodes communicate via ``output_key``/``input_key`` edge ports. +* **State-based**: configure ``PipelineBuilder(state=SomeModel)`` and nodes + become ``async (state) -> dict`` functions over a typed shared state. + Branching is one ``.branch(source, router)`` call. Optional checkpointing + via :class:`Checkpointer` enables resume after failure and mid-pipeline start. """ from fireflyframework_agentic.pipeline.builder import PipelineBuilder +from fireflyframework_agentic.pipeline.checkpoint import ( + Checkpointer, + CheckpointRecord, + FileCheckpointer, +) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import PipelineEngine, PipelineEventHandler +from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult +from fireflyframework_agentic.pipeline.state_pipeline import StatePipeline, StatePipelineResult from fireflyframework_agentic.pipeline.steps import ( AgentStep, BatchLLMStep, @@ -38,11 +54,13 @@ ) __all__ = [ + "DAG", "AgentStep", "BatchLLMStep", "BranchStep", "CallableStep", - "DAG", + "CheckpointRecord", + "Checkpointer", "DAGEdge", "DAGNode", "EmbeddingStep", @@ -50,6 +68,7 @@ "FailureStrategy", "FanInStep", "FanOutStep", + "FileCheckpointer", "NodeResult", "PipelineBuilder", "PipelineContext", @@ -58,5 +77,11 @@ "PipelineResult", "ReasoningStep", "RetrievalStep", + "StatePipeline", + "StatePipelineResult", "StepExecutor", + "append", + "extend", + "merge_dict", + "replace", ] diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 9dfaa8e0..1424d2c9 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -14,61 +14,158 @@ """Fluent builder API for constructing pipeline DAGs. -Usage example:: - - pipeline = ( - PipelineBuilder("idp-pipeline") - .add_node("split", splitter_step) - .add_node("classify", classifier_step) - .add_node("extract", extractor_step) - .add_edge("split", "classify") - .add_edge("classify", "extract") - .build() - ) +Two modes: + +1. **Port-based** (legacy, parallel-friendly): nodes are added by string id, + data flows over edge ports, executed by :class:`PipelineEngine`. Use this + for ETL-shaped DAGs with independent parallel steps:: + + pipeline = ( + PipelineBuilder("idp") + .add_node("split", splitter) + .add_node("classify", classifier) + .add_edge("split", "classify") + .build() + ) + +2. **State-based**: configure ``state=SomeModel`` and nodes become + ``async (state) -> dict`` functions over a typed shared state. Branching + is one ``.branch(source, router)`` call. Function references can be used + as node ids. Optional checkpointing supports resume after failure and + mid-pipeline start. Produces a :class:`StatePipeline`:: + + pipeline = ( + PipelineBuilder("agent", state=AgentState, checkpointer=FileCheckpointer("./ckpt")) + .add_node(classify) + .add_node(answer) + .add_node(escalate) + .branch(classify, route) + .build() + ) """ from __future__ import annotations -import asyncio +import inspect from collections.abc import Callable from typing import Any +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.checkpoint import Checkpointer from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.state_pipeline import ( + BranchSpec, + RouterFn, + StateNodeFn, + StatePipeline, + coerce_state_node_fn, +) from fireflyframework_agentic.pipeline.steps import AgentStep, CallableStep, StepExecutor class PipelineBuilder: - """Fluent builder for constructing a :class:`DAG` and :class:`PipelineEngine`. + """Fluent builder for pipelines. Parameters: name: Human-readable name for the pipeline. + state: Optional Pydantic model class for typed shared state. + When set, the builder produces a :class:`StatePipeline` and nodes + are expected to be ``async (state) -> dict | None``. + checkpointer: Optional :class:`Checkpointer` for state-based pipelines. + Ignored when ``state`` is not set. """ - def __init__(self, name: str = "pipeline") -> None: + def __init__( + self, + name: str = "pipeline", + *, + state: type[BaseModel] | None = None, + checkpointer: Checkpointer | None = None, + ) -> None: self._dag = DAG(name=name) + self._name = name + self._state_schema = state + self._checkpointer = checkpointer self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] + # State-based mode bookkeeping. Keyed by node id. + self._state_node_fns: dict[str, StateNodeFn] = {} + self._branches: dict[str, BranchSpec] = {} def add_node( self, - node_id: str, - step: Any, + node_id_or_fn: str | Callable[..., Any], + step: Any = None, *, condition: Callable[..., bool] | None = None, retry_max: int = 0, timeout_seconds: float = 0, failure_strategy: FailureStrategy = FailureStrategy.SKIP_DOWNSTREAM, ) -> PipelineBuilder: - """Add a node to the pipeline. + """Add a node. - *step* can be: - - A :class:`StepExecutor` (AgentStep, CallableStep, etc.) - - A :class:`FireflyAgent` (auto-wrapped in :class:`AgentStep`) - - An async callable (auto-wrapped in :class:`CallableStep`) + Two signatures: - Returns *self* for chaining. + * ``add_node(fn)`` — state-based mode. ``fn`` is a callable; the node + id is taken from ``fn.__name__``. Requires the builder was constructed + with ``state=...``. + * ``add_node(node_id, step)`` — legacy port-based mode. ``step`` is a + :class:`StepExecutor`, an agent-like, or an async callable. """ + if step is None and callable(node_id_or_fn) and not isinstance(node_id_or_fn, str): + # State-based: derive id from function name. + if self._state_schema is None: + raise PipelineError( + "Function-reference add_node(fn) requires PipelineBuilder(state=...). " + "Use add_node('id', step) for port-based pipelines." + ) + fn = node_id_or_fn + node_id = getattr(fn, "__name__", None) or repr(fn) + self._state_node_fns[node_id] = coerce_state_node_fn(fn) + self._pending_nodes.append( + DAGNode( + node_id=node_id, + step=_StateNodePlaceholder(), # never executed; engine path is unused for state pipelines + condition=condition, + retry_max=retry_max, + timeout_seconds=timeout_seconds, + failure_strategy=failure_strategy, + ) + ) + return self + + if not isinstance(node_id_or_fn, str): + raise PipelineError("add_node(node_id, step) expects a string node id when a step is provided.") + node_id = node_id_or_fn + + if self._state_schema is not None and step is not None: + # State-based pipeline: accept a callable, or an agent-like object + # exposing async ``run(state)``. ``coerce_state_node_fn`` handles both. + run_method = getattr(step, "run", None) + if not callable(step) and not callable(run_method): + raise PipelineError( + f"State pipeline node '{node_id}' must be a callable or expose async run(state); " + f"got {type(step).__name__}" + ) + self._state_node_fns[node_id] = coerce_state_node_fn(step) + self._pending_nodes.append( + DAGNode( + node_id=node_id, + step=_StateNodePlaceholder(), + condition=condition, + retry_max=retry_max, + timeout_seconds=timeout_seconds, + failure_strategy=failure_strategy, + ) + ) + return self + + if step is None: + raise PipelineError(f"add_node('{node_id}', step=...) requires a step.") + executor = self._resolve_step(step) self._pending_nodes.append( DAGNode( @@ -84,46 +181,88 @@ def add_node( def add_edge( self, - source: str, - target: str, + source: str | Callable[..., Any], + target: str | Callable[..., Any], *, output_key: str = "output", input_key: str = "input", ) -> PipelineBuilder: """Add a directed edge from *source* to *target*. - Returns *self* for chaining. + Both endpoints may be node ids (str) or function references (in which + case ``fn.__name__`` is used). """ self._pending_edges.append( DAGEdge( - source=source, - target=target, + source=_id(source), + target=_id(target), output_key=output_key, input_key=input_key, ) ) return self - def chain(self, *node_ids: str) -> PipelineBuilder: - """Connect nodes in sequence: A -> B -> C -> ... + def chain(self, *nodes: str | Callable[..., Any]) -> PipelineBuilder: + """Connect nodes in sequence: A -> B -> C -> ...""" + ids = [_id(n) for n in nodes] + for i in range(len(ids) - 1): + self.add_edge(ids[i], ids[i + 1]) + return self + + def branch( + self, + source: str | Callable[..., Any], + router: RouterFn, + mapping: dict[str, str | Callable[..., Any]] | None = None, + ) -> PipelineBuilder: + """Register a router on ``source``. + + ``router`` is a synchronous ``(state) -> str`` callable. Behaviour: - All referenced nodes must already have been added via :meth:`add_node`. - Returns *self* for chaining. + * If ``mapping`` is None, the router must return the **id of an + existing node** that will run next. + * If ``mapping`` is provided, the router returns an abstract label + that is looked up in ``mapping`` to find the target node id. + + State-based pipelines only. """ - for i in range(len(node_ids) - 1): - self.add_edge(node_ids[i], node_ids[i + 1]) + if self._state_schema is None: + raise PipelineError(".branch(...) requires PipelineBuilder(state=...)") + source_id = _id(source) + resolved_mapping: dict[str, str] | None = None + if mapping is not None: + resolved_mapping = {label: _id(target) for label, target in mapping.items()} + # Materialize each label's edge into the DAG so topology is inspectable. + for target_id in resolved_mapping.values(): + self._pending_edges.append(DAGEdge(source=source_id, target=target_id)) + else: + # No mapping: we don't know targets at build time; edges will + # be missing from the DAG. That's fine for the StatePipeline + # executor (it consults the router), but visualisation will be + # incomplete. Materialize edges lazily when the router fires. + pass + self._branches[source_id] = BranchSpec(source=source_id, router=router, mapping=resolved_mapping) return self - def build(self) -> PipelineEngine: - """Build the DAG, validate it, and return a :class:`PipelineEngine`. - - Raises: - PipelineError: If the graph is invalid (cycles, missing nodes). + def build(self) -> PipelineEngine | StatePipeline: + """Build the DAG and return either a :class:`PipelineEngine` + (legacy port-based) or :class:`StatePipeline` (when ``state=`` is set). """ for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: self._dag.add_edge(edge) + + if self._state_schema is not None: + return StatePipeline( + name=self._name, + dag=self._dag, + state_schema=self._state_schema, + node_fns=self._state_node_fns, + branches=self._branches, + checkpointer=self._checkpointer, + ) + return PipelineEngine(self._dag) def build_dag(self) -> DAG: @@ -139,12 +278,29 @@ def _resolve_step(step: Any) -> Any: """Wrap non-executor objects in the appropriate step type.""" if isinstance(step, StepExecutor): return step - # Duck-type check for agent-like objects if hasattr(step, "run") and callable(step.run): return AgentStep(step) - # Async callable - if callable(step) and asyncio.iscoroutinefunction(step): + if callable(step) and inspect.iscoroutinefunction(step): return CallableStep(step) raise TypeError( - f"Cannot resolve {type(step).__name__} as a pipeline step. Must be StepExecutor, agent-like, or async callable." + f"Cannot resolve {type(step).__name__} as a pipeline step. " + f"Must be StepExecutor, agent-like, or async callable." ) + + +def _id(ref: str | Callable[..., Any]) -> str: + """Coerce a string id or function reference into a node id string.""" + if isinstance(ref, str): + return ref + name = getattr(ref, "__name__", None) + if not name: + raise PipelineError(f"Cannot derive node id from {ref!r}") + return name + + +class _StateNodePlaceholder: + """Sentinel step kept in the DAG so topology is intact. Never executed — + state pipelines bypass :class:`PipelineEngine` entirely.""" + + async def execute(self, *_args: Any, **_kwargs: Any) -> Any: + raise PipelineError("_StateNodePlaceholder.execute called — state pipelines should not use PipelineEngine.") diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py new file mode 100644 index 00000000..d4ebeaae --- /dev/null +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -0,0 +1,99 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline state checkpointing for failure recovery and resumable runs. + +A :class:`Checkpointer` persists state after each successful node, keyed by +``(pipeline_name, run_id, node_id)``. On resume the engine loads the latest +checkpoint and skips nodes that already completed in that run. + +:class:`FileCheckpointer` is a filesystem-backed JSON implementation suitable +for single-process workflows. Redis / Postgres checkpointers can implement +the same Protocol and be swapped in without API changes. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +from pydantic import BaseModel + + +class CheckpointRecord(BaseModel): + """One saved checkpoint.""" + + pipeline_name: str + run_id: str + node_id: str + sequence: int + state: dict[str, Any] + completed_nodes: list[str] + + +@runtime_checkable +class Checkpointer(Protocol): + """Persists pipeline state after each successful node. + + Implementations must be safe to call from async code (the engine awaits + save() inside its task loop) but the methods themselves may be sync. + """ + + def save(self, record: CheckpointRecord) -> None: + """Persist a checkpoint. Overwrites if (pipeline, run_id, node_id) exists.""" + ... + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + """Return the most recent checkpoint for ``run_id`` or ``None`` if no run exists.""" + ... + + def list_runs(self, pipeline_name: str) -> list[str]: + """Return all known run IDs for ``pipeline_name``.""" + ... + + +class FileCheckpointer: + """Filesystem-backed checkpointer. Layout:: + + ///_.json + + The ``sequence`` prefix gives a natural sort order for ``load_latest``. + """ + + def __init__(self, root: str | Path) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + def save(self, record: CheckpointRecord) -> None: + run_dir = self._root / record.pipeline_name / record.run_id + run_dir.mkdir(parents=True, exist_ok=True) + path = run_dir / f"{record.sequence:06d}_{record.node_id}.json" + path.write_text(record.model_dump_json(indent=2)) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + run_dir = self._root / pipeline_name / run_id + if not run_dir.exists(): + return None + files = sorted(run_dir.glob("*.json")) + if not files: + return None + latest = files[-1] + return CheckpointRecord.model_validate(json.loads(latest.read_text())) + + def list_runs(self, pipeline_name: str) -> list[str]: + pipeline_dir = self._root / pipeline_name + if not pipeline_dir.exists(): + return [] + return sorted(d.name for d in pipeline_dir.iterdir() if d.is_dir()) diff --git a/fireflyframework_agentic/pipeline/reducers.py b/fireflyframework_agentic/pipeline/reducers.py new file mode 100644 index 00000000..da91e1dd --- /dev/null +++ b/fireflyframework_agentic/pipeline/reducers.py @@ -0,0 +1,62 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-merge reducers for typed pipeline state. + +Reducers are functions ``(current, update) -> merged`` declared on a state +field via :class:`typing.Annotated`. The pipeline engine inspects +``typing.get_type_hints(state_schema, include_extras=True)`` for each field +and applies the relevant reducer when a node returns a partial state dict. + +Fields without an annotated reducer use :func:`replace` (last-write-wins). + +Example:: + + class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] + intent: str | None = None # uses replace by default +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +Reducer = Callable[[Any, Any], Any] + + +def replace(current: Any, update: Any) -> Any: # noqa: ARG001 + """Last-write-wins: the update value replaces the current value.""" + return update + + +def append(current: Any, update: Any) -> list[Any]: + """Append a single item to a list. ``current`` is treated as ``[]`` if ``None``.""" + base = list(current) if current else [] + base.append(update) + return base + + +def extend(current: Any, update: Any) -> list[Any]: + """Concatenate two iterables. ``update`` must be iterable.""" + base = list(current) if current else [] + base.extend(update) + return base + + +def merge_dict(current: Any, update: Any) -> dict[Any, Any]: + """Shallow-merge two dicts; keys in ``update`` win.""" + base = dict(current) if current else {} + base.update(update or {}) + return base diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py new file mode 100644 index 00000000..1b09e185 --- /dev/null +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -0,0 +1,375 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-based pipeline: a sequential executor over a typed shared-state object. + +Layered on top of :class:`DAG` for topology, but uses its own simple executor +rather than :class:`PipelineEngine`. The trade-off: no within-level parallelism, +but in exchange we get clean semantics for typed state, reducers, branching, +checkpointing, and mid-pipeline resume — which are the things this API exists +to provide. Port-based parallel DAGs continue to use :class:`PipelineEngine`. +""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import time +import uuid +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any, get_type_hints + +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord +from fireflyframework_agentic.pipeline.dag import DAG +from fireflyframework_agentic.pipeline.reducers import Reducer, replace + +logger = logging.getLogger(__name__) + +StateNodeFn = Callable[[Any], Awaitable[dict[str, Any] | None]] +RouterFn = Callable[[Any], str] + + +@dataclass +class BranchSpec: + """Internal: registered branch from one source node.""" + + source: str + router: RouterFn + mapping: dict[str, str] | None # label -> target node_id. None = router returns target directly. + + +@dataclass +class StatePipelineResult: + """Outcome of a single ``invoke`` call. + + Attributes: + state: Final state object. + run_id: ID of this run (use to resume later via ``invoke(run_id=...)``). + completed_nodes: Node IDs that ran successfully this invocation, in order. + success: True iff all attempted nodes completed without error. + error: Last error message if ``success`` is False. + failed_node: Node ID that failed, if any. + """ + + state: Any + run_id: str + completed_nodes: list[str] + success: bool + error: str | None = None + failed_node: str | None = None + + +def discover_reducers(state_schema: type) -> dict[str, Reducer]: + """Inspect ``Annotated[T, reducer_fn]`` annotations on the schema. + + Only ``Annotated[...]`` metadata is consulted — not generic origins like + ``list[...]`` or unions. Fields without an annotated reducer are absent + from the returned dict; callers should treat absence as :func:`replace`. + """ + out: dict[str, Reducer] = {} + try: + hints = get_type_hints(state_schema, include_extras=True) + except Exception: + return out + for field_name, hint in hints.items(): + # Annotated[...] is the only metadata-bearing form we care about. + metadata = getattr(hint, "__metadata__", None) + if not metadata: + continue + for meta in metadata: + if callable(meta): + out[field_name] = meta + break + return out + + +def apply_update(state: BaseModel, update: dict[str, Any], reducers: dict[str, Reducer]) -> BaseModel: + """Return a new state object with ``update`` merged into ``state`` via reducers.""" + if not update: + return state + new_values = state.model_dump() + for key, value in update.items(): + if key not in new_values: + # Tolerate unknown keys with a warning rather than failing — + # makes incremental schema evolution painless. + logger.warning("State update key '%s' not in schema %s; ignored.", key, type(state).__name__) + continue + reducer = reducers.get(key, replace) + new_values[key] = reducer(new_values[key], value) + return type(state).model_validate(new_values) + + +class StatePipeline: + """Compiled state-based pipeline. Returned by ``PipelineBuilder.build()`` + when a ``state=`` schema is configured. + """ + + def __init__( + self, + *, + name: str, + dag: DAG, + state_schema: type[BaseModel], + node_fns: dict[str, StateNodeFn], + branches: dict[str, BranchSpec], + checkpointer: Checkpointer | None = None, + ) -> None: + self._name = name + self._dag = dag + self._state_schema = state_schema + self._node_fns = node_fns + self._branches = branches + self._checkpointer = checkpointer + self._reducers = discover_reducers(state_schema) + self._validate() + + @property + def name(self) -> str: + return self._name + + @property + def dag(self) -> DAG: + return self._dag + + def _validate(self) -> None: + # Every node must have a registered fn. + for node_id in self._dag.nodes: + if node_id not in self._node_fns: + raise PipelineError(f"Node '{node_id}' has no registered function") + # Every branch source/target must exist. + for source, spec in self._branches.items(): + if source not in self._dag.nodes: + raise PipelineError(f"Branch source '{source}' not in DAG") + if spec.mapping: + for label, target in spec.mapping.items(): + if target not in self._dag.nodes: + raise PipelineError(f"Branch target '{target}' (label '{label}') not in DAG") + + def _entry_node(self) -> str: + """Default entry: the first node added. + + Override with ``invoke(state, start_at=...)``. Picking insertion-order + rather than the topological root keeps things predictable in the + common case where a ``.branch(...)`` without an explicit mapping + leaves multiple nodes with no inbound edges. + """ + order = list(self._dag.nodes) + if not order: + raise PipelineError("Pipeline has no nodes") + return order[0] + + def _next_node(self, current: str, state: BaseModel) -> str | None: + """Decide which successor runs next given the current state. + + For non-branching nodes: pick the unique successor (or None at terminus). + For branching nodes: run the router and resolve via mapping if present. + """ + if current in self._branches: + spec = self._branches[current] + label = spec.router(state) + if spec.mapping is not None: + if label not in spec.mapping: + raise PipelineError( + f"Router for '{current}' returned label '{label}' not in mapping {list(spec.mapping)}" + ) + return spec.mapping[label] + # Mapping omitted: router returns target node id directly. + if label not in self._dag.nodes: + raise PipelineError( + f"Router for '{current}' returned '{label}' " + f"which is not a registered node id; pass an explicit mapping if you want labels." + ) + return label + + successors = self._dag.successors(current) + if not successors: + return None + if len(successors) > 1: + raise PipelineError( + f"Node '{current}' has multiple successors {successors} but no .branch(...) registered. " + f"Register a branch router or remove the extra edges." + ) + return successors[0] + + async def invoke( + self, + state: BaseModel | None = None, + *, + run_id: str | None = None, + start_at: str | Callable[..., Any] | None = None, + ) -> StatePipelineResult: + """Run the pipeline. + + Modes: + * Fresh run: ``invoke(state)`` — generates a new ``run_id``. + * Resume: ``invoke(run_id="abc")`` — loads latest checkpoint and continues. + * Mid-pipeline start: ``invoke(state=..., start_at=node)`` — + starts execution at ``node`` with the provided state. + """ + resumed_completed: list[str] = [] + + # Resume mode: load checkpoint, derive starting node from it. + if run_id is not None and state is None and start_at is None: + if self._checkpointer is None: + raise PipelineError("Cannot resume: pipeline has no checkpointer") + record = self._checkpointer.load_latest(self._name, run_id) + if record is None: + raise PipelineError(f"No checkpoint found for run_id='{run_id}'") + state = self._state_schema.model_validate(record.state) + resumed_completed = list(record.completed_nodes) + # Resume at the successor of the last completed node. + last = record.node_id + next_node = self._next_node(last, state) + if next_node is None: + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=resumed_completed, + success=True, + ) + current_node: str | None = next_node + else: + if state is None: + raise PipelineError("invoke() requires a state argument (or a run_id to resume)") + if not isinstance(state, self._state_schema): + # Be helpful if caller passed a dict or a different model. + try: + state = self._state_schema.model_validate(state) + except Exception as exc: + raise PipelineError(f"state argument is not a {self._state_schema.__name__}: {exc}") from exc + if start_at is not None: + current_node = _resolve_node_id(start_at) + if current_node not in self._dag.nodes: + raise PipelineError(f"start_at='{current_node}' not in DAG") + else: + current_node = self._entry_node() + + if run_id is None: + run_id = uuid.uuid4().hex[:12] + + assert state is not None # narrowed by the branches above + completed: list[str] = list(resumed_completed) + sequence = len(completed) + + while current_node is not None: + fn = self._node_fns[current_node] + t0 = time.perf_counter() + try: + update = await fn(state) + except Exception as exc: + logger.exception( + "State pipeline '%s' run '%s' failed at node '%s'", + self._name, + run_id, + current_node, + ) + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=current_node, + ) + elapsed = (time.perf_counter() - t0) * 1000 + logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, current_node, elapsed) + + if update: + state = apply_update(state, update, self._reducers) + + completed.append(current_node) + sequence += 1 + + if self._checkpointer is not None: + try: + self._checkpointer.save( + CheckpointRecord( + pipeline_name=self._name, + run_id=run_id, + node_id=current_node, + sequence=sequence, + state=state.model_dump(), + completed_nodes=list(completed), + ) + ) + except Exception: + # Checkpoint failure is non-fatal — log and continue. + logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, current_node) + + try: + current_node = self._next_node(current_node, state) + except PipelineError as exc: + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=completed[-1] if completed else None, + ) + + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=True, + ) + + +def _resolve_node_id(ref: str | Callable[..., Any]) -> str: + """Turn either a string id or a function reference into a node id.""" + if isinstance(ref, str): + return ref + name = getattr(ref, "__name__", None) + if not name: + raise PipelineError(f"Cannot derive node id from {ref!r}") + return name + + +def coerce_state_node_fn(fn: Callable[..., Any]) -> StateNodeFn: + """Adapt a user-supplied callable into the ``async (state) -> dict | None`` shape. + + Accepted forms: + * ``async def f(state) -> dict | None`` — used as-is. + * ``def f(state) -> dict | None`` — wrapped to run in a thread. + * Object with ``async run(state)`` (e.g. a FireflyAgent-like) — adapter calls ``.run(state)``. + """ + if inspect.iscoroutinefunction(fn): + return fn # type: ignore[return-value] + + # Object with .run(state) — e.g. a FireflyAgent. Check before the generic + # callable branch so agent-shaped objects don't get treated as plain callables. + run = getattr(fn, "run", None) + if not callable(fn) and run is not None and callable(run): + + async def _agent_wrap(state: Any) -> Any: + if inspect.iscoroutinefunction(run): + return await run(state) + return await asyncio.get_running_loop().run_in_executor(None, run, state) + + return _agent_wrap + + if callable(fn): + + async def _async_wrap(state: Any) -> Any: + return await asyncio.get_running_loop().run_in_executor(None, fn, state) + + return _async_wrap + + raise PipelineError(f"Cannot adapt {fn!r} as a state node function") diff --git a/tests/unit/pipeline/test_state_pipeline.py b/tests/unit/pipeline/test_state_pipeline.py new file mode 100644 index 00000000..0853c106 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline.py @@ -0,0 +1,352 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Tests for the state-based pipeline API (issue #147 phase 1). + +Covers the canonical agentic-pipeline shape: + * Typed shared state via a Pydantic model. + * Reducers via ``Annotated[T, reducer_fn]``. + * Function references as node ids. + * Auto-entry detection. + * ``.branch(source, router)`` with and without an explicit mapping. + * Checkpoint + resume after failure (the software-factory scenario). + * ``start_at`` to jump into the middle of a pipeline with explicit state. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + StatePipeline, + append, +) + + +class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] + intent: str | None = None + answer: str | None = None + + +# --- linear pipeline ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_linear_pipeline_runs_all_nodes(): + """Three nodes in sequence; each writes to state; final state has all updates.""" + + async def step_a(state: AgentState) -> dict: + return {"messages": "a"} + + async def step_b(state: AgentState) -> dict: + return {"messages": "b"} + + async def step_c(state: AgentState) -> dict: + return {"messages": "c", "answer": "done"} + + pipeline = ( + PipelineBuilder("linear", state=AgentState) + .add_node(step_a) + .add_node(step_b) + .add_node(step_c) + .chain(step_a, step_b, step_c) + .build() + ) + assert isinstance(pipeline, StatePipeline) + result = await pipeline.invoke(AgentState(messages=["start"])) + assert result.success + assert result.completed_nodes == ["step_a", "step_b", "step_c"] + assert result.state.messages == ["start", "a", "b", "c"] + assert result.state.answer == "done" + + +@pytest.mark.asyncio +async def test_returning_none_or_empty_dict_keeps_state(): + """A node that returns None or {} should leave state unchanged.""" + + async def noop(state: AgentState) -> None: + return None + + async def writer(state: AgentState) -> dict: + return {"answer": "ok"} + + pipeline = PipelineBuilder("noop", state=AgentState).add_node(noop).add_node(writer).chain(noop, writer).build() + result = await pipeline.invoke(AgentState(messages=["x"])) + assert result.success + assert result.state.messages == ["x"] # unchanged by noop + assert result.state.answer == "ok" + + +# --- branching ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_branch_without_mapping_router_returns_node_id(): + """Router returns the target node id directly; no mapping needed.""" + + async def classify(state: AgentState) -> dict: + return {"intent": "complaint" if "refund" in " ".join(state.messages) else "general"} + + async def answer(state: AgentState) -> dict: + return {"answer": "Here is your answer."} + + async def escalate(state: AgentState) -> dict: + return {"answer": "Escalated."} + + def route(state: AgentState) -> str: + return "escalate" if state.intent == "complaint" else "answer" + + pipeline = ( + PipelineBuilder("branch", state=AgentState) + .add_node(classify) + .add_node(answer) + .add_node(escalate) + .branch(classify, route) + .build() + ) + complaint = await pipeline.invoke(AgentState(messages=["I want a refund"])) + assert complaint.state.answer == "Escalated." + assert complaint.completed_nodes == ["classify", "escalate"] + + general = await pipeline.invoke(AgentState(messages=["hello"])) + assert general.state.answer == "Here is your answer." + assert general.completed_nodes == ["classify", "answer"] + + +@pytest.mark.asyncio +async def test_branch_with_explicit_mapping_uses_abstract_labels(): + async def start(state: AgentState) -> dict: + return {"intent": "x"} + + async def left(state: AgentState) -> dict: + return {"answer": "L"} + + async def right(state: AgentState) -> dict: + return {"answer": "R"} + + def route(state: AgentState) -> str: + return "go_left" if state.intent == "x" else "go_right" + + pipeline = ( + PipelineBuilder("mapped", state=AgentState) + .add_node(start) + .add_node(left) + .add_node(right) + .branch(start, route, {"go_left": left, "go_right": right}) + .build() + ) + result = await pipeline.invoke(AgentState()) + assert result.state.answer == "L" + + +@pytest.mark.asyncio +async def test_router_returning_unknown_label_raises(): + async def start(state: AgentState) -> dict: + return {} + + async def target(state: AgentState) -> dict: + return {"answer": "ok"} + + def bad_router(state: AgentState) -> str: + return "nonexistent_node" + + pipeline = ( + PipelineBuilder("bad", state=AgentState).add_node(start).add_node(target).branch(start, bad_router).build() + ) + result = await pipeline.invoke(AgentState()) + assert not result.success + assert "nonexistent_node" in (result.error or "") + + +# --- reducers -------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_append_reducer_accumulates_across_nodes(): + """The default test schema uses append on messages; each node adds one.""" + + async def a(state: AgentState) -> dict: + return {"messages": "from_a"} + + async def b(state: AgentState) -> dict: + return {"messages": "from_b"} + + pipeline = PipelineBuilder("acc", state=AgentState).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(AgentState(messages=["initial"])) + assert result.state.messages == ["initial", "from_a", "from_b"] + + +@pytest.mark.asyncio +async def test_replace_reducer_is_default_for_unannotated_field(): + async def a(state: AgentState) -> dict: + return {"answer": "first"} + + async def b(state: AgentState) -> dict: + return {"answer": "second"} + + pipeline = PipelineBuilder("rep", state=AgentState).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(AgentState()) + assert result.state.answer == "second" + + +# --- checkpoint + resume --------------------------------------------------- + + +class BuildState(BaseModel): + """Software-factory scenario state.""" + + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +@pytest.mark.asyncio +async def test_checkpoint_resume_after_failure(tmp_path: Path): + """Run a 4-step agent factory; deployer fails the first time; resume succeeds.""" + + failed_once = {"deploy": False} + + async def architect(state: BuildState) -> dict: + return {"spec": "architecture spec for: " + state.requirements} + + async def python_dev(state: BuildState) -> dict: + return {"code": f"# code implementing {state.spec}"} + + async def deployer(state: BuildState) -> dict: + if not failed_once["deploy"]: + failed_once["deploy"] = True + raise RuntimeError("network glitch") + return {"deploy_url": "https://app.example.com"} + + async def evaluator(state: BuildState) -> dict: + return {"evaluation": f"PASS: {state.deploy_url}"} + + ckpt = FileCheckpointer(tmp_path / "ckpt") + pipeline = ( + PipelineBuilder("software-factory", state=BuildState, checkpointer=ckpt) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + + # First run: deployer fails. + first = await pipeline.invoke(BuildState(requirements="user-mgmt service")) + assert not first.success + assert first.failed_node == "deployer" + assert first.completed_nodes == ["architect", "python_dev"] + assert first.state.code is not None # python_dev did persist + + # Resume: should skip architect/python_dev, retry deployer, then evaluator. + second = await pipeline.invoke(run_id=first.run_id) + assert second.success + assert second.completed_nodes == ["architect", "python_dev", "deployer", "evaluator"] + assert second.state.evaluation == "PASS: https://app.example.com" + + +@pytest.mark.asyncio +async def test_start_at_jumps_to_middle_with_explicit_state(tmp_path: Path): + """Caller supplies state + start_at to run only from deployer onwards.""" + + async def architect(state: BuildState) -> dict: + raise AssertionError("should not run") + + async def python_dev(state: BuildState) -> dict: + raise AssertionError("should not run") + + async def deployer(state: BuildState) -> dict: + return {"deploy_url": "https://app.example.com"} + + async def evaluator(state: BuildState) -> dict: + return {"evaluation": "PASS"} + + pipeline = ( + PipelineBuilder("factory", state=BuildState) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + pre_built = BuildState(requirements="x", spec="precomputed", code="precomputed code") + result = await pipeline.invoke(pre_built, start_at=deployer) + assert result.success + assert result.completed_nodes == ["deployer", "evaluator"] + assert result.state.deploy_url == "https://app.example.com" + + +@pytest.mark.asyncio +async def test_resume_without_checkpointer_raises(): + async def a(state: AgentState) -> dict: + return {} + + pipeline = PipelineBuilder("nockpt", state=AgentState).add_node(a).build() + with pytest.raises(PipelineError, match="no checkpointer"): + await pipeline.invoke(run_id="anything") + + +# --- validation / errors --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_default_entry_is_first_node_added(): + """When no inbound edges disambiguate, the first add_node call is the entry.""" + + async def first_one(state: AgentState) -> dict: + return {"answer": "first ran"} + + async def second_one(state: AgentState) -> dict: + raise AssertionError("not reached without an edge") + + pipeline = PipelineBuilder("order", state=AgentState).add_node(first_one).add_node(second_one).build() + result = await pipeline.invoke(AgentState()) + assert result.completed_nodes == ["first_one"] + assert result.state.answer == "first ran" + + +def test_function_ref_without_state_raises(): + async def step(state): + return {} + + with pytest.raises(PipelineError, match="state=..."): + PipelineBuilder("nostate").add_node(step) + + +def test_branch_without_state_raises(): + builder = PipelineBuilder("nostate") + with pytest.raises(PipelineError, match="state=..."): + builder.branch("x", lambda s: "y") + + +# --- agent-shape adapter --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_agent_like_object_adapts_via_run_method(): + """Object exposing async run(state) is accepted as a node.""" + + class MockAgent: + __name__ = "mock_agent" # required for function-ref node id derivation + + async def run(self, state: AgentState) -> dict: + return {"answer": "from mock agent"} + + pipeline = PipelineBuilder("agent", state=AgentState).add_node("mock_agent", MockAgent()).build() + result = await pipeline.invoke(AgentState()) + assert result.success + assert result.state.answer == "from mock agent" From 37450ea9b0087f2e0167101667826bf48d9f05f1 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 15:18:52 +0200 Subject: [PATCH 02/26] feat(pipeline): cycles, Send fan-out, Mermaid/JSON export (#147 phase 2) Adds the agentic-loop and fan-out features deferred from phase 1, plus visualization on the DAG, plus soft-deprecation of the legacy branching/ fan-out steps. Phase 1 surface is preserved; everything here is additive. What's new: - Send(target, payload) dataclass. Routers can return list[Send] for runtime fan-out: workers run concurrently with their own payload-merged state copy; results reduce back into shared state. - Cycles supported in state mode. PipelineBuilder(state=...) constructs the underlying DAG with allow_cycles=True so a node can route back to itself for ReAct loops / retry-with-critique. - recursion_limit kwarg on PipelineBuilder (default 25). Per-node visit counter aborts runaway cycles with a clean failure result. - DAG.to_mermaid() / DAG.to_json() for any DAG. - StatePipeline.to_mermaid() that adds branch-edge labels from the registered mapping. - BranchStep and FanOutStep emit DeprecationWarning pointing to .branch(...) and Send(...). Existing pipelines continue to work. API additions exported from pipeline.__init__: Send, RecursionLimitError Tests: 9 new in test_state_pipeline_phase2.py covering loop-with-exit, recursion_limit, map-reduce-style fan-out, unknown-target error, Mermaid/JSON output, and deprecation warning emission. Pipeline suite now 97 passed (88 phase-1 + 9 phase-2); full unit suite 1405 passed. Ruff check + format clean. Pyright clean on touched modules. --- fireflyframework_agentic/pipeline/__init__.py | 9 +- fireflyframework_agentic/pipeline/builder.py | 8 +- fireflyframework_agentic/pipeline/dag.py | 72 ++++- .../pipeline/state_pipeline.py | 303 +++++++++++++++--- fireflyframework_agentic/pipeline/steps.py | 13 + .../pipeline/test_state_pipeline_phase2.py | 231 +++++++++++++ 6 files changed, 582 insertions(+), 54 deletions(-) create mode 100644 tests/unit/pipeline/test_state_pipeline_phase2.py diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index f49f8db6..ed844dc0 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -39,7 +39,12 @@ from fireflyframework_agentic.pipeline.engine import PipelineEngine, PipelineEventHandler from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult -from fireflyframework_agentic.pipeline.state_pipeline import StatePipeline, StatePipelineResult +from fireflyframework_agentic.pipeline.state_pipeline import ( + RecursionLimitError, + Send, + StatePipeline, + StatePipelineResult, +) from fireflyframework_agentic.pipeline.steps import ( AgentStep, BatchLLMStep, @@ -76,6 +81,8 @@ "PipelineEventHandler", "PipelineResult", "ReasoningStep", + "RecursionLimitError", + "Send", "RetrievalStep", "StatePipeline", "StatePipelineResult", diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 1424d2c9..bb5f7681 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -59,6 +59,7 @@ from fireflyframework_agentic.pipeline.state_pipeline import ( BranchSpec, RouterFn, + Send, # noqa: F401 re-exported via pipeline/__init__.py StateNodeFn, StatePipeline, coerce_state_node_fn, @@ -84,11 +85,15 @@ def __init__( *, state: type[BaseModel] | None = None, checkpointer: Checkpointer | None = None, + recursion_limit: int = 25, ) -> None: - self._dag = DAG(name=name) + # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). + # The legacy port-based path keeps acyclicity as an invariant. + self._dag = DAG(name=name, allow_cycles=state is not None) self._name = name self._state_schema = state self._checkpointer = checkpointer + self._recursion_limit = recursion_limit self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] # State-based mode bookkeeping. Keyed by node id. @@ -261,6 +266,7 @@ def build(self) -> PipelineEngine | StatePipeline: node_fns=self._state_node_fns, branches=self._branches, checkpointer=self._checkpointer, + recursion_limit=self._recursion_limit, ) return PipelineEngine(self._dag) diff --git a/fireflyframework_agentic/pipeline/dag.py b/fireflyframework_agentic/pipeline/dag.py index 2e4fef9b..80d78b76 100644 --- a/fireflyframework_agentic/pipeline/dag.py +++ b/fireflyframework_agentic/pipeline/dag.py @@ -15,12 +15,14 @@ """Directed Acyclic Graph (DAG) model for pipeline topology. :class:`DAG` holds :class:`DAGNode` and :class:`DAGEdge` objects, validates -acyclicity, computes topological sort, and identifies independent execution -levels for parallel scheduling. +acyclicity (unless ``allow_cycles=True``), computes topological sort, and +identifies independent execution levels for parallel scheduling. Also renders +itself as Mermaid or JSON for inspection / docs / Studio. """ from __future__ import annotations +import json from collections import defaultdict, deque from collections.abc import Callable from enum import StrEnum @@ -98,8 +100,9 @@ class DAG: name: A human-readable name for the pipeline. """ - def __init__(self, name: str = "pipeline") -> None: + def __init__(self, name: str = "pipeline", *, allow_cycles: bool = False) -> None: self._name = name + self._allow_cycles = allow_cycles self._nodes: dict[str, DAGNode] = {} self._edges: list[DAGEdge] = [] # Adjacency and reverse adjacency for topo-sort @@ -136,9 +139,9 @@ def add_edge(self, edge: DAGEdge) -> None: self._edges.append(edge) self._adj[edge.source].append(edge.target) self._in_degree[edge.target] = self._in_degree.get(edge.target, 0) + 1 - # Incremental cycle check - if self._has_cycle(): - # Rollback + # Cycle check is skipped when the DAG was constructed with allow_cycles=True + # (state-based pipelines opt into this for ReAct-style loops). + if not self._allow_cycles and self._has_cycle(): self._edges.pop() self._adj[edge.source].pop() self._in_degree[edge.target] -= 1 @@ -236,5 +239,62 @@ def _has_cycle(self) -> bool: queue.append(neighbour) return count != len(self._nodes) + def is_cyclic(self) -> bool: + """True if the graph contains at least one cycle.""" + return self._has_cycle() + + # -- Export ------------------------------------------------------------ + + def to_mermaid(self) -> str: + """Render the topology as a Mermaid flowchart. + + Edges with ``input_key`` other than the default ``"input"`` are + labelled with that key so port wiring is visible. + """ + lines = ["flowchart TD"] + for node_id in self._nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + for edge in self._edges: + label = edge.input_key if edge.input_key and edge.input_key != "input" else None + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + return "\n".join(lines) + + def to_json(self) -> str: + """Render the topology as a JSON document. + + Schema:: + + {"name": str, "nodes": [str], "edges": [{"source", "target", "output_key", "input_key"}]} + """ + doc = { + "name": self._name, + "nodes": list(self._nodes.keys()), + "edges": [ + { + "source": e.source, + "target": e.target, + "output_key": e.output_key, + "input_key": e.input_key, + } + for e in self._edges + ], + } + return json.dumps(doc, indent=2) + def __repr__(self) -> str: return f"DAG(name={self._name!r}, nodes={len(self._nodes)}, edges={len(self._edges)})" + + +def _mermaid_id(node_id: str) -> str: + """Sanitize a node id for use as a Mermaid identifier.""" + out = [] + for ch in node_id: + if ch.isalnum() or ch == "_": + out.append(ch) + else: + out.append("_") + sanitized = "".join(out) + if sanitized and sanitized[0].isdigit(): + sanitized = "n_" + sanitized + return sanitized or "anon" diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 1b09e185..56d332d3 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -36,13 +36,34 @@ from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord -from fireflyframework_agentic.pipeline.dag import DAG +from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id from fireflyframework_agentic.pipeline.reducers import Reducer, replace logger = logging.getLogger(__name__) StateNodeFn = Callable[[Any], Awaitable[dict[str, Any] | None]] -RouterFn = Callable[[Any], str] +# A router may return: a node id (str), a Send, or a list[Send] for fan-out. +RouterFn = Callable[[Any], "str | Send | list[Send]"] + + +@dataclass +class Send: + """Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state. + + Routers can return a single ``Send`` or a list of ``Send`` to dispatch multiple + target invocations concurrently. Each Send's payload is applied to a *copy* + of the current state before its target runs; the target's return is then + merged back into shared state via reducers. + + Replaces the legacy ``FanOutStep`` pattern with a first-class primitive. + """ + + target: str + payload: dict[str, Any] + + +class RecursionLimitError(Exception): + """Raised when a node is visited more times than ``recursion_limit`` permits.""" @dataclass @@ -129,6 +150,7 @@ def __init__( node_fns: dict[str, StateNodeFn], branches: dict[str, BranchSpec], checkpointer: Checkpointer | None = None, + recursion_limit: int = 25, ) -> None: self._name = name self._dag = dag @@ -136,6 +158,7 @@ def __init__( self._node_fns = node_fns self._branches = branches self._checkpointer = checkpointer + self._recursion_limit = recursion_limit self._reducers = discover_reducers(state_schema) self._validate() @@ -147,6 +170,35 @@ def name(self) -> str: def dag(self) -> DAG: return self._dag + def to_mermaid(self) -> str: + """Render the pipeline as a Mermaid flowchart, including branch edges. + + Branches that omit an explicit mapping are rendered as a dashed edge + labelled ``router`` because the targets are decided at runtime. + """ + lines = ["flowchart TD"] + for node_id in self._dag.nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + # Explicit edges (including branch mappings, which were materialized). + rendered: set[tuple[str, str]] = set() + for edge in self._dag.edges: + key = (edge.source, edge.target) + rendered.add(key) + label = None + spec = self._branches.get(edge.source) + if spec and spec.mapping: + for lbl, tgt in spec.mapping.items(): + if tgt == edge.target: + label = lbl + break + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + # Dynamic branches (no mapping): show as a dashed self-edge stub. + for source, spec in self._branches.items(): + if spec.mapping is None and not self._dag.successors(source): + lines.append(f" {_mermaid_id(source)} -.->|router| {_mermaid_id(source)}_router((dynamic))") + return "\n".join(lines) + def _validate(self) -> None: # Every node must have a registered fn. for node_id in self._dag.nodes: @@ -174,28 +226,17 @@ def _entry_node(self) -> str: raise PipelineError("Pipeline has no nodes") return order[0] - def _next_node(self, current: str, state: BaseModel) -> str | None: - """Decide which successor runs next given the current state. + def _next_step(self, current: str, state: BaseModel) -> str | list[Send] | None: + """Decide what runs next given the current state. - For non-branching nodes: pick the unique successor (or None at terminus). - For branching nodes: run the router and resolve via mapping if present. + Returns: + * A node id (str) for a single deterministic step. + * A list of :class:`Send` for runtime fan-out — workers run concurrently. + * ``None`` when the pipeline reaches a terminus. """ if current in self._branches: - spec = self._branches[current] - label = spec.router(state) - if spec.mapping is not None: - if label not in spec.mapping: - raise PipelineError( - f"Router for '{current}' returned label '{label}' not in mapping {list(spec.mapping)}" - ) - return spec.mapping[label] - # Mapping omitted: router returns target node id directly. - if label not in self._dag.nodes: - raise PipelineError( - f"Router for '{current}' returned '{label}' " - f"which is not a registered node id; pass an explicit mapping if you want labels." - ) - return label + decision = self._branches[current].router(state) + return self._resolve_router_decision(current, decision) successors = self._dag.successors(current) if not successors: @@ -207,6 +248,52 @@ def _next_node(self, current: str, state: BaseModel) -> str | None: ) return successors[0] + def _resolve_router_decision(self, current: str, decision: str | Send | list[Send]) -> str | list[Send] | None: + """Translate a router's return value into a concrete next-step instruction.""" + # Fan-out: list of Send dispatches. + if isinstance(decision, list): + if not decision: + return None + for s in decision: + if not isinstance(s, Send): + raise PipelineError( + f"Router for '{current}' returned a list containing non-Send " + f"element {s!r}; expected list[Send]." + ) + if s.target not in self._dag.nodes: + raise PipelineError(f"Router for '{current}' fans out to unknown target '{s.target}'") + return decision + + if isinstance(decision, Send): + if decision.target not in self._dag.nodes: + raise PipelineError(f"Router for '{current}' dispatched to unknown target '{decision.target}'") + return [decision] + + # String label. + spec = self._branches[current] + if spec.mapping is not None: + if decision not in spec.mapping: + raise PipelineError( + f"Router for '{current}' returned label '{decision}' not in mapping {list(spec.mapping)}" + ) + return spec.mapping[decision] + if decision not in self._dag.nodes: + raise PipelineError( + f"Router for '{current}' returned '{decision}' " + f"which is not a registered node id; pass an explicit mapping if you want labels." + ) + return decision + + def _common_successor(self, node_ids: list[str]) -> str | None: + """Return the node all ``node_ids`` share as their unique successor, or None.""" + successor_sets = [set(self._dag.successors(nid)) for nid in node_ids] + if not successor_sets or any(len(s) != 1 for s in successor_sets): + return None + common = successor_sets[0] + for s in successor_sets[1:]: + common = common & s + return next(iter(common)) if len(common) == 1 else None + async def invoke( self, state: BaseModel | None = None, @@ -235,7 +322,13 @@ async def invoke( resumed_completed = list(record.completed_nodes) # Resume at the successor of the last completed node. last = record.node_id - next_node = self._next_node(last, state) + next_node = self._next_step(last, state) + # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. + if isinstance(next_node, list): + raise PipelineError( + "Resume across a fan-out (Send) is not supported in Phase 2; " + "the run finished by reaching a fan-out node." + ) if next_node is None: return StatePipelineResult( state=state, @@ -266,9 +359,58 @@ async def invoke( assert state is not None # narrowed by the branches above completed: list[str] = list(resumed_completed) sequence = len(completed) + visit_counts: dict[str, int] = {} + + next_step: str | list[Send] | None = current_node + last_node_id: str | None = current_node + + while next_step is not None: + # --- fan-out branch (list[Send]) --------------------------------- + if isinstance(next_step, list): + try: + state, sequence = await self._run_fanout( + sends=next_step, + state=state, + completed=completed, + run_id=run_id, + sequence=sequence, + visit_counts=visit_counts, + ) + except _NodeFailureError as fail: + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=fail.message, + failed_node=fail.node_id, + ) + last_node_id = next_step[-1].target + # After fan-out, continue from the workers' shared successor (if any). + worker_ids = [s.target for s in next_step] + shared = self._common_successor(worker_ids) + next_step = shared + continue + + # --- single-node step -------------------------------------------- + node_id = next_step + visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 + if visit_counts[node_id] > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " + f"Raise recursion_limit= or fix the routing logic." + ) + logger.error(msg) + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=msg, + failed_node=node_id, + ) - while current_node is not None: - fn = self._node_fns[current_node] + fn = self._node_fns[node_id] t0 = time.perf_counter() try: update = await fn(state) @@ -277,7 +419,7 @@ async def invoke( "State pipeline '%s' run '%s' failed at node '%s'", self._name, run_id, - current_node, + node_id, ) return StatePipelineResult( state=state, @@ -285,35 +427,21 @@ async def invoke( completed_nodes=completed, success=False, error=str(exc), - failed_node=current_node, + failed_node=node_id, ) elapsed = (time.perf_counter() - t0) * 1000 - logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, current_node, elapsed) + logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) if update: state = apply_update(state, update, self._reducers) - completed.append(current_node) + completed.append(node_id) sequence += 1 - - if self._checkpointer is not None: - try: - self._checkpointer.save( - CheckpointRecord( - pipeline_name=self._name, - run_id=run_id, - node_id=current_node, - sequence=sequence, - state=state.model_dump(), - completed_nodes=list(completed), - ) - ) - except Exception: - # Checkpoint failure is non-fatal — log and continue. - logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, current_node) + self._save_checkpoint(run_id, node_id, sequence, state, completed) + last_node_id = node_id try: - current_node = self._next_node(current_node, state) + next_step = self._next_step(node_id, state) except PipelineError as exc: return StatePipelineResult( state=state, @@ -321,7 +449,7 @@ async def invoke( completed_nodes=completed, success=False, error=str(exc), - failed_node=completed[-1] if completed else None, + failed_node=last_node_id, ) return StatePipelineResult( @@ -331,6 +459,89 @@ async def invoke( success=True, ) + async def _run_fanout( + self, + *, + sends: list[Send], + state: BaseModel, + completed: list[str], + run_id: str, + sequence: int, + visit_counts: dict[str, int], + ) -> tuple[BaseModel, int]: + """Run all ``Send`` dispatches concurrently. Each task gets its own state + copy with the Send's payload merged in; results are reduced into shared state. + """ + for send in sends: + visit_counts[send.target] = visit_counts.get(send.target, 0) + 1 + if visit_counts[send.target] > self._recursion_limit: + raise _NodeFailureError( + node_id=send.target, + message=( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{send.target}' during fan-out." + ), + ) + + async def _run_one(send: Send) -> tuple[Send, dict[str, Any] | None]: + task_state = apply_update(state, send.payload, self._reducers) + fn = self._node_fns[send.target] + return send, await fn(task_state) + + try: + results = await asyncio.gather(*(_run_one(s) for s in sends)) + except Exception as exc: + # Best-effort: report the first failing target as the failure point. + raise _NodeFailureError( + node_id=sends[0].target, + message=f"Fan-out failure: {exc}", + ) from exc + + new_state = state + for send, update in results: + if update: + new_state = apply_update(new_state, update, self._reducers) + completed.append(send.target) + sequence += 1 + self._save_checkpoint(run_id, send.target, sequence, new_state, completed) + + return new_state, sequence + + def _save_checkpoint( + self, + run_id: str, + node_id: str, + sequence: int, + state: BaseModel, + completed: list[str], + ) -> None: + """Persist state via the configured checkpointer (no-op if absent).""" + if self._checkpointer is None: + return + try: + self._checkpointer.save( + CheckpointRecord( + pipeline_name=self._name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + state=state.model_dump(), + completed_nodes=list(completed), + ) + ) + except Exception: + logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) + + +@dataclass +class _NodeFailureError(Exception): + """Internal sentinel used to bubble fan-out failures out to the main loop.""" + + node_id: str + message: str + + def __str__(self) -> str: + return self.message + def _resolve_node_id(ref: str | Callable[..., Any]) -> str: """Turn either a string id or a function reference into a node id.""" diff --git a/fireflyframework_agentic/pipeline/steps.py b/fireflyframework_agentic/pipeline/steps.py index 47d12220..68e2bb29 100644 --- a/fireflyframework_agentic/pipeline/steps.py +++ b/fireflyframework_agentic/pipeline/steps.py @@ -20,6 +20,7 @@ import asyncio import logging +import warnings from collections.abc import Callable, Coroutine from typing import Any, Protocol, runtime_checkable @@ -148,6 +149,12 @@ def classify(inputs): """ def __init__(self, router: Callable[[dict[str, Any]], str]) -> None: + warnings.warn( + "BranchStep is deprecated; use PipelineBuilder(state=...).branch(source, router) " + "for first-class declarative branching.", + DeprecationWarning, + stacklevel=2, + ) self._router = router async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: @@ -162,6 +169,12 @@ class FanOutStep: """ def __init__(self, split_fn: Callable[[Any], list[Any]]) -> None: + warnings.warn( + "FanOutStep is deprecated; use PipelineBuilder(state=...) with a router returning " + "list[Send(target, payload)] for first-class runtime fan-out.", + DeprecationWarning, + stacklevel=2, + ) self._split_fn = split_fn async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: diff --git a/tests/unit/pipeline/test_state_pipeline_phase2.py b/tests/unit/pipeline/test_state_pipeline_phase2.py new file mode 100644 index 00000000..f2040f80 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_phase2.py @@ -0,0 +1,231 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-2 tests: cycles + recursion_limit, Send fan-out, Mermaid export, +soft-deprecation of BranchStep / FanOutStep. +""" + +from __future__ import annotations + +import json +import warnings +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + DAG, + BranchStep, + DAGEdge, + DAGNode, + FailureStrategy, + FanOutStep, + PipelineBuilder, + Send, + StatePipeline, + extend, +) +from fireflyframework_agentic.pipeline.steps import CallableStep + + +class LoopState(BaseModel): + counter: int = 0 + log: Annotated[list[str], extend] = [] + + +# --- cycles + recursion_limit ---------------------------------------------- + + +@pytest.mark.asyncio +async def test_simple_cycle_with_exit_router(): + """A node loops back to itself N times, then a router exits to END.""" + + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1, "log": [f"step#{state.counter + 1}"]} + + async def done(state: LoopState) -> dict: + return {"log": ["done"]} + + def route(state: LoopState) -> str: + return "done" if state.counter >= 3 else "step" + + pipeline = PipelineBuilder("loop", state=LoopState).add_node(step).add_node(done).branch(step, route).build() + assert isinstance(pipeline, StatePipeline) + result = await pipeline.invoke(LoopState()) + assert result.success + assert result.state.counter == 3 + assert "done" in result.state.log + # 3 step visits + 1 done = 4 entries before done's own log entry. + assert result.completed_nodes == ["step", "step", "step", "done"] + + +@pytest.mark.asyncio +async def test_recursion_limit_aborts_infinite_loop(): + """A router that never exits triggers the recursion_limit safety net.""" + + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1} + + def never_exits(state: LoopState) -> str: + return "step" + + pipeline = ( + PipelineBuilder("inf", state=LoopState, recursion_limit=5).add_node(step).branch(step, never_exits).build() + ) + result = await pipeline.invoke(LoopState()) + assert not result.success + assert "Recursion limit" in (result.error or "") + assert result.failed_node == "step" + # The node ran exactly recursion_limit times before the guard fired. + assert result.state.counter == 5 + + +# --- Send fan-out ----------------------------------------------------------- + + +class FanOutState(BaseModel): + items: list[str] = [] + results: Annotated[list[str], extend] = [] + item: str | None = None # filled per-Send via payload + + +@pytest.mark.asyncio +async def test_send_fans_out_to_multiple_workers_and_merges_results(): + """Router returns list[Send]; workers run concurrently; reducer merges.""" + + async def planner(state: FanOutState) -> dict: + return {} # passthrough; could populate items if not preset + + async def worker(state: FanOutState) -> dict: + # Each worker sees its own copy of state with the Send payload applied. + assert state.item is not None + return {"results": [f"processed:{state.item}"]} + + async def collect(state: FanOutState) -> dict: + return {"results": ["collected"]} + + def dispatch(state: FanOutState) -> list[Send]: + return [Send("worker", {"item": x}) for x in state.items] + + pipeline = ( + PipelineBuilder("mapreduce", state=FanOutState) + .add_node(planner) + .add_node(worker) + .add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() + ) + result = await pipeline.invoke(FanOutState(items=["a", "b", "c"])) + assert result.success + processed = sorted(r for r in result.state.results if r.startswith("processed:")) + assert processed == ["processed:a", "processed:b", "processed:c"] + assert "collected" in result.state.results + # Each worker counts as a completed node visit; planner once, three workers, then collect. + assert result.completed_nodes.count("worker") == 3 + assert result.completed_nodes[-1] == "collect" + + +@pytest.mark.asyncio +async def test_send_to_unknown_target_fails_cleanly(): + async def planner(state: FanOutState) -> dict: + return {} + + async def worker(state: FanOutState) -> dict: + return {} + + def bad_dispatch(state: FanOutState) -> list[Send]: + return [Send("ghost", {})] + + pipeline = ( + PipelineBuilder("bad", state=FanOutState) + .add_node(planner) + .add_node(worker) + .branch(planner, bad_dispatch) + .build() + ) + result = await pipeline.invoke(FanOutState()) + assert not result.success + assert "ghost" in (result.error or "") + + +# --- Mermaid + JSON export -------------------------------------------------- + + +def test_dag_to_mermaid_renders_topology(): + dag = DAG(name="example") + dag.add_node(DAGNode(node_id="a", step=CallableStep(_noop_async))) + dag.add_node(DAGNode(node_id="b", step=CallableStep(_noop_async))) + dag.add_edge(DAGEdge(source="a", target="b")) + out = dag.to_mermaid() + assert out.startswith("flowchart TD") + assert "a[a]" in out + assert "b[b]" in out + assert "a --> b" in out + + +def test_dag_to_json_round_trips_via_pydantic(): + dag = DAG(name="example") + dag.add_node(DAGNode(node_id="a", step=CallableStep(_noop_async))) + dag.add_node(DAGNode(node_id="b", step=CallableStep(_noop_async), failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b", input_key="payload")) + doc = json.loads(dag.to_json()) + assert doc["name"] == "example" + assert doc["nodes"] == ["a", "b"] + assert doc["edges"] == [{"source": "a", "target": "b", "output_key": "output", "input_key": "payload"}] + + +def test_state_pipeline_to_mermaid_labels_branch_edges(): + async def start(state: LoopState) -> dict: + return {} + + async def left(state: LoopState) -> dict: + return {} + + async def right(state: LoopState) -> dict: + return {} + + def route(state: LoopState) -> str: + return "left_path" + + pipeline = ( + PipelineBuilder("branched", state=LoopState) + .add_node(start) + .add_node(left) + .add_node(right) + .branch(start, route, {"left_path": left, "right_path": right}) + .build() + ) + assert isinstance(pipeline, StatePipeline) + mermaid = pipeline.to_mermaid() + assert "start -->|left_path| left" in mermaid + assert "start -->|right_path| right" in mermaid + + +# --- soft-deprecation ------------------------------------------------------ + + +def test_branch_step_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + BranchStep(router=lambda _: "x") + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert any("branch(" in str(w.message) for w in caught) + + +def test_fan_out_step_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + FanOutStep(split_fn=lambda x: [x]) + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert any("Send" in str(w.message) for w in caught) + + +# --- helpers --------------------------------------------------------------- + + +async def _noop_async(ctx, inputs): + return None From 32e373fec75c32c3683ef2dacbed6d0b2f060482 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 16:09:23 +0200 Subject: [PATCH 03/26] docs(pipeline): state-mode example + docs section, simplify state_pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - examples/pipeline_state.py: runnable demo covering branching, software-factory with checkpoint/resume, and Send-based map-reduce. No API key needed. - docs/pipeline.md: new "State-Based Pipelines" section documenting state schema, reducers, .branch, checkpoint/resume, recursion_limit, Send fan-out, Mermaid export, and a "when to use which mode" comparison. Mark BranchStep / FanOutStep as deprecated with pointers to the new API. - state_pipeline.py: small simplifications from the simplifier pass — dead 'rendered' set removed from to_mermaid, _common_successor uses direct equality instead of set intersection, dropped unused last_node_id tracker, _NodeFailureError as a plain Exception subclass instead of @dataclass. 97 pipeline tests still green; ruff check + format clean. --- docs/pipeline.md | 189 +++++++++++++- examples/pipeline_state.py | 242 ++++++++++++++++++ .../pipeline/state_pipeline.py | 32 +-- 3 files changed, 438 insertions(+), 25 deletions(-) create mode 100644 examples/pipeline_state.py diff --git a/docs/pipeline.md b/docs/pipeline.md index 57b0d152..2ece529e 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -7,6 +7,17 @@ workflows. It supports parallel execution, conditional branching, retries, timeo and fan-out/fan-in patterns -- everything needed to model real-world enterprise processing pipelines. +`PipelineBuilder` has two modes: + +* **Port-based** (legacy, parallel) — nodes communicate via `output_key` / + `input_key` edge ports and run concurrently within each topological level. + Best for ETL-shaped DAGs. Documented in the bulk of this guide. +* **State-based** — opt-in via `PipelineBuilder("name", state=SomeModel)`. + Nodes become `async (state) -> dict` over a typed shared state. One + `.branch(source, router)` call covers conditional routing; `Send(target, payload)` + covers runtime fan-out; a `Checkpointer` enables resume after failure. Best for + agentic workflows and ReAct-style loops. See [State-Based Pipelines](#state-based-pipelines). + --- ## Concepts @@ -88,15 +99,181 @@ The framework provides these built-in executors: - **CallableStep** -- Wraps any `async` function `(context, inputs) -> output`. - **BatchLLMStep** -- Processes multiple prompts concurrently through an agent for cost optimization. See [Batch Processing](#batch-processing-batchllmstep) below. -- **BranchStep** -- Routes execution to one of several downstream paths based on - a predicate (see [Conditional Branching](#conditional-branching-branchstep) below). -- **FanOutStep** -- Splits input into a list for parallel downstream processing. +- **BranchStep** _(deprecated)_ -- Routes execution to one of several downstream paths based on + a predicate. Use `.branch(...)` in [State-Based Pipelines](#state-based-pipelines) instead. +- **FanOutStep** _(deprecated)_ -- Splits input into a list for parallel downstream processing. + Use `Send` in [State-Based Pipelines](#runtime-fan-out-via-send) instead. - **FanInStep** -- Merges outputs from multiple upstream nodes. --- +## State-Based Pipelines + +Set `state=` on `PipelineBuilder` to switch to a declarative API designed for +agentic workflows. Nodes become `async (state) -> dict | None` functions over +a typed shared-state object; the engine reduces each node's partial-update +dict back into the state. + +```python +from typing import Annotated +from pydantic import BaseModel +from fireflyframework_agentic.pipeline import PipelineBuilder, append + + +class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] # reducer: append + intent: str | None = None # default reducer: replace + answer: str | None = None + + +async def classify(state: AgentState) -> dict: + return {"intent": "complaint" if "refund" in state.messages[-1] else "general"} + + +async def answer(state: AgentState) -> dict: + return {"answer": "Here is your answer."} + + +async def escalate(state: AgentState) -> dict: + return {"answer": "Escalated to human."} + + +def route(state: AgentState) -> str: + return "escalate" if state.intent == "complaint" else "answer" + + +pipeline = ( + PipelineBuilder("support-agent", state=AgentState) + .add_node(classify) # node id derived from fn.__name__ + .add_node(answer) + .add_node(escalate) + .branch(classify, route) # router returns target node id + .build() +) +result = await pipeline.invoke(AgentState(messages=["I want a refund"])) +print(result.state.answer) +``` + +### Reducers + +Reducers are declared as `Annotated[T, reducer_fn]` on the state schema. The +built-ins live in `fireflyframework_agentic.pipeline.reducers`: + +| Reducer | Semantics | +|---------------|-------------------------------------------------| +| `replace` | Last-write-wins (the default for any field). | +| `append` | Append a single item to a list. | +| `extend` | Concatenate two iterables. | +| `merge_dict` | Shallow-merge two dicts; update wins on conflict. | + +Custom reducers are any callable `(current, update) -> merged`. + +### Branching + +`.branch(source, router, mapping=None)` registers a synchronous +`(state) -> str | Send | list[Send]` router on `source`: + +* Returning a node id (string) routes to that node directly. +* Passing `mapping={"label": target_node, ...}` lets the router return an + abstract label instead of a node id. +* Returning a `Send` or `list[Send]` triggers runtime fan-out (see below). + +### Checkpoint + Resume + +Pass a `Checkpointer` to persist state after each successful node. The +filesystem implementation ships in the package; Redis/Postgres backends are +straightforward to plug in via the `Checkpointer` Protocol. + +```python +from fireflyframework_agentic.pipeline import FileCheckpointer + +pipeline = ( + PipelineBuilder("software-factory", state=BuildState, + checkpointer=FileCheckpointer("./checkpoints")) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() +) + +# Fresh run +result = await pipeline.invoke(BuildState(requirements="user-mgmt service")) + +# Resume after crash — picks up at the failed node, skips completed ones +result = await pipeline.invoke(run_id=result.run_id) + +# Or jump into a specific node with explicit state +result = await pipeline.invoke(state=loaded_state, start_at=deployer) +``` + +### Cycles and `recursion_limit` + +State pipelines permit cycles for ReAct loops and retry-with-critique patterns. +The builder accepts `recursion_limit` (default 25) as a safety net — a runaway +loop surfaces as `result.success=False` with a clean error, not an infinite hang. + +```python +def route(state): + return "done" if state.counter >= 3 else "step" + +PipelineBuilder("loop", state=LoopState, recursion_limit=25) + .add_node(step).add_node(done).branch(step, route).build() +``` + +### Runtime Fan-Out via `Send` + +A router may return `list[Send(target, payload)]` to dispatch multiple +invocations of the same (or different) workers concurrently. Each Send's +payload is applied to a copy of the current state before its target runs; +results reduce back into shared state. Replaces the legacy `FanOutStep`. + +```python +from fireflyframework_agentic.pipeline import Send + +def dispatch(state): + return [Send("worker", {"item": x}) for x in state.items] + +PipelineBuilder("mapreduce", state=MapReduceState) + .add_node(planner).add_node(worker).add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() +``` + +When all worker targets share a common successor, the engine continues there +once the fan-out completes; the aggregator runs once with all results in +shared state. + +### Mermaid Export + +`StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a +Mermaid flowchart. Branch edges declared with an explicit mapping show their +label; dynamic routers are noted as such. + +### When to use which mode + +| Use port-based when… | Use state-based when… | +|----------------------|------------------------| +| Pure ETL: parallel, fan-out/fan-in, no shared state | Agentic workflow: classify → branch → respond / loop / retry | +| Each step's input is a single value from the previous step | Multiple agents reading/writing different fields of a shared object | +| You want the engine to run independent nodes concurrently | You want resume-after-failure and start-from-middle semantics | +| You're happy with `BranchStep` + per-node `condition` lambdas | You want one `.branch(...)` call and inspectable routing | + +See [`examples/pipeline_state.py`](../examples/pipeline_state.py) for a +runnable demo covering branching, software-factory checkpoint/resume, and +map-reduce fan-out. + +--- + ## Parallel Execution (Fan-Out / Fan-In) +> **`FanOutStep` is deprecated.** For runtime fan-out (one dispatch per item, +> arbitrary count), prefer `Send` from [State-Based Pipelines](#runtime-fan-out-via-send). +> `FanOutStep` still works for now (it emits a `DeprecationWarning` on +> construction); `FanInStep` is not deprecated. + ```mermaid graph TD SPLIT[Fan-Out] --> W1[Worker 1] @@ -248,6 +425,12 @@ dag.add_node(DAGNode( ### Conditional Branching (BranchStep) +> **Deprecated.** Prefer [State-Based Pipelines](#state-based-pipelines) with +> `.branch(source, router)` — one call instead of `BranchStep` + per-node +> `condition` lambdas, and the topology becomes inspectable as data. +> `BranchStep` still works (it emits a `DeprecationWarning` on construction); +> removal will be tracked in a follow-up issue once internal callers migrate. + `BranchStep` provides router-based conditional branching. The router callable receives the node's input and returns a string key. Downstream nodes use condition gates to check the branch key and execute only the matching path. diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py new file mode 100644 index 00000000..9f0fc463 --- /dev/null +++ b/examples/pipeline_state.py @@ -0,0 +1,242 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-based PipelineBuilder: branching, checkpoint/resume, and Send fan-out. + +Three scenarios: + +1. **Branching** — same sentiment-classification workflow as + ``examples/pipeline_branching.py``, but written with the state-mode API + (one shared ``State`` model, ``async (state) -> dict`` nodes, one + ``.branch(source, router)`` call instead of ``BranchStep`` + per-node + ``condition`` lambdas). + +2. **Software factory with checkpoint/resume** — a four-agent pipeline + (architect → python_dev → deployer → evaluator) where the deployer fails + on its first attempt. The pipeline checkpoints after each successful node, + and a second ``invoke(run_id=...)`` resumes from the failed node instead + of re-running the earlier agents. + +3. **Map-reduce via ``Send``** — a planner dispatches one ``Send`` per work + item to the same worker node, the workers run concurrently, and an + aggregator runs once with all results in shared state. + +Usage:: + + uv run python examples/pipeline_state.py + +.. note:: No OpenAI API key required — all "agents" are plain Python stubs. +""" + +from __future__ import annotations + +import asyncio +import logging +import tempfile +from pathlib import Path +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + Send, + extend, +) + +# Quiet the pipeline's own logger.exception() when we deliberately fail +# the deployer in scenario 2 — the failure is the demo, not a bug. +logging.getLogger("fireflyframework_agentic.pipeline").setLevel(logging.CRITICAL) + + +# ============================================================================= +# Scenario 1 — Branching +# ============================================================================= + + +class SentimentState(BaseModel): + text: str + sentiment: str | None = None + response: str | None = None + + +async def classify_sentiment(state: SentimentState) -> dict: + text = state.text.lower() + positive = {"good", "great", "love", "amazing", "wonderful", "happy", "excellent"} + negative = {"bad", "terrible", "hate", "awful", "horrible", "sad", "poor"} + pos = sum(1 for w in text.split() if w in positive) + neg = sum(1 for w in text.split() if w in negative) + return {"sentiment": "positive" if pos >= neg else "negative"} + + +async def positive_reply(state: SentimentState) -> dict: + return {"response": "😊 Thank you for your kind words!"} + + +async def negative_reply(state: SentimentState) -> dict: + return {"response": "😟 We're sorry to hear that. We'll improve!"} + + +def route_by_sentiment(state: SentimentState) -> str: + # The router returns the node id directly — no mapping needed. + return "positive_reply" if state.sentiment == "positive" else "negative_reply" + + +async def run_branching() -> None: + print("=== 1. Branching (state mode) ===\n") + + pipeline = ( + PipelineBuilder("sentiment", state=SentimentState) + .add_node(classify_sentiment) + .add_node(positive_reply) + .add_node(negative_reply) + .branch(classify_sentiment, route_by_sentiment) + .build() + ) + + for text in ["This product is great and amazing!", "The service was terrible and awful."]: + result = await pipeline.invoke(SentimentState(text=text)) + print(f" input: {text!r}") + print(f" output: {result.state.response}\n") + + +# ============================================================================= +# Scenario 2 — Software factory with checkpoint/resume +# ============================================================================= + + +class BuildState(BaseModel): + """State threaded through a four-agent software-factory pipeline.""" + + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +# A flag so the deployer fails the first time and succeeds the second. +_deployer_failed_once = {"flag": False} + + +async def architect(state: BuildState) -> dict: + return {"spec": f"Architecture for: {state.requirements}"} + + +async def python_dev(state: BuildState) -> dict: + return {"code": f"# code implementing\n# {state.spec}"} + + +async def deployer(state: BuildState) -> dict: + if not _deployer_failed_once["flag"]: + _deployer_failed_once["flag"] = True + raise RuntimeError("network blip — try again") + return {"deploy_url": "https://factory-app.example.com"} + + +async def evaluator(state: BuildState) -> dict: + return {"evaluation": f"PASS — deployed at {state.deploy_url}"} + + +async def run_software_factory() -> None: + print("=== 2. Software factory with checkpoint/resume ===\n") + + with tempfile.TemporaryDirectory() as tmp: + ckpt = FileCheckpointer(Path(tmp)) + pipeline = ( + PipelineBuilder("software-factory", state=BuildState, checkpointer=ckpt) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + + # First run — deployer fails after architect + python_dev complete. + first = await pipeline.invoke(BuildState(requirements="User-management service")) + print(f" first run: success={first.success}, failed_node={first.failed_node}") + print(f" completed: {first.completed_nodes}") + print(f" run_id: {first.run_id}\n") + + # Resume — picks up at deployer, skips architect + python_dev. + second = await pipeline.invoke(run_id=first.run_id) + print(f" resumed: success={second.success}") + print(f" completed: {second.completed_nodes}") + print(f" eval: {second.state.evaluation}\n") + + +# ============================================================================= +# Scenario 3 — Map-reduce via Send +# ============================================================================= + + +class MapReduceState(BaseModel): + items: list[str] = [] + processed: Annotated[list[str], extend] = [] + summary: str | None = None + # Per-Send payload field — each worker receives its own item here. + item: str | None = None + + +async def plan(state: MapReduceState) -> dict: + # No state mutation; the dispatch router below decides what runs next. + return {} + + +async def process_item(state: MapReduceState) -> dict: + assert state.item is not None + return {"processed": [f"processed:{state.item}"]} + + +async def aggregate(state: MapReduceState) -> dict: + return {"summary": f"Processed {len(state.processed)} items: {state.processed}"} + + +def dispatch(state: MapReduceState) -> list[Send]: + # One Send per item — workers run concurrently. The ``extend`` reducer on + # ``processed`` merges all worker outputs into one list. + return [Send("process_item", {"item": x}) for x in state.items] + + +async def run_map_reduce() -> None: + print("=== 3. Map-reduce via Send ===\n") + + pipeline = ( + PipelineBuilder("mapreduce", state=MapReduceState) + .add_node(plan) + .add_node(process_item) + .add_node(aggregate) + .add_edge(process_item, aggregate) + .branch(plan, dispatch) + .build() + ) + result = await pipeline.invoke(MapReduceState(items=["alpha", "beta", "gamma", "delta"])) + print(f" summary: {result.state.summary}") + + +# ============================================================================= +# Entrypoint +# ============================================================================= + + +async def main() -> None: + await run_branching() + await run_software_factory() + await run_map_reduce() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 56d332d3..6c818bf7 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -180,10 +180,7 @@ def to_mermaid(self) -> str: for node_id in self._dag.nodes: lines.append(f" {_mermaid_id(node_id)}[{node_id}]") # Explicit edges (including branch mappings, which were materialized). - rendered: set[tuple[str, str]] = set() for edge in self._dag.edges: - key = (edge.source, edge.target) - rendered.add(key) label = None spec = self._branches.get(edge.source) if spec and spec.mapping: @@ -286,13 +283,11 @@ def _resolve_router_decision(self, current: str, decision: str | Send | list[Sen def _common_successor(self, node_ids: list[str]) -> str | None: """Return the node all ``node_ids`` share as their unique successor, or None.""" - successor_sets = [set(self._dag.successors(nid)) for nid in node_ids] - if not successor_sets or any(len(s) != 1 for s in successor_sets): + successors = [self._dag.successors(nid) for nid in node_ids] + if not successors or any(len(s) != 1 for s in successors): return None - common = successor_sets[0] - for s in successor_sets[1:]: - common = common & s - return next(iter(common)) if len(common) == 1 else None + first = successors[0][0] + return first if all(s[0] == first for s in successors[1:]) else None async def invoke( self, @@ -362,7 +357,6 @@ async def invoke( visit_counts: dict[str, int] = {} next_step: str | list[Send] | None = current_node - last_node_id: str | None = current_node while next_step is not None: # --- fan-out branch (list[Send]) --------------------------------- @@ -385,11 +379,8 @@ async def invoke( error=fail.message, failed_node=fail.node_id, ) - last_node_id = next_step[-1].target # After fan-out, continue from the workers' shared successor (if any). - worker_ids = [s.target for s in next_step] - shared = self._common_successor(worker_ids) - next_step = shared + next_step = self._common_successor([s.target for s in next_step]) continue # --- single-node step -------------------------------------------- @@ -438,7 +429,6 @@ async def invoke( completed.append(node_id) sequence += 1 self._save_checkpoint(run_id, node_id, sequence, state, completed) - last_node_id = node_id try: next_step = self._next_step(node_id, state) @@ -449,7 +439,7 @@ async def invoke( completed_nodes=completed, success=False, error=str(exc), - failed_node=last_node_id, + failed_node=node_id, ) return StatePipelineResult( @@ -532,15 +522,13 @@ def _save_checkpoint( logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) -@dataclass class _NodeFailureError(Exception): """Internal sentinel used to bubble fan-out failures out to the main loop.""" - node_id: str - message: str - - def __str__(self) -> str: - return self.message + def __init__(self, node_id: str, message: str) -> None: + super().__init__(message) + self.node_id = node_id + self.message = message def _resolve_node_id(ref: str | Callable[..., Any]) -> str: From 159a2d4eb532ae7978cf982aa3f729655bb3137e Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 16:49:24 +0200 Subject: [PATCH 04/26] =?UTF-8?q?docs(spec):=20Phase=203a=20design=20?= =?UTF-8?q?=E2=80=94=20Redis=20+=20Postgres=20checkpointers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../2026-05-27-pipeline-phase-3a-design.md | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md diff --git a/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md b/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md new file mode 100644 index 00000000..7eb5f684 --- /dev/null +++ b/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md @@ -0,0 +1,209 @@ +# Phase 3a — Durable Checkpointer Backends (Redis + Postgres) + +Issue [#147](https://github.com/fireflyframework/fireflyframework-agentic/issues/147), phase 3a. Stacked on top of Phase 1+2 (`issue-147-pipeline-evolution`). + +## Problem + +`StatePipeline` checkpointing today has one backend: `FileCheckpointer`, which writes JSON files to a local filesystem. That blocks the use cases customers actually run agentic pipelines for: + +- **Multi-worker fail-over.** A pipeline that crashes on worker A cannot be resumed on worker B because the checkpoint files are on A's disk. +- **Containerized deploys.** Ephemeral container filesystems lose checkpoint state on restart. +- **Horizontal scaling.** No shared state means runs are pinned to a single host. +- **Long-lived runs.** Filesystem checkpoints accumulate without TTL; cleanup is manual. + +Phase 3a ships two durable backends — Redis and Postgres — pluggable through the existing `Checkpointer` Protocol. No API changes to `StatePipeline`; no breaking changes for existing `FileCheckpointer` users. + +## Goal + +`PipelineBuilder("agent", state=..., checkpointer=...)` accepts a `RedisCheckpointer` or `PostgresCheckpointer` interchangeably with the existing `FileCheckpointer`. The software-factory scenario (architect → python_dev → deployer fails → resume on a different worker → evaluator) works end-to-end against any of the three. + +## Non-goals + +- Backend-pluggable observability / OTel spans (phase 3b). +- HITL pause primitive and audit log (phase 3c). +- Real-Redis / real-Postgres tests in this PR — verified out-of-band by the user. +- Schema migration tooling for the JSONB `state` column — the application owns state-schema evolution. +- Connection-pool tuning knobs beyond "pass me a client." +- Auth / RBAC on resume APIs. +- Streaming partial state. + +## Architecture + +Single file: `fireflyframework_agentic/pipeline/checkpoint.py`. Existing contents (`Checkpointer` Protocol, `CheckpointRecord`, `FileCheckpointer`) are preserved. Two new classes added to the same module: + +``` +pipeline/checkpoint.py +├── Checkpointer (Protocol, existing) +├── CheckpointRecord (existing) +├── FileCheckpointer (existing) +├── RedisCheckpointer (new) +└── PostgresCheckpointer (new) +``` + +### Guarded optional imports + +Top-of-file guarded imports mirror the established pattern in `engine.py` for OTel: + +```python +try: + import redis as _redis +except ImportError: + _redis = None + +try: + import psycopg as _psycopg +except ImportError: + _psycopg = None +``` + +Each backend's `__init__` raises a clear install-instruction error if its dependency is absent. No inline imports — project rule respected. + +### Connection lifecycle + +Both backends accept either a connection string OR a pre-built sync client so callers can share a connection pool across many pipelines: + +```python +RedisCheckpointer(url="redis://localhost:6379/0", ttl_seconds=86400 * 30) +RedisCheckpointer(client=existing_redis_client) + +PostgresCheckpointer(dsn="postgresql://user:pw@host/db") +PostgresCheckpointer(connection=existing_psycopg_connection) +``` + +The `Checkpointer` Protocol's methods are sync (called synchronously inside `StatePipeline._save_checkpoint`). Both backends use the sync clients (`redis-py`, `psycopg[binary]`) directly — no `asyncio.run` indirection. + +## Storage schemas + +### Redis + +One key per checkpoint, plus a sorted-set index per pipeline for `list_runs`: + +| Key | Type | Value | +|---|---|---| +| `firefly:ckpt:{pipeline}:{run_id}:{seq:06d}_{node_id}` | string | JSON-serialized `CheckpointRecord` | +| `firefly:ckpt:{pipeline}:runs` | sorted set | members = `run_id`s, scores = last-update unix timestamps | + +- TTL on the per-checkpoint keys is configurable, default 30 days. +- `save` issues `SET key value EX ttl` then `ZADD index ts run_id` (idempotent — score updated each call). +- `load_latest` issues `KEYS firefly:ckpt:{pipeline}:{run_id}:*`, picks the lexicographically last key (the zero-padded `seq` makes lex order match numeric order), then `GET` it. +- `list_runs` issues `ZRANGE index 0 -1` and returns all run IDs ordered by last-update. + +`KEYS` is acceptable here because the cardinality per run is bounded by node count × visit count (small for agentic workflows). A `SCAN`-based variant is trivial to add if scale demands it; not in scope for 3a. + +### Postgres + +Single table, created idempotently on first connect: + +```sql +CREATE TABLE IF NOT EXISTS firefly_checkpoints ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + node_id TEXT NOT NULL, + state JSONB NOT NULL, + completed_nodes JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pipeline_name, run_id, sequence) +); +CREATE INDEX IF NOT EXISTS firefly_checkpoints_run + ON firefly_checkpoints (pipeline_name, run_id); +``` + +- DDL runs once per `PostgresCheckpointer` instance via a `_ddl_applied` flag set on first `save`. +- `save` issues `INSERT … ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET node_id = EXCLUDED.node_id, state = EXCLUDED.state, completed_nodes = EXCLUDED.completed_nodes`. +- `load_latest` issues `SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes FROM firefly_checkpoints WHERE pipeline_name = %s AND run_id = %s ORDER BY sequence DESC LIMIT 1`. +- `list_runs` issues `SELECT DISTINCT run_id FROM firefly_checkpoints WHERE pipeline_name = %s ORDER BY run_id`. + +## Optional dependencies + +`pyproject.toml` gains two new extras, no new required dependencies and no new dev-dependencies: + +```toml +[project.optional-dependencies] +redis = ["redis>=5,<6"] +postgres = ["psycopg[binary]>=3,<4"] +``` + +Install paths: + +```bash +pip install fireflyframework-agentic[redis] +pip install fireflyframework-agentic[postgres] +pip install fireflyframework-agentic[redis,postgres] +``` + +## Testing strategy + +`unittest.mock` only. No `fakeredis`, no `pytest-postgresql`, no test containers. Real-service verification is out-of-band by the user against real Redis and real Postgres, not in this PR. + +### `RedisCheckpointer` tests + +Mock the `redis.Redis` client (`mock.create_autospec(redis.Redis)`). Assertions: + +- `save` issues `client.set(key=..., value=, ex=)` with the expected key format and TTL. +- `save` issues `client.zadd("firefly:ckpt::runs", {: })`. +- `load_latest` issues `client.keys("firefly:ckpt:::*")`, picks the lex-last key, then `client.get()`, returns the parsed `CheckpointRecord`. +- `load_latest` returns `None` when `keys` returns an empty list. +- `list_runs` issues `client.zrange("firefly:ckpt::runs", 0, -1)` and returns its result. +- Constructing `RedisCheckpointer(url=...)` when `_redis is None` raises `ImportError` with a message that names the extra (`pip install fireflyframework-agentic[redis]`). + +### `PostgresCheckpointer` tests + +Mock the `psycopg.Connection` and its `cursor()` context manager. Assertions: + +- First `save` issues the `CREATE TABLE IF NOT EXISTS` DDL exactly once. The DDL flag prevents subsequent `save` calls from re-issuing it. +- `save` issues `INSERT … ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET …` with bound parameters matching the `CheckpointRecord` fields. +- `load_latest` issues `SELECT … ORDER BY sequence DESC LIMIT 1` and returns the row hydrated into a `CheckpointRecord`. Returns `None` for the no-rows case. +- `list_runs` issues `SELECT DISTINCT run_id … ORDER BY run_id` and returns the list. +- Constructing `PostgresCheckpointer(dsn=...)` when `_psycopg is None` raises `ImportError` naming the extra. + +### Protocol-conformance test + +A single test module `tests/unit/pipeline/test_checkpoint_backends.py` parametrizes the existing software-factory scenario (architect → python_dev → deployer fails → resume → evaluator) across all three backends: + +- `FileCheckpointer` — real, uses `tmp_path`. +- `RedisCheckpointer` — wrapping a `MagicMock` Redis client that records `set`/`get`/`keys`/`zadd`/`zrange` calls in-memory. +- `PostgresCheckpointer` — wrapping a `MagicMock` connection whose cursor returns canned results from an in-memory dict mimicking the schema. + +The mocks keep enough state to make the full resume flow work (save then load returns what was saved). Protocol drift between backends is caught as a test failure. + +## Documentation + +`docs/pipeline.md` — "Checkpoint + Resume" subsection gains a backend-comparison table: + +| Backend | Use when | Trade-off | +|---|---|---| +| `FileCheckpointer` | Dev, single-host, ephemeral | No cross-process / cross-host sharing | +| `RedisCheckpointer` | Multi-worker, sub-day-scale runs | TTL eviction; not durable forever | +| `PostgresCheckpointer` | Long-lived runs, compliance, audit-friendly | Operational overhead of a DB | + +`examples/pipeline_state.py` — append a fourth scenario gated behind `if os.environ.get("PG_DSN"):` showing how to swap `FileCheckpointer` for `PostgresCheckpointer`. No-op when the env var is unset, so the example still runs anywhere. + +`fireflyframework_agentic/pipeline/__init__.py` — re-export `RedisCheckpointer` and `PostgresCheckpointer` via a plain `from … import …`. Both classes always exist as importable names; the optional-dep gate is in their `__init__` methods. Rationale: simpler than `__getattr__` indirection on the package, and the canonical pattern in this codebase for OTel and other soft dependencies. + +## Scope + +- `pipeline/checkpoint.py`: ~250 LOC added +- `pyproject.toml`: two new `[project.optional-dependencies]` entries; **no new required deps, no new dev deps** +- `tests/unit/pipeline/test_checkpoint_backends.py`: ~180 LOC +- `docs/pipeline.md`: +20 LOC (table + a paragraph) +- `examples/pipeline_state.py`: +30 LOC (optional fourth scenario) +- `pipeline/__init__.py`: +2 exports + +Total ~450 LOC + tests, independently shippable. + +## Verification + +After implementation: + +1. `pytest tests/unit/pipeline/ -v` — pipeline suite green, all three backends pass the parametrized conformance test. +2. `pytest tests/unit/` — full unit suite green. +3. `ruff check` + `ruff format --check` clean. +4. `pyright` clean on touched modules. +5. `python -c "from fireflyframework_agentic.pipeline import RedisCheckpointer, PostgresCheckpointer"` works in a fresh venv with neither extra installed (the imports succeed; constructing the classes is what raises). +6. `python -c "from fireflyframework_agentic.pipeline import RedisCheckpointer; RedisCheckpointer(url='redis://x')"` in the no-extras venv raises a clear `ImportError` naming the extra. + +## What lands next + +- **Phase 3b** — `StatePipelineEventHandler` Protocol + OTel spans per state-pipeline node. +- **Phase 3c** — `Pause(reason)` sentinel for HITL approval gates + `AuditLog` Protocol with Postgres impl reusing the 3a Postgres connection. From 2952e43f5e6ec352231be85d9af0e6a6a556406d Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 17:04:45 +0200 Subject: [PATCH 05/26] feat(pipeline): Redis + Postgres checkpointer backends (#147 phase 3a) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two durable Checkpointer implementations alongside FileCheckpointer. Single file (pipeline/checkpoint.py), guarded optional imports, no API changes to StatePipeline. Existing FileCheckpointer users are unaffected. - RedisCheckpointer (sync redis-py): SET+EX per checkpoint, ZADD/ZRANGE index of run_ids. TTL configurable, default 30 days. Accepts url= or a pre-built client= (for shared pools). - PostgresCheckpointer (sync psycopg3): single firefly_checkpoints table created idempotently on first save; INSERT … ON CONFLICT DO UPDATE for saves; SELECT … ORDER BY sequence DESC LIMIT 1 for load_latest. Accepts dsn= or a pre-built connection=. table_name is validated to prevent SQL injection from a misconfigured caller. - pyproject.toml: psycopg[binary]>=3 added to the existing `postgres` extra alongside asyncpg; `redis` extra already present. - Tests use unittest.mock only — no fakeredis, no testcontainers. A parametrized software-factory scenario runs across all three backends; per-backend tests verify the right calls are issued and that missing deps surface a clear ImportError naming the extra to install. - examples/pipeline_state.py: optional fourth scenario gated on PG_DSN env var demonstrating PostgresCheckpointer. - docs/pipeline.md: backend-comparison table + code snippet for swapping backends. Tests: 17 new in test_checkpoint_backends.py covering per-backend behaviour + cross-backend conformance. Full pipeline suite 114 passed (88 phase-1+2 baseline + 9 phase-2 features + 17 new phase-3a). Lints clean, pyright clean on touched modules. --- docs/pipeline.md | 29 +- examples/pipeline_state.py | 37 ++ fireflyframework_agentic/pipeline/__init__.py | 6 +- .../pipeline/checkpoint.py | 208 ++++++++- pyproject.toml | 3 + .../unit/pipeline/test_checkpoint_backends.py | 403 ++++++++++++++++++ 6 files changed, 678 insertions(+), 8 deletions(-) create mode 100644 tests/unit/pipeline/test_checkpoint_backends.py diff --git a/docs/pipeline.md b/docs/pipeline.md index 2ece529e..a6918b30 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -180,12 +180,18 @@ Custom reducers are any callable `(current, update) -> merged`. ### Checkpoint + Resume -Pass a `Checkpointer` to persist state after each successful node. The -filesystem implementation ships in the package; Redis/Postgres backends are -straightforward to plug in via the `Checkpointer` Protocol. +Pass a `Checkpointer` to persist state after each successful node. Three +backends ship out of the box, all conforming to the same `Checkpointer` +Protocol so they're swappable without code changes. + +| Backend | Use when | Trade-off | Install | +|---|---|---|---| +| `FileCheckpointer` | Dev, single-host, ephemeral | No cross-process / cross-host sharing | (default — no extra) | +| `RedisCheckpointer` | Multi-worker, sub-day-scale runs | TTL eviction; not durable forever | `pip install fireflyframework-agentic[redis]` | +| `PostgresCheckpointer` | Long-lived runs, compliance, audit-friendly | Operational overhead of a DB | `pip install fireflyframework-agentic[postgres]` | ```python -from fireflyframework_agentic.pipeline import FileCheckpointer +from fireflyframework_agentic.pipeline import FileCheckpointer # or Redis / Postgres pipeline = ( PipelineBuilder("software-factory", state=BuildState, @@ -208,6 +214,21 @@ result = await pipeline.invoke(run_id=result.run_id) result = await pipeline.invoke(state=loaded_state, start_at=deployer) ``` +Swapping backends is a one-line change. Redis uses a TTL on each checkpoint +key (default 30 days) plus a sorted-set index of run IDs; Postgres uses a +single `firefly_checkpoints` table created idempotently on first save: + +```python +from fireflyframework_agentic.pipeline import RedisCheckpointer, PostgresCheckpointer + +# Either a URL/DSN (backend constructs its own client) or a pre-built client +# (lets you share a connection pool across many pipelines). +checkpointer = RedisCheckpointer(url="redis://localhost:6379/0", ttl_seconds=86400 * 30) +checkpointer = RedisCheckpointer(client=my_existing_redis) +checkpointer = PostgresCheckpointer(dsn="postgresql://user:pw@host/db") +checkpointer = PostgresCheckpointer(connection=my_existing_psycopg_connection) +``` + ### Cycles and `recursion_limit` State pipelines permit cycles for ReAct loops and retry-with-critique patterns. diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index 9f0fc463..6a3536f8 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -43,6 +43,7 @@ import asyncio import logging +import os import tempfile from pathlib import Path from typing import Annotated @@ -52,6 +53,7 @@ from fireflyframework_agentic.pipeline import ( FileCheckpointer, PipelineBuilder, + PostgresCheckpointer, Send, extend, ) @@ -232,10 +234,45 @@ async def run_map_reduce() -> None: # ============================================================================= +async def run_software_factory_postgres() -> None: + """Optional: the same software-factory scenario backed by Postgres. + + Runs only when the ``PG_DSN`` env var is set (e.g. + ``PG_DSN=postgresql://user:pw@localhost/firefly``). Requires the + ``postgres`` extra: ``pip install fireflyframework-agentic[postgres]``. + """ + dsn = os.environ.get("PG_DSN") + if not dsn: + return + + print("=== 4. Software factory with PostgresCheckpointer ===\n") + + # Reset the deployer flag so this scenario starts clean. + _deployer_failed_once["flag"] = False + + checkpointer = PostgresCheckpointer(dsn=dsn) + pipeline = ( + PipelineBuilder("software-factory-pg", state=BuildState, checkpointer=checkpointer) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + first = await pipeline.invoke(BuildState(requirements="postgres-backed deploy")) + print(f" first run: success={first.success}, failed_node={first.failed_node}") + print(f" run_id: {first.run_id}\n") + second = await pipeline.invoke(run_id=first.run_id) + print(f" resumed: success={second.success}") + print(f" eval: {second.state.evaluation}\n") + + async def main() -> None: await run_branching() await run_software_factory() await run_map_reduce() + await run_software_factory_postgres() if __name__ == "__main__": diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index ed844dc0..3925952c 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -33,6 +33,8 @@ Checkpointer, CheckpointRecord, FileCheckpointer, + PostgresCheckpointer, + RedisCheckpointer, ) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy @@ -75,6 +77,7 @@ "FanOutStep", "FileCheckpointer", "NodeResult", + "PostgresCheckpointer", "PipelineBuilder", "PipelineContext", "PipelineEngine", @@ -82,8 +85,9 @@ "PipelineResult", "ReasoningStep", "RecursionLimitError", - "Send", + "RedisCheckpointer", "RetrievalStep", + "Send", "StatePipeline", "StatePipelineResult", "StepExecutor", diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index d4ebeaae..3a63d8c1 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -18,19 +18,37 @@ ``(pipeline_name, run_id, node_id)``. On resume the engine loads the latest checkpoint and skips nodes that already completed in that run. -:class:`FileCheckpointer` is a filesystem-backed JSON implementation suitable -for single-process workflows. Redis / Postgres checkpointers can implement -the same Protocol and be swapped in without API changes. +Three backends ship: + +* :class:`FileCheckpointer` — filesystem JSON. Best for dev / single-host. +* :class:`RedisCheckpointer` — Redis with TTL. Best for multi-worker, + sub-day-scale runs. Requires ``pip install fireflyframework-agentic[redis]``. +* :class:`PostgresCheckpointer` — Postgres with a single ``firefly_checkpoints`` + table. Best for long-lived runs / compliance. Requires + ``pip install fireflyframework-agentic[postgres]``. + +Any backend conforms to the :class:`Checkpointer` Protocol and is interchangeable. """ from __future__ import annotations import json +import time from pathlib import Path from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel +try: + import redis as _redis # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _redis = None # type: ignore[assignment] + +try: + import psycopg as _psycopg # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _psycopg = None # type: ignore[assignment] + class CheckpointRecord(BaseModel): """One saved checkpoint.""" @@ -97,3 +115,187 @@ def list_runs(self, pipeline_name: str) -> list[str]: if not pipeline_dir.exists(): return [] return sorted(d.name for d in pipeline_dir.iterdir() if d.is_dir()) + + +class RedisCheckpointer: + """Redis-backed checkpointer. + + Key layout:: + + firefly:ckpt:::_ -> JSON record (with TTL) + firefly:ckpt::runs -> ZSET of run_ids (score = last update ts) + + Parameters: + url: Redis connection string (e.g. ``redis://host:6379/0``). Mutually + exclusive with ``client``. + client: Pre-built ``redis.Redis`` instance. Use this to share a + connection pool across many pipelines. + ttl_seconds: TTL applied to each checkpoint key. Default 30 days. + The runs-index ZSET does not expire (it's tiny). + key_prefix: Override the default ``firefly:ckpt`` key prefix. + + Raises: + ImportError: When the ``redis`` extra is not installed. + """ + + def __init__( + self, + *, + url: str | None = None, + client: Any = None, + ttl_seconds: int = 60 * 60 * 24 * 30, + key_prefix: str = "firefly:ckpt", + ) -> None: + if _redis is None: + raise ImportError( + "RedisCheckpointer requires the 'redis' extra. " + "Install with: pip install fireflyframework-agentic[redis]" + ) + if (url is None) == (client is None): + raise ValueError("RedisCheckpointer needs exactly one of `url` or `client`.") + self._client = client if client is not None else _redis.Redis.from_url(url, decode_responses=True) + self._ttl = ttl_seconds + self._prefix = key_prefix + + def _ckpt_key(self, pipeline: str, run_id: str, sequence: int, node_id: str) -> str: + return f"{self._prefix}:{pipeline}:{run_id}:{sequence:06d}_{node_id}" + + def _runs_index_key(self, pipeline: str) -> str: + return f"{self._prefix}:{pipeline}:runs" + + def _run_pattern(self, pipeline: str, run_id: str) -> str: + return f"{self._prefix}:{pipeline}:{run_id}:*" + + def save(self, record: CheckpointRecord) -> None: + key = self._ckpt_key(record.pipeline_name, record.run_id, record.sequence, record.node_id) + self._client.set(key, record.model_dump_json(), ex=self._ttl) + self._client.zadd(self._runs_index_key(record.pipeline_name), {record.run_id: time.time()}) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + keys = self._client.keys(self._run_pattern(pipeline_name, run_id)) + if not keys: + return None + # Keys are zero-padded by sequence; lex-sorted last = numerically-latest. + latest_key = sorted(keys)[-1] + payload = self._client.get(latest_key) + if payload is None: + return None + return CheckpointRecord.model_validate(json.loads(payload)) + + def list_runs(self, pipeline_name: str) -> list[str]: + return list(self._client.zrange(self._runs_index_key(pipeline_name), 0, -1)) + + +class PostgresCheckpointer: + """Postgres-backed checkpointer. + + Uses a single table created on first ``save`` call. The DDL is idempotent + so multiple processes pointing at the same database are safe. + + Parameters: + dsn: Postgres connection string. Mutually exclusive with ``connection``. + connection: Pre-built ``psycopg.Connection``. Use this to share a + connection across many pipelines. + table_name: Override the default ``firefly_checkpoints`` table name. + + Raises: + ImportError: When the ``postgres`` extra is not installed. + """ + + _DDL = """ + CREATE TABLE IF NOT EXISTS {table} ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + node_id TEXT NOT NULL, + state JSONB NOT NULL, + completed_nodes JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS {table}_run_idx + ON {table} (pipeline_name, run_id); + """ + + def __init__( + self, + *, + dsn: str | None = None, + connection: Any = None, + table_name: str = "firefly_checkpoints", + ) -> None: + if _psycopg is None: + raise ImportError( + "PostgresCheckpointer requires the 'postgres' extra. " + "Install with: pip install fireflyframework-agentic[postgres]" + ) + if (dsn is None) == (connection is None): + raise ValueError("PostgresCheckpointer needs exactly one of `dsn` or `connection`.") + # Table name is interpolated into DDL — validate it strictly to avoid SQL injection. + if not table_name.replace("_", "").isalnum(): + raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") + self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) + self._table = table_name + self._ddl_applied = False + + def _ensure_table(self) -> None: + if self._ddl_applied: + return + with self._conn.cursor() as cur: + cur.execute(self._DDL.format(table=self._table)) + self._ddl_applied = True + + def save(self, record: CheckpointRecord) -> None: + self._ensure_table() + sql = ( + f"INSERT INTO {self._table} " + "(pipeline_name, run_id, sequence, node_id, state, completed_nodes) " + "VALUES (%s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET " + "node_id = EXCLUDED.node_id, " + "state = EXCLUDED.state, " + "completed_nodes = EXCLUDED.completed_nodes" + ) + with self._conn.cursor() as cur: + cur.execute( + sql, + ( + record.pipeline_name, + record.run_id, + record.sequence, + record.node_id, + json.dumps(record.state), + json.dumps(record.completed_nodes), + ), + ) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + self._ensure_table() + sql = ( + f"SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes " + f"FROM {self._table} " + f"WHERE pipeline_name = %s AND run_id = %s " + f"ORDER BY sequence DESC LIMIT 1" + ) + with self._conn.cursor() as cur: + cur.execute(sql, (pipeline_name, run_id)) + row = cur.fetchone() + if row is None: + return None + pipeline, rid, seq, node_id, state, completed = row + # psycopg returns JSONB as parsed Python objects; tolerate raw strings too. + return CheckpointRecord( + pipeline_name=pipeline, + run_id=rid, + sequence=seq, + node_id=node_id, + state=json.loads(state) if isinstance(state, str) else state, + completed_nodes=json.loads(completed) if isinstance(completed, str) else completed, + ) + + def list_runs(self, pipeline_name: str) -> list[str]: + self._ensure_table() + sql = f"SELECT DISTINCT run_id FROM {self._table} WHERE pipeline_name = %s ORDER BY run_id" + with self._conn.cursor() as cur: + cur.execute(sql, (pipeline_name,)) + return [r[0] for r in cur.fetchall()] diff --git a/pyproject.toml b/pyproject.toml index 416a109d..ed77b3c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ queues = [ postgres = [ "asyncpg>=0.30.0", "sqlalchemy>=2.0.0", + # psycopg drives the sync-Protocol PostgresCheckpointer; asyncpg above + # drives the async memory store. Both ship under the same extra. + "psycopg[binary]>=3.2.0,<4", ] mongodb = [ "motor>=3.6.0", diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py new file mode 100644 index 00000000..091f6cf3 --- /dev/null +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -0,0 +1,403 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Tests for the three Checkpointer backends (File, Redis, Postgres). + +Mocks only — no real Redis or Postgres needed. Real-service verification is +out-of-band against actual servers. +""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import MagicMock, create_autospec + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.checkpoint as checkpoint_module +from fireflyframework_agentic.pipeline import ( + CheckpointRecord, + FileCheckpointer, + PipelineBuilder, + PostgresCheckpointer, + RedisCheckpointer, + StatePipeline, +) + + +@pytest.fixture(autouse=True) +def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: + """Make ``_redis`` / ``_psycopg`` truthy so we can construct backends with mock clients. + + Tests that explicitly want the missing-dep code path (the two + ``..._missing_dep_raises`` tests) override this by setting the symbol back + to None via their own monkeypatch — the per-test patch wins. + """ + if checkpoint_module._redis is None: + monkeypatch.setattr(checkpoint_module, "_redis", MagicMock(name="redis_stub")) + if checkpoint_module._psycopg is None: + monkeypatch.setattr(checkpoint_module, "_psycopg", MagicMock(name="psycopg_stub")) + + +# ============================================================================= +# RedisCheckpointer +# ============================================================================= + + +def _redis_client_mock() -> MagicMock: + """Build a MagicMock that tracks the calls a RedisCheckpointer issues, with + just enough state (an in-memory dict) to make round-trip save→load work. + """ + store: dict[str, str] = {} + runs_index: dict[str, dict[str, float]] = {} + + client = MagicMock(name="redis.Redis") + + def fake_set(key: str, value: str, ex: int | None = None) -> bool: + store[key] = value + return True + + def fake_get(key: str) -> str | None: + return store.get(key) + + def fake_keys(pattern: str) -> list[str]: + # Trivial glob: only handles "*" patterns (which is what we use). + if pattern.endswith("*"): + prefix = pattern[:-1] + return [k for k in store if k.startswith(prefix)] + return [k for k in store if k == pattern] + + def fake_zadd(key: str, mapping: dict[str, float]) -> int: + runs_index.setdefault(key, {}).update(mapping) + return len(mapping) + + def fake_zrange(key: str, start: int, end: int) -> list[str]: + members = runs_index.get(key, {}) + ordered = sorted(members, key=lambda m: members[m]) + return ordered[start : (end + 1 if end != -1 else None)] + + client.set.side_effect = fake_set + client.get.side_effect = fake_get + client.keys.side_effect = fake_keys + client.zadd.side_effect = fake_zadd + client.zrange.side_effect = fake_zrange + return client + + +def test_redis_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(checkpoint_module, "_redis", None) + with pytest.raises(ImportError, match=r"\[redis\]"): + RedisCheckpointer(url="redis://x") + + +def test_redis_checkpointer_rejects_both_url_and_client() -> None: + with pytest.raises(ValueError, match="exactly one"): + RedisCheckpointer(url="redis://x", client=MagicMock()) + + +def test_redis_checkpointer_rejects_neither_url_nor_client() -> None: + with pytest.raises(ValueError, match="exactly one"): + RedisCheckpointer() + + +def test_redis_save_issues_set_and_zadd() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client, ttl_seconds=42) + record = CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=1, + node_id="n", + state={"x": 1}, + completed_nodes=["n"], + ) + ckpt.save(record) + + set_call = client.set.call_args + assert set_call.args[0] == "firefly:ckpt:p:r:000001_n" + assert json.loads(set_call.args[1])["node_id"] == "n" + assert set_call.kwargs["ex"] == 42 + + zadd_call = client.zadd.call_args + assert zadd_call.args[0] == "firefly:ckpt:p:runs" + assert "r" in zadd_call.args[1] + + +def test_redis_load_latest_picks_highest_sequence_key() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client) + for seq in (1, 5, 3): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=seq, + node_id=f"node{seq}", + state={"seq": seq}, + completed_nodes=[], + ) + ) + latest = ckpt.load_latest("p", "r") + assert latest is not None + assert latest.sequence == 5 + assert latest.node_id == "node5" + + +def test_redis_load_latest_returns_none_when_run_unknown() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client) + assert ckpt.load_latest("p", "missing") is None + + +def test_redis_list_runs_returns_zrange_result() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client) + for run_id in ("r1", "r2", "r3"): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id=run_id, + sequence=1, + node_id="n", + state={}, + completed_nodes=[], + ) + ) + runs = ckpt.list_runs("p") + assert set(runs) == {"r1", "r2", "r3"} + + +# ============================================================================= +# PostgresCheckpointer +# ============================================================================= + + +def _postgres_connection_mock() -> tuple[MagicMock, dict[tuple, dict[str, Any]]]: + """MagicMock Connection backed by an in-memory dict shaped like the + firefly_checkpoints table. Returns ``(conn, store)`` so tests can poke the store. + """ + store: dict[tuple[str, str, int], dict[str, Any]] = {} + ddl_calls: list[str] = [] + + conn = MagicMock(name="psycopg.Connection") + + def make_cursor() -> MagicMock: + cur = MagicMock(name="psycopg.Cursor") + cur.__enter__ = MagicMock(return_value=cur) + cur.__exit__ = MagicMock(return_value=None) + cur._last_fetchone = None + cur._last_fetchall = [] + + def fake_execute(sql: str, params: tuple | None = None) -> None: + sql_lower = sql.strip().lower() + if sql_lower.startswith("create table"): + ddl_calls.append(sql) + return + if sql_lower.startswith("insert into"): + assert params is not None + key = (params[0], params[1], params[2]) + store[key] = { + "pipeline_name": params[0], + "run_id": params[1], + "sequence": params[2], + "node_id": params[3], + "state": json.loads(params[4]) if isinstance(params[4], str) else params[4], + "completed_nodes": (json.loads(params[5]) if isinstance(params[5], str) else params[5]), + } + return + if sql_lower.startswith("select pipeline_name"): + assert params is not None + matches = [v for k, v in store.items() if k[0] == params[0] and k[1] == params[1]] + matches.sort(key=lambda r: r["sequence"], reverse=True) + cur._last_fetchone = ( + ( + matches[0]["pipeline_name"], + matches[0]["run_id"], + matches[0]["sequence"], + matches[0]["node_id"], + matches[0]["state"], + matches[0]["completed_nodes"], + ) + if matches + else None + ) + return + if sql_lower.startswith("select distinct run_id"): + assert params is not None + runs = sorted({k[1] for k in store if k[0] == params[0]}) + cur._last_fetchall = [(r,) for r in runs] + return + raise AssertionError(f"unexpected SQL: {sql}") + + cur.execute.side_effect = fake_execute + cur.fetchone.side_effect = lambda: cur._last_fetchone + cur.fetchall.side_effect = lambda: cur._last_fetchall + return cur + + conn.cursor.side_effect = make_cursor + conn._ddl_calls = ddl_calls + return conn, store + + +def test_postgres_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(checkpoint_module, "_psycopg", None) + with pytest.raises(ImportError, match=r"\[postgres\]"): + PostgresCheckpointer(dsn="postgresql://x") + + +def test_postgres_checkpointer_rejects_both_dsn_and_connection() -> None: + with pytest.raises(ValueError, match="exactly one"): + PostgresCheckpointer(dsn="postgresql://x", connection=MagicMock()) + + +def test_postgres_checkpointer_rejects_bad_table_name() -> None: + with pytest.raises(ValueError, match="Invalid table_name"): + PostgresCheckpointer(connection=MagicMock(), table_name="bad; DROP TABLE users") + + +def test_postgres_ddl_runs_once_across_many_saves() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + for seq in range(3): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=seq, + node_id=f"n{seq}", + state={}, + completed_nodes=[], + ) + ) + assert len(conn._ddl_calls) == 1, "DDL should run exactly once per instance" + + +def test_postgres_save_then_load_latest_round_trips() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + for seq, node_id in [(1, "a"), (5, "e"), (3, "c")]: + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=seq, + node_id=node_id, + state={"seq": seq}, + completed_nodes=[node_id], + ) + ) + latest = ckpt.load_latest("p", "r") + assert latest is not None + assert latest.sequence == 5 + assert latest.node_id == "e" + + +def test_postgres_load_latest_returns_none_when_empty() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + assert ckpt.load_latest("p", "missing") is None + + +def test_postgres_list_runs_returns_distinct_run_ids() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + for run_id in ("rA", "rB", "rA"): # rA twice, rB once + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id=run_id, + sequence=1, + node_id="n", + state={}, + completed_nodes=[], + ) + ) + assert ckpt.list_runs("p") == ["rA", "rB"] + + +# ============================================================================= +# Protocol conformance — software-factory scenario across all three backends +# ============================================================================= + + +class FactoryState(BaseModel): + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +def _build_factory(checkpointer: Any) -> StatePipeline: + """Construct the canonical 4-step agent pipeline that fails on first deploy.""" + state_flag = {"failed_once": False} + + async def architect(state: FactoryState) -> dict: + return {"spec": f"spec for {state.requirements}"} + + async def python_dev(state: FactoryState) -> dict: + return {"code": f"# code for {state.spec}"} + + async def deployer(state: FactoryState) -> dict: + if not state_flag["failed_once"]: + state_flag["failed_once"] = True + raise RuntimeError("blip") + return {"deploy_url": "https://app"} + + async def evaluator(state: FactoryState) -> dict: + return {"evaluation": f"PASS {state.deploy_url}"} + + pipeline = ( + PipelineBuilder("factory", state=FactoryState, checkpointer=checkpointer) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + assert isinstance(pipeline, StatePipeline) + return pipeline + + +@pytest.fixture +def file_backend(tmp_path): + return FileCheckpointer(tmp_path / "ckpt") + + +@pytest.fixture +def redis_backend(): + return RedisCheckpointer(client=_redis_client_mock()) + + +@pytest.fixture +def postgres_backend(): + conn, _store = _postgres_connection_mock() + return PostgresCheckpointer(connection=conn) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend_fixture", ["file_backend", "redis_backend", "postgres_backend"]) +async def test_backend_supports_fail_and_resume(backend_fixture, request) -> None: + """Same scenario across all three backends: deployer fails, resume completes.""" + backend = request.getfixturevalue(backend_fixture) + pipeline = _build_factory(backend) + + first = await pipeline.invoke(FactoryState(requirements="users service")) + assert not first.success + assert first.failed_node == "deployer" + assert first.completed_nodes == ["architect", "python_dev"] + + second = await pipeline.invoke(run_id=first.run_id) + assert second.success + assert second.completed_nodes == ["architect", "python_dev", "deployer", "evaluator"] + assert second.state.evaluation == "PASS https://app" + + +# Silence the unused-import warning for the autospec we don't actually use here +# but keep available for follow-up tests. +_ = create_autospec From 61f6971097f47a28bd053276a0a1f95e7f1a3c9c Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 17:19:24 +0200 Subject: [PATCH 06/26] chore: remove docs/superpowers/specs/* accidentally force-added past .gitignore --- .../2026-05-27-pipeline-phase-3a-design.md | 209 ------------------ 1 file changed, 209 deletions(-) delete mode 100644 docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md diff --git a/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md b/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md deleted file mode 100644 index 7eb5f684..00000000 --- a/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md +++ /dev/null @@ -1,209 +0,0 @@ -# Phase 3a — Durable Checkpointer Backends (Redis + Postgres) - -Issue [#147](https://github.com/fireflyframework/fireflyframework-agentic/issues/147), phase 3a. Stacked on top of Phase 1+2 (`issue-147-pipeline-evolution`). - -## Problem - -`StatePipeline` checkpointing today has one backend: `FileCheckpointer`, which writes JSON files to a local filesystem. That blocks the use cases customers actually run agentic pipelines for: - -- **Multi-worker fail-over.** A pipeline that crashes on worker A cannot be resumed on worker B because the checkpoint files are on A's disk. -- **Containerized deploys.** Ephemeral container filesystems lose checkpoint state on restart. -- **Horizontal scaling.** No shared state means runs are pinned to a single host. -- **Long-lived runs.** Filesystem checkpoints accumulate without TTL; cleanup is manual. - -Phase 3a ships two durable backends — Redis and Postgres — pluggable through the existing `Checkpointer` Protocol. No API changes to `StatePipeline`; no breaking changes for existing `FileCheckpointer` users. - -## Goal - -`PipelineBuilder("agent", state=..., checkpointer=...)` accepts a `RedisCheckpointer` or `PostgresCheckpointer` interchangeably with the existing `FileCheckpointer`. The software-factory scenario (architect → python_dev → deployer fails → resume on a different worker → evaluator) works end-to-end against any of the three. - -## Non-goals - -- Backend-pluggable observability / OTel spans (phase 3b). -- HITL pause primitive and audit log (phase 3c). -- Real-Redis / real-Postgres tests in this PR — verified out-of-band by the user. -- Schema migration tooling for the JSONB `state` column — the application owns state-schema evolution. -- Connection-pool tuning knobs beyond "pass me a client." -- Auth / RBAC on resume APIs. -- Streaming partial state. - -## Architecture - -Single file: `fireflyframework_agentic/pipeline/checkpoint.py`. Existing contents (`Checkpointer` Protocol, `CheckpointRecord`, `FileCheckpointer`) are preserved. Two new classes added to the same module: - -``` -pipeline/checkpoint.py -├── Checkpointer (Protocol, existing) -├── CheckpointRecord (existing) -├── FileCheckpointer (existing) -├── RedisCheckpointer (new) -└── PostgresCheckpointer (new) -``` - -### Guarded optional imports - -Top-of-file guarded imports mirror the established pattern in `engine.py` for OTel: - -```python -try: - import redis as _redis -except ImportError: - _redis = None - -try: - import psycopg as _psycopg -except ImportError: - _psycopg = None -``` - -Each backend's `__init__` raises a clear install-instruction error if its dependency is absent. No inline imports — project rule respected. - -### Connection lifecycle - -Both backends accept either a connection string OR a pre-built sync client so callers can share a connection pool across many pipelines: - -```python -RedisCheckpointer(url="redis://localhost:6379/0", ttl_seconds=86400 * 30) -RedisCheckpointer(client=existing_redis_client) - -PostgresCheckpointer(dsn="postgresql://user:pw@host/db") -PostgresCheckpointer(connection=existing_psycopg_connection) -``` - -The `Checkpointer` Protocol's methods are sync (called synchronously inside `StatePipeline._save_checkpoint`). Both backends use the sync clients (`redis-py`, `psycopg[binary]`) directly — no `asyncio.run` indirection. - -## Storage schemas - -### Redis - -One key per checkpoint, plus a sorted-set index per pipeline for `list_runs`: - -| Key | Type | Value | -|---|---|---| -| `firefly:ckpt:{pipeline}:{run_id}:{seq:06d}_{node_id}` | string | JSON-serialized `CheckpointRecord` | -| `firefly:ckpt:{pipeline}:runs` | sorted set | members = `run_id`s, scores = last-update unix timestamps | - -- TTL on the per-checkpoint keys is configurable, default 30 days. -- `save` issues `SET key value EX ttl` then `ZADD index ts run_id` (idempotent — score updated each call). -- `load_latest` issues `KEYS firefly:ckpt:{pipeline}:{run_id}:*`, picks the lexicographically last key (the zero-padded `seq` makes lex order match numeric order), then `GET` it. -- `list_runs` issues `ZRANGE index 0 -1` and returns all run IDs ordered by last-update. - -`KEYS` is acceptable here because the cardinality per run is bounded by node count × visit count (small for agentic workflows). A `SCAN`-based variant is trivial to add if scale demands it; not in scope for 3a. - -### Postgres - -Single table, created idempotently on first connect: - -```sql -CREATE TABLE IF NOT EXISTS firefly_checkpoints ( - pipeline_name TEXT NOT NULL, - run_id TEXT NOT NULL, - sequence INT NOT NULL, - node_id TEXT NOT NULL, - state JSONB NOT NULL, - completed_nodes JSONB NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - PRIMARY KEY (pipeline_name, run_id, sequence) -); -CREATE INDEX IF NOT EXISTS firefly_checkpoints_run - ON firefly_checkpoints (pipeline_name, run_id); -``` - -- DDL runs once per `PostgresCheckpointer` instance via a `_ddl_applied` flag set on first `save`. -- `save` issues `INSERT … ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET node_id = EXCLUDED.node_id, state = EXCLUDED.state, completed_nodes = EXCLUDED.completed_nodes`. -- `load_latest` issues `SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes FROM firefly_checkpoints WHERE pipeline_name = %s AND run_id = %s ORDER BY sequence DESC LIMIT 1`. -- `list_runs` issues `SELECT DISTINCT run_id FROM firefly_checkpoints WHERE pipeline_name = %s ORDER BY run_id`. - -## Optional dependencies - -`pyproject.toml` gains two new extras, no new required dependencies and no new dev-dependencies: - -```toml -[project.optional-dependencies] -redis = ["redis>=5,<6"] -postgres = ["psycopg[binary]>=3,<4"] -``` - -Install paths: - -```bash -pip install fireflyframework-agentic[redis] -pip install fireflyframework-agentic[postgres] -pip install fireflyframework-agentic[redis,postgres] -``` - -## Testing strategy - -`unittest.mock` only. No `fakeredis`, no `pytest-postgresql`, no test containers. Real-service verification is out-of-band by the user against real Redis and real Postgres, not in this PR. - -### `RedisCheckpointer` tests - -Mock the `redis.Redis` client (`mock.create_autospec(redis.Redis)`). Assertions: - -- `save` issues `client.set(key=..., value=, ex=)` with the expected key format and TTL. -- `save` issues `client.zadd("firefly:ckpt::runs", {: })`. -- `load_latest` issues `client.keys("firefly:ckpt:::*")`, picks the lex-last key, then `client.get()`, returns the parsed `CheckpointRecord`. -- `load_latest` returns `None` when `keys` returns an empty list. -- `list_runs` issues `client.zrange("firefly:ckpt::runs", 0, -1)` and returns its result. -- Constructing `RedisCheckpointer(url=...)` when `_redis is None` raises `ImportError` with a message that names the extra (`pip install fireflyframework-agentic[redis]`). - -### `PostgresCheckpointer` tests - -Mock the `psycopg.Connection` and its `cursor()` context manager. Assertions: - -- First `save` issues the `CREATE TABLE IF NOT EXISTS` DDL exactly once. The DDL flag prevents subsequent `save` calls from re-issuing it. -- `save` issues `INSERT … ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET …` with bound parameters matching the `CheckpointRecord` fields. -- `load_latest` issues `SELECT … ORDER BY sequence DESC LIMIT 1` and returns the row hydrated into a `CheckpointRecord`. Returns `None` for the no-rows case. -- `list_runs` issues `SELECT DISTINCT run_id … ORDER BY run_id` and returns the list. -- Constructing `PostgresCheckpointer(dsn=...)` when `_psycopg is None` raises `ImportError` naming the extra. - -### Protocol-conformance test - -A single test module `tests/unit/pipeline/test_checkpoint_backends.py` parametrizes the existing software-factory scenario (architect → python_dev → deployer fails → resume → evaluator) across all three backends: - -- `FileCheckpointer` — real, uses `tmp_path`. -- `RedisCheckpointer` — wrapping a `MagicMock` Redis client that records `set`/`get`/`keys`/`zadd`/`zrange` calls in-memory. -- `PostgresCheckpointer` — wrapping a `MagicMock` connection whose cursor returns canned results from an in-memory dict mimicking the schema. - -The mocks keep enough state to make the full resume flow work (save then load returns what was saved). Protocol drift between backends is caught as a test failure. - -## Documentation - -`docs/pipeline.md` — "Checkpoint + Resume" subsection gains a backend-comparison table: - -| Backend | Use when | Trade-off | -|---|---|---| -| `FileCheckpointer` | Dev, single-host, ephemeral | No cross-process / cross-host sharing | -| `RedisCheckpointer` | Multi-worker, sub-day-scale runs | TTL eviction; not durable forever | -| `PostgresCheckpointer` | Long-lived runs, compliance, audit-friendly | Operational overhead of a DB | - -`examples/pipeline_state.py` — append a fourth scenario gated behind `if os.environ.get("PG_DSN"):` showing how to swap `FileCheckpointer` for `PostgresCheckpointer`. No-op when the env var is unset, so the example still runs anywhere. - -`fireflyframework_agentic/pipeline/__init__.py` — re-export `RedisCheckpointer` and `PostgresCheckpointer` via a plain `from … import …`. Both classes always exist as importable names; the optional-dep gate is in their `__init__` methods. Rationale: simpler than `__getattr__` indirection on the package, and the canonical pattern in this codebase for OTel and other soft dependencies. - -## Scope - -- `pipeline/checkpoint.py`: ~250 LOC added -- `pyproject.toml`: two new `[project.optional-dependencies]` entries; **no new required deps, no new dev deps** -- `tests/unit/pipeline/test_checkpoint_backends.py`: ~180 LOC -- `docs/pipeline.md`: +20 LOC (table + a paragraph) -- `examples/pipeline_state.py`: +30 LOC (optional fourth scenario) -- `pipeline/__init__.py`: +2 exports - -Total ~450 LOC + tests, independently shippable. - -## Verification - -After implementation: - -1. `pytest tests/unit/pipeline/ -v` — pipeline suite green, all three backends pass the parametrized conformance test. -2. `pytest tests/unit/` — full unit suite green. -3. `ruff check` + `ruff format --check` clean. -4. `pyright` clean on touched modules. -5. `python -c "from fireflyframework_agentic.pipeline import RedisCheckpointer, PostgresCheckpointer"` works in a fresh venv with neither extra installed (the imports succeed; constructing the classes is what raises). -6. `python -c "from fireflyframework_agentic.pipeline import RedisCheckpointer; RedisCheckpointer(url='redis://x')"` in the no-extras venv raises a clear `ImportError` naming the extra. - -## What lands next - -- **Phase 3b** — `StatePipelineEventHandler` Protocol + OTel spans per state-pipeline node. -- **Phase 3c** — `Pause(reason)` sentinel for HITL approval gates + `AuditLog` Protocol with Postgres impl reusing the 3a Postgres connection. From 6471c60e53b52099edb63f0d590781421635c0de Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 17:22:34 +0200 Subject: [PATCH 07/26] feat(pipeline): StatePipelineEventHandler + OTel spans (#147 phase 3b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds observability to state pipelines, mirroring the legacy PipelineEngine story but with state-mode semantics: run_id is plumbed through every callback, on_node_start carries a per-node visit counter, and there is no on_node_skip (state pipelines abort on failure rather than skipping). - New StatePipelineEventHandler Protocol in pipeline/engine.py with on_pipeline_start / on_node_start / on_node_complete / on_node_error / on_pipeline_complete. Partial handlers are valid (hasattr-checked). - Per-node OTel spans nested under one pipeline-level span. The existing _start_otel_span helper on PipelineEngine is lifted to a module-level start_otel_span function shared by both pipeline types. - PipelineBuilder gains an event_handler kwarg that flows into StatePipeline. - Fan-out via Send emits per-Send node-start/complete pairs with each Send's own visit number (snapshot at increment time, not post-loop). - Handler exceptions are swallowed — observability never breaks business logic. Tests: 8 new in test_state_pipeline_observability.py covering event ordering, failure path, cyclic visit counts, fan-out per-Send events, resume-from-checkpoint, partial handler, swallowed handler exceptions, and OTel span emission with attribute snapshots. Example: examples/pipeline_state.py gains a ProgressHandler that prints live progress for the software-factory scenario. Verification: - pytest tests/unit/pipeline/ → 122 passed (114 baseline + 8 new) - ruff check + format clean - pyright clean on touched modules --- docs/pipeline.md | 53 +++ examples/pipeline_state.py | 43 ++- fireflyframework_agentic/pipeline/__init__.py | 7 +- fireflyframework_agentic/pipeline/builder.py | 5 +- fireflyframework_agentic/pipeline/engine.py | 69 +++- .../pipeline/state_pipeline.py | 177 ++++++++-- .../test_state_pipeline_observability.py | 324 ++++++++++++++++++ 7 files changed, 623 insertions(+), 55 deletions(-) create mode 100644 tests/unit/pipeline/test_state_pipeline_observability.py diff --git a/docs/pipeline.md b/docs/pipeline.md index a6918b30..60bd2097 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -267,6 +267,59 @@ When all worker targets share a common successor, the engine continues there once the fan-out completes; the aggregator runs once with all results in shared state. +### Observability + +State pipelines emit lifecycle callbacks and OTel spans so ops can see what +an agent workflow is doing in real time. + +`StatePipelineEventHandler` mirrors the legacy `PipelineEventHandler` but +every callback carries the `run_id` (so events can be correlated across +resumes) and `on_node_start` carries a per-node visit counter (so cyclic +graphs and `Send` fan-outs are distinguishable). Implement any subset of +methods; missing ones are no-ops. + +```python +from fireflyframework_agentic.pipeline import PipelineBuilder, StatePipelineEventHandler + + +class ProgressHandler: + async def on_pipeline_start(self, name, run_id): + print(f"▶ [{name}] run {run_id} starting") + + async def on_node_start(self, name, run_id, node_id, visit): + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete(self, name, run_id, node_id, latency_ms): + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error(self, name, run_id, node_id, error): + print(f" ✗ {node_id}: {error}") + + async def on_pipeline_complete(self, name, run_id, success, duration_ms): + status = "OK" if success else "FAILED" + print(f"═ [{name}] {status} in {duration_ms:.0f}ms") + + +pipeline = ( + PipelineBuilder("agent", state=AgentState, event_handler=ProgressHandler()) + .add_node(classify).add_node(answer).add_node(escalate) + .branch(classify, route) + .build() +) +``` + +In parallel, the pipeline emits OTel spans automatically when +`observability_enabled` is True and `opentelemetry` is installed: + +- One pipeline-level span `pipeline.state.` around each `invoke`, + attributes `firefly.pipeline`, `firefly.run_id`. +- One per-node span `pipeline.state.node.` for each `fn(state)` + call, parented under the pipeline span, attributes `firefly.node`, + `firefly.visit`. +- For `Send` fan-out: one per-Send span as a sibling under the pipeline span. + +Handler exceptions are swallowed — observability never breaks business logic. + ### Mermaid Export `StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index 6a3536f8..717c7f64 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -152,13 +152,52 @@ async def evaluator(state: BuildState) -> dict: return {"evaluation": f"PASS — deployed at {state.deploy_url}"} +class ProgressHandler: + """Prints live progress for the software-factory scenario. + + Implements only a subset of :class:`StatePipelineEventHandler` — missing + methods are no-ops, which is fine because the pipeline tolerates partial + handlers. + """ + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + print(f" ▶ [{pipeline_name}] run {run_id[:8]}… starting") + + async def on_node_start( + self, pipeline_name: str, run_id: str, node_id: str, visit: int + ) -> None: + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete( + self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float + ) -> None: + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error( + self, pipeline_name: str, run_id: str, node_id: str, error: str + ) -> None: + print(f" ✗ {node_id}: {error}") + + async def on_pipeline_complete( + self, pipeline_name: str, run_id: str, success: bool, duration_ms: float + ) -> None: + status = "OK" if success else "FAILED" + print(f" ═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") + + async def run_software_factory() -> None: - print("=== 2. Software factory with checkpoint/resume ===\n") + print("=== 2. Software factory with checkpoint/resume + live progress ===\n") + handler = ProgressHandler() with tempfile.TemporaryDirectory() as tmp: ckpt = FileCheckpointer(Path(tmp)) pipeline = ( - PipelineBuilder("software-factory", state=BuildState, checkpointer=ckpt) + PipelineBuilder( + "software-factory", + state=BuildState, + checkpointer=ckpt, + event_handler=handler, + ) .add_node(architect) .add_node(python_dev) .add_node(deployer) diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index 3925952c..af4778df 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -38,7 +38,11 @@ ) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine, PipelineEventHandler +from fireflyframework_agentic.pipeline.engine import ( + PipelineEngine, + PipelineEventHandler, + StatePipelineEventHandler, +) from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult from fireflyframework_agentic.pipeline.state_pipeline import ( @@ -89,6 +93,7 @@ "RetrievalStep", "Send", "StatePipeline", + "StatePipelineEventHandler", "StatePipelineResult", "StepExecutor", "append", diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index bb5f7681..2975d59f 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -55,7 +55,7 @@ from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.checkpoint import Checkpointer from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.engine import PipelineEngine, StatePipelineEventHandler from fireflyframework_agentic.pipeline.state_pipeline import ( BranchSpec, RouterFn, @@ -86,6 +86,7 @@ def __init__( state: type[BaseModel] | None = None, checkpointer: Checkpointer | None = None, recursion_limit: int = 25, + event_handler: StatePipelineEventHandler | None = None, ) -> None: # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). # The legacy port-based path keeps acyclicity as an invariant. @@ -94,6 +95,7 @@ def __init__( self._state_schema = state self._checkpointer = checkpointer self._recursion_limit = recursion_limit + self._event_handler = event_handler self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] # State-based mode bookkeeping. Keyed by node id. @@ -267,6 +269,7 @@ def build(self) -> PipelineEngine | StatePipeline: branches=self._branches, checkpointer=self._checkpointer, recursion_limit=self._recursion_limit, + event_handler=self._event_handler, ) return PipelineEngine(self._dag) diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 82010a9d..9f7e3d0f 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -71,6 +71,60 @@ async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration ... +@runtime_checkable +class StatePipelineEventHandler(Protocol): + """Protocol for state-pipeline progress callbacks. + + Mirrors :class:`PipelineEventHandler` but every callback carries the + ``run_id`` so ops can correlate events across resumes, and + ``on_node_start`` carries a ``visit`` counter so cyclic graphs are + distinguishable per iteration. There is no ``on_node_skip`` — state + pipelines abort on failure rather than skipping downstream nodes. + + Implement any subset of methods; missing ones are no-ops. + """ + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + """Called once when ``invoke`` begins.""" + ... + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + """Called each time a node is about to run. ``visit`` starts at 1 + and increments per re-entry (cycles, Send fan-out).""" + ... + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + """Called when a node completes successfully.""" + ... + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + """Called when a node raises an exception.""" + ... + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + """Called once when ``invoke`` returns.""" + ... + + +def start_otel_span(name: str, **attributes: Any) -> Any: + """Start an OTel span if observability is enabled, else return ``None``. + + Module-level helper shared by :class:`PipelineEngine` and + :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. + """ + try: + if not get_config().observability_enabled: + return None + if otel_trace is None: + return None + return otel_trace.get_tracer("fireflyframework_agentic").start_span( + name, + attributes={f"firefly.{k}": str(v) for k, v in attributes.items()}, + ) + except Exception: # noqa: BLE001 + return None + + class PipelineEngine: """Executes a :class:`DAG` by computing topological levels and running nodes within each level concurrently. @@ -347,19 +401,8 @@ async def _execute_node( @staticmethod def _start_otel_span(name: str, **attributes: Any) -> Any: - """Start an OTel span if observability is enabled, else return *None*.""" - try: - if not get_config().observability_enabled: - return None - if otel_trace is None: - return None - - return otel_trace.get_tracer("fireflyframework_agentic").start_span( - name, - attributes={f"firefly.{k}": str(v) for k, v in attributes.items()}, - ) - except Exception: # noqa: BLE001 - return None + """Backwards-compatible wrapper around the module-level :func:`start_otel_span`.""" + return start_otel_span(name, **attributes) @staticmethod def _aggregate_usage(correlation_id: str) -> Any: diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 6c818bf7..d33c37b7 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -24,19 +24,24 @@ from __future__ import annotations import asyncio +import contextlib import inspect import logging import time import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, get_type_hints +from typing import TYPE_CHECKING, Any, get_type_hints from pydantic import BaseModel from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id +from fireflyframework_agentic.pipeline.engine import start_otel_span + +if TYPE_CHECKING: + from fireflyframework_agentic.pipeline.engine import StatePipelineEventHandler from fireflyframework_agentic.pipeline.reducers import Reducer, replace logger = logging.getLogger(__name__) @@ -151,6 +156,7 @@ def __init__( branches: dict[str, BranchSpec], checkpointer: Checkpointer | None = None, recursion_limit: int = 25, + event_handler: StatePipelineEventHandler | None = None, ) -> None: self._name = name self._dag = dag @@ -159,9 +165,43 @@ def __init__( self._branches = branches self._checkpointer = checkpointer self._recursion_limit = recursion_limit + self._event_handler = event_handler self._reducers = discover_reducers(state_schema) self._validate() + async def _finalize_run( + self, + result: StatePipelineResult, + span: Any, + start_time: float, + run_id: str, + ) -> StatePipelineResult: + """Close the pipeline-level span and emit ``on_pipeline_complete``. + + Every return path in :meth:`invoke` after the observability boundary + flows through this helper. + """ + if span is not None: + with contextlib.suppress(Exception): + span.end() + duration_ms = (time.perf_counter() - start_time) * 1000 + await self._emit("on_pipeline_complete", self._name, run_id, result.success, duration_ms) + return result + + async def _emit(self, method: str, *args: Any) -> None: + """Invoke ``method`` on the configured event handler if it exists. + + Missing methods are no-ops; raised exceptions are swallowed so + observability never breaks business logic. + """ + if self._event_handler is None: + return + fn = getattr(self._event_handler, method, None) + if fn is None: + return + with contextlib.suppress(Exception): + await fn(*args) + @property def name(self) -> str: return self._name @@ -358,6 +398,11 @@ async def invoke( next_step: str | list[Send] | None = current_node + # ---- pipeline-level observability boundary ------------------------- + pipeline_start_time = time.perf_counter() + pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id) + await self._emit("on_pipeline_start", self._name, run_id) + while next_step is not None: # --- fan-out branch (list[Send]) --------------------------------- if isinstance(next_step, list): @@ -371,13 +416,18 @@ async def invoke( visit_counts=visit_counts, ) except _NodeFailureError as fail: - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=fail.message, - failed_node=fail.node_id, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=fail.message, + failed_node=fail.node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) # After fan-out, continue from the workers' shared successor (if any). next_step = self._common_successor([s.target for s in next_step]) @@ -386,22 +436,30 @@ async def invoke( # --- single-node step -------------------------------------------- node_id = next_step visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 - if visit_counts[node_id] > self._recursion_limit: + visit_n = visit_counts[node_id] + if visit_n > self._recursion_limit: msg = ( f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " f"Raise recursion_limit= or fix the routing logic." ) logger.error(msg) - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=msg, - failed_node=node_id, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=msg, + failed_node=node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) fn = self._node_fns[node_id] + node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) + await self._emit("on_node_start", self._name, run_id, node_id, visit_n) t0 = time.perf_counter() try: update = await fn(state) @@ -412,15 +470,28 @@ async def invoke( run_id, node_id, ) - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, + await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) elapsed = (time.perf_counter() - t0) * 1000 + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) if update: @@ -433,20 +504,30 @@ async def invoke( try: next_step = self._next_step(node_id, state) except PipelineError as exc: - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=True, + ), + pipeline_span, + pipeline_start_time, + run_id, ) async def _run_fanout( @@ -462,8 +543,12 @@ async def _run_fanout( """Run all ``Send`` dispatches concurrently. Each task gets its own state copy with the Send's payload merged in; results are reduced into shared state. """ + # Snapshot each Send's visit number BEFORE dispatch so the worker's + # closure captures its own visit, not the final post-increment value. + sends_with_visits: list[tuple[Send, int]] = [] for send in sends: visit_counts[send.target] = visit_counts.get(send.target, 0) + 1 + sends_with_visits.append((send, visit_counts[send.target])) if visit_counts[send.target] > self._recursion_limit: raise _NodeFailureError( node_id=send.target, @@ -472,13 +557,29 @@ async def _run_fanout( ), ) - async def _run_one(send: Send) -> tuple[Send, dict[str, Any] | None]: + async def _run_one(send: Send, visit_n: int) -> tuple[Send, dict[str, Any] | None]: + await self._emit("on_node_start", self._name, run_id, send.target, visit_n) + node_span = start_otel_span(f"pipeline.state.node.{send.target}", node=send.target, visit=visit_n) task_state = apply_update(state, send.payload, self._reducers) fn = self._node_fns[send.target] - return send, await fn(task_state) + t0 = time.perf_counter() + try: + update = await fn(task_state) + except Exception as exc: + await self._emit("on_node_error", self._name, run_id, send.target, str(exc)) + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + raise + elapsed = (time.perf_counter() - t0) * 1000 + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + await self._emit("on_node_complete", self._name, run_id, send.target, elapsed) + return send, update try: - results = await asyncio.gather(*(_run_one(s) for s in sends)) + results = await asyncio.gather(*(_run_one(s, v) for s, v in sends_with_visits)) except Exception as exc: # Best-effort: report the first failing target as the failure point. raise _NodeFailureError( diff --git a/tests/unit/pipeline/test_state_pipeline_observability.py b/tests/unit/pipeline/test_state_pipeline_observability.py new file mode 100644 index 00000000..ac79233a --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_observability.py @@ -0,0 +1,324 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3b tests: StatePipelineEventHandler callbacks + OTel spans for state pipelines.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Annotated, Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.engine as engine_module +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + Send, + extend, +) + + +@dataclass +class RecordingHandler: + """A test handler that captures every callback in order.""" + + events: list[tuple] = field(default_factory=list) + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.events.append(("pipeline_start", pipeline_name, run_id)) + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + self.events.append(("node_start", node_id, visit)) + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + self.events.append(("node_complete", node_id)) + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + self.events.append(("node_error", node_id, error)) + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + self.events.append(("pipeline_complete", success)) + + +class LinearState(BaseModel): + log: Annotated[list[str], extend] = [] + + +class LoopState(BaseModel): + counter: int = 0 + + +# --- linear pipeline event ordering ----------------------------------------- + + +@pytest.mark.asyncio +async def test_linear_pipeline_emits_events_in_order() -> None: + async def a(state: LinearState) -> dict: + return {"log": ["a"]} + + async def b(state: LinearState) -> dict: + return {"log": ["b"]} + + async def c(state: LinearState) -> dict: + return {"log": ["c"]} + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("linear", state=LinearState, event_handler=handler) + .add_node(a) + .add_node(b) + .add_node(c) + .chain(a, b, c) + .build() + ) + await pipeline.invoke(LinearState()) + + event_kinds = [e[0] for e in handler.events] + assert event_kinds == [ + "pipeline_start", + "node_start", + "node_complete", + "node_start", + "node_complete", + "node_start", + "node_complete", + "pipeline_complete", + ] + assert handler.events[-1] == ("pipeline_complete", True) + + +# --- failure path ---------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_failure_emits_node_error_and_pipeline_complete_false() -> None: + async def boom(state: LinearState) -> dict: + raise RuntimeError("nope") + + handler = RecordingHandler() + pipeline = PipelineBuilder("fail", state=LinearState, event_handler=handler).add_node(boom).build() + result = await pipeline.invoke(LinearState()) + assert not result.success + + assert ("node_error", "boom", "nope") in handler.events + assert ("pipeline_complete", False) in handler.events + # node_error fires BEFORE pipeline_complete + err_idx = handler.events.index(("node_error", "boom", "nope")) + done_idx = handler.events.index(("pipeline_complete", False)) + assert err_idx < done_idx + + +# --- cyclic graph visit count ---------------------------------------------- + + +@pytest.mark.asyncio +async def test_cyclic_graph_increments_visit_count() -> None: + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1} + + async def done(state: LoopState) -> dict: + return {} + + def route(state: LoopState) -> str: + return "done" if state.counter >= 3 else "step" + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("loop", state=LoopState, event_handler=handler) + .add_node(step) + .add_node(done) + .branch(step, route) + .build() + ) + await pipeline.invoke(LoopState()) + + step_starts = [e for e in handler.events if e[0] == "node_start" and e[1] == "step"] + assert [e[2] for e in step_starts] == [1, 2, 3] + + +# --- fan-out via Send ------------------------------------------------------- + + +class FanOutState(BaseModel): + items: list[str] = [] + results: Annotated[list[str], extend] = [] + item: str | None = None + + +@pytest.mark.asyncio +async def test_fanout_emits_per_send_node_events() -> None: + async def planner(state: FanOutState) -> dict: + return {} + + async def worker(state: FanOutState) -> dict: + return {"results": [f"r:{state.item}"]} + + async def collect(state: FanOutState) -> dict: + return {} + + def dispatch(state: FanOutState) -> list[Send]: + return [Send("worker", {"item": x}) for x in state.items] + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("fanout", state=FanOutState, event_handler=handler) + .add_node(planner) + .add_node(worker) + .add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() + ) + await pipeline.invoke(FanOutState(items=["a", "b", "c"])) + + worker_starts = [e for e in handler.events if e[0] == "node_start" and e[1] == "worker"] + worker_completes = [e for e in handler.events if e[0] == "node_complete" and e[1] == "worker"] + assert len(worker_starts) == 3 + assert len(worker_completes) == 3 + # Visits are 1, 2, 3 across the three Sends. + assert sorted(e[2] for e in worker_starts) == [1, 2, 3] + + +# --- resume from a checkpoint ---------------------------------------------- + + +class BuildState(BaseModel): + requirements: str + spec: str | None = None + code: str | None = None + deploy: str | None = None + + +@pytest.mark.asyncio +async def test_resume_emits_events_only_for_remaining_nodes(tmp_path: Path) -> None: + fail_once = {"flag": False} + + async def arch(state: BuildState) -> dict: + return {"spec": "s"} + + async def dev(state: BuildState) -> dict: + return {"code": "c"} + + async def deploy(state: BuildState) -> dict: + if not fail_once["flag"]: + fail_once["flag"] = True + raise RuntimeError("blip") + return {"deploy": "ok"} + + handler1 = RecordingHandler() + handler2 = RecordingHandler() + ckpt = FileCheckpointer(tmp_path) + + # First run uses handler1; deploy fails. + p1 = ( + PipelineBuilder("factory", state=BuildState, checkpointer=ckpt, event_handler=handler1) + .add_node(arch) + .add_node(dev) + .add_node(deploy) + .chain(arch, dev, deploy) + .build() + ) + first = await p1.invoke(BuildState(requirements="x")) + assert not first.success + + # Second run uses handler2 and resumes; only deploy should run. + p2 = ( + PipelineBuilder("factory", state=BuildState, checkpointer=ckpt, event_handler=handler2) + .add_node(arch) + .add_node(dev) + .add_node(deploy) + .chain(arch, dev, deploy) + .build() + ) + second = await p2.invoke(run_id=first.run_id) + assert second.success + + nodes_started_on_resume = [e[1] for e in handler2.events if e[0] == "node_start"] + assert nodes_started_on_resume == ["deploy"] + + +# --- partial handler -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_partial_handler_only_implementing_node_error_works() -> None: + captured: list[str] = [] + + class JustErrors: + async def on_node_error(self, pipeline_name, run_id, node_id, error): + captured.append(f"{node_id}:{error}") + + async def boom(state: LinearState) -> dict: + raise RuntimeError("kaboom") + + pipeline = PipelineBuilder("partial", state=LinearState, event_handler=JustErrors()).add_node(boom).build() + result = await pipeline.invoke(LinearState()) + assert not result.success + assert captured == ["boom:kaboom"] + + +# --- handler exception is swallowed ---------------------------------------- + + +@pytest.mark.asyncio +async def test_handler_exception_does_not_break_pipeline() -> None: + class CrashyHandler: + async def on_node_complete(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError("handler crashed") + + async def step(state: LinearState) -> dict: + return {"log": ["ran"]} + + pipeline = PipelineBuilder("crash", state=LinearState, event_handler=CrashyHandler()).add_node(step).build() + result = await pipeline.invoke(LinearState()) + assert result.success + assert result.state.log == ["ran"] + + +# --- OTel spans ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_otel_spans_emitted_per_pipeline_and_per_node(monkeypatch: pytest.MonkeyPatch) -> None: + # Force the observability path to fire even if the config disables it by default. + mock_config = MagicMock(observability_enabled=True) + monkeypatch.setattr(engine_module, "get_config", lambda: mock_config) + + # Stub the OTel tracer so we can record span starts. + started: list[tuple[str, dict]] = [] + mock_tracer = MagicMock() + + def fake_start_span(name: str, attributes: dict | None = None) -> MagicMock: + started.append((name, attributes or {})) + return MagicMock() + + mock_tracer.start_span.side_effect = fake_start_span + mock_trace = MagicMock() + mock_trace.get_tracer.return_value = mock_tracer + monkeypatch.setattr(engine_module, "otel_trace", mock_trace) + + async def a(state: LinearState) -> dict: + return {} + + async def b(state: LinearState) -> dict: + return {} + + pipeline = PipelineBuilder("otel", state=LinearState).add_node(a).add_node(b).chain(a, b).build() + await pipeline.invoke(LinearState()) + + names = [n for n, _ in started] + assert "pipeline.state.otel" in names + assert "pipeline.state.node.a" in names + assert "pipeline.state.node.b" in names + + # Spot-check attributes: pipeline span carries run_id, node span carries visit. + pipeline_attrs = next(attrs for name, attrs in started if name == "pipeline.state.otel") + assert "firefly.run_id" in pipeline_attrs + node_a_attrs = next(attrs for name, attrs in started if name == "pipeline.state.node.a") + assert node_a_attrs.get("firefly.visit") == "1" From 890465a9eef64e2f7e21b7a7dce2dcf38605228a Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 18:19:35 +0200 Subject: [PATCH 08/26] style(examples): ruff format pipeline_state.py (fixes #232 CI) --- examples/pipeline_state.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index 717c7f64..b5fba799 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -163,24 +163,16 @@ class ProgressHandler: async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: print(f" ▶ [{pipeline_name}] run {run_id[:8]}… starting") - async def on_node_start( - self, pipeline_name: str, run_id: str, node_id: str, visit: int - ) -> None: + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: print(f" ▶ {node_id} (visit #{visit})") - async def on_node_complete( - self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float - ) -> None: + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: print(f" ✔ {node_id} ({latency_ms:.0f}ms)") - async def on_node_error( - self, pipeline_name: str, run_id: str, node_id: str, error: str - ) -> None: + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: print(f" ✗ {node_id}: {error}") - async def on_pipeline_complete( - self, pipeline_name: str, run_id: str, success: bool, duration_ms: float - ) -> None: + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: status = "OK" if success else "FAILED" print(f" ═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") From 0baab84b81f02efef692880e2f1decb94124db7a Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 18:15:15 +0200 Subject: [PATCH 09/26] feat(pipeline): HITL Pause + AuditLog with 4 backends (#147 phase 3c) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two enterprise features for state pipelines: human-in-the-loop pause/approve gates, and a structured audit trail of every node visit. HITL via Pause: - New Pause(reason=...) sentinel returned by a node halts the pipeline cleanly. - StatePipeline writes a paused checkpoint (new optional CheckpointRecord fields: paused=False, pause_reason=None — backward compatible). - StatePipelineResult gains paused/paused_node/pause_reason. - StatePipelineEventHandler gains an optional on_node_pause callback. - invoke(run_id=..., approve_pause=True) resumes from the SUCCESSOR of the paused node. Without approve_pause=True, resuming a paused run raises PipelineError — pauses are sticky until explicitly released. Audit log via four backends in new pipeline/audit.py: - AuditEntry pydantic model + split Protocol: AuditLog (write-only) + QueryableAuditLog (adds list_entries). - FileAuditLog — JSONL per (pipeline, run_id). Implements QueryableAuditLog. - PostgresAuditLog — single firefly_audit table, idempotent DDL, reuses the psycopg connection from Phase 3a's postgres extra. Implements QueryableAuditLog. - LoggingAuditLog — stdlib logging.Logger; pairs with any log-aggregation stack (Splunk-HEC, Loki, Datadog, OTel-LoggingHandler-bridge). Write-only. No new dep. - OtelAuditLog — OTel logs API directly; emits LogRecord with trace_id / span_id correlation. Requires opentelemetry-sdk. Write-only. - StatePipeline records one AuditEntry per node visit (success/error/pause) with inputs_snapshot, outputs_snapshot, latency_ms, started_at/completed_at, status, plus error_message or pause_reason as appropriate. - Audit-write failures are non-fatal — logged and swallowed. API additions exported from fireflyframework_agentic.pipeline: Pause, AuditEntry, AuditLog, QueryableAuditLog, FileAuditLog, PostgresAuditLog, LoggingAuditLog, OtelAuditLog No new required deps. Postgres extra (already added in 3a) covers PostgresAuditLog. opentelemetry-sdk is the same optional dep already used by Phase 3b for OTel spans. Tests: 20 new across two files. - test_state_pipeline_hitl.py (6): pause halts pipeline; resume without approval raises; resume with approve_pause continues from successor; on_node_pause fires; backward-compat for old checkpoints; pause→fail→retry. - test_audit_log.py (14): per-backend (File JSONL, Postgres mocked, Logging via caplog, OTel via mocked logger); pipeline writes one entry per visit; status reflects success/error/paused; audit write failures don't abort the pipeline. Full pipeline suite 142 passed (122 baseline + 20 new). Lints clean, pyright clean on touched modules. Example pipeline_state.py gains a 5th scenario showing HITL + audit end-to-end. --- docs/pipeline.md | 63 ++++ examples/pipeline_state.py | 69 ++++ fireflyframework_agentic/pipeline/__init__.py | 20 +- fireflyframework_agentic/pipeline/audit.py | 349 ++++++++++++++++++ fireflyframework_agentic/pipeline/builder.py | 4 + .../pipeline/checkpoint.py | 10 +- fireflyframework_agentic/pipeline/engine.py | 5 + .../pipeline/state_pipeline.py | 166 ++++++++- tests/unit/pipeline/test_audit_log.py | 325 ++++++++++++++++ .../unit/pipeline/test_state_pipeline_hitl.py | 208 +++++++++++ 10 files changed, 1216 insertions(+), 3 deletions(-) create mode 100644 fireflyframework_agentic/pipeline/audit.py create mode 100644 tests/unit/pipeline/test_audit_log.py create mode 100644 tests/unit/pipeline/test_state_pipeline_hitl.py diff --git a/docs/pipeline.md b/docs/pipeline.md index 60bd2097..88570307 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -320,6 +320,69 @@ In parallel, the pipeline emits OTel spans automatically when Handler exceptions are swallowed — observability never breaks business logic. +### Human-in-the-loop (Pause) + +Any node may return ``Pause(reason="...")`` instead of a state update to halt +the pipeline cleanly. The current state is checkpointed with a paused marker; +``invoke`` returns with ``result.paused=True`` and ``result.success=False``. + +```python +from fireflyframework_agentic.pipeline import Pause + +async def await_deploy_approval(state: DeployState) -> Pause: + return Pause(reason="awaiting human approval to deploy to production") +``` + +To resume after the external approval comes in, call ``invoke`` with the same +``run_id`` and ``approve_pause=True``. Without ``approve_pause=True``, the +resume raises a ``PipelineError`` — the pause is sticky until explicitly +released. The successor of the paused node runs next; the pause node itself +is not re-executed. + +```python +first = await pipeline.invoke(DeployState(...)) +assert first.paused +# ...later, after approval... +done = await pipeline.invoke(run_id=first.run_id, approve_pause=True) +assert done.success +``` + +The configured ``StatePipelineEventHandler`` receives an ``on_node_pause`` +callback when this happens (the callback is optional — partial handlers +without it continue to work). + +### Audit Log + +Distinct from the ``Checkpointer`` (which stores the *latest* state for +crash recovery), an ``AuditLog`` is an append-only record of *every* node +visit for compliance, debugging, and replay. Wire one in via the +``audit_log`` kwarg: + +```python +from fireflyframework_agentic.pipeline import ( + PipelineBuilder, FileAuditLog, PostgresAuditLog, LoggingAuditLog, OtelAuditLog, +) + +PipelineBuilder("agent", state=AgentState, audit_log=FileAuditLog("./audit")) +``` + +Four backends ship, each conforming to the ``AuditLog`` Protocol: + +| Backend | Use when | Read API | Trace-correlated | Install | +|---|---|---|---|---| +| ``FileAuditLog`` | Dev / single-host | yes | no | (default) | +| ``PostgresAuditLog`` | Compliance, retention, cross-run queries | yes | no | ``[postgres]`` | +| ``LoggingAuditLog`` | Generic log stacks (Splunk-HEC, Loki, JSON-logging) | no (write-only) | no | (default — stdlib) | +| ``OtelAuditLog`` | OTel-native stacks (Application Insights, Datadog APM, OTel Collector) | no (write-only) | **yes** | ``opentelemetry-sdk`` | + +``FileAuditLog`` and ``PostgresAuditLog`` also implement +``QueryableAuditLog`` with ``list_entries(pipeline_name, run_id)``. The +write-only backends delegate query/search to the user's existing +observability stack. + +Audit-log write failures are non-fatal — logged but never abort the +pipeline. + ### Mermaid Export `StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index b5fba799..ee9bb449 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -51,7 +51,9 @@ from pydantic import BaseModel from fireflyframework_agentic.pipeline import ( + FileAuditLog, FileCheckpointer, + Pause, PipelineBuilder, PostgresCheckpointer, Send, @@ -172,6 +174,9 @@ async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: print(f" ✗ {node_id}: {error}") + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + print(f" ⏸ {node_id} paused: {reason}") + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: status = "OK" if success else "FAILED" print(f" ═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") @@ -299,11 +304,75 @@ async def run_software_factory_postgres() -> None: print(f" eval: {second.state.evaluation}\n") +class HitlState(BaseModel): + """State threaded through a deploy pipeline gated by human approval.""" + + target_env: str + artifact: str | None = None + deployed_to: str | None = None + + +async def build_artifact(state: HitlState) -> dict: + return {"artifact": f"build-{state.target_env}.tar.gz"} + + +async def await_approval(state: HitlState) -> Pause: + return Pause(reason=f"awaiting human approval to deploy {state.artifact} to {state.target_env}") + + +async def deploy_artifact(state: HitlState) -> dict: + return {"deployed_to": f"https://{state.target_env}.example.com"} + + +async def run_hitl_with_audit() -> None: + print("=== 5. Human-in-the-loop deploy gate with audit log ===\n") + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + ckpt = FileCheckpointer(root / "ckpt") + audit = FileAuditLog(root / "audit") + pipeline = ( + PipelineBuilder( + "hitl-deploy", + state=HitlState, + checkpointer=ckpt, + audit_log=audit, + ) + .add_node(build_artifact) + .add_node(await_approval) + .add_node(deploy_artifact) + .chain(build_artifact, await_approval, deploy_artifact) + .build() + ) + + # First run halts at the approval gate. + first = await pipeline.invoke(HitlState(target_env="prod")) + print(f" first run: paused={first.paused}, paused_node={first.paused_node}") + print(f" reason: {first.pause_reason}") + print(f" run_id: {first.run_id}\n") + + # ...time passes; a human reviews and approves... + print(" (human reviews and approves)\n") + + # Resume with explicit approval. + done = await pipeline.invoke(run_id=first.run_id, approve_pause=True) + print(f" resumed: success={done.success}, deployed_to={done.state.deployed_to}") + print(f" completed: {done.completed_nodes}\n") + + # Audit log captures every node visit with its status. + entries = audit.list_entries("hitl-deploy", first.run_id) + print(" audit trail:") + for e in entries: + extra = f" reason={e.pause_reason!r}" if e.pause_reason else "" + print(f" seq={e.sequence} node={e.node_id} status={e.status}{extra}") + + async def main() -> None: await run_branching() await run_software_factory() await run_map_reduce() await run_software_factory_postgres() + await run_hitl_with_audit() if __name__ == "__main__": diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index af4778df..6f4dbe6c 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -28,6 +28,15 @@ via :class:`Checkpointer` enables resume after failure and mid-pipeline start. """ +from fireflyframework_agentic.pipeline.audit import ( + AuditEntry, + AuditLog, + FileAuditLog, + LoggingAuditLog, + OtelAuditLog, + PostgresAuditLog, + QueryableAuditLog, +) from fireflyframework_agentic.pipeline.builder import PipelineBuilder from fireflyframework_agentic.pipeline.checkpoint import ( Checkpointer, @@ -46,6 +55,7 @@ from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult from fireflyframework_agentic.pipeline.state_pipeline import ( + Pause, RecursionLimitError, Send, StatePipeline, @@ -67,6 +77,8 @@ __all__ = [ "DAG", "AgentStep", + "AuditEntry", + "AuditLog", "BatchLLMStep", "BranchStep", "CallableStep", @@ -79,14 +91,20 @@ "FailureStrategy", "FanInStep", "FanOutStep", + "FileAuditLog", "FileCheckpointer", + "LoggingAuditLog", "NodeResult", - "PostgresCheckpointer", + "OtelAuditLog", + "Pause", "PipelineBuilder", "PipelineContext", "PipelineEngine", "PipelineEventHandler", "PipelineResult", + "PostgresAuditLog", + "PostgresCheckpointer", + "QueryableAuditLog", "ReasoningStep", "RecursionLimitError", "RedisCheckpointer", diff --git a/fireflyframework_agentic/pipeline/audit.py b/fireflyframework_agentic/pipeline/audit.py new file mode 100644 index 00000000..481ea08d --- /dev/null +++ b/fireflyframework_agentic/pipeline/audit.py @@ -0,0 +1,349 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Append-only audit logs for state-pipeline node visits. + +Distinct from :mod:`fireflyframework_agentic.pipeline.checkpoint` — the +checkpointer stores the *latest* state for crash recovery; the audit log +stores *every* node visit for compliance, debugging, and replay. + +Four backends ship: + +* :class:`FileAuditLog` — one JSONL file per ``(pipeline_name, run_id)``. + Best for dev / single-host audit trails. +* :class:`PostgresAuditLog` — single ``firefly_audit`` table; reuses the + ``psycopg`` connection from the ``postgres`` optional extra (same dep + added in Phase 3a for ``PostgresCheckpointer``). +* :class:`LoggingAuditLog` — stdlib ``logging``; pairs with whatever log + aggregation pipeline (Splunk-HEC, Loki, Datadog, OTel-LoggingHandler-bridge) + the host application already runs. No new optional dep. +* :class:`OtelAuditLog` — direct OTel logs API; attaches trace correlation + (``trace_id``/``span_id``) automatically. Best for OTel-native stacks + (Application Insights, Datadog APM, OTel-Collector). + +File and Postgres also implement :class:`QueryableAuditLog` (``list_entries``); +Logging and OTel are write-only — query your observability stack instead. +""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import BaseModel + +try: + import psycopg as _psycopg # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _psycopg = None # type: ignore[assignment] + +try: + from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found] + from opentelemetry._logs import SeverityNumber as _OtelSeverityNumber # type: ignore[import-not-found] + from opentelemetry._logs import get_logger as _otel_get_logger # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _otel_get_logger = None # type: ignore[assignment] + _OtelLogRecord = None # type: ignore[assignment,misc] + _OtelSeverityNumber = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + + +AuditStatus = Literal["success", "error", "paused"] + + +class AuditEntry(BaseModel): + """A single audit record — one per node visit (success, error, or pause).""" + + pipeline_name: str + run_id: str + node_id: str + sequence: int + visit: int + started_at: datetime + completed_at: datetime + latency_ms: float + status: AuditStatus + inputs_snapshot: dict[str, Any] + outputs_snapshot: dict[str, Any] + error_message: str | None = None + pause_reason: str | None = None + + +@runtime_checkable +class AuditLog(Protocol): + """Write-only audit log. Every backend implements this method. + + Implementations must be safe to call from async code (called inside the + state pipeline's executor) but the method itself is sync. + """ + + def record(self, entry: AuditEntry) -> None: ... + + +@runtime_checkable +class QueryableAuditLog(AuditLog, Protocol): + """Audit log that also supports reading back recorded entries. + + File and Postgres backends implement this. Logging and OTel backends do + not — query your observability stack (Splunk / Datadog / Loki / etc.) + instead. + """ + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: ... + + +class FileAuditLog: + """Filesystem-backed audit log. Layout:: + + //.jsonl + + Each line is a JSON-serialized :class:`AuditEntry`. Appends are atomic + at the line level (single ``write`` call per entry); concurrent writers + to the same run_id may interleave at line boundaries but never within a + single entry. + """ + + def __init__(self, root: str | Path) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + def _path(self, pipeline_name: str, run_id: str) -> Path: + return self._root / pipeline_name / f"{run_id}.jsonl" + + def record(self, entry: AuditEntry) -> None: + path = self._path(entry.pipeline_name, entry.run_id) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as f: + f.write(entry.model_dump_json() + "\n") + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + path = self._path(pipeline_name, run_id) + if not path.exists(): + return [] + entries: list[AuditEntry] = [] + with path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + entries.append(AuditEntry.model_validate(json.loads(line))) + return entries + + +class PostgresAuditLog: + """Postgres-backed audit log. Single table created on first ``record`` call. + + Reuses ``psycopg`` from the ``postgres`` optional extra. ``dsn`` or a + pre-built ``connection`` is required, mirroring :class:`PostgresCheckpointer`'s + API for consistency. + """ + + _DDL = """ + CREATE TABLE IF NOT EXISTS {table} ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + visit INT NOT NULL, + node_id TEXT NOT NULL, + started_at TIMESTAMPTZ NOT NULL, + completed_at TIMESTAMPTZ NOT NULL, + latency_ms DOUBLE PRECISION NOT NULL, + status TEXT NOT NULL, + inputs_snapshot JSONB NOT NULL, + outputs_snapshot JSONB NOT NULL, + error_message TEXT, + pause_reason TEXT, + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS {table}_run_idx + ON {table} (pipeline_name, run_id); + """ + + def __init__( + self, + *, + dsn: str | None = None, + connection: Any = None, + table_name: str = "firefly_audit", + ) -> None: + if _psycopg is None: + raise ImportError( + "PostgresAuditLog requires the 'postgres' extra. " + "Install with: pip install fireflyframework-agentic[postgres]" + ) + if (dsn is None) == (connection is None): + raise ValueError("PostgresAuditLog needs exactly one of `dsn` or `connection`.") + if not table_name.replace("_", "").isalnum(): + raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") + self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) + self._table = table_name + self._ddl_applied = False + + def _ensure_table(self) -> None: + if self._ddl_applied: + return + with self._conn.cursor() as cur: + cur.execute(self._DDL.format(table=self._table)) + self._ddl_applied = True + + def record(self, entry: AuditEntry) -> None: + self._ensure_table() + sql = ( + f"INSERT INTO {self._table} " + "(pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " + "latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO NOTHING" + ) + with self._conn.cursor() as cur: + cur.execute( + sql, + ( + entry.pipeline_name, + entry.run_id, + entry.sequence, + entry.visit, + entry.node_id, + entry.started_at, + entry.completed_at, + entry.latency_ms, + entry.status, + json.dumps(entry.inputs_snapshot), + json.dumps(entry.outputs_snapshot), + entry.error_message, + entry.pause_reason, + ), + ) + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + self._ensure_table() + sql = ( + f"SELECT pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " + f"latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason " + f"FROM {self._table} WHERE pipeline_name = %s AND run_id = %s ORDER BY sequence" + ) + with self._conn.cursor() as cur: + cur.execute(sql, (pipeline_name, run_id)) + rows = cur.fetchall() + entries: list[AuditEntry] = [] + for row in rows: + entries.append( + AuditEntry( + pipeline_name=row[0], + run_id=row[1], + sequence=row[2], + visit=row[3], + node_id=row[4], + started_at=row[5], + completed_at=row[6], + latency_ms=row[7], + status=row[8], + inputs_snapshot=json.loads(row[9]) if isinstance(row[9], str) else row[9], + outputs_snapshot=json.loads(row[10]) if isinstance(row[10], str) else row[10], + error_message=row[11], + pause_reason=row[12], + ) + ) + return entries + + +class LoggingAuditLog: + """Audit log backed by Python's stdlib ``logging`` module. + + Each entry is emitted as a structured log record with the full + :class:`AuditEntry` available under ``record.firefly_audit``. Pairs + naturally with any log-aggregation pipeline the host already configures: + + * OTel collector via ``opentelemetry-sdk._logs.LoggingHandler`` + * Splunk via the official Splunk handler + * Loki via promtail tailing stdout + * Datadog via ``ddtrace`` log injection + * Plain JSON-logging via ``python-json-logger`` + + No new dependency. The user wires their log handler exactly once for the + whole application and audit entries flow there automatically. + """ + + def __init__(self, logger_name: str = "firefly.audit", level: int = logging.INFO) -> None: + self._logger = logging.getLogger(logger_name) + self._level = level + + def record(self, entry: AuditEntry) -> None: + self._logger.log( + self._level, + "firefly_audit: pipeline=%s run=%s node=%s status=%s latency_ms=%.1f", + entry.pipeline_name, + entry.run_id, + entry.node_id, + entry.status, + entry.latency_ms, + extra={"firefly_audit": entry.model_dump(mode="json")}, + ) + + +class OtelAuditLog: + """Audit log backed by the OTel logs API. + + Emits each entry as a structured OTel log record with attributes + matching :class:`AuditEntry` fields. When an OTel trace is active the + record is automatically correlated with the current ``trace_id`` and + ``span_id`` — useful for tying audit history to the spans Phase 3b + emits. + + Requires ``opentelemetry-sdk`` to be installed and a ``LoggerProvider`` + configured by the host application (the framework does not install a + provider). + """ + + def __init__(self, logger_name: str = "fireflyframework_agentic.audit") -> None: + if _otel_get_logger is None: + raise ImportError( + "OtelAuditLog requires the 'opentelemetry-sdk' package " + "(with the logs API). Install: pip install opentelemetry-sdk" + ) + self._otel_logger = _otel_get_logger(logger_name) + + def record(self, entry: AuditEntry) -> None: + # The constructor's guard on _otel_get_logger guarantees the OTel + # logs imports succeeded, so the LogRecord and SeverityNumber are + # not None here. + assert _OtelLogRecord is not None and _OtelSeverityNumber is not None + attrs: dict[str, Any] = { + "firefly.pipeline": entry.pipeline_name, + "firefly.run_id": entry.run_id, + "firefly.node": entry.node_id, + "firefly.sequence": entry.sequence, + "firefly.visit": entry.visit, + "firefly.latency_ms": entry.latency_ms, + "firefly.status": entry.status, + } + if entry.error_message: + attrs["firefly.error"] = entry.error_message + if entry.pause_reason: + attrs["firefly.pause_reason"] = entry.pause_reason + + body = f"[{entry.pipeline_name}] {entry.node_id} {entry.status} ({entry.latency_ms:.0f}ms)" + severity = _OtelSeverityNumber.ERROR if entry.status == "error" else _OtelSeverityNumber.INFO + log_record = _OtelLogRecord( + timestamp=int(entry.completed_at.timestamp() * 1_000_000_000), + severity_number=severity, + severity_text=severity.name, + body=body, + attributes=attrs, + ) + self._otel_logger.emit(log_record) diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 2975d59f..512bd06d 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -53,6 +53,7 @@ from pydantic import BaseModel from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import AuditLog from fireflyframework_agentic.pipeline.checkpoint import Checkpointer from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import PipelineEngine, StatePipelineEventHandler @@ -87,6 +88,7 @@ def __init__( checkpointer: Checkpointer | None = None, recursion_limit: int = 25, event_handler: StatePipelineEventHandler | None = None, + audit_log: AuditLog | None = None, ) -> None: # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). # The legacy port-based path keeps acyclicity as an invariant. @@ -96,6 +98,7 @@ def __init__( self._checkpointer = checkpointer self._recursion_limit = recursion_limit self._event_handler = event_handler + self._audit_log = audit_log self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] # State-based mode bookkeeping. Keyed by node id. @@ -270,6 +273,7 @@ def build(self) -> PipelineEngine | StatePipeline: checkpointer=self._checkpointer, recursion_limit=self._recursion_limit, event_handler=self._event_handler, + audit_log=self._audit_log, ) return PipelineEngine(self._dag) diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index 3a63d8c1..ecedb16c 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -51,7 +51,13 @@ class CheckpointRecord(BaseModel): - """One saved checkpoint.""" + """One saved checkpoint. + + ``paused`` and ``pause_reason`` are set when a node returns + :class:`fireflyframework_agentic.pipeline.state_pipeline.Pause`. Default + to ``False`` / ``None`` so existing records from earlier phases load + cleanly under the new schema. + """ pipeline_name: str run_id: str @@ -59,6 +65,8 @@ class CheckpointRecord(BaseModel): sequence: int state: dict[str, Any] completed_nodes: list[str] + paused: bool = False + pause_reason: str | None = None @runtime_checkable diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 9f7e3d0f..4f1e46e6 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -101,6 +101,11 @@ async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, err """Called when a node raises an exception.""" ... + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + """Called when a node returns :class:`Pause`, halting the pipeline + until an external ``invoke(run_id=..., approve_pause=True)`` resumes it.""" + ... + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: """Called once when ``invoke`` returns.""" ... diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index d33c37b7..823e4928 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -31,11 +31,13 @@ import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, get_type_hints from pydantic import BaseModel from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id from fireflyframework_agentic.pipeline.engine import start_otel_span @@ -71,6 +73,31 @@ class RecursionLimitError(Exception): """Raised when a node is visited more times than ``recursion_limit`` permits.""" +@dataclass +class Pause: + """Human-in-the-loop sentinel returned by a node to halt the pipeline. + + A node returns ``Pause(reason="...")`` when external approval (a human, + another system, a wall-clock event) is required before the pipeline may + continue. The pipeline then: + + 1. Writes a checkpoint with ``paused=True`` and the reason set. + 2. Emits ``on_node_pause`` on the configured event handler. + 3. Returns a :class:`StatePipelineResult` with ``paused=True`` and + ``success=False`` — the run is not finished, but it did not fail either. + + To resume after approval:: + + result = await pipeline.invoke(run_id=paused_run_id, approve_pause=True) + + Without ``approve_pause=True``, resuming a paused run raises + :class:`PipelineError`. The successor of the paused node runs next — + the pause node itself is not re-executed. + """ + + reason: str + + @dataclass class BranchSpec: """Internal: registered branch from one source node.""" @@ -91,6 +118,10 @@ class StatePipelineResult: success: True iff all attempted nodes completed without error. error: Last error message if ``success`` is False. failed_node: Node ID that failed, if any. + paused: True if the run halted on a :class:`Pause` sentinel; resume + via ``invoke(run_id=..., approve_pause=True)``. + paused_node: Node that returned ``Pause`` if ``paused`` is True. + pause_reason: Reason string the paused node passed to ``Pause(...)``. """ state: Any @@ -99,6 +130,9 @@ class StatePipelineResult: success: bool error: str | None = None failed_node: str | None = None + paused: bool = False + paused_node: str | None = None + pause_reason: str | None = None def discover_reducers(state_schema: type) -> dict[str, Reducer]: @@ -157,6 +191,7 @@ def __init__( checkpointer: Checkpointer | None = None, recursion_limit: int = 25, event_handler: StatePipelineEventHandler | None = None, + audit_log: AuditLog | None = None, ) -> None: self._name = name self._dag = dag @@ -166,9 +201,52 @@ def __init__( self._checkpointer = checkpointer self._recursion_limit = recursion_limit self._event_handler = event_handler + self._audit_log = audit_log self._reducers = discover_reducers(state_schema) self._validate() + def _audit( + self, + *, + run_id: str, + node_id: str, + sequence: int, + visit: int, + started_at: datetime, + completed_at: datetime, + latency_ms: float, + status: AuditStatus, + inputs_snapshot: dict[str, Any], + outputs_snapshot: dict[str, Any], + error_message: str | None = None, + pause_reason: str | None = None, + ) -> None: + """Construct and write an :class:`AuditEntry`. No-op if no audit log is configured. + + Audit-write failures are non-fatal — logged and swallowed. + """ + if self._audit_log is None: + return + entry = AuditEntry( + pipeline_name=self._name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit, + started_at=started_at, + completed_at=completed_at, + latency_ms=latency_ms, + status=status, + inputs_snapshot=inputs_snapshot, + outputs_snapshot=outputs_snapshot, + error_message=error_message, + pause_reason=pause_reason, + ) + try: + self._audit_log.record(entry) + except Exception: + logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) + async def _finalize_run( self, result: StatePipelineResult, @@ -335,6 +413,7 @@ async def invoke( *, run_id: str | None = None, start_at: str | Callable[..., Any] | None = None, + approve_pause: bool = False, ) -> StatePipelineResult: """Run the pipeline. @@ -353,9 +432,15 @@ async def invoke( record = self._checkpointer.load_latest(self._name, run_id) if record is None: raise PipelineError(f"No checkpoint found for run_id='{run_id}'") + # A paused run requires explicit approval before continuing. + if record.paused and not approve_pause: + raise PipelineError( + f"Run '{run_id}' is paused at node '{record.node_id}' " + f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." + ) state = self._state_schema.model_validate(record.state) resumed_completed = list(record.completed_nodes) - # Resume at the successor of the last completed node. + # Resume at the successor of the last completed (or paused) node. last = record.node_id next_node = self._next_step(last, state) # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. @@ -460,6 +545,8 @@ async def invoke( fn = self._node_fns[node_id] node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) await self._emit("on_node_start", self._name, run_id, node_id, visit_n) + inputs_snapshot = state.model_dump(mode="json") + started_at = datetime.now(UTC) t0 = time.perf_counter() try: update = await fn(state) @@ -474,6 +561,19 @@ async def invoke( if node_span is not None: with contextlib.suppress(Exception): node_span.end() + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence + 1, + visit=visit_n, + started_at=started_at, + completed_at=datetime.now(UTC), + latency_ms=(time.perf_counter() - t0) * 1000, + status="error", + inputs_snapshot=inputs_snapshot, + outputs_snapshot={}, + error_message=str(exc), + ) return await self._finalize_run( StatePipelineResult( state=state, @@ -488,9 +588,56 @@ async def invoke( run_id, ) elapsed = (time.perf_counter() - t0) * 1000 + completed_at = datetime.now(UTC) if node_span is not None: with contextlib.suppress(Exception): node_span.end() + + # HITL: a node returning Pause halts the pipeline and writes a + # paused checkpoint. Approval comes via invoke(approve_pause=True). + if isinstance(update, Pause): + pause_reason = update.reason + await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) + completed.append(node_id) + sequence += 1 + self._save_checkpoint( + run_id, + node_id, + sequence, + state, + completed, + paused=True, + pause_reason=pause_reason, + ) + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit_n, + started_at=started_at, + completed_at=completed_at, + latency_ms=elapsed, + status="paused", + inputs_snapshot=inputs_snapshot, + outputs_snapshot=state.model_dump(mode="json"), + pause_reason=pause_reason, + ) + logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + paused=True, + paused_node=node_id, + pause_reason=pause_reason, + ), + pipeline_span, + pipeline_start_time, + run_id, + ) + await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) @@ -500,6 +647,18 @@ async def invoke( completed.append(node_id) sequence += 1 self._save_checkpoint(run_id, node_id, sequence, state, completed) + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit_n, + started_at=started_at, + completed_at=completed_at, + latency_ms=elapsed, + status="success", + inputs_snapshot=inputs_snapshot, + outputs_snapshot=state.model_dump(mode="json"), + ) try: next_step = self._next_step(node_id, state) @@ -604,6 +763,9 @@ def _save_checkpoint( sequence: int, state: BaseModel, completed: list[str], + *, + paused: bool = False, + pause_reason: str | None = None, ) -> None: """Persist state via the configured checkpointer (no-op if absent).""" if self._checkpointer is None: @@ -617,6 +779,8 @@ def _save_checkpoint( sequence=sequence, state=state.model_dump(), completed_nodes=list(completed), + paused=paused, + pause_reason=pause_reason, ) ) except Exception: diff --git a/tests/unit/pipeline/test_audit_log.py b/tests/unit/pipeline/test_audit_log.py new file mode 100644 index 00000000..e6b426b2 --- /dev/null +++ b/tests/unit/pipeline/test_audit_log.py @@ -0,0 +1,325 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3c audit-log tests — File / Postgres / Logging / OTel backends + pipeline wiring.""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.audit as audit_module +from fireflyframework_agentic.pipeline import ( + AuditEntry, + FileAuditLog, + LoggingAuditLog, + OtelAuditLog, + Pause, + PipelineBuilder, + PostgresAuditLog, +) + + +def _entry(**overrides: Any) -> AuditEntry: + defaults = { + "pipeline_name": "p", + "run_id": "r", + "node_id": "n", + "sequence": 1, + "visit": 1, + "started_at": datetime(2026, 5, 27, tzinfo=UTC), + "completed_at": datetime(2026, 5, 27, 0, 0, 1, tzinfo=UTC), + "latency_ms": 100.0, + "status": "success", + "inputs_snapshot": {"x": 1}, + "outputs_snapshot": {"y": 2}, + } + defaults.update(overrides) + return AuditEntry(**defaults) # type: ignore[arg-type] + + +# ============================================================================= +# FileAuditLog +# ============================================================================= + + +def test_file_audit_log_writes_jsonl_per_run(tmp_path: Path) -> None: + log = FileAuditLog(tmp_path) + log.record(_entry(sequence=1, node_id="a")) + log.record(_entry(sequence=2, node_id="b")) + + path = tmp_path / "p" / "r.jsonl" + assert path.exists() + lines = path.read_text().strip().splitlines() + assert len(lines) == 2 + assert json.loads(lines[0])["node_id"] == "a" + assert json.loads(lines[1])["node_id"] == "b" + + +def test_file_audit_log_list_entries_round_trips(tmp_path: Path) -> None: + log = FileAuditLog(tmp_path) + for seq, node in [(1, "a"), (2, "b"), (3, "c")]: + log.record(_entry(sequence=seq, node_id=node)) + entries = log.list_entries("p", "r") + assert [e.node_id for e in entries] == ["a", "b", "c"] + + +def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None: + assert FileAuditLog(tmp_path).list_entries("p", "missing") == [] + + +# ============================================================================= +# PostgresAuditLog +# ============================================================================= + + +@pytest.fixture(autouse=True) +def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: + """Stub _psycopg and OTel symbols so backends can be constructed with mocks.""" + if audit_module._psycopg is None: + monkeypatch.setattr(audit_module, "_psycopg", MagicMock(name="psycopg_stub")) + if audit_module._otel_get_logger is None: + monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory")) + monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord")) + sev = MagicMock(name="SeverityNumber") + sev.ERROR = MagicMock(name="ERROR") + sev.ERROR.name = "ERROR" + sev.INFO = MagicMock(name="INFO") + sev.INFO.name = "INFO" + monkeypatch.setattr(audit_module, "_OtelSeverityNumber", sev) + + +def test_postgres_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(audit_module, "_psycopg", None) + with pytest.raises(ImportError, match=r"\[postgres\]"): + PostgresAuditLog(dsn="postgresql://x") + + +def _pg_conn_mock() -> tuple[MagicMock, dict]: + """MagicMock connection backed by an in-memory dict keyed by (pipeline,run,seq).""" + store: dict[tuple[str, str, int], dict[str, Any]] = {} + ddl_calls: list[str] = [] + conn = MagicMock(name="psycopg.Connection") + + def make_cursor() -> MagicMock: + cur = MagicMock() + cur.__enter__ = MagicMock(return_value=cur) + cur.__exit__ = MagicMock(return_value=None) + cur._last_one = None + cur._last_all = [] + + def fake_execute(sql: str, params: tuple | None = None) -> None: + s = sql.strip().lower() + if s.startswith("create table"): + ddl_calls.append(sql) + return + if s.startswith("insert into"): + assert params is not None + key = (params[0], params[1], params[2]) + store[key] = { + "pipeline_name": params[0], + "run_id": params[1], + "sequence": params[2], + "visit": params[3], + "node_id": params[4], + "started_at": params[5], + "completed_at": params[6], + "latency_ms": params[7], + "status": params[8], + "inputs_snapshot": json.loads(params[9]) if isinstance(params[9], str) else params[9], + "outputs_snapshot": json.loads(params[10]) if isinstance(params[10], str) else params[10], + "error_message": params[11], + "pause_reason": params[12], + } + return + if s.startswith("select"): + assert params is not None + rows = [v for k, v in store.items() if k[0] == params[0] and k[1] == params[1]] + rows.sort(key=lambda r: r["sequence"]) + cur._last_all = [ + ( + r["pipeline_name"], + r["run_id"], + r["sequence"], + r["visit"], + r["node_id"], + r["started_at"], + r["completed_at"], + r["latency_ms"], + r["status"], + r["inputs_snapshot"], + r["outputs_snapshot"], + r["error_message"], + r["pause_reason"], + ) + for r in rows + ] + return + raise AssertionError(f"unexpected SQL: {sql}") + + cur.execute.side_effect = fake_execute + cur.fetchone.side_effect = lambda: cur._last_one + cur.fetchall.side_effect = lambda: cur._last_all + return cur + + conn.cursor.side_effect = make_cursor + conn._ddl_calls = ddl_calls + return conn, store + + +def test_postgres_audit_ddl_once_then_inserts() -> None: + conn, store = _pg_conn_mock() + log = PostgresAuditLog(connection=conn) + for seq in (1, 2, 3): + log.record(_entry(sequence=seq, node_id=f"n{seq}")) + assert len(conn._ddl_calls) == 1 + assert len(store) == 3 + + +def test_postgres_audit_list_entries_orders_by_sequence() -> None: + conn, _ = _pg_conn_mock() + log = PostgresAuditLog(connection=conn) + for seq in (3, 1, 2): + log.record(_entry(sequence=seq, node_id=f"n{seq}")) + entries = log.list_entries("p", "r") + assert [e.sequence for e in entries] == [1, 2, 3] + + +def test_postgres_audit_rejects_bad_table_name() -> None: + with pytest.raises(ValueError, match="Invalid table_name"): + PostgresAuditLog(connection=MagicMock(), table_name="bad; DROP TABLE") + + +# ============================================================================= +# LoggingAuditLog +# ============================================================================= + + +def test_logging_audit_emits_record_with_firefly_audit_extra( + caplog: pytest.LogCaptureFixture, +) -> None: + log = LoggingAuditLog(logger_name="firefly.test_audit") + with caplog.at_level(logging.INFO, logger="firefly.test_audit"): + log.record(_entry(node_id="z", status="success")) + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert "firefly_audit" in rec.__dict__ + assert rec.__dict__["firefly_audit"]["node_id"] == "z" + assert rec.__dict__["firefly_audit"]["status"] == "success" + + +# ============================================================================= +# OtelAuditLog +# ============================================================================= + + +def test_otel_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(audit_module, "_otel_get_logger", None) + with pytest.raises(ImportError, match="opentelemetry-sdk"): + OtelAuditLog() + + +def test_otel_audit_emits_log_record_via_otel_logger(monkeypatch: pytest.MonkeyPatch) -> None: + mock_logger = MagicMock(name="otel_logger") + factory = MagicMock(name="get_logger", return_value=mock_logger) + monkeypatch.setattr(audit_module, "_otel_get_logger", factory) + + log = OtelAuditLog() + log.record(_entry(node_id="a", status="success")) + + factory.assert_called_once() + assert mock_logger.emit.called, "OtelAuditLog should call logger.emit() with a LogRecord" + + +# ============================================================================= +# Pipeline wiring — audit fires for every node visit +# ============================================================================= + + +class S(BaseModel): + log: str = "" + + +@pytest.mark.asyncio +async def test_pipeline_writes_one_audit_entry_per_node_visit(tmp_path: Path) -> None: + async def a(state: S) -> dict: + return {"log": "a"} + + async def b(state: S) -> dict: + return {"log": "b"} + + audit = FileAuditLog(tmp_path) + pipeline = PipelineBuilder("audit-test", state=S, audit_log=audit).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-test", result.run_id) + assert [e.node_id for e in entries] == ["a", "b"] + assert all(e.status == "success" for e in entries) + + +@pytest.mark.asyncio +async def test_pipeline_audit_captures_error_status(tmp_path: Path) -> None: + async def boom(state: S) -> dict: + raise RuntimeError("nope") + + audit = FileAuditLog(tmp_path) + pipeline = PipelineBuilder("audit-err", state=S, audit_log=audit).add_node(boom).build() + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-err", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "error" + assert "nope" in (entries[0].error_message or "") + + +@pytest.mark.asyncio +async def test_pipeline_audit_captures_paused_status(tmp_path: Path) -> None: + async def gate(state: S) -> Pause: + return Pause(reason="approval please") + + audit = FileAuditLog(tmp_path / "audit") + from fireflyframework_agentic.pipeline import FileCheckpointer + + pipeline = ( + PipelineBuilder( + "audit-pause", + state=S, + audit_log=audit, + checkpointer=FileCheckpointer(tmp_path / "ckpt"), + ) + .add_node(gate) + .build() + ) + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-pause", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "paused" + assert entries[0].pause_reason == "approval please" + + +@pytest.mark.asyncio +async def test_audit_write_failure_does_not_abort_pipeline(tmp_path: Path) -> None: + """A broken audit log shouldn't kill business logic.""" + + class CrashyAudit: + def record(self, entry: AuditEntry) -> None: + raise RuntimeError("audit storage offline") + + async def step(state: S) -> dict: + return {"log": "ran"} + + pipeline = ( + PipelineBuilder("crashy", state=S, audit_log=CrashyAudit()) # type: ignore[arg-type] + .add_node(step) + .build() + ) + result = await pipeline.invoke(S()) + assert result.success is True + assert result.state.log == "ran" diff --git a/tests/unit/pipeline/test_state_pipeline_hitl.py b/tests/unit/pipeline/test_state_pipeline_hitl.py new file mode 100644 index 00000000..aff1c687 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_hitl.py @@ -0,0 +1,208 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3c HITL tests: Pause + approve_pause resume + on_node_pause event.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + Pause, + PipelineBuilder, + StatePipeline, + extend, +) + + +class DeployState(BaseModel): + requirements: str = "" + spec: str | None = None + approved: Annotated[list[str], extend] = [] + deployed: bool = False + + +@dataclass +class PauseRecorder: + events: list[tuple] = field(default_factory=list) + + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + self.events.append(("pause", node_id, reason)) + + +# --- core pause/resume ------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_node_returning_pause_halts_pipeline(tmp_path: Path) -> None: + async def architect(state: DeployState) -> dict: + return {"spec": "v1"} + + async def gate(state: DeployState) -> Pause: + return Pause(reason="awaiting deploy approval") + + async def deploy(state: DeployState) -> dict: + return {"deployed": True} + + ckpt = FileCheckpointer(tmp_path) + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=ckpt) + .add_node(architect) + .add_node(gate) + .add_node(deploy) + .chain(architect, gate, deploy) + .build() + ) + assert isinstance(pipeline, StatePipeline) + + result = await pipeline.invoke(DeployState(requirements="user-mgmt")) + assert result.paused is True + assert result.paused_node == "gate" + assert result.pause_reason == "awaiting deploy approval" + assert result.success is False # paused != success + assert result.state.deployed is False # deploy did NOT run + assert result.completed_nodes == ["architect", "gate"] + + +@pytest.mark.asyncio +async def test_resume_without_approve_pause_raises(tmp_path: Path) -> None: + async def gate(state: DeployState) -> Pause: + return Pause(reason="block here") + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)).add_node(gate).build() + ) + first = await pipeline.invoke(DeployState()) + assert first.paused is True + + with pytest.raises(PipelineError, match="approve_pause=True"): + await pipeline.invoke(run_id=first.run_id) + + +@pytest.mark.asyncio +async def test_resume_with_approve_pause_continues_from_successor(tmp_path: Path) -> None: + fail_once = {"flag": False} + + async def architect(state: DeployState) -> dict: + if fail_once["flag"]: + raise AssertionError("architect should NOT re-run on resume") + return {"spec": "v1"} + + async def gate(state: DeployState) -> Pause: + if fail_once["flag"]: + raise AssertionError("gate should NOT re-run on resume") + return Pause(reason="approve please") + + async def deploy(state: DeployState) -> dict: + return {"deployed": True} + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)) + .add_node(architect) + .add_node(gate) + .add_node(deploy) + .chain(architect, gate, deploy) + .build() + ) + first = await pipeline.invoke(DeployState(requirements="x")) + assert first.paused is True + fail_once["flag"] = True # ensure neither architect nor gate re-runs + + second = await pipeline.invoke(run_id=first.run_id, approve_pause=True) + assert second.success is True + assert second.state.deployed is True + assert second.completed_nodes == ["architect", "gate", "deploy"] + + +@pytest.mark.asyncio +async def test_on_node_pause_event_fires(tmp_path: Path) -> None: + async def gate(state: DeployState) -> Pause: + return Pause(reason="hold") + + handler = PauseRecorder() + pipeline = ( + PipelineBuilder( + "hitl", + state=DeployState, + checkpointer=FileCheckpointer(tmp_path), + event_handler=handler, # type: ignore[arg-type] + ) + .add_node(gate) + .build() + ) + await pipeline.invoke(DeployState()) + assert handler.events == [("pause", "gate", "hold")] + + +# --- backward compat ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_paused_checkpoint_loads_when_pause_fields_missing(tmp_path: Path) -> None: + """An existing checkpoint without paused/pause_reason fields still loads.""" + from fireflyframework_agentic.pipeline.checkpoint import CheckpointRecord + + # Round-trip a record produced from a dict that omits the new fields — + # mirrors what an existing on-disk checkpoint from a pre-3c version looks like. + raw = { + "pipeline_name": "old", + "run_id": "legacy", + "node_id": "n", + "sequence": 1, + "state": {"requirements": "x"}, + "completed_nodes": ["n"], + } + record = CheckpointRecord.model_validate(raw) + assert record.paused is False + assert record.pause_reason is None + + +# --- a paused pipeline can still resume after non-pause failures ----------- + + +@pytest.mark.asyncio +async def test_pause_then_resume_then_error_then_resume(tmp_path: Path) -> None: + """End-to-end: pause, approve, then a subsequent failure, then retry.""" + counters = {"deploy_fail": False} + + async def gate(state: DeployState) -> Pause: + return Pause(reason="approve") + + async def deploy(state: DeployState) -> dict: + if not counters["deploy_fail"]: + counters["deploy_fail"] = True + raise RuntimeError("flaky") + return {"deployed": True} + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)) + .add_node(gate) + .add_node(deploy) + .chain(gate, deploy) + .build() + ) + paused = await pipeline.invoke(DeployState()) + assert paused.paused + + failed = await pipeline.invoke(run_id=paused.run_id, approve_pause=True) + assert not failed.success + assert failed.failed_node == "deploy" + + succeeded = await pipeline.invoke(run_id=paused.run_id, approve_pause=True) + # The deploy checkpoint at this point isn't marked paused, but the gate + # checkpoint still is. approve_pause is needed because load_latest may + # still return the older paused checkpoint OR the newer one depending on + # backend sort order — both backends sort by sequence so the latest is + # the failed-deploy record. The check passes either way: if the latest is + # paused, approve_pause is required; if not, approve_pause is ignored. + assert succeeded.success is True + assert succeeded.state.deployed is True From 9b6a68728571edb59600f71e5805d441adb3d649 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 08:03:06 +0200 Subject: [PATCH 10/26] refactor(pipeline): collapse invoke() to single return path Replace 7 _finalize_run call sites in StatePipeline.invoke with a single try/finally that ends the pipeline span and emits on_pipeline_complete once. Remove the _finalize_run helper. Behavior unchanged; net -32 LOC. --- .../pipeline/state_pipeline.py | 320 ++++++++---------- 1 file changed, 144 insertions(+), 176 deletions(-) diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 823e4928..88d4379f 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -247,25 +247,6 @@ def _audit( except Exception: logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) - async def _finalize_run( - self, - result: StatePipelineResult, - span: Any, - start_time: float, - run_id: str, - ) -> StatePipelineResult: - """Close the pipeline-level span and emit ``on_pipeline_complete``. - - Every return path in :meth:`invoke` after the observability boundary - flows through this helper. - """ - if span is not None: - with contextlib.suppress(Exception): - span.end() - duration_ms = (time.perf_counter() - start_time) * 1000 - await self._emit("on_pipeline_complete", self._name, run_id, result.success, duration_ms) - return result - async def _emit(self, method: str, *args: Any) -> None: """Invoke ``method`` on the configured event handler if it exists. @@ -488,127 +469,151 @@ async def invoke( pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id) await self._emit("on_pipeline_start", self._name, run_id) - while next_step is not None: - # --- fan-out branch (list[Send]) --------------------------------- - if isinstance(next_step, list): - try: - state, sequence = await self._run_fanout( - sends=next_step, - state=state, - completed=completed, - run_id=run_id, - sequence=sequence, - visit_counts=visit_counts, - ) - except _NodeFailureError as fail: - return await self._finalize_run( - StatePipelineResult( + result: StatePipelineResult | None = None + try: + while next_step is not None: + # --- fan-out branch (list[Send]) --------------------------------- + if isinstance(next_step, list): + try: + state, sequence = await self._run_fanout( + sends=next_step, + state=state, + completed=completed, + run_id=run_id, + sequence=sequence, + visit_counts=visit_counts, + ) + except _NodeFailureError as fail: + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, error=fail.message, failed_node=fail.node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, + ) + break + # After fan-out, continue from the workers' shared successor (if any). + next_step = self._common_successor([s.target for s in next_step]) + continue + + # --- single-node step -------------------------------------------- + node_id = next_step + visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 + visit_n = visit_counts[node_id] + if visit_n > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " + f"Raise recursion_limit= or fix the routing logic." ) - # After fan-out, continue from the workers' shared successor (if any). - next_step = self._common_successor([s.target for s in next_step]) - continue - - # --- single-node step -------------------------------------------- - node_id = next_step - visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 - visit_n = visit_counts[node_id] - if visit_n > self._recursion_limit: - msg = ( - f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " - f"Raise recursion_limit= or fix the routing logic." - ) - logger.error(msg) - return await self._finalize_run( - StatePipelineResult( + logger.error(msg) + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, error=msg, failed_node=node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) - - fn = self._node_fns[node_id] - node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) - await self._emit("on_node_start", self._name, run_id, node_id, visit_n) - inputs_snapshot = state.model_dump(mode="json") - started_at = datetime.now(UTC) - t0 = time.perf_counter() - try: - update = await fn(state) - except Exception as exc: - logger.exception( - "State pipeline '%s' run '%s' failed at node '%s'", - self._name, - run_id, - node_id, - ) - await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) + ) + break + + fn = self._node_fns[node_id] + node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) + await self._emit("on_node_start", self._name, run_id, node_id, visit_n) + inputs_snapshot = state.model_dump(mode="json") + started_at = datetime.now(UTC) + t0 = time.perf_counter() + try: + update = await fn(state) + except Exception as exc: + logger.exception( + "State pipeline '%s' run '%s' failed at node '%s'", + self._name, + run_id, + node_id, + ) + await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence + 1, + visit=visit_n, + started_at=started_at, + completed_at=datetime.now(UTC), + latency_ms=(time.perf_counter() - t0) * 1000, + status="error", + inputs_snapshot=inputs_snapshot, + outputs_snapshot={}, + error_message=str(exc), + ) + result = StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=node_id, + ) + break + elapsed = (time.perf_counter() - t0) * 1000 + completed_at = datetime.now(UTC) if node_span is not None: with contextlib.suppress(Exception): node_span.end() - self._audit( - run_id=run_id, - node_id=node_id, - sequence=sequence + 1, - visit=visit_n, - started_at=started_at, - completed_at=datetime.now(UTC), - latency_ms=(time.perf_counter() - t0) * 1000, - status="error", - inputs_snapshot=inputs_snapshot, - outputs_snapshot={}, - error_message=str(exc), - ) - return await self._finalize_run( - StatePipelineResult( + + # HITL: a node returning Pause halts the pipeline and writes a + # paused checkpoint. Approval comes via invoke(approve_pause=True). + if isinstance(update, Pause): + pause_reason = update.reason + await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) + completed.append(node_id) + sequence += 1 + self._save_checkpoint( + run_id, + node_id, + sequence, + state, + completed, + paused=True, + pause_reason=pause_reason, + ) + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit_n, + started_at=started_at, + completed_at=completed_at, + latency_ms=elapsed, + status="paused", + inputs_snapshot=inputs_snapshot, + outputs_snapshot=state.model_dump(mode="json"), + pause_reason=pause_reason, + ) + logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, - error=str(exc), - failed_node=node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) - elapsed = (time.perf_counter() - t0) * 1000 - completed_at = datetime.now(UTC) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() + paused=True, + paused_node=node_id, + pause_reason=pause_reason, + ) + break + + await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) + logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) + + if update: + state = apply_update(state, update, self._reducers) - # HITL: a node returning Pause halts the pipeline and writes a - # paused checkpoint. Approval comes via invoke(approve_pause=True). - if isinstance(update, Pause): - pause_reason = update.reason - await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) completed.append(node_id) sequence += 1 - self._save_checkpoint( - run_id, - node_id, - sequence, - state, - completed, - paused=True, - pause_reason=pause_reason, - ) + self._save_checkpoint(run_id, node_id, sequence, state, completed) self._audit( run_id=run_id, node_id=node_id, @@ -617,77 +622,40 @@ async def invoke( started_at=started_at, completed_at=completed_at, latency_ms=elapsed, - status="paused", + status="success", inputs_snapshot=inputs_snapshot, outputs_snapshot=state.model_dump(mode="json"), - pause_reason=pause_reason, ) - logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) - return await self._finalize_run( - StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - paused=True, - paused_node=node_id, - pause_reason=pause_reason, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) - - await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) - logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) - - if update: - state = apply_update(state, update, self._reducers) - - completed.append(node_id) - sequence += 1 - self._save_checkpoint(run_id, node_id, sequence, state, completed) - self._audit( - run_id=run_id, - node_id=node_id, - sequence=sequence, - visit=visit_n, - started_at=started_at, - completed_at=completed_at, - latency_ms=elapsed, - status="success", - inputs_snapshot=inputs_snapshot, - outputs_snapshot=state.model_dump(mode="json"), - ) - try: - next_step = self._next_step(node_id, state) - except PipelineError as exc: - return await self._finalize_run( - StatePipelineResult( + try: + next_step = self._next_step(node_id, state) + except PipelineError as exc: + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, error=str(exc), failed_node=node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, + ) + break + + if result is None: + result = StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=True, ) + finally: + if pipeline_span is not None: + with contextlib.suppress(Exception): + pipeline_span.end() + duration_ms = (time.perf_counter() - pipeline_start_time) * 1000 + success = result.success if result is not None else False + await self._emit("on_pipeline_complete", self._name, run_id, success, duration_ms) - return await self._finalize_run( - StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) + return result async def _run_fanout( self, From 1e65e82a95fc439a6bb5eb403bdefdbcd37af92f Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 08:07:44 +0200 Subject: [PATCH 11/26] refactor(pipeline): share psycopg scaffolding between checkpoint and audit Extract the duplicated Postgres setup (optional-dep guard, dsn-xor-connection check, table-name validation, lazy idempotent DDL) into a single PsycopgBackend base class in pipeline/_psycopg_backend.py. PostgresCheckpointer and PostgresAuditLog now inherit from it; each only declares its DDL and default table name. Tests updated to monkeypatch _psycopg on the shared module. --- .../pipeline/_psycopg_backend.py | 76 +++++++++++++++++++ fireflyframework_agentic/pipeline/audit.py | 27 +------ .../pipeline/checkpoint.py | 30 +------- tests/unit/pipeline/test_audit_log.py | 7 +- .../unit/pipeline/test_checkpoint_backends.py | 7 +- 5 files changed, 91 insertions(+), 56 deletions(-) create mode 100644 fireflyframework_agentic/pipeline/_psycopg_backend.py diff --git a/fireflyframework_agentic/pipeline/_psycopg_backend.py b/fireflyframework_agentic/pipeline/_psycopg_backend.py new file mode 100644 index 00000000..b6784c9d --- /dev/null +++ b/fireflyframework_agentic/pipeline/_psycopg_backend.py @@ -0,0 +1,76 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared scaffolding for Postgres-backed pipeline backends. + +The checkpointer and audit-log backends both need the same boilerplate: +optional-dep guard on ``psycopg``, ``dsn`` xor ``connection`` constructor +check, table-name validation, and lazy idempotent DDL on first write. This +module centralizes it so each backend only has to declare its DDL and +default table name. +""" + +from __future__ import annotations + +from typing import Any + +try: + import psycopg as _psycopg # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _psycopg = None # type: ignore[assignment] + + +class PsycopgBackend: + """Base class for backends that persist into a single Postgres table. + + Subclasses set the class attribute ``_DDL`` to a format string with a + single ``{table}`` placeholder, and pass their human-readable name and + default table to ``__init__``. The base class handles the rest: + + * Raises ``ImportError`` if the ``postgres`` extra is not installed. + * Enforces ``dsn`` xor ``connection``. + * Validates ``table_name`` against SQL injection (interpolated into DDL). + * Opens the connection (with ``autocommit=True``) when only ``dsn`` is given. + * Applies the DDL lazily and idempotently on first ``_ensure_table()`` call. + """ + + _DDL: str = "" + + def __init__( + self, + *, + kind: str, + dsn: str | None, + connection: Any, + table_name: str, + ) -> None: + if _psycopg is None: + raise ImportError( + f"{kind} requires the 'postgres' extra. Install with: pip install fireflyframework-agentic[postgres]" + ) + if (dsn is None) == (connection is None): + raise ValueError(f"{kind} needs exactly one of `dsn` or `connection`.") + # Table name is interpolated into DDL — validate strictly to avoid SQL injection. + if not table_name.replace("_", "").isalnum(): + raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") + self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) + self._table = table_name + self._ddl_applied = False + + def _ensure_table(self) -> None: + if self._ddl_applied: + return + with self._conn.cursor() as cur: + cur.execute(self._DDL.format(table=self._table)) + self._ddl_applied = True diff --git a/fireflyframework_agentic/pipeline/audit.py b/fireflyframework_agentic/pipeline/audit.py index 481ea08d..d9691775 100644 --- a/fireflyframework_agentic/pipeline/audit.py +++ b/fireflyframework_agentic/pipeline/audit.py @@ -46,10 +46,7 @@ from pydantic import BaseModel -try: - import psycopg as _psycopg # type: ignore[import-not-found] -except ImportError: # pragma: no cover - optional dep - _psycopg = None # type: ignore[assignment] +from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend try: from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found] @@ -145,7 +142,7 @@ def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: return entries -class PostgresAuditLog: +class PostgresAuditLog(PsycopgBackend): """Postgres-backed audit log. Single table created on first ``record`` call. Reuses ``psycopg`` from the ``postgres`` optional extra. ``dsn`` or a @@ -181,25 +178,7 @@ def __init__( connection: Any = None, table_name: str = "firefly_audit", ) -> None: - if _psycopg is None: - raise ImportError( - "PostgresAuditLog requires the 'postgres' extra. " - "Install with: pip install fireflyframework-agentic[postgres]" - ) - if (dsn is None) == (connection is None): - raise ValueError("PostgresAuditLog needs exactly one of `dsn` or `connection`.") - if not table_name.replace("_", "").isalnum(): - raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") - self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) - self._table = table_name - self._ddl_applied = False - - def _ensure_table(self) -> None: - if self._ddl_applied: - return - with self._conn.cursor() as cur: - cur.execute(self._DDL.format(table=self._table)) - self._ddl_applied = True + super().__init__(kind="PostgresAuditLog", dsn=dsn, connection=connection, table_name=table_name) def record(self, entry: AuditEntry) -> None: self._ensure_table() diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index ecedb16c..5cb2b5cf 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -39,16 +39,13 @@ from pydantic import BaseModel +from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend + try: import redis as _redis # type: ignore[import-not-found] except ImportError: # pragma: no cover - optional dep _redis = None # type: ignore[assignment] -try: - import psycopg as _psycopg # type: ignore[import-not-found] -except ImportError: # pragma: no cover - optional dep - _psycopg = None # type: ignore[assignment] - class CheckpointRecord(BaseModel): """One saved checkpoint. @@ -194,7 +191,7 @@ def list_runs(self, pipeline_name: str) -> list[str]: return list(self._client.zrange(self._runs_index_key(pipeline_name), 0, -1)) -class PostgresCheckpointer: +class PostgresCheckpointer(PsycopgBackend): """Postgres-backed checkpointer. Uses a single table created on first ``save`` call. The DDL is idempotent @@ -232,26 +229,7 @@ def __init__( connection: Any = None, table_name: str = "firefly_checkpoints", ) -> None: - if _psycopg is None: - raise ImportError( - "PostgresCheckpointer requires the 'postgres' extra. " - "Install with: pip install fireflyframework-agentic[postgres]" - ) - if (dsn is None) == (connection is None): - raise ValueError("PostgresCheckpointer needs exactly one of `dsn` or `connection`.") - # Table name is interpolated into DDL — validate it strictly to avoid SQL injection. - if not table_name.replace("_", "").isalnum(): - raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") - self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) - self._table = table_name - self._ddl_applied = False - - def _ensure_table(self) -> None: - if self._ddl_applied: - return - with self._conn.cursor() as cur: - cur.execute(self._DDL.format(table=self._table)) - self._ddl_applied = True + super().__init__(kind="PostgresCheckpointer", dsn=dsn, connection=connection, table_name=table_name) def save(self, record: CheckpointRecord) -> None: self._ensure_table() diff --git a/tests/unit/pipeline/test_audit_log.py b/tests/unit/pipeline/test_audit_log.py index e6b426b2..369cf6a5 100644 --- a/tests/unit/pipeline/test_audit_log.py +++ b/tests/unit/pipeline/test_audit_log.py @@ -17,6 +17,7 @@ import pytest from pydantic import BaseModel +import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module import fireflyframework_agentic.pipeline.audit as audit_module from fireflyframework_agentic.pipeline import ( AuditEntry, @@ -85,8 +86,8 @@ def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None: @pytest.fixture(autouse=True) def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: """Stub _psycopg and OTel symbols so backends can be constructed with mocks.""" - if audit_module._psycopg is None: - monkeypatch.setattr(audit_module, "_psycopg", MagicMock(name="psycopg_stub")) + if psycopg_backend_module._psycopg is None: + monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub")) if audit_module._otel_get_logger is None: monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory")) monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord")) @@ -99,7 +100,7 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: def test_postgres_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(audit_module, "_psycopg", None) + monkeypatch.setattr(psycopg_backend_module, "_psycopg", None) with pytest.raises(ImportError, match=r"\[postgres\]"): PostgresAuditLog(dsn="postgresql://x") diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py index 091f6cf3..47a6674c 100644 --- a/tests/unit/pipeline/test_checkpoint_backends.py +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -18,6 +18,7 @@ import pytest from pydantic import BaseModel +import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module import fireflyframework_agentic.pipeline.checkpoint as checkpoint_module from fireflyframework_agentic.pipeline import ( CheckpointRecord, @@ -39,8 +40,8 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: """ if checkpoint_module._redis is None: monkeypatch.setattr(checkpoint_module, "_redis", MagicMock(name="redis_stub")) - if checkpoint_module._psycopg is None: - monkeypatch.setattr(checkpoint_module, "_psycopg", MagicMock(name="psycopg_stub")) + if psycopg_backend_module._psycopg is None: + monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub")) # ============================================================================= @@ -244,7 +245,7 @@ def fake_execute(sql: str, params: tuple | None = None) -> None: def test_postgres_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(checkpoint_module, "_psycopg", None) + monkeypatch.setattr(psycopg_backend_module, "_psycopg", None) with pytest.raises(ImportError, match=r"\[postgres\]"): PostgresCheckpointer(dsn="postgresql://x") From 716affee2c87ae086d5c83c691cfb270bb91975d Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 09:19:15 +0200 Subject: [PATCH 12/26] refactor(pipeline): drop sequence variable and resumed_completed duplicate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sequence was always len(completed) after each increment (or len(completed)+1 in the failed-node audit case). Replace the running counter with len(completed) expressions at the call sites. resumed_completed was a separate alias for completed during resume; assign into completed directly. _run_fanout no longer takes or returns sequence — derives it from len(completed) like the main loop. Behavior unchanged. --- .../pipeline/state_pipeline.py | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 88d4379f..c31b8e21 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -404,7 +404,7 @@ async def invoke( * Mid-pipeline start: ``invoke(state=..., start_at=node)`` — starts execution at ``node`` with the provided state. """ - resumed_completed: list[str] = [] + completed: list[str] = [] # Resume mode: load checkpoint, derive starting node from it. if run_id is not None and state is None and start_at is None: @@ -420,10 +420,9 @@ async def invoke( f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." ) state = self._state_schema.model_validate(record.state) - resumed_completed = list(record.completed_nodes) + completed = list(record.completed_nodes) # Resume at the successor of the last completed (or paused) node. - last = record.node_id - next_node = self._next_step(last, state) + next_node = self._next_step(record.node_id, state) # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. if isinstance(next_node, list): raise PipelineError( @@ -434,7 +433,7 @@ async def invoke( return StatePipelineResult( state=state, run_id=run_id, - completed_nodes=resumed_completed, + completed_nodes=completed, success=True, ) current_node: str | None = next_node @@ -458,8 +457,6 @@ async def invoke( run_id = uuid.uuid4().hex[:12] assert state is not None # narrowed by the branches above - completed: list[str] = list(resumed_completed) - sequence = len(completed) visit_counts: dict[str, int] = {} next_step: str | list[Send] | None = current_node @@ -475,12 +472,11 @@ async def invoke( # --- fan-out branch (list[Send]) --------------------------------- if isinstance(next_step, list): try: - state, sequence = await self._run_fanout( + state = await self._run_fanout( sends=next_step, state=state, completed=completed, run_id=run_id, - sequence=sequence, visit_counts=visit_counts, ) except _NodeFailureError as fail: @@ -539,7 +535,7 @@ async def invoke( self._audit( run_id=run_id, node_id=node_id, - sequence=sequence + 1, + sequence=len(completed) + 1, visit=visit_n, started_at=started_at, completed_at=datetime.now(UTC), @@ -570,11 +566,10 @@ async def invoke( pause_reason = update.reason await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) completed.append(node_id) - sequence += 1 self._save_checkpoint( run_id, node_id, - sequence, + len(completed), state, completed, paused=True, @@ -583,7 +578,7 @@ async def invoke( self._audit( run_id=run_id, node_id=node_id, - sequence=sequence, + sequence=len(completed), visit=visit_n, started_at=started_at, completed_at=completed_at, @@ -612,12 +607,11 @@ async def invoke( state = apply_update(state, update, self._reducers) completed.append(node_id) - sequence += 1 - self._save_checkpoint(run_id, node_id, sequence, state, completed) + self._save_checkpoint(run_id, node_id, len(completed), state, completed) self._audit( run_id=run_id, node_id=node_id, - sequence=sequence, + sequence=len(completed), visit=visit_n, started_at=started_at, completed_at=completed_at, @@ -664,9 +658,8 @@ async def _run_fanout( state: BaseModel, completed: list[str], run_id: str, - sequence: int, visit_counts: dict[str, int], - ) -> tuple[BaseModel, int]: + ) -> BaseModel: """Run all ``Send`` dispatches concurrently. Each task gets its own state copy with the Send's payload merged in; results are reduced into shared state. """ @@ -719,10 +712,9 @@ async def _run_one(send: Send, visit_n: int) -> tuple[Send, dict[str, Any] | Non if update: new_state = apply_update(new_state, update, self._reducers) completed.append(send.target) - sequence += 1 - self._save_checkpoint(run_id, send.target, sequence, new_state, completed) + self._save_checkpoint(run_id, send.target, len(completed), new_state, completed) - return new_state, sequence + return new_state def _save_checkpoint( self, From d1f669939d41e0053705780b8c00fd85a332dd05 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 10:36:21 +0200 Subject: [PATCH 13/26] fix(pipeline): assert result narrowed in invoke return path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After the try/finally introduced in #240, pyright could not narrow result: StatePipelineResult | None to StatePipelineResult at the return statement (a finally block can reassign, so the narrowing must be local). Add an explicit assert before return — same runtime guarantee, makes the non-None invariant visible to pyright. --- fireflyframework_agentic/pipeline/state_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index c31b8e21..57a1cc0b 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -649,6 +649,7 @@ async def invoke( success = result.success if result is not None else False await self._emit("on_pipeline_complete", self._name, run_id, success, duration_ms) + assert result is not None # set in try-block before reaching here return result async def _run_fanout( From affeff7518d0029f7d0c02467349d5e4e35f2753 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 12:53:48 +0200 Subject: [PATCH 14/26] feat(examples): software_factory example + drop Postgres/Redis from framework MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a self-contained software-factory example package under examples/software_factory/ that exercises the headline state-mode features: * state + reducers (extend on qa_feedback) * branching with router + cycle (qa fail → codegen, recursion_limit=3) * checkpoint + resume on a transient builder failure * StatePipelineEventHandler progress output Keeps codegen and builder as separate nodes deliberately: builder models a transient failure recovered via checkpoint, qa models a substantive failure recovered via cycle back to codegen — two different recovery patterns. Includes plug-and-play Checkpointer Protocol implementations under checkpointers/{postgres,redis}.py and a QueryableAuditLog implementation under audit/postgres.py. Each is a flat ~50–80 LOC class against a caller-supplied connection. FIREFLY_CKPT env var swaps backends. Same PR removes the matching in-framework backends and their machinery: * PostgresCheckpointer, RedisCheckpointer (pipeline/checkpoint.py) * PostgresAuditLog (pipeline/audit.py) * PsycopgBackend helper (pipeline/_psycopg_backend.py) — no remaining consumers * psycopg[binary] dep dropped from [postgres] extra in pyproject.toml * Postgres/Redis mocked test classes (tests/unit/pipeline/test_*.py) The Checkpointer and AuditLog Protocols, FileCheckpointer, FileAuditLog, LoggingAuditLog, and OtelAuditLog stay in the framework — concrete backends that wrap non-trivial APIs (stdlib logging, OTel logs SDK) keep their place. Also drops the run_software_factory and run_software_factory_postgres scenarios from examples/pipeline_state.py (the deep walkthrough now lives in the dedicated folder) and links the new folder from examples/README.md. Verification: pytest tests/unit/ → 1531 passed; pytest examples/software_factory/tests/ → 1 passed; ruff + pyright clean. --- examples/README.md | 2 + examples/pipeline_state.py | 181 ++------- examples/software_factory/README.md | 137 +++++++ examples/software_factory/__init__.py | 0 examples/software_factory/__main__.py | 75 ++++ examples/software_factory/agents.py | 68 ++++ examples/software_factory/audit/__init__.py | 0 examples/software_factory/audit/postgres.py | 112 ++++++ .../checkpointers/__init__.py | 0 .../checkpointers/postgres.py | 107 +++++ .../software_factory/checkpointers/redis.py | 65 ++++ examples/software_factory/pipeline.py | 64 +++ examples/software_factory/progress.py | 34 ++ examples/software_factory/state.py | 30 ++ examples/software_factory/tests/__init__.py | 0 .../software_factory/tests/test_pipeline.py | 48 +++ fireflyframework_agentic/pipeline/__init__.py | 6 - .../pipeline/_psycopg_backend.py | 76 ---- fireflyframework_agentic/pipeline/audit.py | 117 +----- .../pipeline/checkpoint.py | 188 +-------- pyproject.toml | 3 - tests/unit/pipeline/test_audit_log.py | 116 +----- .../unit/pipeline/test_checkpoint_backends.py | 364 +++--------------- 23 files changed, 846 insertions(+), 947 deletions(-) create mode 100644 examples/software_factory/README.md create mode 100644 examples/software_factory/__init__.py create mode 100644 examples/software_factory/__main__.py create mode 100644 examples/software_factory/agents.py create mode 100644 examples/software_factory/audit/__init__.py create mode 100644 examples/software_factory/audit/postgres.py create mode 100644 examples/software_factory/checkpointers/__init__.py create mode 100644 examples/software_factory/checkpointers/postgres.py create mode 100644 examples/software_factory/checkpointers/redis.py create mode 100644 examples/software_factory/pipeline.py create mode 100644 examples/software_factory/progress.py create mode 100644 examples/software_factory/state.py create mode 100644 examples/software_factory/tests/__init__.py create mode 100644 examples/software_factory/tests/test_pipeline.py delete mode 100644 fireflyframework_agentic/pipeline/_psycopg_backend.py diff --git a/examples/README.md b/examples/README.md index cd790e4f..ef28af32 100644 --- a/examples/README.md +++ b/examples/README.md @@ -56,6 +56,8 @@ If `OPENAI_API_KEY` is not set, each script will prompt you interactively. ## Pipeline Examples - **`pipeline_branching.py`** — `BranchStep` for conditional routing in a DAG, `PipelineEventHandler` for live progress, and `DAGNode.backoff_factor` for exponential retry backoff. **No API key required.** +- **`pipeline_state.py`** — Three short scenarios with the state-based `PipelineBuilder` (`state=` mode): sentiment branching with `.branch()`, map-reduce with `Send` fan-out, and a HITL deploy gate using `Pause` plus `FileAuditLog`. **No API key required.** +- **`software_factory/`** — Self-contained example package showing a state-mode agentic SDLC pipeline (`architect → codegen → builder → qa → stable_release`) with the QA feedback loop (`recursion_limit=3`), checkpoint + resume on a transient `builder` failure, and a `StatePipelineEventHandler` printing progress. Includes plug-and-play `Checkpointer` Protocol implementations for Postgres and Redis under `checkpointers/`, and a `QueryableAuditLog` Postgres template under `audit/`. **No API key required.** ## Complex Examples diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index ee9bb449..2bb3c27c 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -12,25 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""State-based PipelineBuilder: branching, checkpoint/resume, and Send fan-out. +"""State-based PipelineBuilder quick-start: branching, Send fan-out, HITL Pause. -Three scenarios: +Three short scenarios in one file: -1. **Branching** — same sentiment-classification workflow as - ``examples/pipeline_branching.py``, but written with the state-mode API - (one shared ``State`` model, ``async (state) -> dict`` nodes, one - ``.branch(source, router)`` call instead of ``BranchStep`` + per-node - ``condition`` lambdas). +1. **Branching** — sentiment-classification workflow with one ``.branch(...)`` + call (vs ``BranchStep`` + per-node ``condition`` lambdas in port-based mode). -2. **Software factory with checkpoint/resume** — a four-agent pipeline - (architect → python_dev → deployer → evaluator) where the deployer fails - on its first attempt. The pipeline checkpoints after each successful node, - and a second ``invoke(run_id=...)`` resumes from the failed node instead - of re-running the earlier agents. +2. **Map-reduce via ``Send``** — a planner dispatches one ``Send`` per work + item; workers run concurrently; an aggregator runs once with all results + merged via the ``extend`` reducer. -3. **Map-reduce via ``Send``** — a planner dispatches one ``Send`` per work - item to the same worker node, the workers run concurrently, and an - aggregator runs once with all results in shared state. +3. **HITL Pause + audit log** — a deploy gate that returns ``Pause(...)`` to + wait for human approval; resume with ``approve_pause=True``; a + ``FileAuditLog`` captures every node visit with its status. + +For the deeper software-factory walkthrough (QA feedback loop, checkpoint + +resume, Postgres / Redis checkpointer templates), see the self-contained +example package ``examples/software_factory/``. Usage:: @@ -43,7 +42,6 @@ import asyncio import logging -import os import tempfile from pathlib import Path from typing import Annotated @@ -55,13 +53,12 @@ FileCheckpointer, Pause, PipelineBuilder, - PostgresCheckpointer, Send, extend, ) -# Quiet the pipeline's own logger.exception() when we deliberately fail -# the deployer in scenario 2 — the failure is the demo, not a bug. +# Quiet the pipeline's own logger.exception() when we deliberately exercise +# a node failure — the failure is part of the demo, not a bug. logging.getLogger("fireflyframework_agentic.pipeline").setLevel(logging.CRITICAL) @@ -117,107 +114,11 @@ async def run_branching() -> None: # ============================================================================= -# Scenario 2 — Software factory with checkpoint/resume -# ============================================================================= - - -class BuildState(BaseModel): - """State threaded through a four-agent software-factory pipeline.""" - - requirements: str - spec: str | None = None - code: str | None = None - deploy_url: str | None = None - evaluation: str | None = None - - -# A flag so the deployer fails the first time and succeeds the second. -_deployer_failed_once = {"flag": False} - - -async def architect(state: BuildState) -> dict: - return {"spec": f"Architecture for: {state.requirements}"} - - -async def python_dev(state: BuildState) -> dict: - return {"code": f"# code implementing\n# {state.spec}"} - - -async def deployer(state: BuildState) -> dict: - if not _deployer_failed_once["flag"]: - _deployer_failed_once["flag"] = True - raise RuntimeError("network blip — try again") - return {"deploy_url": "https://factory-app.example.com"} - - -async def evaluator(state: BuildState) -> dict: - return {"evaluation": f"PASS — deployed at {state.deploy_url}"} - - -class ProgressHandler: - """Prints live progress for the software-factory scenario. - - Implements only a subset of :class:`StatePipelineEventHandler` — missing - methods are no-ops, which is fine because the pipeline tolerates partial - handlers. - """ - - async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: - print(f" ▶ [{pipeline_name}] run {run_id[:8]}… starting") - - async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: - print(f" ▶ {node_id} (visit #{visit})") - - async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: - print(f" ✔ {node_id} ({latency_ms:.0f}ms)") - - async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: - print(f" ✗ {node_id}: {error}") - - async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: - print(f" ⏸ {node_id} paused: {reason}") - - async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: - status = "OK" if success else "FAILED" - print(f" ═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") - - -async def run_software_factory() -> None: - print("=== 2. Software factory with checkpoint/resume + live progress ===\n") - - handler = ProgressHandler() - with tempfile.TemporaryDirectory() as tmp: - ckpt = FileCheckpointer(Path(tmp)) - pipeline = ( - PipelineBuilder( - "software-factory", - state=BuildState, - checkpointer=ckpt, - event_handler=handler, - ) - .add_node(architect) - .add_node(python_dev) - .add_node(deployer) - .add_node(evaluator) - .chain(architect, python_dev, deployer, evaluator) - .build() - ) - - # First run — deployer fails after architect + python_dev complete. - first = await pipeline.invoke(BuildState(requirements="User-management service")) - print(f" first run: success={first.success}, failed_node={first.failed_node}") - print(f" completed: {first.completed_nodes}") - print(f" run_id: {first.run_id}\n") - - # Resume — picks up at deployer, skips architect + python_dev. - second = await pipeline.invoke(run_id=first.run_id) - print(f" resumed: success={second.success}") - print(f" completed: {second.completed_nodes}") - print(f" eval: {second.state.evaluation}\n") - - -# ============================================================================= -# Scenario 3 — Map-reduce via Send +# Scenario 2 — Map-reduce via Send +# +# (The software-factory scenario that used to live here has its own folder +# now: ``examples/software_factory/``. It exercises the QA feedback loop, +# checkpoint + resume, and includes plug-and-play Postgres / Redis templates.) # ============================================================================= @@ -250,7 +151,7 @@ def dispatch(state: MapReduceState) -> list[Send]: async def run_map_reduce() -> None: - print("=== 3. Map-reduce via Send ===\n") + print("=== 2. Map-reduce via Send ===\n") pipeline = ( PipelineBuilder("mapreduce", state=MapReduceState) @@ -270,40 +171,6 @@ async def run_map_reduce() -> None: # ============================================================================= -async def run_software_factory_postgres() -> None: - """Optional: the same software-factory scenario backed by Postgres. - - Runs only when the ``PG_DSN`` env var is set (e.g. - ``PG_DSN=postgresql://user:pw@localhost/firefly``). Requires the - ``postgres`` extra: ``pip install fireflyframework-agentic[postgres]``. - """ - dsn = os.environ.get("PG_DSN") - if not dsn: - return - - print("=== 4. Software factory with PostgresCheckpointer ===\n") - - # Reset the deployer flag so this scenario starts clean. - _deployer_failed_once["flag"] = False - - checkpointer = PostgresCheckpointer(dsn=dsn) - pipeline = ( - PipelineBuilder("software-factory-pg", state=BuildState, checkpointer=checkpointer) - .add_node(architect) - .add_node(python_dev) - .add_node(deployer) - .add_node(evaluator) - .chain(architect, python_dev, deployer, evaluator) - .build() - ) - first = await pipeline.invoke(BuildState(requirements="postgres-backed deploy")) - print(f" first run: success={first.success}, failed_node={first.failed_node}") - print(f" run_id: {first.run_id}\n") - second = await pipeline.invoke(run_id=first.run_id) - print(f" resumed: success={second.success}") - print(f" eval: {second.state.evaluation}\n") - - class HitlState(BaseModel): """State threaded through a deploy pipeline gated by human approval.""" @@ -325,7 +192,7 @@ async def deploy_artifact(state: HitlState) -> dict: async def run_hitl_with_audit() -> None: - print("=== 5. Human-in-the-loop deploy gate with audit log ===\n") + print("=== 3. Human-in-the-loop deploy gate with audit log ===\n") with tempfile.TemporaryDirectory() as tmp: root = Path(tmp) @@ -369,9 +236,7 @@ async def run_hitl_with_audit() -> None: async def main() -> None: await run_branching() - await run_software_factory() await run_map_reduce() - await run_software_factory_postgres() await run_hitl_with_audit() diff --git a/examples/software_factory/README.md b/examples/software_factory/README.md new file mode 100644 index 00000000..0348be9a --- /dev/null +++ b/examples/software_factory/README.md @@ -0,0 +1,137 @@ +# `software_factory/` — a state-based agentic SDLC pipeline + +A small, self-contained example that shows the headline features of +`PipelineBuilder` in state mode: + +- **State + reducers** — one Pydantic model carries everything the agents read or write; `extend` accumulates QA feedback across loop iterations. +- **Branching** — one `.branch("qa", qa_router)` call gives both the success terminus and the QA cycle. +- **Cycle with `recursion_limit`** — the QA fail → codegen loop is something port-based DAGs cannot express. +- **Checkpoint + resume** — `builder` raises a simulated transient error on its first call; `invoke(run_id=...)` resumes from the checkpoint. +- **Observability handler** — a `StatePipelineEventHandler` prints per-node progress. + +No LLM calls. All agents are deterministic stubs so the example runs offline and the smoke test is stable. + +## Run it + +```bash +source ~/.venvs/firefly/bin/activate +python -m examples.software_factory +``` + +Expected output: + +``` +▶ [software-factory] run abc123ef… starting + ▶ architect (visit #1) + ✔ architect (0ms) + ▶ codegen (visit #1) + ✔ codegen (0ms) + ▶ builder (visit #1) + ✗ builder: dep install timed out +═ [software-factory] FAILED in 1ms + +first run: success=False failed_node=builder run_id=abc123ef… + +▶ [software-factory] run abc123ef… starting + ▶ builder (visit #1) + ✔ builder (0ms) + ▶ qa (visit #1) + ✔ qa (0ms) + ▶ codegen (visit #2) + ✔ codegen (0ms) + ▶ builder (visit #2) + ✔ builder (0ms) + ▶ qa (visit #2) + ✔ qa (0ms) + ▶ stable_release (visit #1) + ✔ stable_release (0ms) +═ [software-factory] OK in 2ms + +resumed: success=True release=v2026.05.28 iteration=2 +qa_feedback: ['missing PSD2 strong-auth flow'] +``` + +## The DAG + +``` + ┌─────────── qa_status == 'fail' → codegen (recursion_limit=3) ─────────┐ + │ │ + ▼ │ +architect → codegen → builder → qa ──(qa_router)──▶ stable_release │ + │ │ + └───────────────────────────────────────────────────┘ +``` + +| Node | What it does | +|---|---| +| `architect` | Writes a stub ADR string into `state.adr`. | +| `codegen` | Bumps `state.iteration`, writes `state.code = "v{iteration} (addresses: ...)"`. Iteration 2+ visibly incorporates `qa_feedback`. | +| `builder` | **Transient failure** on the first call across the process (`raise RuntimeError("dep install timed out")`). Succeeds on every subsequent call. | +| `qa` | **Substantive failure** on iteration 1 (`qa_status="fail"`, appends to `qa_feedback`). Passes on iteration 2. | +| `stable_release` | Sets `release_tag`. Terminal. | + +### Why are `codegen` and `builder` separate nodes? + +In stub form they look redundant. They're kept distinct because they model **two different failure-recovery patterns** the state-mode API supports: + +| Failure mode | Meaning | How the pipeline recovers | +|---|---|---| +| `builder` raises | Transient (network blip, dep flake) — same code, just retry | The engine catches the exception, checkpoints the failure, returns `success=False`. `invoke(run_id=...)` resumes by re-running `builder` in place. **No cycle.** | +| `qa` returns `"fail"` | Substantive (tests don't pass) — the code itself needs to change | `qa_router` returns `"codegen"`; the cycle re-enters `codegen` which writes v2 informed by `qa_feedback`. | + +One pipeline, two recovery patterns. Collapsing the nodes loses one of them. + +## Swapping the checkpointer + +The example defaults to `FileCheckpointer`. To run against a real Redis or Postgres: + +```bash +FIREFLY_CKPT=postgres PG_DSN="postgresql://localhost:5432/firefly" python -m examples.software_factory +FIREFLY_CKPT=redis REDIS_URL="redis://localhost:6379/0" python -m examples.software_factory +``` + +The Postgres and Redis backends live in this folder as **plug-and-play templates**, not framework code: + +- `checkpointers/postgres.py` — implements the framework's `Checkpointer` Protocol against a caller-supplied `psycopg.Connection`. +- `checkpointers/redis.py` — same idea against a caller-supplied `redis.Redis` client. +- `audit/postgres.py` — implements `QueryableAuditLog` against a caller-supplied `psycopg.Connection`. + +Each file is a flat ~50-LOC class. The framework no longer ships these — copy whichever you need into your project, adapt the table name or key prefix, and pass your own connection. The framework's `Checkpointer` and `AuditLog` Protocols are the only contract you need to match. + +## When to use Redis vs Postgres + +Both implement the same `Checkpointer` Protocol. The choice is about durability, latency, and inspection: + +| | Redis | Postgres | +|---|---|---| +| Durability | RDB + AOF; can lose the tail on crash unless `fsync=always` (slow). | WAL-fsynced; survives crashes cleanly. | +| Latency | Sub-millisecond writes. | Single-digit ms. | +| TTL | Native per-key (`EX` on `SET`). Old checkpoints disappear automatically. | Manual (cron, partition drop). | +| Inspection | `KEYS` / `GET`; no SQL, no joins. | Full SQL — joinable with the app's domain tables. | +| Footprint | Often already in the stack as a cache. | Often already in the stack as the app DB. | + +Rule of thumb: + +- **Redis** for short-lived workflows (minutes to a few hours), high throughput, where you're OK losing the last few checkpoints on a hard crash and want automatic TTL cleanup. +- **Postgres** for long-running workflows (hours to days, anything that uses `Pause` for human approval), compliance/audit needs, or when you want to query checkpoint history with SQL. + +For most Signature client apps already running on PostgreSQL Flexible Server, Postgres is the default; Redis is the choice when latency matters more than durability. + +## File layout + +``` +software_factory/ +├── README.md +├── __main__.py # entry point — crash, then resume +├── state.py # BuildState pydantic model + extend reducer +├── agents.py # 5 stub agents (architect, codegen, builder, qa, stable_release) +├── pipeline.py # build_pipeline(); qa_router +├── progress.py # StatePipelineEventHandler implementation +├── checkpointers/ +│ ├── postgres.py # Checkpointer Protocol impl (psycopg) +│ └── redis.py # Checkpointer Protocol impl (redis-py) +├── audit/ +│ └── postgres.py # QueryableAuditLog Protocol impl (psycopg) +└── tests/ + └── test_pipeline.py # end-to-end smoke test +``` diff --git a/examples/software_factory/__init__.py b/examples/software_factory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/__main__.py b/examples/software_factory/__main__.py new file mode 100644 index 00000000..a2edb9f0 --- /dev/null +++ b/examples/software_factory/__main__.py @@ -0,0 +1,75 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Entry point: ``python -m examples.software_factory``. + +Demonstrates the full QA-loop + checkpoint-resume flow: + +1. First ``invoke`` runs architect → codegen → builder. The builder raises on + its first call (simulated transient ``dep install`` failure); the engine + checkpoints the failure and returns ``success=False``. +2. Second ``invoke(run_id=...)`` resumes from the checkpoint. The builder + succeeds, QA fails with PSD2 feedback, the cycle re-enters codegen, the + rewritten code passes QA, ``stable_release`` runs, and the pipeline + finishes with ``success=True``. + +By default uses :class:`FileCheckpointer` over a tmp directory. Set +``FIREFLY_CKPT=postgres`` (with ``PG_DSN``) or ``FIREFLY_CKPT=redis`` (with +``REDIS_URL``) to swap in the templates under ``checkpointers/``. +""" + +from __future__ import annotations + +import asyncio +import os +import tempfile +from pathlib import Path + +from examples.software_factory.pipeline import build_pipeline +from examples.software_factory.state import BuildState +from fireflyframework_agentic.pipeline import Checkpointer, FileCheckpointer + + +def _resolve_checkpointer(default_dir: Path) -> Checkpointer: + backend = os.environ.get("FIREFLY_CKPT", "file").lower() + if backend == "postgres": + import psycopg + + from examples.software_factory.checkpointers.postgres import ( + PostgresCheckpointer, + ) + + dsn = os.environ["PG_DSN"] + return PostgresCheckpointer(psycopg.connect(dsn, autocommit=True)) + + if backend == "redis": + import redis + + from examples.software_factory.checkpointers.redis import RedisCheckpointer + + url = os.environ["REDIS_URL"] + return RedisCheckpointer(redis.Redis.from_url(url, decode_responses=True)) + + return FileCheckpointer(default_dir) + + +async def main() -> None: + with tempfile.TemporaryDirectory() as ckpt_dir: + checkpointer = _resolve_checkpointer(Path(ckpt_dir)) + pipeline = build_pipeline(checkpointer) + + first = await pipeline.invoke(BuildState(request="payments microservice")) + print(f"\nfirst run: success={first.success} failed_node={first.failed_node} run_id={first.run_id}\n") + + resumed = await pipeline.invoke(run_id=first.run_id) + print( + f"\nresumed: success={resumed.success} " + f"release={resumed.state.release_tag} iteration={resumed.state.iteration}" + ) + print(f"qa_feedback: {resumed.state.qa_feedback}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/software_factory/agents.py b/examples/software_factory/agents.py new file mode 100644 index 00000000..44f82efe --- /dev/null +++ b/examples/software_factory/agents.py @@ -0,0 +1,68 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Stub agents for the software factory example. + +No LLM calls. Each agent is a plain ``async (state) -> dict`` function that +returns the fields it wants merged into shared state. The behaviour is +deterministic so the example runs offline and the test suite is stable. + +Two failure modes are simulated to show how state mode handles each: + +* ``builder`` raises on its very first call (transient failure → recovered + via checkpoint + ``invoke(run_id=...)`` resume). +* ``qa`` returns ``qa_status='fail'`` on iteration 1 (substantive failure → + recovered via cycle back to ``codegen``). +""" + +from __future__ import annotations + +from examples.software_factory.state import BuildState + +_BUILDER_ATTEMPTS: dict[str, int] = {} + + +async def architect(state: BuildState) -> dict: + adr = ( + f"ADR for '{state.request}': split into api / domain / data modules; " + "use idiomatic Firefly patterns; PSD2 strong-auth required for payments." + ) + return {"adr": adr} + + +async def codegen(state: BuildState) -> dict: + next_iteration = state.iteration + 1 + if state.qa_feedback: + addressed = "; ".join(state.qa_feedback) + code = f"v{next_iteration} (addresses: {addressed})" + else: + code = f"v{next_iteration}" + return {"iteration": next_iteration, "code": code} + + +async def builder(state: BuildState) -> dict: + # Transient failure on the very first call across the whole process — + # exercises checkpoint + resume. Subsequent calls always succeed. + key = "global" + _BUILDER_ATTEMPTS[key] = _BUILDER_ATTEMPTS.get(key, 0) + 1 + if _BUILDER_ATTEMPTS[key] == 1: + raise RuntimeError("dep install timed out") + return {"build_status": "ok"} + + +async def qa(state: BuildState) -> dict: + # Substantive failure on iteration 1: code lacks PSD2 strong-auth flow. + # Iteration 2's codegen sees `qa_feedback` and rewrites the code, + # so QA passes on the next visit. + if state.iteration <= 1: + return { + "qa_status": "fail", + "qa_feedback": ["missing PSD2 strong-auth flow"], + } + return {"qa_status": "pass"} + + +async def stable_release(state: BuildState) -> dict: + return {"release_tag": "v2026.05.28"} diff --git a/examples/software_factory/audit/__init__.py b/examples/software_factory/audit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/audit/postgres.py b/examples/software_factory/audit/postgres.py new file mode 100644 index 00000000..9858faab --- /dev/null +++ b/examples/software_factory/audit/postgres.py @@ -0,0 +1,112 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Plug-and-play Postgres audit log for fireflyframework-agentic. + +This is **example code**, not framework code. Implements the framework's +:class:`QueryableAuditLog` Protocol (write + read-back) against a +caller-supplied ``psycopg.Connection``. + +Distinct from the checkpointer — the checkpointer stores the latest state +for crash recovery; the audit log stores every node visit for compliance +and replay. +""" + +from __future__ import annotations + +import json +from typing import Any + +from fireflyframework_agentic.pipeline import AuditEntry + + +class PostgresAuditLog: + """Append-only audit log backed by a single ``firefly_audit`` table. + + Implements the :class:`fireflyframework_agentic.pipeline.QueryableAuditLog` + Protocol — :meth:`record` writes one entry; :meth:`list_entries` reads + every entry for a given run in sequence order. + """ + + def __init__(self, connection: Any) -> None: + self._conn = connection + with connection.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS firefly_audit ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + visit INT NOT NULL, + node_id TEXT NOT NULL, + started_at TIMESTAMPTZ NOT NULL, + completed_at TIMESTAMPTZ NOT NULL, + latency_ms DOUBLE PRECISION NOT NULL, + status TEXT NOT NULL, + inputs_snapshot JSONB NOT NULL, + outputs_snapshot JSONB NOT NULL, + error_message TEXT, + pause_reason TEXT, + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS firefly_audit_run_idx + ON firefly_audit (pipeline_name, run_id); + """ + ) + + def record(self, entry: AuditEntry) -> None: + with self._conn.cursor() as cur: + cur.execute( + "INSERT INTO firefly_audit " + "(pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " + " latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO NOTHING", + ( + entry.pipeline_name, + entry.run_id, + entry.sequence, + entry.visit, + entry.node_id, + entry.started_at, + entry.completed_at, + entry.latency_ms, + entry.status, + json.dumps(entry.inputs_snapshot), + json.dumps(entry.outputs_snapshot), + entry.error_message, + entry.pause_reason, + ), + ) + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + with self._conn.cursor() as cur: + cur.execute( + "SELECT pipeline_name, run_id, sequence, visit, node_id, started_at, " + " completed_at, latency_ms, status, inputs_snapshot, outputs_snapshot, " + " error_message, pause_reason " + "FROM firefly_audit WHERE pipeline_name = %s AND run_id = %s " + "ORDER BY sequence", + (pipeline_name, run_id), + ) + rows = cur.fetchall() + return [ + AuditEntry( + pipeline_name=row[0], + run_id=row[1], + sequence=row[2], + visit=row[3], + node_id=row[4], + started_at=row[5], + completed_at=row[6], + latency_ms=row[7], + status=row[8], + inputs_snapshot=json.loads(row[9]) if isinstance(row[9], str) else row[9], + outputs_snapshot=json.loads(row[10]) if isinstance(row[10], str) else row[10], + error_message=row[11], + pause_reason=row[12], + ) + for row in rows + ] diff --git a/examples/software_factory/checkpointers/__init__.py b/examples/software_factory/checkpointers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/checkpointers/postgres.py b/examples/software_factory/checkpointers/postgres.py new file mode 100644 index 00000000..a9a6d3eb --- /dev/null +++ b/examples/software_factory/checkpointers/postgres.py @@ -0,0 +1,107 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Plug-and-play Postgres :class:`Checkpointer` for fireflyframework-agentic. + +This is **example code**, not framework code. Copy this file into your +project and adapt as needed: + +* Pass your own ``psycopg.Connection``. This template does not own the pool. +* Adapt the table name if ``firefly_checkpoints`` clashes with anything. +* Add retry / instrumentation in a wrapper if your stack needs it — the + framework engine already catches and logs checkpoint failures, so the + pipeline keeps running on transient errors regardless. +""" + +from __future__ import annotations + +import json +from typing import Any + +from fireflyframework_agentic.pipeline import CheckpointRecord + + +class PostgresCheckpointer: + """Stores checkpoints in a single ``firefly_checkpoints`` table. + + Implements the :class:`fireflyframework_agentic.pipeline.Checkpointer` + Protocol — three sync methods over a caller-supplied connection. + """ + + def __init__(self, connection: Any) -> None: + self._conn = connection + with connection.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS firefly_checkpoints ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + node_id TEXT NOT NULL, + state JSONB NOT NULL, + completed_nodes JSONB NOT NULL, + paused BOOLEAN NOT NULL DEFAULT FALSE, + pause_reason TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS firefly_checkpoints_run_idx + ON firefly_checkpoints (pipeline_name, run_id); + """ + ) + + def save(self, record: CheckpointRecord) -> None: + with self._conn.cursor() as cur: + cur.execute( + "INSERT INTO firefly_checkpoints " + "(pipeline_name, run_id, sequence, node_id, state, completed_nodes, paused, pause_reason) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET " + "node_id=EXCLUDED.node_id, state=EXCLUDED.state, " + "completed_nodes=EXCLUDED.completed_nodes, " + "paused=EXCLUDED.paused, pause_reason=EXCLUDED.pause_reason", + ( + record.pipeline_name, + record.run_id, + record.sequence, + record.node_id, + json.dumps(record.state), + json.dumps(record.completed_nodes), + record.paused, + record.pause_reason, + ), + ) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + with self._conn.cursor() as cur: + cur.execute( + "SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes, " + " paused, pause_reason " + "FROM firefly_checkpoints " + "WHERE pipeline_name = %s AND run_id = %s " + "ORDER BY sequence DESC LIMIT 1", + (pipeline_name, run_id), + ) + row = cur.fetchone() + if row is None: + return None + return CheckpointRecord( + pipeline_name=row[0], + run_id=row[1], + sequence=row[2], + node_id=row[3], + state=json.loads(row[4]) if isinstance(row[4], str) else row[4], + completed_nodes=json.loads(row[5]) if isinstance(row[5], str) else row[5], + paused=row[6], + pause_reason=row[7], + ) + + def list_runs(self, pipeline_name: str) -> list[str]: + with self._conn.cursor() as cur: + cur.execute( + "SELECT DISTINCT run_id FROM firefly_checkpoints WHERE pipeline_name = %s ORDER BY run_id", + (pipeline_name,), + ) + return [r[0] for r in cur.fetchall()] diff --git a/examples/software_factory/checkpointers/redis.py b/examples/software_factory/checkpointers/redis.py new file mode 100644 index 00000000..ae43283d --- /dev/null +++ b/examples/software_factory/checkpointers/redis.py @@ -0,0 +1,65 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Plug-and-play Redis :class:`Checkpointer` for fireflyframework-agentic. + +This is **example code**, not framework code. Copy this file into your +project and adapt as needed: + +* Pass your own ``redis.Redis`` client. The template does not own it. +* Tune ``ttl_seconds`` to match your workflow's longest expected wall-clock. +* The ``firefly:ckpt::runs`` ZSET does not expire — it's tiny and + serves as the index for :meth:`list_runs`. +""" + +from __future__ import annotations + +import json +import time +from typing import Any + +from fireflyframework_agentic.pipeline import CheckpointRecord + + +class RedisCheckpointer: + """Stores checkpoints as TTL'd JSON keys, indexed by a per-pipeline ZSET. + + Key layout: + + * ``firefly:ckpt:::_`` → JSON record (TTL). + * ``firefly:ckpt::runs`` → ZSET of run_ids (no TTL). + + Implements the :class:`fireflyframework_agentic.pipeline.Checkpointer` + Protocol — three sync methods over a caller-supplied client. + """ + + _PREFIX = "firefly:ckpt" + + def __init__(self, client: Any, *, ttl_seconds: int = 30 * 24 * 3600) -> None: + self._client = client + self._ttl = ttl_seconds + + def save(self, record: CheckpointRecord) -> None: + key = f"{self._PREFIX}:{record.pipeline_name}:{record.run_id}:{record.sequence:06d}_{record.node_id}" + self._client.set(key, record.model_dump_json(), ex=self._ttl) + self._client.zadd( + f"{self._PREFIX}:{record.pipeline_name}:runs", + {record.run_id: time.time()}, + ) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + pattern = f"{self._PREFIX}:{pipeline_name}:{run_id}:*" + keys = self._client.keys(pattern) + if not keys: + return None + # Keys are zero-padded by sequence — lex-sorted last = numerically-latest. + latest_key = sorted(keys)[-1] + payload = self._client.get(latest_key) + if payload is None: + return None + return CheckpointRecord.model_validate(json.loads(payload)) + + def list_runs(self, pipeline_name: str) -> list[str]: + return list(self._client.zrange(f"{self._PREFIX}:{pipeline_name}:runs", 0, -1)) diff --git a/examples/software_factory/pipeline.py b/examples/software_factory/pipeline.py new file mode 100644 index 00000000..a686a50a --- /dev/null +++ b/examples/software_factory/pipeline.py @@ -0,0 +1,64 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Wire the software-factory DAG. + +The pipeline: + + architect → codegen → builder → qa ──(qa_router)──▶ stable_release + │ + └──── qa_status='fail' ──▶ codegen (cycle) + +``qa_router`` is the one piece of routing logic — it implements the QA +feedback loop with a hard cap of ``recursion_limit=3``. +""" + +from __future__ import annotations + +from typing import cast + +from examples.software_factory.agents import ( + architect, + builder, + codegen, + qa, + stable_release, +) +from examples.software_factory.progress import ProgressHandler +from examples.software_factory.state import BuildState +from fireflyframework_agentic.pipeline import ( + Checkpointer, + PipelineBuilder, + StatePipeline, +) + + +def qa_router(state: BuildState) -> str: + """Route on QA outcome — pass → release, fail → codegen (cycle).""" + return "stable_release" if state.qa_status == "pass" else "codegen" + + +def build_pipeline(checkpointer: Checkpointer) -> StatePipeline: + pipeline = ( + PipelineBuilder( + "software-factory", + state=BuildState, + checkpointer=checkpointer, + recursion_limit=3, + event_handler=ProgressHandler(), + ) + .add_node(architect) + .add_node(codegen) + .add_node(builder) + .add_node(qa) + .add_node(stable_release) + .add_edge("architect", "codegen") + .add_edge("codegen", "builder") + .add_edge("builder", "qa") + .branch("qa", qa_router) + .build() + ) + # state= was set, so .build() returns a StatePipeline — narrow for the type checker. + return cast("StatePipeline", pipeline) diff --git a/examples/software_factory/progress.py b/examples/software_factory/progress.py new file mode 100644 index 00000000..4f9492c9 --- /dev/null +++ b/examples/software_factory/progress.py @@ -0,0 +1,34 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Console progress handler. + +Implements (structurally) the framework's :class:`StatePipelineEventHandler` +Protocol. Prints one line per pipeline / node event so the QA loop and +checkpoint+resume flow are visible when running the example by hand. +""" + +from __future__ import annotations + + +class ProgressHandler: + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + print(f"▶ [{pipeline_name}] run {run_id[:8]}… starting") + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + print(f" ✗ {node_id}: {error}") + + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + print(f" ⏸ {node_id}: {reason}") + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + status = "OK" if success else "FAILED" + print(f"═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") diff --git a/examples/software_factory/state.py b/examples/software_factory/state.py new file mode 100644 index 00000000..f9c15bb6 --- /dev/null +++ b/examples/software_factory/state.py @@ -0,0 +1,30 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Shared state for the software factory pipeline. + +One Pydantic model carries every field the agents read or write. The only +non-default reducer is ``extend`` on ``qa_feedback`` so feedback accumulates +across QA-loop iterations instead of being overwritten on each pass. +""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import extend + + +class BuildState(BaseModel): + request: str + iteration: int = 0 + adr: str | None = None + code: str | None = None + build_status: str | None = None + qa_status: str | None = None + qa_feedback: Annotated[list[str], extend] = [] + release_tag: str | None = None diff --git a/examples/software_factory/tests/__init__.py b/examples/software_factory/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/tests/test_pipeline.py b/examples/software_factory/tests/test_pipeline.py new file mode 100644 index 00000000..ff3c8d01 --- /dev/null +++ b/examples/software_factory/tests/test_pipeline.py @@ -0,0 +1,48 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""End-to-end smoke test for the software_factory example. + +Runs the pipeline against a tmp-dir FileCheckpointer; asserts that the +transient builder failure triggers a checkpointed failure on the first +invoke, and that resuming via ``run_id`` walks through the QA loop and +finishes with a stable release. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from examples.software_factory import agents +from examples.software_factory.pipeline import build_pipeline +from examples.software_factory.state import BuildState +from fireflyframework_agentic.pipeline import FileCheckpointer + + +@pytest.fixture(autouse=True) +def _reset_builder_attempts() -> None: + """The builder stub uses a process-wide counter to simulate a one-shot + transient failure. Reset it so each test sees the same starting state. + """ + agents._BUILDER_ATTEMPTS.clear() + + +async def test_factory_end_to_end(tmp_path: Path) -> None: + pipeline = build_pipeline(FileCheckpointer(tmp_path)) + + first = await pipeline.invoke(BuildState(request="payments microservice")) + assert first.success is False + assert first.failed_node == "builder" + assert first.state.iteration == 1 + assert first.state.build_status is None + + resumed = await pipeline.invoke(run_id=first.run_id) + assert resumed.success is True + assert resumed.state.release_tag == "v2026.05.28" + assert resumed.state.qa_status == "pass" + assert resumed.state.iteration == 2 # QA fail on iter 1 → loop → iter 2 passes + assert resumed.state.qa_feedback == ["missing PSD2 strong-auth flow"] diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index 6f4dbe6c..a6221a3d 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -34,7 +34,6 @@ FileAuditLog, LoggingAuditLog, OtelAuditLog, - PostgresAuditLog, QueryableAuditLog, ) from fireflyframework_agentic.pipeline.builder import PipelineBuilder @@ -42,8 +41,6 @@ Checkpointer, CheckpointRecord, FileCheckpointer, - PostgresCheckpointer, - RedisCheckpointer, ) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy @@ -102,12 +99,9 @@ "PipelineEngine", "PipelineEventHandler", "PipelineResult", - "PostgresAuditLog", - "PostgresCheckpointer", "QueryableAuditLog", "ReasoningStep", "RecursionLimitError", - "RedisCheckpointer", "RetrievalStep", "Send", "StatePipeline", diff --git a/fireflyframework_agentic/pipeline/_psycopg_backend.py b/fireflyframework_agentic/pipeline/_psycopg_backend.py deleted file mode 100644 index b6784c9d..00000000 --- a/fireflyframework_agentic/pipeline/_psycopg_backend.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2026 Firefly Software Foundation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared scaffolding for Postgres-backed pipeline backends. - -The checkpointer and audit-log backends both need the same boilerplate: -optional-dep guard on ``psycopg``, ``dsn`` xor ``connection`` constructor -check, table-name validation, and lazy idempotent DDL on first write. This -module centralizes it so each backend only has to declare its DDL and -default table name. -""" - -from __future__ import annotations - -from typing import Any - -try: - import psycopg as _psycopg # type: ignore[import-not-found] -except ImportError: # pragma: no cover - optional dep - _psycopg = None # type: ignore[assignment] - - -class PsycopgBackend: - """Base class for backends that persist into a single Postgres table. - - Subclasses set the class attribute ``_DDL`` to a format string with a - single ``{table}`` placeholder, and pass their human-readable name and - default table to ``__init__``. The base class handles the rest: - - * Raises ``ImportError`` if the ``postgres`` extra is not installed. - * Enforces ``dsn`` xor ``connection``. - * Validates ``table_name`` against SQL injection (interpolated into DDL). - * Opens the connection (with ``autocommit=True``) when only ``dsn`` is given. - * Applies the DDL lazily and idempotently on first ``_ensure_table()`` call. - """ - - _DDL: str = "" - - def __init__( - self, - *, - kind: str, - dsn: str | None, - connection: Any, - table_name: str, - ) -> None: - if _psycopg is None: - raise ImportError( - f"{kind} requires the 'postgres' extra. Install with: pip install fireflyframework-agentic[postgres]" - ) - if (dsn is None) == (connection is None): - raise ValueError(f"{kind} needs exactly one of `dsn` or `connection`.") - # Table name is interpolated into DDL — validate strictly to avoid SQL injection. - if not table_name.replace("_", "").isalnum(): - raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") - self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) - self._table = table_name - self._ddl_applied = False - - def _ensure_table(self) -> None: - if self._ddl_applied: - return - with self._conn.cursor() as cur: - cur.execute(self._DDL.format(table=self._table)) - self._ddl_applied = True diff --git a/fireflyframework_agentic/pipeline/audit.py b/fireflyframework_agentic/pipeline/audit.py index d9691775..fbb02156 100644 --- a/fireflyframework_agentic/pipeline/audit.py +++ b/fireflyframework_agentic/pipeline/audit.py @@ -18,22 +18,20 @@ checkpointer stores the *latest* state for crash recovery; the audit log stores *every* node visit for compliance, debugging, and replay. -Four backends ship: +Three backends ship in the framework: * :class:`FileAuditLog` — one JSONL file per ``(pipeline_name, run_id)``. - Best for dev / single-host audit trails. -* :class:`PostgresAuditLog` — single ``firefly_audit`` table; reuses the - ``psycopg`` connection from the ``postgres`` optional extra (same dep - added in Phase 3a for ``PostgresCheckpointer``). + Best for dev / single-host audit trails. Implements :class:`QueryableAuditLog`. * :class:`LoggingAuditLog` — stdlib ``logging``; pairs with whatever log aggregation pipeline (Splunk-HEC, Loki, Datadog, OTel-LoggingHandler-bridge) - the host application already runs. No new optional dep. + the host application already runs. Write-only. * :class:`OtelAuditLog` — direct OTel logs API; attaches trace correlation (``trace_id``/``span_id``) automatically. Best for OTel-native stacks - (Application Insights, Datadog APM, OTel-Collector). + (Application Insights, Datadog APM, OTel-Collector). Write-only. -File and Postgres also implement :class:`QueryableAuditLog` (``list_entries``); -Logging and OTel are write-only — query your observability stack instead. +For a Postgres-backed queryable audit log, see the plug-and-play template at +``examples/software_factory/audit/postgres.py`` — a ~80 LOC class implementing +:class:`QueryableAuditLog` against a caller-supplied ``psycopg.Connection``. """ from __future__ import annotations @@ -46,8 +44,6 @@ from pydantic import BaseModel -from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend - try: from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found] from opentelemetry._logs import SeverityNumber as _OtelSeverityNumber # type: ignore[import-not-found] @@ -142,105 +138,6 @@ def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: return entries -class PostgresAuditLog(PsycopgBackend): - """Postgres-backed audit log. Single table created on first ``record`` call. - - Reuses ``psycopg`` from the ``postgres`` optional extra. ``dsn`` or a - pre-built ``connection`` is required, mirroring :class:`PostgresCheckpointer`'s - API for consistency. - """ - - _DDL = """ - CREATE TABLE IF NOT EXISTS {table} ( - pipeline_name TEXT NOT NULL, - run_id TEXT NOT NULL, - sequence INT NOT NULL, - visit INT NOT NULL, - node_id TEXT NOT NULL, - started_at TIMESTAMPTZ NOT NULL, - completed_at TIMESTAMPTZ NOT NULL, - latency_ms DOUBLE PRECISION NOT NULL, - status TEXT NOT NULL, - inputs_snapshot JSONB NOT NULL, - outputs_snapshot JSONB NOT NULL, - error_message TEXT, - pause_reason TEXT, - PRIMARY KEY (pipeline_name, run_id, sequence) - ); - CREATE INDEX IF NOT EXISTS {table}_run_idx - ON {table} (pipeline_name, run_id); - """ - - def __init__( - self, - *, - dsn: str | None = None, - connection: Any = None, - table_name: str = "firefly_audit", - ) -> None: - super().__init__(kind="PostgresAuditLog", dsn=dsn, connection=connection, table_name=table_name) - - def record(self, entry: AuditEntry) -> None: - self._ensure_table() - sql = ( - f"INSERT INTO {self._table} " - "(pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " - "latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason) " - "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) " - "ON CONFLICT (pipeline_name, run_id, sequence) DO NOTHING" - ) - with self._conn.cursor() as cur: - cur.execute( - sql, - ( - entry.pipeline_name, - entry.run_id, - entry.sequence, - entry.visit, - entry.node_id, - entry.started_at, - entry.completed_at, - entry.latency_ms, - entry.status, - json.dumps(entry.inputs_snapshot), - json.dumps(entry.outputs_snapshot), - entry.error_message, - entry.pause_reason, - ), - ) - - def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: - self._ensure_table() - sql = ( - f"SELECT pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " - f"latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason " - f"FROM {self._table} WHERE pipeline_name = %s AND run_id = %s ORDER BY sequence" - ) - with self._conn.cursor() as cur: - cur.execute(sql, (pipeline_name, run_id)) - rows = cur.fetchall() - entries: list[AuditEntry] = [] - for row in rows: - entries.append( - AuditEntry( - pipeline_name=row[0], - run_id=row[1], - sequence=row[2], - visit=row[3], - node_id=row[4], - started_at=row[5], - completed_at=row[6], - latency_ms=row[7], - status=row[8], - inputs_snapshot=json.loads(row[9]) if isinstance(row[9], str) else row[9], - outputs_snapshot=json.loads(row[10]) if isinstance(row[10], str) else row[10], - error_message=row[11], - pause_reason=row[12], - ) - ) - return entries - - class LoggingAuditLog: """Audit log backed by Python's stdlib ``logging`` module. diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index 5cb2b5cf..765b8355 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -18,34 +18,21 @@ ``(pipeline_name, run_id, node_id)``. On resume the engine loads the latest checkpoint and skips nodes that already completed in that run. -Three backends ship: - -* :class:`FileCheckpointer` — filesystem JSON. Best for dev / single-host. -* :class:`RedisCheckpointer` — Redis with TTL. Best for multi-worker, - sub-day-scale runs. Requires ``pip install fireflyframework-agentic[redis]``. -* :class:`PostgresCheckpointer` — Postgres with a single ``firefly_checkpoints`` - table. Best for long-lived runs / compliance. Requires - ``pip install fireflyframework-agentic[postgres]``. - -Any backend conforms to the :class:`Checkpointer` Protocol and is interchangeable. +The framework ships :class:`FileCheckpointer` for dev / single-host work. +For Postgres- or Redis-backed checkpointing, see the plug-and-play templates +under ``examples/software_factory/checkpointers/``: each is a ~50 LOC class +that implements the :class:`Checkpointer` Protocol against a caller-supplied +connection. Copy whichever you need into your project and adapt. """ from __future__ import annotations import json -import time from pathlib import Path from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel -from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend - -try: - import redis as _redis # type: ignore[import-not-found] -except ImportError: # pragma: no cover - optional dep - _redis = None # type: ignore[assignment] - class CheckpointRecord(BaseModel): """One saved checkpoint. @@ -120,168 +107,3 @@ def list_runs(self, pipeline_name: str) -> list[str]: if not pipeline_dir.exists(): return [] return sorted(d.name for d in pipeline_dir.iterdir() if d.is_dir()) - - -class RedisCheckpointer: - """Redis-backed checkpointer. - - Key layout:: - - firefly:ckpt:::_ -> JSON record (with TTL) - firefly:ckpt::runs -> ZSET of run_ids (score = last update ts) - - Parameters: - url: Redis connection string (e.g. ``redis://host:6379/0``). Mutually - exclusive with ``client``. - client: Pre-built ``redis.Redis`` instance. Use this to share a - connection pool across many pipelines. - ttl_seconds: TTL applied to each checkpoint key. Default 30 days. - The runs-index ZSET does not expire (it's tiny). - key_prefix: Override the default ``firefly:ckpt`` key prefix. - - Raises: - ImportError: When the ``redis`` extra is not installed. - """ - - def __init__( - self, - *, - url: str | None = None, - client: Any = None, - ttl_seconds: int = 60 * 60 * 24 * 30, - key_prefix: str = "firefly:ckpt", - ) -> None: - if _redis is None: - raise ImportError( - "RedisCheckpointer requires the 'redis' extra. " - "Install with: pip install fireflyframework-agentic[redis]" - ) - if (url is None) == (client is None): - raise ValueError("RedisCheckpointer needs exactly one of `url` or `client`.") - self._client = client if client is not None else _redis.Redis.from_url(url, decode_responses=True) - self._ttl = ttl_seconds - self._prefix = key_prefix - - def _ckpt_key(self, pipeline: str, run_id: str, sequence: int, node_id: str) -> str: - return f"{self._prefix}:{pipeline}:{run_id}:{sequence:06d}_{node_id}" - - def _runs_index_key(self, pipeline: str) -> str: - return f"{self._prefix}:{pipeline}:runs" - - def _run_pattern(self, pipeline: str, run_id: str) -> str: - return f"{self._prefix}:{pipeline}:{run_id}:*" - - def save(self, record: CheckpointRecord) -> None: - key = self._ckpt_key(record.pipeline_name, record.run_id, record.sequence, record.node_id) - self._client.set(key, record.model_dump_json(), ex=self._ttl) - self._client.zadd(self._runs_index_key(record.pipeline_name), {record.run_id: time.time()}) - - def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: - keys = self._client.keys(self._run_pattern(pipeline_name, run_id)) - if not keys: - return None - # Keys are zero-padded by sequence; lex-sorted last = numerically-latest. - latest_key = sorted(keys)[-1] - payload = self._client.get(latest_key) - if payload is None: - return None - return CheckpointRecord.model_validate(json.loads(payload)) - - def list_runs(self, pipeline_name: str) -> list[str]: - return list(self._client.zrange(self._runs_index_key(pipeline_name), 0, -1)) - - -class PostgresCheckpointer(PsycopgBackend): - """Postgres-backed checkpointer. - - Uses a single table created on first ``save`` call. The DDL is idempotent - so multiple processes pointing at the same database are safe. - - Parameters: - dsn: Postgres connection string. Mutually exclusive with ``connection``. - connection: Pre-built ``psycopg.Connection``. Use this to share a - connection across many pipelines. - table_name: Override the default ``firefly_checkpoints`` table name. - - Raises: - ImportError: When the ``postgres`` extra is not installed. - """ - - _DDL = """ - CREATE TABLE IF NOT EXISTS {table} ( - pipeline_name TEXT NOT NULL, - run_id TEXT NOT NULL, - sequence INT NOT NULL, - node_id TEXT NOT NULL, - state JSONB NOT NULL, - completed_nodes JSONB NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - PRIMARY KEY (pipeline_name, run_id, sequence) - ); - CREATE INDEX IF NOT EXISTS {table}_run_idx - ON {table} (pipeline_name, run_id); - """ - - def __init__( - self, - *, - dsn: str | None = None, - connection: Any = None, - table_name: str = "firefly_checkpoints", - ) -> None: - super().__init__(kind="PostgresCheckpointer", dsn=dsn, connection=connection, table_name=table_name) - - def save(self, record: CheckpointRecord) -> None: - self._ensure_table() - sql = ( - f"INSERT INTO {self._table} " - "(pipeline_name, run_id, sequence, node_id, state, completed_nodes) " - "VALUES (%s, %s, %s, %s, %s, %s) " - "ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET " - "node_id = EXCLUDED.node_id, " - "state = EXCLUDED.state, " - "completed_nodes = EXCLUDED.completed_nodes" - ) - with self._conn.cursor() as cur: - cur.execute( - sql, - ( - record.pipeline_name, - record.run_id, - record.sequence, - record.node_id, - json.dumps(record.state), - json.dumps(record.completed_nodes), - ), - ) - - def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: - self._ensure_table() - sql = ( - f"SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes " - f"FROM {self._table} " - f"WHERE pipeline_name = %s AND run_id = %s " - f"ORDER BY sequence DESC LIMIT 1" - ) - with self._conn.cursor() as cur: - cur.execute(sql, (pipeline_name, run_id)) - row = cur.fetchone() - if row is None: - return None - pipeline, rid, seq, node_id, state, completed = row - # psycopg returns JSONB as parsed Python objects; tolerate raw strings too. - return CheckpointRecord( - pipeline_name=pipeline, - run_id=rid, - sequence=seq, - node_id=node_id, - state=json.loads(state) if isinstance(state, str) else state, - completed_nodes=json.loads(completed) if isinstance(completed, str) else completed, - ) - - def list_runs(self, pipeline_name: str) -> list[str]: - self._ensure_table() - sql = f"SELECT DISTINCT run_id FROM {self._table} WHERE pipeline_name = %s ORDER BY run_id" - with self._conn.cursor() as cur: - cur.execute(sql, (pipeline_name,)) - return [r[0] for r in cur.fetchall()] diff --git a/pyproject.toml b/pyproject.toml index ed77b3c7..416a109d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,6 @@ queues = [ postgres = [ "asyncpg>=0.30.0", "sqlalchemy>=2.0.0", - # psycopg drives the sync-Protocol PostgresCheckpointer; asyncpg above - # drives the async memory store. Both ship under the same extra. - "psycopg[binary]>=3.2.0,<4", ] mongodb = [ "motor>=3.6.0", diff --git a/tests/unit/pipeline/test_audit_log.py b/tests/unit/pipeline/test_audit_log.py index 369cf6a5..6ccb1fde 100644 --- a/tests/unit/pipeline/test_audit_log.py +++ b/tests/unit/pipeline/test_audit_log.py @@ -3,7 +3,12 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -"""Phase-3c audit-log tests — File / Postgres / Logging / OTel backends + pipeline wiring.""" +"""Audit-log tests — File / Logging / OTel backends + pipeline wiring. + +PostgresAuditLog used to live in the framework and was tested here with mocks; +it moved to ``examples/software_factory/audit/postgres.py`` as a plug-and-play +template. +""" from __future__ import annotations @@ -17,7 +22,6 @@ import pytest from pydantic import BaseModel -import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module import fireflyframework_agentic.pipeline.audit as audit_module from fireflyframework_agentic.pipeline import ( AuditEntry, @@ -26,7 +30,6 @@ OtelAuditLog, Pause, PipelineBuilder, - PostgresAuditLog, ) @@ -79,15 +82,13 @@ def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None: # ============================================================================= -# PostgresAuditLog +# Optional-dep stubs for OTel # ============================================================================= @pytest.fixture(autouse=True) def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: - """Stub _psycopg and OTel symbols so backends can be constructed with mocks.""" - if psycopg_backend_module._psycopg is None: - monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub")) + """Stub OTel symbols so OtelAuditLog can be constructed with mocks.""" if audit_module._otel_get_logger is None: monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory")) monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord")) @@ -99,107 +100,6 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(audit_module, "_OtelSeverityNumber", sev) -def test_postgres_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(psycopg_backend_module, "_psycopg", None) - with pytest.raises(ImportError, match=r"\[postgres\]"): - PostgresAuditLog(dsn="postgresql://x") - - -def _pg_conn_mock() -> tuple[MagicMock, dict]: - """MagicMock connection backed by an in-memory dict keyed by (pipeline,run,seq).""" - store: dict[tuple[str, str, int], dict[str, Any]] = {} - ddl_calls: list[str] = [] - conn = MagicMock(name="psycopg.Connection") - - def make_cursor() -> MagicMock: - cur = MagicMock() - cur.__enter__ = MagicMock(return_value=cur) - cur.__exit__ = MagicMock(return_value=None) - cur._last_one = None - cur._last_all = [] - - def fake_execute(sql: str, params: tuple | None = None) -> None: - s = sql.strip().lower() - if s.startswith("create table"): - ddl_calls.append(sql) - return - if s.startswith("insert into"): - assert params is not None - key = (params[0], params[1], params[2]) - store[key] = { - "pipeline_name": params[0], - "run_id": params[1], - "sequence": params[2], - "visit": params[3], - "node_id": params[4], - "started_at": params[5], - "completed_at": params[6], - "latency_ms": params[7], - "status": params[8], - "inputs_snapshot": json.loads(params[9]) if isinstance(params[9], str) else params[9], - "outputs_snapshot": json.loads(params[10]) if isinstance(params[10], str) else params[10], - "error_message": params[11], - "pause_reason": params[12], - } - return - if s.startswith("select"): - assert params is not None - rows = [v for k, v in store.items() if k[0] == params[0] and k[1] == params[1]] - rows.sort(key=lambda r: r["sequence"]) - cur._last_all = [ - ( - r["pipeline_name"], - r["run_id"], - r["sequence"], - r["visit"], - r["node_id"], - r["started_at"], - r["completed_at"], - r["latency_ms"], - r["status"], - r["inputs_snapshot"], - r["outputs_snapshot"], - r["error_message"], - r["pause_reason"], - ) - for r in rows - ] - return - raise AssertionError(f"unexpected SQL: {sql}") - - cur.execute.side_effect = fake_execute - cur.fetchone.side_effect = lambda: cur._last_one - cur.fetchall.side_effect = lambda: cur._last_all - return cur - - conn.cursor.side_effect = make_cursor - conn._ddl_calls = ddl_calls - return conn, store - - -def test_postgres_audit_ddl_once_then_inserts() -> None: - conn, store = _pg_conn_mock() - log = PostgresAuditLog(connection=conn) - for seq in (1, 2, 3): - log.record(_entry(sequence=seq, node_id=f"n{seq}")) - assert len(conn._ddl_calls) == 1 - assert len(store) == 3 - - -def test_postgres_audit_list_entries_orders_by_sequence() -> None: - conn, _ = _pg_conn_mock() - log = PostgresAuditLog(connection=conn) - for seq in (3, 1, 2): - log.record(_entry(sequence=seq, node_id=f"n{seq}")) - entries = log.list_entries("p", "r") - assert [e.sequence for e in entries] == [1, 2, 3] - - -def test_postgres_audit_rejects_bad_table_name() -> None: - with pytest.raises(ValueError, match="Invalid table_name"): - PostgresAuditLog(connection=MagicMock(), table_name="bad; DROP TABLE") - - # ============================================================================= # LoggingAuditLog # ============================================================================= diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py index 47a6674c..419faca3 100644 --- a/tests/unit/pipeline/test_checkpoint_backends.py +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -3,325 +3,105 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -"""Tests for the three Checkpointer backends (File, Redis, Postgres). +"""Tests for the framework's File Checkpointer. -Mocks only — no real Redis or Postgres needed. Real-service verification is -out-of-band against actual servers. +The Postgres and Redis backends used to live in the framework and were +exercised here with mocks; both moved to plug-and-play templates under +``examples/software_factory/checkpointers/`` (apps that need them copy the +file into their repo and test it against their own infra). """ from __future__ import annotations -import json -from typing import Any -from unittest.mock import MagicMock, create_autospec - import pytest from pydantic import BaseModel -import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module -import fireflyframework_agentic.pipeline.checkpoint as checkpoint_module from fireflyframework_agentic.pipeline import ( CheckpointRecord, FileCheckpointer, PipelineBuilder, - PostgresCheckpointer, - RedisCheckpointer, StatePipeline, ) - -@pytest.fixture(autouse=True) -def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: - """Make ``_redis`` / ``_psycopg`` truthy so we can construct backends with mock clients. - - Tests that explicitly want the missing-dep code path (the two - ``..._missing_dep_raises`` tests) override this by setting the symbol back - to None via their own monkeypatch — the per-test patch wins. - """ - if checkpoint_module._redis is None: - monkeypatch.setattr(checkpoint_module, "_redis", MagicMock(name="redis_stub")) - if psycopg_backend_module._psycopg is None: - monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub")) - - # ============================================================================= -# RedisCheckpointer +# FileCheckpointer # ============================================================================= -def _redis_client_mock() -> MagicMock: - """Build a MagicMock that tracks the calls a RedisCheckpointer issues, with - just enough state (an in-memory dict) to make round-trip save→load work. - """ - store: dict[str, str] = {} - runs_index: dict[str, dict[str, float]] = {} - - client = MagicMock(name="redis.Redis") - - def fake_set(key: str, value: str, ex: int | None = None) -> bool: - store[key] = value - return True - - def fake_get(key: str) -> str | None: - return store.get(key) - - def fake_keys(pattern: str) -> list[str]: - # Trivial glob: only handles "*" patterns (which is what we use). - if pattern.endswith("*"): - prefix = pattern[:-1] - return [k for k in store if k.startswith(prefix)] - return [k for k in store if k == pattern] - - def fake_zadd(key: str, mapping: dict[str, float]) -> int: - runs_index.setdefault(key, {}).update(mapping) - return len(mapping) - - def fake_zrange(key: str, start: int, end: int) -> list[str]: - members = runs_index.get(key, {}) - ordered = sorted(members, key=lambda m: members[m]) - return ordered[start : (end + 1 if end != -1 else None)] +def test_file_checkpointer_save_and_load_latest(tmp_path) -> None: + ckpt = FileCheckpointer(tmp_path / "ckpt") - client.set.side_effect = fake_set - client.get.side_effect = fake_get - client.keys.side_effect = fake_keys - client.zadd.side_effect = fake_zadd - client.zrange.side_effect = fake_zrange - return client - - -def test_redis_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(checkpoint_module, "_redis", None) - with pytest.raises(ImportError, match=r"\[redis\]"): - RedisCheckpointer(url="redis://x") - - -def test_redis_checkpointer_rejects_both_url_and_client() -> None: - with pytest.raises(ValueError, match="exactly one"): - RedisCheckpointer(url="redis://x", client=MagicMock()) - - -def test_redis_checkpointer_rejects_neither_url_nor_client() -> None: - with pytest.raises(ValueError, match="exactly one"): - RedisCheckpointer() - - -def test_redis_save_issues_set_and_zadd() -> None: - client = _redis_client_mock() - ckpt = RedisCheckpointer(client=client, ttl_seconds=42) - record = CheckpointRecord( - pipeline_name="p", - run_id="r", - sequence=1, - node_id="n", - state={"x": 1}, - completed_nodes=["n"], + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=1, + node_id="a", + state={"k": 1}, + completed_nodes=["a"], + ) ) - ckpt.save(record) - - set_call = client.set.call_args - assert set_call.args[0] == "firefly:ckpt:p:r:000001_n" - assert json.loads(set_call.args[1])["node_id"] == "n" - assert set_call.kwargs["ex"] == 42 - - zadd_call = client.zadd.call_args - assert zadd_call.args[0] == "firefly:ckpt:p:runs" - assert "r" in zadd_call.args[1] - - -def test_redis_load_latest_picks_highest_sequence_key() -> None: - client = _redis_client_mock() - ckpt = RedisCheckpointer(client=client) - for seq in (1, 5, 3): - ckpt.save( - CheckpointRecord( - pipeline_name="p", - run_id="r", - sequence=seq, - node_id=f"node{seq}", - state={"seq": seq}, - completed_nodes=[], - ) + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=2, + node_id="b", + state={"k": 2}, + completed_nodes=["a", "b"], ) + ) + latest = ckpt.load_latest("p", "r") assert latest is not None - assert latest.sequence == 5 - assert latest.node_id == "node5" + assert latest.node_id == "b" + assert latest.state == {"k": 2} + assert latest.completed_nodes == ["a", "b"] -def test_redis_load_latest_returns_none_when_run_unknown() -> None: - client = _redis_client_mock() - ckpt = RedisCheckpointer(client=client) - assert ckpt.load_latest("p", "missing") is None +def test_file_checkpointer_load_latest_unknown_run_returns_none(tmp_path) -> None: + assert FileCheckpointer(tmp_path).load_latest("p", "missing") is None -def test_redis_list_runs_returns_zrange_result() -> None: - client = _redis_client_mock() - ckpt = RedisCheckpointer(client=client) - for run_id in ("r1", "r2", "r3"): +def test_file_checkpointer_list_runs(tmp_path) -> None: + ckpt = FileCheckpointer(tmp_path / "ckpt") + for run_id in ("rA", "rB"): ckpt.save( CheckpointRecord( pipeline_name="p", run_id=run_id, sequence=1, - node_id="n", - state={}, - completed_nodes=[], - ) - ) - runs = ckpt.list_runs("p") - assert set(runs) == {"r1", "r2", "r3"} - - -# ============================================================================= -# PostgresCheckpointer -# ============================================================================= - - -def _postgres_connection_mock() -> tuple[MagicMock, dict[tuple, dict[str, Any]]]: - """MagicMock Connection backed by an in-memory dict shaped like the - firefly_checkpoints table. Returns ``(conn, store)`` so tests can poke the store. - """ - store: dict[tuple[str, str, int], dict[str, Any]] = {} - ddl_calls: list[str] = [] - - conn = MagicMock(name="psycopg.Connection") - - def make_cursor() -> MagicMock: - cur = MagicMock(name="psycopg.Cursor") - cur.__enter__ = MagicMock(return_value=cur) - cur.__exit__ = MagicMock(return_value=None) - cur._last_fetchone = None - cur._last_fetchall = [] - - def fake_execute(sql: str, params: tuple | None = None) -> None: - sql_lower = sql.strip().lower() - if sql_lower.startswith("create table"): - ddl_calls.append(sql) - return - if sql_lower.startswith("insert into"): - assert params is not None - key = (params[0], params[1], params[2]) - store[key] = { - "pipeline_name": params[0], - "run_id": params[1], - "sequence": params[2], - "node_id": params[3], - "state": json.loads(params[4]) if isinstance(params[4], str) else params[4], - "completed_nodes": (json.loads(params[5]) if isinstance(params[5], str) else params[5]), - } - return - if sql_lower.startswith("select pipeline_name"): - assert params is not None - matches = [v for k, v in store.items() if k[0] == params[0] and k[1] == params[1]] - matches.sort(key=lambda r: r["sequence"], reverse=True) - cur._last_fetchone = ( - ( - matches[0]["pipeline_name"], - matches[0]["run_id"], - matches[0]["sequence"], - matches[0]["node_id"], - matches[0]["state"], - matches[0]["completed_nodes"], - ) - if matches - else None - ) - return - if sql_lower.startswith("select distinct run_id"): - assert params is not None - runs = sorted({k[1] for k in store if k[0] == params[0]}) - cur._last_fetchall = [(r,) for r in runs] - return - raise AssertionError(f"unexpected SQL: {sql}") - - cur.execute.side_effect = fake_execute - cur.fetchone.side_effect = lambda: cur._last_fetchone - cur.fetchall.side_effect = lambda: cur._last_fetchall - return cur - - conn.cursor.side_effect = make_cursor - conn._ddl_calls = ddl_calls - return conn, store - - -def test_postgres_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(psycopg_backend_module, "_psycopg", None) - with pytest.raises(ImportError, match=r"\[postgres\]"): - PostgresCheckpointer(dsn="postgresql://x") - - -def test_postgres_checkpointer_rejects_both_dsn_and_connection() -> None: - with pytest.raises(ValueError, match="exactly one"): - PostgresCheckpointer(dsn="postgresql://x", connection=MagicMock()) - - -def test_postgres_checkpointer_rejects_bad_table_name() -> None: - with pytest.raises(ValueError, match="Invalid table_name"): - PostgresCheckpointer(connection=MagicMock(), table_name="bad; DROP TABLE users") - - -def test_postgres_ddl_runs_once_across_many_saves() -> None: - conn, _store = _postgres_connection_mock() - ckpt = PostgresCheckpointer(connection=conn) - for seq in range(3): - ckpt.save( - CheckpointRecord( - pipeline_name="p", - run_id="r", - sequence=seq, - node_id=f"n{seq}", + node_id="a", state={}, - completed_nodes=[], + completed_nodes=["a"], ) ) - assert len(conn._ddl_calls) == 1, "DDL should run exactly once per instance" - - -def test_postgres_save_then_load_latest_round_trips() -> None: - conn, _store = _postgres_connection_mock() - ckpt = PostgresCheckpointer(connection=conn) - for seq, node_id in [(1, "a"), (5, "e"), (3, "c")]: - ckpt.save( - CheckpointRecord( - pipeline_name="p", - run_id="r", - sequence=seq, - node_id=node_id, - state={"seq": seq}, - completed_nodes=[node_id], - ) + assert ckpt.list_runs("p") == ["rA", "rB"] + assert ckpt.list_runs("missing") == [] + + +def test_file_checkpointer_paused_record_round_trips(tmp_path) -> None: + ckpt = FileCheckpointer(tmp_path) + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=1, + node_id="await_approval", + state={"x": 1}, + completed_nodes=["a", "await_approval"], + paused=True, + pause_reason="waiting on human", ) + ) latest = ckpt.load_latest("p", "r") assert latest is not None - assert latest.sequence == 5 - assert latest.node_id == "e" - - -def test_postgres_load_latest_returns_none_when_empty() -> None: - conn, _store = _postgres_connection_mock() - ckpt = PostgresCheckpointer(connection=conn) - assert ckpt.load_latest("p", "missing") is None - - -def test_postgres_list_runs_returns_distinct_run_ids() -> None: - conn, _store = _postgres_connection_mock() - ckpt = PostgresCheckpointer(connection=conn) - for run_id in ("rA", "rB", "rA"): # rA twice, rB once - ckpt.save( - CheckpointRecord( - pipeline_name="p", - run_id=run_id, - sequence=1, - node_id="n", - state={}, - completed_nodes=[], - ) - ) - assert ckpt.list_runs("p") == ["rA", "rB"] + assert latest.paused is True + assert latest.pause_reason == "waiting on human" # ============================================================================= -# Protocol conformance — software-factory scenario across all three backends +# Protocol conformance — software-factory scenario against File backend # ============================================================================= @@ -333,7 +113,7 @@ class FactoryState(BaseModel): evaluation: str | None = None -def _build_factory(checkpointer: Any) -> StatePipeline: +def _build_factory(checkpointer) -> StatePipeline: """Construct the canonical 4-step agent pipeline that fails on first deploy.""" state_flag = {"failed_once": False} @@ -365,27 +145,10 @@ async def evaluator(state: FactoryState) -> dict: return pipeline -@pytest.fixture -def file_backend(tmp_path): - return FileCheckpointer(tmp_path / "ckpt") - - -@pytest.fixture -def redis_backend(): - return RedisCheckpointer(client=_redis_client_mock()) - - -@pytest.fixture -def postgres_backend(): - conn, _store = _postgres_connection_mock() - return PostgresCheckpointer(connection=conn) - - @pytest.mark.asyncio -@pytest.mark.parametrize("backend_fixture", ["file_backend", "redis_backend", "postgres_backend"]) -async def test_backend_supports_fail_and_resume(backend_fixture, request) -> None: - """Same scenario across all three backends: deployer fails, resume completes.""" - backend = request.getfixturevalue(backend_fixture) +async def test_file_backend_supports_fail_and_resume(tmp_path) -> None: + """Deployer fails on its first call → run is checkpointed → resume completes.""" + backend = FileCheckpointer(tmp_path / "ckpt") pipeline = _build_factory(backend) first = await pipeline.invoke(FactoryState(requirements="users service")) @@ -397,8 +160,3 @@ async def test_backend_supports_fail_and_resume(backend_fixture, request) -> Non assert second.success assert second.completed_nodes == ["architect", "python_dev", "deployer", "evaluator"] assert second.state.evaluation == "PASS https://app" - - -# Silence the unused-import warning for the autospec we don't actually use here -# but keep available for follow-up tests. -_ = create_autospec From ff5bd7c2c9fc8ce1191e703fb834a7bc625dc368 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 12:58:32 +0200 Subject: [PATCH 15/26] test(examples): move software_factory smoke test to tests/examples/ Match the layout used by other example packages (tests/examples/corpus_search/ is the established pattern). README links to the new test location. --- examples/software_factory/README.md | 8 ++++---- .../tests => tests/examples/software_factory}/__init__.py | 0 .../examples/software_factory}/test_pipeline.py | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename {examples/software_factory/tests => tests/examples/software_factory}/__init__.py (100%) rename {examples/software_factory/tests => tests/examples/software_factory}/test_pipeline.py (100%) diff --git a/examples/software_factory/README.md b/examples/software_factory/README.md index 0348be9a..e92a53a6 100644 --- a/examples/software_factory/README.md +++ b/examples/software_factory/README.md @@ -130,8 +130,8 @@ software_factory/ ├── checkpointers/ │ ├── postgres.py # Checkpointer Protocol impl (psycopg) │ └── redis.py # Checkpointer Protocol impl (redis-py) -├── audit/ -│ └── postgres.py # QueryableAuditLog Protocol impl (psycopg) -└── tests/ - └── test_pipeline.py # end-to-end smoke test +└── audit/ + └── postgres.py # QueryableAuditLog Protocol impl (psycopg) ``` + +The end-to-end smoke test lives at `tests/examples/software_factory/test_pipeline.py` — same shape as the other example tests in this repo. diff --git a/examples/software_factory/tests/__init__.py b/tests/examples/software_factory/__init__.py similarity index 100% rename from examples/software_factory/tests/__init__.py rename to tests/examples/software_factory/__init__.py diff --git a/examples/software_factory/tests/test_pipeline.py b/tests/examples/software_factory/test_pipeline.py similarity index 100% rename from examples/software_factory/tests/test_pipeline.py rename to tests/examples/software_factory/test_pipeline.py From d8c54bb4ce6fd7ece9c9c34875fd669b489b4cfa Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 14:15:28 +0200 Subject: [PATCH 16/26] feat(pipeline): PipelineEngine checkpoint, audit, and resume (#245 layer 1) First layer of the unification proposal in #245. Port-based PipelineEngine gains the same checkpoint + audit machinery StatePipeline already has, plus resume via run(run_id=...). - PipelineEngine accepts checkpointer + audit_log kwargs - run_id flows into PipelineResult; auto-generated when not supplied - After every successful node: checkpoint written and audit entry recorded - Failed (non-skipped) nodes only audit; resume re-runs them - Skipped nodes neither audit nor checkpoint (no work happened) - run(run_id=...) loads the latest checkpoint, restores context.results for completed nodes, and continues from successors - Resume without checkpointer raises PipelineError; unknown run_id raises - Module-level _serialize_value helper makes checkpoint state and audit snapshots JSON-safe (Pydantic -> model_dump, primitives passthrough, fallback str()) No changes to StatePipeline. No new public types. PipelineEventHandler unchanged (protocol unification is deferred to a later layer). Tests: 13 new in tests/unit/pipeline/test_pipeline_engine_lifecycle.py. Full suite: 1544 passed. Refs: #245 --- fireflyframework_agentic/pipeline/engine.py | 198 +++++++++++++- fireflyframework_agentic/pipeline/result.py | 2 + .../test_pipeline_engine_lifecycle.py | 252 ++++++++++++++++++ 3 files changed, 443 insertions(+), 9 deletions(-) create mode 100644 tests/unit/pipeline/test_pipeline_engine_lifecycle.py diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 4f1e46e6..4ac4fc4f 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -21,6 +21,7 @@ import logging import random import time +import uuid from datetime import UTC, datetime from typing import Any, Protocol, runtime_checkable @@ -30,7 +31,10 @@ otel_trace = None # type: ignore[assignment] from fireflyframework_agentic.config import get_config +from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.observability.usage import default_usage_tracker +from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus +from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy from fireflyframework_agentic.pipeline.result import ( @@ -111,6 +115,28 @@ async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: b ... +def _serialize_value(value: Any) -> Any: + """Best-effort conversion of arbitrary values into JSON-safe form. + + Pydantic models go through ``model_dump(mode="json")``. Primitives, + lists, and dicts pass through. Anything else falls back to ``str()`` + so the serialization layer (checkpoint, audit) doesn't blow up on + exotic objects. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_serialize_value(v) for v in value] + if hasattr(value, "model_dump"): + try: + return value.model_dump(mode="json") + except Exception: + return str(value) + return str(value) + + def start_otel_span(name: str, **attributes: Any) -> Any: """Start an OTel span if observability is enabled, else return ``None``. @@ -143,27 +169,52 @@ def __init__( dag: DAG, *, event_handler: PipelineEventHandler | None = None, + checkpointer: Checkpointer | None = None, + audit_log: AuditLog | None = None, ) -> None: self._dag = dag self._event_handler = event_handler + self._checkpointer = checkpointer + self._audit_log = audit_log async def run( self, context: PipelineContext | None = None, *, inputs: Any = None, + run_id: str | None = None, ) -> PipelineResult: """Execute the pipeline. Parameters: context: Pre-built context, or *None* to create one automatically. inputs: Initial inputs (used if *context* is not provided). + run_id: Identifier for this run. When given alone (no ``context`` + and no ``inputs``), the engine loads the latest checkpoint for + that run and resumes from after the last completed node. + Requires a checkpointer to be configured. Returns: - A :class:`PipelineResult` with all node outputs and trace. + A :class:`PipelineResult` with all node outputs, trace, and + ``run_id`` (use to resume later). """ - if context is None: - context = PipelineContext(inputs=inputs) + if run_id is not None and context is None and inputs is None: + resume_run_id: str = run_id + context, pre_completed, sequence_start = self._load_for_resume(resume_run_id) + all_results: dict[str, NodeResult] = { + nid: nr + for nid in pre_completed + if (nr := context.get_node_result(nid)) is not None and isinstance(nr, NodeResult) + } + else: + if context is None: + context = PipelineContext(inputs=inputs) + pre_completed = set() + sequence_start = 0 + all_results = {} + + if run_id is None: + run_id = uuid.uuid4().hex[:12] # Observability: pipeline-level span _pipeline_span = self._start_otel_span( @@ -176,7 +227,6 @@ async def run( # level are independent and run concurrently via asyncio.gather. levels = self._dag.execution_levels() trace_entries: list[ExecutionTraceEntry] = [] - all_results: dict[str, NodeResult] = {} pipeline_start = time.perf_counter() failed_nodes: set[str] = set() @@ -188,9 +238,12 @@ async def run( pending: set[str] = set() for level in levels: pending.update(level) + pending -= pre_completed # resume: don't re-run nodes already completed - completed: set[str] = set() + completed: set[str] = set(pre_completed) running: dict[str, asyncio.Task[NodeResult]] = {} + inputs_by_node: dict[str, dict[str, Any]] = {} + sequence = sequence_start abort = False def _ready(nid: str) -> bool: @@ -203,8 +256,12 @@ def _ready(nid: str) -> bool: if not abort: for nid in list(pending): if _ready(nid) and nid not in running: + # Gather inputs outside _execute_node so we can stash + # them for the audit snapshot. + gathered = self._gather_inputs(nid, context) + inputs_by_node[nid] = gathered task = asyncio.create_task( - self._execute_node(nid, context, trace_entries, failed_nodes), + self._execute_node(nid, context, trace_entries, failed_nodes, inputs=gathered), ) running[nid] = task pending.discard(nid) @@ -239,6 +296,26 @@ def _ready(nid: str) -> bool: # Emit event callbacks await self._emit_node_result(nr) + # Persist lifecycle: audit every executed visit; checkpoint only + # successful completions (failed nodes must re-run on resume). + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + nr=nr, + inputs_snapshot=inputs_by_node.get(node_id, {}), + trace_entries=trace_entries, + ) + if nr.success and not nr.skipped: + self._save_checkpoint( + run_id=run_id, + node_id=node_id, + sequence=sequence, + context=context, + all_results=all_results, + ) + # Handle failure strategies if not nr.success and not nr.skipped: node = self._dag.nodes.get(node_id) @@ -290,6 +367,7 @@ def _ready(nid: str) -> bool: total_duration_ms=pipeline_elapsed, success=success, usage=usage_summary, + run_id=run_id, ) async def _execute_node( @@ -298,8 +376,14 @@ async def _execute_node( context: PipelineContext, trace_entries: list[ExecutionTraceEntry], failed_nodes: set[str] | None = None, + *, + inputs: dict[str, Any] | None = None, ) -> NodeResult: - """Execute a single node with retries and condition gating.""" + """Execute a single node with retries and condition gating. + + ``inputs`` may be pre-gathered by the caller so the same dict can be + used both for execution and for the audit log's inputs snapshot. + """ # Skip if an upstream node failed with SKIP_DOWNSTREAM strategy if failed_nodes and node_id in failed_nodes: logger.debug("Node '%s' skipped (upstream failure)", node_id) @@ -318,8 +402,9 @@ async def _execute_node( logger.debug("Node '%s' skipped (condition not met)", node_id) return NodeResult(node_id=node_id, skipped=True) - # Gather inputs from upstream edges - inputs = self._gather_inputs(node_id, context) + # Gather inputs from upstream edges (unless caller already did) + if inputs is None: + inputs = self._gather_inputs(node_id, context) _node_span = self._start_otel_span( f"pipeline.node.{node_id}", @@ -447,6 +532,101 @@ async def _emit_node_result(self, nr: NodeResult) -> None: except Exception: # noqa: BLE001 pass + def _load_for_resume(self, run_id: str) -> tuple[PipelineContext, set[str], int]: + """Rebuild context + completed-set from the latest checkpoint.""" + if self._checkpointer is None: + raise PipelineError("Cannot resume: pipeline has no checkpointer configured") + record = self._checkpointer.load_latest(self._dag.name, run_id) + if record is None: + raise PipelineError(f"No checkpoint found for run_id='{run_id}'") + context = PipelineContext(inputs=record.state.get("inputs")) + for nid, nr_dict in record.state.get("results", {}).items(): + try: + context.set_node_result(nid, NodeResult.model_validate(nr_dict)) + except Exception: + logger.warning("Could not restore NodeResult for '%s' on resume", nid) + return context, set(record.completed_nodes), record.sequence + + def _save_checkpoint( + self, + *, + run_id: str, + node_id: str, + sequence: int, + context: PipelineContext, + all_results: dict[str, NodeResult], + ) -> None: + """Persist state after a successful node. No-op if no checkpointer. + + Only successful (non-skipped) nodes go into ``completed_nodes`` so + that resume re-attempts the failures. + """ + if self._checkpointer is None: + return + completed_successful = [nid for nid, nr in all_results.items() if nr.success and not nr.skipped] + state = { + "inputs": _serialize_value(context.inputs), + "results": {nid: all_results[nid].model_dump(mode="json") for nid in completed_successful}, + } + try: + self._checkpointer.save( + CheckpointRecord( + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + state=state, + completed_nodes=completed_successful, + ) + ) + except Exception: + logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) + + def _record_audit( + self, + *, + run_id: str, + node_id: str, + sequence: int, + nr: NodeResult, + inputs_snapshot: dict[str, Any], + trace_entries: list[ExecutionTraceEntry], + ) -> None: + """Write an audit entry for a node visit. No-op if no audit log. + + Skipped nodes are not recorded — they represent work that did NOT + happen and would clutter the trail. + """ + if self._audit_log is None or nr.skipped: + return + # Pull timing from the trace entry the node just wrote. + started_at = completed_at = datetime.now(UTC) + for te in reversed(trace_entries): + if te.node_id == node_id: + started_at = te.started_at + completed_at = te.completed_at + break + status: AuditStatus = "success" if nr.success else "error" + outputs: dict[str, Any] = {"output": _serialize_value(nr.output)} if nr.success else {} + entry = AuditEntry( + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=1, + started_at=started_at, + completed_at=completed_at, + latency_ms=nr.latency_ms or 0.0, + status=status, + inputs_snapshot={k: _serialize_value(v) for k, v in inputs_snapshot.items()}, + outputs_snapshot=outputs, + error_message=nr.error if not nr.success else None, + ) + try: + self._audit_log.record(entry) + except Exception: + logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) + def _gather_inputs(self, node_id: str, context: PipelineContext) -> dict[str, Any]: """Collect inputs for a node from its upstream edges.""" edges = self._dag.incoming_edges(node_id) diff --git a/fireflyframework_agentic/pipeline/result.py b/fireflyframework_agentic/pipeline/result.py index dff1ffff..929b478d 100644 --- a/fireflyframework_agentic/pipeline/result.py +++ b/fireflyframework_agentic/pipeline/result.py @@ -68,6 +68,7 @@ class PipelineResult(BaseModel): total_duration_ms: End-to-end pipeline execution time. success: Whether all nodes completed successfully. usage: Aggregated token usage across all pipeline nodes. + run_id: Identifier for this run; resume with ``engine.run(run_id=...)``. """ pipeline_name: str = "" @@ -77,6 +78,7 @@ class PipelineResult(BaseModel): total_duration_ms: float = 0.0 success: bool = True usage: UsageSummary | None = None + run_id: str = "" @property def failed_nodes(self) -> list[str]: diff --git a/tests/unit/pipeline/test_pipeline_engine_lifecycle.py b/tests/unit/pipeline/test_pipeline_engine_lifecycle.py new file mode 100644 index 00000000..157dc431 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_lifecycle.py @@ -0,0 +1,252 @@ +"""Layer 1 of the unification (#245): PipelineEngine gains checkpoint, audit, and resume. + +These tests pin the contract for port-based pipelines to opt into the same +checkpointing + audit machinery that StatePipeline already has — without +becoming state-based. Resume via ``run(run_id=...)`` is the headline feature. +""" + +from __future__ import annotations + +import pytest + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import FileAuditLog +from fireflyframework_agentic.pipeline.checkpoint import FileCheckpointer +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _CountingStep: + """Step that records how many times its .execute() was called.""" + + def __init__(self, prefix: str = "") -> None: + self._prefix = prefix + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + val = inputs.get("input", "") + return f"{self._prefix}{val}" + + +class _FailOnceStep: + """Step that raises on the first call and succeeds afterward.""" + + def __init__(self) -> None: + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + if self.calls == 1: + raise RuntimeError("flake") + return "b:done" + + +def _chain_dag(*node_ids: str) -> tuple[DAG, dict[str, _CountingStep]]: + dag = DAG("chain") + steps: dict[str, _CountingStep] = {} + for nid in node_ids: + step = _CountingStep(f"{nid}:") + steps[nid] = step + dag.add_node(DAGNode(node_id=nid, step=step)) + for i in range(len(node_ids) - 1): + dag.add_edge(DAGEdge(source=node_ids[i], target=node_ids[i + 1])) + return dag, steps + + +# ---- run_id ----------------------------------------------------------------- + + +async def test_run_returns_non_empty_run_id(): + dag, _ = _chain_dag("a", "b") + engine = PipelineEngine(dag) + result = await engine.run(inputs="x") + assert result.success + assert result.run_id # non-empty + + +async def test_explicit_run_id_is_preserved(): + dag, _ = _chain_dag("a") + engine = PipelineEngine(dag) + result = await engine.run(inputs="x", run_id="manual-id") + assert result.run_id == "manual-id" + + +# ---- checkpointing --------------------------------------------------------- + + +async def test_checkpoint_written_per_successful_node(tmp_path): + dag, _ = _chain_dag("a", "b", "c") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + result = await engine.run(inputs="x") + assert result.success + files = sorted((tmp_path / "chain" / result.run_id).glob("*.json")) + assert len(files) == 3 + # Sequence prefix preserves completion order. + assert files[0].name.endswith("_a.json") + assert files[1].name.endswith("_b.json") + assert files[2].name.endswith("_c.json") + + +async def test_checkpoint_omitted_when_no_checkpointer(tmp_path): + dag, _ = _chain_dag("a", "b") + engine = PipelineEngine(dag) # no checkpointer + result = await engine.run(inputs="x") + assert result.success + # tmp_path should still be empty since no checkpointer was wired. + assert not any(tmp_path.iterdir()) + + +async def test_checkpoint_records_completed_nodes(tmp_path): + dag, _ = _chain_dag("a", "b") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + result = await engine.run(inputs="x") + record = cp.load_latest("chain", result.run_id) + assert record is not None + assert record.completed_nodes == ["a", "b"] + assert record.node_id == "b" + + +# ---- resume ---------------------------------------------------------------- + + +async def test_resume_completed_run_is_a_noop(tmp_path): + dag, steps = _chain_dag("a", "b", "c") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + result = await engine.run(inputs="x") + assert all(s.calls == 1 for s in steps.values()) + # All nodes are completed; resume should not re-execute anything. + result2 = await engine.run(run_id=result.run_id) + assert result2.success + assert all(s.calls == 1 for s in steps.values()) + + +async def test_resume_after_failure_skips_completed_and_finishes(tmp_path): + a_step = _CountingStep("a:") + b_step = _FailOnceStep() + c_step = _CountingStep("c:") + dag = DAG("recoverable") + dag.add_node(DAGNode(node_id="a", step=a_step)) + dag.add_node(DAGNode(node_id="b", step=b_step, failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_node(DAGNode(node_id="c", step=c_step)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="c")) + + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + + result1 = await engine.run(inputs="x") + assert not result1.success + assert a_step.calls == 1 + assert b_step.calls == 1 + assert c_step.calls == 0 + + result2 = await engine.run(run_id=result1.run_id) + assert result2.success + # 'a' was already done — must not be re-executed on resume. + assert a_step.calls == 1 + # 'b' is re-executed (its second attempt succeeds via _FailOnceStep). + assert b_step.calls == 2 + # 'c' runs once, after b succeeds on resume. + assert c_step.calls == 1 + + +async def test_resume_without_checkpointer_raises(): + dag, _ = _chain_dag("a") + engine = PipelineEngine(dag) + with pytest.raises(PipelineError, match="checkpoint"): + await engine.run(run_id="anything") + + +async def test_resume_unknown_run_id_raises(tmp_path): + dag, _ = _chain_dag("a") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + with pytest.raises(PipelineError, match="No checkpoint"): + await engine.run(run_id="missing") + + +# ---- audit log ------------------------------------------------------------- + + +async def test_audit_log_writes_entry_per_node(tmp_path): + dag, _ = _chain_dag("a", "b") + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al) + result = await engine.run(inputs="x") + entries = al.list_entries("chain", result.run_id) + assert len(entries) == 2 + assert [e.node_id for e in entries] == ["a", "b"] + assert all(e.status == "success" for e in entries) + assert all(e.visit == 1 for e in entries) + assert all(e.latency_ms >= 0 for e in entries) + + +async def test_audit_log_captures_failure(tmp_path): + class _Bad: + async def execute(self, ctx, inputs): + raise RuntimeError("boom") + + dag = DAG("fail") + dag.add_node(DAGNode(node_id="bad", step=_Bad(), failure_strategy=FailureStrategy.FAIL_PIPELINE)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al) + result = await engine.run(inputs="x") + assert not result.success + entries = al.list_entries("fail", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "error" + assert entries[0].error_message is not None + assert "boom" in entries[0].error_message + + +async def test_audit_skipped_nodes_not_recorded_as_success(tmp_path): + """Skipped nodes (condition gate) shouldn't show up as successful audits.""" + step = _CountingStep("a:") + dag = DAG("skipping") + dag.add_node(DAGNode(node_id="skipped", step=step, condition=lambda ctx: False)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al) + await engine.run(inputs="x") + entries = al.list_entries("skipping", _last_run_id(al, "skipping")) + # Skipped nodes are not work that happened — leave them out. + assert entries == [] or all(e.status != "success" for e in entries) + + +def _last_run_id(al: FileAuditLog, pipeline: str) -> str: + pipeline_dir = al._root / pipeline + if not pipeline_dir.exists(): + return "" + files = list(pipeline_dir.glob("*.jsonl")) + return files[0].stem if files else "" + + +# ---- combined checkpoint + audit + resume ---------------------------------- + + +async def test_full_stack_resume_with_audit(tmp_path): + cp_dir = tmp_path / "cp" + al_dir = tmp_path / "al" + a_step = _CountingStep("a:") + b_step = _FailOnceStep() + dag = DAG("full") + dag.add_node(DAGNode(node_id="a", step=a_step)) + dag.add_node(DAGNode(node_id="b", step=b_step, failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b")) + cp = FileCheckpointer(cp_dir) + al = FileAuditLog(al_dir) + engine = PipelineEngine(dag, checkpointer=cp, audit_log=al) + + r1 = await engine.run(inputs="x") + assert not r1.success + r2 = await engine.run(run_id=r1.run_id) + assert r2.success + entries = al.list_entries("full", r1.run_id) + # Three entries: a-success, b-error (first attempt), b-success (resume). + assert len(entries) == 3 + assert entries[0].node_id == "a" and entries[0].status == "success" + assert entries[1].node_id == "b" and entries[1].status == "error" + assert entries[2].node_id == "b" and entries[2].status == "success" From 0f8b17393c54966a40bf9e30a0ef893d778c5116 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 14:30:20 +0200 Subject: [PATCH 17/26] feat(pipeline): unified EventHandler protocol (#245 layer 1B) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Completes layer 1 by collapsing PipelineEventHandler + StatePipelineEventHandler into a single EventHandler protocol. The two old protocols stay as deprecated aliases for backward compatibility; the engine adapts to either via signature-inspection dispatch. New surface: - EventHandler protocol with the union of all callbacks (on_pipeline_start, on_node_start with visit, on_node_complete, on_node_error, on_node_skip, on_node_pause, on_pipeline_complete) — every method carries run_id. - PipelineEngine._dispatch(method_name, **kwargs) inspects each callback's signature once (cached) and passes only the parameters the method declares. That keeps legacy PipelineEventHandler implementations working unchanged — the engine silently drops run_id / visit when the method doesn't accept them. - PipelineEngine now emits on_pipeline_start at run() start (was only state-side). - _execute_node forwards run_id so on_node_start callbacks can correlate. - All previous direct event_handler method calls in engine.py go through _dispatch. Backward compatibility: - PipelineEventHandler and StatePipelineEventHandler stay in place with docstrings marking them legacy. Existing implementations of either continue to work without modification. - New code should implement EventHandler. Tests: 7 new in tests/unit/pipeline/test_pipeline_engine_event_dispatch.py covering unified handler (run_id + visit received), legacy handler (run_id dropped silently), mixed handler (some methods rich, some legacy), exception swallowing. Existing test_event_handler.py (legacy shape) still passes unchanged. Full suite: 1551 passed. Refs: #245 --- fireflyframework_agentic/pipeline/__init__.py | 2 + fireflyframework_agentic/pipeline/engine.py | 226 +++++++++++------- .../test_pipeline_engine_event_dispatch.py | 163 +++++++++++++ 3 files changed, 304 insertions(+), 87 deletions(-) create mode 100644 tests/unit/pipeline/test_pipeline_engine_event_dispatch.py diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index a6221a3d..4fa01bd3 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -45,6 +45,7 @@ from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import ( + EventHandler, PipelineEngine, PipelineEventHandler, StatePipelineEventHandler, @@ -86,6 +87,7 @@ "EmbeddingStep", "ExecutionTraceEntry", "FailureStrategy", + "EventHandler", "FanInStep", "FanOutStep", "FileAuditLog", diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 4ac4fc4f..ab4412e3 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -18,6 +18,7 @@ import asyncio import contextlib +import inspect import logging import random import time @@ -47,72 +48,83 @@ @runtime_checkable -class PipelineEventHandler(Protocol): - """Protocol for pipeline progress callbacks. +class EventHandler(Protocol): + """Unified pipeline event handler. Used by :class:`PipelineEngine` and + :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. - Implement any subset of these methods to receive notifications - when pipeline nodes start, complete, or fail. + Implement any subset of these methods; missing ones are no-ops. Exceptions + raised in callbacks are swallowed by the engine so observability never + breaks business logic. + + The engine dispatches events by parameter name. If your method signature + omits a parameter — e.g. legacy implementations that don't accept + ``run_id`` or ``visit`` — the engine simply drops it from the call. + That keeps legacy :class:`PipelineEventHandler` / + :class:`StatePipelineEventHandler` implementations working during the + transition to this unified shape. + + Parameter conventions: + + * ``pipeline_name`` — DAG name, always present. + * ``run_id`` — opaque identifier for a single invocation; lets ops + correlate events across resumes and across multiple parallel runs. + * ``visit`` — re-entry counter on cyclic graphs and fan-out. Starts at + 1 and increments each time a node is re-entered. + * ``latency_ms`` — node wall-clock time, captured at the engine level. + * ``reason`` — human-readable string; for skips and pauses. """ - async def on_node_start(self, node_id: str, pipeline_name: str) -> None: - """Called when a node begins execution.""" - ... + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: ... - async def on_node_complete(self, node_id: str, pipeline_name: str, latency_ms: float) -> None: - """Called when a node completes successfully.""" - ... + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: ... - async def on_node_error(self, node_id: str, pipeline_name: str, error: str) -> None: - """Called when a node fails (after all retries exhausted).""" - ... + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: ... - async def on_node_skip(self, node_id: str, pipeline_name: str, reason: str) -> None: - """Called when a node is skipped.""" - ... + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: ... - async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: - """Called when the entire pipeline finishes.""" - ... + async def on_node_skip(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... -@runtime_checkable -class StatePipelineEventHandler(Protocol): - """Protocol for state-pipeline progress callbacks. + async def on_pipeline_complete( + self, pipeline_name: str, run_id: str, success: bool, duration_ms: float + ) -> None: ... - Mirrors :class:`PipelineEventHandler` but every callback carries the - ``run_id`` so ops can correlate events across resumes, and - ``on_node_start`` carries a ``visit`` counter so cyclic graphs are - distinguishable per iteration. There is no ``on_node_skip`` — state - pipelines abort on failure rather than skipping downstream nodes. - Implement any subset of methods; missing ones are no-ops. - """ +@runtime_checkable +class PipelineEventHandler(Protocol): + """Legacy port-based event handler protocol. Use :class:`EventHandler`. - async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: - """Called once when ``invoke`` begins.""" - ... + Kept for backward compatibility. The engine inspects each callback's + signature and only passes parameters the method declares — so existing + implementations of this protocol continue to work unchanged. New code + should implement :class:`EventHandler` so it receives ``run_id`` and + ``visit`` too. + """ - async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: - """Called each time a node is about to run. ``visit`` starts at 1 - and increments per re-entry (cycles, Send fan-out).""" - ... + async def on_node_start(self, node_id: str, pipeline_name: str) -> None: ... + async def on_node_complete(self, node_id: str, pipeline_name: str, latency_ms: float) -> None: ... + async def on_node_error(self, node_id: str, pipeline_name: str, error: str) -> None: ... + async def on_node_skip(self, node_id: str, pipeline_name: str, reason: str) -> None: ... + async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: ... - async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: - """Called when a node completes successfully.""" - ... - async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: - """Called when a node raises an exception.""" - ... +@runtime_checkable +class StatePipelineEventHandler(Protocol): + """Legacy state-pipeline event handler protocol. Use :class:`EventHandler`. - async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: - """Called when a node returns :class:`Pause`, halting the pipeline - until an external ``invoke(run_id=..., approve_pause=True)`` resumes it.""" - ... + Same shape as :class:`EventHandler` minus ``on_node_skip``. Kept for + backward compatibility; new code should implement :class:`EventHandler`. + """ - async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: - """Called once when ``invoke`` returns.""" - ... + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: ... + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: ... + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: ... + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: ... + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... + async def on_pipeline_complete( + self, pipeline_name: str, run_id: str, success: bool, duration_ms: float + ) -> None: ... def _serialize_value(value: Any) -> Any: @@ -168,7 +180,7 @@ def __init__( self, dag: DAG, *, - event_handler: PipelineEventHandler | None = None, + event_handler: EventHandler | PipelineEventHandler | None = None, checkpointer: Checkpointer | None = None, audit_log: AuditLog | None = None, ) -> None: @@ -176,6 +188,43 @@ def __init__( self._event_handler = event_handler self._checkpointer = checkpointer self._audit_log = audit_log + # Per-method signature cache for legacy-vs-unified dispatch. + self._handler_params: dict[str, set[str]] = {} + + async def _dispatch(self, method_name: str, /, **kwargs: Any) -> None: + """Invoke ``event_handler.method_name`` with the subset of ``kwargs`` + the method's signature actually declares. + + Lets the engine emit events using the unified :class:`EventHandler` + convention while still supporting legacy + :class:`PipelineEventHandler` implementations whose methods don't + accept ``run_id`` or ``visit``. Missing methods and raised + exceptions are silently swallowed — observability never breaks the + pipeline. + """ + if self._event_handler is None: + return + method = getattr(self._event_handler, method_name, None) + if method is None: + return + if method_name not in self._handler_params: + try: + params = inspect.signature(method).parameters + self._handler_params[method_name] = { + name + for name, p in params.items() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + except (TypeError, ValueError): + self._handler_params[method_name] = set(kwargs) + accepted = self._handler_params[method_name] + call_kwargs = {k: v for k, v in kwargs.items() if k in accepted} + with contextlib.suppress(Exception): + await method(**call_kwargs) async def run( self, @@ -216,11 +265,12 @@ async def run( if run_id is None: run_id = uuid.uuid4().hex[:12] - # Observability: pipeline-level span + # Observability: pipeline-level span + start event _pipeline_span = self._start_otel_span( f"pipeline.{self._dag.name}", pipeline=self._dag.name, ) + await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id) # Topological levels ensure that all upstream dependencies of a node # complete before the node itself executes. Nodes within the same @@ -261,7 +311,14 @@ def _ready(nid: str) -> bool: gathered = self._gather_inputs(nid, context) inputs_by_node[nid] = gathered task = asyncio.create_task( - self._execute_node(nid, context, trace_entries, failed_nodes, inputs=gathered), + self._execute_node( + nid, + context, + trace_entries, + failed_nodes, + inputs=gathered, + run_id=run_id, + ), ) running[nid] = task pending.discard(nid) @@ -294,7 +351,7 @@ def _ready(nid: str) -> bool: context.set_node_result(node_id, nr) # Emit event callbacks - await self._emit_node_result(nr) + await self._emit_node_result(nr, run_id) # Persist lifecycle: audit every executed visit; checkpoint only # successful completions (failed nodes must re-run on resume). @@ -345,13 +402,13 @@ def _ready(nid: str) -> bool: success = all(r.success or r.skipped for r in all_results.values()) # Emit pipeline complete event - if self._event_handler is not None and hasattr(self._event_handler, "on_pipeline_complete"): - with contextlib.suppress(Exception): - await self._event_handler.on_pipeline_complete( - self._dag.name, - success, - pipeline_elapsed, - ) + await self._dispatch( + "on_pipeline_complete", + pipeline_name=self._dag.name, + run_id=run_id, + success=success, + duration_ms=pipeline_elapsed, + ) # Aggregate usage across all nodes for this pipeline run usage_summary = self._aggregate_usage(context.correlation_id) @@ -378,6 +435,7 @@ async def _execute_node( failed_nodes: set[str] | None = None, *, inputs: dict[str, Any] | None = None, + run_id: str = "", ) -> NodeResult: """Execute a single node with retries and condition gating. @@ -411,10 +469,14 @@ async def _execute_node( node=node_id, ) - # Emit node start event - if self._event_handler is not None and hasattr(self._event_handler, "on_node_start"): - with contextlib.suppress(Exception): - await self._event_handler.on_node_start(node_id, self._dag.name) + # Emit node start event (visit=1; cycles arrive in a later layer) + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + visit=1, + ) max_retries = node.retry_max backoff_factor = node.backoff_factor @@ -506,31 +568,21 @@ def _aggregate_usage(correlation_id: str) -> Any: except Exception: # noqa: BLE001 return None - async def _emit_node_result(self, nr: NodeResult) -> None: - """Emit event handler callbacks for a completed node.""" + async def _emit_node_result(self, nr: NodeResult, run_id: str) -> None: + """Emit handler callbacks for a completed node via :meth:`_dispatch`.""" if self._event_handler is None: return - try: - if nr.skipped and hasattr(self._event_handler, "on_node_skip"): - await self._event_handler.on_node_skip( - nr.node_id, - self._dag.name, - nr.error or "skipped", - ) - elif nr.success and hasattr(self._event_handler, "on_node_complete"): - await self._event_handler.on_node_complete( - nr.node_id, - self._dag.name, - nr.latency_ms or 0.0, - ) - elif not nr.success and hasattr(self._event_handler, "on_node_error"): - await self._event_handler.on_node_error( - nr.node_id, - self._dag.name, - nr.error or "unknown", - ) - except Exception: # noqa: BLE001 - pass + common = { + "pipeline_name": self._dag.name, + "run_id": run_id, + "node_id": nr.node_id, + } + if nr.skipped: + await self._dispatch("on_node_skip", reason=nr.error or "skipped", **common) + elif nr.success: + await self._dispatch("on_node_complete", latency_ms=nr.latency_ms or 0.0, **common) + else: + await self._dispatch("on_node_error", error=nr.error or "unknown", **common) def _load_for_resume(self, run_id: str) -> tuple[PipelineContext, set[str], int]: """Rebuild context + completed-set from the latest checkpoint.""" diff --git a/tests/unit/pipeline/test_pipeline_engine_event_dispatch.py b/tests/unit/pipeline/test_pipeline_engine_event_dispatch.py new file mode 100644 index 00000000..86125515 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_event_dispatch.py @@ -0,0 +1,163 @@ +"""Layer 1B of the unification (#245): unified EventHandler protocol. + +PipelineEngine now uses a single :class:`EventHandler` protocol that +includes ``run_id`` and ``visit`` on every callback, plus +``on_pipeline_start`` and ``on_node_pause``. Dispatch is by parameter name +via signature inspection, so legacy :class:`PipelineEventHandler` +implementations (port-based, run_id-unaware) still receive the events +they declared. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _Echo: + async def execute(self, ctx, inputs): + return inputs.get("input", "") + + +def _two_node_dag() -> DAG: + dag = DAG("dispatch") + dag.add_node(DAGNode(node_id="a", step=_Echo())) + dag.add_node(DAGNode(node_id="b", step=_Echo())) + dag.add_edge(DAGEdge(source="a", target="b")) + return dag + + +# ---- Unified (rich) handler ------------------------------------------------ + + +@dataclass +class _UnifiedHandler: + """Implements the full EventHandler shape (run_id + visit aware).""" + + started: list[tuple[str, str]] = field(default_factory=list) # (pipeline, run_id) + node_starts: list[tuple[str, int]] = field(default_factory=list) # (node_id, visit) + node_completes: list[str] = field(default_factory=list) + completed: list[tuple[str, bool]] = field(default_factory=list) # (run_id, success) + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.started.append((pipeline_name, run_id)) + + async def on_node_start(self, pipeline_name, run_id, node_id, visit): + self.node_starts.append((node_id, visit)) + + async def on_node_complete(self, pipeline_name, run_id, node_id, latency_ms): + self.node_completes.append(node_id) + + async def on_pipeline_complete(self, pipeline_name, run_id, success, duration_ms): + self.completed.append((run_id, success)) + + +async def test_unified_handler_receives_pipeline_start_with_run_id(): + handler = _UnifiedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert handler.started == [("dispatch", result.run_id)] + + +async def test_unified_handler_receives_visit_on_node_start(): + handler = _UnifiedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + await engine.run(inputs="x") + # Port-based pipelines always emit visit=1 until cycles arrive in a later layer. + assert handler.node_starts == [("a", 1), ("b", 1)] + + +async def test_unified_handler_receives_pipeline_complete_with_run_id(): + handler = _UnifiedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert handler.completed == [(result.run_id, True)] + + +# ---- Legacy PipelineEventHandler (run_id-unaware) -------------------------- + + +@dataclass +class _LegacyHandler: + """Implements the legacy PipelineEventHandler signatures (no run_id).""" + + starts: list[str] = field(default_factory=list) + completes: list[str] = field(default_factory=list) + pipeline_done: list[tuple[str, bool]] = field(default_factory=list) + + async def on_node_start(self, node_id: str, pipeline_name: str) -> None: + self.starts.append(node_id) + + async def on_node_complete(self, node_id: str, pipeline_name: str, latency_ms: float) -> None: + self.completes.append(node_id) + + async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: + self.pipeline_done.append((pipeline_name, success)) + + +async def test_legacy_handler_still_works_without_run_id(): + """The engine drops run_id/visit when the handler doesn't declare them.""" + handler = _LegacyHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert result.success + assert handler.starts == ["a", "b"] + assert handler.completes == ["a", "b"] + assert handler.pipeline_done == [("dispatch", True)] + + +async def test_legacy_handler_without_on_pipeline_start_is_fine(): + """Legacy handlers don't have on_pipeline_start; engine just skips it.""" + handler = _LegacyHandler() + assert not hasattr(handler, "on_pipeline_start") + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + # Should not raise — missing methods are no-ops. + await engine.run(inputs="x") + + +# ---- Mixed handler (some legacy methods, some new) ------------------------ + + +@dataclass +class _MixedHandler: + """Some methods unified-signature, some legacy. Both should fire.""" + + pipeline_starts_with_run_id: list[str] = field(default_factory=list) + legacy_node_starts: list[str] = field(default_factory=list) + + # New (rich) signature + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.pipeline_starts_with_run_id.append(run_id) + + # Legacy signature — engine should still call it without run_id/visit + async def on_node_start(self, node_id: str, pipeline_name: str) -> None: + self.legacy_node_starts.append(node_id) + + +async def test_mixed_handler_dispatches_correctly(): + handler = _MixedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert handler.pipeline_starts_with_run_id == [result.run_id] + assert handler.legacy_node_starts == ["a", "b"] + + +# ---- Exception safety ------------------------------------------------------ + + +async def test_handler_exception_does_not_break_pipeline(): + class _Broken: + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + raise RuntimeError("boom in start") + + async def on_node_start(self, pipeline_name, run_id, node_id, visit): + raise RuntimeError("boom in node start") + + async def on_pipeline_complete(self, pipeline_name, run_id, success, duration_ms): + raise RuntimeError("boom in complete") + + engine = PipelineEngine(_two_node_dag(), event_handler=_Broken()) + result = await engine.run(inputs="x") + assert result.success From 65f7fbe9a0456619388caf149c7a00d3a227b0e1 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 14:38:31 +0200 Subject: [PATCH 18/26] feat(pipeline): branching as DAGEdge.condition (#245 layer 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Second layer of the unification. Branching moves from a per-node concern (BranchStep + DAGNode.condition) to a per-edge property — DAGEdge.condition. The DAG itself is now the single source of truth for routing. DAGEdge gains: condition: Callable[[PipelineContext], bool] | None = None Semantics: - No condition (default): edge is always traversed — legacy behavior. - Condition returns True: edge is "alive"; target proceeds normally. - Condition returns False (or raises — fail closed): edge is inactive. - Multiple outgoing edges from one source = branching: only the targets whose incoming edge is alive run. - Multi-incoming (fan-in) with mixed conditions = OR semantic: target runs if at least one incoming edge is alive. - All-False incoming = target is "dead": skipped without scheduling, cascades to its transitive successors via the existing SKIP_DOWNSTREAM failed-nodes mechanism. Engine changes: - _edge_alive() evaluates a single edge's condition, defaulting to True. - _ready() now requires (a) all sources completed AND (b) at least one alive incoming edge. - _is_dead() detects nodes whose every incoming edge has resolved but none is alive. - _record_skip() marks dead nodes as skipped without executing them and feeds the same audit/event hooks as in-flight skips. - A new post-batch sweep in run() drains dead nodes from `pending`. DAG.to_mermaid() prefixes conditional edges with "if?" so branches are visible in renderings. DAGNode.condition is preserved unchanged — still respected by _execute_node — for backward compatibility. A later layer will deprecate and remove it once callers migrate. Tests: 10 new in tests/unit/pipeline/test_pipeline_engine_edge_condition.py covering: - baseline (no condition behavior unchanged) - true/false single edges - if/else branching from one source - skip cascade to downstream - fan-in OR semantics (mixed and all-dead) - conditions reading upstream NodeResult outputs - raising conditions treated as False (fail closed) - mermaid rendering Full suite: 1561 passed. Refs: #245 --- fireflyframework_agentic/pipeline/dag.py | 26 ++- fireflyframework_agentic/pipeline/engine.py | 70 +++++- .../test_pipeline_engine_edge_condition.py | 219 ++++++++++++++++++ 3 files changed, 311 insertions(+), 4 deletions(-) create mode 100644 tests/unit/pipeline/test_pipeline_engine_edge_condition.py diff --git a/fireflyframework_agentic/pipeline/dag.py b/fireflyframework_agentic/pipeline/dag.py index 80d78b76..6272068d 100644 --- a/fireflyframework_agentic/pipeline/dag.py +++ b/fireflyframework_agentic/pipeline/dag.py @@ -58,12 +58,28 @@ class DAGEdge(BaseModel): target: ID of the downstream node. output_key: Which output from the source to pass (default ``"output"``). input_key: Which input key on the target receives the value (default ``"input"``). + condition: Optional predicate ``(PipelineContext) -> bool`` that gates + edge traversal. When False (or when the callable raises), the + edge is treated as inactive: it neither delivers a signal to + the target nor contributes to scheduling readiness. If every + incoming edge of a target is inactive (and all of them are + resolved), the target is skipped — the same SKIP_DOWNSTREAM + cascade as an upstream failure. ``None`` (default) means + "always traverse". + + Conditions live on edges rather than on ``DAGNode`` because + branching is a routing decision, not a node-internal predicate. + The legacy :attr:`DAGNode.condition` field is preserved for + backward compatibility but is the wrong layer. """ + model_config = {"arbitrary_types_allowed": True} + source: str target: str output_key: str = "output" input_key: str = "input" + condition: Callable[..., bool] | None = None class DAGNode(BaseModel): @@ -249,13 +265,19 @@ def to_mermaid(self) -> str: """Render the topology as a Mermaid flowchart. Edges with ``input_key`` other than the default ``"input"`` are - labelled with that key so port wiring is visible. + labelled with that key so port wiring is visible. Conditional + edges are prefixed ``if?`` so branches stand out. """ lines = ["flowchart TD"] for node_id in self._nodes: lines.append(f" {_mermaid_id(node_id)}[{node_id}]") for edge in self._edges: - label = edge.input_key if edge.input_key and edge.input_key != "input" else None + parts: list[str] = [] + if edge.condition is not None: + parts.append("if?") + if edge.input_key and edge.input_key != "input": + parts.append(edge.input_key) + label = " · ".join(parts) if parts else None arrow = f"-->|{label}|" if label else "-->" lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") return "\n".join(lines) diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index ab4412e3..6d9f7b08 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -296,10 +296,68 @@ async def run( sequence = sequence_start abort = False + def _edge_alive(edge: Any) -> bool: + """An edge is alive if it has no condition, or its condition returns True. + + Raises in the condition itself are treated as False — fail + closed so a broken predicate kills the branch instead of + silently waking up the wrong target. + """ + if edge.condition is None: + return True + try: + return bool(edge.condition(context)) + except Exception: + return False + def _ready(nid: str) -> bool: - """A node is ready when all its upstream deps have completed.""" + """A node is ready when: + (1) every incoming edge's source has completed, AND + (2) at least one of those edges is alive (or it has no edges). + + Entry nodes (no incoming) are always ready once scheduled. + """ edges = self._dag.incoming_edges(nid) - return all(e.source in completed for e in edges) + if not edges: + return True + if not all(e.source in completed for e in edges): + return False + return any(_edge_alive(e) for e in edges) + + def _is_dead(nid: str) -> bool: + """A node is dead when every incoming edge has resolved but + none of them is alive. Cascades via the SKIP_DOWNSTREAM + mechanism so transitive successors are skipped without + being scheduled. + """ + edges = self._dag.incoming_edges(nid) + if not edges: + return False + if not all(e.source in completed for e in edges): + return False + return not any(_edge_alive(e) for e in edges) + + async def _record_skip(nid: str) -> None: + """Mark a node as skipped without scheduling it. Mirrors the + handling of in-flight skips returned by ``_execute_node``. + """ + nonlocal sequence + nr = NodeResult(node_id=nid, skipped=True, error="No alive incoming edge") + all_results[nid] = nr + context.set_node_result(nid, nr) + completed.add(nid) + failed_nodes.add(nid) + failed_nodes.update(self._dag.transitive_successors(nid)) + await self._emit_node_result(nr, run_id) + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=nid, + sequence=sequence, + nr=nr, + inputs_snapshot={}, + trace_entries=trace_entries, + ) while pending or running: # Schedule all ready nodes that aren't already running. @@ -383,6 +441,14 @@ def _ready(nid: str) -> bool: failed_nodes.add(node_id) failed_nodes.update(self._dag.transitive_successors(node_id)) + # Sweep pending for nodes whose incoming edges have resolved + # but none is alive. Mark them skipped and cascade — this is + # what makes DAGEdge.condition a usable branching primitive. + for nid in list(pending): + if _is_dead(nid): + await _record_skip(nid) + pending.discard(nid) + if abort: # Cancel remaining tasks for t in running.values(): diff --git a/tests/unit/pipeline/test_pipeline_engine_edge_condition.py b/tests/unit/pipeline/test_pipeline_engine_edge_condition.py new file mode 100644 index 00000000..39fe7878 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_edge_condition.py @@ -0,0 +1,219 @@ +"""Layer 2 of the unification (#245): branching as DAGEdge.condition. + +DAGEdge now carries an optional predicate that gates traversal. When a +source completes, each outgoing edge's condition is evaluated against the +current PipelineContext. Targets whose incoming edges all evaluate False +are marked skipped (no execution, no result, transitive downstream cascade +via SKIP_DOWNSTREAM). + +This unifies the legacy ``BranchStep`` + ``DAGNode.condition`` machinery +into a single property of the DAG. ``.branch(source, router, mapping)`` — +which today lives in StatePipeline — will be reframed as sugar that adds +conditional edges in a later layer. +""" + +from __future__ import annotations + +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _Echo: + """Step that returns its input verbatim, tagged with a node prefix.""" + + def __init__(self, prefix: str = "") -> None: + self.prefix = prefix + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return f"{self.prefix}{inputs.get('input', '')}" + + +# ---- baseline: edge without condition is unchanged ------------------------ + + +async def test_edge_without_condition_is_unchanged(): + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("plain") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert a.calls == 1 and b.calls == 1 + + +# ---- single conditional edge ---------------------------------------------- + + +async def test_true_condition_lets_target_run(): + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("true-cond") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b", condition=lambda ctx: True)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert b.calls == 1 + + +async def test_false_condition_skips_target(): + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("false-cond") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b", condition=lambda ctx: False)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success # the pipeline as a whole still succeeds + assert a.calls == 1 + assert b.calls == 0 + assert result.outputs["b"].skipped + + +# ---- branching: one source, two conditional targets ----------------------- + + +async def test_branch_chooses_one_of_two_targets(): + """Classic if/else branching via two conditional edges from the same source.""" + a = _Echo("a:") + yes, no = _Echo("yes:"), _Echo("no:") + dag = DAG("if-else") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="yes", step=yes)) + dag.add_node(DAGNode(node_id="no", step=no)) + dag.add_edge( + DAGEdge( + source="a", + target="yes", + condition=lambda ctx: "good" in str(ctx.get_node_result("a").output), + ) + ) + dag.add_edge( + DAGEdge( + source="a", + target="no", + condition=lambda ctx: "good" not in str(ctx.get_node_result("a").output), + ) + ) + result = await PipelineEngine(dag).run(inputs="good run") + assert result.success + assert yes.calls == 1 + assert no.calls == 0 + assert result.outputs["no"].skipped + + +# ---- cascading skip -------------------------------------------------------- + + +async def test_skipped_target_cascades_to_its_downstream(): + a, b, c = _Echo("a:"), _Echo("b:"), _Echo("c:") + dag = DAG("cascade") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="b", condition=lambda ctx: False)) + dag.add_edge(DAGEdge(source="b", target="c")) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert a.calls == 1 + assert b.calls == 0 + assert c.calls == 0 + assert result.outputs["b"].skipped + assert result.outputs["c"].skipped + + +# ---- fan-in with mixed conditions: OR semantics --------------------------- + + +async def test_fanin_runs_if_any_incoming_edge_alive(): + """Two upstreams, one edge False, one edge True → target runs.""" + a, b, c = _Echo("a:"), _Echo("b:"), _Echo("c:") + dag = DAG("fanin") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="c", condition=lambda ctx: False)) + dag.add_edge(DAGEdge(source="b", target="c", condition=lambda ctx: True)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert c.calls == 1 + assert not result.outputs["c"].skipped + + +async def test_fanin_skipped_when_all_incoming_edges_dead(): + a, b, c = _Echo("a:"), _Echo("b:"), _Echo("c:") + dag = DAG("fanin-dead") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="c", condition=lambda ctx: False)) + dag.add_edge(DAGEdge(source="b", target="c", condition=lambda ctx: False)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert c.calls == 0 + assert result.outputs["c"].skipped + + +# ---- condition can read upstream output ----------------------------------- + + +async def test_condition_sees_completed_upstream_output(): + """The condition gets a PipelineContext and can inspect prior node results.""" + + class _Number: + async def execute(self, ctx, inputs): + return 42 + + a = _Number() + b = _Echo("big:") + dag = DAG("cond-reads-upstream") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge( + DAGEdge( + source="a", + target="b", + condition=lambda ctx: ctx.get_node_result("a").output > 10, + ) + ) + result = await PipelineEngine(dag).run(inputs="") + assert result.success + assert b.calls == 1 + + +# ---- condition exception is treated as False ------------------------------ + + +async def test_raising_condition_treated_as_false(): + """If the condition itself raises, the edge is dead — fail closed.""" + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("raising-cond") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + + def raiser(ctx): + raise RuntimeError("oops") + + dag.add_edge(DAGEdge(source="a", target="b", condition=raiser)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert b.calls == 0 + assert result.outputs["b"].skipped + + +# ---- to_mermaid renders conditional edges --------------------------------- + + +def test_mermaid_marks_conditional_edges(): + dag = DAG("viz") + dag.add_node(DAGNode(node_id="a", step=_Echo())) + dag.add_node(DAGNode(node_id="b", step=_Echo())) + dag.add_node(DAGNode(node_id="c", step=_Echo())) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="a", target="c", condition=lambda ctx: True)) + mermaid = dag.to_mermaid() + # Unconditional edge: plain arrow. + assert "a --> b" in mermaid + # Conditional edge: labelled distinctively (we use "if?"). + assert "a -->|if?| c" in mermaid or "a -.->|if?| c" in mermaid From bd109a30bf99f25f529bd89283253c19847ba030 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 14:48:34 +0200 Subject: [PATCH 19/26] feat(pipeline): state as optional overlay on PipelineEngine (#245 layer 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PipelineEngine now accepts an optional state_schema= and state= argument. When configured, nodes that return a dict have it merged into a shared Pydantic state object via reducers; non-dict returns continue to flow through edges as port outputs. Both modes coexist on the same node. This is the layer that reclaims parallelism for state-aware pipelines: the existing topological scheduler runs disjoint state-writing nodes concurrently, and concurrent writes to the same field are reconciled by the reducer declared on that field. Commutative reducers (append, extend, merge_dict) are race-free; replace is last-write-wins. Engine changes: - PipelineEngine(__init__) accepts state_schema: type[BaseModel] | None. Caches discover_reducers(schema) at construction. - PipelineEngine.run() accepts state: BaseModel | None. When state_schema is set: validates state, instantiates from defaults if None, exposes on ctx.state. - After each successful node, if return value is dict and state_schema is set, apply_update reduces it into context.state. - _save_checkpoint persists context.state under "shared_state". - _load_for_resume restores context.state from "shared_state". - PipelineResult gets final_state field — populated from context.state. Refactor: - discover_reducers and apply_update lifted from state_pipeline.py to reducers.py so both StatePipeline and PipelineEngine can share them without a circular import. state_pipeline.py now imports them. - PipelineContext gets `state: Any = None` attribute. - StatePipeline behavior unchanged — verified by 37 existing state-pipeline tests. Tests: 12 new in tests/unit/pipeline/test_pipeline_engine_state_overlay.py covering: - baseline (no state_schema = unchanged behavior) - auto-instantiation from defaults - explicit state arg - dict return as state update (replace, append, extend, merge_dict) - non-dict return as port output (state untouched) - edge condition reading ctx.state - parallel nodes writing same field via append (commutative-reducer parallelism) - resume restores shared state - audit log works under state overlay Full suite: 1573 passed. Refs: #245 --- fireflyframework_agentic/pipeline/context.py | 6 + fireflyframework_agentic/pipeline/engine.py | 45 ++- fireflyframework_agentic/pipeline/reducers.py | 48 +++- fireflyframework_agentic/pipeline/result.py | 3 + .../pipeline/state_pipeline.py | 44 +-- .../test_pipeline_engine_state_overlay.py | 264 ++++++++++++++++++ 6 files changed, 364 insertions(+), 46 deletions(-) create mode 100644 tests/unit/pipeline/test_pipeline_engine_state_overlay.py diff --git a/fireflyframework_agentic/pipeline/context.py b/fireflyframework_agentic/pipeline/context.py index c79e10d9..8af0c998 100644 --- a/fireflyframework_agentic/pipeline/context.py +++ b/fireflyframework_agentic/pipeline/context.py @@ -42,11 +42,17 @@ def __init__( metadata: dict[str, Any] | None = None, correlation_id: str | None = None, memory: MemoryManager | None = None, + state: Any = None, ) -> None: self.inputs = inputs self.metadata: dict[str, Any] = metadata or {} self.correlation_id = correlation_id or uuid.uuid4().hex self.memory: MemoryManager | None = memory + # Shared typed state for state-aware pipelines. None for legacy + # port-based runs. Engine reassigns after each node's reducer-merged + # update — readers within a single in-flight task see the snapshot + # they were scheduled with. + self.state: Any = state self._results: dict[str, Any] = {} # node_id -> NodeResult def set_node_result(self, node_id: str, result: Any) -> None: diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 6d9f7b08..448f1ac7 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -31,6 +31,8 @@ except ImportError: # pragma: no cover - optional dep otel_trace = None # type: ignore[assignment] +from pydantic import BaseModel + from fireflyframework_agentic.config import get_config from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.observability.usage import default_usage_tracker @@ -38,6 +40,7 @@ from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy +from fireflyframework_agentic.pipeline.reducers import Reducer, apply_update, discover_reducers from fireflyframework_agentic.pipeline.result import ( ExecutionTraceEntry, NodeResult, @@ -183,11 +186,17 @@ def __init__( event_handler: EventHandler | PipelineEventHandler | None = None, checkpointer: Checkpointer | None = None, audit_log: AuditLog | None = None, + state_schema: type[BaseModel] | None = None, ) -> None: self._dag = dag self._event_handler = event_handler self._checkpointer = checkpointer self._audit_log = audit_log + # Optional shared-state overlay. When set, nodes returning a dict + # have it merged into the state via reducers; non-dict returns + # continue to flow through edges as port outputs. Both can coexist. + self._state_schema = state_schema + self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {} # Per-method signature cache for legacy-vs-unified dispatch. self._handler_params: dict[str, set[str]] = {} @@ -231,6 +240,7 @@ async def run( context: PipelineContext | None = None, *, inputs: Any = None, + state: BaseModel | None = None, run_id: str | None = None, ) -> PipelineResult: """Execute the pipeline. @@ -238,16 +248,19 @@ async def run( Parameters: context: Pre-built context, or *None* to create one automatically. inputs: Initial inputs (used if *context* is not provided). + state: Optional shared state object for engines configured with + ``state_schema=``. When omitted, the engine instantiates the + schema with its defaults. run_id: Identifier for this run. When given alone (no ``context`` and no ``inputs``), the engine loads the latest checkpoint for that run and resumes from after the last completed node. Requires a checkpointer to be configured. Returns: - A :class:`PipelineResult` with all node outputs, trace, and - ``run_id`` (use to resume later). + A :class:`PipelineResult` with all node outputs, trace, ``run_id`` + (use to resume later), and ``final_state`` for state-aware runs. """ - if run_id is not None and context is None and inputs is None: + if run_id is not None and context is None and inputs is None and state is None: resume_run_id: str = run_id context, pre_completed, sequence_start = self._load_for_resume(resume_run_id) all_results: dict[str, NodeResult] = { @@ -258,6 +271,19 @@ async def run( else: if context is None: context = PipelineContext(inputs=inputs) + # Initialize shared state if configured. + if self._state_schema is not None and context.state is None: + if state is not None: + if not isinstance(state, self._state_schema): + state = self._state_schema.model_validate(state) + context.state = state + else: + try: + context.state = self._state_schema() + except Exception as exc: + raise PipelineError( + f"state required for pipeline with state_schema {self._state_schema.__name__}: {exc}" + ) from exc pre_completed = set() sequence_start = 0 all_results = {} @@ -423,6 +449,10 @@ async def _record_skip(nid: str) -> None: trace_entries=trace_entries, ) if nr.success and not nr.skipped: + # State overlay: a dict return from the node is a state + # update; non-dict returns flow through edges as ports. + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) self._save_checkpoint( run_id=run_id, node_id=node_id, @@ -491,6 +521,7 @@ async def _record_skip(nid: str) -> None: success=success, usage=usage_summary, run_id=run_id, + final_state=context.state, ) async def _execute_node( @@ -663,6 +694,13 @@ def _load_for_resume(self, run_id: str) -> tuple[PipelineContext, set[str], int] context.set_node_result(nid, NodeResult.model_validate(nr_dict)) except Exception: logger.warning("Could not restore NodeResult for '%s' on resume", nid) + # Restore shared state if the run was state-aware. + saved_state = record.state.get("shared_state") + if self._state_schema is not None and isinstance(saved_state, dict): + try: + context.state = self._state_schema.model_validate(saved_state) + except Exception: + logger.warning("Could not restore shared state on resume for run '%s'", run_id) return context, set(record.completed_nodes), record.sequence def _save_checkpoint( @@ -685,6 +723,7 @@ def _save_checkpoint( state = { "inputs": _serialize_value(context.inputs), "results": {nid: all_results[nid].model_dump(mode="json") for nid in completed_successful}, + "shared_state": _serialize_value(context.state), } try: self._checkpointer.save( diff --git a/fireflyframework_agentic/pipeline/reducers.py b/fireflyframework_agentic/pipeline/reducers.py index da91e1dd..989e1697 100644 --- a/fireflyframework_agentic/pipeline/reducers.py +++ b/fireflyframework_agentic/pipeline/reducers.py @@ -30,8 +30,13 @@ class AgentState(BaseModel): from __future__ import annotations +import logging from collections.abc import Callable -from typing import Any +from typing import Any, get_type_hints + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) Reducer = Callable[[Any, Any], Any] @@ -60,3 +65,44 @@ def merge_dict(current: Any, update: Any) -> dict[Any, Any]: base = dict(current) if current else {} base.update(update or {}) return base + + +def discover_reducers(state_schema: type) -> dict[str, Reducer]: + """Inspect ``Annotated[T, reducer_fn]`` annotations on the schema. + + Only ``Annotated[...]`` metadata is consulted — not generic origins like + ``list[...]`` or unions. Fields without an annotated reducer are absent + from the returned dict; callers should treat absence as :func:`replace`. + """ + out: dict[str, Reducer] = {} + try: + hints = get_type_hints(state_schema, include_extras=True) + except Exception: + return out + for field_name, hint in hints.items(): + metadata = getattr(hint, "__metadata__", None) + if not metadata: + continue + for meta in metadata: + if callable(meta): + out[field_name] = meta + break + return out + + +def apply_update(state: BaseModel, update: dict[str, Any], reducers: dict[str, Reducer]) -> BaseModel: + """Return a new state object with ``update`` merged into ``state`` via reducers. + + Keys present in ``update`` but missing from the schema are logged and + ignored — incremental schema evolution stays painless. + """ + if not update: + return state + new_values = state.model_dump() + for key, value in update.items(): + if key not in new_values: + logger.warning("State update key '%s' not in schema %s; ignored.", key, type(state).__name__) + continue + reducer = reducers.get(key, replace) + new_values[key] = reducer(new_values[key], value) + return type(state).model_validate(new_values) diff --git a/fireflyframework_agentic/pipeline/result.py b/fireflyframework_agentic/pipeline/result.py index 929b478d..1b3e26ee 100644 --- a/fireflyframework_agentic/pipeline/result.py +++ b/fireflyframework_agentic/pipeline/result.py @@ -79,6 +79,9 @@ class PipelineResult(BaseModel): success: bool = True usage: UsageSummary | None = None run_id: str = "" + # Final shared state for pipelines configured with state_schema. None + # when the engine had no state overlay. + final_state: Any = None @property def failed_nodes(self) -> list[str]: diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 57a1cc0b..c8c9247b 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -32,7 +32,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, get_type_hints +from typing import TYPE_CHECKING, Any from pydantic import BaseModel @@ -41,10 +41,10 @@ from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id from fireflyframework_agentic.pipeline.engine import start_otel_span +from fireflyframework_agentic.pipeline.reducers import apply_update, discover_reducers if TYPE_CHECKING: from fireflyframework_agentic.pipeline.engine import StatePipelineEventHandler -from fireflyframework_agentic.pipeline.reducers import Reducer, replace logger = logging.getLogger(__name__) @@ -135,46 +135,6 @@ class StatePipelineResult: pause_reason: str | None = None -def discover_reducers(state_schema: type) -> dict[str, Reducer]: - """Inspect ``Annotated[T, reducer_fn]`` annotations on the schema. - - Only ``Annotated[...]`` metadata is consulted — not generic origins like - ``list[...]`` or unions. Fields without an annotated reducer are absent - from the returned dict; callers should treat absence as :func:`replace`. - """ - out: dict[str, Reducer] = {} - try: - hints = get_type_hints(state_schema, include_extras=True) - except Exception: - return out - for field_name, hint in hints.items(): - # Annotated[...] is the only metadata-bearing form we care about. - metadata = getattr(hint, "__metadata__", None) - if not metadata: - continue - for meta in metadata: - if callable(meta): - out[field_name] = meta - break - return out - - -def apply_update(state: BaseModel, update: dict[str, Any], reducers: dict[str, Reducer]) -> BaseModel: - """Return a new state object with ``update`` merged into ``state`` via reducers.""" - if not update: - return state - new_values = state.model_dump() - for key, value in update.items(): - if key not in new_values: - # Tolerate unknown keys with a warning rather than failing — - # makes incremental schema evolution painless. - logger.warning("State update key '%s' not in schema %s; ignored.", key, type(state).__name__) - continue - reducer = reducers.get(key, replace) - new_values[key] = reducer(new_values[key], value) - return type(state).model_validate(new_values) - - class StatePipeline: """Compiled state-based pipeline. Returned by ``PipelineBuilder.build()`` when a ``state=`` schema is configured. diff --git a/tests/unit/pipeline/test_pipeline_engine_state_overlay.py b/tests/unit/pipeline/test_pipeline_engine_state_overlay.py new file mode 100644 index 00000000..f503a2cd --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_state_overlay.py @@ -0,0 +1,264 @@ +"""Layer 3 of the unification (#245): state as optional overlay on PipelineEngine. + +PipelineEngine now accepts ``state_schema=`` and ``state=`` arguments. When +configured, nodes that return a dict have it merged into a shared Pydantic +state object via reducers (replace, append, extend, merge_dict). Non-dict +returns continue to flow as port outputs — both modes coexist on the same +node. + +This reclaims parallelism for state-aware pipelines: nodes that write +disjoint state fields can run concurrently via the existing topological +scheduler. Concurrent writes to the same field are merged by the reducer +declared on that field (commutative reducers like ``append`` are safe; +``replace`` is last-write-wins). +""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline.audit import FileAuditLog +from fireflyframework_agentic.pipeline.checkpoint import FileCheckpointer +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict + +# ---- Schemas --------------------------------------------------------------- + + +class _SimpleState(BaseModel): + counter: int = 0 + note: str = "" + + +class _ListState(BaseModel): + items: Annotated[list[str], append] = [] + batch: Annotated[list[str], extend] = [] + + +class _MergeState(BaseModel): + bag: Annotated[dict[str, int], merge_dict] = {} + + +# ---- Step helpers ---------------------------------------------------------- + + +def _step_returns(value): + """Build a step whose execute() always returns `value`.""" + + class _Step: + async def execute(self, ctx, inputs): + return value + + return _Step() + + +def _step_reads_state(field: str): + """Build a step that returns the current state's `field` as a port output.""" + + class _Step: + async def execute(self, ctx, inputs): + return getattr(ctx.state, field) + + return _Step() + + +# ---- baseline: no state_schema = unchanged -------------------------------- + + +async def test_engine_without_state_schema_is_unchanged(): + dag = DAG("plain") + dag.add_node(DAGNode(node_id="a", step=_step_returns("port-value"))) + engine = PipelineEngine(dag) # no state_schema + result = await engine.run(inputs="x") + assert result.success + assert result.final_state is None + assert result.outputs["a"].output == "port-value" + + +# ---- engine instantiates state from defaults when none passed ------------- + + +async def test_state_schema_with_defaults_is_auto_instantiated(): + dag = DAG("auto-state") + dag.add_node(DAGNode(node_id="a", step=_step_returns(None))) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert isinstance(result.final_state, _SimpleState) + assert result.final_state.counter == 0 + + +# ---- explicit state passed via run() -------------------------------------- + + +async def test_state_arg_is_used_when_passed(): + dag = DAG("explicit-state") + dag.add_node(DAGNode(node_id="a", step=_step_returns(None))) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x", state=_SimpleState(counter=42, note="hi")) + assert result.success + assert result.final_state.counter == 42 + assert result.final_state.note == "hi" + + +# ---- node returning dict merges into state via reducer -------------------- + + +async def test_dict_return_is_state_update_under_replace(): + dag = DAG("dict-replace") + dag.add_node(DAGNode(node_id="a", step=_step_returns({"counter": 7}))) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.counter == 7 + + +async def test_append_reducer_accumulates_across_nodes(): + a, b = _step_returns({"items": "first"}), _step_returns({"items": "second"}) + dag = DAG("appender") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_ListState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.items == ["first", "second"] + + +async def test_extend_reducer_concatenates(): + a, b = _step_returns({"batch": ["x", "y"]}), _step_returns({"batch": ["z"]}) + dag = DAG("extender") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_ListState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.batch == ["x", "y", "z"] + + +async def test_merge_dict_reducer_merges(): + a, b = _step_returns({"bag": {"k1": 1}}), _step_returns({"bag": {"k2": 2}}) + dag = DAG("merger") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_MergeState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.bag == {"k1": 1, "k2": 2} + + +# ---- non-dict return still flows as a port output ------------------------- + + +async def test_non_dict_return_is_still_a_port_output(): + """A node can write state OR emit a port value — its return type decides.""" + a = _step_returns("port-value") # str, not dict → port output + b = _step_reads_state("note") + dag = DAG("mixed") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert result.outputs["a"].output == "port-value" # port preserved + assert result.final_state.note == "" # state untouched by 'a' + + +# ---- conditions can read state -------------------------------------------- + + +async def test_edge_condition_reads_ctx_state(): + a = _step_returns({"counter": 5}) + b = _step_returns(None) + dag = DAG("cond-on-state") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge( + DAGEdge( + source="a", + target="b", + condition=lambda ctx: ctx.state.counter > 3, + ) + ) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert not result.outputs["b"].skipped + + +# ---- parallelism: disjoint fields, commutative reducer --------------------- + + +async def test_parallel_nodes_with_commutative_reducer_accumulate(): + """Two nodes at the same level both append to items; both contributions land.""" + a = _step_returns({"items": "from-a"}) + b = _step_returns({"items": "from-b"}) + c = _step_returns(None) + dag = DAG("parallel-append") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="c")) + dag.add_edge(DAGEdge(source="b", target="c")) + engine = PipelineEngine(dag, state_schema=_ListState) + result = await engine.run(inputs="x") + assert result.success + assert sorted(result.final_state.items) == ["from-a", "from-b"] + + +# ---- checkpoint + resume restores state ----------------------------------- + + +async def test_resume_restores_shared_state(tmp_path): + """Run with state, fail mid-pipeline, resume — state survives.""" + + class _FailOnce: + def __init__(self): + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + if self.calls == 1: + raise RuntimeError("flake") + return {"counter": ctx.state.counter + 100} + + from fireflyframework_agentic.pipeline.dag import FailureStrategy + + a = _step_returns({"counter": 1, "note": "from-a"}) + b = _FailOnce() + dag = DAG("resume-state") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b, failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp, state_schema=_SimpleState) + r1 = await engine.run(inputs="x") + assert not r1.success + # After 'a' succeeds, the checkpoint should contain state.counter=1. + assert r1.final_state.counter == 1 + + r2 = await engine.run(run_id=r1.run_id) + assert r2.success + # On resume: state.counter restored to 1, then b adds 100. + assert r2.final_state.counter == 101 + assert r2.final_state.note == "from-a" # state preserved + + +# ---- audit log works alongside state -------------------------------------- + + +async def test_audit_records_under_state_overlay(tmp_path): + dag = DAG("audit-state") + dag.add_node(DAGNode(node_id="a", step=_step_returns({"counter": 3}))) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al, state_schema=_SimpleState) + result = await engine.run(inputs="x") + entries = al.list_entries("audit-state", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "success" From 7a6cf0b86081a7545b6bb9dcdc371e9809b037f1 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 14:59:05 +0200 Subject: [PATCH 20/26] feat(pipeline): cycle-aware scheduler and topo-sort safety (#245 layer 4) Fourth layer of the unification. PipelineEngine gains the ability to run cyclic DAGs (ReAct loops, retry-with-critique) and the long-standing silent corruption in topological_sort/execution_levels on cyclic graphs is fixed. Engine changes: - PipelineEngine(__init__) accepts recursion_limit: int = 25 (matches StatePipeline). Bounds visit count per node in cyclic mode. - run() detects cyclic DAGs via dag.is_cyclic() and routes to a new _run_cyclic helper. Acyclic graphs use the existing topological scheduler unchanged. - _run_cyclic: sequential frontier-following. Picks the unique alive outgoing edge from each completed node, increments per-node visit count, enforces recursion_limit. Fan-out to multiple alive edges raises a clear PipelineError (multi-target cyclic fan-out arrives with Send in layer 5). - _record_audit accepts visit=, defaulting to 1 for the acyclic scheduler. The cyclic scheduler passes the actual visit number so audit entries distinguish iterations. DAG changes (silent-corruption fix from #245's review): - topological_sort() now raises PipelineError on cyclic graphs instead of returning a wrong-length list with a misleading "should not reach here" message. - execution_levels() now raises PipelineError on cyclic graphs instead of silently producing incomplete levels. Both methods document is_cyclic() as the right pre-check. Tests: 7 new in tests/unit/pipeline/test_pipeline_engine_cycles.py covering: - topological_sort raises on cyclic DAGs - execution_levels raises on cyclic DAGs - ReAct-style finite loop terminates correctly - Recursion limit halts runaway cycle - Default recursion_limit is 25 - Audit visit number increments per iteration - Acyclic DAG with allow_cycles=True still uses the parallel scheduler Full suite: 1580 passed. Refs: #245 --- fireflyframework_agentic/pipeline/dag.py | 20 +- fireflyframework_agentic/pipeline/engine.py | 183 +++++++++++++++++- .../pipeline/test_pipeline_engine_cycles.py | 167 ++++++++++++++++ 3 files changed, 364 insertions(+), 6 deletions(-) create mode 100644 tests/unit/pipeline/test_pipeline_engine_cycles.py diff --git a/fireflyframework_agentic/pipeline/dag.py b/fireflyframework_agentic/pipeline/dag.py index 6272068d..1656cc42 100644 --- a/fireflyframework_agentic/pipeline/dag.py +++ b/fireflyframework_agentic/pipeline/dag.py @@ -166,7 +166,16 @@ def add_edge(self, edge: DAGEdge) -> None: # -- Query ------------------------------------------------------------- def topological_sort(self) -> list[str]: - """Return node IDs in topological order (Kahn's algorithm).""" + """Return node IDs in topological order (Kahn's algorithm). + + Raises :class:`PipelineError` if the DAG contains a cycle. Cyclic + graphs have no topological order; the caller should branch on + :meth:`is_cyclic` first (or use the engine's cycle-aware scheduler). + """ + if self._has_cycle(): + raise PipelineError( + "topological_sort() is not defined on cyclic graphs; use is_cyclic() to branch before calling." + ) in_deg = dict(self._in_degree) for nid in self._nodes: in_deg.setdefault(nid, 0) @@ -181,16 +190,19 @@ def topological_sort(self) -> list[str]: if in_deg[neighbour] == 0: queue.append(neighbour) - if len(order) != len(self._nodes): - raise PipelineError("DAG contains a cycle (should not reach here)") return order def execution_levels(self) -> list[list[str]]: """Group nodes into levels for parallel execution. Nodes at the same level have no inter-dependencies and can be - executed concurrently. + executed concurrently. Raises :class:`PipelineError` on cyclic + DAGs — levels are undefined when cycles exist. """ + if self._has_cycle(): + raise PipelineError( + "execution_levels() is not defined on cyclic graphs; use is_cyclic() to branch before calling." + ) in_deg = dict(self._in_degree) for nid in self._nodes: in_deg.setdefault(nid, 0) diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 448f1ac7..556fc071 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -187,6 +187,7 @@ def __init__( checkpointer: Checkpointer | None = None, audit_log: AuditLog | None = None, state_schema: type[BaseModel] | None = None, + recursion_limit: int = 25, ) -> None: self._dag = dag self._event_handler = event_handler @@ -197,6 +198,8 @@ def __init__( # continue to flow through edges as port outputs. Both can coexist. self._state_schema = state_schema self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {} + # Max visits per node for cycle-aware runs. Matches StatePipeline's default. + self._recursion_limit = recursion_limit # Per-method signature cache for legacy-vs-unified dispatch. self._handler_params: dict[str, set[str]] = {} @@ -298,6 +301,20 @@ async def run( ) await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id) + # Cycle-aware mode: a separate sequential frontier-following scheduler + # that respects ``recursion_limit``. The topological scheduler below + # cannot run cyclic graphs because execution_levels()/topological_sort() + # are undefined on them. + if self._dag.is_cyclic(): + return await self._run_cyclic( + context=context, + run_id=run_id, + all_results=all_results, + pre_completed=pre_completed, + sequence_start=sequence_start, + pipeline_span=_pipeline_span, + ) + # Topological levels ensure that all upstream dependencies of a node # complete before the node itself executes. Nodes within the same # level are independent and run concurrently via asyncio.gather. @@ -524,6 +541,166 @@ async def _record_skip(nid: str) -> None: final_state=context.state, ) + async def _run_cyclic( + self, + *, + context: PipelineContext, + run_id: str, + all_results: dict[str, NodeResult], + pre_completed: set[str], + sequence_start: int, + pipeline_span: Any, + ) -> PipelineResult: + """Sequential frontier-following scheduler for cyclic DAGs. + + Walks the graph one node at a time, picking the next node from each + completed node's alive outgoing edges. Visit counts are tracked per + node and bounded by ``self._recursion_limit``. Within this mode, + having multiple alive outgoing edges is currently an error — parallel + cyclic fan-out is the job of :class:`Send` in a later layer. + """ + trace_entries: list[ExecutionTraceEntry] = [] + pipeline_start = time.perf_counter() + visit_counts: dict[str, int] = dict.fromkeys(pre_completed, 1) + sequence = sequence_start + + # Entry node: insertion order, matching StatePipeline. + nodes_in_order = list(self._dag.nodes) + if not nodes_in_order: + raise PipelineError("Pipeline has no nodes") + current: str | None = nodes_in_order[0] + # Skip past anything already completed during this resumed run. + while current is not None and current in pre_completed: + current = self._cyclic_next(current, context) + + result: PipelineResult | None = None + try: + while current is not None: + visit_counts[current] = visit_counts.get(current, 0) + 1 + visit_n = visit_counts[current] + if visit_n > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node " + f"'{current}'. Raise recursion_limit= or fix the routing logic." + ) + logger.error(msg) + nr = NodeResult(node_id=current, success=False, error=msg) + all_results[current] = nr + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=current, + sequence=sequence, + nr=nr, + inputs_snapshot={}, + trace_entries=trace_entries, + ) + break + + gathered = self._gather_inputs(current, context) + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=current, + visit=visit_n, + ) + nr = await self._execute_node( + current, + context, + trace_entries, + None, + inputs=gathered, + run_id=run_id, + ) + all_results[current] = nr + context.set_node_result(current, nr) + await self._emit_node_result(nr, run_id) + + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=current, + sequence=sequence, + nr=nr, + inputs_snapshot=gathered, + trace_entries=trace_entries, + visit=visit_n, + ) + + if not nr.success and not nr.skipped: + break + + if not nr.skipped: + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) + self._save_checkpoint( + run_id=run_id, + node_id=current, + sequence=sequence, + context=context, + all_results=all_results, + ) + + current = self._cyclic_next(current, context) + finally: + elapsed = (time.perf_counter() - pipeline_start) * 1000 + success = ( + result.success if result is not None else all(r.success or r.skipped for r in all_results.values()) + ) + await self._dispatch( + "on_pipeline_complete", + pipeline_name=self._dag.name, + run_id=run_id, + success=success, + duration_ms=elapsed, + ) + if pipeline_span is not None: + with contextlib.suppress(Exception): + pipeline_span.end() + + final_output = None + return PipelineResult( + pipeline_name=self._dag.name, + outputs=all_results, + final_output=final_output, + execution_trace=trace_entries, + total_duration_ms=elapsed, + success=success, + usage=None, + run_id=run_id, + final_state=context.state, + ) + + def _cyclic_next(self, current: str, context: PipelineContext) -> str | None: + """Pick the next node by following the unique alive outgoing edge. + + Returns ``None`` when no outgoing edge is alive (terminus). Raises + if more than one is alive — parallel cyclic fan-out lands with + :class:`Send` in a later layer. + """ + + def _alive(edge: Any) -> bool: + if edge.condition is None: + return True + try: + return bool(edge.condition(context)) + except Exception: + return False + + outgoing = [e for e in self._dag.edges if e.source == current] + alive = [e for e in outgoing if _alive(e)] + if not alive: + return None + if len(alive) > 1: + raise PipelineError( + f"Cyclic node '{current}' has multiple alive outgoing edges " + f"({[e.target for e in alive]}). Parallel cyclic fan-out arrives " + f"with Send in a later layer; for now, use mutually exclusive " + f"edge conditions." + ) + return alive[0].target + async def _execute_node( self, node_id: str, @@ -748,11 +925,13 @@ def _record_audit( nr: NodeResult, inputs_snapshot: dict[str, Any], trace_entries: list[ExecutionTraceEntry], + visit: int = 1, ) -> None: """Write an audit entry for a node visit. No-op if no audit log. Skipped nodes are not recorded — they represent work that did NOT - happen and would clutter the trail. + happen and would clutter the trail. ``visit`` defaults to 1 for the + acyclic scheduler and is supplied by the cyclic scheduler. """ if self._audit_log is None or nr.skipped: return @@ -770,7 +949,7 @@ def _record_audit( run_id=run_id, node_id=node_id, sequence=sequence, - visit=1, + visit=visit, started_at=started_at, completed_at=completed_at, latency_ms=nr.latency_ms or 0.0, diff --git a/tests/unit/pipeline/test_pipeline_engine_cycles.py b/tests/unit/pipeline/test_pipeline_engine_cycles.py new file mode 100644 index 00000000..d45b1c16 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_cycles.py @@ -0,0 +1,167 @@ +"""Layer 4 of the unification (#245): cycle-aware scheduler. + +PipelineEngine accepts ``recursion_limit=`` and, when the DAG is cyclic +(allow_cycles=True and a cycle is actually present), switches to a +sequential frontier-following scheduler. Each node visit increments a +per-node counter; exceeding ``recursion_limit`` halts the run with an +explanatory failure. + +This also patches the silent-corruption hazard in :meth:`DAG.topological_sort` +and :meth:`DAG.execution_levels` — both now raise on cyclic DAGs instead +of producing partial / wrong output. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.reducers import append + +# ---- topology-API safety --------------------------------------------------- + + +def test_topological_sort_raises_on_cyclic_dag(): + dag = DAG("cyclic", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=None)) + dag.add_node(DAGNode(node_id="b", step=None)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="a")) + with pytest.raises(PipelineError, match="cyclic"): + dag.topological_sort() + + +def test_execution_levels_raises_on_cyclic_dag(): + dag = DAG("cyclic-lev", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=None)) + dag.add_node(DAGNode(node_id="b", step=None)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="a")) + with pytest.raises(PipelineError, match="cyclic"): + dag.execution_levels() + + +# ---- cyclic execution ------------------------------------------------------ + + +class _CounterState(BaseModel): + counter: int = 0 + log: Annotated[list[str], append] = [] + + +def _bump(label: str, by: int = 1): + """Return a step that records its label and bumps counter by `by`.""" + + class _Step: + def __init__(self): + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return {"counter": ctx.state.counter + by, "log": label} + + return _Step() + + +async def test_cyclic_dag_loops_until_condition_fails(): + """Loop: incrementer -> guard. Guard's outgoing edge back to incrementer + is alive while counter < 3. Loop exits when guard's continue edge dies.""" + inc = _bump("inc", by=1) + # guard is a no-op pass-through. + + class _Pass: + calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return None + + guard = _Pass() + dag = DAG("loop", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=guard)) + dag.add_edge(DAGEdge(source="inc", target="guard")) + # Continue edge: re-enter inc while counter < 3. + dag.add_edge(DAGEdge(source="guard", target="inc", condition=lambda ctx: ctx.state.counter < 3)) + engine = PipelineEngine(dag, state_schema=_CounterState, recursion_limit=10) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.counter == 3 + assert inc.calls == 3 + # guard runs after each inc. + assert guard.calls == 3 + + +async def test_recursion_limit_halts_runaway_cycle(): + inc = _bump("inc") + + class _Pass: + async def execute(self, ctx, inputs): + return None + + dag = DAG("infinite", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=_Pass())) + dag.add_edge(DAGEdge(source="inc", target="guard")) + dag.add_edge(DAGEdge(source="guard", target="inc")) # always alive — runaway + engine = PipelineEngine(dag, state_schema=_CounterState, recursion_limit=5) + result = await engine.run(inputs="") + assert not result.success + assert ( + "recursion" in (result.outputs.get("inc") and result.outputs["inc"].error or "").lower() + or "recursion" in (result.outputs.get("guard") and result.outputs["guard"].error or "").lower() + ) + + +async def test_recursion_limit_default_is_25(): + """The engine's default recursion_limit matches StatePipeline's (25).""" + engine = PipelineEngine(DAG("x")) + assert engine._recursion_limit == 25 # noqa: SLF001 + + +async def test_audit_records_visit_per_iteration(tmp_path): + """Each iteration of a cycle gets its own audit entry with incrementing visit.""" + from fireflyframework_agentic.pipeline.audit import FileAuditLog + + inc = _bump("inc") + + class _Pass: + async def execute(self, ctx, inputs): + return None + + dag = DAG("audited-loop", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=_Pass())) + dag.add_edge(DAGEdge(source="inc", target="guard")) + dag.add_edge(DAGEdge(source="guard", target="inc", condition=lambda ctx: ctx.state.counter < 2)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, state_schema=_CounterState, audit_log=al, recursion_limit=10) + result = await engine.run(inputs="") + assert result.success + entries = al.list_entries("audited-loop", result.run_id) + inc_visits = sorted([e.visit for e in entries if e.node_id == "inc"]) + assert inc_visits == [1, 2] + + +# ---- acyclic still works --------------------------------------------------- + + +async def test_acyclic_dag_with_allow_cycles_true_runs_normally(): + """allow_cycles=True doesn't force cyclic mode if there are no cycles.""" + a = _bump("a") + b = _bump("b") + dag = DAG("ac", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_CounterState) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.counter == 2 + assert a.calls == 1 + assert b.calls == 1 From edaf9f7bcdd099330c01f2fb2f1d52b7b2a95639 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 15:11:03 +0200 Subject: [PATCH 21/26] feat(pipeline): Pause and Send in unified PipelineEngine (#245 layer 5) Fifth layer of the unification. PipelineEngine now recognizes the same control sentinels that StatePipeline uses today: - A node returning Pause(reason=...) halts the pipeline cleanly. The run resumes with engine.run(run_id=..., approve_pause=True). - A node returning Send or list[Send] triggers parallel fan-out where each Send's target runs concurrently with the payload merged into a per-worker state copy. Reducers merge worker outputs back into shared state. Refactor (no behavior change): - Pause and Send dataclasses moved from state_pipeline.py to engine.py so PipelineEngine can recognize them without a circular import. state_pipeline.py now imports them from engine. Public re-exports from fireflyframework_agentic.pipeline are unchanged. Engine changes: - run() accepts approve_pause: bool = False kwarg. Resuming a paused checkpoint without approve_pause=True raises PipelineError with the pause reason. - _save_checkpoint accepts paused=, pause_reason= kwargs, persisted on CheckpointRecord (the fields landed in Layer 1A). - _load_for_resume enforces the approve_pause gate. - Main loop branches on Pause: emits on_node_pause event, checkpoints with paused=True, sets pending_pause and aborts. - Main loop branches on Send: dispatches workers via the new _run_sends helper, marks workers as completed so the scheduler does not re-run them. - _run_sends: validates target IDs up front; per-worker PipelineContext with its own state copy (payload applied via reducers); asyncio.gather across workers; results merge back into shared state via reducers; any worker failure aborts the pipeline. - _is_send_payload helper at module level. Result changes: - PipelineResult gains paused: bool, paused_node: str | None, pause_reason: str | None. Mirrors StatePipelineResult. Tests: 7 new in tests/unit/pipeline/test_pipeline_engine_pause_send.py covering: - Pause halts the pipeline and records paused state in result + checkpoint - Resume without approve_pause=True raises - Resume with approve_pause=True continues from the paused node's successor - list[Send] dispatches workers concurrently with per-worker state copies - Single Send is treated as list[Send] of one - Unknown Send target marks the pipeline as failed - Pause and Send remain re-exported from the pipeline package Full suite: 1587 passed. Refs: #245 --- fireflyframework_agentic/pipeline/engine.py | 207 +++++++++++++++++- fireflyframework_agentic/pipeline/result.py | 5 + .../pipeline/state_pipeline.py | 43 +--- .../test_pipeline_engine_pause_send.py | 203 +++++++++++++++++ 4 files changed, 413 insertions(+), 45 deletions(-) create mode 100644 tests/unit/pipeline/test_pipeline_engine_pause_send.py diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 556fc071..cebee341 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -23,6 +23,7 @@ import random import time import uuid +from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, Protocol, runtime_checkable @@ -130,6 +131,59 @@ async def on_pipeline_complete( ) -> None: ... +@dataclass +class Pause: + """Human-in-the-loop sentinel returned by a node to halt the pipeline. + + A node returns ``Pause(reason="...")`` when external approval is required + before the pipeline may continue. The engine then: + + 1. Writes a checkpoint with ``paused=True`` and the reason set. + 2. Emits ``on_node_pause`` on the configured event handler. + 3. Returns a :class:`PipelineResult` with ``paused=True`` and + ``success=False`` — the run is not finished, but it did not fail + either. + + Resume after approval:: + + result = await engine.run(run_id=paused_run_id, approve_pause=True) + + The successor of the paused node runs next — the pause node itself is + not re-executed. Without ``approve_pause=True``, resuming a paused run + raises :class:`PipelineError`. + """ + + reason: str + + +@dataclass +class Send: + """Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state. + + A node may return a single ``Send`` or ``list[Send]`` to dispatch one or + more targets concurrently. Each Send's payload is applied to a *copy* of + the current state before its target runs; the target's return is then + merged back into shared state via reducers. + + Replaces the legacy ``FanOutStep`` pattern with a first-class primitive. + """ + + target: str + payload: dict[str, Any] + + +def _is_send_payload(value: Any) -> bool: + """True when a node's return value is a single :class:`Send` or a + non-empty ``list[Send]``. Drives the runtime fan-out branch in + :meth:`PipelineEngine.run`. + """ + if isinstance(value, Send): + return True + if isinstance(value, list) and value and all(isinstance(s, Send) for s in value): + return True + return False + + def _serialize_value(value: Any) -> Any: """Best-effort conversion of arbitrary values into JSON-safe form. @@ -245,6 +299,7 @@ async def run( inputs: Any = None, state: BaseModel | None = None, run_id: str | None = None, + approve_pause: bool = False, ) -> PipelineResult: """Execute the pipeline. @@ -265,7 +320,7 @@ async def run( """ if run_id is not None and context is None and inputs is None and state is None: resume_run_id: str = run_id - context, pre_completed, sequence_start = self._load_for_resume(resume_run_id) + context, pre_completed, sequence_start = self._load_for_resume(resume_run_id, approve_pause=approve_pause) all_results: dict[str, NodeResult] = { nid: nr for nid in pre_completed @@ -338,6 +393,7 @@ async def run( inputs_by_node: dict[str, dict[str, Any]] = {} sequence = sequence_start abort = False + pending_pause: tuple[str, str] | None = None # (node_id, reason) if Pause def _edge_alive(edge: Any) -> bool: """An edge is alive if it has no condition, or its condition returns True. @@ -465,6 +521,48 @@ async def _record_skip(nid: str) -> None: inputs_snapshot=inputs_by_node.get(node_id, {}), trace_entries=trace_entries, ) + # HITL: a node returned Pause(reason=...). Halt cleanly, save + # a paused checkpoint, and surface the pause in the result. + if nr.success and isinstance(nr.output, Pause): + pause_reason = nr.output.reason + await self._dispatch( + "on_node_pause", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + reason=pause_reason, + ) + self._save_checkpoint( + run_id=run_id, + node_id=node_id, + sequence=sequence, + context=context, + all_results=all_results, + paused=True, + pause_reason=pause_reason, + ) + pending_pause = (node_id, pause_reason) + abort = True + continue + + # Runtime fan-out: a node returned Send / list[Send]. + if nr.success and _is_send_payload(nr.output): + sends = nr.output if isinstance(nr.output, list) else [nr.output] + ok = await self._run_sends( + sends=sends, + context=context, + run_id=run_id, + all_results=all_results, + trace_entries=trace_entries, + completed=completed, + pending=pending, + ) + if not ok: + abort = True + # Successors of the worker targets are picked up by the + # normal readiness sweep on the next loop iteration. + continue + if nr.success and not nr.skipped: # State overlay: a dict return from the node is a state # update; non-dict returns flow through edges as ports. @@ -529,6 +627,12 @@ async def _record_skip(nid: str) -> None: if _pipeline_span is not None: _pipeline_span.end() + paused_node = pending_pause[0] if pending_pause else None + pause_reason_final = pending_pause[1] if pending_pause else None + if pending_pause is not None: + # A paused run is not "successful" — it didn't finish. + success = False + return PipelineResult( pipeline_name=self._dag.name, outputs=all_results, @@ -539,8 +643,91 @@ async def _record_skip(nid: str) -> None: usage=usage_summary, run_id=run_id, final_state=context.state, + paused=pending_pause is not None, + paused_node=paused_node, + pause_reason=pause_reason_final, ) + async def _run_sends( + self, + *, + sends: list[Send], + context: PipelineContext, + run_id: str, + all_results: dict[str, NodeResult], + trace_entries: list[ExecutionTraceEntry], + completed: set[str], + pending: set[str], + ) -> bool: + """Dispatch a list of :class:`Send` workers concurrently. + + Each Send's payload is applied to a copy of the current state before + its target runs. Results merge back into shared state via reducers. + Targets are added to ``completed`` and removed from ``pending`` so + the main scheduler does not re-execute them. + + Returns ``True`` on success, ``False`` if any worker failed (the + caller treats this as an abort signal). + """ + # Validate targets up front so unknown ones fail loud, not after gather(). + for send in sends: + if send.target not in self._dag.nodes: + nr = NodeResult( + node_id=send.target, + success=False, + error=f"Send dispatches to unknown target '{send.target}'", + ) + all_results[send.target] = nr + return False + + async def _run_one(send: Send) -> tuple[Send, NodeResult]: + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=send.target, + visit=1, + ) + # Per-worker context: own state copy with payload applied so + # workers don't race on the shared state object. + worker_context = PipelineContext(inputs=context.inputs) + if self._state_schema is not None and context.state is not None: + worker_context.state = apply_update(context.state, send.payload, self._reducers) + for nid, prev in context.results.items(): + worker_context.set_node_result(nid, prev) + nr = await self._execute_node( + send.target, + worker_context, + trace_entries, + None, + inputs={"input": send.payload}, + run_id=run_id, + ) + return send, nr + + try: + results = await asyncio.gather(*(_run_one(s) for s in sends)) + except Exception as exc: + logger.exception("Fan-out worker crashed") + for send in sends: + if send.target not in all_results: + all_results[send.target] = NodeResult(node_id=send.target, success=False, error=str(exc)) + return False + + all_ok = True + for send, nr in results: + all_results[send.target] = nr + context.set_node_result(send.target, nr) + completed.add(send.target) + pending.discard(send.target) + await self._emit_node_result(nr, run_id) + if not nr.success: + all_ok = False + continue + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) + return all_ok + async def _run_cyclic( self, *, @@ -858,13 +1045,23 @@ async def _emit_node_result(self, nr: NodeResult, run_id: str) -> None: else: await self._dispatch("on_node_error", error=nr.error or "unknown", **common) - def _load_for_resume(self, run_id: str) -> tuple[PipelineContext, set[str], int]: - """Rebuild context + completed-set from the latest checkpoint.""" + def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, set[str], int]: + """Rebuild context + completed-set from the latest checkpoint. + + Resuming a paused run (checkpoint.paused=True) requires + ``approve_pause=True``; otherwise a :class:`PipelineError` halts the + attempt and surfaces the pause reason. + """ if self._checkpointer is None: raise PipelineError("Cannot resume: pipeline has no checkpointer configured") record = self._checkpointer.load_latest(self._dag.name, run_id) if record is None: raise PipelineError(f"No checkpoint found for run_id='{run_id}'") + if record.paused and not approve_pause: + raise PipelineError( + f"Run '{run_id}' is paused at node '{record.node_id}' " + f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." + ) context = PipelineContext(inputs=record.state.get("inputs")) for nid, nr_dict in record.state.get("results", {}).items(): try: @@ -888,6 +1085,8 @@ def _save_checkpoint( sequence: int, context: PipelineContext, all_results: dict[str, NodeResult], + paused: bool = False, + pause_reason: str | None = None, ) -> None: """Persist state after a successful node. No-op if no checkpointer. @@ -911,6 +1110,8 @@ def _save_checkpoint( sequence=sequence, state=state, completed_nodes=completed_successful, + paused=paused, + pause_reason=pause_reason, ) ) except Exception: diff --git a/fireflyframework_agentic/pipeline/result.py b/fireflyframework_agentic/pipeline/result.py index 1b3e26ee..1f6bc260 100644 --- a/fireflyframework_agentic/pipeline/result.py +++ b/fireflyframework_agentic/pipeline/result.py @@ -82,6 +82,11 @@ class PipelineResult(BaseModel): # Final shared state for pipelines configured with state_schema. None # when the engine had no state overlay. final_state: Any = None + # HITL: a node returned :class:`Pause` and the run halted cleanly. + # Resume via ``engine.run(run_id=..., approve_pause=True)``. + paused: bool = False + paused_node: str | None = None + pause_reason: str | None = None @property def failed_nodes(self) -> list[str]: diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index c8c9247b..aae0c0c4 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -40,7 +40,7 @@ from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id -from fireflyframework_agentic.pipeline.engine import start_otel_span +from fireflyframework_agentic.pipeline.engine import Pause, Send, start_otel_span from fireflyframework_agentic.pipeline.reducers import apply_update, discover_reducers if TYPE_CHECKING: @@ -53,51 +53,10 @@ RouterFn = Callable[[Any], "str | Send | list[Send]"] -@dataclass -class Send: - """Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state. - - Routers can return a single ``Send`` or a list of ``Send`` to dispatch multiple - target invocations concurrently. Each Send's payload is applied to a *copy* - of the current state before its target runs; the target's return is then - merged back into shared state via reducers. - - Replaces the legacy ``FanOutStep`` pattern with a first-class primitive. - """ - - target: str - payload: dict[str, Any] - - class RecursionLimitError(Exception): """Raised when a node is visited more times than ``recursion_limit`` permits.""" -@dataclass -class Pause: - """Human-in-the-loop sentinel returned by a node to halt the pipeline. - - A node returns ``Pause(reason="...")`` when external approval (a human, - another system, a wall-clock event) is required before the pipeline may - continue. The pipeline then: - - 1. Writes a checkpoint with ``paused=True`` and the reason set. - 2. Emits ``on_node_pause`` on the configured event handler. - 3. Returns a :class:`StatePipelineResult` with ``paused=True`` and - ``success=False`` — the run is not finished, but it did not fail either. - - To resume after approval:: - - result = await pipeline.invoke(run_id=paused_run_id, approve_pause=True) - - Without ``approve_pause=True``, resuming a paused run raises - :class:`PipelineError`. The successor of the paused node runs next — - the pause node itself is not re-executed. - """ - - reason: str - - @dataclass class BranchSpec: """Internal: registered branch from one source node.""" diff --git a/tests/unit/pipeline/test_pipeline_engine_pause_send.py b/tests/unit/pipeline/test_pipeline_engine_pause_send.py new file mode 100644 index 00000000..f4526e15 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_pause_send.py @@ -0,0 +1,203 @@ +"""Layer 5 of the unification (#245): Pause and Send in PipelineEngine. + +PipelineEngine recognizes the same control sentinels that StatePipeline +uses today: + +- A node returning :class:`Pause` halts the pipeline cleanly; the run + resumes with ``engine.run(run_id=..., approve_pause=True)``. +- A node returning :class:`Send` or ``list[Send]`` triggers a parallel + fan-out where each Send's target runs concurrently with the supplied + payload merged into a per-worker state copy. Reducers merge worker + outputs back into shared state. + +Both sentinels work in the acyclic and cyclic schedulers. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.checkpoint import FileCheckpointer +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode + +# Pause and Send live in pipeline.engine now (moved from state_pipeline in +# this layer); the public re-export from pipeline/__init__.py is unchanged. +from fireflyframework_agentic.pipeline.engine import Pause, PipelineEngine, Send +from fireflyframework_agentic.pipeline.reducers import extend + +# ---- shared state --------------------------------------------------------- + + +class _LoopState(BaseModel): + items: Annotated[list[str], extend] = [] + approved: bool = False + deployed_to: str = "" + + +# ---- Pause ---------------------------------------------------------------- + + +def _step_pause(reason: str = "human gate"): + class _Step: + async def execute(self, ctx, inputs): + return Pause(reason=reason) + + return _Step() + + +def _step_record(label: str): + class _Step: + async def execute(self, ctx, inputs): + return {"items": [label]} + + return _Step() + + +async def test_pause_halts_pipeline_and_records_state(tmp_path): + """A node returning Pause halts: result.paused=True, success=False, checkpoint with paused=True.""" + build = _step_record("build") + gate = _step_pause("awaiting approval") + deploy = _step_record("deploy") + dag = DAG("hitl") + dag.add_node(DAGNode(node_id="build", step=build)) + dag.add_node(DAGNode(node_id="gate", step=gate)) + dag.add_node(DAGNode(node_id="deploy", step=deploy)) + dag.add_edge(DAGEdge(source="build", target="gate")) + dag.add_edge(DAGEdge(source="gate", target="deploy")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, state_schema=_LoopState, checkpointer=cp) + result = await engine.run(inputs="") + assert result.paused is True + assert result.paused_node == "gate" + assert result.pause_reason == "awaiting approval" + assert not result.success + # State so far: only 'build' contributed. + assert result.final_state.items == ["build"] + # Checkpoint reflects the paused state. + record = cp.load_latest("hitl", result.run_id) + assert record is not None + assert record.paused is True + assert record.pause_reason == "awaiting approval" + + +async def test_resume_paused_run_requires_approve_pause(tmp_path): + build = _step_record("build") + gate = _step_pause("awaiting approval") + dag = DAG("needs-approve") + dag.add_node(DAGNode(node_id="build", step=build)) + dag.add_node(DAGNode(node_id="gate", step=gate)) + dag.add_edge(DAGEdge(source="build", target="gate")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, state_schema=_LoopState, checkpointer=cp) + paused = await engine.run(inputs="") + assert paused.paused is True + # Without approve_pause: error. + with pytest.raises(PipelineError, match="paused"): + await engine.run(run_id=paused.run_id) + + +async def test_resume_with_approve_pause_continues_from_successor(tmp_path): + build = _step_record("build") + gate = _step_pause("awaiting approval") + deploy = _step_record("deploy") + dag = DAG("approved") + dag.add_node(DAGNode(node_id="build", step=build)) + dag.add_node(DAGNode(node_id="gate", step=gate)) + dag.add_node(DAGNode(node_id="deploy", step=deploy)) + dag.add_edge(DAGEdge(source="build", target="gate")) + dag.add_edge(DAGEdge(source="gate", target="deploy")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, state_schema=_LoopState, checkpointer=cp) + paused = await engine.run(inputs="") + assert paused.paused is True + resumed = await engine.run(run_id=paused.run_id, approve_pause=True) + assert resumed.success + # 'gate' is NOT re-executed; only 'deploy' adds to items. + assert resumed.final_state.items == ["build", "deploy"] + + +# ---- Send ------------------------------------------------------------------ + + +def _step_emit_sends(targets: list[str]): + class _Step: + async def execute(self, ctx, inputs): + return [Send(target=t, payload={"items": [f"sent-{t}"]}) for t in targets] + + return _Step() + + +def _step_consume_payload(suffix: str): + """A worker that turns its inbound items into a state update.""" + + class _Step: + async def execute(self, ctx, inputs): + seen = list(ctx.state.items) + return {"items": [f"{s}+{suffix}" for s in seen]} + + return _Step() + + +async def test_send_dispatches_workers_concurrently(): + """One node returns list[Send]; targets run concurrently and their outputs merge.""" + planner = _step_emit_sends(["a", "b"]) + worker_a = _step_consume_payload("A") + worker_b = _step_consume_payload("B") + dag = DAG("fanout") + dag.add_node(DAGNode(node_id="planner", step=planner)) + dag.add_node(DAGNode(node_id="a", step=worker_a)) + dag.add_node(DAGNode(node_id="b", step=worker_b)) + dag.add_edge(DAGEdge(source="planner", target="a")) + dag.add_edge(DAGEdge(source="planner", target="b")) + engine = PipelineEngine(dag, state_schema=_LoopState) + result = await engine.run(inputs="") + assert result.success + # Each worker sees its own payload (a sees "sent-a", b sees "sent-b"). + assert sorted(result.final_state.items) == sorted(["sent-a+A", "sent-b+B"]) + + +async def test_single_send_is_treated_as_list_of_one(): + planner = _step_emit_sends(["a"]) + # Override to emit a single Send (not a list). + + class _Solo: + async def execute(self, ctx, inputs): + return Send(target="a", payload={"items": ["just-a"]}) + + worker_a = _step_consume_payload("X") + dag = DAG("solo") + dag.add_node(DAGNode(node_id="planner", step=_Solo())) + dag.add_node(DAGNode(node_id="a", step=worker_a)) + dag.add_edge(DAGEdge(source="planner", target="a")) + engine = PipelineEngine(dag, state_schema=_LoopState) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.items == ["just-a+X"] + + +async def test_send_to_unknown_target_raises(): + class _Bad: + async def execute(self, ctx, inputs): + return [Send(target="ghost", payload={})] + + dag = DAG("unknown-send") + dag.add_node(DAGNode(node_id="planner", step=_Bad())) + engine = PipelineEngine(dag, state_schema=_LoopState) + result = await engine.run(inputs="") + # The fan-out fails; the pipeline reports failure. + assert not result.success + + +# ---- Pause exports --------------------------------------------------------- + + +def test_pause_and_send_reexported_from_pipeline_package(): + from fireflyframework_agentic.pipeline import Pause as P + from fireflyframework_agentic.pipeline import Send as S + + assert P is Pause + assert S is Send From 0f1686eccaf7fe6b91c8006bdcfbb4a94946b242 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 15:18:01 +0200 Subject: [PATCH 22/26] feat(pipeline): start_at kwarg for mid-pipeline entry (#245 layer 6) Sixth layer of the unification. PipelineEngine.run() accepts start_at= (string node id or callable reference) to begin execution mid-DAG. Nodes not reachable from start_at are treated as pre-completed and skipped. Resume (run_id) and approve_pause already landed in layers 1A and 5; this completes the entry-control kwargs. Engine changes: - run() accepts start_at: str | Callable | None = None. - New _resolve_node_id helper at module level. - pre_completed initialized from all_nodes - forward(start_at). - Works with state overlay, edge conditions, and cyclic mode. Tests: 7 new in tests/unit/pipeline/test_pipeline_engine_start_at.py. Full suite: 1594 passed. Refs: #245 --- fireflyframework_agentic/pipeline/engine.py | 24 ++++ .../pipeline/test_pipeline_engine_start_at.py | 123 ++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 tests/unit/pipeline/test_pipeline_engine_start_at.py diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index cebee341..cea10b14 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -172,6 +172,20 @@ class Send: payload: dict[str, Any] +def _resolve_node_id(ref: Any) -> str: + """Turn either a string node id or a function reference into a node id. + + Function references use ``fn.__name__``. Anything else raises + :class:`PipelineError`. + """ + if isinstance(ref, str): + return ref + name = getattr(ref, "__name__", None) + if not name: + raise PipelineError(f"Cannot derive node id from {ref!r}") + return name + + def _is_send_payload(value: Any) -> bool: """True when a node's return value is a single :class:`Send` or a non-empty ``list[Send]``. Drives the runtime fan-out branch in @@ -300,6 +314,7 @@ async def run( state: BaseModel | None = None, run_id: str | None = None, approve_pause: bool = False, + start_at: str | Any = None, ) -> PipelineResult: """Execute the pipeline. @@ -345,6 +360,15 @@ async def run( pre_completed = set() sequence_start = 0 all_results = {} + # Mid-pipeline start: pretend everything not reachable from + # `start_at` already ran. The scheduler then starts at start_at + # because its upstream nodes appear "completed". + if start_at is not None: + start_id = _resolve_node_id(start_at) + if start_id not in self._dag.nodes: + raise PipelineError(f"start_at='{start_id}' not in DAG") + forward = {start_id} | self._dag.transitive_successors(start_id) + pre_completed = {nid for nid in self._dag.nodes if nid not in forward} if run_id is None: run_id = uuid.uuid4().hex[:12] diff --git a/tests/unit/pipeline/test_pipeline_engine_start_at.py b/tests/unit/pipeline/test_pipeline_engine_start_at.py new file mode 100644 index 00000000..81003ec3 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_start_at.py @@ -0,0 +1,123 @@ +"""Layer 6 of the unification (#245): start_at kwarg for mid-pipeline entry. + +PipelineEngine.run() accepts ``start_at=`` (string node id or callable +reference). Execution begins at the named node; everything not reachable +from it is treated as pre-completed and skipped. This is the unified +equivalent of StatePipeline.invoke(state=..., start_at=...). +""" + +from __future__ import annotations + +import pytest + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _Counting: + def __init__(self, name: str): + self.name = name + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return self.name + + +def _chain(*ids: str) -> tuple[DAG, dict[str, _Counting]]: + dag = DAG("chain") + steps: dict[str, _Counting] = {} + for nid in ids: + s = _Counting(nid) + steps[nid] = s + dag.add_node(DAGNode(node_id=nid, step=s)) + for i in range(len(ids) - 1): + dag.add_edge(DAGEdge(source=ids[i], target=ids[i + 1])) + return dag, steps + + +# ---- baseline ------------------------------------------------------------- + + +async def test_no_start_at_runs_every_node(): + dag, steps = _chain("a", "b", "c") + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert all(s.calls == 1 for s in steps.values()) + + +# ---- start_at: skip upstream ---------------------------------------------- + + +async def test_start_at_skips_upstream_nodes(): + dag, steps = _chain("a", "b", "c", "d") + result = await PipelineEngine(dag).run(inputs="x", start_at="c") + assert result.success + assert steps["a"].calls == 0 + assert steps["b"].calls == 0 + assert steps["c"].calls == 1 + assert steps["d"].calls == 1 + + +async def test_start_at_first_node_is_like_no_start_at(): + dag, steps = _chain("a", "b", "c") + result = await PipelineEngine(dag).run(inputs="x", start_at="a") + assert result.success + assert all(s.calls == 1 for s in steps.values()) + + +async def test_start_at_terminal_runs_only_that_node(): + dag, steps = _chain("a", "b", "c") + result = await PipelineEngine(dag).run(inputs="x", start_at="c") + assert result.success + assert steps["a"].calls == 0 + assert steps["b"].calls == 0 + assert steps["c"].calls == 1 + + +# ---- start_at via callable ------------------------------------------------ + + +async def test_start_at_accepts_callable_reference(): + async def deploy(ctx, inputs): + return "deployed" + + from fireflyframework_agentic.pipeline.steps import CallableStep + + dag = DAG("callable") + dag.add_node(DAGNode(node_id="build", step=_Counting("build"))) + dag.add_node(DAGNode(node_id="deploy", step=CallableStep(deploy))) + dag.add_edge(DAGEdge(source="build", target="deploy")) + result = await PipelineEngine(dag).run(inputs="x", start_at=deploy) + assert result.success + # Resolves deploy.__name__ -> 'deploy' -> only deploy ran. + assert result.outputs["deploy"].output == "deployed" + + +# ---- invalid start_at ----------------------------------------------------- + + +async def test_unknown_start_at_raises(): + dag, _ = _chain("a", "b") + with pytest.raises(PipelineError, match="start_at"): + await PipelineEngine(dag).run(inputs="x", start_at="ghost") + + +# ---- branching dag -------------------------------------------------------- + + +async def test_start_at_in_branching_dag(): + """In a branching DAG, start_at picks one branch; the other is skipped entirely.""" + dag = DAG("branchy") + for nid in ("root", "left", "right", "leftchild"): + dag.add_node(DAGNode(node_id=nid, step=_Counting(nid))) + dag.add_edge(DAGEdge(source="root", target="left")) + dag.add_edge(DAGEdge(source="root", target="right")) + dag.add_edge(DAGEdge(source="left", target="leftchild")) + result = await PipelineEngine(dag).run(inputs="x", start_at="left") + assert result.success + # 'right' is not downstream of 'left' — it should not run. + assert "right" not in result.outputs or result.outputs["right"].skipped + # 'leftchild' is reachable from 'left'. + assert result.outputs["leftchild"].success From 56d52ecdec217d711b06e185909cd5aeaa8c687c Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 15:24:41 +0200 Subject: [PATCH 23/26] feat(pipeline): deprecate StatePipeline (#245 layer 7) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Seventh layer of the unification. StatePipeline now emits a DeprecationWarning on construction pointing users at PipelineEngine configured with state_schema= as the unified replacement. PipelineEngine has feature parity with StatePipeline after layers 1-6: state overlay, reducers, Pause, Send, cycles, recursion_limit, checkpointing, audit log, resume, start_at — plus parallel state-aware execution that StatePipeline could not provide. Changes: - StatePipeline.__init__ raises DeprecationWarning with migration text. - import warnings added at module level. - Class docstring includes a deprecation notice and migration example. StatePipeline still works — all 37 existing state-pipeline tests pass unchanged (they emit the new DeprecationWarning but continue to operate). Tests: 1 new in tests/unit/pipeline/test_state_pipeline_deprecation.py verifying the warning is raised and references PipelineEngine + #245. Full suite: 1595 passed. Layer 8 (full deletion of state_pipeline.py and dual-mode logic in builder.py) is tracked as follow-up — it requires migrating the 37 state-pipeline tests to the unified PipelineEngine surface and translating PipelineBuilder.branch(source, router, mapping) into conditional DAGEdges, which is more invasive than this deprecation layer warrants. Refs: #245 --- .../pipeline/state_pipeline.py | 32 +++++++++++++++++++ .../test_state_pipeline_deprecation.py | 32 +++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 tests/unit/pipeline/test_state_pipeline_deprecation.py diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index aae0c0c4..07a3b1ea 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -29,6 +29,7 @@ import logging import time import uuid +import warnings from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime @@ -97,6 +98,30 @@ class StatePipelineResult: class StatePipeline: """Compiled state-based pipeline. Returned by ``PipelineBuilder.build()`` when a ``state=`` schema is configured. + + .. deprecated:: + :class:`StatePipeline` is being subsumed by + :class:`fireflyframework_agentic.pipeline.engine.PipelineEngine` + configured with ``state_schema=``. The unified engine supports + the same features (state overlay, reducers, Pause, Send, cycles, + recursion_limit, checkpointing, audit, resume, start_at) and adds + true parallelism for state-aware pipelines via the topological + scheduler. New code should prefer ``PipelineEngine`` directly: + + .. code-block:: python + + engine = PipelineEngine( + dag, + state_schema=MyState, + checkpointer=cp, + audit_log=al, + recursion_limit=10, + ) + result = await engine.run(state=MyState(...)) + + See issue #245 for the full migration plan. The next layer of + unification removes :class:`StatePipeline` after a deprecation + cycle. """ def __init__( @@ -112,6 +137,13 @@ def __init__( event_handler: StatePipelineEventHandler | None = None, audit_log: AuditLog | None = None, ) -> None: + warnings.warn( + "StatePipeline is deprecated; use PipelineEngine(state_schema=...) " + "for the unified API. The unified engine supports the same features " + "and adds parallel state-aware execution. See issue #245.", + DeprecationWarning, + stacklevel=2, + ) self._name = name self._dag = dag self._state_schema = state_schema diff --git a/tests/unit/pipeline/test_state_pipeline_deprecation.py b/tests/unit/pipeline/test_state_pipeline_deprecation.py new file mode 100644 index 00000000..ec65da49 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_deprecation.py @@ -0,0 +1,32 @@ +"""Layer 7 of the unification (#245): StatePipeline deprecation. + +Constructing :class:`StatePipeline` now emits a :class:`DeprecationWarning` +pointing at :class:`PipelineEngine` configured with ``state_schema=`` as +the supported replacement. +""" + +from __future__ import annotations + +import warnings + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline.builder import PipelineBuilder + + +class _S(BaseModel): + x: int = 0 + + +async def _noop(state): + return None + + +def test_state_pipeline_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + PipelineBuilder("p", state=_S).add_node(_noop).build() + deprec = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert deprec, "expected a DeprecationWarning when constructing StatePipeline" + assert "PipelineEngine" in str(deprec[0].message) + assert "#245" in str(deprec[0].message) From f09d179b7b11f6a1545c80f775ed42911dae4ae5 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 15:33:51 +0200 Subject: [PATCH 24/26] fix(ci): satisfy stricter ruff rules on PR gate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three Ruff rules that the project ruff CI enforces but my local ruff missed (newer CI version): - SIM103 in engine.py:_is_send_payload — collapse the second if/return pair into a direct return. - F841 in test_pipeline_engine_pause_send.py — remove an unused 'planner' variable that was leftover from an earlier draft of test_single_send_is_treated_as_list_of_one. - N817 in same test file — rename 'Pause as P' / 'Send as S' to 'PausePkg'/'SendPkg' (single-letter CamelCase aliases trip the acronym rule). No behavior change. Tests still 7/7 in that file. --- fireflyframework_agentic/pipeline/engine.py | 4 +--- .../unit/pipeline/test_pipeline_engine_pause_send.py | 12 +++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index cea10b14..8317e5b1 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -193,9 +193,7 @@ def _is_send_payload(value: Any) -> bool: """ if isinstance(value, Send): return True - if isinstance(value, list) and value and all(isinstance(s, Send) for s in value): - return True - return False + return isinstance(value, list) and bool(value) and all(isinstance(s, Send) for s in value) def _serialize_value(value: Any) -> Any: diff --git a/tests/unit/pipeline/test_pipeline_engine_pause_send.py b/tests/unit/pipeline/test_pipeline_engine_pause_send.py index f4526e15..32e56c85 100644 --- a/tests/unit/pipeline/test_pipeline_engine_pause_send.py +++ b/tests/unit/pipeline/test_pipeline_engine_pause_send.py @@ -161,9 +161,7 @@ async def test_send_dispatches_workers_concurrently(): async def test_single_send_is_treated_as_list_of_one(): - planner = _step_emit_sends(["a"]) - # Override to emit a single Send (not a list). - + # A planner step that emits one Send directly (not wrapped in a list). class _Solo: async def execute(self, ctx, inputs): return Send(target="a", payload={"items": ["just-a"]}) @@ -196,8 +194,8 @@ async def execute(self, ctx, inputs): def test_pause_and_send_reexported_from_pipeline_package(): - from fireflyframework_agentic.pipeline import Pause as P - from fireflyframework_agentic.pipeline import Send as S + from fireflyframework_agentic.pipeline import Pause as PausePkg + from fireflyframework_agentic.pipeline import Send as SendPkg - assert P is Pause - assert S is Send + assert PausePkg is Pause + assert SendPkg is Send From 19d3b636a2b7466bb16958fa5299d77d6127d1fd Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 16:29:01 +0200 Subject: [PATCH 25/26] feat(pipeline): delete StatePipeline, unify on PipelineEngine (#245 layer 8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Final layer of the unification. StatePipeline + state_pipeline.py + the dual-mode logic in builder.py are gone; PipelineEngine is the only executor. PipelineBuilder(state=...) now constructs a PipelineEngine configured with state_schema, recursion_limit, audit_log, checkpointer, event_handler, and any routers registered via .branch(...). No deprecation cycle needed: StatePipeline never landed on main (introduced in #232 which is the umbrella PR). All consumers move at the same time. Engine changes: - PipelineEngine accepts routers: dict[str, RouterFn] and router_mappings: dict[str, dict[str, str]]. When non-empty (or the DAG is cyclic), the engine routes through the cyclic frontier scheduler. - New _cyclic_next consults routers first, then falls back to edge conditions. - _resolve_router_decision lifts the router-return → next-step translation from state_pipeline. - _run_cyclic now handles Pause, Send, and list[Send] returns alongside the dict/state-update path. - _run_sends accepts visit_counts; per-Send visit numbers are tracked per target so observability shows the right counter across fan-out. - PipelineEngine.to_mermaid renders branch labels from router_mappings. - PipelineEngine.invoke(state, ...) shorthand mirrors the StatePipeline.invoke signature so test migration stays mechanical. - _record_audit accepts status_override + pause_reason so a Pause-returning node is audited as "paused" rather than "success". - Span names use "pipeline.state.*" when state_schema is set (matches the legacy taxonomy observability dashboards already key on). - on_node_start is now emitted by the schedulers (acyclic main loop, cyclic loop, _run_sends) — removed from _execute_node so the visit number is correct in every scheduler. - _load_for_resume returns list[str] (ordered) instead of set[str] so PipelineResult.completed_nodes after resume reflects the original execution order. - Resume seeds trace_entries with the pre_completed nodes so PipelineResult.completed_nodes (now derived from execution_trace) includes the full history. Builder changes: - PipelineBuilder.build() always returns PipelineEngine. The state-mode branch wires state_schema/routers/router_mappings into the engine. - New _StateStepAdapter wraps a state-mode fn into a StepExecutor so the unified engine's _execute_node can run it. - _coerce_state_node_fn (moved from state_pipeline.py) keeps the "async def, def, or .run(state) object" forms working. - _StateNodePlaceholder is gone — state-mode nodes now run real code. Result changes: - PipelineResult grows state-mode convenience properties: state (alias of final_state), completed_nodes (derived from execution_trace so cyclic visits appear individually), failed_node, error. Deletions: - fireflyframework_agentic/pipeline/state_pipeline.py (~750 LOC) - StatePipeline, StatePipelineResult, StatePipelineEventHandler, RecursionLimitError, BranchSpec, StateNodeFn (in state_pipeline) - _StateNodePlaceholder (in builder) - tests/unit/pipeline/test_state_pipeline_deprecation.py (Layer 7's warning is gone too) - " StatePipeline," imports and isinstance assertions across the state-pipeline test files. Migration notes (internal to this PR): - Pause and Send already live in engine.py since Layer 5; this PR just removes the state_pipeline re-export. - StatePipelineEventHandler removed; existing usages either implement EventHandler (the unified protocol) or the legacy PipelineEventHandler. Full unit suite: 1594 passed. Net diff: -507 LOC (-955 deleted, +448 added). Refs: #245 --- examples/software_factory/progress.py | 2 +- fireflyframework_agentic/pipeline/__init__.py | 14 +- fireflyframework_agentic/pipeline/builder.py | 174 ++-- .../pipeline/checkpoint.py | 2 +- fireflyframework_agentic/pipeline/engine.py | 384 +++++++-- fireflyframework_agentic/pipeline/result.py | 33 + .../pipeline/state_pipeline.py | 751 ------------------ .../unit/pipeline/test_checkpoint_backends.py | 4 +- tests/unit/pipeline/test_state_pipeline.py | 2 - .../test_state_pipeline_deprecation.py | 32 - .../unit/pipeline/test_state_pipeline_hitl.py | 2 - .../pipeline/test_state_pipeline_phase2.py | 3 - 12 files changed, 448 insertions(+), 955 deletions(-) delete mode 100644 fireflyframework_agentic/pipeline/state_pipeline.py delete mode 100644 tests/unit/pipeline/test_state_pipeline_deprecation.py diff --git a/examples/software_factory/progress.py b/examples/software_factory/progress.py index 4f9492c9..f43892e6 100644 --- a/examples/software_factory/progress.py +++ b/examples/software_factory/progress.py @@ -5,7 +5,7 @@ """Console progress handler. -Implements (structurally) the framework's :class:`StatePipelineEventHandler` +Implements (structurally) the framework's :class:`EventHandler` Protocol. Prints one line per pipeline / node event so the QA loop and checkpoint+resume flow are visible when running the example by hand. """ diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index 4fa01bd3..c4f69fd0 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -46,19 +46,13 @@ from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import ( EventHandler, + Pause, PipelineEngine, PipelineEventHandler, - StatePipelineEventHandler, + Send, ) from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult -from fireflyframework_agentic.pipeline.state_pipeline import ( - Pause, - RecursionLimitError, - Send, - StatePipeline, - StatePipelineResult, -) from fireflyframework_agentic.pipeline.steps import ( AgentStep, BatchLLMStep, @@ -103,12 +97,8 @@ "PipelineResult", "QueryableAuditLog", "ReasoningStep", - "RecursionLimitError", "RetrievalStep", "Send", - "StatePipeline", - "StatePipelineEventHandler", - "StatePipelineResult", "StepExecutor", "append", "extend", diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 512bd06d..d6a6805b 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -14,11 +14,10 @@ """Fluent builder API for constructing pipeline DAGs. -Two modes: +Two modes, both backed by the same :class:`PipelineEngine`: -1. **Port-based** (legacy, parallel-friendly): nodes are added by string id, - data flows over edge ports, executed by :class:`PipelineEngine`. Use this - for ETL-shaped DAGs with independent parallel steps:: +1. **Port-based** (parallel-friendly): nodes are added by string id, data + flows over edge ports:: pipeline = ( PipelineBuilder("idp") @@ -29,10 +28,10 @@ ) 2. **State-based**: configure ``state=SomeModel`` and nodes become - ``async (state) -> dict`` functions over a typed shared state. Branching - is one ``.branch(source, router)`` call. Function references can be used - as node ids. Optional checkpointing supports resume after failure and - mid-pipeline start. Produces a :class:`StatePipeline`:: + ``async (state) -> dict | None | Pause | Send | list[Send]``. Branching + is one ``.branch(source, router)`` call; function references work as + node ids. Optional checkpointing supports resume after failure and + mid-pipeline start:: pipeline = ( PipelineBuilder("agent", state=AgentState, checkpointer=FileCheckpointer("./ckpt")) @@ -46,8 +45,9 @@ from __future__ import annotations +import asyncio import inspect -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import Any from pydantic import BaseModel @@ -55,18 +55,69 @@ from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.audit import AuditLog from fireflyframework_agentic.pipeline.checkpoint import Checkpointer +from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine, StatePipelineEventHandler -from fireflyframework_agentic.pipeline.state_pipeline import ( - BranchSpec, +from fireflyframework_agentic.pipeline.engine import ( + EventHandler, + PipelineEngine, + PipelineEventHandler, RouterFn, Send, # noqa: F401 re-exported via pipeline/__init__.py - StateNodeFn, - StatePipeline, - coerce_state_node_fn, ) from fireflyframework_agentic.pipeline.steps import AgentStep, CallableStep, StepExecutor +StateNodeFn = Callable[[Any], Awaitable[Any]] +"""Signature for a state-mode node: ``async (state) -> dict | None | Pause | Send | list[Send]``.""" + + +class _StateStepAdapter: + """Adapts a state-mode node fn into the :class:`StepExecutor` shape so it + can ride through :meth:`PipelineEngine._execute_node`. + + State-mode functions take ``(state)`` and return a state update (or one + of the control sentinels :class:`Pause` / :class:`Send`). The engine + calls ``step.execute(context, inputs)``; this adapter forwards + ``context.state`` to the wrapped fn and returns its value verbatim so + PipelineEngine's existing dict/Pause/Send handling fires. + """ + + def __init__(self, fn: Callable[..., Any]) -> None: + self._fn = _coerce_state_node_fn(fn) + + async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: # noqa: ARG002 + return await self._fn(context.state) + + +def _coerce_state_node_fn(fn: Callable[..., Any]) -> StateNodeFn: + """Turn user-supplied state-mode callables into the standard ``async (state) -> Any`` shape. + + Accepts: + * ``async def f(state)`` — used as-is. + * ``def f(state)`` — wrapped to run on a worker thread. + * Object with ``async run(state)`` (e.g. a FireflyAgent) — adapter calls ``.run(state)``. + """ + if inspect.iscoroutinefunction(fn): + return fn # type: ignore[return-value] + + run = getattr(fn, "run", None) + if not callable(fn) and run is not None and callable(run): + + async def _agent_wrap(state: Any) -> Any: + if inspect.iscoroutinefunction(run): + return await run(state) + return await asyncio.get_running_loop().run_in_executor(None, run, state) + + return _agent_wrap + + if callable(fn): + + async def _async_wrap(state: Any) -> Any: + return await asyncio.get_running_loop().run_in_executor(None, fn, state) + + return _async_wrap + + raise PipelineError(f"Cannot adapt {fn!r} as a state node function") + class PipelineBuilder: """Fluent builder for pipelines. @@ -74,10 +125,14 @@ class PipelineBuilder: Parameters: name: Human-readable name for the pipeline. state: Optional Pydantic model class for typed shared state. - When set, the builder produces a :class:`StatePipeline` and nodes - are expected to be ``async (state) -> dict | None``. - checkpointer: Optional :class:`Checkpointer` for state-based pipelines. - Ignored when ``state`` is not set. + When set, the builder produces a state-aware + :class:`PipelineEngine` and nodes are expected to be + ``async (state) -> dict | None | Pause | Send | list[Send]``. + checkpointer: Optional :class:`Checkpointer` for resume. + recursion_limit: Max visits per node in cycle-aware runs. + event_handler: Optional :class:`EventHandler` (or legacy + :class:`PipelineEventHandler`). + audit_log: Optional :class:`AuditLog`. """ def __init__( @@ -87,11 +142,10 @@ def __init__( state: type[BaseModel] | None = None, checkpointer: Checkpointer | None = None, recursion_limit: int = 25, - event_handler: StatePipelineEventHandler | None = None, + event_handler: EventHandler | PipelineEventHandler | None = None, audit_log: AuditLog | None = None, ) -> None: - # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). - # The legacy port-based path keeps acyclicity as an invariant. + # State-aware pipelines may have cycles (ReAct loops, retry-with-critique). self._dag = DAG(name=name, allow_cycles=state is not None) self._name = name self._state_schema = state @@ -101,9 +155,9 @@ def __init__( self._audit_log = audit_log self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] - # State-based mode bookkeeping. Keyed by node id. - self._state_node_fns: dict[str, StateNodeFn] = {} - self._branches: dict[str, BranchSpec] = {} + # Routers + mappings drive the cyclic scheduler's next-step pick. + self._routers: dict[str, RouterFn] = {} + self._router_mappings: dict[str, dict[str, str]] = {} def add_node( self, @@ -122,11 +176,10 @@ def add_node( * ``add_node(fn)`` — state-based mode. ``fn`` is a callable; the node id is taken from ``fn.__name__``. Requires the builder was constructed with ``state=...``. - * ``add_node(node_id, step)`` — legacy port-based mode. ``step`` is a + * ``add_node(node_id, step)`` — port-based mode. ``step`` is a :class:`StepExecutor`, an agent-like, or an async callable. """ if step is None and callable(node_id_or_fn) and not isinstance(node_id_or_fn, str): - # State-based: derive id from function name. if self._state_schema is None: raise PipelineError( "Function-reference add_node(fn) requires PipelineBuilder(state=...). " @@ -134,11 +187,10 @@ def add_node( ) fn = node_id_or_fn node_id = getattr(fn, "__name__", None) or repr(fn) - self._state_node_fns[node_id] = coerce_state_node_fn(fn) self._pending_nodes.append( DAGNode( node_id=node_id, - step=_StateNodePlaceholder(), # never executed; engine path is unused for state pipelines + step=_StateStepAdapter(fn), condition=condition, retry_max=retry_max, timeout_seconds=timeout_seconds, @@ -152,19 +204,16 @@ def add_node( node_id = node_id_or_fn if self._state_schema is not None and step is not None: - # State-based pipeline: accept a callable, or an agent-like object - # exposing async ``run(state)``. ``coerce_state_node_fn`` handles both. run_method = getattr(step, "run", None) if not callable(step) and not callable(run_method): raise PipelineError( f"State pipeline node '{node_id}' must be a callable or expose async run(state); " f"got {type(step).__name__}" ) - self._state_node_fns[node_id] = coerce_state_node_fn(step) self._pending_nodes.append( DAGNode( node_id=node_id, - step=_StateNodePlaceholder(), + step=_StateStepAdapter(step), condition=condition, retry_max=retry_max, timeout_seconds=timeout_seconds, @@ -225,61 +274,50 @@ def branch( router: RouterFn, mapping: dict[str, str | Callable[..., Any]] | None = None, ) -> PipelineBuilder: - """Register a router on ``source``. + """Register a runtime router on ``source``. - ``router`` is a synchronous ``(state) -> str`` callable. Behaviour: + ``router`` is a synchronous ``(state) -> str | Send | list[Send]`` + callable. Behaviour: * If ``mapping`` is None, the router must return the **id of an existing node** that will run next. * If ``mapping`` is provided, the router returns an abstract label that is looked up in ``mapping`` to find the target node id. - State-based pipelines only. + State-aware pipelines only. """ if self._state_schema is None: raise PipelineError(".branch(...) requires PipelineBuilder(state=...)") source_id = _id(source) - resolved_mapping: dict[str, str] | None = None if mapping is not None: resolved_mapping = {label: _id(target) for label, target in mapping.items()} - # Materialize each label's edge into the DAG so topology is inspectable. + self._router_mappings[source_id] = resolved_mapping + # Materialize each label's edge so topology stays inspectable. for target_id in resolved_mapping.values(): self._pending_edges.append(DAGEdge(source=source_id, target=target_id)) - else: - # No mapping: we don't know targets at build time; edges will - # be missing from the DAG. That's fine for the StatePipeline - # executor (it consults the router), but visualisation will be - # incomplete. Materialize edges lazily when the router fires. - pass - self._branches[source_id] = BranchSpec(source=source_id, router=router, mapping=resolved_mapping) + self._routers[source_id] = router return self - def build(self) -> PipelineEngine | StatePipeline: - """Build the DAG and return either a :class:`PipelineEngine` - (legacy port-based) or :class:`StatePipeline` (when ``state=`` is set). - """ + def build(self) -> PipelineEngine: + """Build the DAG and return a :class:`PipelineEngine`.""" for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: self._dag.add_edge(edge) - if self._state_schema is not None: - return StatePipeline( - name=self._name, - dag=self._dag, - state_schema=self._state_schema, - node_fns=self._state_node_fns, - branches=self._branches, - checkpointer=self._checkpointer, - recursion_limit=self._recursion_limit, - event_handler=self._event_handler, - audit_log=self._audit_log, - ) - - return PipelineEngine(self._dag) + return PipelineEngine( + self._dag, + event_handler=self._event_handler, + checkpointer=self._checkpointer, + audit_log=self._audit_log, + state_schema=self._state_schema, + recursion_limit=self._recursion_limit, + routers=self._routers, + router_mappings=self._router_mappings, + ) def build_dag(self) -> DAG: - """Build and return just the :class:`DAG` (for inspection or custom engines).""" + """Build and return just the :class:`DAG` (for inspection).""" for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: @@ -309,11 +347,3 @@ def _id(ref: str | Callable[..., Any]) -> str: if not name: raise PipelineError(f"Cannot derive node id from {ref!r}") return name - - -class _StateNodePlaceholder: - """Sentinel step kept in the DAG so topology is intact. Never executed — - state pipelines bypass :class:`PipelineEngine` entirely.""" - - async def execute(self, *_args: Any, **_kwargs: Any) -> Any: - raise PipelineError("_StateNodePlaceholder.execute called — state pipelines should not use PipelineEngine.") diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index 765b8355..b3cf6b94 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -38,7 +38,7 @@ class CheckpointRecord(BaseModel): """One saved checkpoint. ``paused`` and ``pause_reason`` are set when a node returns - :class:`fireflyframework_agentic.pipeline.state_pipeline.Pause`. Default + :class:`fireflyframework_agentic.pipeline.engine.Pause`. Default to ``False`` / ``None`` so existing records from earlier phases load cleanly under the new schema. """ diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 8317e5b1..de644ad8 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -23,9 +23,10 @@ import random import time import uuid +from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, cast, runtime_checkable try: from opentelemetry import trace as otel_trace @@ -40,7 +41,7 @@ from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.context import PipelineContext -from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy +from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy, _mermaid_id from fireflyframework_agentic.pipeline.reducers import Reducer, apply_update, discover_reducers from fireflyframework_agentic.pipeline.result import ( ExecutionTraceEntry, @@ -53,8 +54,7 @@ @runtime_checkable class EventHandler(Protocol): - """Unified pipeline event handler. Used by :class:`PipelineEngine` and - :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. + """Pipeline event handler used by :class:`PipelineEngine`. Implement any subset of these methods; missing ones are no-ops. Exceptions raised in callbacks are swallowed by the engine so observability never @@ -63,9 +63,8 @@ class EventHandler(Protocol): The engine dispatches events by parameter name. If your method signature omits a parameter — e.g. legacy implementations that don't accept ``run_id`` or ``visit`` — the engine simply drops it from the call. - That keeps legacy :class:`PipelineEventHandler` / - :class:`StatePipelineEventHandler` implementations working during the - transition to this unified shape. + That keeps the legacy :class:`PipelineEventHandler` shape working + transparently alongside this unified one. Parameter conventions: @@ -113,22 +112,11 @@ async def on_node_skip(self, node_id: str, pipeline_name: str, reason: str) -> N async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: ... -@runtime_checkable -class StatePipelineEventHandler(Protocol): - """Legacy state-pipeline event handler protocol. Use :class:`EventHandler`. - - Same shape as :class:`EventHandler` minus ``on_node_skip``. Kept for - backward compatibility; new code should implement :class:`EventHandler`. - """ - - async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: ... - async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: ... - async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: ... - async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: ... - async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... - async def on_pipeline_complete( - self, pipeline_name: str, run_id: str, success: bool, duration_ms: float - ) -> None: ... +RouterFn = Callable[[Any], "str | Send | list[Send]"] +"""Signature for a runtime branch router: receives the current state, returns +the next-step instruction — either a target node id, a single Send, or a +list of Sends for fan-out. +""" @dataclass @@ -222,7 +210,7 @@ def start_otel_span(name: str, **attributes: Any) -> Any: """Start an OTel span if observability is enabled, else return ``None``. Module-level helper shared by :class:`PipelineEngine` and - :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. + :class:`PipelineEngine`. """ try: if not get_config().observability_enabled: @@ -254,6 +242,8 @@ def __init__( audit_log: AuditLog | None = None, state_schema: type[BaseModel] | None = None, recursion_limit: int = 25, + routers: dict[str, RouterFn] | None = None, + router_mappings: dict[str, dict[str, str]] | None = None, ) -> None: self._dag = dag self._event_handler = event_handler @@ -266,9 +256,62 @@ def __init__( self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {} # Max visits per node for cycle-aware runs. Matches StatePipeline's default. self._recursion_limit = recursion_limit + # Optional runtime routers: source_id -> router(state) -> str | Send | list[Send]. + # When a source node has a router, the cyclic scheduler consults it + # instead of (or in addition to) the source's outgoing edges. With + # an accompanying mapping the router returns an abstract label that + # is looked up in the mapping. + self._routers: dict[str, RouterFn] = dict(routers or {}) + self._router_mappings: dict[str, dict[str, str]] = dict(router_mappings or {}) # Per-method signature cache for legacy-vs-unified dispatch. self._handler_params: dict[str, set[str]] = {} + def to_mermaid(self) -> str: + """Render the pipeline as a Mermaid flowchart, labelling branch edges. + + When the builder called ``.branch(source, router, mapping={...})`` + the resulting edges carry abstract labels (``yes``/``no``/etc). + This view threads those labels back into the diagram so the routing + is visible alongside the topology. + """ + lines = ["flowchart TD"] + for node_id in self._dag.nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + for edge in self._dag.edges: + label: str | None = None + mapping = self._router_mappings.get(edge.source) + if mapping: + for lbl, tgt in mapping.items(): + if tgt == edge.target: + label = lbl + break + if label is None and edge.condition is not None: + label = "if?" + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + return "\n".join(lines) + + async def invoke( + self, + state: Any = None, + *, + run_id: str | None = None, + start_at: Any = None, + approve_pause: bool = False, + ) -> PipelineResult: + """Shorthand for state-aware runs: ``await pipeline.invoke(state)``. + + Mirrors the legacy ``StatePipeline.invoke`` signature so callers that + treat the first positional as the state object keep working. New code + should call :meth:`run` directly with explicit kwargs. + """ + return await self.run( + state=state, + run_id=run_id, + start_at=start_at, + approve_pause=approve_pause, + ) + async def _dispatch(self, method_name: str, /, **kwargs: Any) -> None: """Invoke ``event_handler.method_name`` with the subset of ``kwargs`` the method's signature actually declares. @@ -333,12 +376,24 @@ async def run( """ if run_id is not None and context is None and inputs is None and state is None: resume_run_id: str = run_id - context, pre_completed, sequence_start = self._load_for_resume(resume_run_id, approve_pause=approve_pause) - all_results: dict[str, NodeResult] = { - nid: nr - for nid in pre_completed - if (nr := context.get_node_result(nid)) is not None and isinstance(nr, NodeResult) - } + context, pre_completed_list, sequence_start = self._load_for_resume( + resume_run_id, approve_pause=approve_pause + ) + pre_completed = set(pre_completed_list) + # Preserve original completion order so PipelineResult.completed_nodes + # reflects the run's actual sequence after resume. + all_results: dict[str, NodeResult] = {} + for nid in pre_completed_list: + nr = context.get_node_result(nid) + if isinstance(nr, NodeResult): + all_results[nid] = nr + # Synthesize trace entries for the pre-completed nodes so the + # resumed result's completed_nodes reflects the full history. + now = datetime.now(UTC) + resume_trace_seed: list[ExecutionTraceEntry] = [ + ExecutionTraceEntry(node_id=nid, started_at=now, completed_at=now, status="success") + for nid in pre_completed_list + ] else: if context is None: context = PipelineContext(inputs=inputs) @@ -358,6 +413,7 @@ async def run( pre_completed = set() sequence_start = 0 all_results = {} + resume_trace_seed: list[ExecutionTraceEntry] = [] # Mid-pipeline start: pretend everything not reachable from # `start_at` already ran. The scheduler then starts at start_at # because its upstream nodes appear "completed". @@ -372,17 +428,23 @@ async def run( run_id = uuid.uuid4().hex[:12] # Observability: pipeline-level span + start event + # State-aware runs use the "pipeline.state.*" span prefix to match + # the legacy StatePipeline taxonomy that observability dashboards + # already key on. + span_prefix = "pipeline.state" if self._state_schema is not None else "pipeline" _pipeline_span = self._start_otel_span( - f"pipeline.{self._dag.name}", + f"{span_prefix}.{self._dag.name}", pipeline=self._dag.name, + run_id=run_id, ) await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id) # Cycle-aware mode: a separate sequential frontier-following scheduler # that respects ``recursion_limit``. The topological scheduler below # cannot run cyclic graphs because execution_levels()/topological_sort() - # are undefined on them. - if self._dag.is_cyclic(): + # are undefined on them. Runtime routers also force this mode because + # they make routing a function of state, not topology. + if self._dag.is_cyclic() or self._routers: return await self._run_cyclic( context=context, run_id=run_id, @@ -390,13 +452,14 @@ async def run( pre_completed=pre_completed, sequence_start=sequence_start, pipeline_span=_pipeline_span, + resume_trace_seed=resume_trace_seed, ) # Topological levels ensure that all upstream dependencies of a node # complete before the node itself executes. Nodes within the same # level are independent and run concurrently via asyncio.gather. levels = self._dag.execution_levels() - trace_entries: list[ExecutionTraceEntry] = [] + trace_entries: list[ExecutionTraceEntry] = list(resume_trace_seed) pipeline_start = time.perf_counter() failed_nodes: set[str] = set() @@ -489,6 +552,15 @@ async def _record_skip(nid: str) -> None: # them for the audit snapshot. gathered = self._gather_inputs(nid, context) inputs_by_node[nid] = gathered + # Emit start event here (visit=1 in the acyclic scheduler; + # the cyclic scheduler and Send fan-out emit their own). + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=nid, + visit=1, + ) task = asyncio.create_task( self._execute_node( nid, @@ -535,6 +607,7 @@ async def _record_skip(nid: str) -> None: # Persist lifecycle: audit every executed visit; checkpoint only # successful completions (failed nodes must re-run on resume). sequence += 1 + paused_now = nr.success and isinstance(nr.output, Pause) self._record_audit( run_id=run_id, node_id=node_id, @@ -542,10 +615,12 @@ async def _record_skip(nid: str) -> None: nr=nr, inputs_snapshot=inputs_by_node.get(node_id, {}), trace_entries=trace_entries, + status_override="paused" if paused_now else None, + pause_reason=nr.output.reason if paused_now else None, ) # HITL: a node returned Pause(reason=...). Halt cleanly, save # a paused checkpoint, and surface the pause in the result. - if nr.success and isinstance(nr.output, Pause): + if paused_now: pause_reason = nr.output.reason await self._dispatch( "on_node_pause", @@ -569,7 +644,10 @@ async def _record_skip(nid: str) -> None: # Runtime fan-out: a node returned Send / list[Send]. if nr.success and _is_send_payload(nr.output): - sends = nr.output if isinstance(nr.output, list) else [nr.output] + sends: list[Send] = cast( + "list[Send]", + list(nr.output) if isinstance(nr.output, list) else [nr.output], + ) ok = await self._run_sends( sends=sends, context=context, @@ -680,6 +758,7 @@ async def _run_sends( trace_entries: list[ExecutionTraceEntry], completed: set[str], pending: set[str], + visit_counts: dict[str, int] | None = None, ) -> bool: """Dispatch a list of :class:`Send` workers concurrently. @@ -702,13 +781,13 @@ async def _run_sends( all_results[send.target] = nr return False - async def _run_one(send: Send) -> tuple[Send, NodeResult]: + async def _run_one(send: Send, visit_n: int) -> tuple[Send, NodeResult]: await self._dispatch( "on_node_start", pipeline_name=self._dag.name, run_id=run_id, node_id=send.target, - visit=1, + visit=visit_n, ) # Per-worker context: own state copy with payload applied so # workers don't race on the shared state object. @@ -727,8 +806,19 @@ async def _run_one(send: Send) -> tuple[Send, NodeResult]: ) return send, nr + # Per-Send visit numbers: increment per dispatched target. The + # caller may seed counts (cyclic scheduler tracks them globally); + # otherwise each fan-out batch starts at 1. + send_visits: list[int] = [] + running_counts = dict(visit_counts) if visit_counts else {} + for send in sends: + running_counts[send.target] = running_counts.get(send.target, 0) + 1 + send_visits.append(running_counts[send.target]) + if visit_counts is not None: + visit_counts.update(running_counts) + try: - results = await asyncio.gather(*(_run_one(s) for s in sends)) + results = await asyncio.gather(*(_run_one(s, v) for s, v in zip(sends, send_visits, strict=True))) except Exception as exc: logger.exception("Fan-out worker crashed") for send in sends: @@ -759,6 +849,7 @@ async def _run_cyclic( pre_completed: set[str], sequence_start: int, pipeline_span: Any, + resume_trace_seed: list[ExecutionTraceEntry] | None = None, ) -> PipelineResult: """Sequential frontier-following scheduler for cyclic DAGs. @@ -768,7 +859,7 @@ async def _run_cyclic( having multiple alive outgoing edges is currently an error — parallel cyclic fan-out is the job of :class:`Send` in a later layer. """ - trace_entries: list[ExecutionTraceEntry] = [] + trace_entries: list[ExecutionTraceEntry] = list(resume_trace_seed or []) pipeline_start = time.perf_counter() visit_counts: dict[str, int] = dict.fromkeys(pre_completed, 1) sequence = sequence_start @@ -777,14 +868,53 @@ async def _run_cyclic( nodes_in_order = list(self._dag.nodes) if not nodes_in_order: raise PipelineError("Pipeline has no nodes") - current: str | None = nodes_in_order[0] + next_step: str | list[Send] | None = nodes_in_order[0] # Skip past anything already completed during this resumed run. - while current is not None and current in pre_completed: - current = self._cyclic_next(current, context) + while isinstance(next_step, str) and next_step in pre_completed: + next_step = self._cyclic_next(next_step, context) - result: PipelineResult | None = None + pending_pause: tuple[str, str] | None = None try: - while current is not None: + while next_step is not None: + # --- Fan-out (list[Send]) --------------------------------- + if isinstance(next_step, list): + sends = next_step + # Preview the per-target visit numbers to enforce the + # recursion limit; _run_sends does the real increment. + over_limit: str | None = None + preview = dict(visit_counts) + for send in sends: + preview[send.target] = preview.get(send.target, 0) + 1 + if preview[send.target] > self._recursion_limit: + over_limit = send.target + break + if over_limit is not None: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{over_limit}' during fan-out." + ) + logger.error(msg) + all_results[over_limit] = NodeResult(node_id=over_limit, success=False, error=msg) + break + completed_set: set[str] = set(all_results) + pending_set: set[str] = set() + ok = await self._run_sends( + sends=sends, + context=context, + run_id=run_id, + all_results=all_results, + trace_entries=trace_entries, + completed=completed_set, + pending=pending_set, + visit_counts=visit_counts, + ) + if not ok: + break + # Continue from the common successor of all workers, if any. + next_step = self._common_successor([s.target for s in sends]) + continue + + # --- Single-node step ------------------------------------- + current = next_step visit_counts[current] = visit_counts.get(current, 0) + 1 visit_n = visit_counts[current] if visit_n > self._recursion_limit: @@ -793,16 +923,17 @@ async def _run_cyclic( f"'{current}'. Raise recursion_limit= or fix the routing logic." ) logger.error(msg) - nr = NodeResult(node_id=current, success=False, error=msg) - all_results[current] = nr + nr_over = NodeResult(node_id=current, success=False, error=msg) + all_results[current] = nr_over sequence += 1 self._record_audit( run_id=run_id, node_id=current, sequence=sequence, - nr=nr, + nr=nr_over, inputs_snapshot={}, trace_entries=trace_entries, + visit=visit_n, ) break @@ -827,6 +958,7 @@ async def _run_cyclic( await self._emit_node_result(nr, run_id) sequence += 1 + paused_now = nr.success and isinstance(nr.output, Pause) self._record_audit( run_id=run_id, node_id=current, @@ -835,11 +967,43 @@ async def _run_cyclic( inputs_snapshot=gathered, trace_entries=trace_entries, visit=visit_n, + status_override="paused" if paused_now else None, + pause_reason=nr.output.reason if paused_now else None, ) if not nr.success and not nr.skipped: break + # HITL: node returned Pause — checkpoint paused and halt. + if paused_now: + pause_reason = nr.output.reason + await self._dispatch( + "on_node_pause", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=current, + reason=pause_reason, + ) + self._save_checkpoint( + run_id=run_id, + node_id=current, + sequence=sequence, + context=context, + all_results=all_results, + paused=True, + pause_reason=pause_reason, + ) + pending_pause = (current, pause_reason) + break + + # Fan-out: node returned Send / list[Send]. + if nr.success and _is_send_payload(nr.output): + next_step = cast( + "list[Send]", + list(nr.output) if isinstance(nr.output, list) else [nr.output], + ) + continue + if not nr.skipped: if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): context.state = apply_update(context.state, nr.output, self._reducers) @@ -851,12 +1015,14 @@ async def _run_cyclic( all_results=all_results, ) - current = self._cyclic_next(current, context) + try: + next_step = self._cyclic_next(current, context) + except PipelineError as exc: + all_results[current] = NodeResult(node_id=current, success=False, error=str(exc), output=nr.output) + break finally: elapsed = (time.perf_counter() - pipeline_start) * 1000 - success = ( - result.success if result is not None else all(r.success or r.skipped for r in all_results.values()) - ) + success = False if pending_pause is not None else all(r.success or r.skipped for r in all_results.values()) await self._dispatch( "on_pipeline_complete", pipeline_name=self._dag.name, @@ -868,26 +1034,43 @@ async def _run_cyclic( with contextlib.suppress(Exception): pipeline_span.end() - final_output = None + paused_node = pending_pause[0] if pending_pause else None + pause_reason_final = pending_pause[1] if pending_pause else None return PipelineResult( pipeline_name=self._dag.name, outputs=all_results, - final_output=final_output, + final_output=None, execution_trace=trace_entries, total_duration_ms=elapsed, success=success, usage=None, run_id=run_id, final_state=context.state, + paused=pending_pause is not None, + paused_node=paused_node, + pause_reason=pause_reason_final, ) - def _cyclic_next(self, current: str, context: PipelineContext) -> str | None: - """Pick the next node by following the unique alive outgoing edge. + def _common_successor(self, node_ids: list[str]) -> str | None: + """Return the node all ``node_ids`` share as their unique successor, or None.""" + successors = [self._dag.successors(nid) for nid in node_ids] + if not successors or any(len(s) != 1 for s in successors): + return None + first = successors[0][0] + return first if all(s[0] == first for s in successors[1:]) else None + + def _cyclic_next(self, current: str, context: PipelineContext) -> str | list[Send] | None: + """Pick what runs next from ``current``. - Returns ``None`` when no outgoing edge is alive (terminus). Raises - if more than one is alive — parallel cyclic fan-out lands with - :class:`Send` in a later layer. + Priority: a registered router (.branch(...)) wins. Its return value + (str, Send, list[Send], or None) is resolved to a concrete target. + Otherwise fall back to the unique alive outgoing edge; multiple + alive edges with no router raise. """ + # Runtime router takes precedence. + if current in self._routers: + decision = self._routers[current](context.state) + return self._resolve_router_decision(current, decision) def _alive(edge: Any) -> bool: if edge.condition is None: @@ -904,12 +1087,57 @@ def _alive(edge: Any) -> bool: if len(alive) > 1: raise PipelineError( f"Cyclic node '{current}' has multiple alive outgoing edges " - f"({[e.target for e in alive]}). Parallel cyclic fan-out arrives " - f"with Send in a later layer; for now, use mutually exclusive " - f"edge conditions." + f"({[e.target for e in alive]}). Register a .branch(...) " + f"router or make the edge conditions mutually exclusive." ) return alive[0].target + def _resolve_router_decision(self, source: str, decision: Any) -> str | list[Send] | None: + """Translate a router's return value into a concrete next-step. + + Accepts: + * a string node id (looked up in ``router_mappings`` if registered), + * a single :class:`Send` (wrapped into a one-element list), + * a ``list[Send]`` (returned as-is after validation), + * ``None`` or an empty list (terminus). + """ + if decision is None: + return None + if isinstance(decision, list): + if not decision: + return None + for s in decision: + if not isinstance(s, Send): + raise PipelineError( + f"Router for '{source}' returned a list containing non-Send element {s!r}; expected list[Send]." + ) + if s.target not in self._dag.nodes: + raise PipelineError(f"Router for '{source}' fans out to unknown target '{s.target}'") + return decision + if isinstance(decision, Send): + if decision.target not in self._dag.nodes: + raise PipelineError(f"Router for '{source}' dispatched to unknown target '{decision.target}'") + return [decision] + if isinstance(decision, str): + mapping = self._router_mappings.get(source) + if mapping is not None: + if decision not in mapping: + raise PipelineError( + f"Router for '{source}' returned label '{decision}' not in mapping {list(mapping)}" + ) + return mapping[decision] + if decision not in self._dag.nodes: + raise PipelineError( + f"Router for '{source}' returned '{decision}' " + f"which is not a registered node id; pass an explicit " + f"mapping if you want abstract labels." + ) + return decision + raise PipelineError( + f"Router for '{source}' returned unsupported type {type(decision).__name__}; " + f"expected str, Send, list[Send], or None." + ) + async def _execute_node( self, node_id: str, @@ -947,20 +1175,17 @@ async def _execute_node( if inputs is None: inputs = self._gather_inputs(node_id, context) + node_prefix = "pipeline.state.node" if self._state_schema is not None else "pipeline.node" _node_span = self._start_otel_span( - f"pipeline.node.{node_id}", + f"{node_prefix}.{node_id}", node=node_id, - ) - - # Emit node start event (visit=1; cycles arrive in a later layer) - await self._dispatch( - "on_node_start", - pipeline_name=self._dag.name, - run_id=run_id, - node_id=node_id, visit=1, ) + # Note: on_node_start is now emitted by the caller (run / _run_cyclic / + # _run_sends) so the cyclic and fan-out paths can supply the right + # ``visit`` number. Emitting here would duplicate the event. + max_retries = node.retry_max backoff_factor = node.backoff_factor retries = 0 @@ -1067,7 +1292,7 @@ async def _emit_node_result(self, nr: NodeResult, run_id: str) -> None: else: await self._dispatch("on_node_error", error=nr.error or "unknown", **common) - def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, set[str], int]: + def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, list[str], int]: """Rebuild context + completed-set from the latest checkpoint. Resuming a paused run (checkpoint.paused=True) requires @@ -1097,7 +1322,7 @@ def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple context.state = self._state_schema.model_validate(saved_state) except Exception: logger.warning("Could not restore shared state on resume for run '%s'", run_id) - return context, set(record.completed_nodes), record.sequence + return context, list(record.completed_nodes), record.sequence def _save_checkpoint( self, @@ -1149,12 +1374,15 @@ def _record_audit( inputs_snapshot: dict[str, Any], trace_entries: list[ExecutionTraceEntry], visit: int = 1, + status_override: AuditStatus | None = None, + pause_reason: str | None = None, ) -> None: """Write an audit entry for a node visit. No-op if no audit log. Skipped nodes are not recorded — they represent work that did NOT - happen and would clutter the trail. ``visit`` defaults to 1 for the - acyclic scheduler and is supplied by the cyclic scheduler. + happen and would clutter the trail. ``status_override`` lets the + cyclic scheduler tag a Pause-returning node with ``"paused"`` + instead of the default ``"success"`` derived from ``nr.success``. """ if self._audit_log is None or nr.skipped: return @@ -1165,7 +1393,10 @@ def _record_audit( started_at = te.started_at completed_at = te.completed_at break - status: AuditStatus = "success" if nr.success else "error" + if status_override is not None: + status: AuditStatus = status_override + else: + status = "success" if nr.success else "error" outputs: dict[str, Any] = {"output": _serialize_value(nr.output)} if nr.success else {} entry = AuditEntry( pipeline_name=self._dag.name, @@ -1180,6 +1411,7 @@ def _record_audit( inputs_snapshot={k: _serialize_value(v) for k, v in inputs_snapshot.items()}, outputs_snapshot=outputs, error_message=nr.error if not nr.success else None, + pause_reason=pause_reason, ) try: self._audit_log.record(entry) diff --git a/fireflyframework_agentic/pipeline/result.py b/fireflyframework_agentic/pipeline/result.py index 1f6bc260..ade1ccf8 100644 --- a/fireflyframework_agentic/pipeline/result.py +++ b/fireflyframework_agentic/pipeline/result.py @@ -91,3 +91,36 @@ class PipelineResult(BaseModel): @property def failed_nodes(self) -> list[str]: return [nid for nid, r in self.outputs.items() if not r.success and not r.skipped] + + # -- State-mode convenience aliases --------------------------------- + + @property + def state(self) -> Any: + """Final shared state. Alias of :attr:`final_state` for state-aware + pipelines built via ``PipelineBuilder(state=...)``.""" + return self.final_state + + @property + def completed_nodes(self) -> list[str]: + """IDs of every successful node visit, in completion order. + + Derived from :attr:`execution_trace` so each cyclic re-entry of a + node appears as its own entry (matches StatePipeline's semantics). + """ + return [e.node_id for e in self.execution_trace if e.status == "success"] + + @property + def failed_node(self) -> str | None: + """First node that failed, if any. ``None`` when the run succeeded.""" + for nid, r in self.outputs.items(): + if not r.success and not r.skipped: + return nid + return None + + @property + def error(self) -> str | None: + """Error message from the first failed node, if any.""" + for r in self.outputs.values(): + if not r.success and not r.skipped and r.error: + return r.error + return None diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py deleted file mode 100644 index 07a3b1ea..00000000 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ /dev/null @@ -1,751 +0,0 @@ -# Copyright 2026 Firefly Software Foundation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""State-based pipeline: a sequential executor over a typed shared-state object. - -Layered on top of :class:`DAG` for topology, but uses its own simple executor -rather than :class:`PipelineEngine`. The trade-off: no within-level parallelism, -but in exchange we get clean semantics for typed state, reducers, branching, -checkpointing, and mid-pipeline resume — which are the things this API exists -to provide. Port-based parallel DAGs continue to use :class:`PipelineEngine`. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import inspect -import logging -import time -import uuid -import warnings -from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel - -from fireflyframework_agentic.exceptions import PipelineError -from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus -from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord -from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id -from fireflyframework_agentic.pipeline.engine import Pause, Send, start_otel_span -from fireflyframework_agentic.pipeline.reducers import apply_update, discover_reducers - -if TYPE_CHECKING: - from fireflyframework_agentic.pipeline.engine import StatePipelineEventHandler - -logger = logging.getLogger(__name__) - -StateNodeFn = Callable[[Any], Awaitable[dict[str, Any] | None]] -# A router may return: a node id (str), a Send, or a list[Send] for fan-out. -RouterFn = Callable[[Any], "str | Send | list[Send]"] - - -class RecursionLimitError(Exception): - """Raised when a node is visited more times than ``recursion_limit`` permits.""" - - -@dataclass -class BranchSpec: - """Internal: registered branch from one source node.""" - - source: str - router: RouterFn - mapping: dict[str, str] | None # label -> target node_id. None = router returns target directly. - - -@dataclass -class StatePipelineResult: - """Outcome of a single ``invoke`` call. - - Attributes: - state: Final state object. - run_id: ID of this run (use to resume later via ``invoke(run_id=...)``). - completed_nodes: Node IDs that ran successfully this invocation, in order. - success: True iff all attempted nodes completed without error. - error: Last error message if ``success`` is False. - failed_node: Node ID that failed, if any. - paused: True if the run halted on a :class:`Pause` sentinel; resume - via ``invoke(run_id=..., approve_pause=True)``. - paused_node: Node that returned ``Pause`` if ``paused`` is True. - pause_reason: Reason string the paused node passed to ``Pause(...)``. - """ - - state: Any - run_id: str - completed_nodes: list[str] - success: bool - error: str | None = None - failed_node: str | None = None - paused: bool = False - paused_node: str | None = None - pause_reason: str | None = None - - -class StatePipeline: - """Compiled state-based pipeline. Returned by ``PipelineBuilder.build()`` - when a ``state=`` schema is configured. - - .. deprecated:: - :class:`StatePipeline` is being subsumed by - :class:`fireflyframework_agentic.pipeline.engine.PipelineEngine` - configured with ``state_schema=``. The unified engine supports - the same features (state overlay, reducers, Pause, Send, cycles, - recursion_limit, checkpointing, audit, resume, start_at) and adds - true parallelism for state-aware pipelines via the topological - scheduler. New code should prefer ``PipelineEngine`` directly: - - .. code-block:: python - - engine = PipelineEngine( - dag, - state_schema=MyState, - checkpointer=cp, - audit_log=al, - recursion_limit=10, - ) - result = await engine.run(state=MyState(...)) - - See issue #245 for the full migration plan. The next layer of - unification removes :class:`StatePipeline` after a deprecation - cycle. - """ - - def __init__( - self, - *, - name: str, - dag: DAG, - state_schema: type[BaseModel], - node_fns: dict[str, StateNodeFn], - branches: dict[str, BranchSpec], - checkpointer: Checkpointer | None = None, - recursion_limit: int = 25, - event_handler: StatePipelineEventHandler | None = None, - audit_log: AuditLog | None = None, - ) -> None: - warnings.warn( - "StatePipeline is deprecated; use PipelineEngine(state_schema=...) " - "for the unified API. The unified engine supports the same features " - "and adds parallel state-aware execution. See issue #245.", - DeprecationWarning, - stacklevel=2, - ) - self._name = name - self._dag = dag - self._state_schema = state_schema - self._node_fns = node_fns - self._branches = branches - self._checkpointer = checkpointer - self._recursion_limit = recursion_limit - self._event_handler = event_handler - self._audit_log = audit_log - self._reducers = discover_reducers(state_schema) - self._validate() - - def _audit( - self, - *, - run_id: str, - node_id: str, - sequence: int, - visit: int, - started_at: datetime, - completed_at: datetime, - latency_ms: float, - status: AuditStatus, - inputs_snapshot: dict[str, Any], - outputs_snapshot: dict[str, Any], - error_message: str | None = None, - pause_reason: str | None = None, - ) -> None: - """Construct and write an :class:`AuditEntry`. No-op if no audit log is configured. - - Audit-write failures are non-fatal — logged and swallowed. - """ - if self._audit_log is None: - return - entry = AuditEntry( - pipeline_name=self._name, - run_id=run_id, - node_id=node_id, - sequence=sequence, - visit=visit, - started_at=started_at, - completed_at=completed_at, - latency_ms=latency_ms, - status=status, - inputs_snapshot=inputs_snapshot, - outputs_snapshot=outputs_snapshot, - error_message=error_message, - pause_reason=pause_reason, - ) - try: - self._audit_log.record(entry) - except Exception: - logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) - - async def _emit(self, method: str, *args: Any) -> None: - """Invoke ``method`` on the configured event handler if it exists. - - Missing methods are no-ops; raised exceptions are swallowed so - observability never breaks business logic. - """ - if self._event_handler is None: - return - fn = getattr(self._event_handler, method, None) - if fn is None: - return - with contextlib.suppress(Exception): - await fn(*args) - - @property - def name(self) -> str: - return self._name - - @property - def dag(self) -> DAG: - return self._dag - - def to_mermaid(self) -> str: - """Render the pipeline as a Mermaid flowchart, including branch edges. - - Branches that omit an explicit mapping are rendered as a dashed edge - labelled ``router`` because the targets are decided at runtime. - """ - lines = ["flowchart TD"] - for node_id in self._dag.nodes: - lines.append(f" {_mermaid_id(node_id)}[{node_id}]") - # Explicit edges (including branch mappings, which were materialized). - for edge in self._dag.edges: - label = None - spec = self._branches.get(edge.source) - if spec and spec.mapping: - for lbl, tgt in spec.mapping.items(): - if tgt == edge.target: - label = lbl - break - arrow = f"-->|{label}|" if label else "-->" - lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") - # Dynamic branches (no mapping): show as a dashed self-edge stub. - for source, spec in self._branches.items(): - if spec.mapping is None and not self._dag.successors(source): - lines.append(f" {_mermaid_id(source)} -.->|router| {_mermaid_id(source)}_router((dynamic))") - return "\n".join(lines) - - def _validate(self) -> None: - # Every node must have a registered fn. - for node_id in self._dag.nodes: - if node_id not in self._node_fns: - raise PipelineError(f"Node '{node_id}' has no registered function") - # Every branch source/target must exist. - for source, spec in self._branches.items(): - if source not in self._dag.nodes: - raise PipelineError(f"Branch source '{source}' not in DAG") - if spec.mapping: - for label, target in spec.mapping.items(): - if target not in self._dag.nodes: - raise PipelineError(f"Branch target '{target}' (label '{label}') not in DAG") - - def _entry_node(self) -> str: - """Default entry: the first node added. - - Override with ``invoke(state, start_at=...)``. Picking insertion-order - rather than the topological root keeps things predictable in the - common case where a ``.branch(...)`` without an explicit mapping - leaves multiple nodes with no inbound edges. - """ - order = list(self._dag.nodes) - if not order: - raise PipelineError("Pipeline has no nodes") - return order[0] - - def _next_step(self, current: str, state: BaseModel) -> str | list[Send] | None: - """Decide what runs next given the current state. - - Returns: - * A node id (str) for a single deterministic step. - * A list of :class:`Send` for runtime fan-out — workers run concurrently. - * ``None`` when the pipeline reaches a terminus. - """ - if current in self._branches: - decision = self._branches[current].router(state) - return self._resolve_router_decision(current, decision) - - successors = self._dag.successors(current) - if not successors: - return None - if len(successors) > 1: - raise PipelineError( - f"Node '{current}' has multiple successors {successors} but no .branch(...) registered. " - f"Register a branch router or remove the extra edges." - ) - return successors[0] - - def _resolve_router_decision(self, current: str, decision: str | Send | list[Send]) -> str | list[Send] | None: - """Translate a router's return value into a concrete next-step instruction.""" - # Fan-out: list of Send dispatches. - if isinstance(decision, list): - if not decision: - return None - for s in decision: - if not isinstance(s, Send): - raise PipelineError( - f"Router for '{current}' returned a list containing non-Send " - f"element {s!r}; expected list[Send]." - ) - if s.target not in self._dag.nodes: - raise PipelineError(f"Router for '{current}' fans out to unknown target '{s.target}'") - return decision - - if isinstance(decision, Send): - if decision.target not in self._dag.nodes: - raise PipelineError(f"Router for '{current}' dispatched to unknown target '{decision.target}'") - return [decision] - - # String label. - spec = self._branches[current] - if spec.mapping is not None: - if decision not in spec.mapping: - raise PipelineError( - f"Router for '{current}' returned label '{decision}' not in mapping {list(spec.mapping)}" - ) - return spec.mapping[decision] - if decision not in self._dag.nodes: - raise PipelineError( - f"Router for '{current}' returned '{decision}' " - f"which is not a registered node id; pass an explicit mapping if you want labels." - ) - return decision - - def _common_successor(self, node_ids: list[str]) -> str | None: - """Return the node all ``node_ids`` share as their unique successor, or None.""" - successors = [self._dag.successors(nid) for nid in node_ids] - if not successors or any(len(s) != 1 for s in successors): - return None - first = successors[0][0] - return first if all(s[0] == first for s in successors[1:]) else None - - async def invoke( - self, - state: BaseModel | None = None, - *, - run_id: str | None = None, - start_at: str | Callable[..., Any] | None = None, - approve_pause: bool = False, - ) -> StatePipelineResult: - """Run the pipeline. - - Modes: - * Fresh run: ``invoke(state)`` — generates a new ``run_id``. - * Resume: ``invoke(run_id="abc")`` — loads latest checkpoint and continues. - * Mid-pipeline start: ``invoke(state=..., start_at=node)`` — - starts execution at ``node`` with the provided state. - """ - completed: list[str] = [] - - # Resume mode: load checkpoint, derive starting node from it. - if run_id is not None and state is None and start_at is None: - if self._checkpointer is None: - raise PipelineError("Cannot resume: pipeline has no checkpointer") - record = self._checkpointer.load_latest(self._name, run_id) - if record is None: - raise PipelineError(f"No checkpoint found for run_id='{run_id}'") - # A paused run requires explicit approval before continuing. - if record.paused and not approve_pause: - raise PipelineError( - f"Run '{run_id}' is paused at node '{record.node_id}' " - f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." - ) - state = self._state_schema.model_validate(record.state) - completed = list(record.completed_nodes) - # Resume at the successor of the last completed (or paused) node. - next_node = self._next_step(record.node_id, state) - # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. - if isinstance(next_node, list): - raise PipelineError( - "Resume across a fan-out (Send) is not supported in Phase 2; " - "the run finished by reaching a fan-out node." - ) - if next_node is None: - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, - ) - current_node: str | None = next_node - else: - if state is None: - raise PipelineError("invoke() requires a state argument (or a run_id to resume)") - if not isinstance(state, self._state_schema): - # Be helpful if caller passed a dict or a different model. - try: - state = self._state_schema.model_validate(state) - except Exception as exc: - raise PipelineError(f"state argument is not a {self._state_schema.__name__}: {exc}") from exc - if start_at is not None: - current_node = _resolve_node_id(start_at) - if current_node not in self._dag.nodes: - raise PipelineError(f"start_at='{current_node}' not in DAG") - else: - current_node = self._entry_node() - - if run_id is None: - run_id = uuid.uuid4().hex[:12] - - assert state is not None # narrowed by the branches above - visit_counts: dict[str, int] = {} - - next_step: str | list[Send] | None = current_node - - # ---- pipeline-level observability boundary ------------------------- - pipeline_start_time = time.perf_counter() - pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id) - await self._emit("on_pipeline_start", self._name, run_id) - - result: StatePipelineResult | None = None - try: - while next_step is not None: - # --- fan-out branch (list[Send]) --------------------------------- - if isinstance(next_step, list): - try: - state = await self._run_fanout( - sends=next_step, - state=state, - completed=completed, - run_id=run_id, - visit_counts=visit_counts, - ) - except _NodeFailureError as fail: - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=fail.message, - failed_node=fail.node_id, - ) - break - # After fan-out, continue from the workers' shared successor (if any). - next_step = self._common_successor([s.target for s in next_step]) - continue - - # --- single-node step -------------------------------------------- - node_id = next_step - visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 - visit_n = visit_counts[node_id] - if visit_n > self._recursion_limit: - msg = ( - f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " - f"Raise recursion_limit= or fix the routing logic." - ) - logger.error(msg) - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=msg, - failed_node=node_id, - ) - break - - fn = self._node_fns[node_id] - node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) - await self._emit("on_node_start", self._name, run_id, node_id, visit_n) - inputs_snapshot = state.model_dump(mode="json") - started_at = datetime.now(UTC) - t0 = time.perf_counter() - try: - update = await fn(state) - except Exception as exc: - logger.exception( - "State pipeline '%s' run '%s' failed at node '%s'", - self._name, - run_id, - node_id, - ) - await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - self._audit( - run_id=run_id, - node_id=node_id, - sequence=len(completed) + 1, - visit=visit_n, - started_at=started_at, - completed_at=datetime.now(UTC), - latency_ms=(time.perf_counter() - t0) * 1000, - status="error", - inputs_snapshot=inputs_snapshot, - outputs_snapshot={}, - error_message=str(exc), - ) - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, - ) - break - elapsed = (time.perf_counter() - t0) * 1000 - completed_at = datetime.now(UTC) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - - # HITL: a node returning Pause halts the pipeline and writes a - # paused checkpoint. Approval comes via invoke(approve_pause=True). - if isinstance(update, Pause): - pause_reason = update.reason - await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) - completed.append(node_id) - self._save_checkpoint( - run_id, - node_id, - len(completed), - state, - completed, - paused=True, - pause_reason=pause_reason, - ) - self._audit( - run_id=run_id, - node_id=node_id, - sequence=len(completed), - visit=visit_n, - started_at=started_at, - completed_at=completed_at, - latency_ms=elapsed, - status="paused", - inputs_snapshot=inputs_snapshot, - outputs_snapshot=state.model_dump(mode="json"), - pause_reason=pause_reason, - ) - logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - paused=True, - paused_node=node_id, - pause_reason=pause_reason, - ) - break - - await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) - logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) - - if update: - state = apply_update(state, update, self._reducers) - - completed.append(node_id) - self._save_checkpoint(run_id, node_id, len(completed), state, completed) - self._audit( - run_id=run_id, - node_id=node_id, - sequence=len(completed), - visit=visit_n, - started_at=started_at, - completed_at=completed_at, - latency_ms=elapsed, - status="success", - inputs_snapshot=inputs_snapshot, - outputs_snapshot=state.model_dump(mode="json"), - ) - - try: - next_step = self._next_step(node_id, state) - except PipelineError as exc: - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, - ) - break - - if result is None: - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, - ) - finally: - if pipeline_span is not None: - with contextlib.suppress(Exception): - pipeline_span.end() - duration_ms = (time.perf_counter() - pipeline_start_time) * 1000 - success = result.success if result is not None else False - await self._emit("on_pipeline_complete", self._name, run_id, success, duration_ms) - - assert result is not None # set in try-block before reaching here - return result - - async def _run_fanout( - self, - *, - sends: list[Send], - state: BaseModel, - completed: list[str], - run_id: str, - visit_counts: dict[str, int], - ) -> BaseModel: - """Run all ``Send`` dispatches concurrently. Each task gets its own state - copy with the Send's payload merged in; results are reduced into shared state. - """ - # Snapshot each Send's visit number BEFORE dispatch so the worker's - # closure captures its own visit, not the final post-increment value. - sends_with_visits: list[tuple[Send, int]] = [] - for send in sends: - visit_counts[send.target] = visit_counts.get(send.target, 0) + 1 - sends_with_visits.append((send, visit_counts[send.target])) - if visit_counts[send.target] > self._recursion_limit: - raise _NodeFailureError( - node_id=send.target, - message=( - f"Recursion limit ({self._recursion_limit}) exceeded at node '{send.target}' during fan-out." - ), - ) - - async def _run_one(send: Send, visit_n: int) -> tuple[Send, dict[str, Any] | None]: - await self._emit("on_node_start", self._name, run_id, send.target, visit_n) - node_span = start_otel_span(f"pipeline.state.node.{send.target}", node=send.target, visit=visit_n) - task_state = apply_update(state, send.payload, self._reducers) - fn = self._node_fns[send.target] - t0 = time.perf_counter() - try: - update = await fn(task_state) - except Exception as exc: - await self._emit("on_node_error", self._name, run_id, send.target, str(exc)) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - raise - elapsed = (time.perf_counter() - t0) * 1000 - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - await self._emit("on_node_complete", self._name, run_id, send.target, elapsed) - return send, update - - try: - results = await asyncio.gather(*(_run_one(s, v) for s, v in sends_with_visits)) - except Exception as exc: - # Best-effort: report the first failing target as the failure point. - raise _NodeFailureError( - node_id=sends[0].target, - message=f"Fan-out failure: {exc}", - ) from exc - - new_state = state - for send, update in results: - if update: - new_state = apply_update(new_state, update, self._reducers) - completed.append(send.target) - self._save_checkpoint(run_id, send.target, len(completed), new_state, completed) - - return new_state - - def _save_checkpoint( - self, - run_id: str, - node_id: str, - sequence: int, - state: BaseModel, - completed: list[str], - *, - paused: bool = False, - pause_reason: str | None = None, - ) -> None: - """Persist state via the configured checkpointer (no-op if absent).""" - if self._checkpointer is None: - return - try: - self._checkpointer.save( - CheckpointRecord( - pipeline_name=self._name, - run_id=run_id, - node_id=node_id, - sequence=sequence, - state=state.model_dump(), - completed_nodes=list(completed), - paused=paused, - pause_reason=pause_reason, - ) - ) - except Exception: - logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) - - -class _NodeFailureError(Exception): - """Internal sentinel used to bubble fan-out failures out to the main loop.""" - - def __init__(self, node_id: str, message: str) -> None: - super().__init__(message) - self.node_id = node_id - self.message = message - - -def _resolve_node_id(ref: str | Callable[..., Any]) -> str: - """Turn either a string id or a function reference into a node id.""" - if isinstance(ref, str): - return ref - name = getattr(ref, "__name__", None) - if not name: - raise PipelineError(f"Cannot derive node id from {ref!r}") - return name - - -def coerce_state_node_fn(fn: Callable[..., Any]) -> StateNodeFn: - """Adapt a user-supplied callable into the ``async (state) -> dict | None`` shape. - - Accepted forms: - * ``async def f(state) -> dict | None`` — used as-is. - * ``def f(state) -> dict | None`` — wrapped to run in a thread. - * Object with ``async run(state)`` (e.g. a FireflyAgent-like) — adapter calls ``.run(state)``. - """ - if inspect.iscoroutinefunction(fn): - return fn # type: ignore[return-value] - - # Object with .run(state) — e.g. a FireflyAgent. Check before the generic - # callable branch so agent-shaped objects don't get treated as plain callables. - run = getattr(fn, "run", None) - if not callable(fn) and run is not None and callable(run): - - async def _agent_wrap(state: Any) -> Any: - if inspect.iscoroutinefunction(run): - return await run(state) - return await asyncio.get_running_loop().run_in_executor(None, run, state) - - return _agent_wrap - - if callable(fn): - - async def _async_wrap(state: Any) -> Any: - return await asyncio.get_running_loop().run_in_executor(None, fn, state) - - return _async_wrap - - raise PipelineError(f"Cannot adapt {fn!r} as a state node function") diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py index 419faca3..51e42906 100644 --- a/tests/unit/pipeline/test_checkpoint_backends.py +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -20,7 +20,6 @@ CheckpointRecord, FileCheckpointer, PipelineBuilder, - StatePipeline, ) # ============================================================================= @@ -113,7 +112,7 @@ class FactoryState(BaseModel): evaluation: str | None = None -def _build_factory(checkpointer) -> StatePipeline: +def _build_factory(checkpointer): """Construct the canonical 4-step agent pipeline that fails on first deploy.""" state_flag = {"failed_once": False} @@ -141,7 +140,6 @@ async def evaluator(state: FactoryState) -> dict: .chain(architect, python_dev, deployer, evaluator) .build() ) - assert isinstance(pipeline, StatePipeline) return pipeline diff --git a/tests/unit/pipeline/test_state_pipeline.py b/tests/unit/pipeline/test_state_pipeline.py index 0853c106..f0d1095d 100644 --- a/tests/unit/pipeline/test_state_pipeline.py +++ b/tests/unit/pipeline/test_state_pipeline.py @@ -27,7 +27,6 @@ from fireflyframework_agentic.pipeline import ( FileCheckpointer, PipelineBuilder, - StatePipeline, append, ) @@ -62,7 +61,6 @@ async def step_c(state: AgentState) -> dict: .chain(step_a, step_b, step_c) .build() ) - assert isinstance(pipeline, StatePipeline) result = await pipeline.invoke(AgentState(messages=["start"])) assert result.success assert result.completed_nodes == ["step_a", "step_b", "step_c"] diff --git a/tests/unit/pipeline/test_state_pipeline_deprecation.py b/tests/unit/pipeline/test_state_pipeline_deprecation.py deleted file mode 100644 index ec65da49..00000000 --- a/tests/unit/pipeline/test_state_pipeline_deprecation.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Layer 7 of the unification (#245): StatePipeline deprecation. - -Constructing :class:`StatePipeline` now emits a :class:`DeprecationWarning` -pointing at :class:`PipelineEngine` configured with ``state_schema=`` as -the supported replacement. -""" - -from __future__ import annotations - -import warnings - -from pydantic import BaseModel - -from fireflyframework_agentic.pipeline.builder import PipelineBuilder - - -class _S(BaseModel): - x: int = 0 - - -async def _noop(state): - return None - - -def test_state_pipeline_emits_deprecation_warning(): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - PipelineBuilder("p", state=_S).add_node(_noop).build() - deprec = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert deprec, "expected a DeprecationWarning when constructing StatePipeline" - assert "PipelineEngine" in str(deprec[0].message) - assert "#245" in str(deprec[0].message) diff --git a/tests/unit/pipeline/test_state_pipeline_hitl.py b/tests/unit/pipeline/test_state_pipeline_hitl.py index aff1c687..f880c459 100644 --- a/tests/unit/pipeline/test_state_pipeline_hitl.py +++ b/tests/unit/pipeline/test_state_pipeline_hitl.py @@ -19,7 +19,6 @@ FileCheckpointer, Pause, PipelineBuilder, - StatePipeline, extend, ) @@ -62,7 +61,6 @@ async def deploy(state: DeployState) -> dict: .chain(architect, gate, deploy) .build() ) - assert isinstance(pipeline, StatePipeline) result = await pipeline.invoke(DeployState(requirements="user-mgmt")) assert result.paused is True diff --git a/tests/unit/pipeline/test_state_pipeline_phase2.py b/tests/unit/pipeline/test_state_pipeline_phase2.py index f2040f80..c3862fa4 100644 --- a/tests/unit/pipeline/test_state_pipeline_phase2.py +++ b/tests/unit/pipeline/test_state_pipeline_phase2.py @@ -25,7 +25,6 @@ FanOutStep, PipelineBuilder, Send, - StatePipeline, extend, ) from fireflyframework_agentic.pipeline.steps import CallableStep @@ -53,7 +52,6 @@ def route(state: LoopState) -> str: return "done" if state.counter >= 3 else "step" pipeline = PipelineBuilder("loop", state=LoopState).add_node(step).add_node(done).branch(step, route).build() - assert isinstance(pipeline, StatePipeline) result = await pipeline.invoke(LoopState()) assert result.success assert result.state.counter == 3 @@ -199,7 +197,6 @@ def route(state: LoopState) -> str: .branch(start, route, {"left_path": left, "right_path": right}) .build() ) - assert isinstance(pipeline, StatePipeline) mermaid = pipeline.to_mermaid() assert "start -->|left_path| left" in mermaid assert "start -->|right_path| right" in mermaid From fbbadb36d8295c8e634ab34cfc73672e056c6a37 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 16:35:03 +0200 Subject: [PATCH 26/26] fix(example): software-factory pipeline imports PipelineEngine Layer 8 (#255) deleted StatePipeline but I missed the import + return type in examples/software_factory/pipeline.py. CI on #232 fails to collect tests/examples/software_factory/test_pipeline.py with: ImportError: cannot import name 'StatePipeline' from 'fireflyframework_agentic.pipeline' Fix: import PipelineEngine instead and drop the now-unnecessary cast. PipelineBuilder.build() returns PipelineEngine directly after Layer 8. Test passes locally. --- examples/software_factory/pipeline.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/software_factory/pipeline.py b/examples/software_factory/pipeline.py index a686a50a..f64f3b0e 100644 --- a/examples/software_factory/pipeline.py +++ b/examples/software_factory/pipeline.py @@ -17,8 +17,6 @@ from __future__ import annotations -from typing import cast - from examples.software_factory.agents import ( architect, builder, @@ -31,7 +29,7 @@ from fireflyframework_agentic.pipeline import ( Checkpointer, PipelineBuilder, - StatePipeline, + PipelineEngine, ) @@ -40,7 +38,7 @@ def qa_router(state: BuildState) -> str: return "stable_release" if state.qa_status == "pass" else "codegen" -def build_pipeline(checkpointer: Checkpointer) -> StatePipeline: +def build_pipeline(checkpointer: Checkpointer) -> PipelineEngine: pipeline = ( PipelineBuilder( "software-factory", @@ -60,5 +58,4 @@ def build_pipeline(checkpointer: Checkpointer) -> StatePipeline: .branch("qa", qa_router) .build() ) - # state= was set, so .build() returns a StatePipeline — narrow for the type checker. - return cast("StatePipeline", pipeline) + return pipeline