diff --git a/.jscpd.json b/.jscpd.json index 5a6fcad7..ed59a649 100644 --- a/.jscpd.json +++ b/.jscpd.json @@ -1,5 +1,13 @@ { - "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"], + "ignore": [ + "**/.github/**", + "**/.git/**", + "**/tests/**", + "**/src/a2a/grpc/**", + "**/src/a2a/compat/**", + "**/.nox/**", + "**/.venv/**" + ], "threshold": 3, "reporters": ["html", "markdown"] } diff --git a/pyproject.toml b/pyproject.toml index 370315e1..c57824ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ [project.optional-dependencies] http-server = ["fastapi>=0.115.2", "sse-starlette", "starlette"] encryption = ["cryptography>=43.0.0"] -grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"] +grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio-status>=1.60", "grpcio_reflection>=1.7.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index b33e5d34..5ca1ac4f 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -2,24 +2,29 @@ from collections.abc import AsyncGenerator, Callable from functools import wraps -from typing import Any, NoReturn +from typing import Any, NoReturn, cast +from a2a.client.errors import A2AClientError, A2AClientTimeoutError from a2a.client.middleware import ClientCallContext -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP try: import grpc # type: ignore[reportMissingModuleSource] + + from grpc_status import rpc_status except ImportError as e: raise ImportError( - 'A2AGrpcClient requires grpcio and grpcio-tools to be installed. ' + 'A2AGrpcClient requires grpcio, grpcio-tools, and grpcio-status to be installed. ' 'Install with: ' "'pip install a2a-sdk[grpc]'" ) from e +from google.rpc import ( # type: ignore[reportMissingModuleSource] + error_details_pb2, +) + from a2a.client.client import ClientConfig -from a2a.client.errors import A2AClientError, A2AClientTimeoutError from a2a.client.middleware import ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport @@ -43,27 +48,32 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER +from a2a.utils.errors import A2A_REASON_TO_ERROR from a2a.utils.telemetry import SpanKind, trace_class logger = logging.getLogger(__name__) -_A2A_ERROR_NAME_TO_CLS = { - error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP -} - def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: raise A2AClientTimeoutError('Client Request timed out') from e - details = e.details() - if isinstance(details, str) and ': ' in details: - error_type_name, error_message = details.split(': ', 1) - # TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723. - exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name) - if exception_cls: - raise exception_cls(error_message) from e + # Use grpc_status to cleanly extract the rich Status from the call + status = rpc_status.from_call(cast('grpc.Call', e)) + + if status is not None: + for detail in status.details: + if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + error_info = error_details_pb2.ErrorInfo() + detail.Unpack(error_info) + + if error_info.domain == 'a2a-protocol.org': + exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) + if exception_cls: + raise exception_cls(status.message) from e + raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index d6348aa9..551891ee 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,22 +3,23 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Awaitable +from collections.abc import AsyncIterable, Awaitable, Callable try: import grpc # type: ignore[reportMissingModuleSource] import grpc.aio # type: ignore[reportMissingModuleSource] + + from grpc_status import rpc_status except ImportError as e: raise ImportError( - 'GrpcHandler requires grpcio and grpcio-tools to be installed. ' + 'GrpcHandler requires grpcio, grpcio-tools, and grpcio-status to be installed. ' 'Install with: ' "'pip install a2a-sdk[grpc]'" ) from e -from collections.abc import Callable - -from google.protobuf import empty_pb2, message +from google.protobuf import any_pb2, empty_pb2, message +from google.rpc import error_details_pb2, status_pb2 import a2a.types.a2a_pb2_grpc as a2a_grpc @@ -33,7 +34,7 @@ from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils -from a2a.utils.errors import A2AError, TaskNotFoundError +from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError from a2a.utils.helpers import maybe_await, validate, validate_async_generator @@ -419,11 +420,41 @@ async def abort_context( ) -> None: """Sets the grpc errors appropriately in the context.""" code = _ERROR_CODE_MAP.get(type(error)) + if code: - await context.abort( - code, - f'{type(error).__name__}: {error.message}', + reason = A2A_ERROR_REASONS.get(type(error), 'UNKNOWN_ERROR') + error_info = error_details_pb2.ErrorInfo( + reason=reason, + domain='a2a-protocol.org', + ) + + status_code = ( + code.value[0] if code else grpc.StatusCode.UNKNOWN.value[0] ) + error_msg = ( + error.message if hasattr(error, 'message') else str(error) + ) + + # Create standard Status and pack the ErrorInfo + status = status_pb2.Status(code=status_code, message=error_msg) + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) + + # Use grpc_status to safely generate standard trailing metadata + rich_status = rpc_status.to_status(status) + + new_metadata: list[tuple[str, str | bytes]] = [] + trailing = context.trailing_metadata() + if trailing: + for k, v in trailing: + new_metadata.append((str(k), v)) + + for k, v in rich_status.trailing_metadata: + new_metadata.append((str(k), v)) + + context.set_trailing_metadata(tuple(new_metadata)) + await context.abort(rich_status.code, rich_status.details) else: await context.abort( grpc.StatusCode.UNKNOWN, diff --git a/src/a2a/server/request_handlers/response_helpers.py b/src/a2a/server/request_handlers/response_helpers.py index 5f38a0a6..f7bffd60 100644 --- a/src/a2a/server/request_handlers/response_helpers.py +++ b/src/a2a/server/request_handlers/response_helpers.py @@ -27,6 +27,7 @@ SendMessageResponse as SendMessageResponseProto, ) from a2a.utils.errors import ( + JSON_RPC_ERROR_CODE_MAP, A2AError, AuthenticatedExtendedCardNotConfiguredError, ContentTypeNotSupportedError, @@ -56,19 +57,6 @@ InternalError: JSONRPCInternalError, } -ERROR_CODE_MAP: dict[type[A2AError], int] = { - TaskNotFoundError: -32001, - TaskNotCancelableError: -32002, - PushNotificationNotSupportedError: -32003, - UnsupportedOperationError: -32004, - ContentTypeNotSupportedError: -32005, - InvalidAgentResponseError: -32006, - AuthenticatedExtendedCardNotConfiguredError: -32007, - InvalidParamsError: -32602, - InvalidRequestError: -32600, - MethodNotFoundError: -32601, -} - # Tuple of all A2AError types for isinstance checks _A2A_ERROR_TYPES: tuple[type, ...] = (A2AError,) @@ -136,7 +124,7 @@ def build_error_response( elif isinstance(error, A2AError): error_type = type(error) model_class = EXCEPTION_MAP.get(error_type, JSONRPCInternalError) - code = ERROR_CODE_MAP.get(error_type, -32603) + code = JSON_RPC_ERROR_CODE_MAP.get(error_type, -32603) jsonrpc_error = model_class( code=code, message=str(error), diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index 845bbfca..9353805e 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -82,11 +82,26 @@ class MethodNotFoundError(A2AError): message = 'Method not found' +class ExtensionSupportRequiredError(A2AError): + """Exception raised when extension support is required but not present.""" + + message = 'Extension support required' + + +class VersionNotSupportedError(A2AError): + """Exception raised when the requested version is not supported.""" + + message = 'Version not supported' + + # For backward compatibility if needed, or just aliases for clean refactor # We remove the Pydantic models here. __all__ = [ + 'A2A_ERROR_REASONS', + 'A2A_REASON_TO_ERROR', 'JSON_RPC_ERROR_CODE_MAP', + 'ExtensionSupportRequiredError', 'InternalError', 'InvalidAgentResponseError', 'InvalidParamsError', @@ -96,6 +111,7 @@ class MethodNotFoundError(A2AError): 'TaskNotCancelableError', 'TaskNotFoundError', 'UnsupportedOperationError', + 'VersionNotSupportedError', ] @@ -112,3 +128,18 @@ class MethodNotFoundError(A2AError): MethodNotFoundError: -32601, InternalError: -32603, } + + +A2A_ERROR_REASONS = { + TaskNotFoundError: 'TASK_NOT_FOUND', + TaskNotCancelableError: 'TASK_NOT_CANCELABLE', + PushNotificationNotSupportedError: 'PUSH_NOTIFICATION_NOT_SUPPORTED', + UnsupportedOperationError: 'UNSUPPORTED_OPERATION', + ContentTypeNotSupportedError: 'CONTENT_TYPE_NOT_SUPPORTED', + InvalidAgentResponseError: 'INVALID_AGENT_RESPONSE', + AuthenticatedExtendedCardNotConfiguredError: 'EXTENDED_AGENT_CARD_NOT_CONFIGURED', + ExtensionSupportRequiredError: 'EXTENSION_SUPPORT_REQUIRED', + VersionNotSupportedError: 'VERSION_NOT_SUPPORTED', +} + +A2A_REASON_TO_ERROR = {reason: cls for cls, reason in A2A_ERROR_REASONS.items()} diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index be4bf9c5..506d33d6 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -3,10 +3,14 @@ import grpc import pytest +from google.protobuf import any_pb2 +from google.rpc import error_details_pb2, status_pb2 + from a2a.client.middleware import ClientCallContext from a2a.client.transports.grpc import GrpcTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT +from a2a.utils.errors import A2A_ERROR_REASONS from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -32,7 +36,6 @@ TaskStatusUpdateEvent, ) from a2a.utils import get_text_parts -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP @pytest.fixture @@ -245,28 +248,45 @@ async def test_send_message_with_timeout_context( assert kwargs['timeout'] == 12.5 -@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys())) +@pytest.mark.parametrize('error_cls', list(A2A_ERROR_REASONS.keys())) @pytest.mark.asyncio -async def test_grpc_mapped_errors( +async def test_grpc_mapped_errors_rich( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_message_send_params: SendMessageRequest, error_cls, ) -> None: - """Test handling of mapped gRPC error responses.""" + """Test handling of rich gRPC error responses with Status metadata.""" + + reason = A2A_ERROR_REASONS.get(error_cls, 'UNKNOWN_ERROR') + + error_info = error_details_pb2.ErrorInfo( + reason=reason, + domain='a2a-protocol.org', + ) + error_details = f'{error_cls.__name__}: Mapped Error' + status = status_pb2.Status( + code=grpc.StatusCode.INTERNAL.value[0], message=error_details + ) + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) - # We must trigger it from a standard transport method call, for example `send_message`. mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError( code=grpc.StatusCode.INTERNAL, initial_metadata=grpc.aio.Metadata(), - trailing_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata( + ('grpc-status-details-bin', status.SerializeToString()), + ), details=error_details, ) - with pytest.raises(error_cls): + with pytest.raises(error_cls) as excinfo: await grpc_transport.send_message(sample_message_send_params) + assert str(excinfo.value) == error_details + @pytest.mark.asyncio async def test_send_message_message_response( diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 802cbf66..4d121ca2 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -5,6 +5,7 @@ import grpc.aio import pytest +from google.rpc import error_details_pb2, status_pb2 from a2a import types from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext @@ -99,7 +100,7 @@ async def test_send_message_server_error( await grpc_handler.SendMessage(request_proto, mock_grpc_context) mock_grpc_context.abort.assert_awaited_once_with( - grpc.StatusCode.INVALID_ARGUMENT, 'InvalidParamsError: Bad params' + grpc.StatusCode.INVALID_ARGUMENT, 'Bad params' ) @@ -138,7 +139,7 @@ async def test_get_task_not_found( await grpc_handler.GetTask(request_proto, mock_grpc_context) mock_grpc_context.abort.assert_awaited_once_with( - grpc.StatusCode.NOT_FOUND, 'TaskNotFoundError: Task not found' + grpc.StatusCode.NOT_FOUND, 'Task not found' ) @@ -157,7 +158,7 @@ async def test_cancel_task_server_error( mock_grpc_context.abort.assert_awaited_once_with( grpc.StatusCode.UNIMPLEMENTED, - 'TaskNotCancelableError: Task cannot be canceled', + 'Task cannot be canceled', ) @@ -379,7 +380,44 @@ async def test_abort_context_error_mapping( # noqa: PLR0913 mock_grpc_context.abort.assert_awaited_once() call_args, _ = mock_grpc_context.abort.call_args assert call_args[0] == grpc_status_code - assert error_message_part in call_args[1] + + # We shouldn't rely on the legacy ExceptionName: message string format + # But for backward compatability fallback it shouldn't fail + mock_grpc_context.set_trailing_metadata.assert_called_once() + metadata = mock_grpc_context.set_trailing_metadata.call_args[0][0] + + assert any(key == 'grpc-status-details-bin' for key, _ in metadata) + + +@pytest.mark.asyncio +async def test_abort_context_rich_error_format( + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, +) -> None: + + error = types.TaskNotFoundError('Could not find the task') + mock_request_handler.on_get_task.side_effect = error + request_proto = a2a_pb2.GetTaskRequest(id='any') + await grpc_handler.GetTask(request_proto, mock_grpc_context) + + mock_grpc_context.set_trailing_metadata.assert_called_once() + metadata = mock_grpc_context.set_trailing_metadata.call_args[0][0] + + bin_values = [v for k, v in metadata if k == 'grpc-status-details-bin'] + assert len(bin_values) == 1 + + status = status_pb2.Status.FromString(bin_values[0]) + assert status.code == grpc.StatusCode.NOT_FOUND.value[0] + assert status.message == 'Could not find the task' + + assert len(status.details) == 1 + + error_info = error_details_pb2.ErrorInfo() + status.details[0].Unpack(error_info) + + assert error_info.reason == 'TASK_NOT_FOUND' + assert error_info.domain == 'a2a-protocol.org' @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index f6f0cc5a..bfcde562 100644 --- a/uv.lock +++ b/uv.lock @@ -29,6 +29,7 @@ all = [ { name = "google-cloud-aiplatform" }, { name = "grpcio" }, { name = "grpcio-reflection" }, + { name = "grpcio-status" }, { name = "grpcio-tools" }, { name = "opentelemetry-api" }, { name = "opentelemetry-sdk" }, @@ -46,6 +47,7 @@ encryption = [ grpc = [ { name = "grpcio" }, { name = "grpcio-reflection" }, + { name = "grpcio-status" }, { name = "grpcio-tools" }, ] http-server = [ @@ -117,6 +119,8 @@ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.60" }, { name = "grpcio-reflection", marker = "extra == 'all'", specifier = ">=1.7.0" }, { name = "grpcio-reflection", marker = "extra == 'grpc'", specifier = ">=1.7.0" }, + { name = "grpcio-status", marker = "extra == 'all'", specifier = ">=1.60" }, + { name = "grpcio-status", marker = "extra == 'grpc'", specifier = ">=1.60" }, { name = "grpcio-tools", marker = "extra == 'all'", specifier = ">=1.60" }, { name = "grpcio-tools", marker = "extra == 'grpc'", specifier = ">=1.60" }, { name = "httpx", specifier = ">=0.28.1" },