diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 90237d8e5..3f1588a0b 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -9,7 +9,13 @@ ) from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer +from a2a.client.client import ( + Client, + ClientCallContext, + ClientConfig, + ClientEvent, + Consumer, +) from a2a.client.client_factory import ClientFactory, minimal_agent_card from a2a.client.errors import ( A2AClientError, @@ -17,7 +23,7 @@ AgentCardResolutionError, ) from a2a.client.helpers import create_text_message_object -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.interceptors import ClientCallInterceptor logger = logging.getLogger(__name__) diff --git a/src/a2a/client/auth/credentials.py b/src/a2a/client/auth/credentials.py index 11f323709..e3d74e4af 100644 --- a/src/a2a/client/auth/credentials.py +++ b/src/a2a/client/auth/credentials.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from a2a.client.middleware import ClientCallContext +from a2a.client.client import ClientCallContext class CredentialService(ABC): diff --git a/src/a2a/client/auth/interceptor.py b/src/a2a/client/auth/interceptor.py index a19c7a8ed..a29f9881c 100644 --- a/src/a2a/client/auth/interceptor.py +++ b/src/a2a/client/auth/interceptor.py @@ -1,9 +1,12 @@ import logging # noqa: I001 -from typing import Any from a2a.client.auth.credentials import CredentialService -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.types.a2a_pb2 import AgentCard +from a2a.client.client import ClientCallContext +from a2a.client.interceptors import ( + AfterArgs, + BeforeArgs, + ClientCallInterceptor, +) logger = logging.getLogger(__name__) @@ -17,36 +20,34 @@ class AuthInterceptor(ClientCallInterceptor): def __init__(self, credential_service: CredentialService): self._credential_service = credential_service - async def intercept( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any], - agent_card: AgentCard | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: + async def before(self, args: BeforeArgs) -> None: """Applies authentication headers to the request if credentials are available.""" + agent_card = args.agent_card + # Proto3 repeated fields (security) and maps (security_schemes) do not track presence. # HasField() raises ValueError for them. # We check for truthiness to see if they are non-empty. if ( - agent_card is None - or not agent_card.security_requirements + not agent_card.security_requirements or not agent_card.security_schemes ): - return request_payload, http_kwargs + return for requirement in agent_card.security_requirements: for scheme_name in requirement.schemes: credential = await self._credential_service.get_credentials( - scheme_name, context + scheme_name, args.context ) if credential and scheme_name in agent_card.security_schemes: scheme = agent_card.security_schemes.get(scheme_name) if not scheme: continue - headers = http_kwargs.get('headers', {}) + if args.context is None: + args.context = ClientCallContext() + + if args.context.service_parameters is None: + args.context.service_parameters = {} # HTTP Bearer authentication if ( @@ -54,25 +55,27 @@ async def intercept( and scheme.http_auth_security_scheme.scheme.lower() == 'bearer' ): - headers['Authorization'] = f'Bearer {credential}' + args.context.service_parameters['Authorization'] = ( + f'Bearer {credential}' + ) logger.debug( "Added Bearer token for scheme '%s'.", scheme_name, ) - http_kwargs['headers'] = headers - return request_payload, http_kwargs + return # OAuth2 and OIDC schemes are implicitly Bearer if scheme.HasField( 'oauth2_security_scheme' ) or scheme.HasField('open_id_connect_security_scheme'): - headers['Authorization'] = f'Bearer {credential}' + args.context.service_parameters['Authorization'] = ( + f'Bearer {credential}' + ) logger.debug( "Added Bearer token for scheme '%s'.", scheme_name, ) - http_kwargs['headers'] = headers - return request_payload, http_kwargs + return # API Key in Header if ( @@ -80,16 +83,16 @@ async def intercept( and scheme.api_key_security_scheme.location.lower() == 'header' ): - headers[scheme.api_key_security_scheme.name] = ( - credential - ) + args.context.service_parameters[ + scheme.api_key_security_scheme.name + ] = credential logger.debug( "Added API Key Header for scheme '%s'.", scheme_name, ) - http_kwargs['headers'] = headers - return request_payload, http_kwargs + return # Note: Other cases like API keys in query/cookie are not handled and will be skipped. - return request_payload, http_kwargs + async def after(self, args: AfterArgs) -> None: + """Invoked after the method is executed.""" diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index cc17b0349..a825ef50c 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -1,13 +1,19 @@ -from collections.abc import AsyncGenerator, AsyncIterator, Callable +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable +from typing import Any from a2a.client.client import ( Client, + ClientCallContext, ClientConfig, ClientEvent, Consumer, ) from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.interceptors import ( + AfterArgs, + BeforeArgs, + ClientCallInterceptor, +) from a2a.client.transports.base import ClientTransport from a2a.types.a2a_pb2 import ( AgentCard, @@ -37,12 +43,13 @@ def __init__( config: ClientConfig, transport: ClientTransport, consumers: list[Consumer], - middleware: list[ClientCallInterceptor], + interceptors: list[ClientCallInterceptor], ): - super().__init__(consumers, middleware) + super().__init__(consumers, interceptors) self._card = card self._config = config self._transport = transport + self._interceptors = interceptors async def send_message( self, @@ -65,8 +72,13 @@ async def send_message( """ self._apply_client_config(request) if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport.send_message( - request, context=context + response = await self._execute_with_interceptors( + input_data=request, + method='send_message', + context=context, + transport_call=lambda req, ctx: self._transport.send_message( + req, context=ctx + ), ) # In non-streaming case we convert to a StreamResponse so that the @@ -87,11 +99,15 @@ async def send_message( yield client_event return - stream = self._transport.send_message_streaming( - request, context=context - ) - async for client_event in self._process_stream(stream): - yield client_event + async for event in self._execute_stream_with_interceptors( + input_data=request, + method='send_message_streaming', + context=context, + transport_call=lambda req, ctx: ( + self._transport.send_message_streaming(req, context=ctx) + ), + ): + yield event def _apply_client_config(self, request: SendMessageRequest) -> None: request.configuration.return_immediately |= self._config.polling @@ -111,25 +127,26 @@ def _apply_client_config(self, request: SendMessageRequest) -> None: ) async def _process_stream( - self, stream: AsyncIterator[StreamResponse] + self, + stream: AsyncIterator[StreamResponse], + before_args: BeforeArgs, ) -> AsyncGenerator[ClientEvent]: tracker = ClientTaskManager() async for stream_response in stream: - client_event: ClientEvent - # When we get a message in the stream then we don't expect any - # further messages so yield and return - if stream_response.HasField('message'): - client_event = (stream_response, None) - await self.consume(client_event, self._card) - yield client_event - return - - # Otherwise track the task / task update then yield to the client - await tracker.process(stream_response) - updated_task = tracker.get_task_or_raise() - client_event = (stream_response, updated_task) - await self.consume(client_event, self._card) + after_args = AfterArgs( + result=stream_response, + method=before_args.method, + agent_card=self._card, + context=before_args.context, + ) + await self._intercept_after(after_args) + intercepted_response = after_args.result + client_event = await self._format_stream_event( + intercepted_response, tracker + ) yield client_event + if intercepted_response.HasField('message'): + return async def get_task( self, @@ -146,7 +163,14 @@ async def get_task( Returns: A `Task` object representing the current state of the task. """ - return await self._transport.get_task(request, context=context) + return await self._execute_with_interceptors( + input_data=request, + method='get_task', + context=context, + transport_call=lambda req, ctx: self._transport.get_task( + req, context=ctx + ), + ) async def list_tasks( self, @@ -155,7 +179,14 @@ async def list_tasks( context: ClientCallContext | None = None, ) -> ListTasksResponse: """Retrieves tasks for an agent.""" - return await self._transport.list_tasks(request, context=context) + return await self._execute_with_interceptors( + input_data=request, + method='list_tasks', + context=context, + transport_call=lambda req, ctx: self._transport.list_tasks( + req, context=ctx + ), + ) async def cancel_task( self, @@ -172,7 +203,14 @@ async def cancel_task( Returns: A `Task` object containing the updated task status. """ - return await self._transport.cancel_task(request, context=context) + return await self._execute_with_interceptors( + input_data=request, + method='cancel_task', + context=context, + transport_call=lambda req, ctx: self._transport.cancel_task( + req, context=ctx + ), + ) async def create_task_push_notification_config( self, @@ -189,8 +227,15 @@ async def create_task_push_notification_config( Returns: The created or updated `TaskPushNotificationConfig` object. """ - return await self._transport.create_task_push_notification_config( - request, context=context + return await self._execute_with_interceptors( + input_data=request, + method='create_task_push_notification_config', + context=context, + transport_call=lambda req, ctx: ( + self._transport.create_task_push_notification_config( + req, context=ctx + ) + ), ) async def get_task_push_notification_config( @@ -208,8 +253,15 @@ async def get_task_push_notification_config( Returns: A `TaskPushNotificationConfig` object containing the configuration. """ - return await self._transport.get_task_push_notification_config( - request, context=context + return await self._execute_with_interceptors( + input_data=request, + method='get_task_push_notification_config', + context=context, + transport_call=lambda req, ctx: ( + self._transport.get_task_push_notification_config( + req, context=ctx + ) + ), ) async def list_task_push_notification_configs( @@ -227,8 +279,15 @@ async def list_task_push_notification_configs( Returns: A `ListTaskPushNotificationConfigsResponse` object. """ - return await self._transport.list_task_push_notification_configs( - request, context=context + return await self._execute_with_interceptors( + input_data=request, + method='list_task_push_notification_configs', + context=context, + transport_call=lambda req, ctx: ( + self._transport.list_task_push_notification_configs( + req, context=ctx + ) + ), ) async def delete_task_push_notification_config( @@ -243,8 +302,15 @@ async def delete_task_push_notification_config( request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request. context: Optional client call context. """ - await self._transport.delete_task_push_notification_config( - request, context=context + return await self._execute_with_interceptors( + input_data=request, + method='delete_task_push_notification_config', + context=context, + transport_call=lambda req, ctx: ( + self._transport.delete_task_push_notification_config( + req, context=ctx + ) + ), ) async def subscribe( @@ -272,12 +338,15 @@ async def subscribe( 'client and/or server do not support resubscription.' ) - # Note: resubscribe can only be called on an existing task. As such, - # we should never see Message updates, despite the typing of the service - # definition indicating it may be possible. - stream = self._transport.subscribe(request, context=context) - async for client_event in self._process_stream(stream): - yield client_event + async for event in self._execute_stream_with_interceptors( + input_data=request, + method='subscribe', + context=context, + transport_call=lambda req, ctx: self._transport.subscribe( + req, context=ctx + ), + ): + yield event async def get_extended_agent_card( self, @@ -299,9 +368,13 @@ async def get_extended_agent_card( Returns: The `AgentCard` for the agent. """ - card = await self._transport.get_extended_agent_card( - request, + card = await self._execute_with_interceptors( + input_data=request, + method='get_extended_agent_card', context=context, + transport_call=lambda req, ctx: ( + self._transport.get_extended_agent_card(req, context=ctx) + ), ) if signature_verifier: signature_verifier(card) @@ -312,3 +385,129 @@ async def get_extended_agent_card( async def close(self) -> None: """Closes the underlying transport.""" await self._transport.close() + + async def _execute_with_interceptors( + self, + input_data: Any, + method: str, + context: ClientCallContext | None, + transport_call: Callable[ + [Any, ClientCallContext | None], Awaitable[Any] + ], + ) -> Any: + before_args = BeforeArgs( + input=input_data, + method=method, + agent_card=self._card, + context=context, + ) + before_result = await self._intercept_before(before_args) + + if before_result is not None: + early_after_args = AfterArgs( + result=before_result['early_return'], + method=method, + agent_card=self._card, + context=before_args.context, + ) + await self._intercept_after( + early_after_args, + before_result['executed'], + ) + return early_after_args.result + + result = await transport_call(before_args.input, before_args.context) + + after_args = AfterArgs( + result=result, + method=method, + agent_card=self._card, + context=before_args.context, + ) + await self._intercept_after(after_args) + + return after_args.result + + async def _execute_stream_with_interceptors( + self, + input_data: Any, + method: str, + context: ClientCallContext | None, + transport_call: Callable[ + [Any, ClientCallContext | None], AsyncIterator[StreamResponse] + ], + ) -> AsyncIterator[ClientEvent]: + + before_args = BeforeArgs( + input=input_data, + method=method, + agent_card=self._card, + context=context, + ) + before_result = await self._intercept_before(before_args) + + if before_result: + after_args = AfterArgs( + result=before_result['early_return'], + method=method, + agent_card=self._card, + context=before_args.context, + ) + await self._intercept_after(after_args, before_result['executed']) + + tracker = ClientTaskManager() + yield await self._format_stream_event(after_args.result, tracker) + return + + stream = transport_call(before_args.input, before_args.context) + + async for client_event in self._process_stream(stream, before_args): + yield client_event + + async def _intercept_before( + self, + args: BeforeArgs, + ) -> dict[str, Any] | None: + if not self._interceptors: + return None + executed: list[ClientCallInterceptor] = [] + for interceptor in self._interceptors: + await interceptor.before(args) + executed.append(interceptor) + if args.early_return: + return { + 'early_return': args.early_return, + 'executed': executed, + } + return None + + async def _intercept_after( + self, + args: AfterArgs, + interceptors: list[ClientCallInterceptor] | None = None, + ) -> None: + interceptors_to_use = ( + interceptors if interceptors is not None else self._interceptors + ) + + reversed_interceptors = list(reversed(interceptors_to_use)) + for interceptor in reversed_interceptors: + await interceptor.after(args) + if args.early_return: + return + + async def _format_stream_event( + self, stream_response: StreamResponse, tracker: ClientTaskManager + ) -> ClientEvent: + client_event: ClientEvent + if stream_response.HasField('message'): + client_event = (stream_response, None) + await self.consume(client_event, self._card) + return client_event + + await tracker.process(stream_response) + updated_task = tracker.get_task_or_raise() + client_event = (stream_response, updated_task) + + await self.consume(client_event, self._card) + return client_event diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index b19b2219d..6c715e5f0 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -2,16 +2,18 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Callable, Coroutine +from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping from types import TracebackType from typing import Any import httpx +from pydantic import BaseModel, Field from typing_extensions import Self -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.interceptors import ClientCallInterceptor from a2a.client.optionals import Channel +from a2a.client.service_parameters import ServiceParameters from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -82,6 +84,18 @@ class ClientConfig: Consumer = Callable[[ClientEvent, AgentCard], Coroutine[None, Any, Any]] +class ClientCallContext(BaseModel): + """A context passed with each client call, allowing for call-specific. + + configuration and data passing. Such as authentication details or + request deadlines. + """ + + state: MutableMapping[str, Any] = Field(default_factory=dict) + timeout: float | None = None + service_parameters: ServiceParameters | None = None + + class Client(ABC): """Abstract base class defining the interface for an A2A client. @@ -93,20 +107,16 @@ class Client(ABC): def __init__( self, consumers: list[Consumer] | None = None, - middleware: list[ClientCallInterceptor] | None = None, + interceptors: list[ClientCallInterceptor] | None = None, ): - """Initializes the client with consumers and middleware. + """Initializes the client with consumers and interceptors. Args: consumers: A list of callables to process events from the agent. - middleware: A list of interceptors to process requests and responses. + interceptors: A list of interceptors to process requests and responses. """ - if middleware is None: - middleware = [] - if consumers is None: - consumers = [] - self._consumers = consumers - self._middleware = middleware + self._consumers = consumers or [] + self._interceptors = interceptors or [] async def __aenter__(self) -> Self: """Enters the async context manager.""" @@ -227,11 +237,9 @@ async def add_event_consumer(self, consumer: Consumer) -> None: """Attaches additional consumers to the `Client`.""" self._consumers.append(consumer) - async def add_request_middleware( - self, middleware: ClientCallInterceptor - ) -> None: - """Attaches additional middleware to the `Client`.""" - self._middleware.append(middleware) + async def add_interceptor(self, interceptor: ClientCallInterceptor) -> None: + """Attaches additional interceptors to the `Client`.""" + self._interceptors.append(interceptor) async def consume( self, diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 30016d02c..400647b59 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import httpx @@ -12,7 +12,6 @@ from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver from a2a.client.client import Client, ClientConfig, Consumer -from a2a.client.middleware import ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport @@ -31,6 +30,10 @@ ) +if TYPE_CHECKING: + from a2a.client.interceptors import ClientCallInterceptor + + try: from a2a.client.transports.grpc import GrpcTransport except ImportError: @@ -46,7 +49,7 @@ TransportProducer = Callable[ - [AgentCard, str, ClientConfig, list[ClientCallInterceptor]], + [AgentCard, str, ClientConfig], ClientTransport, ] @@ -96,7 +99,6 @@ def jsonrpc_transport_producer( card: AgentCard, url: str, config: ClientConfig, - interceptors: list[ClientCallInterceptor], ) -> ClientTransport: interface = ClientFactory._find_best_interface( list(card.supported_interfaces), @@ -118,14 +120,12 @@ def jsonrpc_transport_producer( cast('httpx.AsyncClient', config.httpx_client), card, url, - interceptors, ) return JsonRpcTransport( cast('httpx.AsyncClient', config.httpx_client), card, url, - interceptors, ) self.register( @@ -138,7 +138,6 @@ def rest_transport_producer( card: AgentCard, url: str, config: ClientConfig, - interceptors: list[ClientCallInterceptor], ) -> ClientTransport: interface = ClientFactory._find_best_interface( list(card.supported_interfaces), @@ -160,14 +159,12 @@ def rest_transport_producer( cast('httpx.AsyncClient', config.httpx_client), card, url, - interceptors, ) return RestTransport( cast('httpx.AsyncClient', config.httpx_client), card, url, - interceptors, ) self.register( @@ -185,7 +182,6 @@ def grpc_transport_producer( card: AgentCard, url: str, config: ClientConfig, - interceptors: list[ClientCallInterceptor], ) -> ClientTransport: # The interface has already been selected and passed as `url`. # We determine its version to use the appropriate transport implementation. @@ -204,12 +200,10 @@ def grpc_transport_producer( ClientFactory._is_legacy_version(version) and CompatGrpcTransport is not None ): - return CompatGrpcTransport.create( - card, url, config, interceptors - ) + return CompatGrpcTransport.create(card, url, config) if GrpcTransport is not None: - return GrpcTransport.create(card, url, config, interceptors) + return GrpcTransport.create(card, url, config) raise ImportError( 'GrpcTransport is not available. ' @@ -410,7 +404,7 @@ def create( all_consumers.extend(consumers) transport = self._registry[transport_protocol]( - card, selected_interface.url, self._config, interceptors or [] + card, selected_interface.url, self._config ) if selected_interface.tenant: diff --git a/src/a2a/client/interceptors.py b/src/a2a/client/interceptors.py new file mode 100644 index 000000000..9903708f3 --- /dev/null +++ b/src/a2a/client/interceptors.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from a2a.client.client import ClientCallContext + +from a2a.types.a2a_pb2 import ( # noqa: TC001 + AgentCard, +) + + +@dataclass +class BeforeArgs: + """Arguments passed to the interceptor before a method call.""" + + input: Any + method: str + agent_card: AgentCard + context: ClientCallContext | None = None + early_return: Any | None = None + + +@dataclass +class AfterArgs: + """Arguments passed to the interceptor after a method call completes.""" + + result: Any + method: str + agent_card: AgentCard + context: ClientCallContext | None = None + early_return: bool = False + + +class ClientCallInterceptor(ABC): + """An abstract base class for client-side call interceptors. + + Interceptors can inspect and modify requests before they are sent, + which is ideal for concerns like authentication, logging, or tracing. + """ + + @abstractmethod + async def before(self, args: BeforeArgs) -> None: + """Invoked before transport method.""" + + @abstractmethod + async def after(self, args: AfterArgs) -> None: + """Invoked after transport method.""" diff --git a/src/a2a/client/middleware.py b/src/a2a/client/middleware.py deleted file mode 100644 index a852c93a7..000000000 --- a/src/a2a/client/middleware.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import MutableMapping # noqa: TC003 -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel, Field - -from a2a.client.service_parameters import ServiceParameters # noqa: TC001 - - -if TYPE_CHECKING: - from a2a.types.a2a_pb2 import AgentCard - - -class ClientCallContext(BaseModel): - """A context passed with each client call, allowing for call-specific. - - configuration and data passing. Such as authentication details or - request deadlines. - """ - - state: MutableMapping[str, Any] = Field(default_factory=dict) - timeout: float | None = None - service_parameters: ServiceParameters | None = None - - -class ClientCallInterceptor(ABC): - """An abstract base class for client-side call interceptors. - - Interceptors can inspect and modify requests before they are sent, - which is ideal for concerns like authentication, logging, or tracing. - """ - - @abstractmethod - async def intercept( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any], - agent_card: AgentCard | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Intercepts a client call before the request is sent. - - Args: - method_name: The name of the RPC method (e.g., 'message/send'). - request_payload: The JSON RPC request payload dictionary. - http_kwargs: The keyword arguments for the httpx request. - agent_card: The AgentCard associated with the client. - context: The ClientCallContext for this specific call. - - Returns: - A tuple containing the (potentially modified) request_payload - and http_kwargs. - """ diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index b840b9597..e46aae25e 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -4,7 +4,7 @@ from typing_extensions import Self -from a2a.client.middleware import ClientCallContext +from a2a.client.client import ClientCallContext from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 5ca1ac4f5..02c418eb3 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -4,8 +4,8 @@ from functools import wraps from typing import Any, NoReturn, cast +from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError, A2AClientTimeoutError -from a2a.client.middleware import ClientCallContext try: @@ -25,7 +25,6 @@ ) from a2a.client.client import ClientConfig -from a2a.client.middleware import ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport from a2a.types import a2a_pb2_grpc @@ -122,7 +121,6 @@ def create( card: AgentCard, url: str, config: ClientConfig, - interceptors: list[ClientCallInterceptor], ) -> 'GrpcTransport': """Creates a gRPC transport for the A2A client.""" if config.grpc_channel_factory is None: diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 43969dc40..0a5721b50 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -8,8 +8,8 @@ from httpx_sse import SSEError, aconnect_sse +from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError, A2AClientTimeoutError -from a2a.client.middleware import ClientCallContext @contextmanager diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index d40f1a0e1..9854aabb0 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -9,8 +9,8 @@ from google.protobuf import json_format from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response +from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.client.transports.http_helpers import ( get_http_args, @@ -55,13 +55,11 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard, url: str, - interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the JsonRpcTransport.""" self.url = url self.httpx_client = httpx_client self.agent_card = agent_card - self.interceptors = interceptors or [] async def send_message( self, diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 65ae850ae..27c0b6a0a 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -8,8 +8,8 @@ from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.client.transports.http_helpers import ( get_http_args, @@ -54,13 +54,11 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard, url: str, - interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the RestTransport.""" self.url = url.removesuffix('/') self.httpx_client = httpx_client self.agent_card = agent_card - self.interceptors = interceptors or [] async def send_message( self, diff --git a/src/a2a/client/transports/tenant_decorator.py b/src/a2a/client/transports/tenant_decorator.py index 07ef8213b..d1059d757 100644 --- a/src/a2a/client/transports/tenant_decorator.py +++ b/src/a2a/client/transports/tenant_decorator.py @@ -1,6 +1,6 @@ from collections.abc import AsyncGenerator -from a2a.client.middleware import ClientCallContext +from a2a.client.client import ClientCallContext from a2a.client.transports.base import ClientTransport from a2a.types.a2a_pb2 import ( AgentCard, diff --git a/src/a2a/compat/v0_3/grpc_transport.py b/src/a2a/compat/v0_3/grpc_transport.py index e862bcfa2..32ce7f27b 100644 --- a/src/a2a/compat/v0_3/grpc_transport.py +++ b/src/a2a/compat/v0_3/grpc_transport.py @@ -18,8 +18,7 @@ ) from e -from a2a.client.client import ClientConfig -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.client import ClientCallContext, ClientConfig from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport from a2a.compat.v0_3 import ( @@ -97,7 +96,6 @@ def create( card: a2a_pb2.AgentCard, url: str, config: ClientConfig, - interceptors: list[ClientCallInterceptor], ) -> 'CompatGrpcTransport': """Creates a gRPC transport for the A2A client.""" if config.grpc_channel_factory is None: diff --git a/src/a2a/compat/v0_3/jsonrpc_transport.py b/src/a2a/compat/v0_3/jsonrpc_transport.py index 0bfb854fd..6153ccfc0 100644 --- a/src/a2a/compat/v0_3/jsonrpc_transport.py +++ b/src/a2a/compat/v0_3/jsonrpc_transport.py @@ -9,8 +9,8 @@ from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response +from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.client.transports.http_helpers import ( get_http_args, @@ -58,13 +58,11 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard | None, url: str, - interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the CompatJsonRpcTransport.""" self.url = url self.httpx_client = httpx_client self.agent_card = agent_card - self.interceptors = interceptors or [] async def send_message( self, diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index f7f2d71c5..7b04f9d70 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -8,8 +8,8 @@ from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.client.transports.http_helpers import ( get_http_args, @@ -63,13 +63,11 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard | None, url: str, - interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the CompatRestTransport.""" self.url = url.removesuffix('/') self.httpx_client = httpx_client self.agent_card = agent_card - self.interceptors = interceptors or [] async def send_message( self, diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_interceptor.py similarity index 77% rename from tests/client/test_auth_middleware.py rename to tests/client/test_auth_interceptor.py index 4d7f9f7fa..8713c54eb 100644 --- a/tests/client/test_auth_middleware.py +++ b/tests/client/test_auth_interceptor.py @@ -1,3 +1,4 @@ +# ruff: noqa: INP001, S106 import json from collections.abc import Callable @@ -8,16 +9,17 @@ import pytest import respx +from google.protobuf import json_format + from a2a.client import ( AuthInterceptor, Client, ClientCallContext, - ClientCallInterceptor, ClientConfig, ClientFactory, InMemoryContextCredentialStore, ) -from a2a.utils.constants import TransportProtocol +from a2a.client.interceptors import BeforeArgs from a2a.types.a2a_pb2 import ( APIKeySecurityScheme, AgentCapabilities, @@ -36,35 +38,11 @@ SendMessageResponse, StringList, ) - - -class HeaderInterceptor(ClientCallInterceptor): - """A simple mock interceptor for testing basic middleware functionality.""" - - def __init__(self, header_name: str, header_value: str): - self.header_name = header_name - self.header_value = header_value - - async def intercept( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any], - agent_card: AgentCard | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - headers = http_kwargs.get('headers', {}) - headers[self.header_name] = self.header_value - http_kwargs['headers'] = headers - return request_payload, http_kwargs - - -from google.protobuf import json_format +from a2a.utils.constants import TransportProtocol def build_success_response(request: httpx.Request) -> httpx.Response: """Creates a valid JSON-RPC success response based on the request.""" - from a2a.types.a2a_pb2 import SendMessageResponse request_payload = json.loads(request.content) message = Message( @@ -120,19 +98,18 @@ async def test_auth_interceptor_skips_when_no_agent_card( store: InMemoryContextCredentialStore, ) -> None: """Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.""" - request_payload = {'foo': 'bar'} - http_kwargs = {'fizz': 'buzz'} auth_interceptor = AuthInterceptor(credential_service=store) - - new_payload, new_kwargs = await auth_interceptor.intercept( - method_name='SendMessage', - request_payload=request_payload, - http_kwargs=http_kwargs, - agent_card=None, - context=ClientCallContext(state={}), + request = SendMessageRequest(message=Message()) + context = ClientCallContext(state={}) + args = BeforeArgs( + input=request, + method='send_message', + agent_card=AgentCard(), + context=context, ) - assert new_payload == request_payload - assert new_kwargs == http_kwargs + + await auth_interceptor.before(args) + assert context.service_parameters is None @pytest.mark.asyncio @@ -172,52 +149,17 @@ async def test_in_memory_context_credential_store( assert await store.get_credentials(scheme_name, context) == new_credential -@pytest.mark.skip( - reason='Interceptors not explicitly being tested as per use request' -) -@pytest.mark.asyncio -@respx.mock -async def test_client_with_simple_interceptor() -> None: - """Ensures that a custom HeaderInterceptor correctly injects a static header into outbound HTTP requests from the A2AClient.""" - url = 'http://agent.com/rpc' - interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123') - card = AgentCard( - supported_interfaces=[ - AgentInterface(url=url, protocol_binding=TransportProtocol.JSONRPC) - ], - name='testbot', - description='test bot', - version='1.0', - default_input_modes=[], - default_output_modes=[], - skills=[], - capabilities=AgentCapabilities(), - ) - - async with httpx.AsyncClient() as http_client: - config = ClientConfig( - httpx_client=http_client, - supported_protocol_bindings=[TransportProtocol.JSONRPC], - ) - factory = ClientFactory(config) - client = factory.create(card, interceptors=[interceptor]) - - request = await send_message(client, url) - assert request.headers['x-test-header'] == 'Test-Value-123' - - def wrap_security_scheme(scheme: Any) -> SecurityScheme: """Wraps a security scheme in the correct SecurityScheme proto field.""" if isinstance(scheme, APIKeySecurityScheme): return SecurityScheme(api_key_security_scheme=scheme) - elif isinstance(scheme, HTTPAuthSecurityScheme): + if isinstance(scheme, HTTPAuthSecurityScheme): return SecurityScheme(http_auth_security_scheme=scheme) - elif isinstance(scheme, OAuth2SecurityScheme): + if isinstance(scheme, OAuth2SecurityScheme): return SecurityScheme(oauth2_security_scheme=scheme) - elif isinstance(scheme, OpenIdConnectSecurityScheme): + if isinstance(scheme, OpenIdConnectSecurityScheme): return SecurityScheme(open_id_connect_security_scheme=scheme) - else: - raise ValueError(f'Unknown security scheme type: {type(scheme)}') + raise ValueError(f'Unknown security scheme type: {type(scheme)}') @dataclass @@ -363,8 +305,6 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes( scheme_name = 'missing' session_id = 'session-id' credential = 'test-token' - request_payload = {'foo': 'bar'} - http_kwargs = {'fizz': 'buzz'} await store.set_credentials(session_id, scheme_name, credential) auth_interceptor = AuthInterceptor(credential_service=store) agent_card = AgentCard( @@ -386,13 +326,14 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes( ], security_schemes={}, ) - - new_payload, new_kwargs = await auth_interceptor.intercept( - method_name='SendMessage', - request_payload=request_payload, - http_kwargs=http_kwargs, + request = SendMessageRequest(message=Message()) + context = ClientCallContext(state={'sessionId': session_id}) + args = BeforeArgs( + input=request, + method='send_message', agent_card=agent_card, - context=ClientCallContext(state={'sessionId': session_id}), + context=context, ) - assert new_payload == request_payload - assert new_kwargs == http_kwargs + + await auth_interceptor.before(args) + assert context.service_parameters is None diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index a278eb7fe..4aa243377 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -73,7 +73,7 @@ def base_client( config=config, transport=mock_transport, consumers=[], - middleware=[], + interceptors=[], ) diff --git a/tests/client/test_base_client_interceptors.py b/tests/client/test_base_client_interceptors.py new file mode 100644 index 000000000..0e7328440 --- /dev/null +++ b/tests/client/test_base_client_interceptors.py @@ -0,0 +1,241 @@ +# ruff: noqa: INP001 +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from a2a.client.base_client import BaseClient +from a2a.client.client import ClientConfig +from a2a.client.interceptors import ( + AfterArgs, + BeforeArgs, + ClientCallInterceptor, +) +from a2a.client.transports.base import ClientTransport +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + Message, + StreamResponse, +) + + +@pytest.fixture +def mock_transport() -> AsyncMock: + return AsyncMock(spec=ClientTransport) + + +@pytest.fixture +def sample_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='An agent for testing', + supported_interfaces=[ + AgentInterface(url='http://test.com', protocol_binding='HTTP+JSON') + ], + version='1.0', + capabilities=AgentCapabilities(streaming=True), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + skills=[], + ) + + +@pytest.fixture +def mock_interceptor() -> AsyncMock: + return AsyncMock(spec=ClientCallInterceptor) + + +@pytest.fixture +def base_client( + sample_agent_card: AgentCard, + mock_transport: AsyncMock, + mock_interceptor: AsyncMock, +) -> BaseClient: + config = ClientConfig(streaming=True) + return BaseClient( + card=sample_agent_card, + config=config, + transport=mock_transport, + consumers=[], + interceptors=[mock_interceptor], + ) + + +class TestBaseClientInterceptors: + @pytest.mark.asyncio + async def test_execute_with_interceptors_normal_flow( + self, + base_client: BaseClient, + mock_interceptor: AsyncMock, + ): + input_data = MagicMock() + method = 'get_task' + context = MagicMock() + mock_transport_call = AsyncMock(return_value='transport_result') + + # Set up mock interceptor to just pass through + mock_interceptor.before.return_value = None + + result = await base_client._execute_with_interceptors( + input_data=input_data, + method=method, + context=context, + transport_call=mock_transport_call, + ) + + assert result == 'transport_result' + + # Verify before was called + mock_interceptor.before.assert_called_once() + before_args = mock_interceptor.before.call_args[0][0] + assert isinstance(before_args, BeforeArgs) + assert before_args.input == input_data + assert before_args.context == context + + # Verify transport call was made + mock_transport_call.assert_called_once_with(input_data, context) + + # Verify after was called + mock_interceptor.after.assert_called_once() + after_args = mock_interceptor.after.call_args[0][0] + assert isinstance(after_args, AfterArgs) + assert after_args.method == method + assert after_args.result == 'transport_result' + assert after_args.context == context + + @pytest.mark.asyncio + async def test_execute_with_interceptors_early_return( + self, + base_client: BaseClient, + mock_interceptor: AsyncMock, + ): + input_data = MagicMock() + method = 'get_task' + context = MagicMock() + mock_transport_call = AsyncMock() + + # Set up early return in before + early_return_result = 'early_result' + + async def mock_before_with_early_return(args: BeforeArgs): + args.early_return = early_return_result + + mock_interceptor.before.side_effect = mock_before_with_early_return + + result = await base_client._execute_with_interceptors( + input_data=input_data, + method=method, + context=context, + transport_call=mock_transport_call, + ) + + assert result == 'early_result' + + # Verify before was called + mock_interceptor.before.assert_called_once() + + # Verify transport call was NOT made + mock_transport_call.assert_not_called() + + # Verify after was called with early return value + mock_interceptor.after.assert_called_once() + after_args = mock_interceptor.after.call_args[0][0] + assert isinstance(after_args, AfterArgs) + assert after_args.result == 'early_result' + assert after_args.context == context + + @pytest.mark.asyncio + async def test_execute_stream_with_interceptors_normal_flow( + self, + base_client: BaseClient, + mock_interceptor: AsyncMock, + ): + input_data = MagicMock() + method = 'send_message_streaming' + context = MagicMock() + + async def mock_transport_call(*args, **kwargs): + yield StreamResponse(message=Message(message_id='1')) + + # Set up mock interceptor to just pass through + mock_interceptor.before.return_value = None + + events = [ + e + async for e in base_client._execute_stream_with_interceptors( + input_data=input_data, + method=method, + context=context, + transport_call=mock_transport_call, + ) + ] + + assert len(events) == 1 + + # Verify before was called + mock_interceptor.before.assert_called_once() + before_args = mock_interceptor.before.call_args[0][0] + assert isinstance(before_args, BeforeArgs) + assert before_args.input == input_data + assert before_args.context == context + + # Verify after was called + mock_interceptor.after.assert_called_once() + after_args = mock_interceptor.after.call_args[0][0] + assert isinstance(after_args, AfterArgs) + assert after_args.method == method + + @pytest.mark.asyncio + async def test_execute_stream_with_interceptors_early_return( + self, + base_client: BaseClient, + mock_interceptor: AsyncMock, + ): + input_data = MagicMock() + method = 'send_message_streaming' + context = MagicMock() + mock_transport_call = AsyncMock() + + # Set up early return in before + early_return_result = StreamResponse(message=Message(message_id='2')) + + async def mock_before_with_early_return(args: BeforeArgs): + args.early_return = early_return_result + return { + 'early_return': early_return_result, + 'executed': [mock_interceptor], + } + + mock_interceptor.before.side_effect = mock_before_with_early_return + + # Override BaseClient's _intercept_before to respect our early return setup + # as the test's mock interceptor replaces the actual list items + base_client._intercept_before = AsyncMock( # type: ignore + return_value={ + 'early_return': early_return_result, + 'executed': [mock_interceptor], + } + ) + + events = [ + e + async for e in base_client._execute_stream_with_interceptors( + input_data=input_data, + method=method, + context=context, + transport_call=mock_transport_call, + ) + ] + + assert len(events) == 1 + + # Verify transport call was NOT made + mock_transport_call.assert_not_called() + + # Verify after was called with early return value + mock_interceptor.after.assert_called_once() + after_args = mock_interceptor.after.call_args[0][0] + assert isinstance(after_args, AfterArgs) + assert after_args.method == method + assert after_args.context == context diff --git a/tests/client/test_client_factory_grpc.py b/tests/client/test_client_factory_grpc.py index 1e7563248..47423d0ab 100644 --- a/tests/client/test_client_factory_grpc.py +++ b/tests/client/test_client_factory_grpc.py @@ -60,7 +60,7 @@ def test_grpc_priority_1_0(grpc_agent_card): # Priority 1: 1.0 -> GrpcTransport mock_grpc.create.assert_called_once_with( - grpc_agent_card, 'url10', config, [] + grpc_agent_card, 'url10', config ) mock_compat.create.assert_not_called() @@ -101,7 +101,7 @@ def test_grpc_priority_gt_1_0(grpc_agent_card): # Priority 2: > 1.0 -> GrpcTransport (first matching is 1.1) mock_grpc.create.assert_called_once_with( - grpc_agent_card, 'url11', config, [] + grpc_agent_card, 'url11', config ) mock_compat.create.assert_not_called() @@ -171,5 +171,5 @@ def test_grpc_unspecified_version_uses_grpc_transport(grpc_agent_card): factory.create(grpc_agent_card) mock_grpc.create.assert_called_once_with( - grpc_agent_card, 'url_no_version', config, [] + grpc_agent_card, 'url_no_version', config ) diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index 506d33d6e..9e81bd71e 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -6,7 +6,7 @@ from google.protobuf import any_pb2 from google.rpc import error_details_pb2, status_pb2 -from a2a.client.middleware import ClientCallContext +from a2a.client.client import ClientCallContext from a2a.client.transports.grpc import GrpcTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT @@ -230,7 +230,7 @@ async def test_send_message_with_timeout_context( sample_task: Task, ) -> None: """Test send_message passes context timeout to grpc stub.""" - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( task=sample_task diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index e5de809db..b568865e6 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -117,17 +117,6 @@ def test_init_with_agent_card(self, mock_httpx_client, agent_card): assert transport.url == 'http://test-agent.example.com' assert transport.agent_card == agent_card - def test_init_with_interceptors(self, mock_httpx_client, agent_card): - """Test initialization with interceptors.""" - interceptor = MagicMock() - transport = JsonRpcTransport( - httpx_client=mock_httpx_client, - agent_card=agent_card, - url='http://test-agent.example.com', - interceptors=[interceptor], - ) - assert transport.interceptors == [interceptor] - class TestSendMessage: """Tests for the send_message method.""" @@ -229,7 +218,7 @@ async def test_send_message_with_timeout_context( self, transport, mock_httpx_client ): """Test that send_message passes context timeout to build_request.""" - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext mock_response = MagicMock() mock_response.json.return_value = { @@ -544,7 +533,7 @@ async def test_extensions_added_to_request( request = create_send_message_request() - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext context = ClientCallContext( service_parameters={'X-A2A-Extensions': 'https://example.com/ext1'} @@ -631,7 +620,7 @@ async def test_get_card_with_extended_card_support_with_extensions( 'result': json_format.MessageToDict(extended_card), } - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext context = ClientCallContext( service_parameters={HTTP_EXTENSION_HEADER: extensions_header_val} diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index ec29ddc56..d76873918 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -148,7 +148,7 @@ async def test_send_message_with_timeout_context( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): """Test that send_message passes context timeout to build_request.""" - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext client = RestTransport( httpx_client=mock_httpx_client, @@ -244,7 +244,7 @@ async def test_send_message_with_default_extensions( mock_response.status_code = 200 mock_httpx_client.send.return_value = mock_response - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext context = ClientCallContext( service_parameters={ @@ -288,7 +288,7 @@ async def test_send_message_streaming_with_new_extensions( mock_event_source ) - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext context = ClientCallContext( service_parameters={ @@ -390,7 +390,7 @@ async def test_get_card_with_extended_card_support_with_extensions( request = GetExtendedAgentCardRequest() - from a2a.client.middleware import ClientCallContext + from a2a.client.client import ClientCallContext context = ClientCallContext( service_parameters={HTTP_EXTENSION_HEADER: extensions_str} diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 82c14ce6d..e239d780f 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -17,7 +17,7 @@ from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver from a2a.client.client_factory import ClientFactory -from a2a.client.middleware import ClientCallContext +from a2a.client.client import ClientCallContext from a2a.client.service_parameters import ( ServiceParametersFactory, with_a2a_extensions, @@ -545,7 +545,7 @@ async def test_json_transport_base_client_send_message_with_extensions( config=ClientConfig(streaming=False), transport=transport, consumers=[], - middleware=[], + interceptors=[], ) message_to_send = Message( @@ -705,7 +705,7 @@ async def test_client_get_signed_extended_card( config=ClientConfig(streaming=False), transport=transport, consumers=[], - middleware=[], + interceptors=[], ) signature_verifier = create_signature_verifier( @@ -791,7 +791,7 @@ async def test_client_get_signed_base_and_extended_cards( config=ClientConfig(streaming=False), transport=transport, consumers=[], - middleware=[], + interceptors=[], ) # 3. Fetch extended card via client