Skip to content

Commit e4b237a

Browse files
committed
fix: clean up legacy grpc error mapping and status handling
- Removed unreleased legacy string-split fallback parsing for gRPC errors. - Refactored client and server handlers to use pure grpcio-status helpers for encoding/decoding google.rpc.Status to standard trailing metadata. - Cleaned up A2A_ERROR_REASONS map to strictly contain only the 9 domain errors defined in the specification (removing generic JSON-RPC errors). - Added grpcio-status to pyproject.toml dependencies.
1 parent fdb156b commit e4b237a

8 files changed

Lines changed: 75 additions & 110 deletions

File tree

.jscpd.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
{
2-
"ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"],
2+
"ignore": [
3+
"**/.github/**",
4+
"**/.git/**",
5+
"**/tests/**",
6+
"**/src/a2a/grpc/**",
7+
"**/src/a2a/compat/**",
8+
"**/.nox/**",
9+
"**/.venv/**"
10+
],
311
"threshold": 3,
412
"reporters": ["html", "markdown"]
513
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ classifiers = [
3333
[project.optional-dependencies]
3434
http-server = ["fastapi>=0.115.2", "sse-starlette", "starlette"]
3535
encryption = ["cryptography>=43.0.0"]
36-
grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"]
36+
grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio-status>=1.60", "grpcio_reflection>=1.7.0"]
3737
telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
3838
postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
3939
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]

src/a2a/client/transports/grpc.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
import logging
22

3-
from collections.abc import AsyncGenerator, Callable, Iterable
3+
from collections.abc import AsyncGenerator, Callable
44
from functools import wraps
55
from typing import Any, NoReturn, cast
66

77
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
88
from a2a.client.middleware import ClientCallContext
9-
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, A2AError
109

1110

1211
try:
1312
import grpc # type: ignore[reportMissingModuleSource]
13+
14+
from grpc_status import rpc_status
1415
except ImportError as e:
1516
raise ImportError(
16-
'A2AGrpcClient requires grpcio and grpcio-tools to be installed. '
17+
'A2AGrpcClient requires grpcio, grpcio-tools, and grpcio-status to be installed. '
1718
'Install with: '
1819
"'pip install a2a-sdk[grpc]'"
1920
) from e
2021

2122

2223
from google.rpc import ( # type: ignore[reportMissingModuleSource]
2324
error_details_pb2,
24-
status_pb2,
2525
)
2626

2727
from a2a.client.client import ClientConfig
@@ -55,16 +55,16 @@
5555

5656
logger = logging.getLogger(__name__)
5757

58-
_A2A_ERROR_NAME_TO_CLS = {
59-
error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP
60-
}
6158

59+
def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
60+
61+
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
62+
raise A2AClientTimeoutError('Client Request timed out') from e
63+
64+
# 1. Use grpc_status to cleanly extract the rich Status from the call
65+
status = rpc_status.from_call(cast('grpc.Call', e))
6266

63-
def _parse_rich_grpc_error(
64-
value: bytes, original_error: grpc.aio.AioRpcError
65-
) -> None:
66-
try:
67-
status = status_pb2.Status.FromString(value)
67+
if status is not None:
6868
for detail in status.details:
6969
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
7070
error_info = error_details_pb2.ErrorInfo()
@@ -73,34 +73,8 @@ def _parse_rich_grpc_error(
7373
if error_info.domain == 'a2a-protocol.org':
7474
exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason)
7575
if exception_cls:
76-
raise exception_cls(status.message) from original_error # noqa: TRY301
77-
except Exception as parse_e:
78-
# Don't swallow A2A errors generated above
79-
if isinstance(parse_e, (A2AError, A2AClientError)):
80-
raise parse_e
81-
logger.warning(
82-
'Failed to parse grpc-status-details-bin', exc_info=parse_e
83-
)
84-
85-
86-
def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
87-
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
88-
raise A2AClientTimeoutError('Client Request timed out') from e
76+
raise exception_cls(status.message) from e
8977

90-
metadata = e.trailing_metadata()
91-
if metadata:
92-
iterable_metadata = cast('Iterable[tuple[str, str | bytes]]', metadata)
93-
for key, value in iterable_metadata:
94-
if key == 'grpc-status-details-bin' and isinstance(value, bytes):
95-
_parse_rich_grpc_error(value, e)
96-
97-
details = e.details()
98-
if isinstance(details, str) and ': ' in details:
99-
error_type_name, error_message = details.split(': ', 1)
100-
# Leaving as fallback for errors that don't use the rich error details.
101-
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name)
102-
if exception_cls:
103-
raise exception_cls(error_message) from e
10478
raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e
10579

10680

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
from abc import ABC, abstractmethod
66
from collections.abc import AsyncIterable, Awaitable, Callable
7-
from typing import cast
87

98

109
try:
1110
import grpc # type: ignore[reportMissingModuleSource]
1211
import grpc.aio # type: ignore[reportMissingModuleSource]
12+
13+
from grpc_status import rpc_status
1314
except ImportError as e:
1415
raise ImportError(
15-
'GrpcHandler requires grpcio and grpcio-tools to be installed. '
16+
'GrpcHandler requires grpcio, grpcio-tools, and grpcio-status to be installed. '
1617
'Install with: '
1718
"'pip install a2a-sdk[grpc]'"
1819
) from e
@@ -420,37 +421,40 @@ async def abort_context(
420421
"""Sets the grpc errors appropriately in the context."""
421422
code = _ERROR_CODE_MAP.get(type(error))
422423

423-
status_value = code.value if code else grpc.StatusCode.UNKNOWN.value
424-
status_code = (
425-
status_value[0] if isinstance(status_value, tuple) else status_value
426-
)
427-
error_msg = error.message if hasattr(error, 'message') else str(error)
428-
status = status_pb2.Status(code=status_code, message=error_msg)
429-
430424
if code:
431425
reason = A2A_ERROR_REASONS.get(type(error), 'UNKNOWN_ERROR')
432-
433426
error_info = error_details_pb2.ErrorInfo(
434427
reason=reason,
435428
domain='a2a-protocol.org',
436429
)
437430

431+
status_code = (
432+
code.value[0] if code else grpc.StatusCode.UNKNOWN.value[0]
433+
)
434+
error_msg = (
435+
error.message if hasattr(error, 'message') else str(error)
436+
)
437+
438+
# Create standard Status and pack the ErrorInfo
439+
status = status_pb2.Status(code=status_code, message=error_msg)
438440
detail = any_pb2.Any()
439441
detail.Pack(error_info)
440442
status.details.append(detail)
441443

442-
context.set_trailing_metadata(
443-
cast(
444-
'tuple[tuple[str, str | bytes], ...]',
445-
(('grpc-status-details-bin', status.SerializeToString()),),
446-
)
447-
)
444+
# Use grpc_status to safely generate standard trailing metadata
445+
rich_status = rpc_status.to_status(status)
448446

449-
if code:
450-
await context.abort(
451-
code,
452-
status.message,
453-
)
447+
new_metadata: list[tuple[str, str | bytes]] = []
448+
trailing = context.trailing_metadata()
449+
if trailing:
450+
for k, v in trailing:
451+
new_metadata.append((str(k), v))
452+
453+
for k, v in rich_status.trailing_metadata:
454+
new_metadata.append((str(k), v))
455+
456+
context.set_trailing_metadata(tuple(new_metadata))
457+
await context.abort(rich_status.code, rich_status.details)
454458
else:
455459
await context.abort(
456460
grpc.StatusCode.UNKNOWN,

src/a2a/server/request_handlers/response_helpers.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SendMessageResponse as SendMessageResponseProto,
2828
)
2929
from a2a.utils.errors import (
30+
JSON_RPC_ERROR_CODE_MAP,
3031
A2AError,
3132
AuthenticatedExtendedCardNotConfiguredError,
3233
ContentTypeNotSupportedError,
@@ -56,19 +57,6 @@
5657
InternalError: JSONRPCInternalError,
5758
}
5859

59-
ERROR_CODE_MAP: dict[type[A2AError], int] = {
60-
TaskNotFoundError: -32001,
61-
TaskNotCancelableError: -32002,
62-
PushNotificationNotSupportedError: -32003,
63-
UnsupportedOperationError: -32004,
64-
ContentTypeNotSupportedError: -32005,
65-
InvalidAgentResponseError: -32006,
66-
AuthenticatedExtendedCardNotConfiguredError: -32007,
67-
InvalidParamsError: -32602,
68-
InvalidRequestError: -32600,
69-
MethodNotFoundError: -32601,
70-
}
71-
7260

7361
# Tuple of all A2AError types for isinstance checks
7462
_A2A_ERROR_TYPES: tuple[type, ...] = (A2AError,)
@@ -136,7 +124,7 @@ def build_error_response(
136124
elif isinstance(error, A2AError):
137125
error_type = type(error)
138126
model_class = EXCEPTION_MAP.get(error_type, JSONRPCInternalError)
139-
code = ERROR_CODE_MAP.get(error_type, -32603)
127+
code = JSON_RPC_ERROR_CODE_MAP.get(error_type, -32603)
140128
jsonrpc_error = model_class(
141129
code=code,
142130
message=str(error),

src/a2a/utils/errors.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,6 @@ class VersionNotSupportedError(A2AError):
140140
AuthenticatedExtendedCardNotConfiguredError: 'EXTENDED_AGENT_CARD_NOT_CONFIGURED',
141141
ExtensionSupportRequiredError: 'EXTENSION_SUPPORT_REQUIRED',
142142
VersionNotSupportedError: 'VERSION_NOT_SUPPORTED',
143-
InvalidParamsError: 'INVALID_PARAMS',
144-
InvalidRequestError: 'INVALID_REQUEST',
145-
MethodNotFoundError: 'METHOD_NOT_FOUND',
146-
InternalError: 'INTERNAL_ERROR',
147143
}
148144

149145
A2A_REASON_TO_ERROR = {reason: cls for cls, reason in A2A_ERROR_REASONS.items()}

tests/client/transports/test_grpc_client.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
TaskStatusUpdateEvent,
3737
)
3838
from a2a.utils import get_text_parts
39-
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
4039

4140

4241
@pytest.fixture
@@ -259,29 +258,7 @@ async def test_send_message_with_timeout_context(
259258
assert kwargs['timeout'] == 12.5
260259

261260

262-
@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
263-
@pytest.mark.asyncio
264-
async def test_grpc_mapped_errors_legacy(
265-
grpc_transport: GrpcTransport,
266-
mock_grpc_stub: AsyncMock,
267-
sample_message_send_params: SendMessageRequest,
268-
error_cls,
269-
) -> None:
270-
"""Test handling of legacy gRPC error responses."""
271-
error_details = f'{error_cls.__name__}: Mapped Error'
272-
273-
mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError(
274-
code=grpc.StatusCode.INTERNAL,
275-
initial_metadata=grpc.aio.Metadata(),
276-
trailing_metadata=grpc.aio.Metadata(),
277-
details=error_details,
278-
)
279-
280-
with pytest.raises(error_cls):
281-
await grpc_transport.send_message(sample_message_send_params)
282-
283-
284-
@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
261+
@pytest.mark.parametrize('error_cls', list(A2A_ERROR_REASONS.keys()))
285262
@pytest.mark.asyncio
286263
async def test_grpc_mapped_errors_rich(
287264
grpc_transport: GrpcTransport,
@@ -312,7 +289,7 @@ async def test_grpc_mapped_errors_rich(
312289
trailing_metadata=grpc.aio.Metadata(
313290
('grpc-status-details-bin', status.SerializeToString()),
314291
),
315-
details='A generic error message',
292+
details=error_details,
316293
)
317294

318295
with pytest.raises(error_cls) as excinfo:

uv.lock

Lines changed: 23 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)