Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions fireflyframework_agentic/pipeline/state_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def invoke(
* Mid-pipeline start: ``invoke(state=..., start_at=node)`` —
starts execution at ``node`` with the provided state.
"""
resumed_completed: list[str] = []
completed: list[str] = []

# Resume mode: load checkpoint, derive starting node from it.
if run_id is not None and state is None and start_at is None:
Expand All @@ -420,10 +420,9 @@ async def invoke(
f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume."
)
state = self._state_schema.model_validate(record.state)
resumed_completed = list(record.completed_nodes)
completed = list(record.completed_nodes)
# Resume at the successor of the last completed (or paused) node.
last = record.node_id
next_node = self._next_step(last, state)
next_node = self._next_step(record.node_id, state)
# Resume can't seamlessly continue mid-fan-out yet; treat fan-out as terminal here.
if isinstance(next_node, list):
raise PipelineError(
Expand All @@ -434,7 +433,7 @@ async def invoke(
return StatePipelineResult(
state=state,
run_id=run_id,
completed_nodes=resumed_completed,
completed_nodes=completed,
success=True,
)
current_node: str | None = next_node
Expand All @@ -458,8 +457,6 @@ async def invoke(
run_id = uuid.uuid4().hex[:12]

assert state is not None # narrowed by the branches above
completed: list[str] = list(resumed_completed)
sequence = len(completed)
visit_counts: dict[str, int] = {}

next_step: str | list[Send] | None = current_node
Expand All @@ -475,12 +472,11 @@ async def invoke(
# --- fan-out branch (list[Send]) ---------------------------------
if isinstance(next_step, list):
try:
state, sequence = await self._run_fanout(
state = await self._run_fanout(
sends=next_step,
state=state,
completed=completed,
run_id=run_id,
sequence=sequence,
visit_counts=visit_counts,
)
except _NodeFailureError as fail:
Expand Down Expand Up @@ -539,7 +535,7 @@ async def invoke(
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence + 1,
sequence=len(completed) + 1,
visit=visit_n,
started_at=started_at,
completed_at=datetime.now(UTC),
Expand Down Expand Up @@ -570,11 +566,10 @@ async def invoke(
pause_reason = update.reason
await self._emit("on_node_pause", self._name, run_id, node_id, pause_reason)
completed.append(node_id)
sequence += 1
self._save_checkpoint(
run_id,
node_id,
sequence,
len(completed),
state,
completed,
paused=True,
Expand All @@ -583,7 +578,7 @@ async def invoke(
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence,
sequence=len(completed),
visit=visit_n,
started_at=started_at,
completed_at=completed_at,
Expand Down Expand Up @@ -612,12 +607,11 @@ async def invoke(
state = apply_update(state, update, self._reducers)

completed.append(node_id)
sequence += 1
self._save_checkpoint(run_id, node_id, sequence, state, completed)
self._save_checkpoint(run_id, node_id, len(completed), state, completed)
self._audit(
run_id=run_id,
node_id=node_id,
sequence=sequence,
sequence=len(completed),
visit=visit_n,
started_at=started_at,
completed_at=completed_at,
Expand Down Expand Up @@ -664,9 +658,8 @@ async def _run_fanout(
state: BaseModel,
completed: list[str],
run_id: str,
sequence: int,
visit_counts: dict[str, int],
) -> tuple[BaseModel, int]:
) -> BaseModel:
"""Run all ``Send`` dispatches concurrently. Each task gets its own state
copy with the Send's payload merged in; results are reduced into shared state.
"""
Expand Down Expand Up @@ -719,10 +712,9 @@ async def _run_one(send: Send, visit_n: int) -> tuple[Send, dict[str, Any] | Non
if update:
new_state = apply_update(new_state, update, self._reducers)
completed.append(send.target)
sequence += 1
self._save_checkpoint(run_id, send.target, sequence, new_state, completed)
self._save_checkpoint(run_id, send.target, len(completed), new_state, completed)

return new_state, sequence
return new_state

def _save_checkpoint(
self,
Expand Down