diff --git a/docs/pipeline.md b/docs/pipeline.md index 57b0d152..2ece529e 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -7,6 +7,17 @@ workflows. It supports parallel execution, conditional branching, retries, timeo and fan-out/fan-in patterns -- everything needed to model real-world enterprise processing pipelines. +`PipelineBuilder` has two modes: + +* **Port-based** (legacy, parallel) — nodes communicate via `output_key` / + `input_key` edge ports and run concurrently within each topological level. + Best for ETL-shaped DAGs. Documented in the bulk of this guide. +* **State-based** — opt-in via `PipelineBuilder("name", state=SomeModel)`. + Nodes become `async (state) -> dict` over a typed shared state. One + `.branch(source, router)` call covers conditional routing; `Send(target, payload)` + covers runtime fan-out; a `Checkpointer` enables resume after failure. Best for + agentic workflows and ReAct-style loops. See [State-Based Pipelines](#state-based-pipelines). + --- ## Concepts @@ -88,15 +99,181 @@ The framework provides these built-in executors: - **CallableStep** -- Wraps any `async` function `(context, inputs) -> output`. - **BatchLLMStep** -- Processes multiple prompts concurrently through an agent for cost optimization. See [Batch Processing](#batch-processing-batchllmstep) below. -- **BranchStep** -- Routes execution to one of several downstream paths based on - a predicate (see [Conditional Branching](#conditional-branching-branchstep) below). -- **FanOutStep** -- Splits input into a list for parallel downstream processing. +- **BranchStep** _(deprecated)_ -- Routes execution to one of several downstream paths based on + a predicate. Use `.branch(...)` in [State-Based Pipelines](#state-based-pipelines) instead. +- **FanOutStep** _(deprecated)_ -- Splits input into a list for parallel downstream processing. + Use `Send` in [State-Based Pipelines](#runtime-fan-out-via-send) instead. - **FanInStep** -- Merges outputs from multiple upstream nodes. --- +## State-Based Pipelines + +Set `state=` on `PipelineBuilder` to switch to a declarative API designed for +agentic workflows. Nodes become `async (state) -> dict | None` functions over +a typed shared-state object; the engine reduces each node's partial-update +dict back into the state. + +```python +from typing import Annotated +from pydantic import BaseModel +from fireflyframework_agentic.pipeline import PipelineBuilder, append + + +class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] # reducer: append + intent: str | None = None # default reducer: replace + answer: str | None = None + + +async def classify(state: AgentState) -> dict: + return {"intent": "complaint" if "refund" in state.messages[-1] else "general"} + + +async def answer(state: AgentState) -> dict: + return {"answer": "Here is your answer."} + + +async def escalate(state: AgentState) -> dict: + return {"answer": "Escalated to human."} + + +def route(state: AgentState) -> str: + return "escalate" if state.intent == "complaint" else "answer" + + +pipeline = ( + PipelineBuilder("support-agent", state=AgentState) + .add_node(classify) # node id derived from fn.__name__ + .add_node(answer) + .add_node(escalate) + .branch(classify, route) # router returns target node id + .build() +) +result = await pipeline.invoke(AgentState(messages=["I want a refund"])) +print(result.state.answer) +``` + +### Reducers + +Reducers are declared as `Annotated[T, reducer_fn]` on the state schema. The +built-ins live in `fireflyframework_agentic.pipeline.reducers`: + +| Reducer | Semantics | +|---------------|-------------------------------------------------| +| `replace` | Last-write-wins (the default for any field). | +| `append` | Append a single item to a list. | +| `extend` | Concatenate two iterables. | +| `merge_dict` | Shallow-merge two dicts; update wins on conflict. | + +Custom reducers are any callable `(current, update) -> merged`. + +### Branching + +`.branch(source, router, mapping=None)` registers a synchronous +`(state) -> str | Send | list[Send]` router on `source`: + +* Returning a node id (string) routes to that node directly. +* Passing `mapping={"label": target_node, ...}` lets the router return an + abstract label instead of a node id. +* Returning a `Send` or `list[Send]` triggers runtime fan-out (see below). + +### Checkpoint + Resume + +Pass a `Checkpointer` to persist state after each successful node. The +filesystem implementation ships in the package; Redis/Postgres backends are +straightforward to plug in via the `Checkpointer` Protocol. + +```python +from fireflyframework_agentic.pipeline import FileCheckpointer + +pipeline = ( + PipelineBuilder("software-factory", state=BuildState, + checkpointer=FileCheckpointer("./checkpoints")) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() +) + +# Fresh run +result = await pipeline.invoke(BuildState(requirements="user-mgmt service")) + +# Resume after crash — picks up at the failed node, skips completed ones +result = await pipeline.invoke(run_id=result.run_id) + +# Or jump into a specific node with explicit state +result = await pipeline.invoke(state=loaded_state, start_at=deployer) +``` + +### Cycles and `recursion_limit` + +State pipelines permit cycles for ReAct loops and retry-with-critique patterns. +The builder accepts `recursion_limit` (default 25) as a safety net — a runaway +loop surfaces as `result.success=False` with a clean error, not an infinite hang. + +```python +def route(state): + return "done" if state.counter >= 3 else "step" + +PipelineBuilder("loop", state=LoopState, recursion_limit=25) + .add_node(step).add_node(done).branch(step, route).build() +``` + +### Runtime Fan-Out via `Send` + +A router may return `list[Send(target, payload)]` to dispatch multiple +invocations of the same (or different) workers concurrently. Each Send's +payload is applied to a copy of the current state before its target runs; +results reduce back into shared state. Replaces the legacy `FanOutStep`. + +```python +from fireflyframework_agentic.pipeline import Send + +def dispatch(state): + return [Send("worker", {"item": x}) for x in state.items] + +PipelineBuilder("mapreduce", state=MapReduceState) + .add_node(planner).add_node(worker).add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() +``` + +When all worker targets share a common successor, the engine continues there +once the fan-out completes; the aggregator runs once with all results in +shared state. + +### Mermaid Export + +`StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a +Mermaid flowchart. Branch edges declared with an explicit mapping show their +label; dynamic routers are noted as such. + +### When to use which mode + +| Use port-based when… | Use state-based when… | +|----------------------|------------------------| +| Pure ETL: parallel, fan-out/fan-in, no shared state | Agentic workflow: classify → branch → respond / loop / retry | +| Each step's input is a single value from the previous step | Multiple agents reading/writing different fields of a shared object | +| You want the engine to run independent nodes concurrently | You want resume-after-failure and start-from-middle semantics | +| You're happy with `BranchStep` + per-node `condition` lambdas | You want one `.branch(...)` call and inspectable routing | + +See [`examples/pipeline_state.py`](../examples/pipeline_state.py) for a +runnable demo covering branching, software-factory checkpoint/resume, and +map-reduce fan-out. + +--- + ## Parallel Execution (Fan-Out / Fan-In) +> **`FanOutStep` is deprecated.** For runtime fan-out (one dispatch per item, +> arbitrary count), prefer `Send` from [State-Based Pipelines](#runtime-fan-out-via-send). +> `FanOutStep` still works for now (it emits a `DeprecationWarning` on +> construction); `FanInStep` is not deprecated. + ```mermaid graph TD SPLIT[Fan-Out] --> W1[Worker 1] @@ -248,6 +425,12 @@ dag.add_node(DAGNode( ### Conditional Branching (BranchStep) +> **Deprecated.** Prefer [State-Based Pipelines](#state-based-pipelines) with +> `.branch(source, router)` — one call instead of `BranchStep` + per-node +> `condition` lambdas, and the topology becomes inspectable as data. +> `BranchStep` still works (it emits a `DeprecationWarning` on construction); +> removal will be tracked in a follow-up issue once internal callers migrate. + `BranchStep` provides router-based conditional branching. The router callable receives the node's input and returns a string key. Downstream nodes use condition gates to check the branch key and execute only the matching path. diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py new file mode 100644 index 00000000..9f0fc463 --- /dev/null +++ b/examples/pipeline_state.py @@ -0,0 +1,242 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-based PipelineBuilder: branching, checkpoint/resume, and Send fan-out. + +Three scenarios: + +1. **Branching** — same sentiment-classification workflow as + ``examples/pipeline_branching.py``, but written with the state-mode API + (one shared ``State`` model, ``async (state) -> dict`` nodes, one + ``.branch(source, router)`` call instead of ``BranchStep`` + per-node + ``condition`` lambdas). + +2. **Software factory with checkpoint/resume** — a four-agent pipeline + (architect → python_dev → deployer → evaluator) where the deployer fails + on its first attempt. The pipeline checkpoints after each successful node, + and a second ``invoke(run_id=...)`` resumes from the failed node instead + of re-running the earlier agents. + +3. **Map-reduce via ``Send``** — a planner dispatches one ``Send`` per work + item to the same worker node, the workers run concurrently, and an + aggregator runs once with all results in shared state. + +Usage:: + + uv run python examples/pipeline_state.py + +.. note:: No OpenAI API key required — all "agents" are plain Python stubs. +""" + +from __future__ import annotations + +import asyncio +import logging +import tempfile +from pathlib import Path +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + Send, + extend, +) + +# Quiet the pipeline's own logger.exception() when we deliberately fail +# the deployer in scenario 2 — the failure is the demo, not a bug. +logging.getLogger("fireflyframework_agentic.pipeline").setLevel(logging.CRITICAL) + + +# ============================================================================= +# Scenario 1 — Branching +# ============================================================================= + + +class SentimentState(BaseModel): + text: str + sentiment: str | None = None + response: str | None = None + + +async def classify_sentiment(state: SentimentState) -> dict: + text = state.text.lower() + positive = {"good", "great", "love", "amazing", "wonderful", "happy", "excellent"} + negative = {"bad", "terrible", "hate", "awful", "horrible", "sad", "poor"} + pos = sum(1 for w in text.split() if w in positive) + neg = sum(1 for w in text.split() if w in negative) + return {"sentiment": "positive" if pos >= neg else "negative"} + + +async def positive_reply(state: SentimentState) -> dict: + return {"response": "😊 Thank you for your kind words!"} + + +async def negative_reply(state: SentimentState) -> dict: + return {"response": "😟 We're sorry to hear that. We'll improve!"} + + +def route_by_sentiment(state: SentimentState) -> str: + # The router returns the node id directly — no mapping needed. + return "positive_reply" if state.sentiment == "positive" else "negative_reply" + + +async def run_branching() -> None: + print("=== 1. Branching (state mode) ===\n") + + pipeline = ( + PipelineBuilder("sentiment", state=SentimentState) + .add_node(classify_sentiment) + .add_node(positive_reply) + .add_node(negative_reply) + .branch(classify_sentiment, route_by_sentiment) + .build() + ) + + for text in ["This product is great and amazing!", "The service was terrible and awful."]: + result = await pipeline.invoke(SentimentState(text=text)) + print(f" input: {text!r}") + print(f" output: {result.state.response}\n") + + +# ============================================================================= +# Scenario 2 — Software factory with checkpoint/resume +# ============================================================================= + + +class BuildState(BaseModel): + """State threaded through a four-agent software-factory pipeline.""" + + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +# A flag so the deployer fails the first time and succeeds the second. +_deployer_failed_once = {"flag": False} + + +async def architect(state: BuildState) -> dict: + return {"spec": f"Architecture for: {state.requirements}"} + + +async def python_dev(state: BuildState) -> dict: + return {"code": f"# code implementing\n# {state.spec}"} + + +async def deployer(state: BuildState) -> dict: + if not _deployer_failed_once["flag"]: + _deployer_failed_once["flag"] = True + raise RuntimeError("network blip — try again") + return {"deploy_url": "https://factory-app.example.com"} + + +async def evaluator(state: BuildState) -> dict: + return {"evaluation": f"PASS — deployed at {state.deploy_url}"} + + +async def run_software_factory() -> None: + print("=== 2. Software factory with checkpoint/resume ===\n") + + with tempfile.TemporaryDirectory() as tmp: + ckpt = FileCheckpointer(Path(tmp)) + pipeline = ( + PipelineBuilder("software-factory", state=BuildState, checkpointer=ckpt) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + + # First run — deployer fails after architect + python_dev complete. + first = await pipeline.invoke(BuildState(requirements="User-management service")) + print(f" first run: success={first.success}, failed_node={first.failed_node}") + print(f" completed: {first.completed_nodes}") + print(f" run_id: {first.run_id}\n") + + # Resume — picks up at deployer, skips architect + python_dev. + second = await pipeline.invoke(run_id=first.run_id) + print(f" resumed: success={second.success}") + print(f" completed: {second.completed_nodes}") + print(f" eval: {second.state.evaluation}\n") + + +# ============================================================================= +# Scenario 3 — Map-reduce via Send +# ============================================================================= + + +class MapReduceState(BaseModel): + items: list[str] = [] + processed: Annotated[list[str], extend] = [] + summary: str | None = None + # Per-Send payload field — each worker receives its own item here. + item: str | None = None + + +async def plan(state: MapReduceState) -> dict: + # No state mutation; the dispatch router below decides what runs next. + return {} + + +async def process_item(state: MapReduceState) -> dict: + assert state.item is not None + return {"processed": [f"processed:{state.item}"]} + + +async def aggregate(state: MapReduceState) -> dict: + return {"summary": f"Processed {len(state.processed)} items: {state.processed}"} + + +def dispatch(state: MapReduceState) -> list[Send]: + # One Send per item — workers run concurrently. The ``extend`` reducer on + # ``processed`` merges all worker outputs into one list. + return [Send("process_item", {"item": x}) for x in state.items] + + +async def run_map_reduce() -> None: + print("=== 3. Map-reduce via Send ===\n") + + pipeline = ( + PipelineBuilder("mapreduce", state=MapReduceState) + .add_node(plan) + .add_node(process_item) + .add_node(aggregate) + .add_edge(process_item, aggregate) + .branch(plan, dispatch) + .build() + ) + result = await pipeline.invoke(MapReduceState(items=["alpha", "beta", "gamma", "delta"])) + print(f" summary: {result.state.summary}") + + +# ============================================================================= +# Entrypoint +# ============================================================================= + + +async def main() -> None: + await run_branching() + await run_software_factory() + await run_map_reduce() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index f49f8db6..ed844dc0 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -39,7 +39,12 @@ from fireflyframework_agentic.pipeline.engine import PipelineEngine, PipelineEventHandler from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult -from fireflyframework_agentic.pipeline.state_pipeline import StatePipeline, StatePipelineResult +from fireflyframework_agentic.pipeline.state_pipeline import ( + RecursionLimitError, + Send, + StatePipeline, + StatePipelineResult, +) from fireflyframework_agentic.pipeline.steps import ( AgentStep, BatchLLMStep, @@ -76,6 +81,8 @@ "PipelineEventHandler", "PipelineResult", "ReasoningStep", + "RecursionLimitError", + "Send", "RetrievalStep", "StatePipeline", "StatePipelineResult", diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 1424d2c9..bb5f7681 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -59,6 +59,7 @@ from fireflyframework_agentic.pipeline.state_pipeline import ( BranchSpec, RouterFn, + Send, # noqa: F401 re-exported via pipeline/__init__.py StateNodeFn, StatePipeline, coerce_state_node_fn, @@ -84,11 +85,15 @@ def __init__( *, state: type[BaseModel] | None = None, checkpointer: Checkpointer | None = None, + recursion_limit: int = 25, ) -> None: - self._dag = DAG(name=name) + # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). + # The legacy port-based path keeps acyclicity as an invariant. + self._dag = DAG(name=name, allow_cycles=state is not None) self._name = name self._state_schema = state self._checkpointer = checkpointer + self._recursion_limit = recursion_limit self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] # State-based mode bookkeeping. Keyed by node id. @@ -261,6 +266,7 @@ def build(self) -> PipelineEngine | StatePipeline: node_fns=self._state_node_fns, branches=self._branches, checkpointer=self._checkpointer, + recursion_limit=self._recursion_limit, ) return PipelineEngine(self._dag) diff --git a/fireflyframework_agentic/pipeline/dag.py b/fireflyframework_agentic/pipeline/dag.py index 2e4fef9b..80d78b76 100644 --- a/fireflyframework_agentic/pipeline/dag.py +++ b/fireflyframework_agentic/pipeline/dag.py @@ -15,12 +15,14 @@ """Directed Acyclic Graph (DAG) model for pipeline topology. :class:`DAG` holds :class:`DAGNode` and :class:`DAGEdge` objects, validates -acyclicity, computes topological sort, and identifies independent execution -levels for parallel scheduling. +acyclicity (unless ``allow_cycles=True``), computes topological sort, and +identifies independent execution levels for parallel scheduling. Also renders +itself as Mermaid or JSON for inspection / docs / Studio. """ from __future__ import annotations +import json from collections import defaultdict, deque from collections.abc import Callable from enum import StrEnum @@ -98,8 +100,9 @@ class DAG: name: A human-readable name for the pipeline. """ - def __init__(self, name: str = "pipeline") -> None: + def __init__(self, name: str = "pipeline", *, allow_cycles: bool = False) -> None: self._name = name + self._allow_cycles = allow_cycles self._nodes: dict[str, DAGNode] = {} self._edges: list[DAGEdge] = [] # Adjacency and reverse adjacency for topo-sort @@ -136,9 +139,9 @@ def add_edge(self, edge: DAGEdge) -> None: self._edges.append(edge) self._adj[edge.source].append(edge.target) self._in_degree[edge.target] = self._in_degree.get(edge.target, 0) + 1 - # Incremental cycle check - if self._has_cycle(): - # Rollback + # Cycle check is skipped when the DAG was constructed with allow_cycles=True + # (state-based pipelines opt into this for ReAct-style loops). + if not self._allow_cycles and self._has_cycle(): self._edges.pop() self._adj[edge.source].pop() self._in_degree[edge.target] -= 1 @@ -236,5 +239,62 @@ def _has_cycle(self) -> bool: queue.append(neighbour) return count != len(self._nodes) + def is_cyclic(self) -> bool: + """True if the graph contains at least one cycle.""" + return self._has_cycle() + + # -- Export ------------------------------------------------------------ + + def to_mermaid(self) -> str: + """Render the topology as a Mermaid flowchart. + + Edges with ``input_key`` other than the default ``"input"`` are + labelled with that key so port wiring is visible. + """ + lines = ["flowchart TD"] + for node_id in self._nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + for edge in self._edges: + label = edge.input_key if edge.input_key and edge.input_key != "input" else None + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + return "\n".join(lines) + + def to_json(self) -> str: + """Render the topology as a JSON document. + + Schema:: + + {"name": str, "nodes": [str], "edges": [{"source", "target", "output_key", "input_key"}]} + """ + doc = { + "name": self._name, + "nodes": list(self._nodes.keys()), + "edges": [ + { + "source": e.source, + "target": e.target, + "output_key": e.output_key, + "input_key": e.input_key, + } + for e in self._edges + ], + } + return json.dumps(doc, indent=2) + def __repr__(self) -> str: return f"DAG(name={self._name!r}, nodes={len(self._nodes)}, edges={len(self._edges)})" + + +def _mermaid_id(node_id: str) -> str: + """Sanitize a node id for use as a Mermaid identifier.""" + out = [] + for ch in node_id: + if ch.isalnum() or ch == "_": + out.append(ch) + else: + out.append("_") + sanitized = "".join(out) + if sanitized and sanitized[0].isdigit(): + sanitized = "n_" + sanitized + return sanitized or "anon" diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 1b09e185..6c818bf7 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -36,13 +36,34 @@ from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord -from fireflyframework_agentic.pipeline.dag import DAG +from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id from fireflyframework_agentic.pipeline.reducers import Reducer, replace logger = logging.getLogger(__name__) StateNodeFn = Callable[[Any], Awaitable[dict[str, Any] | None]] -RouterFn = Callable[[Any], str] +# A router may return: a node id (str), a Send, or a list[Send] for fan-out. +RouterFn = Callable[[Any], "str | Send | list[Send]"] + + +@dataclass +class Send: + """Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state. + + Routers can return a single ``Send`` or a list of ``Send`` to dispatch multiple + target invocations concurrently. Each Send's payload is applied to a *copy* + of the current state before its target runs; the target's return is then + merged back into shared state via reducers. + + Replaces the legacy ``FanOutStep`` pattern with a first-class primitive. + """ + + target: str + payload: dict[str, Any] + + +class RecursionLimitError(Exception): + """Raised when a node is visited more times than ``recursion_limit`` permits.""" @dataclass @@ -129,6 +150,7 @@ def __init__( node_fns: dict[str, StateNodeFn], branches: dict[str, BranchSpec], checkpointer: Checkpointer | None = None, + recursion_limit: int = 25, ) -> None: self._name = name self._dag = dag @@ -136,6 +158,7 @@ def __init__( self._node_fns = node_fns self._branches = branches self._checkpointer = checkpointer + self._recursion_limit = recursion_limit self._reducers = discover_reducers(state_schema) self._validate() @@ -147,6 +170,32 @@ def name(self) -> str: def dag(self) -> DAG: return self._dag + def to_mermaid(self) -> str: + """Render the pipeline as a Mermaid flowchart, including branch edges. + + Branches that omit an explicit mapping are rendered as a dashed edge + labelled ``router`` because the targets are decided at runtime. + """ + lines = ["flowchart TD"] + for node_id in self._dag.nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + # Explicit edges (including branch mappings, which were materialized). + for edge in self._dag.edges: + label = None + spec = self._branches.get(edge.source) + if spec and spec.mapping: + for lbl, tgt in spec.mapping.items(): + if tgt == edge.target: + label = lbl + break + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + # Dynamic branches (no mapping): show as a dashed self-edge stub. + for source, spec in self._branches.items(): + if spec.mapping is None and not self._dag.successors(source): + lines.append(f" {_mermaid_id(source)} -.->|router| {_mermaid_id(source)}_router((dynamic))") + return "\n".join(lines) + def _validate(self) -> None: # Every node must have a registered fn. for node_id in self._dag.nodes: @@ -174,28 +223,17 @@ def _entry_node(self) -> str: raise PipelineError("Pipeline has no nodes") return order[0] - def _next_node(self, current: str, state: BaseModel) -> str | None: - """Decide which successor runs next given the current state. + def _next_step(self, current: str, state: BaseModel) -> str | list[Send] | None: + """Decide what runs next given the current state. - For non-branching nodes: pick the unique successor (or None at terminus). - For branching nodes: run the router and resolve via mapping if present. + Returns: + * A node id (str) for a single deterministic step. + * A list of :class:`Send` for runtime fan-out — workers run concurrently. + * ``None`` when the pipeline reaches a terminus. """ if current in self._branches: - spec = self._branches[current] - label = spec.router(state) - if spec.mapping is not None: - if label not in spec.mapping: - raise PipelineError( - f"Router for '{current}' returned label '{label}' not in mapping {list(spec.mapping)}" - ) - return spec.mapping[label] - # Mapping omitted: router returns target node id directly. - if label not in self._dag.nodes: - raise PipelineError( - f"Router for '{current}' returned '{label}' " - f"which is not a registered node id; pass an explicit mapping if you want labels." - ) - return label + decision = self._branches[current].router(state) + return self._resolve_router_decision(current, decision) successors = self._dag.successors(current) if not successors: @@ -207,6 +245,50 @@ def _next_node(self, current: str, state: BaseModel) -> str | None: ) return successors[0] + def _resolve_router_decision(self, current: str, decision: str | Send | list[Send]) -> str | list[Send] | None: + """Translate a router's return value into a concrete next-step instruction.""" + # Fan-out: list of Send dispatches. + if isinstance(decision, list): + if not decision: + return None + for s in decision: + if not isinstance(s, Send): + raise PipelineError( + f"Router for '{current}' returned a list containing non-Send " + f"element {s!r}; expected list[Send]." + ) + if s.target not in self._dag.nodes: + raise PipelineError(f"Router for '{current}' fans out to unknown target '{s.target}'") + return decision + + if isinstance(decision, Send): + if decision.target not in self._dag.nodes: + raise PipelineError(f"Router for '{current}' dispatched to unknown target '{decision.target}'") + return [decision] + + # String label. + spec = self._branches[current] + if spec.mapping is not None: + if decision not in spec.mapping: + raise PipelineError( + f"Router for '{current}' returned label '{decision}' not in mapping {list(spec.mapping)}" + ) + return spec.mapping[decision] + if decision not in self._dag.nodes: + raise PipelineError( + f"Router for '{current}' returned '{decision}' " + f"which is not a registered node id; pass an explicit mapping if you want labels." + ) + return decision + + def _common_successor(self, node_ids: list[str]) -> str | None: + """Return the node all ``node_ids`` share as their unique successor, or None.""" + successors = [self._dag.successors(nid) for nid in node_ids] + if not successors or any(len(s) != 1 for s in successors): + return None + first = successors[0][0] + return first if all(s[0] == first for s in successors[1:]) else None + async def invoke( self, state: BaseModel | None = None, @@ -235,7 +317,13 @@ async def invoke( resumed_completed = list(record.completed_nodes) # Resume at the successor of the last completed node. last = record.node_id - next_node = self._next_node(last, state) + next_node = self._next_step(last, state) + # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. + if isinstance(next_node, list): + raise PipelineError( + "Resume across a fan-out (Send) is not supported in Phase 2; " + "the run finished by reaching a fan-out node." + ) if next_node is None: return StatePipelineResult( state=state, @@ -266,9 +354,54 @@ async def invoke( assert state is not None # narrowed by the branches above completed: list[str] = list(resumed_completed) sequence = len(completed) + visit_counts: dict[str, int] = {} + + next_step: str | list[Send] | None = current_node + + while next_step is not None: + # --- fan-out branch (list[Send]) --------------------------------- + if isinstance(next_step, list): + try: + state, sequence = await self._run_fanout( + sends=next_step, + state=state, + completed=completed, + run_id=run_id, + sequence=sequence, + visit_counts=visit_counts, + ) + except _NodeFailureError as fail: + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=fail.message, + failed_node=fail.node_id, + ) + # After fan-out, continue from the workers' shared successor (if any). + next_step = self._common_successor([s.target for s in next_step]) + continue + + # --- single-node step -------------------------------------------- + node_id = next_step + visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 + if visit_counts[node_id] > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " + f"Raise recursion_limit= or fix the routing logic." + ) + logger.error(msg) + return StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=msg, + failed_node=node_id, + ) - while current_node is not None: - fn = self._node_fns[current_node] + fn = self._node_fns[node_id] t0 = time.perf_counter() try: update = await fn(state) @@ -277,7 +410,7 @@ async def invoke( "State pipeline '%s' run '%s' failed at node '%s'", self._name, run_id, - current_node, + node_id, ) return StatePipelineResult( state=state, @@ -285,35 +418,20 @@ async def invoke( completed_nodes=completed, success=False, error=str(exc), - failed_node=current_node, + failed_node=node_id, ) elapsed = (time.perf_counter() - t0) * 1000 - logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, current_node, elapsed) + logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) if update: state = apply_update(state, update, self._reducers) - completed.append(current_node) + completed.append(node_id) sequence += 1 - - if self._checkpointer is not None: - try: - self._checkpointer.save( - CheckpointRecord( - pipeline_name=self._name, - run_id=run_id, - node_id=current_node, - sequence=sequence, - state=state.model_dump(), - completed_nodes=list(completed), - ) - ) - except Exception: - # Checkpoint failure is non-fatal — log and continue. - logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, current_node) + self._save_checkpoint(run_id, node_id, sequence, state, completed) try: - current_node = self._next_node(current_node, state) + next_step = self._next_step(node_id, state) except PipelineError as exc: return StatePipelineResult( state=state, @@ -321,7 +439,7 @@ async def invoke( completed_nodes=completed, success=False, error=str(exc), - failed_node=completed[-1] if completed else None, + failed_node=node_id, ) return StatePipelineResult( @@ -331,6 +449,87 @@ async def invoke( success=True, ) + async def _run_fanout( + self, + *, + sends: list[Send], + state: BaseModel, + completed: list[str], + run_id: str, + sequence: int, + visit_counts: dict[str, int], + ) -> tuple[BaseModel, int]: + """Run all ``Send`` dispatches concurrently. Each task gets its own state + copy with the Send's payload merged in; results are reduced into shared state. + """ + for send in sends: + visit_counts[send.target] = visit_counts.get(send.target, 0) + 1 + if visit_counts[send.target] > self._recursion_limit: + raise _NodeFailureError( + node_id=send.target, + message=( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{send.target}' during fan-out." + ), + ) + + async def _run_one(send: Send) -> tuple[Send, dict[str, Any] | None]: + task_state = apply_update(state, send.payload, self._reducers) + fn = self._node_fns[send.target] + return send, await fn(task_state) + + try: + results = await asyncio.gather(*(_run_one(s) for s in sends)) + except Exception as exc: + # Best-effort: report the first failing target as the failure point. + raise _NodeFailureError( + node_id=sends[0].target, + message=f"Fan-out failure: {exc}", + ) from exc + + new_state = state + for send, update in results: + if update: + new_state = apply_update(new_state, update, self._reducers) + completed.append(send.target) + sequence += 1 + self._save_checkpoint(run_id, send.target, sequence, new_state, completed) + + return new_state, sequence + + def _save_checkpoint( + self, + run_id: str, + node_id: str, + sequence: int, + state: BaseModel, + completed: list[str], + ) -> None: + """Persist state via the configured checkpointer (no-op if absent).""" + if self._checkpointer is None: + return + try: + self._checkpointer.save( + CheckpointRecord( + pipeline_name=self._name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + state=state.model_dump(), + completed_nodes=list(completed), + ) + ) + except Exception: + logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) + + +class _NodeFailureError(Exception): + """Internal sentinel used to bubble fan-out failures out to the main loop.""" + + def __init__(self, node_id: str, message: str) -> None: + super().__init__(message) + self.node_id = node_id + self.message = message + def _resolve_node_id(ref: str | Callable[..., Any]) -> str: """Turn either a string id or a function reference into a node id.""" diff --git a/fireflyframework_agentic/pipeline/steps.py b/fireflyframework_agentic/pipeline/steps.py index 47d12220..68e2bb29 100644 --- a/fireflyframework_agentic/pipeline/steps.py +++ b/fireflyframework_agentic/pipeline/steps.py @@ -20,6 +20,7 @@ import asyncio import logging +import warnings from collections.abc import Callable, Coroutine from typing import Any, Protocol, runtime_checkable @@ -148,6 +149,12 @@ def classify(inputs): """ def __init__(self, router: Callable[[dict[str, Any]], str]) -> None: + warnings.warn( + "BranchStep is deprecated; use PipelineBuilder(state=...).branch(source, router) " + "for first-class declarative branching.", + DeprecationWarning, + stacklevel=2, + ) self._router = router async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: @@ -162,6 +169,12 @@ class FanOutStep: """ def __init__(self, split_fn: Callable[[Any], list[Any]]) -> None: + warnings.warn( + "FanOutStep is deprecated; use PipelineBuilder(state=...) with a router returning " + "list[Send(target, payload)] for first-class runtime fan-out.", + DeprecationWarning, + stacklevel=2, + ) self._split_fn = split_fn async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: diff --git a/tests/unit/pipeline/test_state_pipeline_phase2.py b/tests/unit/pipeline/test_state_pipeline_phase2.py new file mode 100644 index 00000000..f2040f80 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_phase2.py @@ -0,0 +1,231 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-2 tests: cycles + recursion_limit, Send fan-out, Mermaid export, +soft-deprecation of BranchStep / FanOutStep. +""" + +from __future__ import annotations + +import json +import warnings +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + DAG, + BranchStep, + DAGEdge, + DAGNode, + FailureStrategy, + FanOutStep, + PipelineBuilder, + Send, + StatePipeline, + extend, +) +from fireflyframework_agentic.pipeline.steps import CallableStep + + +class LoopState(BaseModel): + counter: int = 0 + log: Annotated[list[str], extend] = [] + + +# --- cycles + recursion_limit ---------------------------------------------- + + +@pytest.mark.asyncio +async def test_simple_cycle_with_exit_router(): + """A node loops back to itself N times, then a router exits to END.""" + + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1, "log": [f"step#{state.counter + 1}"]} + + async def done(state: LoopState) -> dict: + return {"log": ["done"]} + + def route(state: LoopState) -> str: + return "done" if state.counter >= 3 else "step" + + pipeline = PipelineBuilder("loop", state=LoopState).add_node(step).add_node(done).branch(step, route).build() + assert isinstance(pipeline, StatePipeline) + result = await pipeline.invoke(LoopState()) + assert result.success + assert result.state.counter == 3 + assert "done" in result.state.log + # 3 step visits + 1 done = 4 entries before done's own log entry. + assert result.completed_nodes == ["step", "step", "step", "done"] + + +@pytest.mark.asyncio +async def test_recursion_limit_aborts_infinite_loop(): + """A router that never exits triggers the recursion_limit safety net.""" + + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1} + + def never_exits(state: LoopState) -> str: + return "step" + + pipeline = ( + PipelineBuilder("inf", state=LoopState, recursion_limit=5).add_node(step).branch(step, never_exits).build() + ) + result = await pipeline.invoke(LoopState()) + assert not result.success + assert "Recursion limit" in (result.error or "") + assert result.failed_node == "step" + # The node ran exactly recursion_limit times before the guard fired. + assert result.state.counter == 5 + + +# --- Send fan-out ----------------------------------------------------------- + + +class FanOutState(BaseModel): + items: list[str] = [] + results: Annotated[list[str], extend] = [] + item: str | None = None # filled per-Send via payload + + +@pytest.mark.asyncio +async def test_send_fans_out_to_multiple_workers_and_merges_results(): + """Router returns list[Send]; workers run concurrently; reducer merges.""" + + async def planner(state: FanOutState) -> dict: + return {} # passthrough; could populate items if not preset + + async def worker(state: FanOutState) -> dict: + # Each worker sees its own copy of state with the Send payload applied. + assert state.item is not None + return {"results": [f"processed:{state.item}"]} + + async def collect(state: FanOutState) -> dict: + return {"results": ["collected"]} + + def dispatch(state: FanOutState) -> list[Send]: + return [Send("worker", {"item": x}) for x in state.items] + + pipeline = ( + PipelineBuilder("mapreduce", state=FanOutState) + .add_node(planner) + .add_node(worker) + .add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() + ) + result = await pipeline.invoke(FanOutState(items=["a", "b", "c"])) + assert result.success + processed = sorted(r for r in result.state.results if r.startswith("processed:")) + assert processed == ["processed:a", "processed:b", "processed:c"] + assert "collected" in result.state.results + # Each worker counts as a completed node visit; planner once, three workers, then collect. + assert result.completed_nodes.count("worker") == 3 + assert result.completed_nodes[-1] == "collect" + + +@pytest.mark.asyncio +async def test_send_to_unknown_target_fails_cleanly(): + async def planner(state: FanOutState) -> dict: + return {} + + async def worker(state: FanOutState) -> dict: + return {} + + def bad_dispatch(state: FanOutState) -> list[Send]: + return [Send("ghost", {})] + + pipeline = ( + PipelineBuilder("bad", state=FanOutState) + .add_node(planner) + .add_node(worker) + .branch(planner, bad_dispatch) + .build() + ) + result = await pipeline.invoke(FanOutState()) + assert not result.success + assert "ghost" in (result.error or "") + + +# --- Mermaid + JSON export -------------------------------------------------- + + +def test_dag_to_mermaid_renders_topology(): + dag = DAG(name="example") + dag.add_node(DAGNode(node_id="a", step=CallableStep(_noop_async))) + dag.add_node(DAGNode(node_id="b", step=CallableStep(_noop_async))) + dag.add_edge(DAGEdge(source="a", target="b")) + out = dag.to_mermaid() + assert out.startswith("flowchart TD") + assert "a[a]" in out + assert "b[b]" in out + assert "a --> b" in out + + +def test_dag_to_json_round_trips_via_pydantic(): + dag = DAG(name="example") + dag.add_node(DAGNode(node_id="a", step=CallableStep(_noop_async))) + dag.add_node(DAGNode(node_id="b", step=CallableStep(_noop_async), failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b", input_key="payload")) + doc = json.loads(dag.to_json()) + assert doc["name"] == "example" + assert doc["nodes"] == ["a", "b"] + assert doc["edges"] == [{"source": "a", "target": "b", "output_key": "output", "input_key": "payload"}] + + +def test_state_pipeline_to_mermaid_labels_branch_edges(): + async def start(state: LoopState) -> dict: + return {} + + async def left(state: LoopState) -> dict: + return {} + + async def right(state: LoopState) -> dict: + return {} + + def route(state: LoopState) -> str: + return "left_path" + + pipeline = ( + PipelineBuilder("branched", state=LoopState) + .add_node(start) + .add_node(left) + .add_node(right) + .branch(start, route, {"left_path": left, "right_path": right}) + .build() + ) + assert isinstance(pipeline, StatePipeline) + mermaid = pipeline.to_mermaid() + assert "start -->|left_path| left" in mermaid + assert "start -->|right_path| right" in mermaid + + +# --- soft-deprecation ------------------------------------------------------ + + +def test_branch_step_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + BranchStep(router=lambda _: "x") + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert any("branch(" in str(w.message) for w in caught) + + +def test_fan_out_step_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + FanOutStep(split_fn=lambda x: [x]) + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert any("Send" in str(w.message) for w in caught) + + +# --- helpers --------------------------------------------------------------- + + +async def _noop_async(ctx, inputs): + return None