Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 187 additions & 14 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
BaseAgent,
Content,
ContinuationToken,
HistoryProvider,
Message,
ResponseStream,
SessionContext,
normalize_messages,
prepend_agent_framework_to_user_agent,
)
Expand Down Expand Up @@ -267,6 +269,9 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
Keyword Args:
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
When provided, the session's ``session_id`` is used as the A2A
``context_id`` so that the remote agent can correlate
messages belonging to the same conversation.
function_invocation_kwargs: Present for compatibility with the shared agent interface.
A2AAgent does not use these values directly.
client_kwargs: Present for compatibility with the shared agent interface.
Expand All @@ -284,17 +289,41 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
When stream=True: A ResponseStream of AgentResponseUpdate items.
"""
del function_invocation_kwargs, client_kwargs, kwargs
normalized_messages = normalize_messages(messages)

# Derive context_id from session when available so the remote agent
# can correlate messages belonging to the same conversation.
context_id: str | None = session.session_id if session else None

if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
)
else:
normalized_messages = normalize_messages(messages)
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
if not normalized_messages:
raise ValueError("At least one message is required when starting a new task (no continuation_token).")
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1], context_id=context_id)
a2a_stream = self.client.send_message(a2a_message)
Comment on lines +292 to 306
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds context_id derivation from session.session_id and new fallback behavior in _prepare_message_for_a2a, but there are existing unit tests for A2AAgent and none currently validate that run(..., session=...) actually sends an A2AMessage with the expected context_id (or that additional_properties['context_id'] is honored/filtered from metadata). Add focused tests to prevent regressions in session/context correlation.

Copilot uses AI. Check for mistakes.

provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()

session_context = SessionContext(
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=normalized_messages or [],
options={},
)

response = ResponseStream(
self._map_a2a_stream(a2a_stream, background=background),
self._map_a2a_stream(
a2a_stream,
background=background,
emit_intermediate=stream,
session=provider_session,
session_context=session_context,
),
finalizer=AgentResponse.from_updates,
)
if stream:
Expand All @@ -306,6 +335,9 @@ async def _map_a2a_stream(
a2a_stream: AsyncIterable[A2AStreamItem],
*,
background: bool = False,
emit_intermediate: bool = False,
session: AgentSession | None = None,
session_context: SessionContext | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Map raw A2A protocol items to AgentResponseUpdates.

Expand All @@ -316,38 +348,110 @@ async def _map_a2a_stream(
background: When False, in-progress task updates are silently
consumed (the stream keeps iterating until a terminal state).
When True, they are yielded with a continuation token.
emit_intermediate: When True, in-progress status updates that
carry message content are yielded to the caller. Typically
set for streaming callers so non-streaming consumers only
receive terminal task outputs.
session: The agent session for context providers.
session_context: The session context for context providers.
"""
if session_context is None:
session_context = SessionContext(input_messages=[], options={})

# Run before_run providers (forward order)
for provider in self.context_providers:
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
if session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session,
context=session_context,
state=session.state.setdefault(provider.source_id, {}),
)

all_updates: list[AgentResponseUpdate] = []
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
async for item in a2a_stream:
if isinstance(item, A2AMessage):
# Process A2A Message
contents = self._parse_contents_from_a2a(item.parts)
yield AgentResponseUpdate(
update = AgentResponseUpdate(
contents=contents,
role="assistant" if item.role == A2ARole.agent else "user",
response_id=str(getattr(item, "message_id", uuid.uuid4())),
raw_representation=item,
)
all_updates.append(update)
yield update
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
task, _update_event = item
for update in self._updates_from_task(task, background=background):
task, update_event = item
updates = self._updates_from_task(
task,
update_event=update_event,
background=background,
emit_intermediate=emit_intermediate,
streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id),
)
if isinstance(update_event, TaskArtifactUpdateEvent) and any(
update.raw_representation is update_event for update in updates
):
streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id)
if task.status.state in TERMINAL_TASK_STATES:
streamed_artifact_ids_by_task.pop(task.id, None)
for update in updates:
all_updates.append(update)
yield update
else:
raise NotImplementedError("Only Message and Task responses are supported")

# Set the response on the context for after_run providers
if all_updates:
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]

await self._run_after_providers(session=session, context=session_context)

# ------------------------------------------------------------------
# Task helpers
# ------------------------------------------------------------------

def _updates_from_task(self, task: Task, *, background: bool = False) -> list[AgentResponseUpdate]:
def _updates_from_task(
self,
task: Task,
*,
update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None,
background: bool = False,
emit_intermediate: bool = False,
streamed_artifact_ids: set[str] | None = None,
) -> list[AgentResponseUpdate]:
"""Convert an A2A Task into AgentResponseUpdate(s).

Terminal tasks produce updates from their artifacts/history.
In-progress tasks produce a continuation token update only when
``background=True``; otherwise they are silently skipped so the
caller keeps consuming the stream until completion.
In-progress tasks produce a continuation token update when
``background=True``. When ``emit_intermediate=True`` (typically
set for streaming callers), any message content attached to an
in-progress status update is surfaced; otherwise the update is
silently skipped so the caller keeps consuming the stream until
completion.
"""
if task.status.state in TERMINAL_TASK_STATES:
status = task.status

if (
emit_intermediate
and update_event is not None
and (event_updates := self._updates_from_task_update_event(update_event))
):
return event_updates

if status.state in TERMINAL_TASK_STATES:
task_messages = self._parse_messages_from_task(task)
if task.artifacts is not None and streamed_artifact_ids:
task_messages = [
message
for message in task_messages
if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids
]
if task_messages:
return [
AgentResponseUpdate(
Expand All @@ -359,9 +463,11 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag
)
for message in task_messages
]
if task.artifacts is not None:
return []
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]

if background and task.status.state in IN_PROGRESS_TASK_STATES:
if background and status.state in IN_PROGRESS_TASK_STATES:
token = self._build_continuation_token(task)
return [
AgentResponseUpdate(
Expand All @@ -373,8 +479,66 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag
)
]

# Surface message content from in-progress status updates (e.g. working state)
# Only emitted when the caller opts in (streaming), so non-streaming
# consumers keep receiving only terminal task outputs.
if (
emit_intermediate
and status.state in IN_PROGRESS_TASK_STATES
and status.message is not None
and status.message.parts
):
contents = self._parse_contents_from_a2a(status.message.parts)
if contents:
return [
AgentResponseUpdate(
contents=contents,
role="assistant" if status.message.role == A2ARole.agent else "user",
response_id=task.id,
raw_representation=task,
)
]

return []

def _updates_from_task_update_event(
self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent
) -> list[AgentResponseUpdate]:
"""Convert A2A task update events into streaming AgentResponseUpdates."""
if isinstance(update_event, TaskArtifactUpdateEvent):
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
if not contents:
return []
return [
AgentResponseUpdate(
contents=contents,
role="assistant",
response_id=update_event.task_id,
message_id=update_event.artifact.artifact_id,
raw_representation=update_event,
)
]

if not isinstance(update_event, TaskStatusUpdateEvent):
return []

message = update_event.status.message
if message is None or not message.parts:
return []

contents = self._parse_contents_from_a2a(message.parts)
if not contents:
return []

return [
AgentResponseUpdate(
contents=contents,
role="assistant" if message.role == A2ARole.agent else "user",
response_id=update_event.task_id,
raw_representation=update_event,
)
]

@staticmethod
def _build_continuation_token(task: Task) -> A2AContinuationToken | None:
"""Build an A2AContinuationToken from an A2A Task if it is still in progress."""
Expand Down Expand Up @@ -403,7 +567,7 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp
return AgentResponse.from_updates(updates)
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)

def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage:
"""Prepare a Message for the A2A protocol.

Transforms Agent Framework Message objects into A2A protocol Messages by:
Expand All @@ -412,6 +576,14 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
- Converting file references (URI/data/hosted_file) to FilePart objects
- Preserving metadata and additional properties from the original message
- Setting the role to 'user' as framework messages are treated as user input

Args:
message: The framework Message to convert.

Keyword Args:
context_id: Optional A2A context ID to associate this message with a
conversation session. When provided, the remote agent can correlate
multiple messages belonging to the same conversation.
"""
parts: list[A2APart] = []
if not message.contents:
Expand Down Expand Up @@ -486,13 +658,14 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
raise ValueError(f"Unknown content type: {content.type}")

# Exclude framework-internal keys (e.g. attribution) from wire metadata
internal_keys = {"_attribution"}
internal_keys = {"_attribution", "context_id"}
metadata = {k: v for k, v in message.additional_properties.items() if k not in internal_keys} or None

return A2AMessage(
role=A2ARole("user"),
parts=parts,
message_id=message.message_id or uuid.uuid4().hex,
context_id=context_id or message.additional_properties.get("context_id") or uuid.uuid4().hex,
metadata=metadata,
)

Expand Down
Loading