Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/packages/claude/agent_framework_claude/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,12 +618,24 @@ def run(
"""
response = ResponseStream(
self._get_stream(messages, session=session, options=options, **kwargs),
finalizer=AgentResponse.from_updates,
finalizer=self._finalize_response,
)
if stream:
return response
return response.get_final_response()

def _finalize_response(self, updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
"""Build AgentResponse and propagate structured_output as value.

Args:
updates: The collected stream updates.

Returns:
An AgentResponse with structured_output set as value if present.
"""
structured_output = getattr(self, "_structured_output", None)
return AgentResponse.from_updates(updates, value=structured_output)

async def _get_stream(
self,
messages: AgentRunInputs | None = None,
Expand All @@ -647,6 +659,7 @@ async def _get_stream(
await self._apply_runtime_options(dict(options) if options else None)

session_id: str | None = None
structured_output: Any = None

await self._client.query(prompt)
async for message in self._client.receive_response():
Expand Down Expand Up @@ -700,7 +713,11 @@ async def _get_stream(
error_msg = message.result or "Unknown error from Claude API"
raise AgentException(f"Claude API error: {error_msg}")
session_id = message.session_id
structured_output = message.structured_output

# Update session with session ID
if session_id:
session.service_session_id = session_id

# Store structured output for the finalizer
self._structured_output = structured_output
160 changes: 160 additions & 0 deletions python/packages/claude/tests/test_claude_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,163 @@ async def test_apply_runtime_options_none(self) -> None:
await agent._apply_runtime_options(None) # type: ignore[reportPrivateUsage]
mock_client.set_model.assert_not_called()
mock_client.set_permission_mode.assert_not_called()


# region Test ClaudeAgent Structured Output


class TestClaudeAgentStructuredOutput:
"""Tests for ClaudeAgent structured output propagation."""

@staticmethod
async def _create_async_generator(items: list[Any]) -> Any:
"""Helper to create async generator from list."""
for item in items:
yield item

def _create_mock_client(self, messages: list[Any]) -> MagicMock:
"""Create a mock ClaudeSDKClient that yields given messages."""
mock_client = MagicMock()
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client.query = AsyncMock()
mock_client.set_model = AsyncMock()
mock_client.set_permission_mode = AsyncMock()
mock_client.receive_response = MagicMock(return_value=self._create_async_generator(messages))
return mock_client

async def test_structured_output_propagated_to_response(self) -> None:
"""Test that structured_output from ResultMessage is propagated to response.value."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

structured_data = {"name": "Alice", "age": 30}
messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": '{"name": "Alice", "age": 30}'},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text='{"name": "Alice", "age": 30}')],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
structured_output=structured_data,
),
]
mock_client = self._create_mock_client(messages)

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent()
response = await agent.run("Return structured data")
assert response.value == structured_data

async def test_structured_output_none_when_not_present(self) -> None:
"""Test that response.value is None when structured_output is not present."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Hello!"},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text="Hello!")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
),
]
mock_client = self._create_mock_client(messages)

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent()
response = await agent.run("Hello")
assert response.value is None

async def test_structured_output_with_streaming(self) -> None:
"""Test that structured_output is available via get_final_response after streaming."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

structured_data = {"key": "value"}
messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": '{"key": "value"}'},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text='{"key": "value"}')],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
structured_output=structured_data,
),
]
mock_client = self._create_mock_client(messages)

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent()
stream = agent.run("Return structured data", stream=True)
# Consume the stream
async for _ in stream:
pass
# Structured output should be available via get_final_response
response = await stream.get_final_response()
assert response.value == structured_data

async def test_structured_output_with_error_does_not_propagate(self) -> None:
"""Test that structured_output is not propagated when ResultMessage is an error."""
from agent_framework.exceptions import AgentException
from claude_agent_sdk import ResultMessage

messages = [
ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="error-session",
result="Something went wrong",
structured_output={"some": "data"},
),
]
mock_client = self._create_mock_client(messages)

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent()
with pytest.raises(AgentException) as exc_info:
await agent.run("Hello")
assert "Something went wrong" in str(exc_info.value)
6 changes: 5 additions & 1 deletion python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,7 @@ def from_updates(
updates: Sequence[AgentResponseUpdate],
*,
output_format_type: type[ResponseModelBoundT],
value: Any | None = None,
) -> AgentResponse[ResponseModelBoundT]: ...

@overload
Expand All @@ -2265,6 +2266,7 @@ def from_updates(
updates: Sequence[AgentResponseUpdate],
*,
output_format_type: None = None,
value: Any | None = None,
) -> AgentResponse[Any]: ...

@classmethod
Expand All @@ -2273,6 +2275,7 @@ def from_updates(
updates: Sequence[AgentResponseUpdate],
*,
output_format_type: type[BaseModel] | None = None,
value: Any | None = None,
) -> AgentResponseT:
"""Joins multiple updates into a single AgentResponse.

Expand All @@ -2281,8 +2284,9 @@ def from_updates(

Keyword Args:
output_format_type: Optional Pydantic model type to parse the response text into structured data.
value: Optional pre-parsed structured output value to set directly on the response.
"""
msg = cls(messages=[], response_format=output_format_type)
msg = cls(messages=[], response_format=output_format_type, value=value)
for update in updates:
_process_update(msg, update)
_finalize_response(msg)
Expand Down