Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions fireflyframework_agentic/pipeline/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,16 @@ def add_edge(self, edge: DAGEdge) -> None:
# -- Query -------------------------------------------------------------

def topological_sort(self) -> list[str]:
"""Return node IDs in topological order (Kahn's algorithm)."""
"""Return node IDs in topological order (Kahn's algorithm).

Raises :class:`PipelineError` if the DAG contains a cycle. Cyclic
graphs have no topological order; the caller should branch on
:meth:`is_cyclic` first (or use the engine's cycle-aware scheduler).
"""
if self._has_cycle():
raise PipelineError(
"topological_sort() is not defined on cyclic graphs; use is_cyclic() to branch before calling."
)
in_deg = dict(self._in_degree)
for nid in self._nodes:
in_deg.setdefault(nid, 0)
Expand All @@ -181,16 +190,19 @@ def topological_sort(self) -> list[str]:
if in_deg[neighbour] == 0:
queue.append(neighbour)

if len(order) != len(self._nodes):
raise PipelineError("DAG contains a cycle (should not reach here)")
return order

def execution_levels(self) -> list[list[str]]:
"""Group nodes into levels for parallel execution.

Nodes at the same level have no inter-dependencies and can be
executed concurrently.
executed concurrently. Raises :class:`PipelineError` on cyclic
DAGs — levels are undefined when cycles exist.
"""
if self._has_cycle():
raise PipelineError(
"execution_levels() is not defined on cyclic graphs; use is_cyclic() to branch before calling."
)
in_deg = dict(self._in_degree)
for nid in self._nodes:
in_deg.setdefault(nid, 0)
Expand Down
183 changes: 181 additions & 2 deletions fireflyframework_agentic/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
checkpointer: Checkpointer | None = None,
audit_log: AuditLog | None = None,
state_schema: type[BaseModel] | None = None,
recursion_limit: int = 25,
) -> None:
self._dag = dag
self._event_handler = event_handler
Expand All @@ -197,6 +198,8 @@ def __init__(
# continue to flow through edges as port outputs. Both can coexist.
self._state_schema = state_schema
self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {}
# Max visits per node for cycle-aware runs. Matches StatePipeline's default.
self._recursion_limit = recursion_limit
# Per-method signature cache for legacy-vs-unified dispatch.
self._handler_params: dict[str, set[str]] = {}

Expand Down Expand Up @@ -298,6 +301,20 @@ async def run(
)
await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id)

# Cycle-aware mode: a separate sequential frontier-following scheduler
# that respects ``recursion_limit``. The topological scheduler below
# cannot run cyclic graphs because execution_levels()/topological_sort()
# are undefined on them.
if self._dag.is_cyclic():
return await self._run_cyclic(
context=context,
run_id=run_id,
all_results=all_results,
pre_completed=pre_completed,
sequence_start=sequence_start,
pipeline_span=_pipeline_span,
)

# Topological levels ensure that all upstream dependencies of a node
# complete before the node itself executes. Nodes within the same
# level are independent and run concurrently via asyncio.gather.
Expand Down Expand Up @@ -524,6 +541,166 @@ async def _record_skip(nid: str) -> None:
final_state=context.state,
)

async def _run_cyclic(
self,
*,
context: PipelineContext,
run_id: str,
all_results: dict[str, NodeResult],
pre_completed: set[str],
sequence_start: int,
pipeline_span: Any,
) -> PipelineResult:
"""Sequential frontier-following scheduler for cyclic DAGs.

Walks the graph one node at a time, picking the next node from each
completed node's alive outgoing edges. Visit counts are tracked per
node and bounded by ``self._recursion_limit``. Within this mode,
having multiple alive outgoing edges is currently an error — parallel
cyclic fan-out is the job of :class:`Send` in a later layer.
"""
trace_entries: list[ExecutionTraceEntry] = []
pipeline_start = time.perf_counter()
visit_counts: dict[str, int] = dict.fromkeys(pre_completed, 1)
sequence = sequence_start

# Entry node: insertion order, matching StatePipeline.
nodes_in_order = list(self._dag.nodes)
if not nodes_in_order:
raise PipelineError("Pipeline has no nodes")
current: str | None = nodes_in_order[0]
# Skip past anything already completed during this resumed run.
while current is not None and current in pre_completed:
current = self._cyclic_next(current, context)

result: PipelineResult | None = None
try:
while current is not None:
visit_counts[current] = visit_counts.get(current, 0) + 1
visit_n = visit_counts[current]
if visit_n > self._recursion_limit:
msg = (
f"Recursion limit ({self._recursion_limit}) exceeded at node "
f"'{current}'. Raise recursion_limit= or fix the routing logic."
)
logger.error(msg)
nr = NodeResult(node_id=current, success=False, error=msg)
all_results[current] = nr
sequence += 1
self._record_audit(
run_id=run_id,
node_id=current,
sequence=sequence,
nr=nr,
inputs_snapshot={},
trace_entries=trace_entries,
)
break

gathered = self._gather_inputs(current, context)
await self._dispatch(
"on_node_start",
pipeline_name=self._dag.name,
run_id=run_id,
node_id=current,
visit=visit_n,
)
nr = await self._execute_node(
current,
context,
trace_entries,
None,
inputs=gathered,
run_id=run_id,
)
all_results[current] = nr
context.set_node_result(current, nr)
await self._emit_node_result(nr, run_id)

sequence += 1
self._record_audit(
run_id=run_id,
node_id=current,
sequence=sequence,
nr=nr,
inputs_snapshot=gathered,
trace_entries=trace_entries,
visit=visit_n,
)

if not nr.success and not nr.skipped:
break

if not nr.skipped:
if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict):
context.state = apply_update(context.state, nr.output, self._reducers)
self._save_checkpoint(
run_id=run_id,
node_id=current,
sequence=sequence,
context=context,
all_results=all_results,
)

current = self._cyclic_next(current, context)
finally:
elapsed = (time.perf_counter() - pipeline_start) * 1000
success = (
result.success if result is not None else all(r.success or r.skipped for r in all_results.values())
)
await self._dispatch(
"on_pipeline_complete",
pipeline_name=self._dag.name,
run_id=run_id,
success=success,
duration_ms=elapsed,
)
if pipeline_span is not None:
with contextlib.suppress(Exception):
pipeline_span.end()

final_output = None
return PipelineResult(
pipeline_name=self._dag.name,
outputs=all_results,
final_output=final_output,
execution_trace=trace_entries,
total_duration_ms=elapsed,
success=success,
usage=None,
run_id=run_id,
final_state=context.state,
)

def _cyclic_next(self, current: str, context: PipelineContext) -> str | None:
"""Pick the next node by following the unique alive outgoing edge.

Returns ``None`` when no outgoing edge is alive (terminus). Raises
if more than one is alive — parallel cyclic fan-out lands with
:class:`Send` in a later layer.
"""

def _alive(edge: Any) -> bool:
if edge.condition is None:
return True
try:
return bool(edge.condition(context))
except Exception:
return False

outgoing = [e for e in self._dag.edges if e.source == current]
alive = [e for e in outgoing if _alive(e)]
if not alive:
return None
if len(alive) > 1:
raise PipelineError(
f"Cyclic node '{current}' has multiple alive outgoing edges "
f"({[e.target for e in alive]}). Parallel cyclic fan-out arrives "
f"with Send in a later layer; for now, use mutually exclusive "
f"edge conditions."
)
return alive[0].target

async def _execute_node(
self,
node_id: str,
Expand Down Expand Up @@ -748,11 +925,13 @@ def _record_audit(
nr: NodeResult,
inputs_snapshot: dict[str, Any],
trace_entries: list[ExecutionTraceEntry],
visit: int = 1,
) -> None:
"""Write an audit entry for a node visit. No-op if no audit log.

Skipped nodes are not recorded — they represent work that did NOT
happen and would clutter the trail.
happen and would clutter the trail. ``visit`` defaults to 1 for the
acyclic scheduler and is supplied by the cyclic scheduler.
"""
if self._audit_log is None or nr.skipped:
return
Expand All @@ -770,7 +949,7 @@ def _record_audit(
run_id=run_id,
node_id=node_id,
sequence=sequence,
visit=1,
visit=visit,
started_at=started_at,
completed_at=completed_at,
latency_ms=nr.latency_ms or 0.0,
Expand Down
Loading