diff --git a/examples/software_factory/progress.py b/examples/software_factory/progress.py index 4f9492c9..f43892e6 100644 --- a/examples/software_factory/progress.py +++ b/examples/software_factory/progress.py @@ -5,7 +5,7 @@ """Console progress handler. -Implements (structurally) the framework's :class:`StatePipelineEventHandler` +Implements (structurally) the framework's :class:`EventHandler` Protocol. Prints one line per pipeline / node event so the QA loop and checkpoint+resume flow are visible when running the example by hand. """ diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index 4fa01bd3..c4f69fd0 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -46,19 +46,13 @@ from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy from fireflyframework_agentic.pipeline.engine import ( EventHandler, + Pause, PipelineEngine, PipelineEventHandler, - StatePipelineEventHandler, + Send, ) from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult -from fireflyframework_agentic.pipeline.state_pipeline import ( - Pause, - RecursionLimitError, - Send, - StatePipeline, - StatePipelineResult, -) from fireflyframework_agentic.pipeline.steps import ( AgentStep, BatchLLMStep, @@ -103,12 +97,8 @@ "PipelineResult", "QueryableAuditLog", "ReasoningStep", - "RecursionLimitError", "RetrievalStep", "Send", - "StatePipeline", - "StatePipelineEventHandler", - "StatePipelineResult", "StepExecutor", "append", "extend", diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 512bd06d..d6a6805b 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -14,11 +14,10 @@ """Fluent builder API for constructing pipeline DAGs. -Two modes: +Two modes, both backed by the same :class:`PipelineEngine`: -1. **Port-based** (legacy, parallel-friendly): nodes are added by string id, - data flows over edge ports, executed by :class:`PipelineEngine`. Use this - for ETL-shaped DAGs with independent parallel steps:: +1. **Port-based** (parallel-friendly): nodes are added by string id, data + flows over edge ports:: pipeline = ( PipelineBuilder("idp") @@ -29,10 +28,10 @@ ) 2. **State-based**: configure ``state=SomeModel`` and nodes become - ``async (state) -> dict`` functions over a typed shared state. Branching - is one ``.branch(source, router)`` call. Function references can be used - as node ids. Optional checkpointing supports resume after failure and - mid-pipeline start. Produces a :class:`StatePipeline`:: + ``async (state) -> dict | None | Pause | Send | list[Send]``. Branching + is one ``.branch(source, router)`` call; function references work as + node ids. Optional checkpointing supports resume after failure and + mid-pipeline start:: pipeline = ( PipelineBuilder("agent", state=AgentState, checkpointer=FileCheckpointer("./ckpt")) @@ -46,8 +45,9 @@ from __future__ import annotations +import asyncio import inspect -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import Any from pydantic import BaseModel @@ -55,18 +55,69 @@ from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.audit import AuditLog from fireflyframework_agentic.pipeline.checkpoint import Checkpointer +from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine, StatePipelineEventHandler -from fireflyframework_agentic.pipeline.state_pipeline import ( - BranchSpec, +from fireflyframework_agentic.pipeline.engine import ( + EventHandler, + PipelineEngine, + PipelineEventHandler, RouterFn, Send, # noqa: F401 re-exported via pipeline/__init__.py - StateNodeFn, - StatePipeline, - coerce_state_node_fn, ) from fireflyframework_agentic.pipeline.steps import AgentStep, CallableStep, StepExecutor +StateNodeFn = Callable[[Any], Awaitable[Any]] +"""Signature for a state-mode node: ``async (state) -> dict | None | Pause | Send | list[Send]``.""" + + +class _StateStepAdapter: + """Adapts a state-mode node fn into the :class:`StepExecutor` shape so it + can ride through :meth:`PipelineEngine._execute_node`. + + State-mode functions take ``(state)`` and return a state update (or one + of the control sentinels :class:`Pause` / :class:`Send`). The engine + calls ``step.execute(context, inputs)``; this adapter forwards + ``context.state`` to the wrapped fn and returns its value verbatim so + PipelineEngine's existing dict/Pause/Send handling fires. + """ + + def __init__(self, fn: Callable[..., Any]) -> None: + self._fn = _coerce_state_node_fn(fn) + + async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: # noqa: ARG002 + return await self._fn(context.state) + + +def _coerce_state_node_fn(fn: Callable[..., Any]) -> StateNodeFn: + """Turn user-supplied state-mode callables into the standard ``async (state) -> Any`` shape. + + Accepts: + * ``async def f(state)`` — used as-is. + * ``def f(state)`` — wrapped to run on a worker thread. + * Object with ``async run(state)`` (e.g. a FireflyAgent) — adapter calls ``.run(state)``. + """ + if inspect.iscoroutinefunction(fn): + return fn # type: ignore[return-value] + + run = getattr(fn, "run", None) + if not callable(fn) and run is not None and callable(run): + + async def _agent_wrap(state: Any) -> Any: + if inspect.iscoroutinefunction(run): + return await run(state) + return await asyncio.get_running_loop().run_in_executor(None, run, state) + + return _agent_wrap + + if callable(fn): + + async def _async_wrap(state: Any) -> Any: + return await asyncio.get_running_loop().run_in_executor(None, fn, state) + + return _async_wrap + + raise PipelineError(f"Cannot adapt {fn!r} as a state node function") + class PipelineBuilder: """Fluent builder for pipelines. @@ -74,10 +125,14 @@ class PipelineBuilder: Parameters: name: Human-readable name for the pipeline. state: Optional Pydantic model class for typed shared state. - When set, the builder produces a :class:`StatePipeline` and nodes - are expected to be ``async (state) -> dict | None``. - checkpointer: Optional :class:`Checkpointer` for state-based pipelines. - Ignored when ``state`` is not set. + When set, the builder produces a state-aware + :class:`PipelineEngine` and nodes are expected to be + ``async (state) -> dict | None | Pause | Send | list[Send]``. + checkpointer: Optional :class:`Checkpointer` for resume. + recursion_limit: Max visits per node in cycle-aware runs. + event_handler: Optional :class:`EventHandler` (or legacy + :class:`PipelineEventHandler`). + audit_log: Optional :class:`AuditLog`. """ def __init__( @@ -87,11 +142,10 @@ def __init__( state: type[BaseModel] | None = None, checkpointer: Checkpointer | None = None, recursion_limit: int = 25, - event_handler: StatePipelineEventHandler | None = None, + event_handler: EventHandler | PipelineEventHandler | None = None, audit_log: AuditLog | None = None, ) -> None: - # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). - # The legacy port-based path keeps acyclicity as an invariant. + # State-aware pipelines may have cycles (ReAct loops, retry-with-critique). self._dag = DAG(name=name, allow_cycles=state is not None) self._name = name self._state_schema = state @@ -101,9 +155,9 @@ def __init__( self._audit_log = audit_log self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] - # State-based mode bookkeeping. Keyed by node id. - self._state_node_fns: dict[str, StateNodeFn] = {} - self._branches: dict[str, BranchSpec] = {} + # Routers + mappings drive the cyclic scheduler's next-step pick. + self._routers: dict[str, RouterFn] = {} + self._router_mappings: dict[str, dict[str, str]] = {} def add_node( self, @@ -122,11 +176,10 @@ def add_node( * ``add_node(fn)`` — state-based mode. ``fn`` is a callable; the node id is taken from ``fn.__name__``. Requires the builder was constructed with ``state=...``. - * ``add_node(node_id, step)`` — legacy port-based mode. ``step`` is a + * ``add_node(node_id, step)`` — port-based mode. ``step`` is a :class:`StepExecutor`, an agent-like, or an async callable. """ if step is None and callable(node_id_or_fn) and not isinstance(node_id_or_fn, str): - # State-based: derive id from function name. if self._state_schema is None: raise PipelineError( "Function-reference add_node(fn) requires PipelineBuilder(state=...). " @@ -134,11 +187,10 @@ def add_node( ) fn = node_id_or_fn node_id = getattr(fn, "__name__", None) or repr(fn) - self._state_node_fns[node_id] = coerce_state_node_fn(fn) self._pending_nodes.append( DAGNode( node_id=node_id, - step=_StateNodePlaceholder(), # never executed; engine path is unused for state pipelines + step=_StateStepAdapter(fn), condition=condition, retry_max=retry_max, timeout_seconds=timeout_seconds, @@ -152,19 +204,16 @@ def add_node( node_id = node_id_or_fn if self._state_schema is not None and step is not None: - # State-based pipeline: accept a callable, or an agent-like object - # exposing async ``run(state)``. ``coerce_state_node_fn`` handles both. run_method = getattr(step, "run", None) if not callable(step) and not callable(run_method): raise PipelineError( f"State pipeline node '{node_id}' must be a callable or expose async run(state); " f"got {type(step).__name__}" ) - self._state_node_fns[node_id] = coerce_state_node_fn(step) self._pending_nodes.append( DAGNode( node_id=node_id, - step=_StateNodePlaceholder(), + step=_StateStepAdapter(step), condition=condition, retry_max=retry_max, timeout_seconds=timeout_seconds, @@ -225,61 +274,50 @@ def branch( router: RouterFn, mapping: dict[str, str | Callable[..., Any]] | None = None, ) -> PipelineBuilder: - """Register a router on ``source``. + """Register a runtime router on ``source``. - ``router`` is a synchronous ``(state) -> str`` callable. Behaviour: + ``router`` is a synchronous ``(state) -> str | Send | list[Send]`` + callable. Behaviour: * If ``mapping`` is None, the router must return the **id of an existing node** that will run next. * If ``mapping`` is provided, the router returns an abstract label that is looked up in ``mapping`` to find the target node id. - State-based pipelines only. + State-aware pipelines only. """ if self._state_schema is None: raise PipelineError(".branch(...) requires PipelineBuilder(state=...)") source_id = _id(source) - resolved_mapping: dict[str, str] | None = None if mapping is not None: resolved_mapping = {label: _id(target) for label, target in mapping.items()} - # Materialize each label's edge into the DAG so topology is inspectable. + self._router_mappings[source_id] = resolved_mapping + # Materialize each label's edge so topology stays inspectable. for target_id in resolved_mapping.values(): self._pending_edges.append(DAGEdge(source=source_id, target=target_id)) - else: - # No mapping: we don't know targets at build time; edges will - # be missing from the DAG. That's fine for the StatePipeline - # executor (it consults the router), but visualisation will be - # incomplete. Materialize edges lazily when the router fires. - pass - self._branches[source_id] = BranchSpec(source=source_id, router=router, mapping=resolved_mapping) + self._routers[source_id] = router return self - def build(self) -> PipelineEngine | StatePipeline: - """Build the DAG and return either a :class:`PipelineEngine` - (legacy port-based) or :class:`StatePipeline` (when ``state=`` is set). - """ + def build(self) -> PipelineEngine: + """Build the DAG and return a :class:`PipelineEngine`.""" for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: self._dag.add_edge(edge) - if self._state_schema is not None: - return StatePipeline( - name=self._name, - dag=self._dag, - state_schema=self._state_schema, - node_fns=self._state_node_fns, - branches=self._branches, - checkpointer=self._checkpointer, - recursion_limit=self._recursion_limit, - event_handler=self._event_handler, - audit_log=self._audit_log, - ) - - return PipelineEngine(self._dag) + return PipelineEngine( + self._dag, + event_handler=self._event_handler, + checkpointer=self._checkpointer, + audit_log=self._audit_log, + state_schema=self._state_schema, + recursion_limit=self._recursion_limit, + routers=self._routers, + router_mappings=self._router_mappings, + ) def build_dag(self) -> DAG: - """Build and return just the :class:`DAG` (for inspection or custom engines).""" + """Build and return just the :class:`DAG` (for inspection).""" for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: @@ -309,11 +347,3 @@ def _id(ref: str | Callable[..., Any]) -> str: if not name: raise PipelineError(f"Cannot derive node id from {ref!r}") return name - - -class _StateNodePlaceholder: - """Sentinel step kept in the DAG so topology is intact. Never executed — - state pipelines bypass :class:`PipelineEngine` entirely.""" - - async def execute(self, *_args: Any, **_kwargs: Any) -> Any: - raise PipelineError("_StateNodePlaceholder.execute called — state pipelines should not use PipelineEngine.") diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index 765b8355..b3cf6b94 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -38,7 +38,7 @@ class CheckpointRecord(BaseModel): """One saved checkpoint. ``paused`` and ``pause_reason`` are set when a node returns - :class:`fireflyframework_agentic.pipeline.state_pipeline.Pause`. Default + :class:`fireflyframework_agentic.pipeline.engine.Pause`. Default to ``False`` / ``None`` so existing records from earlier phases load cleanly under the new schema. """ diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 8317e5b1..de644ad8 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -23,9 +23,10 @@ import random import time import uuid +from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, cast, runtime_checkable try: from opentelemetry import trace as otel_trace @@ -40,7 +41,7 @@ from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.context import PipelineContext -from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy +from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy, _mermaid_id from fireflyframework_agentic.pipeline.reducers import Reducer, apply_update, discover_reducers from fireflyframework_agentic.pipeline.result import ( ExecutionTraceEntry, @@ -53,8 +54,7 @@ @runtime_checkable class EventHandler(Protocol): - """Unified pipeline event handler. Used by :class:`PipelineEngine` and - :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. + """Pipeline event handler used by :class:`PipelineEngine`. Implement any subset of these methods; missing ones are no-ops. Exceptions raised in callbacks are swallowed by the engine so observability never @@ -63,9 +63,8 @@ class EventHandler(Protocol): The engine dispatches events by parameter name. If your method signature omits a parameter — e.g. legacy implementations that don't accept ``run_id`` or ``visit`` — the engine simply drops it from the call. - That keeps legacy :class:`PipelineEventHandler` / - :class:`StatePipelineEventHandler` implementations working during the - transition to this unified shape. + That keeps the legacy :class:`PipelineEventHandler` shape working + transparently alongside this unified one. Parameter conventions: @@ -113,22 +112,11 @@ async def on_node_skip(self, node_id: str, pipeline_name: str, reason: str) -> N async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: ... -@runtime_checkable -class StatePipelineEventHandler(Protocol): - """Legacy state-pipeline event handler protocol. Use :class:`EventHandler`. - - Same shape as :class:`EventHandler` minus ``on_node_skip``. Kept for - backward compatibility; new code should implement :class:`EventHandler`. - """ - - async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: ... - async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: ... - async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: ... - async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: ... - async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... - async def on_pipeline_complete( - self, pipeline_name: str, run_id: str, success: bool, duration_ms: float - ) -> None: ... +RouterFn = Callable[[Any], "str | Send | list[Send]"] +"""Signature for a runtime branch router: receives the current state, returns +the next-step instruction — either a target node id, a single Send, or a +list of Sends for fan-out. +""" @dataclass @@ -222,7 +210,7 @@ def start_otel_span(name: str, **attributes: Any) -> Any: """Start an OTel span if observability is enabled, else return ``None``. Module-level helper shared by :class:`PipelineEngine` and - :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. + :class:`PipelineEngine`. """ try: if not get_config().observability_enabled: @@ -254,6 +242,8 @@ def __init__( audit_log: AuditLog | None = None, state_schema: type[BaseModel] | None = None, recursion_limit: int = 25, + routers: dict[str, RouterFn] | None = None, + router_mappings: dict[str, dict[str, str]] | None = None, ) -> None: self._dag = dag self._event_handler = event_handler @@ -266,9 +256,62 @@ def __init__( self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {} # Max visits per node for cycle-aware runs. Matches StatePipeline's default. self._recursion_limit = recursion_limit + # Optional runtime routers: source_id -> router(state) -> str | Send | list[Send]. + # When a source node has a router, the cyclic scheduler consults it + # instead of (or in addition to) the source's outgoing edges. With + # an accompanying mapping the router returns an abstract label that + # is looked up in the mapping. + self._routers: dict[str, RouterFn] = dict(routers or {}) + self._router_mappings: dict[str, dict[str, str]] = dict(router_mappings or {}) # Per-method signature cache for legacy-vs-unified dispatch. self._handler_params: dict[str, set[str]] = {} + def to_mermaid(self) -> str: + """Render the pipeline as a Mermaid flowchart, labelling branch edges. + + When the builder called ``.branch(source, router, mapping={...})`` + the resulting edges carry abstract labels (``yes``/``no``/etc). + This view threads those labels back into the diagram so the routing + is visible alongside the topology. + """ + lines = ["flowchart TD"] + for node_id in self._dag.nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + for edge in self._dag.edges: + label: str | None = None + mapping = self._router_mappings.get(edge.source) + if mapping: + for lbl, tgt in mapping.items(): + if tgt == edge.target: + label = lbl + break + if label is None and edge.condition is not None: + label = "if?" + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + return "\n".join(lines) + + async def invoke( + self, + state: Any = None, + *, + run_id: str | None = None, + start_at: Any = None, + approve_pause: bool = False, + ) -> PipelineResult: + """Shorthand for state-aware runs: ``await pipeline.invoke(state)``. + + Mirrors the legacy ``StatePipeline.invoke`` signature so callers that + treat the first positional as the state object keep working. New code + should call :meth:`run` directly with explicit kwargs. + """ + return await self.run( + state=state, + run_id=run_id, + start_at=start_at, + approve_pause=approve_pause, + ) + async def _dispatch(self, method_name: str, /, **kwargs: Any) -> None: """Invoke ``event_handler.method_name`` with the subset of ``kwargs`` the method's signature actually declares. @@ -333,12 +376,24 @@ async def run( """ if run_id is not None and context is None and inputs is None and state is None: resume_run_id: str = run_id - context, pre_completed, sequence_start = self._load_for_resume(resume_run_id, approve_pause=approve_pause) - all_results: dict[str, NodeResult] = { - nid: nr - for nid in pre_completed - if (nr := context.get_node_result(nid)) is not None and isinstance(nr, NodeResult) - } + context, pre_completed_list, sequence_start = self._load_for_resume( + resume_run_id, approve_pause=approve_pause + ) + pre_completed = set(pre_completed_list) + # Preserve original completion order so PipelineResult.completed_nodes + # reflects the run's actual sequence after resume. + all_results: dict[str, NodeResult] = {} + for nid in pre_completed_list: + nr = context.get_node_result(nid) + if isinstance(nr, NodeResult): + all_results[nid] = nr + # Synthesize trace entries for the pre-completed nodes so the + # resumed result's completed_nodes reflects the full history. + now = datetime.now(UTC) + resume_trace_seed: list[ExecutionTraceEntry] = [ + ExecutionTraceEntry(node_id=nid, started_at=now, completed_at=now, status="success") + for nid in pre_completed_list + ] else: if context is None: context = PipelineContext(inputs=inputs) @@ -358,6 +413,7 @@ async def run( pre_completed = set() sequence_start = 0 all_results = {} + resume_trace_seed: list[ExecutionTraceEntry] = [] # Mid-pipeline start: pretend everything not reachable from # `start_at` already ran. The scheduler then starts at start_at # because its upstream nodes appear "completed". @@ -372,17 +428,23 @@ async def run( run_id = uuid.uuid4().hex[:12] # Observability: pipeline-level span + start event + # State-aware runs use the "pipeline.state.*" span prefix to match + # the legacy StatePipeline taxonomy that observability dashboards + # already key on. + span_prefix = "pipeline.state" if self._state_schema is not None else "pipeline" _pipeline_span = self._start_otel_span( - f"pipeline.{self._dag.name}", + f"{span_prefix}.{self._dag.name}", pipeline=self._dag.name, + run_id=run_id, ) await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id) # Cycle-aware mode: a separate sequential frontier-following scheduler # that respects ``recursion_limit``. The topological scheduler below # cannot run cyclic graphs because execution_levels()/topological_sort() - # are undefined on them. - if self._dag.is_cyclic(): + # are undefined on them. Runtime routers also force this mode because + # they make routing a function of state, not topology. + if self._dag.is_cyclic() or self._routers: return await self._run_cyclic( context=context, run_id=run_id, @@ -390,13 +452,14 @@ async def run( pre_completed=pre_completed, sequence_start=sequence_start, pipeline_span=_pipeline_span, + resume_trace_seed=resume_trace_seed, ) # Topological levels ensure that all upstream dependencies of a node # complete before the node itself executes. Nodes within the same # level are independent and run concurrently via asyncio.gather. levels = self._dag.execution_levels() - trace_entries: list[ExecutionTraceEntry] = [] + trace_entries: list[ExecutionTraceEntry] = list(resume_trace_seed) pipeline_start = time.perf_counter() failed_nodes: set[str] = set() @@ -489,6 +552,15 @@ async def _record_skip(nid: str) -> None: # them for the audit snapshot. gathered = self._gather_inputs(nid, context) inputs_by_node[nid] = gathered + # Emit start event here (visit=1 in the acyclic scheduler; + # the cyclic scheduler and Send fan-out emit their own). + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=nid, + visit=1, + ) task = asyncio.create_task( self._execute_node( nid, @@ -535,6 +607,7 @@ async def _record_skip(nid: str) -> None: # Persist lifecycle: audit every executed visit; checkpoint only # successful completions (failed nodes must re-run on resume). sequence += 1 + paused_now = nr.success and isinstance(nr.output, Pause) self._record_audit( run_id=run_id, node_id=node_id, @@ -542,10 +615,12 @@ async def _record_skip(nid: str) -> None: nr=nr, inputs_snapshot=inputs_by_node.get(node_id, {}), trace_entries=trace_entries, + status_override="paused" if paused_now else None, + pause_reason=nr.output.reason if paused_now else None, ) # HITL: a node returned Pause(reason=...). Halt cleanly, save # a paused checkpoint, and surface the pause in the result. - if nr.success and isinstance(nr.output, Pause): + if paused_now: pause_reason = nr.output.reason await self._dispatch( "on_node_pause", @@ -569,7 +644,10 @@ async def _record_skip(nid: str) -> None: # Runtime fan-out: a node returned Send / list[Send]. if nr.success and _is_send_payload(nr.output): - sends = nr.output if isinstance(nr.output, list) else [nr.output] + sends: list[Send] = cast( + "list[Send]", + list(nr.output) if isinstance(nr.output, list) else [nr.output], + ) ok = await self._run_sends( sends=sends, context=context, @@ -680,6 +758,7 @@ async def _run_sends( trace_entries: list[ExecutionTraceEntry], completed: set[str], pending: set[str], + visit_counts: dict[str, int] | None = None, ) -> bool: """Dispatch a list of :class:`Send` workers concurrently. @@ -702,13 +781,13 @@ async def _run_sends( all_results[send.target] = nr return False - async def _run_one(send: Send) -> tuple[Send, NodeResult]: + async def _run_one(send: Send, visit_n: int) -> tuple[Send, NodeResult]: await self._dispatch( "on_node_start", pipeline_name=self._dag.name, run_id=run_id, node_id=send.target, - visit=1, + visit=visit_n, ) # Per-worker context: own state copy with payload applied so # workers don't race on the shared state object. @@ -727,8 +806,19 @@ async def _run_one(send: Send) -> tuple[Send, NodeResult]: ) return send, nr + # Per-Send visit numbers: increment per dispatched target. The + # caller may seed counts (cyclic scheduler tracks them globally); + # otherwise each fan-out batch starts at 1. + send_visits: list[int] = [] + running_counts = dict(visit_counts) if visit_counts else {} + for send in sends: + running_counts[send.target] = running_counts.get(send.target, 0) + 1 + send_visits.append(running_counts[send.target]) + if visit_counts is not None: + visit_counts.update(running_counts) + try: - results = await asyncio.gather(*(_run_one(s) for s in sends)) + results = await asyncio.gather(*(_run_one(s, v) for s, v in zip(sends, send_visits, strict=True))) except Exception as exc: logger.exception("Fan-out worker crashed") for send in sends: @@ -759,6 +849,7 @@ async def _run_cyclic( pre_completed: set[str], sequence_start: int, pipeline_span: Any, + resume_trace_seed: list[ExecutionTraceEntry] | None = None, ) -> PipelineResult: """Sequential frontier-following scheduler for cyclic DAGs. @@ -768,7 +859,7 @@ async def _run_cyclic( having multiple alive outgoing edges is currently an error — parallel cyclic fan-out is the job of :class:`Send` in a later layer. """ - trace_entries: list[ExecutionTraceEntry] = [] + trace_entries: list[ExecutionTraceEntry] = list(resume_trace_seed or []) pipeline_start = time.perf_counter() visit_counts: dict[str, int] = dict.fromkeys(pre_completed, 1) sequence = sequence_start @@ -777,14 +868,53 @@ async def _run_cyclic( nodes_in_order = list(self._dag.nodes) if not nodes_in_order: raise PipelineError("Pipeline has no nodes") - current: str | None = nodes_in_order[0] + next_step: str | list[Send] | None = nodes_in_order[0] # Skip past anything already completed during this resumed run. - while current is not None and current in pre_completed: - current = self._cyclic_next(current, context) + while isinstance(next_step, str) and next_step in pre_completed: + next_step = self._cyclic_next(next_step, context) - result: PipelineResult | None = None + pending_pause: tuple[str, str] | None = None try: - while current is not None: + while next_step is not None: + # --- Fan-out (list[Send]) --------------------------------- + if isinstance(next_step, list): + sends = next_step + # Preview the per-target visit numbers to enforce the + # recursion limit; _run_sends does the real increment. + over_limit: str | None = None + preview = dict(visit_counts) + for send in sends: + preview[send.target] = preview.get(send.target, 0) + 1 + if preview[send.target] > self._recursion_limit: + over_limit = send.target + break + if over_limit is not None: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{over_limit}' during fan-out." + ) + logger.error(msg) + all_results[over_limit] = NodeResult(node_id=over_limit, success=False, error=msg) + break + completed_set: set[str] = set(all_results) + pending_set: set[str] = set() + ok = await self._run_sends( + sends=sends, + context=context, + run_id=run_id, + all_results=all_results, + trace_entries=trace_entries, + completed=completed_set, + pending=pending_set, + visit_counts=visit_counts, + ) + if not ok: + break + # Continue from the common successor of all workers, if any. + next_step = self._common_successor([s.target for s in sends]) + continue + + # --- Single-node step ------------------------------------- + current = next_step visit_counts[current] = visit_counts.get(current, 0) + 1 visit_n = visit_counts[current] if visit_n > self._recursion_limit: @@ -793,16 +923,17 @@ async def _run_cyclic( f"'{current}'. Raise recursion_limit= or fix the routing logic." ) logger.error(msg) - nr = NodeResult(node_id=current, success=False, error=msg) - all_results[current] = nr + nr_over = NodeResult(node_id=current, success=False, error=msg) + all_results[current] = nr_over sequence += 1 self._record_audit( run_id=run_id, node_id=current, sequence=sequence, - nr=nr, + nr=nr_over, inputs_snapshot={}, trace_entries=trace_entries, + visit=visit_n, ) break @@ -827,6 +958,7 @@ async def _run_cyclic( await self._emit_node_result(nr, run_id) sequence += 1 + paused_now = nr.success and isinstance(nr.output, Pause) self._record_audit( run_id=run_id, node_id=current, @@ -835,11 +967,43 @@ async def _run_cyclic( inputs_snapshot=gathered, trace_entries=trace_entries, visit=visit_n, + status_override="paused" if paused_now else None, + pause_reason=nr.output.reason if paused_now else None, ) if not nr.success and not nr.skipped: break + # HITL: node returned Pause — checkpoint paused and halt. + if paused_now: + pause_reason = nr.output.reason + await self._dispatch( + "on_node_pause", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=current, + reason=pause_reason, + ) + self._save_checkpoint( + run_id=run_id, + node_id=current, + sequence=sequence, + context=context, + all_results=all_results, + paused=True, + pause_reason=pause_reason, + ) + pending_pause = (current, pause_reason) + break + + # Fan-out: node returned Send / list[Send]. + if nr.success and _is_send_payload(nr.output): + next_step = cast( + "list[Send]", + list(nr.output) if isinstance(nr.output, list) else [nr.output], + ) + continue + if not nr.skipped: if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): context.state = apply_update(context.state, nr.output, self._reducers) @@ -851,12 +1015,14 @@ async def _run_cyclic( all_results=all_results, ) - current = self._cyclic_next(current, context) + try: + next_step = self._cyclic_next(current, context) + except PipelineError as exc: + all_results[current] = NodeResult(node_id=current, success=False, error=str(exc), output=nr.output) + break finally: elapsed = (time.perf_counter() - pipeline_start) * 1000 - success = ( - result.success if result is not None else all(r.success or r.skipped for r in all_results.values()) - ) + success = False if pending_pause is not None else all(r.success or r.skipped for r in all_results.values()) await self._dispatch( "on_pipeline_complete", pipeline_name=self._dag.name, @@ -868,26 +1034,43 @@ async def _run_cyclic( with contextlib.suppress(Exception): pipeline_span.end() - final_output = None + paused_node = pending_pause[0] if pending_pause else None + pause_reason_final = pending_pause[1] if pending_pause else None return PipelineResult( pipeline_name=self._dag.name, outputs=all_results, - final_output=final_output, + final_output=None, execution_trace=trace_entries, total_duration_ms=elapsed, success=success, usage=None, run_id=run_id, final_state=context.state, + paused=pending_pause is not None, + paused_node=paused_node, + pause_reason=pause_reason_final, ) - def _cyclic_next(self, current: str, context: PipelineContext) -> str | None: - """Pick the next node by following the unique alive outgoing edge. + def _common_successor(self, node_ids: list[str]) -> str | None: + """Return the node all ``node_ids`` share as their unique successor, or None.""" + successors = [self._dag.successors(nid) for nid in node_ids] + if not successors or any(len(s) != 1 for s in successors): + return None + first = successors[0][0] + return first if all(s[0] == first for s in successors[1:]) else None + + def _cyclic_next(self, current: str, context: PipelineContext) -> str | list[Send] | None: + """Pick what runs next from ``current``. - Returns ``None`` when no outgoing edge is alive (terminus). Raises - if more than one is alive — parallel cyclic fan-out lands with - :class:`Send` in a later layer. + Priority: a registered router (.branch(...)) wins. Its return value + (str, Send, list[Send], or None) is resolved to a concrete target. + Otherwise fall back to the unique alive outgoing edge; multiple + alive edges with no router raise. """ + # Runtime router takes precedence. + if current in self._routers: + decision = self._routers[current](context.state) + return self._resolve_router_decision(current, decision) def _alive(edge: Any) -> bool: if edge.condition is None: @@ -904,12 +1087,57 @@ def _alive(edge: Any) -> bool: if len(alive) > 1: raise PipelineError( f"Cyclic node '{current}' has multiple alive outgoing edges " - f"({[e.target for e in alive]}). Parallel cyclic fan-out arrives " - f"with Send in a later layer; for now, use mutually exclusive " - f"edge conditions." + f"({[e.target for e in alive]}). Register a .branch(...) " + f"router or make the edge conditions mutually exclusive." ) return alive[0].target + def _resolve_router_decision(self, source: str, decision: Any) -> str | list[Send] | None: + """Translate a router's return value into a concrete next-step. + + Accepts: + * a string node id (looked up in ``router_mappings`` if registered), + * a single :class:`Send` (wrapped into a one-element list), + * a ``list[Send]`` (returned as-is after validation), + * ``None`` or an empty list (terminus). + """ + if decision is None: + return None + if isinstance(decision, list): + if not decision: + return None + for s in decision: + if not isinstance(s, Send): + raise PipelineError( + f"Router for '{source}' returned a list containing non-Send element {s!r}; expected list[Send]." + ) + if s.target not in self._dag.nodes: + raise PipelineError(f"Router for '{source}' fans out to unknown target '{s.target}'") + return decision + if isinstance(decision, Send): + if decision.target not in self._dag.nodes: + raise PipelineError(f"Router for '{source}' dispatched to unknown target '{decision.target}'") + return [decision] + if isinstance(decision, str): + mapping = self._router_mappings.get(source) + if mapping is not None: + if decision not in mapping: + raise PipelineError( + f"Router for '{source}' returned label '{decision}' not in mapping {list(mapping)}" + ) + return mapping[decision] + if decision not in self._dag.nodes: + raise PipelineError( + f"Router for '{source}' returned '{decision}' " + f"which is not a registered node id; pass an explicit " + f"mapping if you want abstract labels." + ) + return decision + raise PipelineError( + f"Router for '{source}' returned unsupported type {type(decision).__name__}; " + f"expected str, Send, list[Send], or None." + ) + async def _execute_node( self, node_id: str, @@ -947,20 +1175,17 @@ async def _execute_node( if inputs is None: inputs = self._gather_inputs(node_id, context) + node_prefix = "pipeline.state.node" if self._state_schema is not None else "pipeline.node" _node_span = self._start_otel_span( - f"pipeline.node.{node_id}", + f"{node_prefix}.{node_id}", node=node_id, - ) - - # Emit node start event (visit=1; cycles arrive in a later layer) - await self._dispatch( - "on_node_start", - pipeline_name=self._dag.name, - run_id=run_id, - node_id=node_id, visit=1, ) + # Note: on_node_start is now emitted by the caller (run / _run_cyclic / + # _run_sends) so the cyclic and fan-out paths can supply the right + # ``visit`` number. Emitting here would duplicate the event. + max_retries = node.retry_max backoff_factor = node.backoff_factor retries = 0 @@ -1067,7 +1292,7 @@ async def _emit_node_result(self, nr: NodeResult, run_id: str) -> None: else: await self._dispatch("on_node_error", error=nr.error or "unknown", **common) - def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, set[str], int]: + def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, list[str], int]: """Rebuild context + completed-set from the latest checkpoint. Resuming a paused run (checkpoint.paused=True) requires @@ -1097,7 +1322,7 @@ def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple context.state = self._state_schema.model_validate(saved_state) except Exception: logger.warning("Could not restore shared state on resume for run '%s'", run_id) - return context, set(record.completed_nodes), record.sequence + return context, list(record.completed_nodes), record.sequence def _save_checkpoint( self, @@ -1149,12 +1374,15 @@ def _record_audit( inputs_snapshot: dict[str, Any], trace_entries: list[ExecutionTraceEntry], visit: int = 1, + status_override: AuditStatus | None = None, + pause_reason: str | None = None, ) -> None: """Write an audit entry for a node visit. No-op if no audit log. Skipped nodes are not recorded — they represent work that did NOT - happen and would clutter the trail. ``visit`` defaults to 1 for the - acyclic scheduler and is supplied by the cyclic scheduler. + happen and would clutter the trail. ``status_override`` lets the + cyclic scheduler tag a Pause-returning node with ``"paused"`` + instead of the default ``"success"`` derived from ``nr.success``. """ if self._audit_log is None or nr.skipped: return @@ -1165,7 +1393,10 @@ def _record_audit( started_at = te.started_at completed_at = te.completed_at break - status: AuditStatus = "success" if nr.success else "error" + if status_override is not None: + status: AuditStatus = status_override + else: + status = "success" if nr.success else "error" outputs: dict[str, Any] = {"output": _serialize_value(nr.output)} if nr.success else {} entry = AuditEntry( pipeline_name=self._dag.name, @@ -1180,6 +1411,7 @@ def _record_audit( inputs_snapshot={k: _serialize_value(v) for k, v in inputs_snapshot.items()}, outputs_snapshot=outputs, error_message=nr.error if not nr.success else None, + pause_reason=pause_reason, ) try: self._audit_log.record(entry) diff --git a/fireflyframework_agentic/pipeline/result.py b/fireflyframework_agentic/pipeline/result.py index 1f6bc260..ade1ccf8 100644 --- a/fireflyframework_agentic/pipeline/result.py +++ b/fireflyframework_agentic/pipeline/result.py @@ -91,3 +91,36 @@ class PipelineResult(BaseModel): @property def failed_nodes(self) -> list[str]: return [nid for nid, r in self.outputs.items() if not r.success and not r.skipped] + + # -- State-mode convenience aliases --------------------------------- + + @property + def state(self) -> Any: + """Final shared state. Alias of :attr:`final_state` for state-aware + pipelines built via ``PipelineBuilder(state=...)``.""" + return self.final_state + + @property + def completed_nodes(self) -> list[str]: + """IDs of every successful node visit, in completion order. + + Derived from :attr:`execution_trace` so each cyclic re-entry of a + node appears as its own entry (matches StatePipeline's semantics). + """ + return [e.node_id for e in self.execution_trace if e.status == "success"] + + @property + def failed_node(self) -> str | None: + """First node that failed, if any. ``None`` when the run succeeded.""" + for nid, r in self.outputs.items(): + if not r.success and not r.skipped: + return nid + return None + + @property + def error(self) -> str | None: + """Error message from the first failed node, if any.""" + for r in self.outputs.values(): + if not r.success and not r.skipped and r.error: + return r.error + return None diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py deleted file mode 100644 index 07a3b1ea..00000000 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ /dev/null @@ -1,751 +0,0 @@ -# Copyright 2026 Firefly Software Foundation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""State-based pipeline: a sequential executor over a typed shared-state object. - -Layered on top of :class:`DAG` for topology, but uses its own simple executor -rather than :class:`PipelineEngine`. The trade-off: no within-level parallelism, -but in exchange we get clean semantics for typed state, reducers, branching, -checkpointing, and mid-pipeline resume — which are the things this API exists -to provide. Port-based parallel DAGs continue to use :class:`PipelineEngine`. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import inspect -import logging -import time -import uuid -import warnings -from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel - -from fireflyframework_agentic.exceptions import PipelineError -from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus -from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord -from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id -from fireflyframework_agentic.pipeline.engine import Pause, Send, start_otel_span -from fireflyframework_agentic.pipeline.reducers import apply_update, discover_reducers - -if TYPE_CHECKING: - from fireflyframework_agentic.pipeline.engine import StatePipelineEventHandler - -logger = logging.getLogger(__name__) - -StateNodeFn = Callable[[Any], Awaitable[dict[str, Any] | None]] -# A router may return: a node id (str), a Send, or a list[Send] for fan-out. -RouterFn = Callable[[Any], "str | Send | list[Send]"] - - -class RecursionLimitError(Exception): - """Raised when a node is visited more times than ``recursion_limit`` permits.""" - - -@dataclass -class BranchSpec: - """Internal: registered branch from one source node.""" - - source: str - router: RouterFn - mapping: dict[str, str] | None # label -> target node_id. None = router returns target directly. - - -@dataclass -class StatePipelineResult: - """Outcome of a single ``invoke`` call. - - Attributes: - state: Final state object. - run_id: ID of this run (use to resume later via ``invoke(run_id=...)``). - completed_nodes: Node IDs that ran successfully this invocation, in order. - success: True iff all attempted nodes completed without error. - error: Last error message if ``success`` is False. - failed_node: Node ID that failed, if any. - paused: True if the run halted on a :class:`Pause` sentinel; resume - via ``invoke(run_id=..., approve_pause=True)``. - paused_node: Node that returned ``Pause`` if ``paused`` is True. - pause_reason: Reason string the paused node passed to ``Pause(...)``. - """ - - state: Any - run_id: str - completed_nodes: list[str] - success: bool - error: str | None = None - failed_node: str | None = None - paused: bool = False - paused_node: str | None = None - pause_reason: str | None = None - - -class StatePipeline: - """Compiled state-based pipeline. Returned by ``PipelineBuilder.build()`` - when a ``state=`` schema is configured. - - .. deprecated:: - :class:`StatePipeline` is being subsumed by - :class:`fireflyframework_agentic.pipeline.engine.PipelineEngine` - configured with ``state_schema=``. The unified engine supports - the same features (state overlay, reducers, Pause, Send, cycles, - recursion_limit, checkpointing, audit, resume, start_at) and adds - true parallelism for state-aware pipelines via the topological - scheduler. New code should prefer ``PipelineEngine`` directly: - - .. code-block:: python - - engine = PipelineEngine( - dag, - state_schema=MyState, - checkpointer=cp, - audit_log=al, - recursion_limit=10, - ) - result = await engine.run(state=MyState(...)) - - See issue #245 for the full migration plan. The next layer of - unification removes :class:`StatePipeline` after a deprecation - cycle. - """ - - def __init__( - self, - *, - name: str, - dag: DAG, - state_schema: type[BaseModel], - node_fns: dict[str, StateNodeFn], - branches: dict[str, BranchSpec], - checkpointer: Checkpointer | None = None, - recursion_limit: int = 25, - event_handler: StatePipelineEventHandler | None = None, - audit_log: AuditLog | None = None, - ) -> None: - warnings.warn( - "StatePipeline is deprecated; use PipelineEngine(state_schema=...) " - "for the unified API. The unified engine supports the same features " - "and adds parallel state-aware execution. See issue #245.", - DeprecationWarning, - stacklevel=2, - ) - self._name = name - self._dag = dag - self._state_schema = state_schema - self._node_fns = node_fns - self._branches = branches - self._checkpointer = checkpointer - self._recursion_limit = recursion_limit - self._event_handler = event_handler - self._audit_log = audit_log - self._reducers = discover_reducers(state_schema) - self._validate() - - def _audit( - self, - *, - run_id: str, - node_id: str, - sequence: int, - visit: int, - started_at: datetime, - completed_at: datetime, - latency_ms: float, - status: AuditStatus, - inputs_snapshot: dict[str, Any], - outputs_snapshot: dict[str, Any], - error_message: str | None = None, - pause_reason: str | None = None, - ) -> None: - """Construct and write an :class:`AuditEntry`. No-op if no audit log is configured. - - Audit-write failures are non-fatal — logged and swallowed. - """ - if self._audit_log is None: - return - entry = AuditEntry( - pipeline_name=self._name, - run_id=run_id, - node_id=node_id, - sequence=sequence, - visit=visit, - started_at=started_at, - completed_at=completed_at, - latency_ms=latency_ms, - status=status, - inputs_snapshot=inputs_snapshot, - outputs_snapshot=outputs_snapshot, - error_message=error_message, - pause_reason=pause_reason, - ) - try: - self._audit_log.record(entry) - except Exception: - logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) - - async def _emit(self, method: str, *args: Any) -> None: - """Invoke ``method`` on the configured event handler if it exists. - - Missing methods are no-ops; raised exceptions are swallowed so - observability never breaks business logic. - """ - if self._event_handler is None: - return - fn = getattr(self._event_handler, method, None) - if fn is None: - return - with contextlib.suppress(Exception): - await fn(*args) - - @property - def name(self) -> str: - return self._name - - @property - def dag(self) -> DAG: - return self._dag - - def to_mermaid(self) -> str: - """Render the pipeline as a Mermaid flowchart, including branch edges. - - Branches that omit an explicit mapping are rendered as a dashed edge - labelled ``router`` because the targets are decided at runtime. - """ - lines = ["flowchart TD"] - for node_id in self._dag.nodes: - lines.append(f" {_mermaid_id(node_id)}[{node_id}]") - # Explicit edges (including branch mappings, which were materialized). - for edge in self._dag.edges: - label = None - spec = self._branches.get(edge.source) - if spec and spec.mapping: - for lbl, tgt in spec.mapping.items(): - if tgt == edge.target: - label = lbl - break - arrow = f"-->|{label}|" if label else "-->" - lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") - # Dynamic branches (no mapping): show as a dashed self-edge stub. - for source, spec in self._branches.items(): - if spec.mapping is None and not self._dag.successors(source): - lines.append(f" {_mermaid_id(source)} -.->|router| {_mermaid_id(source)}_router((dynamic))") - return "\n".join(lines) - - def _validate(self) -> None: - # Every node must have a registered fn. - for node_id in self._dag.nodes: - if node_id not in self._node_fns: - raise PipelineError(f"Node '{node_id}' has no registered function") - # Every branch source/target must exist. - for source, spec in self._branches.items(): - if source not in self._dag.nodes: - raise PipelineError(f"Branch source '{source}' not in DAG") - if spec.mapping: - for label, target in spec.mapping.items(): - if target not in self._dag.nodes: - raise PipelineError(f"Branch target '{target}' (label '{label}') not in DAG") - - def _entry_node(self) -> str: - """Default entry: the first node added. - - Override with ``invoke(state, start_at=...)``. Picking insertion-order - rather than the topological root keeps things predictable in the - common case where a ``.branch(...)`` without an explicit mapping - leaves multiple nodes with no inbound edges. - """ - order = list(self._dag.nodes) - if not order: - raise PipelineError("Pipeline has no nodes") - return order[0] - - def _next_step(self, current: str, state: BaseModel) -> str | list[Send] | None: - """Decide what runs next given the current state. - - Returns: - * A node id (str) for a single deterministic step. - * A list of :class:`Send` for runtime fan-out — workers run concurrently. - * ``None`` when the pipeline reaches a terminus. - """ - if current in self._branches: - decision = self._branches[current].router(state) - return self._resolve_router_decision(current, decision) - - successors = self._dag.successors(current) - if not successors: - return None - if len(successors) > 1: - raise PipelineError( - f"Node '{current}' has multiple successors {successors} but no .branch(...) registered. " - f"Register a branch router or remove the extra edges." - ) - return successors[0] - - def _resolve_router_decision(self, current: str, decision: str | Send | list[Send]) -> str | list[Send] | None: - """Translate a router's return value into a concrete next-step instruction.""" - # Fan-out: list of Send dispatches. - if isinstance(decision, list): - if not decision: - return None - for s in decision: - if not isinstance(s, Send): - raise PipelineError( - f"Router for '{current}' returned a list containing non-Send " - f"element {s!r}; expected list[Send]." - ) - if s.target not in self._dag.nodes: - raise PipelineError(f"Router for '{current}' fans out to unknown target '{s.target}'") - return decision - - if isinstance(decision, Send): - if decision.target not in self._dag.nodes: - raise PipelineError(f"Router for '{current}' dispatched to unknown target '{decision.target}'") - return [decision] - - # String label. - spec = self._branches[current] - if spec.mapping is not None: - if decision not in spec.mapping: - raise PipelineError( - f"Router for '{current}' returned label '{decision}' not in mapping {list(spec.mapping)}" - ) - return spec.mapping[decision] - if decision not in self._dag.nodes: - raise PipelineError( - f"Router for '{current}' returned '{decision}' " - f"which is not a registered node id; pass an explicit mapping if you want labels." - ) - return decision - - def _common_successor(self, node_ids: list[str]) -> str | None: - """Return the node all ``node_ids`` share as their unique successor, or None.""" - successors = [self._dag.successors(nid) for nid in node_ids] - if not successors or any(len(s) != 1 for s in successors): - return None - first = successors[0][0] - return first if all(s[0] == first for s in successors[1:]) else None - - async def invoke( - self, - state: BaseModel | None = None, - *, - run_id: str | None = None, - start_at: str | Callable[..., Any] | None = None, - approve_pause: bool = False, - ) -> StatePipelineResult: - """Run the pipeline. - - Modes: - * Fresh run: ``invoke(state)`` — generates a new ``run_id``. - * Resume: ``invoke(run_id="abc")`` — loads latest checkpoint and continues. - * Mid-pipeline start: ``invoke(state=..., start_at=node)`` — - starts execution at ``node`` with the provided state. - """ - completed: list[str] = [] - - # Resume mode: load checkpoint, derive starting node from it. - if run_id is not None and state is None and start_at is None: - if self._checkpointer is None: - raise PipelineError("Cannot resume: pipeline has no checkpointer") - record = self._checkpointer.load_latest(self._name, run_id) - if record is None: - raise PipelineError(f"No checkpoint found for run_id='{run_id}'") - # A paused run requires explicit approval before continuing. - if record.paused and not approve_pause: - raise PipelineError( - f"Run '{run_id}' is paused at node '{record.node_id}' " - f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." - ) - state = self._state_schema.model_validate(record.state) - completed = list(record.completed_nodes) - # Resume at the successor of the last completed (or paused) node. - next_node = self._next_step(record.node_id, state) - # Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here. - if isinstance(next_node, list): - raise PipelineError( - "Resume across a fan-out (Send) is not supported in Phase 2; " - "the run finished by reaching a fan-out node." - ) - if next_node is None: - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, - ) - current_node: str | None = next_node - else: - if state is None: - raise PipelineError("invoke() requires a state argument (or a run_id to resume)") - if not isinstance(state, self._state_schema): - # Be helpful if caller passed a dict or a different model. - try: - state = self._state_schema.model_validate(state) - except Exception as exc: - raise PipelineError(f"state argument is not a {self._state_schema.__name__}: {exc}") from exc - if start_at is not None: - current_node = _resolve_node_id(start_at) - if current_node not in self._dag.nodes: - raise PipelineError(f"start_at='{current_node}' not in DAG") - else: - current_node = self._entry_node() - - if run_id is None: - run_id = uuid.uuid4().hex[:12] - - assert state is not None # narrowed by the branches above - visit_counts: dict[str, int] = {} - - next_step: str | list[Send] | None = current_node - - # ---- pipeline-level observability boundary ------------------------- - pipeline_start_time = time.perf_counter() - pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id) - await self._emit("on_pipeline_start", self._name, run_id) - - result: StatePipelineResult | None = None - try: - while next_step is not None: - # --- fan-out branch (list[Send]) --------------------------------- - if isinstance(next_step, list): - try: - state = await self._run_fanout( - sends=next_step, - state=state, - completed=completed, - run_id=run_id, - visit_counts=visit_counts, - ) - except _NodeFailureError as fail: - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=fail.message, - failed_node=fail.node_id, - ) - break - # After fan-out, continue from the workers' shared successor (if any). - next_step = self._common_successor([s.target for s in next_step]) - continue - - # --- single-node step -------------------------------------------- - node_id = next_step - visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 - visit_n = visit_counts[node_id] - if visit_n > self._recursion_limit: - msg = ( - f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " - f"Raise recursion_limit= or fix the routing logic." - ) - logger.error(msg) - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=msg, - failed_node=node_id, - ) - break - - fn = self._node_fns[node_id] - node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) - await self._emit("on_node_start", self._name, run_id, node_id, visit_n) - inputs_snapshot = state.model_dump(mode="json") - started_at = datetime.now(UTC) - t0 = time.perf_counter() - try: - update = await fn(state) - except Exception as exc: - logger.exception( - "State pipeline '%s' run '%s' failed at node '%s'", - self._name, - run_id, - node_id, - ) - await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - self._audit( - run_id=run_id, - node_id=node_id, - sequence=len(completed) + 1, - visit=visit_n, - started_at=started_at, - completed_at=datetime.now(UTC), - latency_ms=(time.perf_counter() - t0) * 1000, - status="error", - inputs_snapshot=inputs_snapshot, - outputs_snapshot={}, - error_message=str(exc), - ) - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, - ) - break - elapsed = (time.perf_counter() - t0) * 1000 - completed_at = datetime.now(UTC) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - - # HITL: a node returning Pause halts the pipeline and writes a - # paused checkpoint. Approval comes via invoke(approve_pause=True). - if isinstance(update, Pause): - pause_reason = update.reason - await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) - completed.append(node_id) - self._save_checkpoint( - run_id, - node_id, - len(completed), - state, - completed, - paused=True, - pause_reason=pause_reason, - ) - self._audit( - run_id=run_id, - node_id=node_id, - sequence=len(completed), - visit=visit_n, - started_at=started_at, - completed_at=completed_at, - latency_ms=elapsed, - status="paused", - inputs_snapshot=inputs_snapshot, - outputs_snapshot=state.model_dump(mode="json"), - pause_reason=pause_reason, - ) - logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - paused=True, - paused_node=node_id, - pause_reason=pause_reason, - ) - break - - await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) - logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) - - if update: - state = apply_update(state, update, self._reducers) - - completed.append(node_id) - self._save_checkpoint(run_id, node_id, len(completed), state, completed) - self._audit( - run_id=run_id, - node_id=node_id, - sequence=len(completed), - visit=visit_n, - started_at=started_at, - completed_at=completed_at, - latency_ms=elapsed, - status="success", - inputs_snapshot=inputs_snapshot, - outputs_snapshot=state.model_dump(mode="json"), - ) - - try: - next_step = self._next_step(node_id, state) - except PipelineError as exc: - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, - ) - break - - if result is None: - result = StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, - ) - finally: - if pipeline_span is not None: - with contextlib.suppress(Exception): - pipeline_span.end() - duration_ms = (time.perf_counter() - pipeline_start_time) * 1000 - success = result.success if result is not None else False - await self._emit("on_pipeline_complete", self._name, run_id, success, duration_ms) - - assert result is not None # set in try-block before reaching here - return result - - async def _run_fanout( - self, - *, - sends: list[Send], - state: BaseModel, - completed: list[str], - run_id: str, - visit_counts: dict[str, int], - ) -> BaseModel: - """Run all ``Send`` dispatches concurrently. Each task gets its own state - copy with the Send's payload merged in; results are reduced into shared state. - """ - # Snapshot each Send's visit number BEFORE dispatch so the worker's - # closure captures its own visit, not the final post-increment value. - sends_with_visits: list[tuple[Send, int]] = [] - for send in sends: - visit_counts[send.target] = visit_counts.get(send.target, 0) + 1 - sends_with_visits.append((send, visit_counts[send.target])) - if visit_counts[send.target] > self._recursion_limit: - raise _NodeFailureError( - node_id=send.target, - message=( - f"Recursion limit ({self._recursion_limit}) exceeded at node '{send.target}' during fan-out." - ), - ) - - async def _run_one(send: Send, visit_n: int) -> tuple[Send, dict[str, Any] | None]: - await self._emit("on_node_start", self._name, run_id, send.target, visit_n) - node_span = start_otel_span(f"pipeline.state.node.{send.target}", node=send.target, visit=visit_n) - task_state = apply_update(state, send.payload, self._reducers) - fn = self._node_fns[send.target] - t0 = time.perf_counter() - try: - update = await fn(task_state) - except Exception as exc: - await self._emit("on_node_error", self._name, run_id, send.target, str(exc)) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - raise - elapsed = (time.perf_counter() - t0) * 1000 - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() - await self._emit("on_node_complete", self._name, run_id, send.target, elapsed) - return send, update - - try: - results = await asyncio.gather(*(_run_one(s, v) for s, v in sends_with_visits)) - except Exception as exc: - # Best-effort: report the first failing target as the failure point. - raise _NodeFailureError( - node_id=sends[0].target, - message=f"Fan-out failure: {exc}", - ) from exc - - new_state = state - for send, update in results: - if update: - new_state = apply_update(new_state, update, self._reducers) - completed.append(send.target) - self._save_checkpoint(run_id, send.target, len(completed), new_state, completed) - - return new_state - - def _save_checkpoint( - self, - run_id: str, - node_id: str, - sequence: int, - state: BaseModel, - completed: list[str], - *, - paused: bool = False, - pause_reason: str | None = None, - ) -> None: - """Persist state via the configured checkpointer (no-op if absent).""" - if self._checkpointer is None: - return - try: - self._checkpointer.save( - CheckpointRecord( - pipeline_name=self._name, - run_id=run_id, - node_id=node_id, - sequence=sequence, - state=state.model_dump(), - completed_nodes=list(completed), - paused=paused, - pause_reason=pause_reason, - ) - ) - except Exception: - logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) - - -class _NodeFailureError(Exception): - """Internal sentinel used to bubble fan-out failures out to the main loop.""" - - def __init__(self, node_id: str, message: str) -> None: - super().__init__(message) - self.node_id = node_id - self.message = message - - -def _resolve_node_id(ref: str | Callable[..., Any]) -> str: - """Turn either a string id or a function reference into a node id.""" - if isinstance(ref, str): - return ref - name = getattr(ref, "__name__", None) - if not name: - raise PipelineError(f"Cannot derive node id from {ref!r}") - return name - - -def coerce_state_node_fn(fn: Callable[..., Any]) -> StateNodeFn: - """Adapt a user-supplied callable into the ``async (state) -> dict | None`` shape. - - Accepted forms: - * ``async def f(state) -> dict | None`` — used as-is. - * ``def f(state) -> dict | None`` — wrapped to run in a thread. - * Object with ``async run(state)`` (e.g. a FireflyAgent-like) — adapter calls ``.run(state)``. - """ - if inspect.iscoroutinefunction(fn): - return fn # type: ignore[return-value] - - # Object with .run(state) — e.g. a FireflyAgent. Check before the generic - # callable branch so agent-shaped objects don't get treated as plain callables. - run = getattr(fn, "run", None) - if not callable(fn) and run is not None and callable(run): - - async def _agent_wrap(state: Any) -> Any: - if inspect.iscoroutinefunction(run): - return await run(state) - return await asyncio.get_running_loop().run_in_executor(None, run, state) - - return _agent_wrap - - if callable(fn): - - async def _async_wrap(state: Any) -> Any: - return await asyncio.get_running_loop().run_in_executor(None, fn, state) - - return _async_wrap - - raise PipelineError(f"Cannot adapt {fn!r} as a state node function") diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py index 419faca3..51e42906 100644 --- a/tests/unit/pipeline/test_checkpoint_backends.py +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -20,7 +20,6 @@ CheckpointRecord, FileCheckpointer, PipelineBuilder, - StatePipeline, ) # ============================================================================= @@ -113,7 +112,7 @@ class FactoryState(BaseModel): evaluation: str | None = None -def _build_factory(checkpointer) -> StatePipeline: +def _build_factory(checkpointer): """Construct the canonical 4-step agent pipeline that fails on first deploy.""" state_flag = {"failed_once": False} @@ -141,7 +140,6 @@ async def evaluator(state: FactoryState) -> dict: .chain(architect, python_dev, deployer, evaluator) .build() ) - assert isinstance(pipeline, StatePipeline) return pipeline diff --git a/tests/unit/pipeline/test_state_pipeline.py b/tests/unit/pipeline/test_state_pipeline.py index 0853c106..f0d1095d 100644 --- a/tests/unit/pipeline/test_state_pipeline.py +++ b/tests/unit/pipeline/test_state_pipeline.py @@ -27,7 +27,6 @@ from fireflyframework_agentic.pipeline import ( FileCheckpointer, PipelineBuilder, - StatePipeline, append, ) @@ -62,7 +61,6 @@ async def step_c(state: AgentState) -> dict: .chain(step_a, step_b, step_c) .build() ) - assert isinstance(pipeline, StatePipeline) result = await pipeline.invoke(AgentState(messages=["start"])) assert result.success assert result.completed_nodes == ["step_a", "step_b", "step_c"] diff --git a/tests/unit/pipeline/test_state_pipeline_deprecation.py b/tests/unit/pipeline/test_state_pipeline_deprecation.py deleted file mode 100644 index ec65da49..00000000 --- a/tests/unit/pipeline/test_state_pipeline_deprecation.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Layer 7 of the unification (#245): StatePipeline deprecation. - -Constructing :class:`StatePipeline` now emits a :class:`DeprecationWarning` -pointing at :class:`PipelineEngine` configured with ``state_schema=`` as -the supported replacement. -""" - -from __future__ import annotations - -import warnings - -from pydantic import BaseModel - -from fireflyframework_agentic.pipeline.builder import PipelineBuilder - - -class _S(BaseModel): - x: int = 0 - - -async def _noop(state): - return None - - -def test_state_pipeline_emits_deprecation_warning(): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - PipelineBuilder("p", state=_S).add_node(_noop).build() - deprec = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert deprec, "expected a DeprecationWarning when constructing StatePipeline" - assert "PipelineEngine" in str(deprec[0].message) - assert "#245" in str(deprec[0].message) diff --git a/tests/unit/pipeline/test_state_pipeline_hitl.py b/tests/unit/pipeline/test_state_pipeline_hitl.py index aff1c687..f880c459 100644 --- a/tests/unit/pipeline/test_state_pipeline_hitl.py +++ b/tests/unit/pipeline/test_state_pipeline_hitl.py @@ -19,7 +19,6 @@ FileCheckpointer, Pause, PipelineBuilder, - StatePipeline, extend, ) @@ -62,7 +61,6 @@ async def deploy(state: DeployState) -> dict: .chain(architect, gate, deploy) .build() ) - assert isinstance(pipeline, StatePipeline) result = await pipeline.invoke(DeployState(requirements="user-mgmt")) assert result.paused is True diff --git a/tests/unit/pipeline/test_state_pipeline_phase2.py b/tests/unit/pipeline/test_state_pipeline_phase2.py index f2040f80..c3862fa4 100644 --- a/tests/unit/pipeline/test_state_pipeline_phase2.py +++ b/tests/unit/pipeline/test_state_pipeline_phase2.py @@ -25,7 +25,6 @@ FanOutStep, PipelineBuilder, Send, - StatePipeline, extend, ) from fireflyframework_agentic.pipeline.steps import CallableStep @@ -53,7 +52,6 @@ def route(state: LoopState) -> str: return "done" if state.counter >= 3 else "step" pipeline = PipelineBuilder("loop", state=LoopState).add_node(step).add_node(done).branch(step, route).build() - assert isinstance(pipeline, StatePipeline) result = await pipeline.invoke(LoopState()) assert result.success assert result.state.counter == 3 @@ -199,7 +197,6 @@ def route(state: LoopState) -> str: .branch(start, route, {"left_path": left, "right_path": right}) .build() ) - assert isinstance(pipeline, StatePipeline) mermaid = pipeline.to_mermaid() assert "start -->|left_path| left" in mermaid assert "start -->|right_path| right" in mermaid