Skip to content

Commit 0172037

Browse files
Varun SharmaCopilot
andcommitted
fix: preserve API gateway path prefix in SSE client URL resolution
Fixes #795 When an MCP server sits behind a reverse proxy or API gateway that adds a path prefix, urljoin drops the prefix for absolute endpoint paths (starting with '/'). Add _resolve_endpoint_url() that detects gateway prefixes by finding where the endpoint's first path segment appears in the base URL path. Everything before that match is preserved as the prefix. Example: base_url: https://host/gateway/v1/sse endpoint: /v1/messages/?session_id=abc before: https://host/v1/messages/?session_id=abc (prefix lost) after: https://host/gateway/v1/messages/?session_id=abc Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Reported-by: lizzzcai <https://github.com/lizzzcai>
1 parent 7ba41dc commit 0172037

File tree

2 files changed

+127
-3
lines changed

2 files changed

+127
-3
lines changed

src/mcp/client/sse.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Callable
33
from contextlib import asynccontextmanager
44
from typing import Any
5-
from urllib.parse import parse_qs, urljoin, urlparse
5+
from urllib.parse import parse_qs, urljoin, urlparse, urlunparse
66

77
import anyio
88
import httpx
@@ -27,6 +27,65 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
2727
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
2828

2929

30+
def _resolve_endpoint_url(base_url: str, endpoint: str) -> str:
31+
"""Resolve an endpoint URL, preserving any reverse proxy/API gateway path prefix.
32+
33+
When an MCP server sits behind a reverse proxy or API gateway that adds a
34+
path prefix (e.g., ``/gateway``), the server's endpoint events contain paths
35+
without that prefix. Standard ``urljoin`` drops the base URL's path prefix
36+
for absolute paths (starting with ``/``). This function detects and
37+
preserves such prefixes.
38+
39+
Example::
40+
41+
>>> _resolve_endpoint_url(
42+
... "https://host/gateway/v1/sse",
43+
... "/v1/messages/?session_id=abc",
44+
... )
45+
'https://host/gateway/v1/messages/?session_id=abc'
46+
"""
47+
parsed_ep = urlparse(endpoint)
48+
49+
# Full URL — use as-is
50+
if parsed_ep.scheme:
51+
return endpoint
52+
53+
# Relative path (no leading /) — urljoin handles correctly
54+
if not endpoint.startswith("/"):
55+
return urljoin(base_url, endpoint)
56+
57+
# For absolute paths, detect and preserve any gateway prefix.
58+
# Strategy: find the first path segment of the endpoint inside the base URL
59+
# path. If it appears at a position > 0, everything before it is the
60+
# gateway prefix that must be preserved.
61+
parsed_base = urlparse(base_url)
62+
base_path = parsed_base.path
63+
ep_path = parsed_ep.path
64+
65+
ep_segments = [s for s in ep_path.split("/") if s]
66+
if ep_segments:
67+
first_seg = "/" + ep_segments[0]
68+
idx = base_path.find(first_seg + "/")
69+
if idx < 0 and base_path.endswith(first_seg):
70+
idx = len(base_path) - len(first_seg)
71+
72+
if idx > 0:
73+
prefix = base_path[:idx]
74+
return urlunparse(
75+
(
76+
parsed_base.scheme,
77+
parsed_base.netloc,
78+
prefix + ep_path,
79+
"",
80+
parsed_ep.query,
81+
"",
82+
)
83+
)
84+
85+
# No prefix detected — fall back to standard resolution
86+
return urljoin(base_url, endpoint)
87+
88+
3089
@asynccontextmanager
3190
async def sse_client(
3291
url: str,
@@ -80,7 +139,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
80139
logger.debug(f"Received SSE event: {sse.event}")
81140
match sse.event:
82141
case "endpoint":
83-
endpoint_url = urljoin(url, sse.data)
142+
endpoint_url = _resolve_endpoint_url(url, sse.data)
84143
logger.debug(f"Received endpoint URL: {endpoint_url}")
85144

86145
url_parsed = urlparse(url)

tests/shared/test_sse.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import mcp.client.sse
2121
from mcp import types
2222
from mcp.client.session import ClientSession
23-
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
23+
from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client
2424
from mcp.server import Server, ServerRequestContext
2525
from mcp.server.sse import SseServerTransport
2626
from mcp.server.transport_security import TransportSecuritySettings
@@ -229,6 +229,71 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229229
assert _extract_session_id_from_endpoint(endpoint_url) == expected
230230

231231

232+
@pytest.mark.parametrize(
233+
"base_url,endpoint,expected",
234+
[
235+
# --- Gateway / reverse proxy prefix (the bug from issue #795) ---
236+
(
237+
"https://example.com/gateway/v1/sse",
238+
"/v1/messages/?session_id=abc",
239+
"https://example.com/gateway/v1/messages/?session_id=abc",
240+
),
241+
(
242+
"https://example.com/gateway_prefix/v1/sse",
243+
"/v1/messages/?session_id=abc",
244+
"https://example.com/gateway_prefix/v1/messages/?session_id=abc",
245+
),
246+
# Deep gateway prefix
247+
(
248+
"https://example.com/org/team/v1/sse",
249+
"/v1/messages/?session_id=abc",
250+
"https://example.com/org/team/v1/messages/?session_id=abc",
251+
),
252+
# --- No gateway prefix (should behave like urljoin) ---
253+
(
254+
"https://example.com/v1/sse",
255+
"/v1/messages/?session_id=abc",
256+
"https://example.com/v1/messages/?session_id=abc",
257+
),
258+
(
259+
"https://example.com/sse",
260+
"/messages/?session_id=abc",
261+
"https://example.com/messages/?session_id=abc",
262+
),
263+
# --- Relative path (urljoin handles correctly) ---
264+
(
265+
"https://example.com/gateway/v1/sse",
266+
"messages/?session_id=abc",
267+
"https://example.com/gateway/v1/messages/?session_id=abc",
268+
),
269+
# --- Absolute URL endpoint (use as-is) ---
270+
(
271+
"https://example.com/gateway/v1/sse",
272+
"https://example.com/v1/messages/?session_id=abc",
273+
"https://example.com/v1/messages/?session_id=abc",
274+
),
275+
# --- Endpoint at end of base path (no trailing slash match) ---
276+
(
277+
"https://example.com/gw/api",
278+
"/api/messages/?session_id=abc",
279+
"https://example.com/gw/api/messages/?session_id=abc",
280+
),
281+
],
282+
ids=[
283+
"gateway_prefix",
284+
"gateway_prefix_underscore",
285+
"deep_gateway_prefix",
286+
"no_prefix_same_root",
287+
"no_prefix_root_level",
288+
"relative_path",
289+
"absolute_url",
290+
"endpoint_at_path_end",
291+
],
292+
)
293+
def test_resolve_endpoint_url(base_url: str, endpoint: str, expected: str) -> None:
294+
assert _resolve_endpoint_url(base_url, endpoint) == expected
295+
296+
232297
@pytest.mark.anyio
233298
async def test_sse_client_on_session_created_not_called_when_no_session_id(
234299
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch

0 commit comments

Comments
 (0)