Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e38d44f
wip
guglielmo-san Mar 9, 2026
bba49a6
Merge remote-tracking branch 'upstream/1.0-dev' into guglielmoc/rewor…
guglielmo-san Mar 9, 2026
61b7304
feat: Implement and apply interceptors directly within BaseClient met…
guglielmo-san Mar 9, 2026
c8b2550
wip
guglielmo-san Mar 10, 2026
68db81e
add tests
guglielmo-san Mar 10, 2026
1235d9f
run ruff
guglielmo-san Mar 10, 2026
4967d97
add cast
guglielmo-san Mar 10, 2026
eaf1792
refactor: Simplify BeforeArgs initialization in send_message_streamin…
guglielmo-san Mar 10, 2026
1d5319c
refactor: Add explicit type hints to `BeforeArgs` and remove redundan…
guglielmo-san Mar 10, 2026
f379679
revert change
guglielmo-san Mar 10, 2026
8affcb3
refactor: Rename `CallInterceptor` to `ClientCallInterceptor` and sim…
guglielmo-san Mar 10, 2026
ef658ff
fix
guglielmo-san Mar 10, 2026
ab134b3
refactor: centralize stream interception and execution logic into `_e…
guglielmo-san Mar 10, 2026
92bc02f
refactor: Update stream event processing to use a task manager, adjus…
guglielmo-san Mar 10, 2026
0dad45d
refactor: Migrate authentication logic to the new interceptor pattern…
guglielmo-san Mar 10, 2026
191c970
refactor: simplify interceptor argument types by removing generic typ…
guglielmo-san Mar 10, 2026
2d21b05
Merge branch '1.0-dev' into guglielmoc/rework_client_interceptors
guglielmo-san Mar 10, 2026
afb7df9
refactor: make ClientCallInterceptor an abstract base class
guglielmo-san Mar 13, 2026
021c8b3
Merge remote-tracking branch 'refs/remotes/origin/guglielmoc/rework_c…
guglielmo-san Mar 13, 2026
b54e2ea
Merge remote-tracking branch 'upstream/1.0-dev' into guglielmoc/rewor…
guglielmo-san Mar 13, 2026
19da397
fix
guglielmo-san Mar 13, 2026
8f562e4
refactor
guglielmo-san Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/a2a/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
)
from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_factory import ClientFactory, minimal_agent_card
from a2a.client.errors import (
A2AClientError,
A2AClientTimeoutError,
AgentCardResolutionError,
)
from a2a.client.helpers import create_text_message_object
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.interceptors import ClientCallInterceptor


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/client/auth/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod

from a2a.client.middleware import ClientCallContext
from a2a.client.client import ClientCallContext


class CredentialService(ABC):
Expand Down
59 changes: 31 additions & 28 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging # noqa: I001
from typing import Any

from a2a.client.auth.credentials import CredentialService
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.types.a2a_pb2 import AgentCard
from a2a.client.client import ClientCallContext
from a2a.client.interceptors import (
AfterArgs,
BeforeArgs,
ClientCallInterceptor,
)

logger = logging.getLogger(__name__)

Expand All @@ -17,79 +20,79 @@ class AuthInterceptor(ClientCallInterceptor):
def __init__(self, credential_service: CredentialService):
self._credential_service = credential_service

async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
async def before(self, args: BeforeArgs) -> None:
"""Applies authentication headers to the request if credentials are available."""
agent_card = args.agent_card

# Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
# HasField() raises ValueError for them.
# We check for truthiness to see if they are non-empty.
if (
agent_card is None
or not agent_card.security_requirements
not agent_card.security_requirements
or not agent_card.security_schemes
):
return request_payload, http_kwargs
return

for requirement in agent_card.security_requirements:
for scheme_name in requirement.schemes:
credential = await self._credential_service.get_credentials(
scheme_name, context
scheme_name, args.context
)
if credential and scheme_name in agent_card.security_schemes:
scheme = agent_card.security_schemes.get(scheme_name)
if not scheme:
continue

headers = http_kwargs.get('headers', {})
if args.context is None:
args.context = ClientCallContext()

if args.context.service_parameters is None:
args.context.service_parameters = {}

# HTTP Bearer authentication
if (
scheme.HasField('http_auth_security_scheme')
and scheme.http_auth_security_scheme.scheme.lower()
== 'bearer'
):
headers['Authorization'] = f'Bearer {credential}'
args.context.service_parameters['Authorization'] = (
f'Bearer {credential}'
)
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
return

# OAuth2 and OIDC schemes are implicitly Bearer
if scheme.HasField(
'oauth2_security_scheme'
) or scheme.HasField('open_id_connect_security_scheme'):
headers['Authorization'] = f'Bearer {credential}'
args.context.service_parameters['Authorization'] = (
f'Bearer {credential}'
)
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
return

# API Key in Header
if (
scheme.HasField('api_key_security_scheme')
and scheme.api_key_security_scheme.location.lower()
== 'header'
):
headers[scheme.api_key_security_scheme.name] = (
credential
)
args.context.service_parameters[
scheme.api_key_security_scheme.name
] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
return

# Note: Other cases like API keys in query/cookie are not handled and will be skipped.

return request_payload, http_kwargs
async def after(self, args: AfterArgs) -> None:
"""Invoked after the method is executed."""
Loading
Loading