diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117..d3502f2dd 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -238,7 +238,13 @@ def session_manager(self) -> StreamableHTTPSessionManager: return self._lowlevel_server.session_manager # pragma: no cover @overload - def run(self, transport: Literal["stdio"] = ...) -> None: ... + def run( + self, + transport: Literal["stdio"] = ..., + *, + stdin: anyio.AsyncFile[str] | None = ..., + stdout: anyio.AsyncFile[str] | None = ..., + ) -> None: ... @overload def run( @@ -284,7 +290,7 @@ def run( match transport: case "stdio": - anyio.run(self.run_stdio_async) + anyio.run(lambda: self.run_stdio_async(**kwargs)) case "sse": # pragma: no cover anyio.run(lambda: self.run_sse_async(**kwargs)) case "streamable-http": # pragma: no cover @@ -836,9 +842,25 @@ def decorator( # pragma: no cover return decorator # pragma: no cover - async def run_stdio_async(self) -> None: - """Run the server using stdio transport.""" - async with stdio_server() as (read_stream, write_stream): + async def run_stdio_async( + self, + *, + stdin: anyio.AsyncFile[str] | None = None, + stdout: anyio.AsyncFile[str] | None = None, + ) -> None: + """Run the server using stdio transport. + + Args: + stdin: Async text stream to read JSON-RPC lines from. When ``None``, + uses the process stdin (see :func:`mcp.server.stdio.stdio_server`). + stdout: Async text stream to write JSON-RPC lines to. When ``None``, + uses the process stdout. + + Custom streams are useful when the process ``sys.stdout`` / ``sys.stdin`` + must be redirected (for example so logging or subprocess output does not + corrupt the MCP JSON-RPC stream on fd 1). + """ + async with stdio_server(stdin=stdin, stdout=stdout) as (read_stream, write_stream): await self._lowlevel_server.run( read_stream, write_stream, diff --git a/tests/server/mcpserver/test_run_stdio_custom_streams.py b/tests/server/mcpserver/test_run_stdio_custom_streams.py new file mode 100644 index 000000000..9e6fd87b9 --- /dev/null +++ b/tests/server/mcpserver/test_run_stdio_custom_streams.py @@ -0,0 +1,49 @@ +"""MCPServer.run_stdio_async forwards optional stdin/stdout to stdio_server.""" + +from __future__ import annotations + +import io +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +from unittest.mock import AsyncMock + +import anyio +import pytest + +from mcp.server.mcpserver import MCPServer + + +@pytest.mark.anyio +async def test_run_stdio_async_passes_streams_to_stdio_server(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, object] = {} + + @asynccontextmanager + async def spy_stdio_server( + stdin: anyio.AsyncFile[str] | None = None, + stdout: anyio.AsyncFile[str] | None = None, + ) -> AsyncIterator[tuple[AsyncMock, AsyncMock]]: + captured["stdin"] = stdin + captured["stdout"] = stdout + read_stream = AsyncMock() + write_stream = AsyncMock() + yield read_stream, write_stream + + async def noop_run(*_args: Any, **_kwargs: Any) -> None: + return None + + monkeypatch.setattr("mcp.server.mcpserver.server.stdio_server", spy_stdio_server) + + server = MCPServer("test-stdio-spy") + monkeypatch.setattr(server._lowlevel_server, "run", noop_run) + monkeypatch.setattr(server._lowlevel_server, "create_initialization_options", lambda: object()) + + sin = io.StringIO() + sout = io.StringIO() + await server.run_stdio_async( + stdin=anyio.AsyncFile(sin), + stdout=anyio.AsyncFile(sout), + ) + + assert captured["stdin"] is not None + assert captured["stdout"] is not None