From 4823b5d42b17140d57e8c1785350816b8a8fb83d Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Thu, 12 Mar 2026 13:50:38 +0000 Subject: [PATCH] feat(compat): Tests and unified REST url. --- src/a2a/server/apps/rest/fastapi_app.py | 24 +-- .../v0_3/test_rest_fastapi_app_compat.py | 8 +- .../cross_version/client_server/client_0_3.py | 165 ++++++++++++--- .../cross_version/client_server/client_1_0.py | 198 ++++++++++++++++-- .../cross_version/client_server/server_0_3.py | 57 ++++- .../cross_version/client_server/server_1_0.py | 45 +++- .../client_server/server_common.py | 47 +++++ .../client_server/test_client_server.py | 27 ++- .../server/apps/rest/test_rest_fastapi_app.py | 3 +- 9 files changed, 476 insertions(+), 98 deletions(-) create mode 100644 tests/integration/cross_version/client_server/server_common.py diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index 0f9b91c6..c828610a 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -121,6 +121,15 @@ def build( A configured FastAPI application instance. """ app = FastAPI(**kwargs) + if self.enable_v0_3_compat and self._v03_adapter: + v03_adapter = self._v03_adapter + v03_router = APIRouter() + for route, callback in v03_adapter.routes().items(): + v03_router.add_api_route( + f'{rpc_url}{route[0]}', callback, methods=[route[1]] + ) + app.include_router(v03_router) + router = APIRouter() for route, callback in self._adapter.routes().items(): router.add_api_route( @@ -134,19 +143,4 @@ async def get_agent_card(request: Request) -> Response: app.include_router(router) - if self.enable_v0_3_compat and self._v03_adapter: - v03_adapter = self._v03_adapter - v03_router = APIRouter() - for route, callback in v03_adapter.routes().items(): - v03_router.add_api_route( - f'{rpc_url}/v0.3{route[0]}', callback, methods=[route[1]] - ) - - @v03_router.get(f'{rpc_url}/v0.3{agent_card_url}') - async def get_v03_agent_card(request: Request) -> Response: - card = await v03_adapter.handle_get_agent_card(request) - return JSONResponse(card) - - app.include_router(v03_router) - return app diff --git a/tests/compat/v0_3/test_rest_fastapi_app_compat.py b/tests/compat/v0_3/test_rest_fastapi_app_compat.py index 7084d15d..8625b7e0 100644 --- a/tests/compat/v0_3/test_rest_fastapi_app_compat.py +++ b/tests/compat/v0_3/test_rest_fastapi_app_compat.py @@ -92,7 +92,7 @@ async def test_send_message_success_message_v03( ) response = await client.post( - '/v0.3/v1/message:send', json=json_format.MessageToDict(request) + '/v1/message:send', json=json_format.MessageToDict(request) ) response.raise_for_status() @@ -127,7 +127,7 @@ async def test_send_message_success_task_v03( ) response = await client.post( - '/v0.3/v1/message:send', json=json_format.MessageToDict(request) + '/v1/message:send', json=json_format.MessageToDict(request) ) response.raise_for_status() @@ -155,7 +155,7 @@ async def test_get_task_v03( ), ) - response = await client.get('/v0.3/v1/tasks/test_task_id') + response = await client.get('/v1/tasks/test_task_id') response.raise_for_status() actual_response = a2a_v0_3_pb2.Task() @@ -182,7 +182,7 @@ async def test_cancel_task_v03( ), ) - response = await client.post('/v0.3/v1/tasks/test_task_id:cancel') + response = await client.post('/v1/tasks/test_task_id:cancel') response.raise_for_status() actual_response = a2a_v0_3_pb2.Task() diff --git a/tests/integration/cross_version/client_server/client_0_3.py b/tests/integration/cross_version/client_server/client_0_3.py index 2c599122..8e0db514 100644 --- a/tests/integration/cross_version/client_server/client_0_3.py +++ b/tests/integration/cross_version/client_server/client_0_3.py @@ -14,20 +14,45 @@ TransportProtocol, TaskQueryParams, TaskIdParams, + TaskState, TaskPushNotificationConfig, PushNotificationConfig, + FilePart, + FileWithUri, + FileWithBytes, + DataPart, ) from a2a.client.errors import A2AClientJSONRPCError, A2AClientHTTPError import sys +import traceback async def test_send_message_stream(client): print('Testing send_message (streaming)...') + msg = Message( role=Role.user, message_id=f'stream-{uuid4()}', - parts=[Part(root=TextPart(text='stream'))], - metadata={'test_key': 'test_value'}, + parts=[ + Part(root=TextPart(text='stream')), + Part( + root=FilePart( + file=FileWithUri( + uri='https://example.com/file.txt', + mime_type='text/plain', + ) + ) + ), + Part( + root=FilePart( + file=FileWithBytes( + bytes=b'aGVsbG8=', mime_type='application/octet-stream' + ) + ) + ), + Part(root=DataPart(data={'key': 'value'})), + ], + metadata={'test_key': 'full_message'}, ) events = [] @@ -62,38 +87,43 @@ async def test_send_message_sync(url, protocol_enum): role=Role.user, message_id=f'sync-{uuid4()}', parts=[Part(root=TextPart(text='sync'))], - metadata={'test_key': 'test_value'}, + metadata={'test_key': 'simple_message'}, ) - # In v0.3 SDK, send_message ALWAYS returns an async generator async for event in client.send_message(request=msg): assert event is not None event_obj = event[0] if isinstance(event, tuple) else event - if ( - getattr(event_obj, 'status', None) - and getattr(event_obj.status, 'state', None) - == 'TASK_STATE_COMPLETED' - ): - assert ( - getattr(event_obj.status.message, 'metadata', {}).get( - 'response_key' - ) - == 'response_value' - ), ( - f'Missing response metadata: {getattr(event_obj.status.message, "metadata", {})}' + + status = getattr(event_obj, 'status', None) + if status and str(getattr(status, 'state', '')).endswith('completed'): + # In 0.3 SDK, the message on the status might be exposed as 'message' or 'update' + status_msg = getattr( + status, 'message', getattr(status, 'update', None) ) - elif getattr(event_obj, 'status', None) and str( - getattr(event_obj.status, 'state', None) - ).endswith('completed'): - assert ( - getattr(event_obj.status.message, 'metadata', {}).get( - 'response_key' - ) - == 'response_value' - ), ( - f'Missing response metadata: {getattr(event_obj.status.message, "metadata", {})}' + assert status_msg is not None, ( + 'TaskStatus message/update is missing' ) - break + + metadata = getattr(status_msg, 'metadata', {}) + assert metadata.get('response_key') == 'response_value', ( + f'Missing response metadata: {metadata}' + ) + + # Check Part translation (root text part in 0.3) + parts = getattr( + status_msg, 'parts', getattr(status_msg, 'content', []) + ) + assert len(parts) > 0, 'No parts found in TaskStatus message' + first_part = parts[0] + text = getattr(first_part, 'text', '') + if ( + not text + and hasattr(first_part, 'root') + and hasattr(first_part.root, 'text') + ): + text = first_part.root.text + assert text == 'done', f"Expected 'done' text in Part, got '{text}'" + break print(f'Success: send_message (synchronous) passed.') @@ -102,20 +132,73 @@ async def test_get_task(client, task_id): print(f'Testing get_task ({task_id})...') task = await client.get_task(request=TaskQueryParams(id=task_id)) assert task.id == task_id + + user_msgs = [ + m for m in task.history if getattr(m, 'role', None) == Role.user + ] + assert user_msgs, 'Expected at least one ROLE_USER message in task history' + + client_msg = user_msgs[0] + + parts = client_msg.parts + assert len(parts) == 4, f'Expected 4 parts, got {len(parts)}' + + # 1. text part + text = getattr(parts[0].root, 'text', '') + assert text == 'stream', f"Expected 'stream', got {text}" + + # 2. uri part + file_uri = getattr(parts[1].root, 'file', None) + assert ( + file_uri is not None + and getattr(file_uri, 'uri', None) == 'https://example.com/file.txt' + ) + + # 3. bytes part + file_bytes = getattr(parts[2].root, 'file', None) + actual_bytes = getattr(file_bytes, 'bytes', None) + assert actual_bytes == 'aGVsbG8=', ( + f"Expected base64 'hello', got {actual_bytes}" + ) + + # 4. data part + data_val = getattr(parts[3].root, 'data', None) + assert data_val is not None + assert data_val == {'key': 'value'} + print('Success: get_task passed.') async def test_cancel_task(client, task_id): print(f'Testing cancel_task ({task_id})...') await client.cancel_task(request=TaskIdParams(id=task_id)) + task = await client.get_task(request=TaskQueryParams(id=task_id)) + assert task.status.state == TaskState.canceled, ( + f'Expected a canceled state, got {task.status.state}' + ) print('Success: cancel_task passed.') async def test_subscribe(client, task_id): print(f'Testing subscribe ({task_id})...') + has_artifact = False async for event in client.resubscribe(request=TaskIdParams(id=task_id)): - print(f'Received event: {event}') - break + # event is tuple (Task, UpdateEvent) + task, update = event + if update and hasattr(update, 'artifact'): + has_artifact = True + artifact = update.artifact + assert artifact.name == 'test-artifact' + assert artifact.metadata.get('artifact_key') == 'artifact_value' + # part check + assert len(artifact.parts) > 0 + p = artifact.parts[0] + text = getattr(p.root, 'text', '') + assert text == 'artifact-chunk' + print('Success: received artifact update.') + + if has_artifact: + break print('Success: subscribe passed.') @@ -124,7 +207,27 @@ async def test_get_extended_agent_card(client): # In v0.3, extended card is fetched via get_card() on the client card = await client.get_card() assert card is not None - # the MockAgentExecutor might not have a name or has one, just assert card exists + assert card.name in ('Server 0.3', 'Server 1.0') + assert card.version == '1.0.0' + assert 'Server running on a2a v' in card.description + + assert card.capabilities is not None + assert card.capabilities.streaming is True + assert card.capabilities.push_notifications is True + + if card.name == 'Server 0.3': + assert card.url is not None + assert card.preferred_transport == TransportProtocol.jsonrpc + assert len(card.additional_interfaces) == 2 + assert card.supports_authenticated_extended_card is False + else: + assert card.url is not None + assert card.preferred_transport is not None + print( + f'card.supports_authenticated_extended_card is: {card.supports_authenticated_extended_card}' + ) + assert card.supports_authenticated_extended_card in (False, None) + print(f'Success: get_extended_agent_card passed.') @@ -177,8 +280,6 @@ def main(): try: asyncio.run(run_client(args.url, protocol)) except Exception as e: - import traceback - traceback.print_exc() print(f'FAILED protocol {protocol}: {e}') failed = True diff --git a/tests/integration/cross_version/client_server/client_1_0.py b/tests/integration/cross_version/client_server/client_1_0.py index 9fa14852..537a7360 100644 --- a/tests/integration/cross_version/client_server/client_1_0.py +++ b/tests/integration/cross_version/client_server/client_1_0.py @@ -16,16 +16,32 @@ SubscribeToTaskRequest, GetExtendedAgentCardRequest, SendMessageRequest, + TaskPushNotificationConfig, + GetTaskPushNotificationConfigRequest, + ListTaskPushNotificationConfigsRequest, + DeleteTaskPushNotificationConfigRequest, + TaskState, ) +from a2a.client.errors import A2AClientError +from google.protobuf.struct_pb2 import Struct, Value async def test_send_message_stream(client): print('Testing send_message (streaming)...') + + s = Struct() + s.update({'key': 'value'}) + msg = Message( role=Role.ROLE_USER, message_id=f'stream-{uuid4()}', - parts=[Part(text='stream')], - metadata={'test_key': 'test_value'}, + parts=[ + Part(text='stream'), + Part(url='https://example.com/file.txt', media_type='text/plain'), + Part(raw=b'hello', media_type='application/octet-stream'), + Part(data=Value(struct_value=s)), + ], + metadata={'test_key': 'full_message'}, ) events = [] @@ -69,7 +85,7 @@ async def test_send_message_sync(url, protocol_enum): role=Role.ROLE_USER, message_id=f'sync-{uuid4()}', parts=[Part(text='sync')], - metadata={'test_key': 'test_value'}, + metadata={'test_key': 'simple_message'}, ) async for event in client.send_message( @@ -78,22 +94,21 @@ async def test_send_message_sync(url, protocol_enum): assert event is not None stream_response = event[0] - # In v1.0, check task status in StreamResponse + status = None if stream_response.HasField('task'): - task = stream_response.task - if task.status.state == 3: # TASK_STATE_COMPLETED - metadata = dict(task.status.message.metadata) - assert metadata.get('response_key') == 'response_value', ( - f'Missing response metadata: {metadata}' - ) + status = stream_response.task.status elif stream_response.HasField('status_update'): - status_update = stream_response.status_update - if status_update.status.state == 3: # TASK_STATE_COMPLETED - metadata = dict(status_update.status.message.metadata) - assert metadata.get('response_key') == 'response_value', ( - f'Missing response metadata: {metadata}' - ) - break + status = stream_response.status_update.status + + if status and status.state == TaskState.TASK_STATE_COMPLETED: + metadata = dict(status.message.metadata) + assert metadata.get('response_key') == 'response_value', ( + f'Missing response metadata: {metadata}' + ) + assert status.message.parts[0].text == 'done' + break + else: + print(f'Ignore message: {stream_response}') print(f'Success: send_message (synchronous) passed.') @@ -102,32 +117,169 @@ async def test_get_task(client, task_id): print(f'Testing get_task ({task_id})...') task = await client.get_task(request=GetTaskRequest(id=task_id)) assert task.id == task_id + + user_msgs = [m for m in task.history if m.role == Role.ROLE_USER] + assert user_msgs, 'Expected at least one ROLE_USER message in task history' + client_msg = user_msgs[0] + + assert len(client_msg.parts) == 4, ( + f'Expected 4 parts, got {len(client_msg.parts)}' + ) + + # 1. text part + assert client_msg.parts[0].text == 'stream', ( + f"Expected 'stream', got {client_msg.parts[0].text}" + ) + + # 2. uri part + assert client_msg.parts[1].url == 'https://example.com/file.txt' + + # 3. bytes part + assert client_msg.parts[2].raw == b'hello' + + # 4. data part + data_dict = dict(client_msg.parts[3].data.struct_value.fields) + assert data_dict['key'].string_value == 'value' + print('Success: get_task passed.') async def test_cancel_task(client, task_id): print(f'Testing cancel_task ({task_id})...') await client.cancel_task(request=CancelTaskRequest(id=task_id)) + task = await client.get_task(request=GetTaskRequest(id=task_id)) + assert task.status.state == TaskState.TASK_STATE_CANCELED, ( + f'Expected {TaskState.TASK_STATE_CANCELED}, got {task.status.state}' + ) print('Success: cancel_task passed.') async def test_subscribe(client, task_id): print(f'Testing subscribe ({task_id})...') + has_artifact = False async for event in client.subscribe( request=SubscribeToTaskRequest(id=task_id) ): - print(f'Received event: {event}') - break + assert event is not None + stream_response = event[0] + if stream_response.HasField('artifact_update'): + has_artifact = True + artifact = stream_response.artifact_update.artifact + assert artifact.name == 'test-artifact' + val = artifact.metadata['artifact_key'] + if hasattr(val, 'string_value'): + assert val.string_value == 'artifact_value' + else: + assert val == 'artifact_value' + assert artifact.parts[0].text == 'artifact-chunk' + print('Success: received artifact update.') + + if has_artifact: + break print('Success: subscribe passed.') +async def test_list_tasks(client, server_name): + from a2a.types import ListTasksRequest + from a2a.client.errors import A2AClientError + + print('Testing list_tasks...') + try: + resp = await client.list_tasks(request=ListTasksRequest()) + assert resp is not None + print(f'Success: list_tasks returned {len(resp.tasks)} tasks') + except NotImplementedError as e: + if server_name == 'Server 0.3': + print(f'Success: list_tasks gracefully failed on 0.3 Server: {e}') + else: + raise e + + async def test_get_extended_agent_card(client): print('Testing get_extended_agent_card...') card = await client.get_extended_agent_card( request=GetExtendedAgentCardRequest() ) assert card is not None + assert card.name in ('Server 0.3', 'Server 1.0') + assert card.version == '1.0.0' + assert 'Server running on a2a v' in card.description + + assert card.capabilities is not None + assert card.capabilities.streaming is True + assert card.capabilities.push_notifications is True + + if card.name == 'Server 1.0': + assert len(card.supported_interfaces) == 4 + assert card.capabilities.extended_agent_card in (False, None) + else: + assert len(card.supported_interfaces) > 0 + assert card.capabilities.extended_agent_card in (False, None) + print(f'Success: get_extended_agent_card passed.') + return card.name + + +async def test_push_notification_lifecycle(client, task_id, server_name): + print(f'Testing Push Notification lifecycle for task {task_id}...') + config_id = f'push-{uuid4()}' + + # 1. Create + task_push_cfg = TaskPushNotificationConfig( + task_id=task_id, id=config_id, url='http://127.0.0.1:9999/webhook' + ) + + created = await client.create_task_push_notification_config( + request=task_push_cfg + ) + assert created.id == config_id + print('Success: create_task_push_notification_config passed.') + + # 2. Get + get_req = GetTaskPushNotificationConfigRequest( + task_id=task_id, id=config_id + ) + fetched = await client.get_task_push_notification_config(request=get_req) + assert fetched.id == config_id + print('Success: get_task_push_notification_config passed.') + + # 3. List + try: + list_req = ListTaskPushNotificationConfigsRequest(task_id=task_id) + listed = await client.list_task_push_notification_configs( + request=list_req + ) + assert any(c.id == config_id for c in listed.configs) + except (NotImplementedError, A2AClientError) as e: + if server_name == 'Server 0.3': + print( + 'EXPECTED: list_task_push_notification_configs not implemented' + ) + else: + raise e + print('Success: list_task_push_notification_configs passed.') + + try: + # 4. Delete + del_req = DeleteTaskPushNotificationConfigRequest( + task_id=task_id, id=config_id + ) + await client.delete_task_push_notification_config(request=del_req) + print('Success: delete_task_push_notification_config passed.') + + # Verify deletion + listed_after = await client.list_task_push_notification_configs( + request=list_req + ) + assert not any(c.id == config_id for c in listed_after.configs) + print('Success: verified deletion.') + except (NotImplementedError, A2AClientError) as e: + if server_name == 'Server 0.3': + print( + 'EXPECTED: delete_task_push_notification_config not implemented' + ) + else: + raise e async def run_client(url: str, protocol: str): @@ -147,7 +299,10 @@ async def run_client(url: str, protocol: str): client = await ClientFactory.connect(url, client_config=config) # 1. Get Extended Agent Card - await test_get_extended_agent_card(client) + server_name = await test_get_extended_agent_card(client) + + # 1.5. List Tasks + await test_list_tasks(client, server_name) # 2. Send Streaming Message task_id = await test_send_message_stream(client) @@ -155,6 +310,9 @@ async def run_client(url: str, protocol: str): # 3. Get Task await test_get_task(client, task_id) + # 3.5 Push Notification Lifecycle + await test_push_notification_lifecycle(client, task_id, server_name) + # 4. Subscribe to Task await test_subscribe(client, task_id) diff --git a/tests/integration/cross_version/client_server/server_0_3.py b/tests/integration/cross_version/client_server/server_0_3.py index aa0b14de..7bd5f7e7 100644 --- a/tests/integration/cross_version/client_server/server_0_3.py +++ b/tests/integration/cross_version/client_server/server_0_3.py @@ -17,6 +17,9 @@ ) from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.tasks.task_updater import TaskUpdater +from a2a.server.tasks.inmemory_push_notification_config_store import ( + InMemoryPushNotificationConfigStore, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import ( AgentCapabilities, @@ -25,9 +28,18 @@ Part, TaskState, TextPart, + FilePart, TransportProtocol, + FileWithBytes, + FileWithUri, + DataPart, ) from a2a.grpc import a2a_pb2_grpc +from starlette.requests import Request +from starlette.concurrency import iterate_in_threadpool +import time + +from server_common import CustomLoggingMiddleware class MockAgentExecutor(AgentExecutor): @@ -57,12 +69,35 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if context.message and context.message.metadata else {} ) - if metadata.get('test_key') != 'test_value': + if metadata.get('test_key') not in ('full_message', 'simple_message'): print(f'SERVER: WARNING: Missing or incorrect metadata: {metadata}') raise ValueError( f'Missing expected metadata from client. Got: {metadata}' ) + if metadata.get('test_key') == 'full_message': + expected_parts = [ + Part(root=TextPart(text='stream')), + Part( + root=FilePart( + file=FileWithUri( + uri='https://example.com/file.txt', + mime_type='text/plain', + ) + ) + ), + Part( + root=FilePart( + file=FileWithBytes( + bytes=b'aGVsbG8=', + mime_type='application/octet-stream', + ) + ) + ), + Part(root=DataPart(data={'key': 'value'})), + ] + assert context.message.parts == expected_parts + print(f"SERVER: request message text='{text}'") if 'stream' in text: @@ -79,13 +114,20 @@ async def emit_periodic(): [Part(root=TextPart(text='ping'))] ), ) + await task_updater.add_artifact( + [Part(root=TextPart(text='artifact-chunk'))], + name='test-artifact', + metadata={'artifact_key': 'artifact_value'}, + ) await asyncio.sleep(0.1) except asyncio.CancelledError: pass bg_task = asyncio.create_task(emit_periodic()) + await event.wait() bg_task.cancel() + print(f'SERVER: stream event triggered for task {context.task_id}') await task_updater.update_status( @@ -99,8 +141,8 @@ async def emit_periodic(): async def cancel(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: cancel called for task {context.task_id}') - if context.task_id in self.events: - self.events[context.task_id].set() + assert context.task_id in self.events + self.events[context.task_id].set() task_updater = TaskUpdater( event_queue, context.task_id, @@ -121,9 +163,7 @@ async def main_async(http_port: int, grpc_port: int): url=f'http://127.0.0.1:{http_port}/jsonrpc/', preferred_transport=TransportProtocol.jsonrpc, skills=[], - capabilities=AgentCapabilities( - streaming=True, push_notifications=False - ), + capabilities=AgentCapabilities(streaming=True, push_notifications=True), default_input_modes=['text/plain'], default_output_modes=['text/plain'], additional_interfaces=[ @@ -144,6 +184,7 @@ async def main_async(http_port: int, grpc_port: int): agent_executor=MockAgentExecutor(), task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=InMemoryPushNotificationConfigStore(), ) app = FastAPI() @@ -166,9 +207,11 @@ async def main_async(http_port: int, grpc_port: int): server.add_insecure_port(f'127.0.0.1:{grpc_port}') await server.start() + app.add_middleware(CustomLoggingMiddleware) + # Start Uvicorn config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level='warning' + app, host='127.0.0.1', port=http_port, log_level='info', access_log=True ) uvicorn_server = uvicorn.Server(config) await uvicorn_server.serve() diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index f3058771..e079fdf2 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -10,6 +10,9 @@ from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.tasks import TaskUpdater +from a2a.server.tasks.inmemory_push_notification_config_store import ( + InMemoryPushNotificationConfigStore, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -22,6 +25,8 @@ from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.utils import TransportProtocol +from server_common import CustomLoggingMiddleware +from google.protobuf.struct_pb2 import Struct, Value class MockAgentExecutor(AgentExecutor): @@ -47,13 +52,29 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if context.message and context.message.metadata else {} ) - if metadata.get('test_key') != 'test_value': + if metadata.get('test_key') not in ('full_message', 'simple_message'): print(f'SERVER: WARNING: Missing or incorrect metadata: {metadata}') raise ValueError( f'Missing expected metadata from client. Got: {metadata}' ) - print(f'SERVER: request message text={text}\nmessage={context.message}') + for part in context.message.parts: + if part.HasField('raw'): + assert part.raw == b'hello' + + if metadata.get('test_key') == 'full_message': + s = Struct() + s.update({'key': 'value'}) + + expected_parts = [ + Part(text='stream'), + Part( + url='https://example.com/file.txt', media_type='text/plain' + ), + Part(raw=b'hello', media_type='application/octet-stream'), + Part(data=Value(struct_value=s)), + ] + assert context.message.parts == expected_parts if 'stream' in text: print(f'SERVER: waiting on stream event for task {context.task_id}') @@ -69,6 +90,11 @@ async def emit_periodic(): [Part(text='ping')] ), ) + await task_updater.add_artifact( + [Part(text='artifact-chunk')], + name='test-artifact', + metadata={'artifact_key': 'artifact_value'}, + ) await asyncio.sleep(0.1) except asyncio.CancelledError: pass @@ -88,8 +114,8 @@ async def emit_periodic(): async def cancel(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: cancel called for task {context.task_id}') - if context.task_id in self.events: - self.events[context.task_id].set() + assert context.task_id in self.events + self.events[context.task_id].set() task_updater = TaskUpdater( event_queue, context.task_id, @@ -104,9 +130,7 @@ async def main_async(http_port: int, grpc_port: int): description='Server running on a2a v1.0', version='1.0.0', skills=[], - capabilities=AgentCapabilities( - streaming=True, push_notifications=False - ), + capabilities=AgentCapabilities(streaming=True, push_notifications=True), default_input_modes=['text/plain'], default_output_modes=['text/plain'], supported_interfaces=[ @@ -121,7 +145,7 @@ async def main_async(http_port: int, grpc_port: int): ), AgentInterface( protocol_binding=TransportProtocol.HTTP_JSON, - url=f'http://127.0.0.1:{http_port}/rest/v0.3/', + url=f'http://127.0.0.1:{http_port}/rest/', protocol_version='0.3', ), AgentInterface( @@ -136,9 +160,12 @@ async def main_async(http_port: int, grpc_port: int): agent_executor=MockAgentExecutor(), task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=InMemoryPushNotificationConfigStore(), ) app = FastAPI() + app.add_middleware(CustomLoggingMiddleware) + jsonrpc_app = A2AFastAPIApplication( http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True ).build() @@ -164,7 +191,7 @@ async def main_async(http_port: int, grpc_port: int): # Start Uvicorn config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level='warning' + app, host='127.0.0.1', port=http_port, log_level='info', access_log=True ) uvicorn_server = uvicorn.Server(config) await uvicorn_server.serve() diff --git a/tests/integration/cross_version/client_server/server_common.py b/tests/integration/cross_version/client_server/server_common.py new file mode 100644 index 00000000..d66c1eb4 --- /dev/null +++ b/tests/integration/cross_version/client_server/server_common.py @@ -0,0 +1,47 @@ +import collections.abc +from typing import AsyncGenerator +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + + +class PrintingAsyncGenerator(collections.abc.AsyncGenerator): + """ + Wraps an async generator to print items as they are yielded, + fully supporting bi-directional flow (asend, athrow, aclose). + """ + + def __init__(self, url: str, ag: AsyncGenerator): + self.url = url + self._ag = ag + + async def asend(self, value): + # Forward the sent value to the underlying async generator + result = await self._ag.asend(value) + print(f'PrintingAsyncGenerator::Generated: {self.url} {result}') + return result + + async def athrow(self, typ, val=None, tb=None): + # Forward exceptions to the underlying async generator + result = await self._ag.athrow(typ, val, tb) + print( + f'PrintingAsyncGenerator::Generated (via athrow): {self.url} {result}' + ) + return result + + async def aclose(self): + # Gracefully shut down the underlying generator + await self._ag.aclose() + + +class CustomLoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + print('-' * 80) + print(f'REQUEST: {request.method} {request.url}') + print(f'REQUEST BODY: {await request.body()}') + + response = await call_next(request) + # Disabled by default. Can hang the test if enabled. + # response.body_iterator = PrintingAsyncGenerator(request.url, response.body_iterator) + + print('-' * 80) + return response diff --git a/tests/integration/cross_version/client_server/test_client_server.py b/tests/integration/cross_version/client_server/test_client_server.py index eeeb47f9..e65aa185 100644 --- a/tests/integration/cross_version/client_server/test_client_server.py +++ b/tests/integration/cross_version/client_server/test_client_server.py @@ -5,6 +5,8 @@ import time import pytest +import select +import signal def get_free_port(): @@ -46,7 +48,7 @@ def finalize_process( proc: subprocess.Popen, name: str, expected_return_code=None, - timeout: int = 5, + timeout: float = 5.0, ): failure = False if expected_return_code is not None: @@ -59,19 +61,23 @@ def finalize_process( failure = True except subprocess.TimeoutExpired: print(f'Process {name} timed out after {timeout} seconds') + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) failure = True else: if proc.poll() is None: - proc.terminate() + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) else: print(f'Process {name} already terminated!') failure = True - try: - proc.wait(timeout=2) - except subprocess.TimeoutExpired: - proc.kill() - stdout_text, stderr_text = proc.communicate() + try: + proc.wait(timeout=2) + except subprocess.TimeoutExpired: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + + print(f'Process {name} finished with code {proc.wait()}') + + stdout_text, stderr_text = proc.communicate(timeout=3.0) print('-' * 80) print(f'Process {name} STDOUT:\n{stdout_text}') @@ -110,6 +116,7 @@ def running_servers(): stderr=subprocess.PIPE, env=get_env('server_1_0.py'), text=True, + start_new_session=True, ) # Server 0.3 setup @@ -142,6 +149,7 @@ def running_servers(): stderr=subprocess.PIPE, env=get_env('server_0_3.py'), text=True, + start_new_session=True, ) try: @@ -177,7 +185,7 @@ def running_servers(): finalize_process(proc, name) -@pytest.mark.timeout(10) +@pytest.mark.timeout(15) @pytest.mark.parametrize( 'server_script, client_script, client_deps, protocols', [ @@ -207,7 +215,7 @@ def running_servers(): 'server_0_3.py', 'client_1_0.py', [], - ['grpc', 'rest', 'jsonrpc'], + ['grpc', 'jsonrpc', 'rest'], ), ], ) @@ -237,5 +245,6 @@ def test_cross_version( stderr=subprocess.PIPE, env=get_env(client_script), text=True, + start_new_session=True, ) finalize_process(client_result, client_script, 0) diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index af94e5a6..19ee5173 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -198,8 +198,7 @@ async def test_create_a2a_rest_fastapi_app_with_v0_3_compat( ).build(agent_card_url='/well-known/agent.json', rpc_url='') routes = [getattr(route, 'path', '') for route in app.routes] - assert '/v0.3/well-known/agent.json' in routes - assert '/v0.3/v1/message:send' in routes + assert '/v1/message:send' in routes @pytest.mark.anyio