From 864794ce175151a08f5ef9335c9afb6d6ffd5ca0 Mon Sep 17 00:00:00 2001 From: Barry <53959224+cal-gooo@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:26:51 -0400 Subject: [PATCH] feat(agent): add explicit execution state machine Introduces AgentStateMachine and AgentExecutionState to make the agent's implicit lifecycle phases observable and serializable, enabling durable agents and fine-grained hook-based observability. Key additions: - `AgentStateMachine` with validated state transitions and listener callbacks - `AgentExecutionState` enum: IDLE, INITIALIZING, MODEL_CALL, TOOL_EXECUTION, INTERRUPTED, COMPLETED, CANCELLED, ERROR - `CHECKPOINT_STATES` marks safe points for durable agent snapshots (IDLE, INTERRUPTED, COMPLETED) - `AgentStateTransitionEvent` hook fires on every state change, enabling durability checkpoints, metrics, and audit logging - `agent.state_machine.to_dict()` / `from_dict()` for serialization Wired into Agent.stream_async, _run_loop, and event_loop_cycle so all transitions are tracked automatically. Closes #1921 Co-Authored-By: Claude Sonnet 4.6 --- src/strands/agent/__init__.py | 5 + src/strands/agent/agent.py | 32 +++ src/strands/agent/state_machine.py | 364 +++++++++++++++++++++++++++ src/strands/event_loop/event_loop.py | 9 + src/strands/hooks/__init__.py | 2 + src/strands/hooks/events.py | 38 +++ 6 files changed, 450 insertions(+) create mode 100644 src/strands/agent/state_machine.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c901e800f..466d18439 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -19,11 +19,16 @@ SlidingWindowConversationManager, SummarizingConversationManager, ) +from .state_machine import AgentExecutionState, AgentStateMachine, CHECKPOINT_STATES, InvalidStateTransitionError __all__ = [ "Agent", "AgentBase", "AgentResult", + "AgentExecutionState", + "AgentStateMachine", + "CHECKPOINT_STATES", + "InvalidStateTransitionError", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..93c0b7534 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -42,6 +42,7 @@ HookRegistry, MessageAddedEvent, ) +from ..hooks.events import AgentStateTransitionEvent from ..hooks.registry import TEvent from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel @@ -69,6 +70,7 @@ SlidingWindowConversationManager, ) from .state import AgentState +from .state_machine import AgentExecutionState, AgentStateMachine logger = logging.getLogger(__name__) @@ -279,6 +281,11 @@ def __init__( self._interrupt_state = _InterruptState() + # Formal execution state machine — tracks the agent's lifecycle phase and + # emits AgentStateTransitionEvent hook events on every transition. + self.state_machine = AgentStateMachine() + self.state_machine.add_listener(self._on_state_transition) + # Initialize lock for guarding concurrent invocations # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads, so asyncio.Lock wouldn't work @@ -330,6 +337,12 @@ def __init__( self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + def _on_state_transition( + self, old_state: AgentExecutionState, new_state: AgentExecutionState + ) -> None: + """Fire an :class:`~strands.hooks.events.AgentStateTransitionEvent` hook on every state change.""" + self.hooks.invoke_callbacks(AgentStateTransitionEvent(agent=self, old_state=old_state, new_state=new_state)) + def cancel(self) -> None: """Cancel the currently running agent invocation. @@ -744,6 +757,7 @@ async def stream_async( ) try: + self.state_machine.transition(AgentExecutionState.INITIALIZING) self._interrupt_state.resume(prompt) self.event_loop_metrics.reset_usage_metrics() @@ -793,6 +807,13 @@ async def stream_async( # Clear cancel signal to allow agent reuse after cancellation self._cancel_signal.clear() + # Return to IDLE unless we're paused at an INTERRUPTED checkpoint, + # which persists so external code can observe the paused state. + # Use reset() rather than transition() so exceptions in any running + # state (e.g. MODEL_CALL) don't strand the state machine. + if self.state_machine.state != AgentExecutionState.INTERRUPTED: + self.state_machine.reset() + if self._invocation_lock.locked(): self._invocation_lock.release() @@ -857,6 +878,13 @@ async def _run_loop( # Capture the result from the final event if available if isinstance(event, EventLoopStopEvent): agent_result = AgentResult(*event["stop"]) + stop_reason = event["stop"][0] + if stop_reason == "interrupt": + self.state_machine.transition(AgentExecutionState.INTERRUPTED) + elif stop_reason == "cancelled": + self.state_machine.transition(AgentExecutionState.CANCELLED) + else: + self.state_machine.transition(AgentExecutionState.COMPLETED) finally: self.conversation_manager.apply_management(self) @@ -872,6 +900,7 @@ async def _run_loop( # raise TypeError if the resume input is not valid interrupt responses. self._interrupt_state.resume(after_invocation_event.resume) current_messages = await self._convert_prompt_to_messages(after_invocation_event.resume) + self.state_machine.transition(AgentExecutionState.INITIALIZING) else: current_messages = None @@ -914,6 +943,9 @@ async def _execute_event_loop_cycle( if self._session_manager: self._session_manager.sync_agent(self) + # Return to INITIALIZING so event_loop_cycle can re-enter MODEL_CALL cleanly + self.state_machine.try_transition(AgentExecutionState.INITIALIZING) + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) async for event in events: yield event diff --git a/src/strands/agent/state_machine.py b/src/strands/agent/state_machine.py new file mode 100644 index 000000000..126446452 --- /dev/null +++ b/src/strands/agent/state_machine.py @@ -0,0 +1,364 @@ +"""Agent execution state machine. + +This module provides a formal state machine representation of the agent execution +lifecycle. Making states explicit enables: + +- **Durable Agents**: Serialize/restore agent state at safe checkpoint states + (``INTERRUPTED``, ``COMPLETED``) to survive process restarts or failures. +- **Observability**: Inspect ``agent.state_machine.state`` at any point to know + exactly which execution phase the agent is in. +- **Hook integration**: React to state changes via ``AgentStateTransitionEvent`` + to implement custom logic at lifecycle boundaries. + +Example:: + + from strands import Agent + from strands.agent.state_machine import AgentExecutionState + + agent = Agent() + + # Observe current state + print(agent.state_machine.state) # AgentExecutionState.IDLE + + # Listen for transitions + def on_transition(old: AgentExecutionState, new: AgentExecutionState) -> None: + print(f"Agent moved from {old.value!r} -> {new.value!r}") + + agent.state_machine.add_listener(on_transition) + + # Serialize for durable checkpointing + snapshot = agent.state_machine.to_dict() + # ... store snapshot ... + agent.state_machine = AgentStateMachine.from_dict(snapshot) +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable + +logger = logging.getLogger(__name__) + + +class AgentExecutionState(str, Enum): + """Formal states in the agent execution lifecycle. + + The state machine follows this high-level flow:: + + IDLE + │ (invocation begins, lock acquired) + ▼ + INITIALIZING + │ (prompt converted, messages appended) + ▼ + MODEL_CALL ◄──────────────────────────────────────────┐ + │ (model response received) │ + ▼ │ + TOOL_EXECUTION ─── (tools done, recurse) ─────────────┘ + │ (interrupt raised by tool) + ▼ + INTERRUPTED (checkpoint — safe to serialize here) + │ (resume called on next invocation) + └──► INITIALIZING + + MODEL_CALL / TOOL_EXECUTION + │ (normal end_turn / stop_sequence) + ▼ + COMPLETED (checkpoint — safe to serialize here) + │ (AfterInvocationEvent.resume set → loop again) + ├──► INITIALIZING + └──► IDLE + + MODEL_CALL / TOOL_EXECUTION / INITIALIZING + │ (agent.cancel() called) + ▼ + CANCELLED + └──► IDLE + + Any running state + │ (unhandled exception) + ▼ + ERROR + └──► IDLE + """ + + IDLE = "idle" + """Agent exists but no invocation is in progress.""" + + INITIALIZING = "initializing" + """Invocation lock acquired; prompt is being converted to messages.""" + + MODEL_CALL = "model_call" + """Model API call in flight; streaming response chunks.""" + + TOOL_EXECUTION = "tool_execution" + """One or more tools requested by the model are being executed.""" + + INTERRUPTED = "interrupted" + """Paused for human-in-the-loop input. Safe checkpoint state.""" + + COMPLETED = "completed" + """Invocation finished successfully. Safe checkpoint state.""" + + CANCELLED = "cancelled" + """Invocation was cancelled via ``agent.cancel()``.""" + + ERROR = "error" + """Unhandled exception occurred during execution.""" + + +# --------------------------------------------------------------------------- +# Valid transitions +# --------------------------------------------------------------------------- + +_TRANSITIONS: dict[AgentExecutionState, frozenset[AgentExecutionState]] = { + AgentExecutionState.IDLE: frozenset( + [AgentExecutionState.INITIALIZING] + ), + AgentExecutionState.INITIALIZING: frozenset( + [ + AgentExecutionState.MODEL_CALL, + AgentExecutionState.TOOL_EXECUTION, # resuming from interrupt skips model call + AgentExecutionState.COMPLETED, # short-circuit: mocked/overridden event loops + AgentExecutionState.INTERRUPTED, # short-circuit: mocked/overridden event loops + AgentExecutionState.CANCELLED, + AgentExecutionState.ERROR, + ] + ), + AgentExecutionState.MODEL_CALL: frozenset( + [ + AgentExecutionState.TOOL_EXECUTION, + AgentExecutionState.INITIALIZING, # context window overflow retry + AgentExecutionState.COMPLETED, + AgentExecutionState.CANCELLED, + AgentExecutionState.ERROR, + ] + ), + AgentExecutionState.TOOL_EXECUTION: frozenset( + [ + AgentExecutionState.MODEL_CALL, # recurse for next turn + AgentExecutionState.INTERRUPTED, + AgentExecutionState.COMPLETED, + AgentExecutionState.CANCELLED, + AgentExecutionState.ERROR, + ] + ), + AgentExecutionState.INTERRUPTED: frozenset( + [ + AgentExecutionState.IDLE, # waiting for external resume + AgentExecutionState.INITIALIZING, # resume called in same session + ] + ), + AgentExecutionState.COMPLETED: frozenset( + [ + AgentExecutionState.IDLE, + AgentExecutionState.INITIALIZING, # AfterInvocationEvent.resume set + ] + ), + AgentExecutionState.CANCELLED: frozenset( + [AgentExecutionState.IDLE] + ), + AgentExecutionState.ERROR: frozenset( + [AgentExecutionState.IDLE] + ), +} + +# States where it is safe to snapshot agent data for durability +CHECKPOINT_STATES: frozenset[AgentExecutionState] = frozenset( + [AgentExecutionState.IDLE, AgentExecutionState.INTERRUPTED, AgentExecutionState.COMPLETED] +) + + +class InvalidStateTransitionError(Exception): + """Raised when an invalid state transition is attempted. + + Attributes: + from_state: The state the machine was in. + to_state: The state the transition was attempted to. + allowed: The set of valid target states from ``from_state``. + """ + + def __init__( + self, + from_state: AgentExecutionState, + to_state: AgentExecutionState, + allowed: frozenset[AgentExecutionState], + ) -> None: + self.from_state = from_state + self.to_state = to_state + self.allowed = allowed + super().__init__( + f"Invalid state transition: {from_state.value!r} -> {to_state.value!r}. " + f"Allowed targets: {sorted(s.value for s in allowed)}" + ) + + +# --------------------------------------------------------------------------- +# State machine +# --------------------------------------------------------------------------- + +TransitionListener = Callable[[AgentExecutionState, AgentExecutionState], None] +"""Callable invoked synchronously on every successful state transition. + +Args: + old_state: The state before the transition. + new_state: The state after the transition. +""" + + +@dataclass +class AgentStateMachine: + """Tracks the current execution state of an :class:`~strands.agent.Agent`. + + The machine validates every transition against the allowed transition table + and notifies registered listeners synchronously before returning. + + Attributes: + state: The current execution state. + """ + + state: AgentExecutionState = AgentExecutionState.IDLE + _listeners: list[TransitionListener] = field(default_factory=list, repr=False) + + # ------------------------------------------------------------------ + # Transition + # ------------------------------------------------------------------ + + def transition(self, new_state: AgentExecutionState) -> None: + """Transition to *new_state*, validating the transition first. + + Args: + new_state: The target state. + + Raises: + InvalidStateTransitionError: If the transition from the current + state to *new_state* is not permitted. + """ + allowed = _TRANSITIONS.get(self.state, frozenset()) + if new_state not in allowed: + raise InvalidStateTransitionError(self.state, new_state, allowed) + + old_state = self.state + self.state = new_state + logger.debug("state_machine | %s -> %s", old_state.value, new_state.value) + + for listener in self._listeners: + try: + listener(old_state, new_state) + except Exception: + logger.exception( + "state_machine | listener raised an exception during transition %s -> %s", + old_state.value, + new_state.value, + ) + + def reset(self) -> None: + """Force-reset the state machine to IDLE, bypassing transition validation. + + This is intended exclusively for cleanup/error-recovery paths (e.g., the + ``finally`` block of an invocation) where the agent must return to a usable + state regardless of which phase it was in when an exception occurred. + """ + old_state = self.state + self.state = AgentExecutionState.IDLE + if old_state != AgentExecutionState.IDLE: + logger.debug("state_machine | reset %s -> idle", old_state.value) + for listener in self._listeners: + try: + listener(old_state, AgentExecutionState.IDLE) + except Exception: + logger.exception( + "state_machine | listener raised an exception during reset from %s", + old_state.value, + ) + + def try_transition(self, new_state: AgentExecutionState) -> bool: + """Attempt a transition, returning *False* instead of raising on failure. + + Useful for "best-effort" transitions in error paths where the exact + current state may be uncertain. + + Args: + new_state: The target state. + + Returns: + True if the transition succeeded, False otherwise. + """ + try: + self.transition(new_state) + return True + except InvalidStateTransitionError: + logger.debug( + "state_machine | ignoring invalid transition %s -> %s", + self.state.value, + new_state.value, + ) + return False + + # ------------------------------------------------------------------ + # Listeners + # ------------------------------------------------------------------ + + def add_listener(self, listener: TransitionListener) -> None: + """Register a callable to be invoked on every state transition. + + Args: + listener: Callable with signature ``(old_state, new_state) -> None``. + """ + self._listeners.append(listener) + + def remove_listener(self, listener: TransitionListener) -> None: + """Remove a previously registered listener. + + Args: + listener: The listener to remove. + """ + self._listeners.remove(listener) + + # ------------------------------------------------------------------ + # Introspection helpers + # ------------------------------------------------------------------ + + @property + def is_checkpoint(self) -> bool: + """True if the current state is safe for durable snapshots.""" + return self.state in CHECKPOINT_STATES + + @property + def is_running(self) -> bool: + """True if an invocation is currently in progress.""" + return self.state not in ( + AgentExecutionState.IDLE, + AgentExecutionState.INTERRUPTED, + AgentExecutionState.COMPLETED, + AgentExecutionState.CANCELLED, + AgentExecutionState.ERROR, + ) + + # ------------------------------------------------------------------ + # Serialization (for durable agents) + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + """Serialize the state machine to a JSON-safe dict. + + Only :attr:`state` is serialized; listeners are not persisted. + + Returns: + ``{"state": ""}`` + """ + return {"state": self.state.value} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentStateMachine": + """Restore a state machine from a serialized dict. + + Args: + data: Dict previously produced by :meth:`to_dict`. + + Returns: + A new :class:`AgentStateMachine` in the restored state. + """ + return cls(state=AgentExecutionState(data["state"])) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 2e8e4a660..04fb93b00 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -41,6 +41,7 @@ ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse +from ..agent.state_machine import AgentExecutionState from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from ._retry import ModelRetryStrategy from .streaming import stream_messages @@ -142,13 +143,16 @@ async def event_loop_cycle( with trace_api.use_span(cycle_span, end_on_exit=True): # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. if agent._interrupt_state.activated: + agent.state_machine.transition(AgentExecutionState.TOOL_EXECUTION) stop_reason: StopReason = "tool_use" message = agent._interrupt_state.context["tool_use_message"] # Skip model invocation if the latest message contains ToolUse elif _has_tool_use_in_latest_message(agent.messages): + agent.state_machine.transition(AgentExecutionState.TOOL_EXECUTION) stop_reason = "tool_use" message = agent.messages[-1] else: + agent.state_machine.transition(AgentExecutionState.MODEL_CALL) model_events = _handle_model_execution( agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context ) @@ -177,6 +181,10 @@ async def event_loop_cycle( ) if stop_reason == "tool_use": + # Only transition MODEL_CALL → TOOL_EXECUTION here; the interrupt and + # has_tool_use_in_latest_message paths already transitioned above. + if agent.state_machine.state == AgentExecutionState.MODEL_CALL: + agent.state_machine.transition(AgentExecutionState.TOOL_EXECUTION) # Handle tool execution tool_events = _handle_tool_execution( stop_reason, @@ -207,6 +215,7 @@ async def event_loop_cycle( raise e except Exception as e: # Handle any other exceptions + agent.state_machine.try_transition(AgentExecutionState.ERROR) yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 96c7f577b..cf49e3047 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -37,6 +37,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: AfterNodeCallEvent, AfterToolCallEvent, AgentInitializedEvent, + AgentStateTransitionEvent, BeforeInvocationEvent, BeforeModelCallEvent, BeforeMultiAgentInvocationEvent, @@ -49,6 +50,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: __all__ = [ "AgentInitializedEvent", + "AgentStateTransitionEvent", "BeforeInvocationEvent", "BeforeToolCallEvent", "AfterToolCallEvent", diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 9186e0e70..0a32ccd82 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult + from ..agent.state_machine import AgentExecutionState from ..types.agent import AgentInput from ..types.content import Message, Messages @@ -402,3 +403,40 @@ class AfterMultiAgentInvocationEvent(BaseHookEvent): def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +@dataclass +class AgentStateTransitionEvent(HookEvent): + """Event triggered when the agent's execution state changes. + + This event is fired synchronously on every :class:`~strands.agent.state_machine.AgentStateMachine` + state transition, allowing hooks to observe or react to lifecycle changes. + + Common use-cases: + + - **Durability checkpoints**: Serialize agent state when + ``new_state in CHECKPOINT_STATES``. + - **Observability / metrics**: Record how long the agent spends in each phase. + - **Audit logging**: Track every phase transition for compliance. + + Example:: + + from strands.hooks import AgentStateTransitionEvent + from strands.agent.state_machine import AgentExecutionState, CHECKPOINT_STATES + + class DurabilityHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AgentStateTransitionEvent, self.maybe_checkpoint) + + def maybe_checkpoint(self, event: AgentStateTransitionEvent) -> None: + if event.new_state in CHECKPOINT_STATES: + snapshot = event.agent.state_machine.to_dict() + # persist snapshot to storage ... + + Attributes: + old_state: The execution state before the transition. + new_state: The execution state after the transition. + """ + + old_state: "AgentExecutionState" + new_state: "AgentExecutionState"