Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,18 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
yield # Let the application run
finally:
logger.info("StreamableHTTP session manager shutting down")
active_transports = list(self._server_instances.values())
self._server_instances.clear()

for transport in active_transports:
try:
await transport.terminate()
except Exception: # pragma: no cover
logger.exception("Failed to terminate active streamable HTTP session during shutdown")

# Cancel task group to stop all spawned tasks
tg.cancel_scope.cancel()
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process ASGI request with proper session handling and transport setup.
Expand Down
59 changes: 59 additions & 0 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,65 @@ async def mock_receive(): # pragma: no cover
assert not manager._server_instances, "No sessions should be tracked after the only session crashes"


@pytest.mark.anyio
async def test_run_terminates_active_stateful_sessions_on_shutdown():
app = Server("test-shutdown-cleanup")
manager = StreamableHTTPSessionManager(app=app)
created_transports: list[StreamableHTTPServerTransport] = []
run_started = anyio.Event()

original_constructor = StreamableHTTPServerTransport

def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport:
transport = original_constructor(*args, **kwargs)
created_transports.append(transport)
return transport

async def block_run(*args: Any, **kwargs: Any) -> None:
run_started.set()
await anyio.sleep_forever()

app.run = AsyncMock(side_effect=block_run)

sent_messages: list[Message] = []

async def mock_send(message: Message):
sent_messages.append(message)

scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [(b"content-type", b"application/json")],
}

async def mock_receive():
return {"type": "http.request", "body": b"", "more_body": False} # pragma: no cover

with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport):
async with manager.run():
await manager.handle_request(scope, mock_receive, mock_send)
await run_started.wait()

assert len(created_transports) == 1
transport = created_transports[0]
terminate_spy = AsyncMock(side_effect=transport.terminate)
transport.terminate = terminate_spy

response_start = next(
(msg for msg in sent_messages if msg["type"] == "http.response.start"),
None,
)
assert response_start is not None
assert manager._server_instances

await anyio.sleep(0)

terminate_spy.assert_awaited_once()
assert transport._terminated
assert not manager._server_instances


@pytest.mark.anyio
async def test_stateless_requests_memory_cleanup():
"""Test that stateless requests actually clean up resources using real transports."""
Expand Down
Loading