From 55214b0bd929347493149c62ab8e14c8a2baabd0 Mon Sep 17 00:00:00 2001 From: LEDazzio01 <170764058+LEDazzio01@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:48:57 -0400 Subject: [PATCH 1/2] fix: A2AAgent.run() ignores session parameter (#4663) --- .../a2a/agent_framework_a2a/_agent.py | 201 ++++++++++++++++-- 1 file changed, 187 insertions(+), 14 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index c954c90fc0..46ee4f22b7 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -37,8 +37,10 @@ BaseAgent, Content, ContinuationToken, + HistoryProvider, Message, ResponseStream, + SessionContext, normalize_messages, prepend_agent_framework_to_user_agent, ) @@ -267,6 +269,9 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + When provided, the session's ``session_id`` is used as the A2A + ``context_id`` so that the remote agent can correlate + messages belonging to the same conversation. function_invocation_kwargs: Present for compatibility with the shared agent interface. A2AAgent does not use these values directly. client_kwargs: Present for compatibility with the shared agent interface. @@ -284,17 +289,41 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] When stream=True: A ResponseStream of AgentResponseUpdate items. """ del function_invocation_kwargs, client_kwargs, kwargs + normalized_messages = normalize_messages(messages) + + # Derive context_id from session when available so the remote agent + # can correlate messages belonging to the same conversation. + context_id: str | None = session.session_id if session else None + if continuation_token is not None: a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( TaskIdParams(id=continuation_token["task_id"]) ) else: - normalized_messages = normalize_messages(messages) - a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) + if not normalized_messages: + raise ValueError("At least one message is required when starting a new task (no continuation_token).") + a2a_message = self._prepare_message_for_a2a(normalized_messages[-1], context_id=context_id) a2a_stream = self.client.send_message(a2a_message) + provider_session = session + if provider_session is None and self.context_providers: + provider_session = AgentSession() + + session_context = SessionContext( + session_id=provider_session.session_id if provider_session else None, + service_session_id=provider_session.service_session_id if provider_session else None, + input_messages=normalized_messages or [], + options={}, + ) + response = ResponseStream( - self._map_a2a_stream(a2a_stream, background=background), + self._map_a2a_stream( + a2a_stream, + background=background, + emit_intermediate=stream, + session=provider_session, + session_context=session_context, + ), finalizer=AgentResponse.from_updates, ) if stream: @@ -306,6 +335,9 @@ async def _map_a2a_stream( a2a_stream: AsyncIterable[A2AStreamItem], *, background: bool = False, + emit_intermediate: bool = False, + session: AgentSession | None = None, + session_context: SessionContext | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Map raw A2A protocol items to AgentResponseUpdates. @@ -316,38 +348,110 @@ async def _map_a2a_stream( background: When False, in-progress task updates are silently consumed (the stream keeps iterating until a terminal state). When True, they are yielded with a continuation token. + emit_intermediate: When True, in-progress status updates that + carry message content are yielded to the caller. Typically + set for streaming callers so non-streaming consumers only + receive terminal task outputs. + session: The agent session for context providers. + session_context: The session context for context providers. """ + if session_context is None: + session_context = SessionContext(input_messages=[], options={}) + + # Run before_run providers (forward order) + for provider in self.context_providers: + if isinstance(provider, HistoryProvider) and not provider.load_messages: + continue + if session is None: + raise RuntimeError("Provider session must be available when context providers are configured.") + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=session, + context=session_context, + state=session.state.setdefault(provider.source_id, {}), + ) + + all_updates: list[AgentResponseUpdate] = [] + streamed_artifact_ids_by_task: dict[str, set[str]] = {} async for item in a2a_stream: if isinstance(item, A2AMessage): # Process A2A Message contents = self._parse_contents_from_a2a(item.parts) - yield AgentResponseUpdate( + update = AgentResponseUpdate( contents=contents, role="assistant" if item.role == A2ARole.agent else "user", response_id=str(getattr(item, "message_id", uuid.uuid4())), raw_representation=item, ) + all_updates.append(update) + yield update elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task): - task, _update_event = item - for update in self._updates_from_task(task, background=background): + task, update_event = item + updates = self._updates_from_task( + task, + update_event=update_event, + background=background, + emit_intermediate=emit_intermediate, + streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id), + ) + if isinstance(update_event, TaskArtifactUpdateEvent) and any( + update.raw_representation is update_event for update in updates + ): + streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id) + if task.status.state in TERMINAL_TASK_STATES: + streamed_artifact_ids_by_task.pop(task.id, None) + for update in updates: + all_updates.append(update) yield update else: raise NotImplementedError("Only Message and Task responses are supported") + # Set the response on the context for after_run providers + if all_updates: + session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment] + + await self._run_after_providers(session=session, context=session_context) + # ------------------------------------------------------------------ # Task helpers # ------------------------------------------------------------------ - def _updates_from_task(self, task: Task, *, background: bool = False) -> list[AgentResponseUpdate]: + def _updates_from_task( + self, + task: Task, + *, + update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None, + background: bool = False, + emit_intermediate: bool = False, + streamed_artifact_ids: set[str] | None = None, + ) -> list[AgentResponseUpdate]: """Convert an A2A Task into AgentResponseUpdate(s). Terminal tasks produce updates from their artifacts/history. - In-progress tasks produce a continuation token update only when - ``background=True``; otherwise they are silently skipped so the - caller keeps consuming the stream until completion. + In-progress tasks produce a continuation token update when + ``background=True``. When ``emit_intermediate=True`` (typically + set for streaming callers), any message content attached to an + in-progress status update is surfaced; otherwise the update is + silently skipped so the caller keeps consuming the stream until + completion. """ - if task.status.state in TERMINAL_TASK_STATES: + status = task.status + + if ( + emit_intermediate + and update_event is not None + and (event_updates := self._updates_from_task_update_event(update_event)) + ): + return event_updates + + if status.state in TERMINAL_TASK_STATES: task_messages = self._parse_messages_from_task(task) + if task.artifacts is not None and streamed_artifact_ids: + task_messages = [ + message + for message in task_messages + if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids + ] if task_messages: return [ AgentResponseUpdate( @@ -359,9 +463,11 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag ) for message in task_messages ] + if task.artifacts is not None: + return [] return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)] - if background and task.status.state in IN_PROGRESS_TASK_STATES: + if background and status.state in IN_PROGRESS_TASK_STATES: token = self._build_continuation_token(task) return [ AgentResponseUpdate( @@ -373,8 +479,66 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag ) ] + # Surface message content from in-progress status updates (e.g. working state) + # Only emitted when the caller opts in (streaming), so non-streaming + # consumers keep receiving only terminal task outputs. + if ( + emit_intermediate + and status.state in IN_PROGRESS_TASK_STATES + and status.message is not None + and status.message.parts + ): + contents = self._parse_contents_from_a2a(status.message.parts) + if contents: + return [ + AgentResponseUpdate( + contents=contents, + role="assistant" if status.message.role == A2ARole.agent else "user", + response_id=task.id, + raw_representation=task, + ) + ] + return [] + def _updates_from_task_update_event( + self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> list[AgentResponseUpdate]: + """Convert A2A task update events into streaming AgentResponseUpdates.""" + if isinstance(update_event, TaskArtifactUpdateEvent): + contents = self._parse_contents_from_a2a(update_event.artifact.parts) + if not contents: + return [] + return [ + AgentResponseUpdate( + contents=contents, + role="assistant", + response_id=update_event.task_id, + message_id=update_event.artifact.artifact_id, + raw_representation=update_event, + ) + ] + + if not isinstance(update_event, TaskStatusUpdateEvent): + return [] + + message = update_event.status.message + if message is None or not message.parts: + return [] + + contents = self._parse_contents_from_a2a(message.parts) + if not contents: + return [] + + return [ + AgentResponseUpdate( + contents=contents, + role="assistant" if message.role == A2ARole.agent else "user", + response_id=update_event.task_id, + raw_representation=update_event, + ) + ] + @staticmethod def _build_continuation_token(task: Task) -> A2AContinuationToken | None: """Build an A2AContinuationToken from an A2A Task if it is still in progress.""" @@ -403,7 +567,7 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp return AgentResponse.from_updates(updates) return AgentResponse(messages=[], response_id=task.id, raw_representation=task) - def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: + def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage: """Prepare a Message for the A2A protocol. Transforms Agent Framework Message objects into A2A protocol Messages by: @@ -412,6 +576,14 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: - Converting file references (URI/data/hosted_file) to FilePart objects - Preserving metadata and additional properties from the original message - Setting the role to 'user' as framework messages are treated as user input + + Args: + message: The framework Message to convert. + + Keyword Args: + context_id: Optional A2A context ID to associate this message with a + conversation session. When provided, the remote agent can correlate + multiple messages belonging to the same conversation. """ parts: list[A2APart] = [] if not message.contents: @@ -486,13 +658,14 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: raise ValueError(f"Unknown content type: {content.type}") # Exclude framework-internal keys (e.g. attribution) from wire metadata - internal_keys = {"_attribution"} + internal_keys = {"_attribution", "context_id"} metadata = {k: v for k, v in message.additional_properties.items() if k not in internal_keys} or None return A2AMessage( role=A2ARole("user"), parts=parts, message_id=message.message_id or uuid.uuid4().hex, + context_id=context_id or message.additional_properties.get("context_id") or uuid.uuid4().hex, metadata=metadata, ) From d802c8e38e6610adf81a985c5faf43f86745f868 Mon Sep 17 00:00:00 2001 From: "L. Elaine Dazzio" Date: Fri, 10 Apr 2026 18:04:17 -0400 Subject: [PATCH 2/2] test: add context_id propagation tests for A2AAgent --- .../packages/a2a/tests/test_a2a_context_id.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 python/packages/a2a/tests/test_a2a_context_id.py diff --git a/python/packages/a2a/tests/test_a2a_context_id.py b/python/packages/a2a/tests/test_a2a_context_id.py new file mode 100644 index 0000000000..27e86ddc25 --- /dev/null +++ b/python/packages/a2a/tests/test_a2a_context_id.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for A2AAgent context_id propagation from session.""" + +from unittest.mock import MagicMock + +from agent_framework import AgentSession, Content, Message +from agent_framework.a2a import A2AAgent + + +class MockA2AClient: + """Minimal mock for capturing sent messages.""" + + def __init__(self) -> None: + self.sent_messages: list = [] + + async def send_message(self, message): # type: ignore[no-untyped-def] + self.sent_messages.append(message) + return + yield # make it an async generator # noqa: RET504 + + +def test_context_id_derived_from_session() -> None: + """When a session is provided, _prepare_message_for_a2a uses session.session_id as context_id.""" + agent = A2AAgent(name="test", client=MagicMock(), http_client=None) + message = Message(role="user", contents=[Content.from_text(text="Hello")]) + + session = AgentSession() + a2a_msg = agent._prepare_message_for_a2a(message, context_id=session.session_id) + + assert a2a_msg.context_id == session.session_id + + +def test_context_id_falls_back_to_additional_properties() -> None: + """When context_id kwarg is None, additional_properties['context_id'] is used.""" + agent = A2AAgent(name="test", client=MagicMock(), http_client=None) + message = Message( + role="user", + contents=[Content.from_text(text="Hello")], + additional_properties={"context_id": "from-props"}, + ) + + a2a_msg = agent._prepare_message_for_a2a(message, context_id=None) + + assert a2a_msg.context_id == "from-props" + + +def test_context_id_generates_uuid_when_no_source() -> None: + """When no context_id is available from session or properties, a UUID is generated.""" + agent = A2AAgent(name="test", client=MagicMock(), http_client=None) + message = Message(role="user", contents=[Content.from_text(text="Hello")]) + + a2a_msg = agent._prepare_message_for_a2a(message, context_id=None) + + # Should be a non-empty hex string (uuid4().hex) + assert a2a_msg.context_id is not None + assert len(a2a_msg.context_id) == 32 # uuid4().hex is 32 chars + + +def test_explicit_context_id_overrides_additional_properties() -> None: + """When both context_id kwarg and additional_properties are set, kwarg wins.""" + agent = A2AAgent(name="test", client=MagicMock(), http_client=None) + message = Message( + role="user", + contents=[Content.from_text(text="Hello")], + additional_properties={"context_id": "from-props"}, + ) + + a2a_msg = agent._prepare_message_for_a2a(message, context_id="from-session") + + assert a2a_msg.context_id == "from-session" + + +def test_context_id_not_duplicated_in_metadata() -> None: + """context_id should be filtered from wire metadata to avoid duplication.""" + agent = A2AAgent(name="test", client=MagicMock(), http_client=None) + message = Message( + role="user", + contents=[Content.from_text(text="Hello")], + additional_properties={"context_id": "ctx-123", "trace_id": "trace-456"}, + ) + + a2a_msg = agent._prepare_message_for_a2a(message, context_id="ctx-123") + + # context_id should NOT appear in metadata + assert a2a_msg.metadata == {"trace_id": "trace-456"} + # But should be set on the message itself + assert a2a_msg.context_id == "ctx-123"