Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ async def stream_async(

self._end_agent_trace_span(response=result)

except Exception as e:
except BaseException as e:
self._end_agent_trace_span(error=e)
raise

Expand Down Expand Up @@ -988,7 +988,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
def _end_agent_trace_span(
self,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""Ends a trace span for the agent.
Expand Down
151 changes: 90 additions & 61 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,26 +138,37 @@ async def event_loop_cycle(
custom_trace_attributes=agent.trace_attributes,
)
invocation_state["event_loop_cycle_span"] = cycle_span
model_events: AsyncGenerator[TypedEvent, None] | None = None

with trace_api.use_span(cycle_span, end_on_exit=True):
# Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls.
if agent._interrupt_state.activated:
stop_reason: StopReason = "tool_use"
message = agent._interrupt_state.context["tool_use_message"]
# Skip model invocation if the latest message contains ToolUse
elif _has_tool_use_in_latest_message(agent.messages):
stop_reason = "tool_use"
message = agent.messages[-1]
else:
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
with trace_api.use_span(cycle_span, end_on_exit=False):
try:
# Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls.
if agent._interrupt_state.activated:
stop_reason: StopReason = "tool_use"
message = agent._interrupt_state.context["tool_use_message"]
# Skip model invocation if the latest message contains ToolUse
elif _has_tool_use_in_latest_message(agent.messages):
stop_reason = "tool_use"
message = agent.messages[-1]
else:
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
try:
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
finally:
await model_events.aclose()

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise

try:
if stop_reason == "max_tokens":
Expand Down Expand Up @@ -196,41 +207,50 @@ async def event_loop_cycle(

# End the cycle and return results
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes)
# Set attributes before span auto-closes

# Force structured output tool call if LLM didn't use it automatically
if structured_output_context.is_enabled and stop_reason == "end_turn":
if structured_output_context.force_attempted:
raise StructuredOutputException(
"The model failed to invoke the structured output tool even after it was forced."
)
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
await agent._append_messages(
{"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]}
)

tracer.end_event_loop_cycle_span(cycle_span, message)
events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
async for typed_event in events:
yield typed_event
return

tracer.end_event_loop_cycle_span(cycle_span, message)
except EventLoopException:
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
except StructuredOutputException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
except EventLoopException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
# Don't yield or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
raise
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
tracer.end_span_with_error(cycle_span, str(e), e)
raise e
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
# Handle any other exceptions
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e

# Force structured output tool call if LLM didn't use it automatically
if structured_output_context.is_enabled and stop_reason == "end_turn":
if structured_output_context.force_attempted:
raise StructuredOutputException(
"The model failed to invoke the structured output tool even after it was forced."
)
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
await agent._append_messages(
{"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]}
)

events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
async for typed_event in events:
yield typed_event
return

yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise


async def recurse_event_loop(
Expand Down Expand Up @@ -314,21 +334,23 @@ async def _handle_model_execution(
model_id=model_id,
custom_trace_attributes=agent.trace_attributes,
)
with trace_api.use_span(model_invoke_span, end_on_exit=True):
await agent.hooks.invoke_callbacks_async(
BeforeModelCallEvent(
agent=agent,
invocation_state=invocation_state,
streamed_events: AsyncGenerator[TypedEvent, None] | None = None
with trace_api.use_span(model_invoke_span, end_on_exit=False):
try:
await agent.hooks.invoke_callbacks_async(
BeforeModelCallEvent(
agent=agent,
invocation_state=invocation_state,
)
)
)

if structured_output_context.forced_mode:
tool_spec = structured_output_context.get_tool_spec()
tool_specs = [tool_spec] if tool_spec else []
else:
tool_specs = agent.tool_registry.get_all_tool_specs()
try:
async for event in stream_messages(
if structured_output_context.forced_mode:
tool_spec = structured_output_context.get_tool_spec()
tool_specs = [tool_spec] if tool_spec else []
else:
tool_specs = agent.tool_registry.get_all_tool_specs()

streamed_events = stream_messages(
agent.model,
agent.system_prompt,
agent.messages,
Expand All @@ -337,8 +359,12 @@ async def _handle_model_execution(
tool_choice=structured_output_context.tool_choice,
invocation_state=invocation_state,
cancel_signal=agent._cancel_signal,
):
yield event
)
try:
async for event in streamed_events:
yield event
finally:
await streamed_events.aclose()

stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
Expand All @@ -360,17 +386,17 @@ async def _handle_model_execution(
"stop_reason=<%s>, retry_requested=<True> | hook requested model retry",
stop_reason,
)
tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason)
continue # Retry the model call

if stop_reason == "max_tokens":
message = recover_message_on_max_tokens_reached(message)

# Set attributes before span auto-closes
tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason)
break # Success! Break out of retry loop

except Exception as e:
# Exception is automatically recorded by use_span with end_on_exit=True
tracer.end_span_with_error(model_invoke_span, str(e), e)
after_model_call_event = AfterModelCallEvent(
agent=agent,
invocation_state=invocation_state,
Expand Down Expand Up @@ -399,6 +425,9 @@ async def _handle_model_execution(
# No retry requested, raise the exception
yield ForceStopEvent(reason=e)
raise e
except BaseException as e:
tracer.end_span_with_error(model_invoke_span, str(e), e)
raise

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down Expand Up @@ -538,7 +567,7 @@ async def _handle_tool_execution(
interrupts,
structured_output=structured_output_result,
)
# Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle)
# End the cycle span before yielding the recursive cycle.
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message)

Expand All @@ -556,7 +585,7 @@ async def _handle_tool_execution(

yield ToolResultMessageEvent(message=tool_result_message)

# Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle)
# End the cycle span before yielding the recursive cycle.
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)

Expand Down
Loading