From 9b6a68728571edb59600f71e5805d441adb3d649 Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Thu, 28 May 2026 08:03:06 +0200 Subject: [PATCH] refactor(pipeline): collapse invoke() to single return path Replace 7 _finalize_run call sites in StatePipeline.invoke with a single try/finally that ends the pipeline span and emits on_pipeline_complete once. Remove the _finalize_run helper. Behavior unchanged; net -32 LOC. --- .../pipeline/state_pipeline.py | 320 ++++++++---------- 1 file changed, 144 insertions(+), 176 deletions(-) diff --git a/fireflyframework_agentic/pipeline/state_pipeline.py b/fireflyframework_agentic/pipeline/state_pipeline.py index 823e4928..88d4379f 100644 --- a/fireflyframework_agentic/pipeline/state_pipeline.py +++ b/fireflyframework_agentic/pipeline/state_pipeline.py @@ -247,25 +247,6 @@ def _audit( except Exception: logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) - async def _finalize_run( - self, - result: StatePipelineResult, - span: Any, - start_time: float, - run_id: str, - ) -> StatePipelineResult: - """Close the pipeline-level span and emit ``on_pipeline_complete``. - - Every return path in :meth:`invoke` after the observability boundary - flows through this helper. - """ - if span is not None: - with contextlib.suppress(Exception): - span.end() - duration_ms = (time.perf_counter() - start_time) * 1000 - await self._emit("on_pipeline_complete", self._name, run_id, result.success, duration_ms) - return result - async def _emit(self, method: str, *args: Any) -> None: """Invoke ``method`` on the configured event handler if it exists. @@ -488,127 +469,151 @@ async def invoke( pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id) await self._emit("on_pipeline_start", self._name, run_id) - while next_step is not None: - # --- fan-out branch (list[Send]) --------------------------------- - if isinstance(next_step, list): - try: - state, sequence = await self._run_fanout( - sends=next_step, - state=state, - completed=completed, - run_id=run_id, - sequence=sequence, - visit_counts=visit_counts, - ) - except _NodeFailureError as fail: - return await self._finalize_run( - StatePipelineResult( + result: StatePipelineResult | None = None + try: + while next_step is not None: + # --- fan-out branch (list[Send]) --------------------------------- + if isinstance(next_step, list): + try: + state, sequence = await self._run_fanout( + sends=next_step, + state=state, + completed=completed, + run_id=run_id, + sequence=sequence, + visit_counts=visit_counts, + ) + except _NodeFailureError as fail: + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, error=fail.message, failed_node=fail.node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, + ) + break + # After fan-out, continue from the workers' shared successor (if any). + next_step = self._common_successor([s.target for s in next_step]) + continue + + # --- single-node step -------------------------------------------- + node_id = next_step + visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 + visit_n = visit_counts[node_id] + if visit_n > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " + f"Raise recursion_limit= or fix the routing logic." ) - # After fan-out, continue from the workers' shared successor (if any). - next_step = self._common_successor([s.target for s in next_step]) - continue - - # --- single-node step -------------------------------------------- - node_id = next_step - visit_counts[node_id] = visit_counts.get(node_id, 0) + 1 - visit_n = visit_counts[node_id] - if visit_n > self._recursion_limit: - msg = ( - f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. " - f"Raise recursion_limit= or fix the routing logic." - ) - logger.error(msg) - return await self._finalize_run( - StatePipelineResult( + logger.error(msg) + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, error=msg, failed_node=node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) - - fn = self._node_fns[node_id] - node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) - await self._emit("on_node_start", self._name, run_id, node_id, visit_n) - inputs_snapshot = state.model_dump(mode="json") - started_at = datetime.now(UTC) - t0 = time.perf_counter() - try: - update = await fn(state) - except Exception as exc: - logger.exception( - "State pipeline '%s' run '%s' failed at node '%s'", - self._name, - run_id, - node_id, - ) - await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) + ) + break + + fn = self._node_fns[node_id] + node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n) + await self._emit("on_node_start", self._name, run_id, node_id, visit_n) + inputs_snapshot = state.model_dump(mode="json") + started_at = datetime.now(UTC) + t0 = time.perf_counter() + try: + update = await fn(state) + except Exception as exc: + logger.exception( + "State pipeline '%s' run '%s' failed at node '%s'", + self._name, + run_id, + node_id, + ) + await self._emit("on_node_error", self._name, run_id, node_id, str(exc)) + if node_span is not None: + with contextlib.suppress(Exception): + node_span.end() + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence + 1, + visit=visit_n, + started_at=started_at, + completed_at=datetime.now(UTC), + latency_ms=(time.perf_counter() - t0) * 1000, + status="error", + inputs_snapshot=inputs_snapshot, + outputs_snapshot={}, + error_message=str(exc), + ) + result = StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=False, + error=str(exc), + failed_node=node_id, + ) + break + elapsed = (time.perf_counter() - t0) * 1000 + completed_at = datetime.now(UTC) if node_span is not None: with contextlib.suppress(Exception): node_span.end() - self._audit( - run_id=run_id, - node_id=node_id, - sequence=sequence + 1, - visit=visit_n, - started_at=started_at, - completed_at=datetime.now(UTC), - latency_ms=(time.perf_counter() - t0) * 1000, - status="error", - inputs_snapshot=inputs_snapshot, - outputs_snapshot={}, - error_message=str(exc), - ) - return await self._finalize_run( - StatePipelineResult( + + # HITL: a node returning Pause halts the pipeline and writes a + # paused checkpoint. Approval comes via invoke(approve_pause=True). + if isinstance(update, Pause): + pause_reason = update.reason + await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) + completed.append(node_id) + sequence += 1 + self._save_checkpoint( + run_id, + node_id, + sequence, + state, + completed, + paused=True, + pause_reason=pause_reason, + ) + self._audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit_n, + started_at=started_at, + completed_at=completed_at, + latency_ms=elapsed, + status="paused", + inputs_snapshot=inputs_snapshot, + outputs_snapshot=state.model_dump(mode="json"), + pause_reason=pause_reason, + ) + logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, - error=str(exc), - failed_node=node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) - elapsed = (time.perf_counter() - t0) * 1000 - completed_at = datetime.now(UTC) - if node_span is not None: - with contextlib.suppress(Exception): - node_span.end() + paused=True, + paused_node=node_id, + pause_reason=pause_reason, + ) + break + + await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) + logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) + + if update: + state = apply_update(state, update, self._reducers) - # HITL: a node returning Pause halts the pipeline and writes a - # paused checkpoint. Approval comes via invoke(approve_pause=True). - if isinstance(update, Pause): - pause_reason = update.reason - await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason) completed.append(node_id) sequence += 1 - self._save_checkpoint( - run_id, - node_id, - sequence, - state, - completed, - paused=True, - pause_reason=pause_reason, - ) + self._save_checkpoint(run_id, node_id, sequence, state, completed) self._audit( run_id=run_id, node_id=node_id, @@ -617,77 +622,40 @@ async def invoke( started_at=started_at, completed_at=completed_at, latency_ms=elapsed, - status="paused", + status="success", inputs_snapshot=inputs_snapshot, outputs_snapshot=state.model_dump(mode="json"), - pause_reason=pause_reason, ) - logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason) - return await self._finalize_run( - StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=False, - paused=True, - paused_node=node_id, - pause_reason=pause_reason, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) - - await self._emit("on_node_complete", self._name, run_id, node_id, elapsed) - logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed) - - if update: - state = apply_update(state, update, self._reducers) - - completed.append(node_id) - sequence += 1 - self._save_checkpoint(run_id, node_id, sequence, state, completed) - self._audit( - run_id=run_id, - node_id=node_id, - sequence=sequence, - visit=visit_n, - started_at=started_at, - completed_at=completed_at, - latency_ms=elapsed, - status="success", - inputs_snapshot=inputs_snapshot, - outputs_snapshot=state.model_dump(mode="json"), - ) - try: - next_step = self._next_step(node_id, state) - except PipelineError as exc: - return await self._finalize_run( - StatePipelineResult( + try: + next_step = self._next_step(node_id, state) + except PipelineError as exc: + result = StatePipelineResult( state=state, run_id=run_id, completed_nodes=completed, success=False, error=str(exc), failed_node=node_id, - ), - pipeline_span, - pipeline_start_time, - run_id, + ) + break + + if result is None: + result = StatePipelineResult( + state=state, + run_id=run_id, + completed_nodes=completed, + success=True, ) + finally: + if pipeline_span is not None: + with contextlib.suppress(Exception): + pipeline_span.end() + duration_ms = (time.perf_counter() - pipeline_start_time) * 1000 + success = result.success if result is not None else False + await self._emit("on_pipeline_complete", self._name, run_id, success, duration_ms) - return await self._finalize_run( - StatePipelineResult( - state=state, - run_id=run_id, - completed_nodes=completed, - success=True, - ), - pipeline_span, - pipeline_start_time, - run_id, - ) + return result async def _run_fanout( self,