From 6471c60e53b52099edb63f0d590781421635c0de Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 27 May 2026 17:22:34 +0200 Subject: [PATCH] feat(pipeline): StatePipelineEventHandler + OTel spans (#147 phase 3b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds observability to state pipelines, mirroring the legacy PipelineEngine story but with state-mode semantics: run_id is plumbed through every callback, on_node_start carries a per-node visit counter, and there is no on_node_skip (state pipelines abort on failure rather than skipping). - New StatePipelineEventHandler Protocol in pipeline/engine.py with on_pipeline_start / on_node_start / on_node_complete / on_node_error / on_pipeline_complete. Partial handlers are valid (hasattr-checked). - Per-node OTel spans nested under one pipeline-level span. The existing _start_otel_span helper on PipelineEngine is lifted to a module-level start_otel_span function shared by both pipeline types. - PipelineBuilder gains an event_handler kwarg that flows into StatePipeline. - Fan-out via Send emits per-Send node-start/complete pairs with each Send's own visit number (snapshot at increment time, not post-loop). - Handler exceptions are swallowed — observability never breaks business logic. Tests: 8 new in test_state_pipeline_observability.py covering event ordering, failure path, cyclic visit counts, fan-out per-Send events, resume-from-checkpoint, partial handler, swallowed handler exceptions, and OTel span emission with attribute snapshots. Example: examples/pipeline_state.py gains a ProgressHandler that prints live progress for the software-factory scenario. Verification: - pytest tests/unit/pipeline/ → 122 passed (114 baseline + 8 new) - ruff check + format clean - pyright clean on touched modules --- docs/pipeline.md | 53 +++ examples/pipeline_state.py | 43 ++- fireflyframework_agentic/pipeline/__init__.py | 7 +- fireflyframework_agentic/pipeline/builder.py | 5 +- fireflyframework_agentic/pipeline/engine.py | 69 +++- .../pipeline/state_pipeline.py | 177 ++++++++-- .../test_state_pipeline_observability.py | 324 ++++++++++++++++++ 7 files changed, 623 insertions(+), 55 deletions(-) create mode 100644 tests/unit/pipeline/test_state_pipeline_observability.py diff --git a/docs/pipeline.md b/docs/pipeline.md index a6918b30..60bd2097 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -267,6 +267,59 @@ When all worker targets share a common successor, the engine continues there once the fan-out completes; the aggregator runs once with all results in shared state. +### Observability + +State pipelines emit lifecycle callbacks and OTel spans so ops can see what +an agent workflow is doing in real time. + +`StatePipelineEventHandler` mirrors the legacy `PipelineEventHandler` but +every callback carries the `run_id` (so events can be correlated across +resumes) and `on_node_start` carries a per-node visit counter (so cyclic +graphs and `Send` fan-outs are distinguishable). Implement any subset of +methods; missing ones are no-ops. + +```python +from fireflyframework_agentic.pipeline import PipelineBuilder, StatePipelineEventHandler + + +class ProgressHandler: + async def on_pipeline_start(self, name, run_id): + print(f"▶ [{name}] run {run_id} starting") + + async def on_node_start(self, name, run_id, node_id, visit): + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete(self, name, run_id, node_id, latency_ms): + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error(self, name, run_id, node_id, error): + print(f" ✗ {node_id}: {error}") + + async def on_pipeline_complete(self, name, run_id, success, duration_ms): + status = "OK" if success else "FAILED" + print(f"═ [{name}] {status} in {duration_ms:.0f}ms") + + +pipeline = ( + PipelineBuilder("agent", state=AgentState, event_handler=ProgressHandler()) + .add_node(classify).add_node(answer).add_node(escalate) + .branch(classify, route) + .build() +) +``` + +In parallel, the pipeline emits OTel spans automatically when +`observability_enabled` is True and `opentelemetry` is installed: + +- One pipeline-level span `pipeline.state.` around each `invoke`, + attributes `firefly.pipeline`, `firefly.run_id`. +- One per-node span `pipeline.state.node.` for each `fn(state)` + call, parented under the pipeline span, attributes `firefly.node`, + `firefly.visit`. +- For `Send` fan-out: one per-Send span as a sibling under the pipeline span. + +Handler exceptions are swallowed — observability never breaks business logic. + ### Mermaid Export `StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index 6a3536f8..717c7f64 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -152,13 +152,52 @@ async def evaluator(state: BuildState) -> dict: return {"evaluation": f"PASS — deployed at {state.deploy_url}"} +class ProgressHandler: + """Prints live progress for the software-factory scenario. + + Implements only a subset of :class:`StatePipelineEventHandler` — missing + methods are no-ops, which is fine because the pipeline tolerates partial + handlers. + """ + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + print(f" ▶ [{pipeline_name}] run {run_id[:8]}… starting") + + async def on_node_start( + self, pipeline_name: str, run_id: str, node_id: str, visit: int + ) -> None: + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete( + self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float + ) -> None: + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error( + self, pipeline_name: str, run_id: str, node_id: str, error: str + ) -> None: + print(f" ✗ {node_id}: {error}") + + async def on_pipeline_complete( + self, pipeline_name: str, run_id: str, success: bool, duration_ms: float + ) -> None: + status = "OK" if success else "FAILED" + print(f" ═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") + + async def run_software_factory() -> None: - print("=== 2. Software factory with checkpoint/resume ===\n") + print("=== 2. Software factory with checkpoint/resume + live progress ===\n") + handler = ProgressHandler() with tempfile.TemporaryDirectory() as tmp: ckpt = FileCheckpointer(Path(tmp)) pipeline = ( - PipelineBuilder("software-factory", state=BuildState, checkpointer=ckpt) + PipelineBuilder( + "software-factory", + state=BuildState, + checkpointer=ckpt, + event_handler=handler, + ) .add_node(architect) .add_node(python_dev) .add_node(deployer) diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index 3925952c..af4778df 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -38,7 +38,11 @@ ) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine, PipelineEventHandler +from fireflyframework_agentic.pipeline.engine import ( + PipelineEngine, + PipelineEventHandler, + StatePipelineEventHandler, +) from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult from fireflyframework_agentic.pipeline.state_pipeline import ( @@ -89,6 +93,7 @@ "RetrievalStep", "Send", "StatePipeline", + "StatePipelineEventHandler", "StatePipelineResult", "StepExecutor", "append", diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index bb5f7681..2975d59f 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -55,7 +55,7 @@ from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.checkpoint import Checkpointer from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.engine import PipelineEngine, StatePipelineEventHandler from fireflyframework_agentic.pipeline.state_pipeline import ( BranchSpec, RouterFn, @@ -86,6 +86,7 @@ def __init__( state: type[BaseModel] | None = None, checkpointer: Checkpointer | None = None, recursion_limit: int = 25, + event_handler: StatePipelineEventHandler | None = None, ) -> None: # State pipelines may use cyclic graphs (ReAct loops, retry-with-critique). # The legacy port-based path keeps acyclicity as an invariant. @@ -94,6 +95,7 @@ def __init__( self._state_schema = state self._checkpointer = checkpointer self._recursion_limit = recursion_limit + self._event_handler = event_handler self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] # State-based mode bookkeeping. Keyed by node id. @@ -267,6 +269,7 @@ def build(self) -> PipelineEngine | StatePipeline: branches=self._branches, checkpointer=self._checkpointer, recursion_limit=self._recursion_limit, + event_handler=self._event_handler, ) return PipelineEngine(self._dag) diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 82010a9d..9f7e3d0f 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -71,6 +71,60 @@ async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration ... +@runtime_checkable +class StatePipelineEventHandler(Protocol): + """Protocol for state-pipeline progress callbacks. + + Mirrors :class:`PipelineEventHandler` but every callback carries the + ``run_id`` so ops can correlate events across resumes, and + ``on_node_start`` carries a ``visit`` counter so cyclic graphs are + distinguishable per iteration. There is no ``on_node_skip`` — state + pipelines abort on failure rather than skipping downstream nodes. + + Implement any subset of methods; missing ones are no-ops. + """ + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + """Called once when ``invoke`` begins.""" + ... + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + """Called each time a node is about to run. ``visit`` starts at 1 + and increments per re-entry (cycles, Send fan-out).""" + ... + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + """Called when a node completes successfully.""" + ... + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + """Called when a node raises an exception.""" + ... + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + """Called once when ``invoke`` returns.""" + ... + + +def start_otel_span(name: str, **attributes: Any) -> Any: + """Start an OTel span if observability is enabled, else return ``None``. + + Module-level helper shared by :class:`PipelineEngine` and + :class:`fireflyframework_agentic.pipeline.state_pipeline.StatePipeline`. + """ + try: + if not get_config().observability_enabled: + return None + if otel_trace is None: + return None + return otel_trace.get_tracer("fireflyframework_agentic").start_span( + name, + attributes={f"firefly.{k}": str(v) for k, v in attributes.items()}, + ) + except Exception: # noqa: BLE001 + return None + + class PipelineEngine: """Executes a :class:`DAG` by computing topological levels and running nodes within each level concurrently. @@ -347,19 +401,8 @@ async def _execute_node( @staticmethod def _start_otel_span(name: str, **attributes: Any) -> Any: - """Start an OTel span if observability is enabled, else return *None*.""" - try: - if not get_config().observability_enabled: - return None - if otel_trace is None: - return None - - return otel_trace.get_tracer("fireflyframework_agentic").start_span( - name, - attributes={f"firefly.{k}": str(v) for k, v in attributes.items()}, - ) - except Exception: # noqa: BLE001 - return None + """Backwards-compatible wrapper around the module-level :func:`start_otel_span`.""" + return start_otel_span(name, **attributes) @staticmethod def _aggregate_usage(correlation_id: str) -> Any: diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 6c818bf7..d33c37b7 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -24,19 +24,24 @@ from __future__ import annotations import asyncio +import contextlib import inspect import logging import time import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, get_type_hints +from typing import TYPE_CHECKING, Any, get_type_hints from pydantic import BaseModel from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id +from fireflyframework_agentic.pipeline.engine import start_otel_span + +if TYPE_CHECKING: + from fireflyframework_agentic.pipeline.engine import StatePipelineEventHandler from fireflyframework_agentic.pipeline.reducers import Reducer, replace logger = logging.getLogger(__name__) @@ -151,6 +156,7 @@ def __init__( branches: dict[str, BranchSpec], checkpointer: Checkpointer | None = None, recursion_limit: int = 25, + event_handler: StatePipelineEventHandler | None = None, ) -> None: self._name = name self._dag = dag @@ -159,9 +165,43 @@ def __init__( self._branches = branches self._checkpointer = checkpointer self._recursion_limit = recursion_limit + self._event_handler = event_handler self._reducers = discover_reducers(state_schema) self._validate() + async def _finalize_run( + self, + result: StatePipelineResult, + span: Any, + start_time: float, + run_id: str, + ) -> StatePipelineResult: + """Close the pipeline-level span and emit ``on_pipeline_complete``. + + Every return path in :meth:`invoke` after the observability boundary + flows through this helper. + """ + if span is not None: + with contextlib.suppress(Exception): + span.end() + duration_ms = (time.perf_counter() - start_time) * 1000 + await self._emit("on_pipeline_complete", self._name, run_id, result.success, duration_ms) + return result + + async def _emit(self, method: str, *args: Any) -> None: + """Invoke ``method`` on the configured event handler if it exists. + + Missing methods are no-ops; raised exceptions are swallowed so + observability never breaks business logic. + """ + if self._event_handler is None: + return + fn = getattr(self._event_handler, method, None) + if fn is None: + return + with contextlib.suppress(Exception): + await fn(*args) + @property def name(self) -> str: return self._name @@ -358,6 +398,11 @@ async def invoke( next_step: str | list[Send] | None = current_node + # ---- pipeline-level observability boundary ------------------------- + pipeline_start_time = time.perf_counter() + pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id) + await self._emit("on_pipeline_start", self._name, run_id) + while next_step is not None: # --- fan-out branch (list[Send]) --------------------------------- if isinstance(next_step, list): @@ -371,13 +416,18 @@ async def invoke( visit_counts=visit_counts, ) except _NodeFailureError as fail: - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=fail.message, - failed_node=fail.node_id, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=fail.message, + failed_node=fail.node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) # After fan-out, continue from the workers' shared successor (if any). next_step = self._common_successor([s.target for s in next_step]) @@ -386,22 +436,30 @@ async def invoke( # --- single-node step -------------------------------------------- node_id = next_step visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 - if visit_counts[node_id] > self._recursion_limit: + visit_n = visit_counts[node_id] + if visit_n > self._recursion_limit: msg = ( f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " f"Raise recursion_limit= or fix the routing logic." ) logger.error(msg) - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=msg, - failed_node=node_id, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=msg, + failed_node=node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) fn = self._node_fns[node_id] + node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) + await self._emit("on_node_start", self._name, run_id, node_id, visit_n) t0 = time.perf_counter() try: update = await fn(state) @@ -412,15 +470,28 @@ async def invoke( run_id, node_id, ) - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, + await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) elapsed = (time.perf_counter() - t0) * 1000 + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) if update: @@ -433,20 +504,30 @@ async def invoke( try: next_step = self._next_step(node_id, state) except PipelineError as exc: - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - error=str(exc), - failed_node=node_id, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=node_id, + ), + pipeline_span, + pipeline_start_time, + run_id, ) - return StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, + return await self._finalize_run( + StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=True, + ), + pipeline_span, + pipeline_start_time, + run_id, ) async def _run_fanout( @@ -462,8 +543,12 @@ async def _run_fanout( """Run all ``Send`` dispatches concurrently. Each task gets its own state copy with the Send's payload merged in; results are reduced into shared state. """ + # Snapshot each Send's visit number BEFORE dispatch so the worker's + # closure captures its own visit, not the final post-increment value. + sends_with_visits: list[tuple[Send, int]] = [] for send in sends: visit_counts[send.target] = visit_counts.get(send.target, 0) + 1 + sends_with_visits.append((send, visit_counts[send.target])) if visit_counts[send.target] > self._recursion_limit: raise _NodeFailureError( node_id=send.target, @@ -472,13 +557,29 @@ async def _run_fanout( ), ) - async def _run_one(send: Send) -> tuple[Send, dict[str, Any] | None]: + async def _run_one(send: Send, visit_n: int) -> tuple[Send, dict[str, Any] | None]: + await self._emit("on_node_start", self._name, run_id, send.target, visit_n) + node_span = start_otel_span(f"pipeline.state.node.{send.target}", node=send.target, visit=visit_n) task_state = apply_update(state, send.payload, self._reducers) fn = self._node_fns[send.target] - return send, await fn(task_state) + t0 = time.perf_counter() + try: + update = await fn(task_state) + except Exception as exc: + await self._emit("on_node_error", self._name, run_id, send.target, str(exc)) + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + raise + elapsed = (time.perf_counter() - t0) * 1000 + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + await self._emit("on_node_complete", self._name, run_id, send.target, elapsed) + return send, update try: - results = await asyncio.gather(*(_run_one(s) for s in sends)) + results = await asyncio.gather(*(_run_one(s, v) for s, v in sends_with_visits)) except Exception as exc: # Best-effort: report the first failing target as the failure point. raise _NodeFailureError( diff --git a/tests/unit/pipeline/test_state_pipeline_observability.py b/tests/unit/pipeline/test_state_pipeline_observability.py new file mode 100644 index 00000000..ac79233a --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_observability.py @@ -0,0 +1,324 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3b tests: StatePipelineEventHandler callbacks + OTel spans for state pipelines.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Annotated, Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.engine as engine_module +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + Send, + extend, +) + + +@dataclass +class RecordingHandler: + """A test handler that captures every callback in order.""" + + events: list[tuple] = field(default_factory=list) + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.events.append(("pipeline_start", pipeline_name, run_id)) + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + self.events.append(("node_start", node_id, visit)) + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + self.events.append(("node_complete", node_id)) + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + self.events.append(("node_error", node_id, error)) + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + self.events.append(("pipeline_complete", success)) + + +class LinearState(BaseModel): + log: Annotated[list[str], extend] = [] + + +class LoopState(BaseModel): + counter: int = 0 + + +# --- linear pipeline event ordering ----------------------------------------- + + +@pytest.mark.asyncio +async def test_linear_pipeline_emits_events_in_order() -> None: + async def a(state: LinearState) -> dict: + return {"log": ["a"]} + + async def b(state: LinearState) -> dict: + return {"log": ["b"]} + + async def c(state: LinearState) -> dict: + return {"log": ["c"]} + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("linear", state=LinearState, event_handler=handler) + .add_node(a) + .add_node(b) + .add_node(c) + .chain(a, b, c) + .build() + ) + await pipeline.invoke(LinearState()) + + event_kinds = [e[0] for e in handler.events] + assert event_kinds == [ + "pipeline_start", + "node_start", + "node_complete", + "node_start", + "node_complete", + "node_start", + "node_complete", + "pipeline_complete", + ] + assert handler.events[-1] == ("pipeline_complete", True) + + +# --- failure path ---------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_failure_emits_node_error_and_pipeline_complete_false() -> None: + async def boom(state: LinearState) -> dict: + raise RuntimeError("nope") + + handler = RecordingHandler() + pipeline = PipelineBuilder("fail", state=LinearState, event_handler=handler).add_node(boom).build() + result = await pipeline.invoke(LinearState()) + assert not result.success + + assert ("node_error", "boom", "nope") in handler.events + assert ("pipeline_complete", False) in handler.events + # node_error fires BEFORE pipeline_complete + err_idx = handler.events.index(("node_error", "boom", "nope")) + done_idx = handler.events.index(("pipeline_complete", False)) + assert err_idx < done_idx + + +# --- cyclic graph visit count ---------------------------------------------- + + +@pytest.mark.asyncio +async def test_cyclic_graph_increments_visit_count() -> None: + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1} + + async def done(state: LoopState) -> dict: + return {} + + def route(state: LoopState) -> str: + return "done" if state.counter >= 3 else "step" + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("loop", state=LoopState, event_handler=handler) + .add_node(step) + .add_node(done) + .branch(step, route) + .build() + ) + await pipeline.invoke(LoopState()) + + step_starts = [e for e in handler.events if e[0] == "node_start" and e[1] == "step"] + assert [e[2] for e in step_starts] == [1, 2, 3] + + +# --- fan-out via Send ------------------------------------------------------- + + +class FanOutState(BaseModel): + items: list[str] = [] + results: Annotated[list[str], extend] = [] + item: str | None = None + + +@pytest.mark.asyncio +async def test_fanout_emits_per_send_node_events() -> None: + async def planner(state: FanOutState) -> dict: + return {} + + async def worker(state: FanOutState) -> dict: + return {"results": [f"r:{state.item}"]} + + async def collect(state: FanOutState) -> dict: + return {} + + def dispatch(state: FanOutState) -> list[Send]: + return [Send("worker", {"item": x}) for x in state.items] + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("fanout", state=FanOutState, event_handler=handler) + .add_node(planner) + .add_node(worker) + .add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() + ) + await pipeline.invoke(FanOutState(items=["a", "b", "c"])) + + worker_starts = [e for e in handler.events if e[0] == "node_start" and e[1] == "worker"] + worker_completes = [e for e in handler.events if e[0] == "node_complete" and e[1] == "worker"] + assert len(worker_starts) == 3 + assert len(worker_completes) == 3 + # Visits are 1, 2, 3 across the three Sends. + assert sorted(e[2] for e in worker_starts) == [1, 2, 3] + + +# --- resume from a checkpoint ---------------------------------------------- + + +class BuildState(BaseModel): + requirements: str + spec: str | None = None + code: str | None = None + deploy: str | None = None + + +@pytest.mark.asyncio +async def test_resume_emits_events_only_for_remaining_nodes(tmp_path: Path) -> None: + fail_once = {"flag": False} + + async def arch(state: BuildState) -> dict: + return {"spec": "s"} + + async def dev(state: BuildState) -> dict: + return {"code": "c"} + + async def deploy(state: BuildState) -> dict: + if not fail_once["flag"]: + fail_once["flag"] = True + raise RuntimeError("blip") + return {"deploy": "ok"} + + handler1 = RecordingHandler() + handler2 = RecordingHandler() + ckpt = FileCheckpointer(tmp_path) + + # First run uses handler1; deploy fails. + p1 = ( + PipelineBuilder("factory", state=BuildState, checkpointer=ckpt, event_handler=handler1) + .add_node(arch) + .add_node(dev) + .add_node(deploy) + .chain(arch, dev, deploy) + .build() + ) + first = await p1.invoke(BuildState(requirements="x")) + assert not first.success + + # Second run uses handler2 and resumes; only deploy should run. + p2 = ( + PipelineBuilder("factory", state=BuildState, checkpointer=ckpt, event_handler=handler2) + .add_node(arch) + .add_node(dev) + .add_node(deploy) + .chain(arch, dev, deploy) + .build() + ) + second = await p2.invoke(run_id=first.run_id) + assert second.success + + nodes_started_on_resume = [e[1] for e in handler2.events if e[0] == "node_start"] + assert nodes_started_on_resume == ["deploy"] + + +# --- partial handler -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_partial_handler_only_implementing_node_error_works() -> None: + captured: list[str] = [] + + class JustErrors: + async def on_node_error(self, pipeline_name, run_id, node_id, error): + captured.append(f"{node_id}:{error}") + + async def boom(state: LinearState) -> dict: + raise RuntimeError("kaboom") + + pipeline = PipelineBuilder("partial", state=LinearState, event_handler=JustErrors()).add_node(boom).build() + result = await pipeline.invoke(LinearState()) + assert not result.success + assert captured == ["boom:kaboom"] + + +# --- handler exception is swallowed ---------------------------------------- + + +@pytest.mark.asyncio +async def test_handler_exception_does_not_break_pipeline() -> None: + class CrashyHandler: + async def on_node_complete(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError("handler crashed") + + async def step(state: LinearState) -> dict: + return {"log": ["ran"]} + + pipeline = PipelineBuilder("crash", state=LinearState, event_handler=CrashyHandler()).add_node(step).build() + result = await pipeline.invoke(LinearState()) + assert result.success + assert result.state.log == ["ran"] + + +# --- OTel spans ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_otel_spans_emitted_per_pipeline_and_per_node(monkeypatch: pytest.MonkeyPatch) -> None: + # Force the observability path to fire even if the config disables it by default. + mock_config = MagicMock(observability_enabled=True) + monkeypatch.setattr(engine_module, "get_config", lambda: mock_config) + + # Stub the OTel tracer so we can record span starts. + started: list[tuple[str, dict]] = [] + mock_tracer = MagicMock() + + def fake_start_span(name: str, attributes: dict | None = None) -> MagicMock: + started.append((name, attributes or {})) + return MagicMock() + + mock_tracer.start_span.side_effect = fake_start_span + mock_trace = MagicMock() + mock_trace.get_tracer.return_value = mock_tracer + monkeypatch.setattr(engine_module, "otel_trace", mock_trace) + + async def a(state: LinearState) -> dict: + return {} + + async def b(state: LinearState) -> dict: + return {} + + pipeline = PipelineBuilder("otel", state=LinearState).add_node(a).add_node(b).chain(a, b).build() + await pipeline.invoke(LinearState()) + + names = [n for n, _ in started] + assert "pipeline.state.otel" in names + assert "pipeline.state.node.a" in names + assert "pipeline.state.node.b" in names + + # Spot-check attributes: pipeline span carries run_id, node span carries visit. + pipeline_attrs = next(attrs for name, attrs in started if name == "pipeline.state.otel") + assert "firefly.run_id" in pipeline_attrs + node_a_attrs = next(attrs for name, attrs in started if name == "pipeline.state.node.a") + assert node_a_attrs.get("firefly.visit") == "1"