Skip to content
135 changes: 133 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import asyncio
import logging
import threading
import warnings
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -60,7 +62,7 @@
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput, ConcurrentInvocationMode
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException, IdempotencyAbortedError
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .base import AgentBase
Expand Down Expand Up @@ -95,6 +97,21 @@ class _DefaultRetryStrategySentinel:
_DEFAULT_AGENT_ID = "default"


@dataclass
class _InflightInvocation:
"""Tracks an inflight invocation for idempotency deduplication.

When a caller provides an `idempotency_token`, the agent registers this invocation
(in THROW mode only one can be inflight at a time). If a duplicate call arrives
with the same token while the original is still running, the duplicate waits on
the `done` event and receives the same result or error.
"""

done: threading.Event = field(default_factory=threading.Event)
result: AgentResult | None = None
error: BaseException | None = None


class Agent(AgentBase):
"""Core Agent implementation.

Expand Down Expand Up @@ -285,6 +302,14 @@ def __init__(
self._invocation_lock = threading.Lock()
self._concurrent_invocation_mode = concurrent_invocation_mode

# Tracks the single inflight invocation for idempotency duplicate detection.
# In THROW mode only one invocation can be inflight at a time, so a single
# variable suffices. Uses threading primitives (not asyncio) because run_async()
# creates separate threads with separate event loops.
self._inflight_idempotency_token: Any = None
self._inflight_invocation: _InflightInvocation | None = None
self._inflight_invocations_lock = threading.Lock()

# In the future, we'll have a RetryStrategy base class but until
# that API is determined we only allow ModelRetryStrategy
if (
Expand Down Expand Up @@ -422,6 +447,7 @@ def __call__(
invocation_state: dict[str, Any] | None = None,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
idempotency_token: Any = None,
**kwargs: Any,
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand All @@ -441,6 +467,11 @@ def __call__(
invocation_state: Additional parameters to pass through the event loop.
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
structured_output_prompt: Custom prompt for forcing structured output (overrides agent default).
idempotency_token: Optional token for duplicate request detection. If provided in THROW mode
and another invocation with the same token is already inflight, the caller waits for the
original to complete and receives the same result. Duplicate callers receive only the
final AgentResult; intermediate streaming events are not replayed. Can be any hashable
object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode.
**kwargs: Additional parameters to pass through the event loop.[Deprecating]

Returns:
Expand All @@ -458,6 +489,7 @@ def __call__(
invocation_state=invocation_state,
structured_output_model=structured_output_model,
structured_output_prompt=structured_output_prompt,
idempotency_token=idempotency_token,
**kwargs,
)
)
Expand All @@ -469,6 +501,7 @@ async def invoke_async(
invocation_state: dict[str, Any] | None = None,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
idempotency_token: Any = None,
**kwargs: Any,
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand All @@ -488,6 +521,11 @@ async def invoke_async(
invocation_state: Additional parameters to pass through the event loop.
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
structured_output_prompt: Custom prompt for forcing structured output (overrides agent default).
idempotency_token: Optional token for duplicate request detection. If provided in THROW mode
and another invocation with the same token is already inflight, the caller waits for the
original to complete and receives the same result. Duplicate callers receive only the
final AgentResult; intermediate streaming events are not replayed. Can be any hashable
object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode.
**kwargs: Additional parameters to pass through the event loop.[Deprecating]

Returns:
Expand All @@ -503,6 +541,7 @@ async def invoke_async(
invocation_state=invocation_state,
structured_output_model=structured_output_model,
structured_output_prompt=structured_output_prompt,
idempotency_token=idempotency_token,
**kwargs,
)
async for event in events:
Expand Down Expand Up @@ -685,13 +724,82 @@ def __del__(self) -> None:
if hasattr(self, "tool_registry"):
self.tool_registry.cleanup()

def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocation | None, Any]:
"""Check if this invocation is a duplicate of an inflight one, or register it as new.

Only active in THROW mode. In UNSAFE_REENTRANT mode or when no token is provided,
this is a no-op that returns (None, None).

Args:
idempotency_token: Caller-provided token for duplicate detection.

Returns:
A tuple of (waiting_on, registered_token):
- If duplicate: (inflight_invocation_to_wait_on, None)
- If new request: (None, the_registered_token)
- If no token or wrong mode: (None, None)
"""
if idempotency_token is None or self._concurrent_invocation_mode != ConcurrentInvocationMode.THROW:
return None, None

with self._inflight_invocations_lock:
if self._inflight_idempotency_token == idempotency_token:
return self._inflight_invocation, None
elif self._inflight_idempotency_token is not None:
# A different token is already inflight; don't overwrite it.
# Fall through to the _invocation_lock check which will raise ConcurrencyException.
return None, None
else:
self._inflight_invocation = _InflightInvocation()
self._inflight_idempotency_token = idempotency_token
return None, idempotency_token

def _complete_idempotent_invocation(
self,
registered_token: Any,
result: AgentResult | None = None,
error: BaseException | None = None,
) -> None:
"""Signal waiting duplicates and clean up idempotency state.

Safe to call even when registered_token is None (no-op in that case).
If both result and error are None (e.g. primary lost a lock race or was cancelled),
sets IdempotencyAbortedError so duplicates receive a clear error.

Args:
registered_token: The token that was registered by _check_idempotency, or None.
result: The AgentResult to pass to waiting duplicates (success path).
error: The exception to pass to waiting duplicates (error path).
"""
if registered_token is None:
return

with self._inflight_invocations_lock:
if self._inflight_idempotency_token != registered_token:
return # Another invocation owns the slot; don't touch it.
inflight = self._inflight_invocation
self._inflight_idempotency_token = None
self._inflight_invocation = None

if inflight is None:
return

if error is not None:
inflight.error = error
elif result is not None:
inflight.result = result
else:
inflight.error = IdempotencyAbortedError("Primary invocation was aborted before producing a result.")
inflight.done.set()

async def stream_async(
self,
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
idempotency_token: Any = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand All @@ -711,6 +819,11 @@ async def stream_async(
invocation_state: Additional parameters to pass through the event loop.
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
structured_output_prompt: Custom prompt for forcing structured output (overrides agent default).
idempotency_token: Optional token for duplicate request detection. If provided in THROW mode
and another invocation with the same token is already inflight, the caller waits for the
original to complete and receives the same result. Duplicate callers receive only the
final AgentResult; intermediate streaming events are not replayed. Can be any hashable
object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode.
**kwargs: Additional parameters to pass to the event loop.[Deprecating]

Yields:
Expand All @@ -733,15 +846,30 @@ async def stream_async(
yield event["data"]
```
"""
waiting_on, registered_token = self._check_idempotency(idempotency_token)

if waiting_on is not None:
logger.debug("idempotency_token=<%s> | duplicate request detected, waiting for original", idempotency_token)
await asyncio.to_thread(waiting_on.done.wait)
if waiting_on.error is not None:
raise waiting_on.error
if waiting_on.result is not None:
yield AgentResultEvent(result=waiting_on.result).as_dict()
return

# Conditionally acquire lock based on concurrent_invocation_mode
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads
if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW:
lock_acquired = self._invocation_lock.acquire(blocking=False)
if not lock_acquired:
raise ConcurrencyException(
exc = ConcurrencyException(
"Agent is already processing a request. Concurrent invocations are not supported."
)
self._complete_idempotent_invocation(registered_token, error=exc)
raise exc

result: AgentResult | None = None

try:
self._interrupt_state.resume(prompt)
Expand Down Expand Up @@ -787,12 +915,15 @@ async def stream_async(

except Exception as e:
self._end_agent_trace_span(error=e)
self._complete_idempotent_invocation(registered_token, error=e)
raise

finally:
# Clear cancel signal to allow agent reuse after cancellation
self._cancel_signal.clear()

self._complete_idempotent_invocation(registered_token, result=result)

if self._invocation_lock.locked():
self._invocation_lock.release()

Expand Down
12 changes: 12 additions & 0 deletions src/strands/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,15 @@ class ConcurrencyException(Exception):
"""

pass


class IdempotencyAbortedError(Exception):
"""Exception raised to duplicate invocations when the primary invocation was aborted.

When a caller provides an idempotency_token and another invocation with the same token
is already in-flight, the duplicate waits for the primary to complete. If the primary
is aborted before producing a result (e.g. it lost a lock race or was cancelled),
this exception is raised to all waiting duplicates.
"""

pass
Loading