diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..fbcf877ed 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -8,8 +8,10 @@ import json import multiprocessing import socket +import threading import time import traceback +import warnings from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -1462,7 +1464,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1487,15 +1489,13 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: name = params.name args = params.arguments or {} if name == "echo_headers": headers_info: dict[str, Any] = {} - if ctx.request and isinstance(ctx.request, Request): + if ctx.request and isinstance(ctx.request, Request): # pragma: no branch headers_info = dict(ctx.request.headers) return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) @@ -1506,19 +1506,19 @@ async def _handle_context_call_tool( # pragma: no cover "method": None, "path": None, } - if ctx.request and isinstance(ctx.request, Request): + if ctx.request and isinstance(ctx.request, Request): # pragma: no branch request = ctx.request context_data["headers"] = dict(request.headers) context_data["method"] = request.method context_data["path"] = request.url.path return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) + return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # pragma: no cover # Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +def _create_context_aware_server(port: int) -> uvicorn.Server: + """Create the context-aware test server app and uvicorn.Server.""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, @@ -1547,26 +1547,45 @@ def run_context_aware_server(port: int): # pragma: no cover log_level="error", ) ) - server_instance.run() + return server_instance @pytest.fixture def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() + """Start the context-aware server on a background thread (in-process for coverage). + + Unlike multiprocessing, threads share the host process's warning filters. + Uvicorn and the Windows ProactorEventLoop emit DeprecationWarning / + ResourceWarning during startup and teardown that pytest's + ``filterwarnings = ["error"]`` would otherwise promote to hard failures. + We therefore run the server with all warnings suppressed (mirroring + the implicit isolation that multiprocessing provided). + """ + server_instance = _create_context_aware_server(basic_server_port) + + def _run() -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + server_instance.run() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() - # Wait for server to be running wait_for_server(basic_server_port) yield - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + server_instance.should_exit = True + thread.join(timeout=5) + + +# Marker to suppress Windows ProactorEventLoop teardown warnings on threaded servers. +# When uvicorn runs in a thread (instead of a subprocess), transport finalizers fire +# during GC in the main process and trigger PytestUnraisableExceptionWarning. +_suppress_transport_teardown = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +@_suppress_transport_teardown @pytest.mark.anyio async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" @@ -1600,6 +1619,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: assert headers_data.get("x-trace-id") == "trace-123" +@_suppress_transport_teardown @pytest.mark.anyio async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" @@ -1638,6 +1658,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No assert ctx["headers"].get("authorization") == f"Bearer token-{i}" +@_suppress_transport_teardown @pytest.mark.anyio async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): """Test that client includes mcp-protocol-version header after initialization.""" @@ -2251,6 +2272,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( assert custom_client.headers.get("Authorization") == "Bearer test-token" +@_suppress_transport_teardown @pytest.mark.anyio async def test_streamable_http_client_mcp_headers_override_defaults( context_aware_server: None, basic_server_url: str @@ -2282,6 +2304,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( assert headers_data["content-type"] == "application/json" +@_suppress_transport_teardown @pytest.mark.anyio async def test_streamable_http_client_preserves_custom_with_mcp_headers( context_aware_server: None, basic_server_url: str