diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..a51d12386 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,10 +9,12 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import asyncio import logging import threading import warnings from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping +from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, @@ -60,7 +62,7 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException, IdempotencyAbortedError from ..types.traces import AttributeValue from .agent_result import AgentResult from .base import AgentBase @@ -95,6 +97,21 @@ class _DefaultRetryStrategySentinel: _DEFAULT_AGENT_ID = "default" +@dataclass +class _InflightInvocation: + """Tracks an inflight invocation for idempotency deduplication. + + When a caller provides an `idempotency_token`, the agent registers this invocation + (in THROW mode only one can be inflight at a time). If a duplicate call arrives + with the same token while the original is still running, the duplicate waits on + the `done` event and receives the same result or error. + """ + + done: threading.Event = field(default_factory=threading.Event) + result: AgentResult | None = None + error: BaseException | None = None + + class Agent(AgentBase): """Core Agent implementation. @@ -285,6 +302,14 @@ def __init__( self._invocation_lock = threading.Lock() self._concurrent_invocation_mode = concurrent_invocation_mode + # Tracks the single inflight invocation for idempotency duplicate detection. + # In THROW mode only one invocation can be inflight at a time, so a single + # variable suffices. Uses threading primitives (not asyncio) because run_async() + # creates separate threads with separate event loops. + self._inflight_idempotency_token: Any = None + self._inflight_invocation: _InflightInvocation | None = None + self._inflight_invocations_lock = threading.Lock() + # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy if ( @@ -422,6 +447,7 @@ def __call__( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -441,6 +467,11 @@ def __call__( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -458,6 +489,7 @@ def __call__( invocation_state=invocation_state, structured_output_model=structured_output_model, structured_output_prompt=structured_output_prompt, + idempotency_token=idempotency_token, **kwargs, ) ) @@ -469,6 +501,7 @@ async def invoke_async( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -488,6 +521,11 @@ async def invoke_async( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -503,6 +541,7 @@ async def invoke_async( invocation_state=invocation_state, structured_output_model=structured_output_model, structured_output_prompt=structured_output_prompt, + idempotency_token=idempotency_token, **kwargs, ) async for event in events: @@ -685,6 +724,74 @@ def __del__(self) -> None: if hasattr(self, "tool_registry"): self.tool_registry.cleanup() + def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocation | None, Any]: + """Check if this invocation is a duplicate of an inflight one, or register it as new. + + Only active in THROW mode. In UNSAFE_REENTRANT mode or when no token is provided, + this is a no-op that returns (None, None). + + Args: + idempotency_token: Caller-provided token for duplicate detection. + + Returns: + A tuple of (waiting_on, registered_token): + - If duplicate: (inflight_invocation_to_wait_on, None) + - If new request: (None, the_registered_token) + - If no token or wrong mode: (None, None) + """ + if idempotency_token is None or self._concurrent_invocation_mode != ConcurrentInvocationMode.THROW: + return None, None + + with self._inflight_invocations_lock: + if self._inflight_idempotency_token == idempotency_token: + return self._inflight_invocation, None + elif self._inflight_idempotency_token is not None: + # A different token is already inflight; don't overwrite it. + # Fall through to the _invocation_lock check which will raise ConcurrencyException. + return None, None + else: + self._inflight_invocation = _InflightInvocation() + self._inflight_idempotency_token = idempotency_token + return None, idempotency_token + + def _complete_idempotent_invocation( + self, + registered_token: Any, + result: AgentResult | None = None, + error: BaseException | None = None, + ) -> None: + """Signal waiting duplicates and clean up idempotency state. + + Safe to call even when registered_token is None (no-op in that case). + If both result and error are None (e.g. primary lost a lock race or was cancelled), + sets IdempotencyAbortedError so duplicates receive a clear error. + + Args: + registered_token: The token that was registered by _check_idempotency, or None. + result: The AgentResult to pass to waiting duplicates (success path). + error: The exception to pass to waiting duplicates (error path). + """ + if registered_token is None: + return + + with self._inflight_invocations_lock: + if self._inflight_idempotency_token != registered_token: + return # Another invocation owns the slot; don't touch it. + inflight = self._inflight_invocation + self._inflight_idempotency_token = None + self._inflight_invocation = None + + if inflight is None: + return + + if error is not None: + inflight.error = error + elif result is not None: + inflight.result = result + else: + inflight.error = IdempotencyAbortedError("Primary invocation was aborted before producing a result.") + inflight.done.set() + async def stream_async( self, prompt: AgentInput = None, @@ -692,6 +799,7 @@ async def stream_async( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -711,6 +819,11 @@ async def stream_async( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -733,15 +846,30 @@ async def stream_async( yield event["data"] ``` """ + waiting_on, registered_token = self._check_idempotency(idempotency_token) + + if waiting_on is not None: + logger.debug("idempotency_token=<%s> | duplicate request detected, waiting for original", idempotency_token) + await asyncio.to_thread(waiting_on.done.wait) + if waiting_on.error is not None: + raise waiting_on.error + if waiting_on.result is not None: + yield AgentResultEvent(result=waiting_on.result).as_dict() + return + # Conditionally acquire lock based on concurrent_invocation_mode # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: lock_acquired = self._invocation_lock.acquire(blocking=False) if not lock_acquired: - raise ConcurrencyException( + exc = ConcurrencyException( "Agent is already processing a request. Concurrent invocations are not supported." ) + self._complete_idempotent_invocation(registered_token, error=exc) + raise exc + + result: AgentResult | None = None try: self._interrupt_state.resume(prompt) @@ -787,12 +915,15 @@ async def stream_async( except Exception as e: self._end_agent_trace_span(error=e) + self._complete_idempotent_invocation(registered_token, error=e) raise finally: # Clear cancel signal to allow agent reuse after cancellation self._cancel_signal.clear() + self._complete_idempotent_invocation(registered_token, result=result) + if self._invocation_lock.locked(): self._invocation_lock.release() diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1d1983abd..c9f067833 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -105,3 +105,15 @@ class ConcurrencyException(Exception): """ pass + + +class IdempotencyAbortedError(Exception): + """Exception raised to duplicate invocations when the primary invocation was aborted. + + When a caller provides an idempotency_token and another invocation with the same token + is already in-flight, the duplicate waits for the primary to complete. If the primary + is aborted before producing a result (e.g. it lost a lock race or was cancelled), + this exception is raised to all waiting duplicates. + """ + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..94b484e84 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -216,6 +216,42 @@ async def stream( yield event +class SyncEventFailingModel: + """A mock model that signals when streaming starts, then raises an error. + + Used for testing idempotency behavior when the original invocation fails. + """ + + def __init__(self): + self.started_event = threading.Event() + self.proceed_event = threading.Event() + + async def stream(self, *args, **kwargs): + self.started_event.set() + self.proceed_event.wait() + raise RuntimeError("Simulated model failure") + yield # noqa: RET503 - makes this an async generator + + +class IdempotencyTestAgent(Agent): + """Agent subclass that signals when a duplicate idempotency token is detected. + + Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: + the model pauses Thread 1 inside stream(), and this class signals when Thread 2 + has reached _check_idempotency and been identified as a duplicate. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.duplicate_detected = threading.Event() + + def _check_idempotency(self, idempotency_token): + result = super()._check_idempotency(idempotency_token) + if result[0] is not None: + self.duplicate_detected.set() + return result + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -2699,3 +2735,405 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +def test_idempotency_duplicate_waits_and_returns_same_result(): + """Test that a duplicate call with the same idempotency_token waits and returns the same result.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc-123") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert str(results[0]) == str(results[1]) + + +def test_idempotency_original_fails_duplicate_gets_same_error(): + """Test that when the original invocation fails, the duplicate receives the same exception.""" + model = SyncEventFailingModel() + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + errors = [] + lock = threading.Lock() + + def invoke(): + try: + agent("test", idempotency_token="abc-123") + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 2, f"Expected 2 errors, got {len(errors)}" + assert all(isinstance(e, RuntimeError) for e in errors) + assert all("Simulated model failure" in str(e) for e in errors) + + +def test_idempotency_different_token_raises_concurrency_exception(): + """Test that a different idempotency_token while another is inflight raises ConcurrencyException.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_abc(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + def invoke_def(): + try: + result = agent("test", idempotency_token="def") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke_abc) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke_def) + t2.start() + t2.join(timeout=1.0) + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert isinstance(errors[0], ConcurrencyException) + + +def test_idempotency_no_token_falls_back_to_throw(): + """Test that a call without idempotency_token still gets ConcurrencyException in THROW mode.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_with_token(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + def invoke_without_token(): + try: + result = agent("test") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke_with_token) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke_without_token) + t2.start() + t2.join(timeout=1.0) + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert isinstance(errors[0], ConcurrencyException) + + +def test_idempotency_ignored_in_unsafe_reentrant(): + """Test that idempotency_token has no effect in UNSAFE_REENTRANT mode.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + + +def test_idempotency_cleanup_after_completion(): + """Test that after completion, the same token is treated as a fresh request.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + result1 = agent("test", idempotency_token="abc") + assert str(result1).strip() == "response1" + + result2 = agent("test", idempotency_token="abc") + assert str(result2).strip() == "response2" + + assert str(result1) != str(result2) + + +def test_idempotency_with_prompt_as_token(): + """Test that the prompt itself can be used as the idempotency_token.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + prompt = "What's the weather?" + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent(prompt, idempotency_token=prompt) + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert str(results[0]) == str(results[1]) + + +def test_idempotency_no_deadlock_on_competing_token(): + """A 3rd thread with a different token must not prevent a waiting duplicate from waking up. + + T1 runs with token "abc" → T2 (same token) waits as duplicate → T3 arrives with token "def" + and gets ConcurrencyException. T1 then completes and T2 must receive the result, not hang. + """ + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_abc(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(("abc", result)) + except Exception as e: + with lock: + errors.append(e) + + def invoke_def(): + try: + result = agent("test", idempotency_token="def") + with lock: + results.append(("def", result)) + except Exception as e: + with lock: + errors.append(e) + + # T1 starts and pauses inside the model + t1 = threading.Thread(target=invoke_abc) + t1.start() + model.started_event.wait() + + # T2 detects duplicate and waits + t2 = threading.Thread(target=invoke_abc) + t2.start() + agent.duplicate_detected.wait() + + # T3 arrives with a different token - must get ConcurrencyException, not corrupt T1's state + t3 = threading.Thread(target=invoke_def) + t3.start() + t3.join(timeout=2.0) + assert not t3.is_alive(), "T3 should have returned quickly with ConcurrencyException" + + # Unblock T1; T2 must wake up (not hang) + model.proceed_event.set() + t1.join(timeout=5.0) + t2.join(timeout=5.0) + + assert not t1.is_alive(), "T1 hung - possible deadlock" + assert not t2.is_alive(), "T2 hung - deadlock: waiting duplicate never woke up" + + abc_results = [r for name, r in results if name == "abc"] + assert len(abc_results) == 2, f"Expected T1 and T2 both to succeed, got results={results} errors={errors}" + assert str(abc_results[0]) == str(abc_results[1]) + + concurrency_errors = [e for e in errors if isinstance(e, ConcurrencyException)] + assert len(concurrency_errors) == 1, f"Expected exactly 1 ConcurrencyException for T3, got {errors}" + + +def test_idempotency_multiple_duplicates_all_wake_up(): + """Test that multiple duplicates waiting on the same token all receive the result.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + # T1 is the primary + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + # T2 and T3 are both duplicates waiting on the same token + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + agent.duplicate_detected.clear() + + t3 = threading.Thread(target=invoke) + t3.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + t3.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 3, f"Expected 3 results (T1, T2, T3), got {len(results)}" + assert str(results[0]) == str(results[1]) == str(results[2]) + + +def test_idempotency_cleanup_after_failure(): + """Test that after a failure, the same token is treated as a fresh request.""" + fail_model = SyncEventFailingModel() + agent = Agent(model=fail_model, concurrent_invocation_mode="throw") + + # First call fails + with pytest.raises(RuntimeError, match="Simulated model failure"): + fail_model.proceed_event.set() + agent("test", idempotency_token="abc") + + # Second call with the same token should run fresh, not be treated as a duplicate + success_model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "recovered"}]}, + ] + ) + agent.model = success_model + result = agent("test", idempotency_token="abc") + assert str(result).strip() == "recovered"