Skip to content

Commit d1f22c5

Browse files
feat(client): add explicit resume flow for known streamable-http sessions
1 parent ae9e8ec commit d1f22c5

File tree

4 files changed

+147
-5
lines changed

4 files changed

+147
-5
lines changed

src/mcp/client/client.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ async def main():
9898
streamable_http_session_id: str | None = None
9999
"""Optional pre-existing MCP session ID used when server is a StreamableHTTP URL."""
100100

101+
streamable_http_initialize_result: InitializeResult | None = None
102+
"""Previously negotiated InitializeResult used to resume a StreamableHTTP session."""
103+
104+
streamable_http_terminate_on_close: bool = True
105+
"""Whether a URL-based StreamableHTTP client should terminate the session on close."""
106+
101107
_session: ClientSession | None = field(init=False, default=None)
102108
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
103109
_transport: Transport = field(init=False)
@@ -106,10 +112,31 @@ def __post_init__(self) -> None:
106112
if isinstance(self.server, Server | MCPServer):
107113
self._transport = InMemoryTransport(self.server, raise_exceptions=self.raise_exceptions)
108114
elif isinstance(self.server, str):
109-
self._transport = streamable_http_client(self.server, session_id=self.streamable_http_session_id)
115+
self._transport = streamable_http_client(
116+
self.server,
117+
session_id=self.streamable_http_session_id,
118+
terminate_on_close=self.streamable_http_terminate_on_close,
119+
)
110120
else:
111121
self._transport = self.server
112122

123+
@classmethod
124+
def resume_session(
125+
cls,
126+
server: str,
127+
*,
128+
session_id: str,
129+
initialize_result: InitializeResult,
130+
**kwargs: Any,
131+
) -> Client:
132+
"""Create a URL-based client configured to resume an existing StreamableHTTP session."""
133+
return cls(
134+
server=server,
135+
streamable_http_session_id=session_id,
136+
streamable_http_initialize_result=initialize_result,
137+
**kwargs,
138+
)
139+
113140
async def __aenter__(self) -> Client:
114141
"""Enter the async context manager."""
115142
if self._session is not None:
@@ -132,7 +159,15 @@ async def __aenter__(self) -> Client:
132159
)
133160
)
134161

135-
await self._session.initialize()
162+
if self.streamable_http_session_id is None and self.streamable_http_initialize_result is not None:
163+
raise RuntimeError(
164+
"streamable_http_initialize_result requires streamable_http_session_id for session resumption"
165+
)
166+
167+
if self.streamable_http_session_id is not None and self.streamable_http_initialize_result is not None:
168+
self._session.resume(self.streamable_http_initialize_result)
169+
else:
170+
await self._session.initialize()
136171

137172
# Transfer ownership to self for __aexit__ to handle
138173
self._exit_stack = exit_stack.pop_all()

src/mcp/client/session.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,14 @@ async def initialize(self) -> types.InitializeResult:
191191

192192
return result
193193

194+
def resume(self, initialize_result: types.InitializeResult) -> None:
195+
"""Mark this session as resumed using previously negotiated initialization data.
196+
197+
This bypasses the initialize/initialized handshake and seeds the session with
198+
server capabilities and metadata from an earlier connection.
199+
"""
200+
self._initialize_result = initialize_result
201+
194202
@property
195203
def initialize_result(self) -> types.InitializeResult | None:
196204
"""The server's InitializeResult. None until initialize() has been called.

src/mcp/client/streamable_http.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,19 @@ async def streamable_http_client(
561561
write_stream_reader,
562562
anyio.create_task_group() as tg,
563563
):
564+
get_stream_started = False
564565

565566
def start_get_stream() -> None:
567+
nonlocal get_stream_started
568+
if get_stream_started:
569+
return
570+
get_stream_started = True
566571
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
567572

573+
# If we're resuming an existing session, start the GET stream immediately.
574+
if session_id:
575+
start_get_stream()
576+
568577
tg.start_soon(
569578
transport.post_writer,
570579
client,

tests/client/test_client.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from unittest.mock import patch
5+
from unittest.mock import AsyncMock, MagicMock, patch
66

77
import anyio
88
import pytest
@@ -307,13 +307,103 @@ async def test_complete_with_prompt_reference(simple_server: Server):
307307
def test_client_with_url_initializes_streamable_http_transport():
308308
with patch("mcp.client.client.streamable_http_client") as mock:
309309
_ = Client("http://localhost:8000/mcp")
310-
mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None)
310+
mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None, terminate_on_close=True)
311311

312312

313313
def test_client_with_url_and_session_id_initializes_streamable_http_transport():
314314
with patch("mcp.client.client.streamable_http_client") as mock:
315315
_ = Client("http://localhost:8000/mcp", streamable_http_session_id="resume-session-id")
316-
mock.assert_called_once_with("http://localhost:8000/mcp", session_id="resume-session-id")
316+
mock.assert_called_once_with(
317+
"http://localhost:8000/mcp",
318+
session_id="resume-session-id",
319+
terminate_on_close=True,
320+
)
321+
322+
323+
def test_client_with_url_and_terminate_on_close_false_initializes_streamable_http_transport():
324+
with patch("mcp.client.client.streamable_http_client") as mock:
325+
_ = Client("http://localhost:8000/mcp", streamable_http_terminate_on_close=False)
326+
mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None, terminate_on_close=False)
327+
328+
329+
def test_client_resume_session_builder_initializes_streamable_http_transport():
330+
initialize_result = types.InitializeResult(
331+
protocol_version="2025-03-26",
332+
capabilities=types.ServerCapabilities(),
333+
server_info=types.Implementation(name="server", version="1.0"),
334+
)
335+
with patch("mcp.client.client.streamable_http_client") as mock:
336+
_ = Client.resume_session(
337+
"http://localhost:8000/mcp",
338+
session_id="resume-session-id",
339+
initialize_result=initialize_result,
340+
)
341+
mock.assert_called_once_with(
342+
"http://localhost:8000/mcp",
343+
session_id="resume-session-id",
344+
terminate_on_close=True,
345+
)
346+
347+
348+
async def test_client_resume_session_skips_initialize():
349+
initialize_result = types.InitializeResult(
350+
protocol_version="2025-03-26",
351+
capabilities=types.ServerCapabilities(),
352+
server_info=types.Implementation(name="server", version="1.0"),
353+
)
354+
355+
transport_cm = AsyncMock()
356+
transport_cm.__aenter__.return_value = (MagicMock(), MagicMock())
357+
transport_cm.__aexit__.return_value = None
358+
359+
session = MagicMock()
360+
session.initialize = AsyncMock()
361+
session.resume = MagicMock()
362+
session.initialize_result = initialize_result
363+
session_cm = AsyncMock()
364+
session_cm.__aenter__.return_value = session
365+
session_cm.__aexit__.return_value = None
366+
367+
with (
368+
patch("mcp.client.client.streamable_http_client", return_value=transport_cm),
369+
patch("mcp.client.client.ClientSession", return_value=session_cm),
370+
):
371+
async with Client.resume_session(
372+
"http://localhost:8000/mcp",
373+
session_id="resume-session-id",
374+
initialize_result=initialize_result,
375+
) as client:
376+
assert client.initialize_result == initialize_result
377+
378+
session.initialize.assert_not_awaited()
379+
session.resume.assert_called_once_with(initialize_result)
380+
381+
382+
async def test_client_streamable_initialize_result_requires_session_id():
383+
initialize_result = types.InitializeResult(
384+
protocol_version="2025-03-26",
385+
capabilities=types.ServerCapabilities(),
386+
server_info=types.Implementation(name="server", version="1.0"),
387+
)
388+
389+
transport_cm = AsyncMock()
390+
transport_cm.__aenter__.return_value = (MagicMock(), MagicMock())
391+
transport_cm.__aexit__.return_value = None
392+
393+
session_cm = AsyncMock()
394+
session_cm.__aenter__.return_value = AsyncMock()
395+
session_cm.__aexit__.return_value = None
396+
397+
with (
398+
patch("mcp.client.client.streamable_http_client", return_value=transport_cm),
399+
patch("mcp.client.client.ClientSession", return_value=session_cm),
400+
):
401+
client = Client(
402+
"http://localhost:8000/mcp",
403+
streamable_http_initialize_result=initialize_result,
404+
)
405+
with pytest.raises(RuntimeError, match="requires streamable_http_session_id"):
406+
await client.__aenter__()
317407

318408

319409
async def test_client_uses_transport_directly(app: MCPServer):

0 commit comments

Comments
 (0)