diff --git a/src/git/src/mcp_server_git/server.py b/src/git/src/mcp_server_git/server.py index 5ce953e545..9cd588fd4b 100644 --- a/src/git/src/mcp_server_git/server.py +++ b/src/git/src/mcp_server_git/server.py @@ -584,4 +584,4 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: options = server.create_initialization_options() async with stdio_server() as (read_stream, write_stream): - await server.run(read_stream, write_stream, options, raise_exceptions=True) + await server.run(read_stream, write_stream, options) diff --git a/src/git/tests/test_server.py b/src/git/tests/test_server.py index a5492adc85..5a22a86e6b 100644 --- a/src/git/tests/test_server.py +++ b/src/git/tests/test_server.py @@ -1,7 +1,10 @@ +import asyncio +from contextlib import asynccontextmanager import pytest from pathlib import Path import git from git.exc import BadName +import mcp_server_git.server as server_module from mcp_server_git.server import ( git_checkout, git_branch, @@ -39,6 +42,39 @@ def test_git_checkout_existing_branch(test_repository): assert "Switched to branch 'test-branch'" in result assert test_repository.active_branch.name == "test-branch" +def test_stdio_transport_errors_do_not_abort_server(monkeypatch): + captured_run_kwargs = {} + + class FakeServer: + def __init__(self, name): + self.name = name + + def list_roots(self): + return lambda fn: fn + + def list_tools(self): + return lambda fn: fn + + def call_tool(self): + return lambda fn: fn + + def create_initialization_options(self): + return object() + + async def run(self, read_stream, write_stream, options, **kwargs): + captured_run_kwargs.update(kwargs) + + @asynccontextmanager + async def fake_stdio_server(): + yield object(), object() + + monkeypatch.setattr(server_module, "Server", FakeServer) + monkeypatch.setattr(server_module, "stdio_server", fake_stdio_server) + + asyncio.run(server_module.serve(repository=None)) + + assert "raise_exceptions" not in captured_run_kwargs + def test_git_checkout_nonexistent_branch(test_repository): with pytest.raises(BadName):