Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 204 additions & 3 deletions fireflyframework_agentic/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import random
import time
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Protocol, runtime_checkable

Expand Down Expand Up @@ -130,6 +131,59 @@ async def on_pipeline_complete(
) -> None: ...


@dataclass
class Pause:
"""Human-in-the-loop sentinel returned by a node to halt the pipeline.

A node returns ``Pause(reason="...")`` when external approval is required
before the pipeline may continue. The engine then:

1. Writes a checkpoint with ``paused=True`` and the reason set.
2. Emits ``on_node_pause`` on the configured event handler.
3. Returns a :class:`PipelineResult` with ``paused=True`` and
``success=False`` — the run is not finished, but it did not fail
either.

Resume after approval::

result = await engine.run(run_id=paused_run_id, approve_pause=True)

The successor of the paused node runs next — the pause node itself is
not re-executed. Without ``approve_pause=True``, resuming a paused run
raises :class:`PipelineError`.
"""

reason: str


@dataclass
class Send:
"""Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state.

A node may return a single ``Send`` or ``list[Send]`` to dispatch one or
more targets concurrently. Each Send's payload is applied to a *copy* of
the current state before its target runs; the target's return is then
merged back into shared state via reducers.

Replaces the legacy ``FanOutStep`` pattern with a first-class primitive.
"""

target: str
payload: dict[str, Any]


def _is_send_payload(value: Any) -> bool:
"""True when a node's return value is a single :class:`Send` or a
non-empty ``list[Send]``. Drives the runtime fan-out branch in
:meth:`PipelineEngine.run`.
"""
if isinstance(value, Send):
return True
if isinstance(value, list) and value and all(isinstance(s, Send) for s in value):
return True
return False


def _serialize_value(value: Any) -> Any:
"""Best-effort conversion of arbitrary values into JSON-safe form.

Expand Down Expand Up @@ -245,6 +299,7 @@ async def run(
inputs: Any = None,
state: BaseModel | None = None,
run_id: str | None = None,
approve_pause: bool = False,
) -> PipelineResult:
"""Execute the pipeline.

Expand All @@ -265,7 +320,7 @@ async def run(
"""
if run_id is not None and context is None and inputs is None and state is None:
resume_run_id: str = run_id
context, pre_completed, sequence_start = self._load_for_resume(resume_run_id)
context, pre_completed, sequence_start = self._load_for_resume(resume_run_id, approve_pause=approve_pause)
all_results: dict[str, NodeResult] = {
nid: nr
for nid in pre_completed
Expand Down Expand Up @@ -338,6 +393,7 @@ async def run(
inputs_by_node: dict[str, dict[str, Any]] = {}
sequence = sequence_start
abort = False
pending_pause: tuple[str, str] | None = None # (node_id, reason) if Pause

def _edge_alive(edge: Any) -> bool:
"""An edge is alive if it has no condition, or its condition returns True.
Expand Down Expand Up @@ -465,6 +521,48 @@ async def _record_skip(nid: str) -> None:
inputs_snapshot=inputs_by_node.get(node_id, {}),
trace_entries=trace_entries,
)
# HITL: a node returned Pause(reason=...). Halt cleanly, save
# a paused checkpoint, and surface the pause in the result.
if nr.success and isinstance(nr.output, Pause):
pause_reason = nr.output.reason
await self._dispatch(
"on_node_pause",
pipeline_name=self._dag.name,
run_id=run_id,
node_id=node_id,
reason=pause_reason,
)
self._save_checkpoint(
run_id=run_id,
node_id=node_id,
sequence=sequence,
context=context,
all_results=all_results,
paused=True,
pause_reason=pause_reason,
)
pending_pause = (node_id, pause_reason)
abort = True
continue

# Runtime fan-out: a node returned Send / list[Send].
if nr.success and _is_send_payload(nr.output):
sends = nr.output if isinstance(nr.output, list) else [nr.output]
ok = await self._run_sends(
sends=sends,
context=context,
run_id=run_id,
all_results=all_results,
trace_entries=trace_entries,
completed=completed,
pending=pending,
)
if not ok:
abort = True
# Successors of the worker targets are picked up by the
# normal readiness sweep on the next loop iteration.
continue

if nr.success and not nr.skipped:
# State overlay: a dict return from the node is a state
# update; non-dict returns flow through edges as ports.
Expand Down Expand Up @@ -529,6 +627,12 @@ async def _record_skip(nid: str) -> None:
if _pipeline_span is not None:
_pipeline_span.end()

paused_node = pending_pause[0] if pending_pause else None
pause_reason_final = pending_pause[1] if pending_pause else None
if pending_pause is not None:
# A paused run is not "successful" — it didn't finish.
success = False

return PipelineResult(
pipeline_name=self._dag.name,
outputs=all_results,
Expand All @@ -539,8 +643,91 @@ async def _record_skip(nid: str) -> None:
usage=usage_summary,
run_id=run_id,
final_state=context.state,
paused=pending_pause is not None,
paused_node=paused_node,
pause_reason=pause_reason_final,
)

async def _run_sends(
self,
*,
sends: list[Send],
context: PipelineContext,
run_id: str,
all_results: dict[str, NodeResult],
trace_entries: list[ExecutionTraceEntry],
completed: set[str],
pending: set[str],
) -> bool:
"""Dispatch a list of :class:`Send` workers concurrently.

Each Send's payload is applied to a copy of the current state before
its target runs. Results merge back into shared state via reducers.
Targets are added to ``completed`` and removed from ``pending`` so
the main scheduler does not re-execute them.

Returns ``True`` on success, ``False`` if any worker failed (the
caller treats this as an abort signal).
"""
# Validate targets up front so unknown ones fail loud, not after gather().
for send in sends:
if send.target not in self._dag.nodes:
nr = NodeResult(
node_id=send.target,
success=False,
error=f"Send dispatches to unknown target '{send.target}'",
)
all_results[send.target] = nr
return False

async def _run_one(send: Send) -> tuple[Send, NodeResult]:
await self._dispatch(
"on_node_start",
pipeline_name=self._dag.name,
run_id=run_id,
node_id=send.target,
visit=1,
)
# Per-worker context: own state copy with payload applied so
# workers don't race on the shared state object.
worker_context = PipelineContext(inputs=context.inputs)
if self._state_schema is not None and context.state is not None:
worker_context.state = apply_update(context.state, send.payload, self._reducers)
for nid, prev in context.results.items():
worker_context.set_node_result(nid, prev)
nr = await self._execute_node(
send.target,
worker_context,
trace_entries,
None,
inputs={"input": send.payload},
run_id=run_id,
)
return send, nr

try:
results = await asyncio.gather(*(_run_one(s) for s in sends))
except Exception as exc:
logger.exception("Fan-out worker crashed")
for send in sends:
if send.target not in all_results:
all_results[send.target] = NodeResult(node_id=send.target, success=False, error=str(exc))
return False

all_ok = True
for send, nr in results:
all_results[send.target] = nr
context.set_node_result(send.target, nr)
completed.add(send.target)
pending.discard(send.target)
await self._emit_node_result(nr, run_id)
if not nr.success:
all_ok = False
continue
if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict):
context.state = apply_update(context.state, nr.output, self._reducers)
return all_ok

async def _run_cyclic(
self,
*,
Expand Down Expand Up @@ -858,13 +1045,23 @@ async def _emit_node_result(self, nr: NodeResult, run_id: str) -> None:
else:
await self._dispatch("on_node_error", error=nr.error or "unknown", **common)

def _load_for_resume(self, run_id: str) -> tuple[PipelineContext, set[str], int]:
"""Rebuild context + completed-set from the latest checkpoint."""
def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, set[str], int]:
"""Rebuild context + completed-set from the latest checkpoint.

Resuming a paused run (checkpoint.paused=True) requires
``approve_pause=True``; otherwise a :class:`PipelineError` halts the
attempt and surfaces the pause reason.
"""
if self._checkpointer is None:
raise PipelineError("Cannot resume: pipeline has no checkpointer configured")
record = self._checkpointer.load_latest(self._dag.name, run_id)
if record is None:
raise PipelineError(f"No checkpoint found for run_id='{run_id}'")
if record.paused and not approve_pause:
raise PipelineError(
f"Run '{run_id}' is paused at node '{record.node_id}' "
f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume."
)
context = PipelineContext(inputs=record.state.get("inputs"))
for nid, nr_dict in record.state.get("results", {}).items():
try:
Expand All @@ -888,6 +1085,8 @@ def _save_checkpoint(
sequence: int,
context: PipelineContext,
all_results: dict[str, NodeResult],
paused: bool = False,
pause_reason: str | None = None,
) -> None:
"""Persist state after a successful node. No-op if no checkpointer.

Expand All @@ -911,6 +1110,8 @@ def _save_checkpoint(
sequence=sequence,
state=state,
completed_nodes=completed_successful,
paused=paused,
pause_reason=pause_reason,
)
)
except Exception:
Expand Down
5 changes: 5 additions & 0 deletions fireflyframework_agentic/pipeline/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class PipelineResult(BaseModel):
# Final shared state for pipelines configured with state_schema. None
# when the engine had no state overlay.
final_state: Any = None
# HITL: a node returned :class:`Pause` and the run halted cleanly.
# Resume via ``engine.run(run_id=..., approve_pause=True)``.
paused: bool = False
paused_node: str | None = None
pause_reason: str | None = None

@property
def failed_nodes(self) -> list[str]:
Expand Down
43 changes: 1 addition & 42 deletions fireflyframework_agentic/pipeline/state_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus
from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord
from fireflyframework_agentic.pipeline.dag import DAG, _mermaid_id
from fireflyframework_agentic.pipeline.engine import start_otel_span
from fireflyframework_agentic.pipeline.engine import Pause, Send, start_otel_span
from fireflyframework_agentic.pipeline.reducers import apply_update, discover_reducers

if TYPE_CHECKING:
Expand All @@ -53,51 +53,10 @@
RouterFn = Callable[[Any], "str | Send | list[Send]"]


@dataclass
class Send:
"""Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state.

Routers can return a single ``Send`` or a list of ``Send`` to dispatch multiple
target invocations concurrently. Each Send's payload is applied to a *copy*
of the current state before its target runs; the target's return is then
merged back into shared state via reducers.

Replaces the legacy ``FanOutStep`` pattern with a first-class primitive.
"""

target: str
payload: dict[str, Any]


class RecursionLimitError(Exception):
"""Raised when a node is visited more times than ``recursion_limit`` permits."""


@dataclass
class Pause:
"""Human-in-the-loop sentinel returned by a node to halt the pipeline.

A node returns ``Pause(reason="...")`` when external approval (a human,
another system, a wall-clock event) is required before the pipeline may
continue. The pipeline then:

1. Writes a checkpoint with ``paused=True`` and the reason set.
2. Emits ``on_node_pause`` on the configured event handler.
3. Returns a :class:`StatePipelineResult` with ``paused=True`` and
``success=False`` — the run is not finished, but it did not fail either.

To resume after approval::

result = await pipeline.invoke(run_id=paused_run_id, approve_pause=True)

Without ``approve_pause=True``, resuming a paused run raises
:class:`PipelineError`. The successor of the paused node runs next —
the pause node itself is not re-executed.
"""

reason: str


@dataclass
class BranchSpec:
"""Internal: registered branch from one source node."""
Expand Down
Loading