Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 144 additions & 176 deletions fireflyframework_agentic/pipeline/state_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,6 @@ def _audit(
except Exception:
logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id)

async def _finalize_run(
self,
result: StatePipelineResult,
span: Any,
start_time: float,
run_id: str,
) -> StatePipelineResult:
"""Close the pipeline-level span and emit ``on_pipeline_complete``.

Every return path in :meth:`invoke` after the observability boundary
flows through this helper.
"""
if span is not None:
with contextlib.suppress(Exception):
span.end()
duration_ms = (time.perf_counter() - start_time) * 1000
await self._emit("on_pipeline_complete", self._name, run_id, result.success, duration_ms)
return result

async def _emit(self, method: str, *args: Any) -> None:
"""Invoke ``method`` on the configured event handler if it exists.

Expand Down Expand Up @@ -488,127 +469,151 @@ async def invoke(
pipeline_span = start_otel_span(f"pipeline.state.{self._name}", pipeline=self._name, run_id=run_id)
await self._emit("on_pipeline_start", self._name, run_id)

while next_step is not None:
# --- fan-out branch (list[Send]) ---------------------------------
if isinstance(next_step, list):
try:
state, sequence = await self._run_fanout(
sends=next_step,
state=state,
completed=completed,
run_id=run_id,
sequence=sequence,
visit_counts=visit_counts,
)
except _NodeFailureError as fail:
return await self._finalize_run(
StatePipelineResult(
result: StatePipelineResult | None = None
try:
while next_step is not None:
# --- fan-out branch (list[Send]) ---------------------------------
if isinstance(next_step, list):
try:
state, sequence = await self._run_fanout(
sends=next_step,
state=state,
completed=completed,
run_id=run_id,
sequence=sequence,
visit_counts=visit_counts,
)
except _NodeFailureError as fail:
result = StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=False,
error=fail.message,
failed_node=fail.node_id,
),
pipeline_span,
pipeline_start_time,
run_id,
)
break
# After fan-out, continue from the workers' shared successor (if any).
next_step = self._common_successor([s.target for s in next_step])
continue

# --- single-node step --------------------------------------------
node_id = next_step
visit_counts[node_id] = visit_counts.get(node_id, 0) + 1
visit_n = visit_counts[node_id]
if visit_n > self._recursion_limit:
msg = (
f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. "
f"Raise recursion_limit= or fix the routing logic."
)
# After fan-out, continue from the workers' shared successor (if any).
next_step = self._common_successor([s.target for s in next_step])
continue

# --- single-node step --------------------------------------------
node_id = next_step
visit_counts[node_id] = visit_counts.get(node_id, 0) + 1
visit_n = visit_counts[node_id]
if visit_n > self._recursion_limit:
msg = (
f"Recursion limit ({self._recursion_limit}) exceeded at node '{node_id}'. "
f"Raise recursion_limit= or fix the routing logic."
)
logger.error(msg)
return await self._finalize_run(
StatePipelineResult(
logger.error(msg)
result = StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=False,
error=msg,
failed_node=node_id,
),
pipeline_span,
pipeline_start_time,
run_id,
)

fn = self._node_fns[node_id]
node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n)
await self._emit("on_node_start", self._name, run_id, node_id, visit_n)
inputs_snapshot = state.model_dump(mode="json")
started_at = datetime.now(UTC)
t0 = time.perf_counter()
try:
update = await fn(state)
except Exception as exc:
logger.exception(
"State pipeline '%s' run '%s' failed at node '%s'",
self._name,
run_id,
node_id,
)
await self._emit("on_node_error", self._name, run_id, node_id, str(exc))
)
break

fn = self._node_fns[node_id]
node_span = start_otel_span(f"pipeline.state.node.{node_id}", node=node_id, visit=visit_n)
await self._emit("on_node_start", self._name, run_id, node_id, visit_n)
inputs_snapshot = state.model_dump(mode="json")
started_at = datetime.now(UTC)
t0 = time.perf_counter()
try:
update = await fn(state)
except Exception as exc:
logger.exception(
"State pipeline '%s' run '%s' failed at node '%s'",
self._name,
run_id,
node_id,
)
await self._emit("on_node_error", self._name, run_id, node_id, str(exc))
if node_span is not None:
with contextlib.suppress(Exception):
node_span.end()
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence + 1,
visit=visit_n,
started_at=started_at,
completed_at=datetime.now(UTC),
latency_ms=(time.perf_counter() - t0) * 1000,
status="error",
inputs_snapshot=inputs_snapshot,
outputs_snapshot={},
error_message=str(exc),
)
result = StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=False,
error=str(exc),
failed_node=node_id,
)
break
elapsed = (time.perf_counter() - t0) * 1000
completed_at = datetime.now(UTC)
if node_span is not None:
with contextlib.suppress(Exception):
node_span.end()
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence + 1,
visit=visit_n,
started_at=started_at,
completed_at=datetime.now(UTC),
latency_ms=(time.perf_counter() - t0) * 1000,
status="error",
inputs_snapshot=inputs_snapshot,
outputs_snapshot={},
error_message=str(exc),
)
return await self._finalize_run(
StatePipelineResult(

# HITL: a node returning Pause halts the pipeline and writes a
# paused checkpoint. Approval comes via invoke(approve_pause=True).
if isinstance(update, Pause):
pause_reason = update.reason
await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason)
completed.append(node_id)
sequence += 1
self._save_checkpoint(
run_id,
node_id,
sequence,
state,
completed,
paused=True,
pause_reason=pause_reason,
)
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence,
visit=visit_n,
started_at=started_at,
completed_at=completed_at,
latency_ms=elapsed,
status="paused",
inputs_snapshot=inputs_snapshot,
outputs_snapshot=state.model_dump(mode="json"),
pause_reason=pause_reason,
)
logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason)
result = StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=False,
error=str(exc),
failed_node=node_id,
),
pipeline_span,
pipeline_start_time,
run_id,
)
elapsed = (time.perf_counter() - t0) * 1000
completed_at = datetime.now(UTC)
if node_span is not None:
with contextlib.suppress(Exception):
node_span.end()
paused=True,
paused_node=node_id,
pause_reason=pause_reason,
)
break

await self._emit("on_node_complete", self._name, run_id, node_id, elapsed)
logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed)

if update:
state = apply_update(state, update, self._reducers)

# HITL: a node returning Pause halts the pipeline and writes a
# paused checkpoint. Approval comes via invoke(approve_pause=True).
if isinstance(update, Pause):
pause_reason = update.reason
await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason)
completed.append(node_id)
sequence += 1
self._save_checkpoint(
run_id,
node_id,
sequence,
state,
completed,
paused=True,
pause_reason=pause_reason,
)
self._save_checkpoint(run_id, node_id, sequence, state, completed)
self._audit(
run_id=run_id,
node_id=node_id,
Expand All @@ -617,77 +622,40 @@ async def invoke(
started_at=started_at,
completed_at=completed_at,
latency_ms=elapsed,
status="paused",
status="success",
inputs_snapshot=inputs_snapshot,
outputs_snapshot=state.model_dump(mode="json"),
pause_reason=pause_reason,
)
logger.info("Pipeline '%s' paused at node '%s': %s", self._name, node_id, pause_reason)
return await self._finalize_run(
StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=False,
paused=True,
paused_node=node_id,
pause_reason=pause_reason,
),
pipeline_span,
pipeline_start_time,
run_id,
)

await self._emit("on_node_complete", self._name, run_id, node_id, elapsed)
logger.debug("Pipeline '%s' node '%s' completed in %.1fms", self._name, node_id, elapsed)

if update:
state = apply_update(state, update, self._reducers)

completed.append(node_id)
sequence += 1
self._save_checkpoint(run_id, node_id, sequence, state, completed)
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence,
visit=visit_n,
started_at=started_at,
completed_at=completed_at,
latency_ms=elapsed,
status="success",
inputs_snapshot=inputs_snapshot,
outputs_snapshot=state.model_dump(mode="json"),
)

try:
next_step = self._next_step(node_id, state)
except PipelineError as exc:
return await self._finalize_run(
StatePipelineResult(
try:
next_step = self._next_step(node_id, state)
except PipelineError as exc:
result = StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=False,
error=str(exc),
failed_node=node_id,
),
pipeline_span,
pipeline_start_time,
run_id,
)
break

if result is None:
result = StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=True,
)
finally:
if pipeline_span is not None:
with contextlib.suppress(Exception):
pipeline_span.end()
duration_ms = (time.perf_counter() - pipeline_start_time) * 1000
success = result.success if result is not None else False
await self._emit("on_pipeline_complete", self._name, run_id, success, duration_ms)

return await self._finalize_run(
StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=completed,
success=True,
),
pipeline_span,
pipeline_start_time,
run_id,
)
return result

async def _run_fanout(
self,
Expand Down