diff --git a/src/fetch/src/mcp_server_fetch/__init__.py b/src/fetch/src/mcp_server_fetch/__init__.py index 09744ce319..cddc3c913d 100644 --- a/src/fetch/src/mcp_server_fetch/__init__.py +++ b/src/fetch/src/mcp_server_fetch/__init__.py @@ -1,4 +1,12 @@ -from .server import serve +import os +import sys + +from .server import ACLConfigError, serve + + +def _env_flag(name: str) -> bool: + value = os.getenv(name, "").strip().lower() + return value in {"1", "true", "yes", "on"} def main(): @@ -16,9 +24,34 @@ def main(): help="Ignore robots.txt restrictions", ) parser.add_argument("--proxy-url", type=str, help="Proxy URL to use for requests") + parser.add_argument( + "--allow-host", + action="append", + default=[], + help="Allowed host (repeatable). Required when --strict-acl is enabled.", + ) + parser.add_argument( + "--strict-acl", + action="store_true", + help="Fail startup unless explicit ACL configuration is provided.", + ) args = parser.parse_args() - asyncio.run(serve(args.user_agent, args.ignore_robots_txt, args.proxy_url)) + strict_acl = args.strict_acl or _env_flag("MCP_SERVER_STRICT_ACL") + allowed_hosts = tuple(args.allow_host or []) + try: + asyncio.run( + serve( + args.user_agent, + args.ignore_robots_txt, + args.proxy_url, + strict_acl=strict_acl, + allowed_hosts=allowed_hosts, + ) + ) + except ACLConfigError as exc: + print(str(exc), file=sys.stderr) + raise SystemExit(2) from exc if __name__ == "__main__": diff --git a/src/fetch/src/mcp_server_fetch/server.py b/src/fetch/src/mcp_server_fetch/server.py index 2df9d3b604..46d5af3d76 100644 --- a/src/fetch/src/mcp_server_fetch/server.py +++ b/src/fetch/src/mcp_server_fetch/server.py @@ -24,6 +24,65 @@ DEFAULT_USER_AGENT_MANUAL = "ModelContextProtocol/1.0 (User-Specified; +https://github.com/modelcontextprotocol/servers)" +class ACLConfigError(ValueError): + """Raised when strict ACL startup requirements are not met.""" + + +def normalize_allowed_hosts( + allowed_hosts: tuple[str, ...] | list[str] | None, +) -> tuple[str, ...]: + """Normalize and deduplicate host allowlist entries.""" + if not allowed_hosts: + return () + + normalized: list[str] = [] + seen: set[str] = set() + for host in allowed_hosts: + cleaned = host.strip().lower() + if not cleaned: + continue + if cleaned.startswith("*."): + cleaned = cleaned[2:] + if cleaned not in seen: + seen.add(cleaned) + normalized.append(cleaned) + return tuple(normalized) + + +def validate_startup_acl(strict_acl: bool, allowed_hosts: tuple[str, ...]) -> None: + """Fail closed when strict ACL mode is enabled without explicit host ACL config.""" + if strict_acl and len(allowed_hosts) == 0: + raise ACLConfigError( + "ACL_CONFIG_MISSING: strict ACL mode requires at least one --allow-host value." + ) + + +def is_url_allowed(url: str, allowed_hosts: tuple[str, ...]) -> bool: + """Return true if URL host matches explicit allowlist entries.""" + if len(allowed_hosts) == 0: + return True + hostname = (urlparse(url).hostname or "").lower() + if hostname == "": + return False + return any( + hostname == allowed or hostname.endswith(f".{allowed}") + for allowed in allowed_hosts + ) + + +def enforce_url_acl(url: str, allowed_hosts: tuple[str, ...]) -> None: + """Raise MCP error when URL host is outside allowlist.""" + if is_url_allowed(url, allowed_hosts): + return + hostname = (urlparse(url).hostname or "").lower() or "" + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message=f"ACL_CONFIG_DENY: host '{hostname}' is not in allowed hosts.", + ) + ) + + def extract_content_from_html(html: str) -> str: """Extract and convert HTML content to Markdown format. @@ -63,7 +122,9 @@ def get_robots_txt_url(url: str) -> str: return robots_url -async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url: str | None = None) -> None: +async def check_may_autonomously_fetch_url( + url: str, user_agent: str, proxy_url: str | None = None +) -> None: """ Check if the URL can be fetched by the user agent according to the robots.txt file. Raises a McpError if not. @@ -80,15 +141,19 @@ async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url: headers={"User-Agent": user_agent}, ) except HTTPError: - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message=f"Failed to fetch robots.txt {robot_txt_url} due to a connection issue", - )) + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message=f"Failed to fetch robots.txt {robot_txt_url} due to a connection issue", + ) + ) if response.status_code in (401, 403): - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message=f"When fetching robots.txt ({robot_txt_url}), received status {response.status_code} so assuming that autonomous fetching is not allowed, the user can try manually fetching by using the fetch prompt", - )) + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message=f"When fetching robots.txt ({robot_txt_url}), received status {response.status_code} so assuming that autonomous fetching is not allowed, the user can try manually fetching by using the fetch prompt", + ) + ) elif 400 <= response.status_code < 500: return robot_txt = response.text @@ -97,15 +162,17 @@ async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url: ) robot_parser = Protego.parse(processed_robot_txt) if not robot_parser.can_fetch(str(url), user_agent): - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message=f"The sites robots.txt ({robot_txt_url}), specifies that autonomous fetching of this page is not allowed, " - f"{user_agent}\n" - f"{url}" - f"\n{robot_txt}\n\n" - f"The assistant must let the user know that it failed to view the page. The assistant may provide further guidance based on the above information.\n" - f"The assistant can tell the user that they can try manually fetching the page by using the fetch prompt within their UI.", - )) + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message=f"The sites robots.txt ({robot_txt_url}), specifies that autonomous fetching of this page is not allowed, " + f"{user_agent}\n" + f"{url}" + f"\n{robot_txt}\n\n" + f"The assistant must let the user know that it failed to view the page. The assistant may provide further guidance based on the above information.\n" + f"The assistant can tell the user that they can try manually fetching the page by using the fetch prompt within their UI.", + ) + ) async def fetch_url( @@ -125,12 +192,16 @@ async def fetch_url( timeout=30, ) except HTTPError as e: - raise McpError(ErrorData(code=INTERNAL_ERROR, message=f"Failed to fetch {url}: {e!r}")) + raise McpError( + ErrorData(code=INTERNAL_ERROR, message=f"Failed to fetch {url}: {e!r}") + ) if response.status_code >= 400: - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message=f"Failed to fetch {url} - status code {response.status_code}", - )) + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message=f"Failed to fetch {url} - status code {response.status_code}", + ) + ) page_raw = response.text @@ -182,6 +253,8 @@ async def serve( custom_user_agent: str | None = None, ignore_robots_txt: bool = False, proxy_url: str | None = None, + strict_acl: bool = False, + allowed_hosts: tuple[str, ...] = (), ) -> None: """Run the fetch MCP server. @@ -189,7 +262,12 @@ async def serve( custom_user_agent: Optional custom User-Agent string to use for requests ignore_robots_txt: Whether to ignore robots.txt restrictions proxy_url: Optional proxy URL to use for requests + strict_acl: Whether startup should fail without explicit ACL config + allowed_hosts: Explicit host allowlist for outbound fetches """ + normalized_allowed_hosts = normalize_allowed_hosts(allowed_hosts) + validate_startup_acl(strict_acl, normalized_allowed_hosts) + server = Server("mcp-fetch") user_agent_autonomous = custom_user_agent or DEFAULT_USER_AGENT_AUTONOMOUS user_agent_manual = custom_user_agent or DEFAULT_USER_AGENT_MANUAL @@ -230,9 +308,12 @@ async def call_tool(name, arguments: dict) -> list[TextContent]: url = str(args.url) if not url: raise McpError(ErrorData(code=INVALID_PARAMS, message="URL is required")) + enforce_url_acl(url, normalized_allowed_hosts) if not ignore_robots_txt: - await check_may_autonomously_fetch_url(url, user_agent_autonomous, proxy_url) + await check_may_autonomously_fetch_url( + url, user_agent_autonomous, proxy_url + ) content, prefix = await fetch_url( url, user_agent_autonomous, force_raw=args.raw, proxy_url=proxy_url @@ -241,13 +322,17 @@ async def call_tool(name, arguments: dict) -> list[TextContent]: if args.start_index >= original_length: content = "No more content available." else: - truncated_content = content[args.start_index : args.start_index + args.max_length] + truncated_content = content[ + args.start_index : args.start_index + args.max_length + ] if not truncated_content: content = "No more content available." else: content = truncated_content actual_content_length = len(truncated_content) - remaining_content = original_length - (args.start_index + actual_content_length) + remaining_content = original_length - ( + args.start_index + actual_content_length + ) # Only add the prompt to continue fetching if there is still remaining content if actual_content_length == args.max_length and remaining_content > 0: next_start = args.start_index + actual_content_length @@ -260,9 +345,12 @@ async def get_prompt(name: str, arguments: dict | None) -> GetPromptResult: raise McpError(ErrorData(code=INVALID_PARAMS, message="URL is required")) url = arguments["url"] + enforce_url_acl(str(url), normalized_allowed_hosts) try: - content, prefix = await fetch_url(url, user_agent_manual, proxy_url=proxy_url) + content, prefix = await fetch_url( + url, user_agent_manual, proxy_url=proxy_url + ) # TODO: after SDK bug is addressed, don't catch the exception except McpError as e: return GetPromptResult( diff --git a/src/fetch/tests/test_server.py b/src/fetch/tests/test_server.py index 10103b87c4..b49e0a7cb7 100644 --- a/src/fetch/tests/test_server.py +++ b/src/fetch/tests/test_server.py @@ -5,10 +5,14 @@ from mcp.shared.exceptions import McpError from mcp_server_fetch.server import ( + ACLConfigError, extract_content_from_html, get_robots_txt_url, check_may_autonomously_fetch_url, fetch_url, + is_url_allowed, + normalize_allowed_hosts, + validate_startup_acl, DEFAULT_USER_AGENT_AUTONOMOUS, ) @@ -100,13 +104,14 @@ async def test_allows_when_robots_txt_404(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) # Should not raise await check_may_autonomously_fetch_url( - "https://example.com/page", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS ) @pytest.mark.asyncio @@ -118,13 +123,14 @@ async def test_blocks_when_robots_txt_401(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) with pytest.raises(McpError): await check_may_autonomously_fetch_url( - "https://example.com/page", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS ) @pytest.mark.asyncio @@ -136,13 +142,14 @@ async def test_blocks_when_robots_txt_403(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) with pytest.raises(McpError): await check_may_autonomously_fetch_url( - "https://example.com/page", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS ) @pytest.mark.asyncio @@ -155,13 +162,14 @@ async def test_allows_when_robots_txt_allows_all(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) # Should not raise await check_may_autonomously_fetch_url( - "https://example.com/page", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS ) @pytest.mark.asyncio @@ -174,13 +182,14 @@ async def test_blocks_when_robots_txt_disallows_all(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) with pytest.raises(McpError): await check_may_autonomously_fetch_url( - "https://example.com/page", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS ) @@ -207,12 +216,13 @@ async def test_fetch_html_page(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) content, prefix = await fetch_url( - "https://example.com/page", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS ) # HTML is processed, so we check it returns something @@ -231,13 +241,15 @@ async def test_fetch_html_page_raw(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) content, prefix = await fetch_url( "https://example.com/page", DEFAULT_USER_AGENT_AUTONOMOUS, - force_raw=True + force_raw=True, ) assert content == html_content @@ -255,12 +267,13 @@ async def test_fetch_json_returns_raw(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) content, prefix = await fetch_url( - "https://api.example.com/data", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://api.example.com/data", DEFAULT_USER_AGENT_AUTONOMOUS ) assert content == json_content @@ -275,13 +288,14 @@ async def test_fetch_404_raises_error(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) with pytest.raises(McpError): await fetch_url( - "https://example.com/notfound", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/notfound", DEFAULT_USER_AGENT_AUTONOMOUS ) @pytest.mark.asyncio @@ -293,13 +307,14 @@ async def test_fetch_500_raises_error(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) with pytest.raises(McpError): await fetch_url( - "https://example.com/error", - DEFAULT_USER_AGENT_AUTONOMOUS + "https://example.com/error", DEFAULT_USER_AGENT_AUTONOMOUS ) @pytest.mark.asyncio @@ -313,14 +328,42 @@ async def test_fetch_with_proxy(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) await fetch_url( "https://example.com/data", DEFAULT_USER_AGENT_AUTONOMOUS, - proxy_url="http://proxy.example.com:8080" + proxy_url="http://proxy.example.com:8080", ) # Verify AsyncClient was called with proxy - mock_client_class.assert_called_once_with(proxies="http://proxy.example.com:8080") + mock_client_class.assert_called_once_with( + proxies="http://proxy.example.com:8080" + ) + + +class TestStartupAcl: + def test_strict_acl_requires_allow_hosts(self): + with pytest.raises(ACLConfigError, match="ACL_CONFIG_MISSING"): + validate_startup_acl(strict_acl=True, allowed_hosts=()) + + def test_non_strict_acl_allows_missing_allow_hosts(self): + validate_startup_acl(strict_acl=False, allowed_hosts=()) + + def test_strict_acl_allows_explicit_hosts(self): + validate_startup_acl(strict_acl=True, allowed_hosts=("example.com",)) + + def test_normalize_allowed_hosts(self): + hosts = normalize_allowed_hosts( + [" Example.com ", "*.example.com", "api.example.com"] + ) + assert hosts == ("example.com", "api.example.com") + + def test_is_url_allowed_exact_and_subdomain(self): + allowed_hosts = ("example.com",) + assert is_url_allowed("https://example.com/path", allowed_hosts) + assert is_url_allowed("https://docs.example.com/path", allowed_hosts) + assert not is_url_allowed("https://other.com/path", allowed_hosts) diff --git a/src/filesystem/__tests__/startup-validation.test.ts b/src/filesystem/__tests__/startup-validation.test.ts index 3be283df74..e5477aa542 100644 --- a/src/filesystem/__tests__/startup-validation.test.ts +++ b/src/filesystem/__tests__/startup-validation.test.ts @@ -97,4 +97,26 @@ describe('Startup Directory Validation', () => { // Should still start with the valid directory expect(result.stderr).toContain('Secure MCP Filesystem Server running on stdio'); }); + + it('should fail in strict ACL mode when no directory is configured', async () => { + const result = await spawnServer(['--strict-acl']); + + expect(result.exitCode).toBe(2); + expect(result.stderr).toContain('ACL_CONFIG_MISSING'); + }); + + it('should fail in strict ACL mode when any configured directory is invalid', async () => { + const nonExistentDir = path.join(testDir, 'missing-dir'); + const result = await spawnServer(['--strict-acl', nonExistentDir, accessibleDir]); + + expect(result.exitCode).toBe(2); + expect(result.stderr).toContain('ACL_CONFIG_INVALID'); + }); + + it('should start in strict ACL mode when all configured directories are valid', async () => { + const result = await spawnServer(['--strict-acl', accessibleDir, accessibleDir2]); + + expect(result.stderr).toContain('Secure MCP Filesystem Server running on stdio'); + expect(result.stderr).not.toContain('ACL_CONFIG_'); + }); }); diff --git a/src/filesystem/index.ts b/src/filesystem/index.ts index 7b67e63e58..c062f0ce4b 100644 --- a/src/filesystem/index.ts +++ b/src/filesystem/index.ts @@ -29,7 +29,19 @@ import { } from './lib.js'; // Command line argument parsing -const args = process.argv.slice(2); +const strictAclEnv = (process.env.MCP_SERVER_STRICT_ACL ?? "").toLowerCase(); +const strictAclEnvEnabled = ["1", "true", "yes", "on"].includes(strictAclEnv); +const rawArgs = process.argv.slice(2); +const args: string[] = []; +let strictAcl = strictAclEnvEnabled; +for (const arg of rawArgs) { + if (arg === "--strict-acl") { + strictAcl = true; + continue; + } + args.push(arg); +} + if (args.length === 0) { console.error("Usage: mcp-server-filesystem [allowed-directory] [additional-directories...]"); console.error("Note: Allowed directories can be provided via:"); @@ -37,6 +49,10 @@ if (args.length === 0) { console.error(" 2. MCP roots protocol (if client supports it)"); console.error("At least one directory must be provided by EITHER method for the server to operate."); } +if (strictAcl && args.length === 0) { + console.error("ACL_CONFIG_MISSING: strict ACL mode requires at least one allowed directory argument."); + process.exit(2); +} // Store allowed directories in normalized and resolved form // We store BOTH the original path AND the resolved path to handle symlinks correctly @@ -68,6 +84,7 @@ let allowedDirectories = (await Promise.all( // Filter to only accessible directories, warn about inaccessible ones const accessibleDirectories: string[] = []; +const invalidAclMessages: string[] = []; for (const dir of allowedDirectories) { try { const stats = await fs.stat(dir); @@ -75,14 +92,28 @@ for (const dir of allowedDirectories) { accessibleDirectories.push(dir); } else { console.error(`Warning: ${dir} is not a directory, skipping`); + invalidAclMessages.push(`${dir} is not a directory`); } } catch (error) { console.error(`Warning: Cannot access directory ${dir}, skipping`); + invalidAclMessages.push(`Cannot access directory ${dir}`); } } +if (strictAcl && invalidAclMessages.length > 0) { + console.error("ACL_CONFIG_INVALID: strict ACL mode requires all configured directories to be valid and accessible."); + for (const message of invalidAclMessages) { + console.error(` - ${message}`); + } + process.exit(2); +} + // Exit only if ALL paths are inaccessible (and some were specified) if (accessibleDirectories.length === 0 && allowedDirectories.length > 0) { + if (strictAcl) { + console.error("ACL_CONFIG_INVALID: strict ACL mode found no accessible directories."); + process.exit(2); + } console.error("Error: None of the specified directories are accessible"); process.exit(1); } diff --git a/src/git/src/mcp_server_git/__init__.py b/src/git/src/mcp_server_git/__init__.py index 2270018733..73f5febe54 100644 --- a/src/git/src/mcp_server_git/__init__.py +++ b/src/git/src/mcp_server_git/__init__.py @@ -2,12 +2,24 @@ from pathlib import Path import logging import sys -from .server import serve +import os +from .server import ACLConfigError, serve + + +def _env_flag(name: str) -> bool: + value = os.getenv(name, "").strip().lower() + return value in {"1", "true", "yes", "on"} + @click.command() @click.option("--repository", "-r", type=Path, help="Git repository path") +@click.option( + "--strict-acl", + is_flag=True, + help="Fail startup unless repository ACL is explicitly configured.", +) @click.option("-v", "--verbose", count=True) -def main(repository: Path | None, verbose: bool) -> None: +def main(repository: Path | None, strict_acl: bool, verbose: bool) -> None: """MCP Git Server - Git functionality for MCP""" import asyncio @@ -18,7 +30,13 @@ def main(repository: Path | None, verbose: bool) -> None: logging_level = logging.DEBUG logging.basicConfig(level=logging_level, stream=sys.stderr) - asyncio.run(serve(repository)) + strict_acl = strict_acl or _env_flag("MCP_SERVER_STRICT_ACL") + try: + asyncio.run(serve(repository, strict_acl=strict_acl)) + except ACLConfigError as exc: + click.echo(str(exc), err=True) + raise SystemExit(2) from exc + if __name__ == "__main__": main() diff --git a/src/git/src/mcp_server_git/server.py b/src/git/src/mcp_server_git/server.py index 1d0298b465..12acd48b9f 100644 --- a/src/git/src/mcp_server_git/server.py +++ b/src/git/src/mcp_server_git/server.py @@ -19,60 +19,74 @@ # Default number of context lines to show in diff output DEFAULT_CONTEXT_LINES = 3 + +class ACLConfigError(ValueError): + """Raised when strict ACL startup requirements are not met.""" + + class GitStatus(BaseModel): repo_path: str + class GitDiffUnstaged(BaseModel): repo_path: str context_lines: int = DEFAULT_CONTEXT_LINES + class GitDiffStaged(BaseModel): repo_path: str context_lines: int = DEFAULT_CONTEXT_LINES + class GitDiff(BaseModel): repo_path: str target: str context_lines: int = DEFAULT_CONTEXT_LINES + class GitCommit(BaseModel): repo_path: str message: str + class GitAdd(BaseModel): repo_path: str files: list[str] + class GitReset(BaseModel): repo_path: str + class GitLog(BaseModel): repo_path: str max_count: int = 10 start_timestamp: Optional[str] = Field( None, - description="Start timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')" + description="Start timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')", ) end_timestamp: Optional[str] = Field( None, - description="End timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')" + description="End timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')", ) + class GitCreateBranch(BaseModel): repo_path: str branch_name: str base_branch: str | None = None + class GitCheckout(BaseModel): repo_path: str branch_name: str + class GitShow(BaseModel): repo_path: str revision: str - class GitBranch(BaseModel): repo_path: str = Field( ..., @@ -107,16 +121,24 @@ class GitTools(str, Enum): BRANCH = "git_branch" + def git_status(repo: git.Repo) -> str: return repo.git.status() -def git_diff_unstaged(repo: git.Repo, context_lines: int = DEFAULT_CONTEXT_LINES) -> str: + +def git_diff_unstaged( + repo: git.Repo, context_lines: int = DEFAULT_CONTEXT_LINES +) -> str: return repo.git.diff(f"--unified={context_lines}") + def git_diff_staged(repo: git.Repo, context_lines: int = DEFAULT_CONTEXT_LINES) -> str: return repo.git.diff(f"--unified={context_lines}", "--cached") -def git_diff(repo: git.Repo, target: str, context_lines: int = DEFAULT_CONTEXT_LINES) -> str: + +def git_diff( + repo: git.Repo, target: str, context_lines: int = DEFAULT_CONTEXT_LINES +) -> str: # Defense in depth: reject targets starting with '-' to prevent flag injection, # even if a malicious ref with that name exists (e.g. via filesystem manipulation) if target.startswith("-"): @@ -124,10 +146,12 @@ def git_diff(repo: git.Repo, target: str, context_lines: int = DEFAULT_CONTEXT_L repo.rev_parse(target) # Validates target is a real git ref, throws BadName if not return repo.git.diff(f"--unified={context_lines}", target) + def git_commit(repo: git.Repo, message: str) -> str: commit = repo.index.commit(message) return f"Changes committed successfully with hash {commit.hexsha}" + def git_add(repo: git.Repo, files: list[str]) -> str: if files == ["."]: repo.git.add(".") @@ -136,21 +160,28 @@ def git_add(repo: git.Repo, files: list[str]) -> str: repo.git.add("--", *files) return "Files staged successfully" + def git_reset(repo: git.Repo) -> str: repo.index.reset() return "All staged changes reset" -def git_log(repo: git.Repo, max_count: int = 10, start_timestamp: Optional[str] = None, end_timestamp: Optional[str] = None) -> list[str]: + +def git_log( + repo: git.Repo, + max_count: int = 10, + start_timestamp: Optional[str] = None, + end_timestamp: Optional[str] = None, +) -> list[str]: if start_timestamp or end_timestamp: # Use git log command with date filtering args = [] if start_timestamp: - args.extend(['--since', start_timestamp]) + args.extend(["--since", start_timestamp]) if end_timestamp: - args.extend(['--until', end_timestamp]) - args.extend(['--format=%H%n%an%n%ad%n%s%n']) + args.extend(["--until", end_timestamp]) + args.extend(["--format=%H%n%an%n%ad%n%s%n"]) - log_output = repo.git.log(*args).split('\n') + log_output = repo.git.log(*args).split("\n") log = [] # Process commits in groups of 4 (hash, author, date, message) @@ -176,7 +207,10 @@ def git_log(repo: git.Repo, max_count: int = 10, start_timestamp: Optional[str] ) return log -def git_create_branch(repo: git.Repo, branch_name: str, base_branch: str | None = None) -> str: + +def git_create_branch( + repo: git.Repo, branch_name: str, base_branch: str | None = None +) -> str: if base_branch: base = repo.references[base_branch] else: @@ -185,17 +219,19 @@ def git_create_branch(repo: git.Repo, branch_name: str, base_branch: str | None repo.create_head(branch_name, base) return f"Created branch '{branch_name}' from '{base.name}'" + def git_checkout(repo: git.Repo, branch_name: str) -> str: # Defense in depth: reject branch names starting with '-' to prevent flag injection, # even if a malicious ref with that name exists (e.g. via filesystem manipulation) if branch_name.startswith("-"): raise BadName(f"Invalid branch name: '{branch_name}' - cannot start with '-'") - repo.rev_parse(branch_name) # Validates branch_name is a real git ref, throws BadName if not + repo.rev_parse( + branch_name + ) # Validates branch_name is a real git ref, throws BadName if not repo.git.checkout(branch_name) return f"Switched to branch '{branch_name}'" - def git_show(repo: git.Repo, revision: str) -> str: commit = repo.commit(revision) output = [ @@ -214,11 +250,12 @@ def git_show(repo: git.Repo, revision: str) -> str: if d.diff is None: continue if isinstance(d.diff, bytes): - output.append(d.diff.decode('utf-8')) + output.append(d.diff.decode("utf-8")) else: output.append(d.diff) return "".join(output) + def validate_repo_path(repo_path: Path, allowed_repository: Path | None) -> None: """Validate that repo_path is within the allowed repository path.""" if allowed_repository is None: @@ -240,7 +277,29 @@ def validate_repo_path(repo_path: Path, allowed_repository: Path | None) -> None ) -def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, not_contains: str | None = None) -> str: +def validate_startup_acl(repository: Path | None, strict_acl: bool) -> None: + """Validate startup ACL requirements for strict mode.""" + if strict_acl and repository is None: + raise ACLConfigError( + "ACL_CONFIG_MISSING: strict ACL mode requires --repository." + ) + + if repository is not None: + try: + git.Repo(repository) + except git.InvalidGitRepositoryError as exc: + if strict_acl: + raise ACLConfigError( + f"ACL_CONFIG_INVALID: {repository} is not a valid Git repository." + ) from exc + + +def git_branch( + repo: git.Repo, + branch_type: str, + contains: str | None = None, + not_contains: str | None = None, +) -> str: match contains: case None: contains_sha = (None,) @@ -254,11 +313,11 @@ def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, no not_contains_sha = ("--no-contains", not_contains) match branch_type: - case 'local': + case "local": b_type = None - case 'remote': + case "remote": b_type = "-r" - case 'all': + case "all": b_type = "-a" case _: return f"Invalid branch type: {branch_type}" @@ -269,9 +328,9 @@ def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, no return branch_info -async def serve(repository: Path | None) -> None: +async def serve(repository: Path | None, strict_acl: bool = False) -> None: logger = logging.getLogger(__name__) - + validate_startup_acl(repository, strict_acl) if repository is not None: try: git.Repo(repository) @@ -340,26 +399,28 @@ async def list_tools() -> list[Tool]: description="Shows the contents of a commit", inputSchema=GitShow.model_json_schema(), ), - Tool( name=GitTools.BRANCH, description="List Git branches", inputSchema=GitBranch.model_json_schema(), - - ) + ), ] async def list_repos() -> Sequence[str]: async def by_roots() -> Sequence[str]: if not isinstance(server.request_context.session, ServerSession): - raise TypeError("server.request_context.session must be a ServerSession") + raise TypeError( + "server.request_context.session must be a ServerSession" + ) if not server.request_context.session.check_client_capability( ClientCapabilities(roots=RootsCapability()) ): return [] - roots_result: ListRootsResult = await server.request_context.session.list_roots() + roots_result: ListRootsResult = ( + await server.request_context.session.list_roots() + ) logger.debug(f"Roots result: {roots_result}") repo_paths = [] for root in roots_result.roots: @@ -391,52 +452,43 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: match name: case GitTools.STATUS: status = git_status(repo) - return [TextContent( - type="text", - text=f"Repository status:\n{status}" - )] + return [TextContent(type="text", text=f"Repository status:\n{status}")] case GitTools.DIFF_UNSTAGED: - diff = git_diff_unstaged(repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES)) - return [TextContent( - type="text", - text=f"Unstaged changes:\n{diff}" - )] + diff = git_diff_unstaged( + repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES) + ) + return [TextContent(type="text", text=f"Unstaged changes:\n{diff}")] case GitTools.DIFF_STAGED: - diff = git_diff_staged(repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES)) - return [TextContent( - type="text", - text=f"Staged changes:\n{diff}" - )] + diff = git_diff_staged( + repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES) + ) + return [TextContent(type="text", text=f"Staged changes:\n{diff}")] case GitTools.DIFF: - diff = git_diff(repo, arguments["target"], arguments.get("context_lines", DEFAULT_CONTEXT_LINES)) - return [TextContent( - type="text", - text=f"Diff with {arguments['target']}:\n{diff}" - )] + diff = git_diff( + repo, + arguments["target"], + arguments.get("context_lines", DEFAULT_CONTEXT_LINES), + ) + return [ + TextContent( + type="text", text=f"Diff with {arguments['target']}:\n{diff}" + ) + ] case GitTools.COMMIT: result = git_commit(repo, arguments["message"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.ADD: result = git_add(repo, arguments["files"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.RESET: result = git_reset(repo) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] # Update the LOG case: case GitTools.LOG: @@ -444,49 +496,34 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: repo, arguments.get("max_count", 10), arguments.get("start_timestamp"), - arguments.get("end_timestamp") + arguments.get("end_timestamp"), ) - return [TextContent( - type="text", - text="Commit history:\n" + "\n".join(log) - )] + return [ + TextContent(type="text", text="Commit history:\n" + "\n".join(log)) + ] case GitTools.CREATE_BRANCH: result = git_create_branch( - repo, - arguments["branch_name"], - arguments.get("base_branch") + repo, arguments["branch_name"], arguments.get("base_branch") ) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.CHECKOUT: result = git_checkout(repo, arguments["branch_name"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.SHOW: result = git_show(repo, arguments["revision"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.BRANCH: result = git_branch( repo, - arguments.get("branch_type", 'local'), + arguments.get("branch_type", "local"), arguments.get("contains", None), arguments.get("not_contains", None), ) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case _: raise ValueError(f"Unknown tool: {name}") diff --git a/src/git/tests/test_server.py b/src/git/tests/test_server.py index 054bf8c756..2b2d0ce382 100644 --- a/src/git/tests/test_server.py +++ b/src/git/tests/test_server.py @@ -3,6 +3,7 @@ import git from git.exc import BadName from mcp_server_git.server import ( + ACLConfigError, git_checkout, git_branch, git_add, @@ -16,9 +17,11 @@ git_create_branch, git_show, validate_repo_path, + validate_startup_acl, ) import shutil + @pytest.fixture def test_repository(tmp_path: Path): repo_path = tmp_path / "temp_test_repo" @@ -32,6 +35,7 @@ def test_repository(tmp_path: Path): shutil.rmtree(repo_path) + def test_git_checkout_existing_branch(test_repository): test_repository.git.branch("test-branch") result = git_checkout(test_repository, "test-branch") @@ -39,31 +43,37 @@ def test_git_checkout_existing_branch(test_repository): assert "Switched to branch 'test-branch'" in result assert test_repository.active_branch.name == "test-branch" -def test_git_checkout_nonexistent_branch(test_repository): +def test_git_checkout_nonexistent_branch(test_repository): with pytest.raises(BadName): git_checkout(test_repository, "nonexistent-branch") + def test_git_branch_local(test_repository): test_repository.git.branch("new-branch-local") result = git_branch(test_repository, "local") assert "new-branch-local" in result + def test_git_branch_remote(test_repository): result = git_branch(test_repository, "remote") assert "" == result.strip() # Should be empty if no remote branches + def test_git_branch_all(test_repository): test_repository.git.branch("new-branch-all") result = git_branch(test_repository, "all") assert "new-branch-all" in result + def test_git_branch_contains(test_repository): # Get the default branch name (could be "main" or "master") default_branch = test_repository.active_branch.name # Create a new branch and commit to it test_repository.git.checkout("-b", "feature-branch") - Path(test_repository.working_dir / Path("feature.txt")).write_text("feature content") + Path(test_repository.working_dir / Path("feature.txt")).write_text( + "feature content" + ) test_repository.index.add(["feature.txt"]) commit = test_repository.index.commit("feature commit") test_repository.git.checkout(default_branch) @@ -72,12 +82,15 @@ def test_git_branch_contains(test_repository): assert "feature-branch" in result assert default_branch not in result + def test_git_branch_not_contains(test_repository): # Get the default branch name (could be "main" or "master") default_branch = test_repository.active_branch.name # Create a new branch and commit to it test_repository.git.checkout("-b", "another-feature-branch") - Path(test_repository.working_dir / Path("another_feature.txt")).write_text("another feature content") + Path(test_repository.working_dir / Path("another_feature.txt")).write_text( + "another feature content" + ) test_repository.index.add(["another_feature.txt"]) commit = test_repository.index.commit("another feature commit") test_repository.git.checkout(default_branch) @@ -86,6 +99,7 @@ def test_git_branch_not_contains(test_repository): assert "another-feature-branch" not in result assert default_branch in result + def test_git_add_all_files(test_repository): file_path = Path(test_repository.working_dir) / "all_file.txt" file_path.write_text("adding all") @@ -96,6 +110,7 @@ def test_git_add_all_files(test_repository): assert "all_file.txt" in staged_files assert result == "Files staged successfully" + def test_git_add_specific_files(test_repository): file1 = Path(test_repository.working_dir) / "file1.txt" file2 = Path(test_repository.working_dir) / "file2.txt" @@ -109,12 +124,14 @@ def test_git_add_specific_files(test_repository): assert "file2.txt" not in staged_files assert result == "Files staged successfully" + def test_git_status(test_repository): result = git_status(test_repository) assert result is not None assert "On branch" in result or "branch" in result.lower() + def test_git_diff_unstaged(test_repository): file_path = Path(test_repository.working_dir) / "test.txt" file_path.write_text("modified content") @@ -124,11 +141,13 @@ def test_git_diff_unstaged(test_repository): assert "test.txt" in result assert "modified content" in result + def test_git_diff_unstaged_empty(test_repository): result = git_diff_unstaged(test_repository) assert result == "" + def test_git_diff_staged(test_repository): file_path = Path(test_repository.working_dir) / "staged_file.txt" file_path.write_text("staged content") @@ -139,11 +158,13 @@ def test_git_diff_staged(test_repository): assert "staged_file.txt" in result assert "staged content" in result + def test_git_diff_staged_empty(test_repository): result = git_diff_staged(test_repository) assert result == "" + def test_git_diff(test_repository): # Get the default branch name (could be "main" or "master") default_branch = test_repository.active_branch.name @@ -158,6 +179,7 @@ def test_git_diff(test_repository): assert "test.txt" in result assert "feature changes" in result + def test_git_commit(test_repository): file_path = Path(test_repository.working_dir) / "commit_test.txt" file_path.write_text("content to commit") @@ -170,6 +192,7 @@ def test_git_commit(test_repository): latest_commit = test_repository.head.commit assert latest_commit.message.strip() == "test commit message" + def test_git_reset(test_repository): file_path = Path(test_repository.working_dir) / "reset_test.txt" file_path.write_text("content to reset") @@ -185,6 +208,7 @@ def test_git_reset(test_repository): staged_after = [item.a_path for item in test_repository.index.diff("HEAD")] assert "reset_test.txt" not in staged_after + def test_git_log(test_repository): for i in range(3): file_path = Path(test_repository.working_dir) / f"log_test_{i}.txt" @@ -201,6 +225,7 @@ def test_git_log(test_repository): assert "Date:" in result[0] assert "Message:" in result[0] + def test_git_log_default(test_repository): result = git_log(test_repository) @@ -208,6 +233,7 @@ def test_git_log_default(test_repository): assert len(result) >= 1 assert "initial commit" in result[0] + def test_git_create_branch(test_repository): result = git_create_branch(test_repository, "new-feature-branch") @@ -216,6 +242,7 @@ def test_git_create_branch(test_repository): branches = [ref.name for ref in test_repository.references] assert "new-feature-branch" in branches + def test_git_create_branch_from_base(test_repository): test_repository.git.checkout("-b", "base-branch") file_path = Path(test_repository.working_dir) / "base.txt" @@ -227,6 +254,7 @@ def test_git_create_branch_from_base(test_repository): assert "Created branch 'derived-branch' from 'base-branch'" in result + def test_git_show(test_repository): file_path = Path(test_repository.working_dir) / "show_test.txt" file_path.write_text("show content") @@ -242,6 +270,7 @@ def test_git_show(test_repository): assert "show test commit" in result assert "show_test.txt" in result + def test_git_show_initial_commit(test_repository): initial_commit = list(test_repository.iter_commits())[-1] @@ -254,6 +283,7 @@ def test_git_show_initial_commit(test_repository): # Tests for validate_repo_path (repository scoping security fix) + def test_validate_repo_path_no_restriction(): """When no repository restriction is configured, any path should be allowed.""" validate_repo_path(Path("/any/path"), None) # Should not raise @@ -313,8 +343,37 @@ def test_validate_repo_path_symlink_escape(tmp_path: Path): with pytest.raises(ValueError) as exc_info: validate_repo_path(symlink, allowed) assert "outside the allowed repository" in str(exc_info.value) + + +def test_validate_startup_acl_non_strict_allows_missing_repository(): + validate_startup_acl(None, strict_acl=False) + + +def test_validate_startup_acl_strict_requires_repository(): + with pytest.raises(ACLConfigError, match="ACL_CONFIG_MISSING"): + validate_startup_acl(None, strict_acl=True) + + +def test_validate_startup_acl_strict_rejects_invalid_repository(tmp_path: Path): + not_repo = tmp_path / "not-a-repo" + not_repo.mkdir() + with pytest.raises(ACLConfigError, match="ACL_CONFIG_INVALID"): + validate_startup_acl(not_repo, strict_acl=True) + + +def test_validate_startup_acl_non_strict_allows_invalid_repository(tmp_path: Path): + not_repo = tmp_path / "not-a-repo" + not_repo.mkdir() + validate_startup_acl(not_repo, strict_acl=False) + + +def test_validate_startup_acl_strict_accepts_valid_repository(test_repository): + validate_startup_acl(Path(test_repository.working_dir), strict_acl=True) + + # Tests for argument injection protection + def test_git_diff_rejects_flag_injection(test_repository): """git_diff should reject flags that could be used for argument injection.""" with pytest.raises(BadName):