Skip to content

Commit 349661c

Browse files
Varun SharmaCopilot
andcommitted
fix: collapse single-exception ExceptionGroups from task groups
Replace all 16 anyio.create_task_group() calls with create_mcp_task_group() which automatically unwraps BaseExceptionGroups containing a single exception. This allows callers to catch specific error types (e.g. except ConnectionError) instead of having to handle ExceptionGroup wrapping. - Add src/mcp/shared/_task_group.py with collapse_exception_group() utility and _CollapsingTaskGroup wrapper class - Update all client transports (sse, stdio, websocket, streamable_http, memory) - Update all server transports (sse, stdio, websocket, streamable_http) - Update shared session, session_group, lowlevel server, task_support, task_result_handler, and streamable_http_manager - Add builtins config for BaseExceptionGroup/ExceptionGroup in ruff - Add 12 comprehensive tests covering collapse logic and integration Closes #2114 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 7ba41dc commit 349661c

File tree

18 files changed

+279
-22
lines changed

18 files changed

+279
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ executionEnvironments = [
126126
line-length = 120
127127
target-version = "py310"
128128
extend-exclude = ["README.md", "README.v2.md"]
129+
builtins = ["BaseExceptionGroup", "ExceptionGroup"]
129130

130131
[tool.ruff.lint]
131132
select = [

src/mcp/client/_memory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from types import TracebackType
88
from typing import Any
99

10-
import anyio
11-
1210
from mcp.client._transport import TransportStreams
1311
from mcp.server import Server
1412
from mcp.server.mcpserver import MCPServer
13+
from mcp.shared._task_group import create_mcp_task_group
1514
from mcp.shared.memory import create_client_server_memory_streams
1615

1716

@@ -48,7 +47,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
4847
client_read, client_write = client_streams
4948
server_read, server_write = server_streams
5049

51-
async with anyio.create_task_group() as tg:
50+
async with create_mcp_task_group() as tg:
5251
# Start server in background
5352
tg.start_soon(
5453
lambda: actual_server.run(

src/mcp/client/session_group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from types import TracebackType
1414
from typing import Any, TypeAlias
1515

16-
import anyio
1716
import httpx
1817
from pydantic import BaseModel, Field
1918
from typing_extensions import Self
@@ -25,6 +24,7 @@
2524
from mcp.client.stdio import StdioServerParameters
2625
from mcp.client.streamable_http import streamable_http_client
2726
from mcp.shared._httpx_utils import create_mcp_http_client
27+
from mcp.shared._task_group import create_mcp_task_group
2828
from mcp.shared.exceptions import MCPError
2929
from mcp.shared.session import ProgressFnT
3030

@@ -166,7 +166,7 @@ async def __aexit__(
166166
await self._exit_stack.aclose()
167167

168168
# Concurrently close session stacks.
169-
async with anyio.create_task_group() as tg:
169+
async with create_mcp_task_group() as tg:
170170
for exit_stack in self._session_exit_stacks.values():
171171
tg.start_soon(exit_stack.aclose)
172172

src/mcp/client/sse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mcp import types
1515
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
16+
from mcp.shared._task_group import create_mcp_task_group
1617
from mcp.shared.message import SessionMessage
1718

1819
logger = logging.getLogger(__name__)
@@ -60,7 +61,7 @@ async def sse_client(
6061
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
6162
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
6263

63-
async with anyio.create_task_group() as tg:
64+
async with create_mcp_task_group() as tg:
6465
try:
6566
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
6667
async with httpx_client_factory(

src/mcp/client/stdio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_windows_executable_command,
2121
terminate_windows_process_tree,
2222
)
23+
from mcp.shared._task_group import create_mcp_task_group
2324
from mcp.shared.message import SessionMessage
2425

2526
logger = logging.getLogger(__name__)
@@ -177,7 +178,7 @@ async def stdin_writer():
177178
except anyio.ClosedResourceError: # pragma: no cover
178179
await anyio.lowlevel.checkpoint()
179180

180-
async with anyio.create_task_group() as tg, process:
181+
async with create_mcp_task_group() as tg, process:
181182
tg.start_soon(stdout_reader)
182183
tg.start_soon(stdin_writer)
183184
try:

src/mcp/client/streamable_http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from mcp.client._transport import TransportStreams
1919
from mcp.shared._httpx_utils import create_mcp_http_client
20+
from mcp.shared._task_group import create_mcp_task_group
2021
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2122
from mcp.types import (
2223
INTERNAL_ERROR,
@@ -546,7 +547,7 @@ async def streamable_http_client(
546547

547548
transport = StreamableHTTPTransport(url)
548549

549-
async with anyio.create_task_group() as tg:
550+
async with create_mcp_task_group() as tg:
550551
try:
551552
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
552553

src/mcp/client/websocket.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from websockets.typing import Subprotocol
1010

1111
from mcp import types
12+
from mcp.shared._task_group import create_mcp_task_group
1213
from mcp.shared.message import SessionMessage
1314

1415

@@ -68,7 +69,7 @@ async def ws_writer():
6869
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True)
6970
await ws.send(json.dumps(msg_dict))
7071

71-
async with anyio.create_task_group() as tg:
72+
async with create_mcp_task_group() as tg:
7273
# Start reader and writer tasks
7374
tg.start_soon(ws_reader)
7475
tg.start_soon(ws_writer)

src/mcp/server/experimental/task_result_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
import logging
1313
from typing import Any
1414

15-
import anyio
16-
1715
from mcp.server.session import ServerSession
16+
from mcp.shared._task_group import create_mcp_task_group
1817
from mcp.shared.exceptions import MCPError
1918
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
2019
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
@@ -162,7 +161,7 @@ async def _wait_for_task_update(self, task_id: str) -> None:
162161
163162
Races between store update and queue message - first one wins.
164163
"""
165-
async with anyio.create_task_group() as tg:
164+
async with create_mcp_task_group() as tg:
166165

167166
async def wait_for_store() -> None:
168167
try:

src/mcp/server/experimental/task_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass, field
1010

11-
import anyio
1211
from anyio.abc import TaskGroup
1312

1413
from mcp.server.experimental.task_result_handler import TaskResultHandler
1514
from mcp.server.session import ServerSession
15+
from mcp.shared._task_group import create_mcp_task_group
1616
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
1717
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
1818
from mcp.shared.experimental.tasks.store import TaskStore
@@ -79,7 +79,7 @@ async def run(self) -> AsyncIterator[None]:
7979
# Task group is now available
8080
...
8181
"""
82-
async with anyio.create_task_group() as tg:
82+
async with create_mcp_task_group() as tg:
8383
self._task_group = tg
8484
try:
8585
yield

src/mcp/server/lowlevel/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ async def main():
6565
from mcp.server.streamable_http import EventStore
6666
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
6767
from mcp.server.transport_security import TransportSecuritySettings
68+
from mcp.shared._task_group import create_mcp_task_group
6869
from mcp.shared.exceptions import MCPError
6970
from mcp.shared.message import ServerMessageMetadata, SessionMessage
7071
from mcp.shared.session import RequestResponder
@@ -386,7 +387,7 @@ async def run(
386387
task_support.configure_session(session)
387388
await stack.enter_async_context(task_support.run())
388389

389-
async with anyio.create_task_group() as tg:
390+
async with create_mcp_task_group() as tg:
390391
async for message in session.incoming_messages:
391392
logger.debug("Received message: %s", message)
392393

0 commit comments

Comments
 (0)