diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..02e23784f 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -785,7 +785,7 @@ async def stream_async( self._end_agent_trace_span(response=result) - except Exception as e: + except BaseException as e: self._end_agent_trace_span(error=e) raise @@ -988,7 +988,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3b1e2d76a..358fb55cd 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -138,26 +138,37 @@ async def event_loop_cycle( custom_trace_attributes=agent.trace_attributes, ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events: AsyncGenerator[TypedEvent, None] | None = None - with trace_api.use_span(cycle_span, end_on_exit=True): - # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. - if agent._interrupt_state.activated: - stop_reason: StopReason = "tool_use" - message = agent._interrupt_state.context["tool_use_message"] - # Skip model invocation if the latest message contains ToolUse - elif _has_tool_use_in_latest_message(agent.messages): - stop_reason = "tool_use" - message = agent.messages[-1] - else: - model_events = _handle_model_execution( - agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context - ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event - - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + with trace_api.use_span(cycle_span, end_on_exit=False): + try: + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] + # Skip model invocation if the latest message contains ToolUse + elif _has_tool_use_in_latest_message(agent.messages): + stop_reason = "tool_use" + message = agent.messages[-1] + else: + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) + try: + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + finally: + await model_events.aclose() + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -196,41 +207,50 @@ async def event_loop_cycle( # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - # Set attributes before span auto-closes + + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." + ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + await agent._append_messages( + {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} + ) + + tracer.end_event_loop_cycle_span(cycle_span, message) + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) + async for typed_event in events: + yield typed_event + return + tracer.end_event_loop_cycle_span(cycle_span, message) - except EventLoopException: + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + except StructuredOutputException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise + except EventLoopException as e: + tracer.end_span_with_error(cycle_span, str(e), e) # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise except (ContextWindowOverflowException, MaxTokensReachedException) as e: # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException + tracer.end_span_with_error(cycle_span, str(e), e) raise e except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - - # Force structured output tool call if LLM didn't use it automatically - if structured_output_context.is_enabled and stop_reason == "end_turn": - if structured_output_context.force_attempted: - raise StructuredOutputException( - "The model failed to invoke the structured output tool even after it was forced." - ) - structured_output_context.set_forced_mode() - logger.debug("Forcing structured output tool") - await agent._append_messages( - {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} - ) - - events = recurse_event_loop( - agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context - ) - async for typed_event in events: - yield typed_event - return - - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise async def recurse_event_loop( @@ -314,21 +334,23 @@ async def _handle_model_execution( model_id=model_id, custom_trace_attributes=agent.trace_attributes, ) - with trace_api.use_span(model_invoke_span, end_on_exit=True): - await agent.hooks.invoke_callbacks_async( - BeforeModelCallEvent( - agent=agent, - invocation_state=invocation_state, + streamed_events: AsyncGenerator[TypedEvent, None] | None = None + with trace_api.use_span(model_invoke_span, end_on_exit=False): + try: + await agent.hooks.invoke_callbacks_async( + BeforeModelCallEvent( + agent=agent, + invocation_state=invocation_state, + ) ) - ) - if structured_output_context.forced_mode: - tool_spec = structured_output_context.get_tool_spec() - tool_specs = [tool_spec] if tool_spec else [] - else: - tool_specs = agent.tool_registry.get_all_tool_specs() - try: - async for event in stream_messages( + if structured_output_context.forced_mode: + tool_spec = structured_output_context.get_tool_spec() + tool_specs = [tool_spec] if tool_spec else [] + else: + tool_specs = agent.tool_registry.get_all_tool_specs() + + streamed_events = stream_messages( agent.model, agent.system_prompt, agent.messages, @@ -337,8 +359,12 @@ async def _handle_model_execution( tool_choice=structured_output_context.tool_choice, invocation_state=invocation_state, cancel_signal=agent._cancel_signal, - ): - yield event + ) + try: + async for event in streamed_events: + yield event + finally: + await streamed_events.aclose() stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -360,17 +386,17 @@ async def _handle_model_execution( "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, ) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) continue # Retry the model call if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) - # Set attributes before span auto-closes tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: - # Exception is automatically recorded by use_span with end_on_exit=True + tracer.end_span_with_error(model_invoke_span, str(e), e) after_model_call_event = AfterModelCallEvent( agent=agent, invocation_state=invocation_state, @@ -399,6 +425,9 @@ async def _handle_model_execution( # No retry requested, raise the exception yield ForceStopEvent(reason=e) raise e + except BaseException as e: + tracer.end_span_with_error(model_invoke_span, str(e), e) + raise try: # Add message in trace and mark the end of the stream messages trace @@ -538,7 +567,7 @@ async def _handle_tool_execution( interrupts, structured_output=structured_output_result, ) - # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) @@ -556,7 +585,7 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) - # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 0471a7fcc..3786ec48f 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -184,7 +184,9 @@ def _end_span( self, span: Span, attributes: dict[str, AttributeValue] | None = None, - error: Exception | None = None, + error: BaseException | None = None, + error_message: str | None = None, + flush: bool = False, ) -> None: """Generic helper method to end a span. @@ -192,8 +194,10 @@ def _end_span( span: The span to end attributes: Optional attributes to set before ending the span error: Optional exception if an error occurred + error_message: Optional error message to set in the span status + flush: Force the tracer provider to flush after ending the span """ - if not span: + if not span or not span.is_recording(): return try: @@ -206,7 +210,8 @@ def _end_span( # Handle error if present if error: - span.set_status(StatusCode.ERROR, str(error)) + status_description = error_message or str(error) or type(error).__name__ + span.set_status(StatusCode.ERROR, status_description) span.record_exception(error) else: span.set_status(StatusCode.OK) @@ -214,14 +219,13 @@ def _end_span( logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() - # Force flush to ensure spans are exported - if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): + if flush and self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None: """End a span with error status. Args: @@ -229,11 +233,11 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Excepti error_message: Error message to set in the span status. exception: Optional exception to record in the span. """ - if not span: + if not span or not span.is_recording(): return error = exception or Exception(error_message) - self._end_span(span, error=error) + self._end_span(span, error=error, error_message=error_message) def _add_event( self, span: Span | None, event_name: str, event_attributes: Attributes, to_span_attributes: bool = False @@ -325,18 +329,15 @@ def end_model_invoke_span( ) -> None: """End a model invocation span with results and metrics. - Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. - Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. - Args: - span: The span to set attributes on. + span: The span to end. message: The message response from the model. usage: Token usage information from the model call. metrics: Metrics from the model call. stop_reason: The reason the model stopped generating. """ - # Set end time attribute - span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) + if not span or not span.is_recording(): + return attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], @@ -373,7 +374,7 @@ def end_model_invoke_span( event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, ) - span.set_attributes(attributes) + self._end_span(span, attributes) def start_tool_call_span( self, @@ -445,7 +446,9 @@ def start_tool_call_span( return span - def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: + def end_tool_call_span( + self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None + ) -> None: """End a tool call span with results. Args: @@ -548,20 +551,14 @@ def end_event_loop_cycle_span( ) -> None: """End an event loop cycle span with results. - Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. - Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. - Args: - span: The span to set attributes on. + span: The span to end. message: The message response from this cycle. tool_result_message: Optional tool result message if a tool was called. """ - if not span: + if not span or not span.is_recording(): return - # Set end time attribute - span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) - event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: @@ -586,6 +583,8 @@ def end_event_loop_cycle_span( else: self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) + self._end_span(span) + def start_agent_span( self, messages: Messages, @@ -649,7 +648,7 @@ def end_agent_span( self, span: Span, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """End an agent span with results and metrics. @@ -701,7 +700,7 @@ def end_agent_span( } ) - self._end_span(span, attributes, error) + self._end_span(span, attributes, error, flush=True) def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]]: """Constructs a list of tool definitions from the provided tools_config.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..b5113411a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1415,6 +1415,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist): + """Test that stream_async ends the agent span when a BaseException occurs.""" + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + test_exception = KeyboardInterrupt("stop now") + mock_model.mock_stream.side_effect = test_exception + + agent = Agent(model=mock_model) + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = agent.stream_async("test prompt") + await alist(stream) + + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + def test_agent_init_with_state_object(): agent = Agent(state=AgentState({"foo": "bar"})) assert agent.state.get("foo") == "bar" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0cabeaeee..b56a5c146 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -5,6 +5,9 @@ from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import strands import strands.telemetry @@ -19,6 +22,7 @@ ) from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry from strands.types._events import EventLoopStopEvent @@ -578,6 +582,14 @@ async def test_event_loop_tracing_with_model_error( ) await alist(stream) + assert mock_tracer.end_span_with_error.call_count == 2 + mock_tracer.end_span_with_error.assert_has_calls( + [ + call(model_span, "Input too long", model.stream.side_effect), + call(cycle_span, "Input too long", model.stream.side_effect), + ] + ) + @pytest.mark.asyncio async def test_event_loop_cycle_max_tokens_exception( @@ -668,6 +680,168 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_stream_aclose( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(10) + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await anext(stream) + await anext(stream) + await anext(stream) + await stream.aclose() + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_task_cancellation( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + blocked_on_stream = asyncio.Event() + release_stream = asyncio.Event() + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + blocked_on_stream.set() + await release_stream.wait() + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + async def consume() -> None: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + async for _ in stream: + pass + + task = asyncio.create_task(consume()) + await blocked_on_stream.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt( + mock_get_tracer, + agent, + model, + mock_tracer, + alist, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + test_exception = KeyboardInterrupt("stop now") + model.stream.side_effect = test_exception + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + assert mock_tracer.end_span_with_error.call_args_list == [ + call(model_span, "stop now", test_exception), + call(cycle_span, "stop now", test_exception), + ] + + +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( + agent, + model, + tool_stream, + agenerator, + alist, +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + async def delayed_text_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(0.05) + yield {"contentBlockStop": {}} + + agent.trace_span = None + agent._system_prompt_content = None + model.config = {"model_id": "test-model"} + model.stream.side_effect = [ + agenerator(tool_stream), + delayed_text_stream(), + ] + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + provider.force_flush() + cycle_spans = sorted( + [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"], + key=lambda span: span.start_time, + ) + + assert len(cycle_spans) == 2 + assert cycle_spans[0].end_time <= cycle_spans[1].start_time + assert cycle_spans[0].end_time < cycle_spans[1].end_time + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_throttling_exception( @@ -704,6 +878,7 @@ async def test_event_loop_tracing_with_throttling_exception( ) await alist(stream) + assert mock_tracer.end_span_with_error.call_count == 1 # Verify span was created for the successful retry assert mock_tracer.start_model_invoke_span.call_count == 2 assert mock_tracer.end_model_invoke_span.call_count == 1 diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index ad792f52c..2d1150712 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -4,16 +4,22 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode from pydantic import BaseModel from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.registry import ToolRegistry from strands.tools.structured_output._structured_output_context import ( DEFAULT_STRUCTURED_OUTPUT_PROMPT, StructuredOutputContext, ) from strands.types._events import EventLoopStopEvent, StructuredOutputEvent +from strands.types.exceptions import EventLoopException, StructuredOutputException class UserModel(BaseModel): @@ -253,6 +259,85 @@ async def test_event_loop_forces_structured_output_with_custom_prompt(mock_agent assert args["content"][0]["text"] == custom_prompt +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_structured_output_failure_closes_cycle_span_with_error( + mock_get_tracer, + mock_agent, + structured_output_context, + agenerator, + alist, +): + mock_tracer = Mock() + cycle_span = Mock() + model_span = Mock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + mock_tracer.start_model_invoke_span.return_value = model_span + mock_get_tracer.return_value = mock_tracer + + structured_output_context.set_forced_mode() + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Still not structured"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + expected_message = "The model failed to invoke the structured output tool even after it was forced." + with pytest.raises(StructuredOutputException, match=expected_message): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + mock_tracer.end_model_invoke_span.assert_called_once() + mock_tracer.end_event_loop_cycle_span.assert_not_called() + mock_tracer.end_span_with_error.assert_called_once() + assert mock_tracer.end_span_with_error.call_args.args[0] == cycle_span + assert mock_tracer.end_span_with_error.call_args.args[1] == expected_message + assert isinstance(mock_tracer.end_span_with_error.call_args.args[2], StructuredOutputException) + + +@pytest.mark.asyncio +async def test_event_loop_forced_structured_output_append_failure_records_error_span( + mock_agent, structured_output_context, agenerator, alist +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + mock_agent._append_messages = AsyncMock(side_effect=RuntimeError("append failed")) + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + with pytest.raises(EventLoopException, match="append failed"): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + finished_cycle_spans = [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"] + + assert len(finished_cycle_spans) == 1 + assert finished_cycle_spans[0].status.status_code == StatusCode.ERROR + + @pytest.mark.asyncio async def test_structured_output_tool_execution_extracts_result( mock_agent, structured_output_context, agenerator, alist diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 410db0c0c..f35a2798d 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -128,6 +128,42 @@ def test_end_span_with_error_message(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): + """Test that empty exception messages fall back to the exception type name.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span): + """Test that empty BaseException messages fall back to the exception type name.""" + tracer = Tracer() + error = KeyboardInterrupt() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_span_with_error_prefers_explicit_message(mock_span): + """Test that an explicit error message takes precedence over the exception text.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "Explicit error message", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Explicit error message") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -251,6 +287,8 @@ def test_end_model_invoke_span(mock_span): "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): @@ -290,6 +328,8 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): ), }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_start_tool_call_span(mock_tracer): @@ -690,6 +730,8 @@ def test_end_event_loop_cycle_span(mock_span): "tool.result": json.dumps(tool_result_message["content"]), }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): @@ -725,6 +767,8 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): ) }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_start_agent_span(mock_tracer): @@ -958,6 +1002,8 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): "gen_ai.server.time_to_first_token": 5, } ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_agent_span_with_cache_metrics(mock_span): @@ -1043,21 +1089,64 @@ def test_end_span_with_exception_handling(mock_span): pytest.fail("_end_span should not raise exceptions") +def test_end_span_does_not_force_flush_by_default(mock_span, mock_get_tracer_provider): + """Test that ending a regular span does not force flush by default.""" + tracer = Tracer() + mock_tracer_provider = mock_get_tracer_provider.return_value + + tracer._end_span(mock_span) + + mock_tracer_provider.force_flush.assert_not_called() + + def test_force_flush_with_error(mock_span, mock_get_tracer_provider): """Test force flush with error handling.""" - # Setup the tracer with a provider that raises an exception on force_flush tracer = Tracer() mock_tracer_provider = mock_get_tracer_provider.return_value mock_tracer_provider.force_flush.side_effect = Exception("Force flush error") # Should not raise an exception - tracer._end_span(mock_span) + tracer._end_span(mock_span, flush=True) # Verify force_flush was called mock_tracer_provider.force_flush.assert_called_once() +def test_end_agent_span_force_flushes(mock_span, mock_get_tracer_provider): + """Test that ending an agent span forces a flush.""" + tracer = Tracer() + mock_tracer_provider = mock_get_tracer_provider.return_value + + tracer.end_agent_span(mock_span) + + mock_tracer_provider.force_flush.assert_called_once() + + +def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that agent spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_agent_span(mock_span, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that tool call spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_tool_call_span(mock_span, None, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer()