diff --git a/fireflyframework_agentic/pipeline/dag.py b/fireflyframework_agentic/pipeline/dag.py index 6272068d..1656cc42 100644 --- a/fireflyframework_agentic/pipeline/dag.py +++ b/fireflyframework_agentic/pipeline/dag.py @@ -166,7 +166,16 @@ def add_edge(self, edge: DAGEdge) -> None: # -- Query ------------------------------------------------------------- def topological_sort(self) -> list[str]: - """Return node IDs in topological order (Kahn's algorithm).""" + """Return node IDs in topological order (Kahn's algorithm). + + Raises :class:`PipelineError` if the DAG contains a cycle. Cyclic + graphs have no topological order; the caller should branch on + :meth:`is_cyclic` first (or use the engine's cycle-aware scheduler). + """ + if self._has_cycle(): + raise PipelineError( + "topological_sort() is not defined on cyclic graphs; use is_cyclic() to branch before calling." + ) in_deg = dict(self._in_degree) for nid in self._nodes: in_deg.setdefault(nid, 0) @@ -181,16 +190,19 @@ def topological_sort(self) -> list[str]: if in_deg[neighbour] == 0: queue.append(neighbour) - if len(order) != len(self._nodes): - raise PipelineError("DAG contains a cycle (should not reach here)") return order def execution_levels(self) -> list[list[str]]: """Group nodes into levels for parallel execution. Nodes at the same level have no inter-dependencies and can be - executed concurrently. + executed concurrently. Raises :class:`PipelineError` on cyclic + DAGs — levels are undefined when cycles exist. """ + if self._has_cycle(): + raise PipelineError( + "execution_levels() is not defined on cyclic graphs; use is_cyclic() to branch before calling." + ) in_deg = dict(self._in_degree) for nid in self._nodes: in_deg.setdefault(nid, 0) diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 448f1ac7..556fc071 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -187,6 +187,7 @@ def __init__( checkpointer: Checkpointer | None = None, audit_log: AuditLog | None = None, state_schema: type[BaseModel] | None = None, + recursion_limit: int = 25, ) -> None: self._dag = dag self._event_handler = event_handler @@ -197,6 +198,8 @@ def __init__( # continue to flow through edges as port outputs. Both can coexist. self._state_schema = state_schema self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {} + # Max visits per node for cycle-aware runs. Matches StatePipeline's default. + self._recursion_limit = recursion_limit # Per-method signature cache for legacy-vs-unified dispatch. self._handler_params: dict[str, set[str]] = {} @@ -298,6 +301,20 @@ async def run( ) await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id) + # Cycle-aware mode: a separate sequential frontier-following scheduler + # that respects ``recursion_limit``. The topological scheduler below + # cannot run cyclic graphs because execution_levels()/topological_sort() + # are undefined on them. + if self._dag.is_cyclic(): + return await self._run_cyclic( + context=context, + run_id=run_id, + all_results=all_results, + pre_completed=pre_completed, + sequence_start=sequence_start, + pipeline_span=_pipeline_span, + ) + # Topological levels ensure that all upstream dependencies of a node # complete before the node itself executes. Nodes within the same # level are independent and run concurrently via asyncio.gather. @@ -524,6 +541,166 @@ async def _record_skip(nid: str) -> None: final_state=context.state, ) + async def _run_cyclic( + self, + *, + context: PipelineContext, + run_id: str, + all_results: dict[str, NodeResult], + pre_completed: set[str], + sequence_start: int, + pipeline_span: Any, + ) -> PipelineResult: + """Sequential frontier-following scheduler for cyclic DAGs. + + Walks the graph one node at a time, picking the next node from each + completed node's alive outgoing edges. Visit counts are tracked per + node and bounded by ``self._recursion_limit``. Within this mode, + having multiple alive outgoing edges is currently an error — parallel + cyclic fan-out is the job of :class:`Send` in a later layer. + """ + trace_entries: list[ExecutionTraceEntry] = [] + pipeline_start = time.perf_counter() + visit_counts: dict[str, int] = dict.fromkeys(pre_completed, 1) + sequence = sequence_start + + # Entry node: insertion order, matching StatePipeline. + nodes_in_order = list(self._dag.nodes) + if not nodes_in_order: + raise PipelineError("Pipeline has no nodes") + current: str | None = nodes_in_order[0] + # Skip past anything already completed during this resumed run. + while current is not None and current in pre_completed: + current = self._cyclic_next(current, context) + + result: PipelineResult | None = None + try: + while current is not None: + visit_counts[current] = visit_counts.get(current, 0) + 1 + visit_n = visit_counts[current] + if visit_n > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node " + f"'{current}'. Raise recursion_limit= or fix the routing logic." + ) + logger.error(msg) + nr = NodeResult(node_id=current, success=False, error=msg) + all_results[current] = nr + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=current, + sequence=sequence, + nr=nr, + inputs_snapshot={}, + trace_entries=trace_entries, + ) + break + + gathered = self._gather_inputs(current, context) + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=current, + visit=visit_n, + ) + nr = await self._execute_node( + current, + context, + trace_entries, + None, + inputs=gathered, + run_id=run_id, + ) + all_results[current] = nr + context.set_node_result(current, nr) + await self._emit_node_result(nr, run_id) + + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=current, + sequence=sequence, + nr=nr, + inputs_snapshot=gathered, + trace_entries=trace_entries, + visit=visit_n, + ) + + if not nr.success and not nr.skipped: + break + + if not nr.skipped: + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) + self._save_checkpoint( + run_id=run_id, + node_id=current, + sequence=sequence, + context=context, + all_results=all_results, + ) + + current = self._cyclic_next(current, context) + finally: + elapsed = (time.perf_counter() - pipeline_start) * 1000 + success = ( + result.success if result is not None else all(r.success or r.skipped for r in all_results.values()) + ) + await self._dispatch( + "on_pipeline_complete", + pipeline_name=self._dag.name, + run_id=run_id, + success=success, + duration_ms=elapsed, + ) + if pipeline_span is not None: + with contextlib.suppress(Exception): + pipeline_span.end() + + final_output = None + return PipelineResult( + pipeline_name=self._dag.name, + outputs=all_results, + final_output=final_output, + execution_trace=trace_entries, + total_duration_ms=elapsed, + success=success, + usage=None, + run_id=run_id, + final_state=context.state, + ) + + def _cyclic_next(self, current: str, context: PipelineContext) -> str | None: + """Pick the next node by following the unique alive outgoing edge. + + Returns ``None`` when no outgoing edge is alive (terminus). Raises + if more than one is alive — parallel cyclic fan-out lands with + :class:`Send` in a later layer. + """ + + def _alive(edge: Any) -> bool: + if edge.condition is None: + return True + try: + return bool(edge.condition(context)) + except Exception: + return False + + outgoing = [e for e in self._dag.edges if e.source == current] + alive = [e for e in outgoing if _alive(e)] + if not alive: + return None + if len(alive) > 1: + raise PipelineError( + f"Cyclic node '{current}' has multiple alive outgoing edges " + f"({[e.target for e in alive]}). Parallel cyclic fan-out arrives " + f"with Send in a later layer; for now, use mutually exclusive " + f"edge conditions." + ) + return alive[0].target + async def _execute_node( self, node_id: str, @@ -748,11 +925,13 @@ def _record_audit( nr: NodeResult, inputs_snapshot: dict[str, Any], trace_entries: list[ExecutionTraceEntry], + visit: int = 1, ) -> None: """Write an audit entry for a node visit. No-op if no audit log. Skipped nodes are not recorded — they represent work that did NOT - happen and would clutter the trail. + happen and would clutter the trail. ``visit`` defaults to 1 for the + acyclic scheduler and is supplied by the cyclic scheduler. """ if self._audit_log is None or nr.skipped: return @@ -770,7 +949,7 @@ def _record_audit( run_id=run_id, node_id=node_id, sequence=sequence, - visit=1, + visit=visit, started_at=started_at, completed_at=completed_at, latency_ms=nr.latency_ms or 0.0, diff --git a/tests/unit/pipeline/test_pipeline_engine_cycles.py b/tests/unit/pipeline/test_pipeline_engine_cycles.py new file mode 100644 index 00000000..d45b1c16 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_cycles.py @@ -0,0 +1,167 @@ +"""Layer 4 of the unification (#245): cycle-aware scheduler. + +PipelineEngine accepts ``recursion_limit=`` and, when the DAG is cyclic +(allow_cycles=True and a cycle is actually present), switches to a +sequential frontier-following scheduler. Each node visit increments a +per-node counter; exceeding ``recursion_limit`` halts the run with an +explanatory failure. + +This also patches the silent-corruption hazard in :meth:`DAG.topological_sort` +and :meth:`DAG.execution_levels` — both now raise on cyclic DAGs instead +of producing partial / wrong output. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.reducers import append + +# ---- topology-API safety --------------------------------------------------- + + +def test_topological_sort_raises_on_cyclic_dag(): + dag = DAG("cyclic", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=None)) + dag.add_node(DAGNode(node_id="b", step=None)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="a")) + with pytest.raises(PipelineError, match="cyclic"): + dag.topological_sort() + + +def test_execution_levels_raises_on_cyclic_dag(): + dag = DAG("cyclic-lev", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=None)) + dag.add_node(DAGNode(node_id="b", step=None)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="a")) + with pytest.raises(PipelineError, match="cyclic"): + dag.execution_levels() + + +# ---- cyclic execution ------------------------------------------------------ + + +class _CounterState(BaseModel): + counter: int = 0 + log: Annotated[list[str], append] = [] + + +def _bump(label: str, by: int = 1): + """Return a step that records its label and bumps counter by `by`.""" + + class _Step: + def __init__(self): + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return {"counter": ctx.state.counter + by, "log": label} + + return _Step() + + +async def test_cyclic_dag_loops_until_condition_fails(): + """Loop: incrementer -> guard. Guard's outgoing edge back to incrementer + is alive while counter < 3. Loop exits when guard's continue edge dies.""" + inc = _bump("inc", by=1) + # guard is a no-op pass-through. + + class _Pass: + calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return None + + guard = _Pass() + dag = DAG("loop", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=guard)) + dag.add_edge(DAGEdge(source="inc", target="guard")) + # Continue edge: re-enter inc while counter < 3. + dag.add_edge(DAGEdge(source="guard", target="inc", condition=lambda ctx: ctx.state.counter < 3)) + engine = PipelineEngine(dag, state_schema=_CounterState, recursion_limit=10) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.counter == 3 + assert inc.calls == 3 + # guard runs after each inc. + assert guard.calls == 3 + + +async def test_recursion_limit_halts_runaway_cycle(): + inc = _bump("inc") + + class _Pass: + async def execute(self, ctx, inputs): + return None + + dag = DAG("infinite", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=_Pass())) + dag.add_edge(DAGEdge(source="inc", target="guard")) + dag.add_edge(DAGEdge(source="guard", target="inc")) # always alive — runaway + engine = PipelineEngine(dag, state_schema=_CounterState, recursion_limit=5) + result = await engine.run(inputs="") + assert not result.success + assert ( + "recursion" in (result.outputs.get("inc") and result.outputs["inc"].error or "").lower() + or "recursion" in (result.outputs.get("guard") and result.outputs["guard"].error or "").lower() + ) + + +async def test_recursion_limit_default_is_25(): + """The engine's default recursion_limit matches StatePipeline's (25).""" + engine = PipelineEngine(DAG("x")) + assert engine._recursion_limit == 25 # noqa: SLF001 + + +async def test_audit_records_visit_per_iteration(tmp_path): + """Each iteration of a cycle gets its own audit entry with incrementing visit.""" + from fireflyframework_agentic.pipeline.audit import FileAuditLog + + inc = _bump("inc") + + class _Pass: + async def execute(self, ctx, inputs): + return None + + dag = DAG("audited-loop", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=_Pass())) + dag.add_edge(DAGEdge(source="inc", target="guard")) + dag.add_edge(DAGEdge(source="guard", target="inc", condition=lambda ctx: ctx.state.counter < 2)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, state_schema=_CounterState, audit_log=al, recursion_limit=10) + result = await engine.run(inputs="") + assert result.success + entries = al.list_entries("audited-loop", result.run_id) + inc_visits = sorted([e.visit for e in entries if e.node_id == "inc"]) + assert inc_visits == [1, 2] + + +# ---- acyclic still works --------------------------------------------------- + + +async def test_acyclic_dag_with_allow_cycles_true_runs_normally(): + """allow_cycles=True doesn't force cyclic mode if there are no cycles.""" + a = _bump("a") + b = _bump("b") + dag = DAG("ac", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_CounterState) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.counter == 2 + assert a.calls == 1 + assert b.calls == 1