From b8bd92555e8ce1e42411ba873d7271af712e9fa6 Mon Sep 17 00:00:00 2001 From: Di-Is Date: Fri, 20 Mar 2026 10:59:10 +0900 Subject: [PATCH 1/3] fix: restore explicit span.end() to fix span end_time regression PR #1293 wrapped event_loop_cycle() in use_span(end_on_exit=True) and removed explicit span.end() calls. Because event_loop_cycle is an async generator, yield keeps the context manager open across recursive cycles, causing all execute_event_loop_cycle spans to share the same OTel end_time. Switch to end_on_exit=False and explicitly call span.end() via _end_span() in end_event_loop_cycle_span() and end_model_invoke_span(), restoring end_span_with_error() in all exception paths. --- src/strands/event_loop/event_loop.py | 125 ++++++++++-------- src/strands/telemetry/tracer.py | 34 +++-- tests/strands/event_loop/test_event_loop.py | 60 +++++++++ .../test_event_loop_structured_output.py | 85 ++++++++++++ tests/strands/telemetry/test_tracer.py | 34 +++++ 5 files changed, 262 insertions(+), 76 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3b1e2d76a..fa15e1739 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -139,25 +139,29 @@ async def event_loop_cycle( ) invocation_state["event_loop_cycle_span"] = cycle_span - with trace_api.use_span(cycle_span, end_on_exit=True): - # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. - if agent._interrupt_state.activated: - stop_reason: StopReason = "tool_use" - message = agent._interrupt_state.context["tool_use_message"] - # Skip model invocation if the latest message contains ToolUse - elif _has_tool_use_in_latest_message(agent.messages): - stop_reason = "tool_use" - message = agent.messages[-1] - else: - model_events = _handle_model_execution( - agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context - ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + with trace_api.use_span(cycle_span, end_on_exit=False): + try: + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] + # Skip model invocation if the latest message contains ToolUse + elif _has_tool_use_in_latest_message(agent.messages): + stop_reason = "tool_use" + message = agent.messages[-1] + else: + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -196,42 +200,48 @@ async def event_loop_cycle( # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - # Set attributes before span auto-closes + + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." + ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + await agent._append_messages( + {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} + ) + + tracer.end_event_loop_cycle_span(cycle_span, message) + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) + async for typed_event in events: + yield typed_event + return + tracer.end_event_loop_cycle_span(cycle_span, message) - except EventLoopException: + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + except StructuredOutputException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise + except EventLoopException as e: + tracer.end_span_with_error(cycle_span, str(e), e) # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise except (ContextWindowOverflowException, MaxTokensReachedException) as e: # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException + tracer.end_span_with_error(cycle_span, str(e), e) raise e except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - # Force structured output tool call if LLM didn't use it automatically - if structured_output_context.is_enabled and stop_reason == "end_turn": - if structured_output_context.force_attempted: - raise StructuredOutputException( - "The model failed to invoke the structured output tool even after it was forced." - ) - structured_output_context.set_forced_mode() - logger.debug("Forcing structured output tool") - await agent._append_messages( - {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} - ) - - events = recurse_event_loop( - agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context - ) - async for typed_event in events: - yield typed_event - return - - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - async def recurse_event_loop( agent: "Agent", @@ -314,20 +324,21 @@ async def _handle_model_execution( model_id=model_id, custom_trace_attributes=agent.trace_attributes, ) - with trace_api.use_span(model_invoke_span, end_on_exit=True): - await agent.hooks.invoke_callbacks_async( - BeforeModelCallEvent( - agent=agent, - invocation_state=invocation_state, + with trace_api.use_span(model_invoke_span, end_on_exit=False): + try: + await agent.hooks.invoke_callbacks_async( + BeforeModelCallEvent( + agent=agent, + invocation_state=invocation_state, + ) ) - ) - if structured_output_context.forced_mode: - tool_spec = structured_output_context.get_tool_spec() - tool_specs = [tool_spec] if tool_spec else [] - else: - tool_specs = agent.tool_registry.get_all_tool_specs() - try: + if structured_output_context.forced_mode: + tool_spec = structured_output_context.get_tool_spec() + tool_specs = [tool_spec] if tool_spec else [] + else: + tool_specs = agent.tool_registry.get_all_tool_specs() + async for event in stream_messages( agent.model, agent.system_prompt, @@ -360,17 +371,17 @@ async def _handle_model_execution( "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, ) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) continue # Retry the model call if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) - # Set attributes before span auto-closes tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: - # Exception is automatically recorded by use_span with end_on_exit=True + tracer.end_span_with_error(model_invoke_span, str(e), e) after_model_call_event = AfterModelCallEvent( agent=agent, invocation_state=invocation_state, @@ -538,7 +549,7 @@ async def _handle_tool_execution( interrupts, structured_output=structured_output_result, ) - # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) @@ -556,7 +567,7 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) - # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 0471a7fcc..4068d84a0 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -185,6 +185,7 @@ def _end_span( span: Span, attributes: dict[str, AttributeValue] | None = None, error: Exception | None = None, + error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -192,8 +193,9 @@ def _end_span( span: The span to end attributes: Optional attributes to set before ending the span error: Optional exception if an error occurred + error_message: Optional error message to set in the span status """ - if not span: + if not span or not span.is_recording(): return try: @@ -206,7 +208,8 @@ def _end_span( # Handle error if present if error: - span.set_status(StatusCode.ERROR, str(error)) + status_description = error_message or str(error) or type(error).__name__ + span.set_status(StatusCode.ERROR, status_description) span.record_exception(error) else: span.set_status(StatusCode.OK) @@ -229,11 +232,11 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Excepti error_message: Error message to set in the span status. exception: Optional exception to record in the span. """ - if not span: + if not span or not span.is_recording(): return error = exception or Exception(error_message) - self._end_span(span, error=error) + self._end_span(span, error=error, error_message=error_message) def _add_event( self, span: Span | None, event_name: str, event_attributes: Attributes, to_span_attributes: bool = False @@ -325,18 +328,15 @@ def end_model_invoke_span( ) -> None: """End a model invocation span with results and metrics. - Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. - Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. - Args: - span: The span to set attributes on. + span: The span to end. message: The message response from the model. usage: Token usage information from the model call. metrics: Metrics from the model call. stop_reason: The reason the model stopped generating. """ - # Set end time attribute - span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) + if not span or not span.is_recording(): + return attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], @@ -373,7 +373,7 @@ def end_model_invoke_span( event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, ) - span.set_attributes(attributes) + self._end_span(span, attributes) def start_tool_call_span( self, @@ -548,20 +548,14 @@ def end_event_loop_cycle_span( ) -> None: """End an event loop cycle span with results. - Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. - Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. - Args: - span: The span to set attributes on. + span: The span to end. message: The message response from this cycle. tool_result_message: Optional tool result message if a tool was called. """ - if not span: + if not span or not span.is_recording(): return - # Set end time attribute - span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) - event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: @@ -586,6 +580,8 @@ def end_event_loop_cycle_span( else: self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) + self._end_span(span) + def start_agent_span( self, messages: Messages, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0cabeaeee..cd3ebe2fd 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -5,6 +5,9 @@ from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import strands import strands.telemetry @@ -19,6 +22,7 @@ ) from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry from strands.types._events import EventLoopStopEvent @@ -578,6 +582,14 @@ async def test_event_loop_tracing_with_model_error( ) await alist(stream) + assert mock_tracer.end_span_with_error.call_count == 2 + mock_tracer.end_span_with_error.assert_has_calls( + [ + call(model_span, "Input too long", model.stream.side_effect), + call(cycle_span, "Input too long", model.stream.side_effect), + ] + ) + @pytest.mark.asyncio async def test_event_loop_cycle_max_tokens_exception( @@ -668,6 +680,53 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( + agent, + model, + tool_stream, + agenerator, + alist, +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + async def delayed_text_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(0.05) + yield {"contentBlockStop": {}} + + agent.trace_span = None + agent._system_prompt_content = None + model.config = {"model_id": "test-model"} + model.stream.side_effect = [ + agenerator(tool_stream), + delayed_text_stream(), + ] + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + provider.force_flush() + cycle_spans = sorted( + [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"], + key=lambda span: span.start_time, + ) + + assert len(cycle_spans) == 2 + assert cycle_spans[0].end_time <= cycle_spans[1].start_time + assert cycle_spans[0].end_time < cycle_spans[1].end_time + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_throttling_exception( @@ -704,6 +763,7 @@ async def test_event_loop_tracing_with_throttling_exception( ) await alist(stream) + assert mock_tracer.end_span_with_error.call_count == 1 # Verify span was created for the successful retry assert mock_tracer.start_model_invoke_span.call_count == 2 assert mock_tracer.end_model_invoke_span.call_count == 1 diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index ad792f52c..2d1150712 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -4,16 +4,22 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode from pydantic import BaseModel from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.registry import ToolRegistry from strands.tools.structured_output._structured_output_context import ( DEFAULT_STRUCTURED_OUTPUT_PROMPT, StructuredOutputContext, ) from strands.types._events import EventLoopStopEvent, StructuredOutputEvent +from strands.types.exceptions import EventLoopException, StructuredOutputException class UserModel(BaseModel): @@ -253,6 +259,85 @@ async def test_event_loop_forces_structured_output_with_custom_prompt(mock_agent assert args["content"][0]["text"] == custom_prompt +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_structured_output_failure_closes_cycle_span_with_error( + mock_get_tracer, + mock_agent, + structured_output_context, + agenerator, + alist, +): + mock_tracer = Mock() + cycle_span = Mock() + model_span = Mock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + mock_tracer.start_model_invoke_span.return_value = model_span + mock_get_tracer.return_value = mock_tracer + + structured_output_context.set_forced_mode() + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Still not structured"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + expected_message = "The model failed to invoke the structured output tool even after it was forced." + with pytest.raises(StructuredOutputException, match=expected_message): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + mock_tracer.end_model_invoke_span.assert_called_once() + mock_tracer.end_event_loop_cycle_span.assert_not_called() + mock_tracer.end_span_with_error.assert_called_once() + assert mock_tracer.end_span_with_error.call_args.args[0] == cycle_span + assert mock_tracer.end_span_with_error.call_args.args[1] == expected_message + assert isinstance(mock_tracer.end_span_with_error.call_args.args[2], StructuredOutputException) + + +@pytest.mark.asyncio +async def test_event_loop_forced_structured_output_append_failure_records_error_span( + mock_agent, structured_output_context, agenerator, alist +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + mock_agent._append_messages = AsyncMock(side_effect=RuntimeError("append failed")) + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + with pytest.raises(EventLoopException, match="append failed"): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + finished_cycle_spans = [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"] + + assert len(finished_cycle_spans) == 1 + assert finished_cycle_spans[0].status.status_code == StatusCode.ERROR + + @pytest.mark.asyncio async def test_structured_output_tool_execution_extracts_result( mock_agent, structured_output_context, agenerator, alist diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 410db0c0c..57f7aeca9 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -128,6 +128,30 @@ def test_end_span_with_error_message(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): + """Test that empty exception messages fall back to the exception type name.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_span_with_error_prefers_explicit_message(mock_span): + """Test that an explicit error message takes precedence over the exception text.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "Explicit error message", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Explicit error message") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -251,6 +275,8 @@ def test_end_model_invoke_span(mock_span): "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): @@ -290,6 +316,8 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): ), }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_start_tool_call_span(mock_tracer): @@ -690,6 +718,8 @@ def test_end_event_loop_cycle_span(mock_span): "tool.result": json.dumps(tool_result_message["content"]), }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): @@ -725,6 +755,8 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): ) }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_start_agent_span(mock_tracer): @@ -958,6 +990,8 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): "gen_ai.server.time_to_first_token": 5, } ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_agent_span_with_cache_metrics(mock_span): From c5342cc7443e1f6c7ad120aea6e0ddcfb3266d58 Mon Sep 17 00:00:00 2001 From: Di-Is Date: Fri, 20 Mar 2026 11:01:55 +0900 Subject: [PATCH 2/3] fix: handle BaseException in trace spans to prevent span leaks on KeyboardInterrupt Trace spans were not properly closed when BaseException (e.g. KeyboardInterrupt, asyncio.CancelledError) was raised. Add explicit BaseException handlers to close spans and aclose() calls to ensure async generators are cleaned up. --- src/strands/agent/agent.py | 4 +- src/strands/event_loop/event_loop.py | 30 ++++- src/strands/telemetry/tracer.py | 10 +- tests/strands/agent/test_agent.py | 21 ++++ tests/strands/event_loop/test_event_loop.py | 115 ++++++++++++++++++++ tests/strands/telemetry/test_tracer.py | 36 ++++++ 6 files changed, 204 insertions(+), 12 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..02e23784f 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -785,7 +785,7 @@ async def stream_async( self._end_agent_trace_span(response=result) - except Exception as e: + except BaseException as e: self._end_agent_trace_span(error=e) raise @@ -988,7 +988,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fa15e1739..358fb55cd 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -138,6 +138,7 @@ async def event_loop_cycle( custom_trace_attributes=agent.trace_attributes, ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(cycle_span, end_on_exit=False): try: @@ -153,15 +154,21 @@ async def event_loop_cycle( model_events = _handle_model_execution( agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + try: + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + finally: + await model_events.aclose() stop_reason, message, *_ = model_event["stop"] yield ModelMessageEvent(message=message) except Exception as e: tracer.end_span_with_error(cycle_span, str(e), e) raise + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -241,6 +248,9 @@ async def event_loop_cycle( yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise async def recurse_event_loop( @@ -324,6 +334,7 @@ async def _handle_model_execution( model_id=model_id, custom_trace_attributes=agent.trace_attributes, ) + streamed_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(model_invoke_span, end_on_exit=False): try: await agent.hooks.invoke_callbacks_async( @@ -339,7 +350,7 @@ async def _handle_model_execution( else: tool_specs = agent.tool_registry.get_all_tool_specs() - async for event in stream_messages( + streamed_events = stream_messages( agent.model, agent.system_prompt, agent.messages, @@ -348,8 +359,12 @@ async def _handle_model_execution( tool_choice=structured_output_context.tool_choice, invocation_state=invocation_state, cancel_signal=agent._cancel_signal, - ): - yield event + ) + try: + async for event in streamed_events: + yield event + finally: + await streamed_events.aclose() stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -410,6 +425,9 @@ async def _handle_model_execution( # No retry requested, raise the exception yield ForceStopEvent(reason=e) raise e + except BaseException as e: + tracer.end_span_with_error(model_invoke_span, str(e), e) + raise try: # Add message in trace and mark the end of the stream messages trace diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 4068d84a0..37098f104 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -184,7 +184,7 @@ def _end_span( self, span: Span, attributes: dict[str, AttributeValue] | None = None, - error: Exception | None = None, + error: BaseException | None = None, error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -224,7 +224,7 @@ def _end_span( except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None: """End a span with error status. Args: @@ -445,7 +445,9 @@ def start_tool_call_span( return span - def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: + def end_tool_call_span( + self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None + ) -> None: """End a tool call span with results. Args: @@ -645,7 +647,7 @@ def end_agent_span( self, span: Span, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """End an agent span with results and metrics. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..b5113411a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1415,6 +1415,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist): + """Test that stream_async ends the agent span when a BaseException occurs.""" + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + test_exception = KeyboardInterrupt("stop now") + mock_model.mock_stream.side_effect = test_exception + + agent = Agent(model=mock_model) + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = agent.stream_async("test prompt") + await alist(stream) + + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + def test_agent_init_with_state_object(): agent = Agent(state=AgentState({"foo": "bar"})) assert agent.state.get("foo") == "bar" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index cd3ebe2fd..b56a5c146 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -680,6 +680,121 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_stream_aclose( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(10) + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await anext(stream) + await anext(stream) + await anext(stream) + await stream.aclose() + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_task_cancellation( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + blocked_on_stream = asyncio.Event() + release_stream = asyncio.Event() + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + blocked_on_stream.set() + await release_stream.wait() + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + async def consume() -> None: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + async for _ in stream: + pass + + task = asyncio.create_task(consume()) + await blocked_on_stream.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt( + mock_get_tracer, + agent, + model, + mock_tracer, + alist, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + test_exception = KeyboardInterrupt("stop now") + model.stream.side_effect = test_exception + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + assert mock_tracer.end_span_with_error.call_args_list == [ + call(model_span, "stop now", test_exception), + call(cycle_span, "stop now", test_exception), + ] + + @pytest.mark.asyncio async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( agent, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 57f7aeca9..4ac1e8b41 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -140,6 +140,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span): + """Test that empty BaseException messages fall back to the exception type name.""" + tracer = Tracer() + error = KeyboardInterrupt() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_span_with_error_prefers_explicit_message(mock_span): """Test that an explicit error message takes precedence over the exception text.""" tracer = Tracer() @@ -1092,6 +1104,30 @@ def test_force_flush_with_error(mock_span, mock_get_tracer_provider): mock_tracer_provider.force_flush.assert_called_once() +def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that agent spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_agent_span(mock_span, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that tool call spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_tool_call_span(mock_span, None, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer() From 1b7b6f244520cb593117d2a12e452c9ea07370ca Mon Sep 17 00:00:00 2001 From: Di-Is Date: Fri, 20 Mar 2026 11:17:47 +0900 Subject: [PATCH 3/3] perf: only force flush tracer provider when ending agent spans Reduce overhead by limiting force_flush calls to agent span completion instead of every span end. Add flush parameter to _end_span() with default False, passing True only from end_agent_span(). --- src/strands/telemetry/tracer.py | 7 ++++--- tests/strands/telemetry/test_tracer.py | 23 +++++++++++++++++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 37098f104..3786ec48f 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -186,6 +186,7 @@ def _end_span( attributes: dict[str, AttributeValue] | None = None, error: BaseException | None = None, error_message: str | None = None, + flush: bool = False, ) -> None: """Generic helper method to end a span. @@ -194,6 +195,7 @@ def _end_span( attributes: Optional attributes to set before ending the span error: Optional exception if an error occurred error_message: Optional error message to set in the span status + flush: Force the tracer provider to flush after ending the span """ if not span or not span.is_recording(): return @@ -217,8 +219,7 @@ def _end_span( logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() - # Force flush to ensure spans are exported - if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): + if flush and self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: @@ -699,7 +700,7 @@ def end_agent_span( } ) - self._end_span(span, attributes, error) + self._end_span(span, attributes, error, flush=True) def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]]: """Constructs a list of tool definitions from the provided tools_config.""" diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 4ac1e8b41..f35a2798d 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1089,21 +1089,40 @@ def test_end_span_with_exception_handling(mock_span): pytest.fail("_end_span should not raise exceptions") +def test_end_span_does_not_force_flush_by_default(mock_span, mock_get_tracer_provider): + """Test that ending a regular span does not force flush by default.""" + tracer = Tracer() + mock_tracer_provider = mock_get_tracer_provider.return_value + + tracer._end_span(mock_span) + + mock_tracer_provider.force_flush.assert_not_called() + + def test_force_flush_with_error(mock_span, mock_get_tracer_provider): """Test force flush with error handling.""" - # Setup the tracer with a provider that raises an exception on force_flush tracer = Tracer() mock_tracer_provider = mock_get_tracer_provider.return_value mock_tracer_provider.force_flush.side_effect = Exception("Force flush error") # Should not raise an exception - tracer._end_span(mock_span) + tracer._end_span(mock_span, flush=True) # Verify force_flush was called mock_tracer_provider.force_flush.assert_called_once() +def test_end_agent_span_force_flushes(mock_span, mock_get_tracer_provider): + """Test that ending an agent span forces a flush.""" + tracer = Tracer() + mock_tracer_provider = mock_get_tracer_provider.return_value + + tracer.end_agent_span(mock_span) + + mock_tracer_provider.force_flush.assert_called_once() + + def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span): """Test that agent spans fall back to the exception type name for empty errors.""" tracer = Tracer()