Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mcp.client.experimental import ExperimentalClientFeatures
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared._context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.message import MessageMetadata, SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types._types import RequestParamsMeta
Expand Down Expand Up @@ -461,6 +461,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
message_metadata: MessageMetadata = None,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
Expand Down
123 changes: 80 additions & 43 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ async def main():

import logging
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Mapping
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
from importlib.metadata import version as importlib_version
from typing import Any, Generic
from typing import Any, Generic, cast

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand All @@ -52,8 +52,8 @@ async def main():
from typing_extensions import TypeVar

from mcp import types
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, auth_context_var
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
from mcp.server.auth.settings import AuthSettings
Expand All @@ -66,14 +66,31 @@ async def main():
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.exceptions import MCPError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.session import NotificationWithMetadata, RequestResponder

logger = logging.getLogger(__name__)

LifespanResultT = TypeVar("LifespanResultT", default=Any)


@contextmanager
def _bind_request_auth_context(request_context: Any) -> Iterator[None]:
"""Rebind auth context from the current transport request while handling a message."""
authenticated_user = None
scope = getattr(request_context, "scope", None)
if isinstance(scope, Mapping):
scope_user = cast(Mapping[str, object], scope).get("user")
if isinstance(scope_user, AuthenticatedUser):
authenticated_user = scope_user

token = auth_context_var.set(authenticated_user)
try:
yield
finally:
auth_context_var.reset(token)


class NotificationOptions:
def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False):
self.prompts_changed = prompts_changed
Expand Down Expand Up @@ -407,7 +424,9 @@ async def run(

async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
message: RequestResponder[types.ClientRequest, types.ServerResult]
| NotificationWithMetadata[types.ClientNotification]
| Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
Expand All @@ -419,6 +438,13 @@ async def _handle_message(
await self._handle_request(
message, responder.request, session, lifespan_context, raise_exceptions
)
case NotificationWithMetadata() as notification:
await self._handle_notification(
notification.notification,
session,
lifespan_context,
notification.message_metadata,
)
case Exception():
logger.error(f"Received exception from stream: {message}")
if raise_exceptions:
Expand Down Expand Up @@ -452,28 +478,32 @@ async def _handle_request(
close_sse_stream_cb = message.message_metadata.close_sse_stream
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream

client_capabilities = session.client_params.capabilities if session.client_params else None
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
# Get task metadata from request params if present
task_metadata = None
if hasattr(req, "params") and req.params is not None:
task_metadata = getattr(req.params, "task", None)
ctx = ServerRequestContext(
request_id=message.request_id,
meta=message.request_meta,
session=session,
lifespan_context=lifespan_context,
experimental=Experimental(
task_metadata=task_metadata,
_client_capabilities=client_capabilities,
_session=session,
_task_support=task_support,
),
request=request_data,
close_sse_stream=close_sse_stream_cb,
close_standalone_sse_stream=close_standalone_sse_stream_cb,
)
response = await handler(ctx, req.params)
# Stateful HTTP sessions process later requests on tasks that were
# created during session setup, so ContextVar snapshots can lag
# behind the current request unless we rebind them here.
with _bind_request_auth_context(request_data):
client_capabilities = session.client_params.capabilities if session.client_params else None
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
# Get task metadata from request params if present
task_metadata = None
if hasattr(req, "params") and req.params is not None:
task_metadata = getattr(req.params, "task", None)
ctx = ServerRequestContext(
request_id=message.request_id,
meta=message.request_meta,
session=session,
lifespan_context=lifespan_context,
experimental=Experimental(
task_metadata=task_metadata,
_client_capabilities=client_capabilities,
_session=session,
_task_support=task_support,
),
request=request_data,
close_sse_stream=close_sse_stream_cb,
close_standalone_sse_stream=close_standalone_sse_stream_cb,
)
response = await handler(ctx, req.params)
except MCPError as err:
response = err.error
except anyio.get_cancelled_exc_class():
Expand Down Expand Up @@ -511,24 +541,31 @@ async def _handle_notification(
notify: types.ClientNotification,
session: ServerSession,
lifespan_context: LifespanResultT,
message_metadata: MessageMetadata = None,
) -> None:
if handler := self._notification_handlers.get(notify.method):
logger.debug("Dispatching notification of type %s", type(notify).__name__)

try:
client_capabilities = session.client_params.capabilities if session.client_params else None
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
ctx = ServerRequestContext(
session=session,
lifespan_context=lifespan_context,
experimental=Experimental(
task_metadata=None,
_client_capabilities=client_capabilities,
_session=session,
_task_support=task_support,
),
)
await handler(ctx, notify.params)
request_data = None
if isinstance(message_metadata, ServerMessageMetadata):
request_data = message_metadata.request_context

with _bind_request_auth_context(request_data):
client_capabilities = session.client_params.capabilities if session.client_params else None
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
ctx = ServerRequestContext(
session=session,
lifespan_context=lifespan_context,
experimental=Experimental(
task_metadata=None,
_client_capabilities=client_capabilities,
_session=session,
_task_support=task_support,
),
request=request_data,
)
await handler(ctx, notify.params)
except Exception: # pragma: no cover
logger.exception("Uncaught exception in notification handler")

Expand Down
17 changes: 14 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
from mcp.shared.exceptions import StatelessModeNotSupported
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
NotificationWithMetadata,
RequestResponder,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
Expand All @@ -60,7 +61,9 @@ class InitializationState(Enum):
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")

ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
RequestResponder[types.ClientRequest, types.ServerResult]
| NotificationWithMetadata[types.ClientNotification]
| Exception
)


Expand Down Expand Up @@ -683,7 +686,15 @@ async def send_message(self, message: SessionMessage) -> None:
"""
await self._write_stream.send(message)

async def _handle_incoming(self, req: ServerRequestResponder) -> None:
async def _handle_incoming(
self,
req: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
message_metadata: MessageMetadata = None,
) -> None:
if isinstance(req, types.ClientNotification):
await self._incoming_message_stream_writer.send(NotificationWithMetadata(req, message_metadata))
return

await self._incoming_message_stream_writer.send(req)

@property
Expand Down
15 changes: 13 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar

Expand Down Expand Up @@ -53,6 +54,14 @@ async def __call__(
) -> None: ... # pragma: no branch


@dataclass
class NotificationWithMetadata(Generic[ReceiveNotificationT]):
"""A validated notification paired with its transport metadata."""

notification: ReceiveNotificationT
message_metadata: MessageMetadata = None


class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.

Expand Down Expand Up @@ -396,7 +405,7 @@ async def _receive_loop(self) -> None:
except Exception:
logging.exception("Progress callback raised an exception")
await self._received_notification(notification)
await self._handle_incoming(notification)
await self._handle_incoming(notification, message.metadata)
except Exception:
# For other validation errors, log and continue
logging.warning( # pragma: no cover
Expand Down Expand Up @@ -515,6 +524,8 @@ async def send_progress_notification(
"""Sends a progress notification for a request that is currently being processed."""

async def _handle_incoming(
self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
message_metadata: MessageMetadata = None,
) -> None:
"""A generic handler for incoming messages. Overridden by subclasses."""
Loading
Loading