Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def send_log_message(
related_request_id,
)

async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover
async def send_resource_updated(self, uri: str | AnyUrl) -> None:
"""Send a resource updated notification."""
await self.send_notification(
types.ResourceUpdatedNotification(
Expand Down
12 changes: 6 additions & 6 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover
if scope["type"] != "http":
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http": # pragma: no cover
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")

Expand Down Expand Up @@ -195,7 +195,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send):
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)

async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)

Expand All @@ -205,15 +205,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
return await error_response(scope, receive, send)

session_id_param = request.query_params.get("session_id")
if session_id_param is None:
if session_id_param is None: # pragma: no cover
logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return await response(scope, receive, send)

try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
except ValueError: # pragma: no cover
logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
Expand All @@ -230,7 +230,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
try:
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
except ValidationError as err: # pragma: no cover
logger.exception("Failed to parse message")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
Expand Down
72 changes: 38 additions & 34 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def is_terminated(self) -> bool:
"""Check if this transport has been explicitly terminated."""
return self._terminated

def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
def close_sse_stream(self, request_id: RequestId) -> None:
"""Close SSE connection for a specific request without terminating the stream.

This method closes the HTTP connection for the specified request, triggering
Expand All @@ -200,12 +200,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
writer.close()

# Also close and remove request streams
if request_id in self._request_streams:
if request_id in self._request_streams: # pragma: no branch
send_stream, receive_stream = self._request_streams.pop(request_id)
send_stream.close()
receive_stream.close()

def close_standalone_sse_stream(self) -> None: # pragma: no cover
def close_standalone_sse_stream(self) -> None:
"""Close the standalone GET SSE stream, triggering client reconnection.

This method closes the HTTP connection for the standalone GET stream used
Expand Down Expand Up @@ -240,10 +240,10 @@ def _create_session_message(
# Only provide close callbacks when client supports resumability
if self._event_store and protocol_version >= "2025-11-25":

async def close_stream_callback() -> None: # pragma: no cover
async def close_stream_callback() -> None:
self.close_sse_stream(request_id)

async def close_standalone_stream_callback() -> None: # pragma: no cover
async def close_standalone_stream_callback() -> None:
self.close_standalone_sse_stream()

metadata = ServerMessageMetadata(
Expand Down Expand Up @@ -291,7 +291,7 @@ def _create_error_response(
) -> Response:
"""Create an error response with a simple string message."""
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
if headers: # pragma: no cover
if headers:
response_headers.update(headers)

if self.mcp_session_id:
Expand Down Expand Up @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
}

# If an event ID was provided, include it
if event_message.event_id: # pragma: no cover
if event_message.event_id:
event_data["id"] = event_message.event_id

return event_data
Expand Down Expand Up @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await error_response(scope, receive, send)
return

if self._terminated: # pragma: no cover
if self._terminated:
# If the session has been terminated, return 404 Not Found
response = self._create_error_response(
"Not Found: Session has been terminated",
Expand All @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await self._handle_get_request(request, send)
elif request.method == "DELETE":
await self._handle_delete_request(request, send)
else: # pragma: no cover
else:
await self._handle_unsupported_request(request, send)

def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
Expand Down Expand Up @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re

try:
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
except ValidationError as e: # pragma: no cover
except ValidationError as e:
response = self._create_error_response(
f"Validation error: {str(e)}",
HTTPStatus.BAD_REQUEST,
Expand All @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
)
await response(scope, receive, send)
return
elif not await self._validate_request_headers(request, send): # pragma: no cover
elif not await self._validate_request_headers(request, send):
return

# For notifications and responses only, return 202 Accepted
Expand Down Expand Up @@ -633,7 +633,9 @@ async def sse_writer(): # pragma: lax no cover
finally:
await sse_stream_reader.aclose()

except Exception as err: # pragma: no cover
# Fires on some CI matrix entries (3.12 lowest-direct ubuntu) but not
# others — concurrent-shutdown exception timing varies by dep version.
except Exception as err: # pragma: lax no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
Expand All @@ -659,19 +661,19 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
# Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request)

if not has_sse: # pragma: no cover
if not has_sse:
response = self._create_error_response(
"Not Acceptable: Client must accept text/event-stream",
HTTPStatus.NOT_ACCEPTABLE,
)
await response(request.scope, request.receive, send)
return

if not await self._validate_request_headers(request, send): # pragma: no cover
if not await self._validate_request_headers(request, send):
return

# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
return

Expand All @@ -681,11 +683,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
"Content-Type": CONTENT_TYPE_SSE,
}

if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Check if we already have an active GET stream
if GET_STREAM_KEY in self._request_streams: # pragma: no cover
if GET_STREAM_KEY in self._request_streams:
response = self._create_error_response(
"Conflict: Only one SSE stream is allowed per session",
HTTPStatus.CONFLICT,
Expand Down Expand Up @@ -714,7 +716,9 @@ async def standalone_sse_writer():
# Send the message via SSE
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception: # pragma: no cover
# Fires on most CI matrix entries but not 3.14 locked ubuntu —
# EventSourceResponse cancellation vs stream-close ordering varies.
except Exception: # pragma: lax no cover
logger.exception("Error in standalone SSE writer")
finally:
logger.debug("Closing standalone SSE writer")
Expand Down Expand Up @@ -791,13 +795,13 @@ async def terminate(self) -> None:
# During cleanup, we catch all exceptions since streams might be in various states
logger.debug(f"Error closing streams: {e}")

async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods."""
headers = {
"Content-Type": CONTENT_TYPE_JSON,
"Allow": "GET, POST, DELETE",
}
if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

response = self._create_error_response(
Expand All @@ -824,7 +828,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
request_session_id = self._get_session_id(request)

# If no session ID provided but required, return error
if not request_session_id: # pragma: no cover
if not request_session_id:
response = self._create_error_response(
"Bad Request: Missing session ID",
HTTPStatus.BAD_REQUEST,
Expand All @@ -849,11 +853,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None: # pragma: no cover
if protocol_version is None:
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}. "
Expand All @@ -865,13 +869,13 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool

return True

async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
"""Replays events that would have been sent after the specified event ID.

Only used when resumability is enabled.
"""
event_store = self._event_store
if not event_store:
if not event_store: # pragma: no cover
return

try:
Expand All @@ -881,7 +885,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
"Content-Type": CONTENT_TYPE_SSE,
}

if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Get protocol version from header (already validated in _validate_protocol_version)
Expand All @@ -902,7 +906,7 @@ async def send_event(event_message: EventMessage) -> None:
stream_id = await event_store.replay_events_after(last_event_id, send_event)

# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
if stream_id and stream_id not in self._request_streams: # pragma: no branch
# Register SSE writer so close_sse_stream() can close it
self._sse_stream_writers[stream_id] = sse_stream_writer

Expand All @@ -921,9 +925,9 @@ async def send_event(event_message: EventMessage) -> None:
await sse_stream_writer.send(event_data)
except anyio.ClosedResourceError:
# Expected when close_sse_stream() is called
logger.debug("Replay SSE stream closed by close_sse_stream()")
logger.debug("Replay SSE stream closed by close_sse_stream()") # pragma: no cover
except Exception:
logger.exception("Error in replay sender")
logger.exception("Error in replay sender") # pragma: no cover

# Create and start EventSourceResponse
response = EventSourceResponse(
Expand All @@ -934,13 +938,13 @@ async def send_event(event_message: EventMessage) -> None:

try:
await response(request.scope, request.receive, send)
except Exception:
except Exception: # pragma: no cover
logger.exception("Error in replay response")
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()

except Exception:
except Exception: # pragma: no cover
logger.exception("Error replaying events")
response = self._create_error_response(
"Error replaying events",
Expand Down Expand Up @@ -991,7 +995,7 @@ async def message_router():
if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None:
target_request_id = str(message.id)
# Extract related_request_id from meta if it exists
elif ( # pragma: no cover
elif (
session_message.metadata is not None
and isinstance(
session_message.metadata,
Expand All @@ -1015,10 +1019,10 @@ async def message_router():
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
# Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None)
else: # pragma: no cover
else:
logger.debug(
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
Expand Down
24 changes: 12 additions & 12 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host:
if not host: # pragma: no cover
logger.warning("Missing Host header in request")
return False

Expand All @@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
return True

# Check exact match first
if origin in self.settings.allowed_origins:
if origin in self.settings.allowed_origins: # pragma: no cover
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
if allowed.endswith(":*"): # pragma: no branch
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
Expand Down Expand Up @@ -104,13 +104,13 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
return None

# Validate Host header # pragma: no cover
host = request.headers.get("host") # pragma: no cover
if not self._validate_host(host): # pragma: no cover
return Response("Invalid Host header", status_code=421) # pragma: no cover
host = request.headers.get("host")
if not self._validate_host(host):
return Response("Invalid Host header", status_code=421)

# Validate Origin header # pragma: no cover
origin = request.headers.get("origin") # pragma: no cover
if not self._validate_origin(origin): # pragma: no cover
return Response("Invalid Origin header", status_code=403) # pragma: no cover
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response("Invalid Origin header", status_code=403)

return None # pragma: no cover
return None
Loading
Loading