diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5b0ae78f6..5081aa2a2 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,6 +5,8 @@ agent lifecycle. """ +import json +import uuid from collections.abc import Sequence from typing import TYPE_CHECKING, Any, cast @@ -24,6 +26,44 @@ from ..multiagent.base import MultiAgentResult, NodeResult +def _is_json_serializable(value: Any) -> bool: + """Check if a value is JSON-serializable. + + Args: + value: The value to check. + + Returns: + True if the value can be serialized to JSON, False otherwise. + """ + try: + json.dumps(value) + return True + except (TypeError, ValueError, OverflowError): + return False + + +def _sanitize_invocation_state(invocation_state: dict) -> dict: + """Filter invocation_state to only include JSON-serializable values. + + Non-serializable objects (Agent instances, OpenTelemetry Spans, etc.) are + silently dropped so that events yielded from stream_async() remain + JSON-serializable. UUID values are converted to their string representation. + + Args: + invocation_state: The raw invocation state dict. + + Returns: + A new dict containing only JSON-serializable key-value pairs. + """ + result = {} + for k, v in invocation_state.items(): + if isinstance(v, uuid.UUID): + result[k] = str(v) + elif _is_json_serializable(v): + result[k] = v + return result + + class TypedEvent(dict): """Base class for all typed events in the agent system.""" @@ -69,7 +109,7 @@ def __init__(self) -> None: @override def prepare(self, invocation_state: dict) -> None: - self.update(invocation_state) + self.update(_sanitize_invocation_state(invocation_state)) class StartEvent(TypedEvent): @@ -138,7 +178,7 @@ def is_callback_event(self) -> bool: @override def prepare(self, invocation_state: dict) -> None: if "delta" in self: - self.update(invocation_state) + self.update(_sanitize_invocation_state(invocation_state)) class ToolUseStreamEvent(ModelStreamEvent): @@ -270,7 +310,7 @@ def __init__(self, delay: int) -> None: @override def prepare(self, invocation_state: dict) -> None: - self.update(invocation_state) + self.update(_sanitize_invocation_state(invocation_state)) class ToolResultEvent(TypedEvent): diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 02c367ccc..2b342dc12 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -39,10 +39,7 @@ def mock_sleep(): any_props = { - "agent": ANY, "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, "request_state": {}, } @@ -177,7 +174,6 @@ async def test_stream_e2e_success(alist): "delta": {"text": "Invoking async tool"}, "event_loop_parent_cycle_id": ANY, "messages": ANY, - "model": ANY, "system_prompt": None, "tool_config": tool_config, }, @@ -191,7 +187,6 @@ async def test_stream_e2e_success(alist): "delta": {"toolUse": {"input": "{}"}}, "event_loop_parent_cycle_id": ANY, "messages": ANY, - "model": ANY, "system_prompt": None, "tool_config": tool_config, "type": "tool_use_stream", @@ -235,7 +230,6 @@ async def test_stream_e2e_success(alist): "delta": {"text": "Invoking streaming tool"}, "event_loop_parent_cycle_id": ANY, "messages": ANY, - "model": ANY, "system_prompt": None, "tool_config": tool_config, }, @@ -249,7 +243,6 @@ async def test_stream_e2e_success(alist): "delta": {"toolUse": {"input": "{}"}}, "event_loop_parent_cycle_id": ANY, "messages": ANY, - "model": ANY, "system_prompt": None, "tool_config": tool_config, "type": "tool_use_stream", @@ -301,7 +294,6 @@ async def test_stream_e2e_success(alist): "delta": {"text": "I invoked the tools!"}, "event_loop_parent_cycle_id": ANY, "messages": ANY, - "model": ANY, "system_prompt": None, "tool_config": tool_config, }, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..e954c0902 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -722,34 +722,25 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), unittest.mock.call( type="tool_use_stream", - agent=agent, current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, delta={"toolUse": {"input": '{"value"}'}}, event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, request_state={}, ), unittest.mock.call(event={"contentBlockStop": {}}), unittest.mock.call(event={"contentBlockStart": {"start": {}}}), unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), unittest.mock.call( - agent=agent, delta={"reasoningContent": {"text": "value"}}, event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, reasoning=True, reasoningText="value", request_state={}, ), unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), unittest.mock.call( - agent=agent, delta={"reasoningContent": {"signature": "value"}}, event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, reasoning=True, reasoning_signature="value", request_state={}, @@ -758,12 +749,9 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): unittest.mock.call(event={"contentBlockStart": {"start": {}}}), unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), unittest.mock.call( - agent=agent, data="value", delta={"text": "value"}, event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, request_state={}, ), unittest.mock.call(event={"contentBlockStop": {}}), @@ -1075,7 +1063,7 @@ async def test_event_loop(*args, **kwargs): tru_events = await alist(stream) exp_events = [ - {"init_event_loop": True, "callback_handler": mock_callback}, + {"init_event_loop": True}, {"data": "First chunk"}, {"data": "Second chunk"}, {"complete": True, "data": "Final chunk"}, diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6163faeb6..a508fa753 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -162,6 +162,30 @@ def test_prepare_without_delta(self): event.prepare(invocation_state) assert "request_id" not in event + def test_prepare_filters_non_serializable(self): + """Test prepare method filters out non-JSON-serializable values.""" + import json + + class NonSerializable: + pass + + event = ModelStreamEvent({"delta": "content"}) + invocation_state = { + "request_id": "456", + "agent": NonSerializable(), + "span": NonSerializable(), + "cycle_id": 123, + } + event.prepare(invocation_state) + # Serializable values should be present + assert event["request_id"] == "456" + assert event["cycle_id"] == 123 + # Non-serializable values should be filtered out + assert "agent" not in event + assert "span" not in event + # The resulting event dict must be JSON-serializable + json.dumps(dict(event)) # should not raise + class TestToolUseStreamEvent: """Tests for ToolUseStreamEvent."""