diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7b66b5c1b..2058a1ece 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -9,7 +9,6 @@ from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse -from httpx_sse._exceptions import SSEError from mcp import types from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client @@ -69,6 +68,12 @@ async def sse_client( write_stream, write_stream_reader = anyio.create_memory_object_stream(0) async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + # Before task_status.started() fires, the caller is blocked inside + # tg.start() and nobody reads from read_stream. Sending to the + # zero-buffer stream in that phase would deadlock, so errors must + # be raised instead. After started(), the caller has the streams + # and errors are delivered through read_stream. + started = False try: async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") @@ -79,15 +84,13 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): url_parsed = urlparse(url) endpoint_parsed = urlparse(endpoint_url) - if ( # pragma: no cover + if ( url_parsed.netloc != endpoint_parsed.netloc or url_parsed.scheme != endpoint_parsed.scheme ): - error_msg = ( # pragma: no cover + raise ValueError( f"Endpoint origin does not match connection origin: {endpoint_url}" ) - logger.error(error_msg) # pragma: no cover - raise ValueError(error_msg) # pragma: no cover if on_session_created: session_id = _extract_session_id_from_endpoint(endpoint_url) @@ -95,11 +98,14 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): on_session_created(session_id) task_status.started(endpoint_url) + started = True case "message": # Skip empty data (keep-alive pings) if not sse.data: continue + if not started: + raise RuntimeError("Received message event before endpoint event") try: message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False) logger.debug(f"Received server message: {message}") @@ -112,11 +118,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): await read_stream_writer.send(session_message) case _: # pragma: no cover logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover - except SSEError as sse_exc: # pragma: lax no cover - logger.exception("Encountered SSE exception") - raise sse_exc - except Exception as exc: # pragma: lax no cover + except Exception as exc: logger.exception("Error in sse_reader") + if not started: + raise await read_stream_writer.send(exc) finally: await read_stream_writer.aclose() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707..8aac3b32d 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,11 +1,17 @@ import json import multiprocessing import socket +import sys from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse +# BaseExceptionGroup is builtin on 3.11+. On 3.10 it comes from the +# exceptiongroup backport, which anyio pulls in as a dependency. +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup + import anyio import httpx import pytest @@ -604,6 +610,105 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert msg.message.id == 1 +def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any: + """Patch sse_client's HTTP layer to yield the given SSE event stream.""" + mock_event_source = MagicMock() + mock_event_source.aiter_sse.return_value = aiter_sse + mock_event_source.response.raise_for_status = MagicMock() + + mock_aconnect_sse = MagicMock() + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock())) + + return patch.multiple( + "mcp.client.sse", + create_mcp_http_client=Mock(return_value=mock_client), + aconnect_sse=Mock(return_value=mock_aconnect_sse), + ) + + +@pytest.mark.anyio +async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447 + + When the server sends an endpoint URL with a different origin than the + connection URL, sse_client must raise promptly instead of deadlocking. + Before the fix, the ValueError was caught and sent to a zero-buffer stream + with no reader, hanging forever. + """ + + async def events() -> AsyncGenerator[ServerSentEvent, None]: + yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc") + await anyio.sleep_forever() # pragma: no cover + + with _mock_sse_connection(events()), anyio.fail_after(5): + with pytest.raises(BaseExceptionGroup) as exc_info: + async with sse_client("http://test/sse"): # pragma: no branch + pytest.fail("sse_client should not yield on origin mismatch") # pragma: no cover + assert exc_info.group_contains(ValueError, match="Endpoint origin does not match") + + +@pytest.mark.anyio +async def test_sse_client_raises_on_error_before_endpoint() -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447 + + Any exception raised while waiting for the endpoint event must propagate + instead of deadlocking on the zero-buffer read stream. + """ + + async def events() -> AsyncGenerator[ServerSentEvent, None]: + raise ConnectionError("connection reset by peer") + yield # pragma: no cover + + with _mock_sse_connection(events()), anyio.fail_after(5): + with pytest.raises(BaseExceptionGroup) as exc_info: + async with sse_client("http://test/sse"): # pragma: no branch + pytest.fail("sse_client should not yield on pre-endpoint error") # pragma: no cover + assert exc_info.group_contains(ConnectionError, match="connection reset") + + +@pytest.mark.anyio +async def test_sse_client_raises_on_message_before_endpoint() -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447 + + If the server sends a message event before the endpoint event (protocol + violation), sse_client must raise rather than deadlock trying to send the + message to a stream nobody is reading yet. + """ + + async def events() -> AsyncGenerator[ServerSentEvent, None]: + yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}') + await anyio.sleep_forever() # pragma: no cover + + with _mock_sse_connection(events()), anyio.fail_after(5): + with pytest.raises(BaseExceptionGroup) as exc_info: + async with sse_client("http://test/sse"): # pragma: no branch + pytest.fail("sse_client should not yield on protocol violation") # pragma: no cover + assert exc_info.group_contains(RuntimeError, match="before endpoint event") + + +@pytest.mark.anyio +async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None: + """After the endpoint is received, errors in sse_reader are delivered on the + read stream so the session can handle them, rather than crashing the task group. + """ + + async def events() -> AsyncGenerator[ServerSentEvent, None]: + yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc") + raise ConnectionError("mid-stream failure") + + with _mock_sse_connection(events()), anyio.fail_after(5): + async with sse_client("http://test/sse") as (read_stream, _): + received = await read_stream.receive() + assert isinstance(received, ConnectionError) + assert "mid-stream failure" in str(received) + + @pytest.mark.anyio async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227