diff --git a/docs/experimental/index.md b/docs/experimental/index.md index 1d496b3f1..c97fe2a3d 100644 --- a/docs/experimental/index.md +++ b/docs/experimental/index.md @@ -27,10 +27,9 @@ Tasks are useful for: Experimental features are accessed via the `.experimental` property: ```python -# Server-side -@server.experimental.get_task() -async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - ... +# Server-side: enable task support (auto-registers default handlers) +server = Server(name="my-server") +server.experimental.enable_tasks() # Client-side result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) diff --git a/docs/migration.md b/docs/migration.md index 7d30f0ac9..8bd95ff0b 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -351,7 +351,6 @@ The nested `RequestParams.Meta` Pydantic model class has been replaced with a to - `RequestParams.Meta` (Pydantic model) → `RequestParamsMeta` (TypedDict) - Attribute access (`meta.progress_token`) → Dictionary access (`meta.get("progress_token")`) - `progress_token` field changed from `ProgressToken | None = None` to `NotRequired[ProgressToken]` -` **In request context handlers:** @@ -364,11 +363,12 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: await ctx.session.send_progress_notification(ctx.meta.progress_token, 0.5, 100) # After (v2) -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> list[TextContent]: - ctx = server.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: if ctx.meta and "progress_token" in ctx.meta: await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) + ... + +server = Server("my-server", on_call_tool=handle_call_tool) ``` ### `RequestContext` and `ProgressContext` type parameters simplified @@ -471,6 +471,157 @@ await client.read_resource("test://resource") await client.read_resource(str(my_any_url)) ``` +### Lowlevel `Server`: constructor parameters are now keyword-only + +All parameters after `name` are now keyword-only. If you were passing `version` or other parameters positionally, use keyword arguments instead: + +```python +# Before (v1) +server = Server("my-server", "1.0") + +# After (v2) +server = Server("my-server", version="1.0") +``` + +### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params + +The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import Server + +server = Server("my-server") + +@server.list_tools() +async def handle_list_tools(): + return [types.Tool(name="my_tool", description="A tool", inputSchema={})] + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + return [types.TextContent(type="text", text=f"Called {name}")] +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", description="A tool", inputSchema={})]) + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=f"Called {params.name}")], + is_error=False, + ) + +server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) +``` + +**Key differences:** + +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `RequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). +- The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. + +**Notification handlers:** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ProgressNotificationParams + + +async def handle_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + print(f"Progress: {params.progress}/{params.total}") + +server = Server("my-server", on_progress=handle_progress) +``` + +### Lowlevel `Server`: `request_context` property removed + +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar still exists but should not be needed — use `ctx` directly instead. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import request_ctx + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + ctx = server.request_context # or request_ctx.get() + await ctx.session.send_log_message(level="info", data="Processing...") + return [types.TextContent(type="text", text="Done")] +``` + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import CallToolRequestParams, CallToolResult, TextContent + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + await ctx.session.send_log_message(level="info", data="Processing...") + return CallToolResult( + content=[TextContent(type="text", text="Done")], + is_error=False, + ) +``` + +### `RequestContext`: request-specific fields are now optional + +The `RequestContext` class now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. + +```python +from mcp.server import ServerRequestContext + +# request_id, meta, etc. are available in request handlers +# but None in notification handlers +``` + +### Experimental: task handler decorators removed + +The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. + +Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. + +**Before (v1):** + +```python +server = Server("my-server") +server.experimental.enable_tasks(task_store) + +@server.experimental.get_task() +async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: + ... +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import GetTaskRequestParams, GetTaskResult + + +async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + ... + + +server = Server("my-server") +server.experimental.enable_tasks(on_get_task=custom_get_task) +``` + ## Deprecations @@ -506,16 +657,16 @@ params = CallToolRequestParams( The `streamable_http_app()` method is now available directly on the lowlevel `Server` class, not just `MCPServer`. This allows using the streamable HTTP transport without the MCPServer wrapper. ```python -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams -server = Server("my-server") -# Register handlers... -@server.list_tools() -async def list_tools(): - return [...] +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[...]) + + +server = Server("my-server", on_list_tools=handle_list_tools) -# Create a Starlette app for streamable HTTP app = server.streamable_http_app( streamable_http_path="/mcp", json_response=False, diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a2dada3af..aab5c33f7 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -1,5 +1,6 @@ +from .context import ServerRequestContext from .lowlevel import NotificationOptions, Server from .mcpserver import MCPServer from .models import InitializationOptions -__all__ = ["Server", "MCPServer", "NotificationOptions", "InitializationOptions"] +__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 43b9d3800..d8e11d78b 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -10,7 +10,7 @@ from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback -LifespanContextT = TypeVar("LifespanContextT") +LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 80ae5912b..91aa9a645 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -160,10 +160,7 @@ async def run_task( RuntimeError: If task support is not enabled or task_metadata is missing Example: - @server.call_tool() - async def handle_tool(name: str, args: dict): - ctx = server.request_context - + async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( message="Are you sure?", diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 991221bd0..b2268bc1c 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -44,17 +44,14 @@ class TaskResultHandler: 5. Returns the final result Usage: - # Create handler with store and queue - handler = TaskResultHandler(task_store, message_queue) - - # Register it with the server - @server.experimental.get_task_result() - async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = server.request_context - return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) + async def handle_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + ... + + server.experimental.enable_tasks( + on_task_result=handle_task_result, + ) """ def __init__( diff --git a/src/mcp/server/lowlevel/__init__.py b/src/mcp/server/lowlevel/__init__.py index 66df38991..37191ba1a 100644 --- a/src/mcp/server/lowlevel/__init__.py +++ b/src/mcp/server/lowlevel/__init__.py @@ -1,3 +1,3 @@ from .server import NotificationOptions, Server -__all__ = ["Server", "NotificationOptions"] +__all__ = ["NotificationOptions", "Server"] diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 9b472c023..8ac268728 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -7,10 +7,12 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import Any, Generic +from typing_extensions import TypeVar + +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import MCPError from mcp.shared.experimental.tasks.helpers import cancel_task from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore @@ -18,16 +20,16 @@ from mcp.shared.experimental.tasks.store import TaskStore from mcp.types import ( INVALID_PARAMS, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, - ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, TasksCallCapability, @@ -36,13 +38,12 @@ TasksToolsCapability, ) -if TYPE_CHECKING: - from mcp.server.lowlevel.server import Server - logger = logging.getLogger(__name__) +LifespanResultT = TypeVar("LifespanResultT", default=Any) + -class ExperimentalHandlers: +class ExperimentalHandlers(Generic[LifespanResultT]): """Experimental request/notification handlers. WARNING: These APIs are experimental and may change without notice. @@ -50,13 +51,13 @@ class ExperimentalHandlers: def __init__( self, - server: Server, - request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], - notification_handlers: dict[type, Callable[..., Awaitable[None]]], - ): - self._server = server - self._request_handlers = request_handlers - self._notification_handlers = notification_handlers + add_request_handler: Callable[ + [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None + ], + has_handler: Callable[[str], bool], + ) -> None: + self._add_request_handler = add_request_handler + self._has_handler = has_handler self._task_support: TaskSupport | None = None @property @@ -66,16 +67,13 @@ def task_support(self) -> TaskSupport | None: def update_capabilities(self, capabilities: ServerCapabilities) -> None: # Only add tasks capability if handlers are registered - if not any( - req_type in self._request_handlers - for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest] - ): + if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): return capabilities.tasks = ServerTasksCapability() - if ListTasksRequest in self._request_handlers: + if self._has_handler("tasks/list"): capabilities.tasks.list = TasksListCapability() - if CancelTaskRequest in self._request_handlers: + if self._has_handler("tasks/cancel"): capabilities.tasks.cancel = TasksCancelCapability() capabilities.tasks.requests = ServerTasksRequestsCapability( @@ -86,15 +84,35 @@ def enable_tasks( self, store: TaskStore | None = None, queue: TaskMessageQueue | None = None, + *, + on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] + | None = None, + on_task_result: Callable[ + [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] + ] + | None = None, + on_list_tasks: Callable[ + [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] + ] + | None = None, + on_cancel_task: Callable[ + [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] + ] + | None = None, ) -> TaskSupport: """Enable experimental task support. - This sets up the task infrastructure and auto-registers default handlers - for tasks/get, tasks/result, tasks/list, and tasks/cancel. + This sets up the task infrastructure and registers handlers for + tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers + can be provided via the on_* kwargs; any not provided will use defaults. Args: store: Custom TaskStore implementation (defaults to InMemoryTaskStore) queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) + on_get_task: Custom handler for tasks/get + on_task_result: Custom handler for tasks/result + on_list_tasks: Custom handler for tasks/list + on_cancel_task: Custom handler for tasks/cancel Returns: The TaskSupport configuration object @@ -117,24 +135,27 @@ def enable_tasks( queue = InMemoryTaskMessageQueue() self._task_support = TaskSupport(store=store, queue=queue) - - # Auto-register default handlers - self._register_default_task_handlers() - - return self._task_support - - def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" - assert self._task_support is not None - support = self._task_support - - # Register get_task handler if not already registered - if GetTaskRequest not in self._request_handlers: - - async def _default_get_task(req: GetTaskRequest) -> ServerResult: - task = await support.store.get_task(req.params.task_id) + task_support = self._task_support + + # Register user-provided handlers + if on_get_task is not None: + self._add_request_handler("tasks/get", on_get_task) + if on_task_result is not None: + self._add_request_handler("tasks/result", on_task_result) + if on_list_tasks is not None: + self._add_request_handler("tasks/list", on_list_tasks) + if on_cancel_task is not None: + self._add_request_handler("tasks/cancel", on_cancel_task) + + # Fill in defaults for any not provided + if not self._has_handler("tasks/get"): + + async def _default_get_task( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams + ) -> GetTaskResult: + task = await task_support.store.get_task(params.task_id) if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {req.params.task_id}") + raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( task_id=task.task_id, status=task.status, @@ -145,136 +166,39 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: poll_interval=task.poll_interval, ) - self._request_handlers[GetTaskRequest] = _default_get_task + self._add_request_handler("tasks/get", _default_get_task) - # Register get_task_result handler if not already registered - if GetTaskPayloadRequest not in self._request_handlers: + if not self._has_handler("tasks/result"): - async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = self._server.request_context - result = await support.handler.handle(req, ctx.session, ctx.request_id) + async def _default_get_task_result( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + assert ctx.request_id is not None + req = GetTaskPayloadRequest(params=params) + result = await task_support.handler.handle(req, ctx.session, ctx.request_id) return result - self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result + self._add_request_handler("tasks/result", _default_get_task_result) - # Register list_tasks handler if not already registered - if ListTasksRequest not in self._request_handlers: + if not self._has_handler("tasks/list"): - async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) + async def _default_list_tasks( + ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListTasksResult: + cursor = params.cursor if params else None + tasks, next_cursor = await task_support.store.list_tasks(cursor) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - self._request_handlers[ListTasksRequest] = _default_list_tasks - - # Register cancel_task handler if not already registered - if CancelTaskRequest not in self._request_handlers: - - async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - result = await cancel_task(support.store, req.params.task_id) - return result - - self._request_handlers[CancelTaskRequest] = _default_cancel_task - - def list_tasks( - self, - ) -> Callable[ - [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], - Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ]: - """Register a handler for listing tasks. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: - logger.debug("Registering handler for ListTasksRequest") - wrapper = create_call_wrapper(func, ListTasksRequest) - - async def handler(req: ListTasksRequest) -> ListTasksResult: - result = await wrapper(req) - return result - - self._request_handlers[ListTasksRequest] = handler - return func - - return decorator - - def get_task( - self, - ) -> Callable[ - [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] - ]: - """Register a handler for getting task status. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], - ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: - logger.debug("Registering handler for GetTaskRequest") - wrapper = create_call_wrapper(func, GetTaskRequest) - - async def handler(req: GetTaskRequest) -> GetTaskResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskRequest] = handler - return func - - return decorator - - def get_task_result( - self, - ) -> Callable[ - [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], - Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ]: - """Register a handler for getting task results/payload. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: - logger.debug("Registering handler for GetTaskPayloadRequest") - wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - - async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskPayloadRequest] = handler - return func - - return decorator - - def cancel_task( - self, - ) -> Callable[ - [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], - Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ]: - """Register a handler for cancelling tasks. - - WARNING: This API is experimental and may change without notice. - """ + self._add_request_handler("tasks/list", _default_list_tasks) - def decorator( - func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: - logger.debug("Registering handler for CancelTaskRequest") - wrapper = create_call_wrapper(func, CancelTaskRequest) + if not self._has_handler("tasks/cancel"): - async def handler(req: CancelTaskRequest) -> CancelTaskResult: - result = await wrapper(req) + async def _default_cancel_task( + ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams + ) -> CancelTaskResult: + result = await cancel_task(task_support.store, params.task_id) return result - self._request_handlers[CancelTaskRequest] = handler - return func + self._add_request_handler("tasks/cancel", _default_cancel_task) - return decorator + return task_support diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py deleted file mode 100644 index d17697090..000000000 --- a/src/mcp/server/lowlevel/func_inspection.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from collections.abc import Callable -from typing import Any, TypeVar, get_type_hints - -T = TypeVar("T") -R = TypeVar("R") - - -def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: - """Create a wrapper function that knows how to call func with the request object. - - Returns a wrapper function that takes the request and calls func appropriately. - - The wrapper handles three calling patterns: - 1. Positional-only parameter typed as request_type (no default): func(req) - 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default: func() - """ - try: - sig = inspect.signature(func) - type_hints = get_type_hints(func) - except (ValueError, TypeError, NameError): # pragma: no cover - return lambda _: func() - - # Check for positional-only parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind == inspect.Parameter.POSITIONAL_ONLY: - param_type = type_hints.get(param_name) - if param_type == request_type: # pragma: no branch - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - # Found positional-only parameter with correct type and no default - return lambda req: func(req) - - # Check for any positional/keyword parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch - param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - - # Found keyword parameter with correct type and no default - # Need to capture param_name in closure properly - def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: - return lambda req: func(**{name: req}) - - return make_keyword_wrapper(param_name) - - # No request parameter found - use old style - return lambda _: func() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7bd79bb37..8f647e12e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -2,82 +2,49 @@ This module provides a framework for creating an MCP (Model Context Protocol) server. It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. +using constructor-based handler registration. Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None, - message: str | None - ) -> None: - # Implementation - -4. Run the server: +1. Define handler functions: + async def my_list_tools(ctx, params): + return types.ListToolsResult(tools=[...]) + + async def my_call_tool(ctx, params): + return types.CallToolResult(content=[...]) + +2. Create a Server instance with on_* handlers: + server = Server( + "your_server_name", + on_list_tools=my_list_tools, + on_call_tool=my_call_tool, + ) + +3. Run the server: async def main(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) asyncio.run(main()) -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. +The Server class dispatches incoming requests and notifications to registered +handler callables by method string. """ from __future__ import annotations -import base64 import contextvars -import json import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic, TypeAlias, cast +from typing import Any, Generic import anyio -import jsonschema from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware @@ -94,30 +61,20 @@ async def main(): from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError +from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) -RequestT = TypeVar("RequestT", default=Any) - -# type aliases for tool call results -StructuredContent: TypeAlias = dict[str, Any] -UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] -CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] -# This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[ServerRequestContext[Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[ServerRequestContext[Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -128,7 +85,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Args: @@ -140,10 +97,15 @@ async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[s yield {} -class Server(Generic[LifespanResultT, RequestT]): +async def _ping_handler(ctx: ServerRequestContext[Any], params: types.RequestParams | None) -> types.EmptyResult: + return types.EmptyResult() + + +class Server(Generic[LifespanResultT]): def __init__( self, name: str, + *, version: str | None = None, title: str | None = None, description: str | None = None, @@ -151,9 +113,80 @@ def __init__( website_url: str | None = None, icons: list[types.Icon] | None = None, lifespan: Callable[ - [Server[LifespanResultT, RequestT]], + [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + # Request handlers + on_list_tools: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListToolsResult], + ] + | None = None, + on_call_tool: Callable[ + [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], + Awaitable[types.CallToolResult | types.CreateTaskResult], + ] + | None = None, + on_list_resources: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourcesResult], + ] + | None = None, + on_list_resource_templates: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourceTemplatesResult], + ] + | None = None, + on_read_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams], + Awaitable[types.ReadResourceResult], + ] + | None = None, + on_subscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.SubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_unsubscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.UnsubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_list_prompts: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListPromptsResult], + ] + | None = None, + on_get_prompt: Callable[ + [ServerRequestContext[LifespanResultT], types.GetPromptRequestParams], + Awaitable[types.GetPromptResult], + ] + | None = None, + on_completion: Callable[ + [ServerRequestContext[LifespanResultT], types.CompleteRequestParams], + Awaitable[types.CompleteResult], + ] + | None = None, + on_set_logging_level: Callable[ + [ServerRequestContext[LifespanResultT], types.SetLevelRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_ping: Callable[ + [ServerRequestContext[LifespanResultT], types.RequestParams | None], + Awaitable[types.EmptyResult], + ] = _ping_handler, + # Notification handlers + on_roots_list_changed: Callable[ + [ServerRequestContext[LifespanResultT], types.NotificationParams | None], + Awaitable[None], + ] + | None = None, + on_progress: Callable[ + [ServerRequestContext[LifespanResultT], types.ProgressNotificationParams], + Awaitable[None], + ] + | None = None, ): self.name = name self.version = version @@ -163,15 +196,72 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self._tool_cache: dict[str, types.Tool] = {} - self._experimental_handlers: ExperimentalHandlers | None = None + self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} + self._notification_handlers: dict[ + str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] + ] = {} + self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) + # Populate internal handler dicts from on_* kwargs + self._request_handlers.update( + { + method: handler + for method, handler in { + "ping": on_ping, + "prompts/list": on_list_prompts, + "prompts/get": on_get_prompt, + "resources/list": on_list_resources, + "resources/templates/list": on_list_resource_templates, + "resources/read": on_read_resource, + "resources/subscribe": on_subscribe_resource, + "resources/unsubscribe": on_unsubscribe_resource, + "tools/list": on_list_tools, + "tools/call": on_call_tool, + "logging/setLevel": on_set_logging_level, + "completion/complete": on_completion, + }.items() + if handler is not None + } + ) + + self._notification_handlers.update( + { + method: handler + for method, handler in { + "notifications/roots/list_changed": on_roots_list_changed, + "notifications/progress": on_progress, + }.items() + if handler is not None + } + ) + + def _add_request_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + ) -> None: + """Add a request handler, silently replacing any existing handler for the same method.""" + self._request_handlers[method] = handler + + def _add_notification_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]], + ) -> None: + """Add a notification handler, silently replacing any existing handler for the same method.""" + self._notification_handlers[method] = handler + + def _has_handler(self, method: str) -> bool: + """Check if a handler is registered for the given method.""" + return method in self._request_handlers or method in self._notification_handlers + + # TODO: Rethink capabilities API. Currently capabilities are derived from registered + # handlers but require NotificationOptions to be passed externally for list_changed + # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities + # entirely from server state (e.g. constructor params for list_changed) instead of + # requiring callers to assemble them at create_initialization_options() time. def create_initialization_options( self, notification_options: NotificationOptions | None = None, @@ -214,25 +304,26 @@ def get_capabilities( completions_capability = None # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: + if "prompts/list" in self._request_handlers: prompts_capability = types.PromptsCapability(list_changed=notification_options.prompts_changed) # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: + if "resources/list" in self._request_handlers: resources_capability = types.ResourcesCapability( - subscribe=False, list_changed=notification_options.resources_changed + subscribe="resources/subscribe" in self._request_handlers, + list_changed=notification_options.resources_changed, ) # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: + if "tools/list" in self._request_handlers: tools_capability = types.ToolsCapability(list_changed=notification_options.tools_changed) # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: + if "logging/setLevel" in self._request_handlers: logging_capability = types.LoggingCapability() # Set completions capabilities if handler exists - if types.CompleteRequest in self.request_handlers: + if "completion/complete" in self._request_handlers: completions_capability = types.CompletionsCapability() capabilities = types.ServerCapabilities( @@ -248,12 +339,7 @@ def get_capabilities( return capabilities @property - def request_context(self) -> ServerRequestContext[LifespanResultT, RequestT]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - @property - def experimental(self) -> ExperimentalHandlers: + def experimental(self) -> ExperimentalHandlers[LifespanResultT]: """Experimental APIs for tasks and other features. WARNING: These APIs are experimental and may change without notice. @@ -261,7 +347,10 @@ def experimental(self) -> ExperimentalHandlers: # We create this inline so we only add these capabilities _if_ they're actually used if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers) + self._experimental_handlers = ExperimentalHandlers( + add_request_handler=self._add_request_handler, + has_handler=self._has_handler, + ) return self._experimental_handlers @property @@ -278,374 +367,6 @@ def session_manager(self) -> StreamableHTTPSessionManager: ) return self._session_manager # pragma: no cover - def list_prompts(self): - def decorator( - func: Callable[[], Awaitable[list[types.Prompt]]] - | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], - ): - logger.debug("Registering handler for PromptListRequest") - - wrapper = create_call_wrapper(func, types.ListPromptsRequest) - - async def handler(req: types.ListPromptsRequest): - result = await wrapper(req) - # Handle both old style (list[Prompt]) and new style (ListPromptsResult) - if isinstance(result, types.ListPromptsResult): - return result - else: - # Old style returns list[Prompt] - return types.ListPromptsResult(prompts=result) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return prompt_get - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator( - func: Callable[[], Awaitable[list[types.Resource]]] - | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], - ): - logger.debug("Registering handler for ListResourcesRequest") - - wrapper = create_call_wrapper(func, types.ListResourcesRequest) - - async def handler(req: types.ListResourcesRequest): - result = await wrapper(req) - # Handle both old style (list[Resource]) and new style (ListResourcesResult) - if isinstance(result, types.ListResourcesResult): - return result - else: - # Old style returns list[Resource] - return types.ListResourcesResult(resources=result) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ListResourceTemplatesResult(resource_templates=templates) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator( - func: Callable[[str], Awaitable[str | bytes | Iterable[ReadResourceContents]]], - ): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - - def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any] | None = None): - # Note: ResourceContents uses Field(alias="_meta"), so we must use the alias key - meta_kwargs: dict[str, Any] = {"_meta": meta} if meta is not None else {} - match data: - case str() as data: - return types.TextResourceContents( - uri=req.params.uri, - text=data, - mime_type=mime_type or "text/plain", - **meta_kwargs, - ) - case bytes() as data: # pragma: no branch - return types.BlobResourceContents( - uri=req.params.uri, - blob=base64.b64encode(data).decode(), - mime_type=mime_type or "application/octet-stream", - **meta_kwargs, - ) - - match result: - case str() | bytes() as data: # pragma: lax no cover - warnings.warn( - "Returning str or bytes from read_resource is deprecated. " - "Use Iterable[ReadResourceContents] instead.", - DeprecationWarning, - stacklevel=2, - ) - content = create_content(data, None) - case Iterable() as contents: - contents_list = [ - create_content( - content_item.content, content_item.mime_type, getattr(content_item, "meta", None) - ) - for content_item in contents - ] - return types.ReadResourceResult(contents=contents_list) - case _: # pragma: no cover - raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - - return types.ReadResourceResult(contents=[content]) # pragma: no cover - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.EmptyResult() - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator( - func: Callable[[], Awaitable[list[types.Tool]]] - | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], - ): - logger.debug("Registering handler for ListToolsRequest") - - wrapper = create_call_wrapper(func, types.ListToolsRequest) - - async def handler(req: types.ListToolsRequest): - result = await wrapper(req) - - # Handle both old style (list[Tool]) and new style (ListToolsResult) - if isinstance(result, types.ListToolsResult): - # Refresh the tool cache with returned tools - for tool in result.tools: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return result - else: - # Old style returns list[Tool] - # Clear and refresh the entire tool cache - self._tool_cache.clear() - for tool in result: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return types.ListToolsResult(tools=result) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def _make_error_result(self, error_message: str) -> types.CallToolResult: - """Create a CallToolResult with an error.""" - return types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) - - async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: - """Get tool definition from cache, refreshing if necessary. - - Returns the Tool object if found, None otherwise. - """ - if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: - logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) - - tool = self._tool_cache.get(tool_name) - if tool is None: - logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) - - return tool - - def call_tool(self, *, validate_input: bool = True): - """Register a tool call handler. - - Args: - validate_input: If True, validates input against inputSchema. Default is True. - - The handler validates input against inputSchema (if validate_input=True), calls the tool function, - and builds a CallToolResult with the results: - - Unstructured content (iterable of ContentBlock): returned in content - - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - - Both: returned in content and structuredContent - - If outputSchema is defined, validates structuredContent or errors if missing. - """ - - def decorator( - func: Callable[ - [str, dict[str, Any]], - Awaitable[ - UnstructuredContent - | StructuredContent - | CombinationContent - | types.CallToolResult - | types.CreateTaskResult - ], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - tool_name = req.params.name - arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) - - # input validation - if validate_input and tool: - try: - jsonschema.validate(instance=arguments, schema=tool.input_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Input validation error: {e.message}") - - # tool call - results = await func(tool_name, arguments) - - # output normalization - unstructured_content: UnstructuredContent - maybe_structured_content: StructuredContent | None - if isinstance(results, types.CallToolResult): - return results - elif isinstance(results, types.CreateTaskResult): - # Task-augmented execution returns task info instead of result - return results - elif isinstance(results, tuple) and len(results) == 2: - # tool returned both structured and unstructured content - unstructured_content, maybe_structured_content = cast(CombinationContent, results) - elif isinstance(results, dict): - # tool returned structured content only - maybe_structured_content = cast(StructuredContent, results) - unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): - # tool returned unstructured content only - unstructured_content = cast(UnstructuredContent, results) - maybe_structured_content = None - else: # pragma: no cover - return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") - - # output validation - if tool and tool.output_schema is not None: - if maybe_structured_content is None: - return self._make_error_result( - "Output validation error: outputSchema defined but no structured output returned" - ) - else: - try: - jsonschema.validate(instance=maybe_structured_content, schema=tool.output_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Output validation error: {e.message}") - - # result - return types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # by _handle_request, which converts it to an error response with code -32042 - raise - except Exception as e: - return self._make_error_result(str(e)) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progress_token, - req.params.progress, - req.params.total, - req.params.message, - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceTemplateReference, - types.CompletionArgument, - types.CompletionContext | None, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator - async def run( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -715,7 +436,7 @@ async def _handle_message( if raise_exceptions: raise message case _: - await self._handle_notification(message) + await self._handle_notification(message, session, lifespan_context) for warning in w: # pragma: lax no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -730,10 +451,9 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): + if handler := self._request_handlers.get(req.method): logger.debug("Dispatching request of type %s", type(req).__name__) - token = None try: # Extract request context and close_sse_stream from message metadata request_data = None @@ -744,32 +464,32 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - # Set our global state that can be retrieved via - # app.get_request_context() client_capabilities = session.client_params.capabilities if session.client_params else None task_support = self._experimental_handlers.task_support if self._experimental_handlers else None # Get task metadata from request params if present task_metadata = None if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) - token = request_ctx.set( - ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - response = await handler(req) + token = request_ctx.set(ctx) + try: + response = await handler(ctx, req.params) + finally: + request_ctx.reset(token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): @@ -779,10 +499,6 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err), data=None) - finally: - # Reset the global state after we are done - if token is not None: # pragma: no branch - request_ctx.reset(token) await message.respond(response) else: # pragma: no cover @@ -790,12 +506,29 @@ async def _handle_request( logger.debug("Response sent") - async def _handle_notification(self, notify: Any): - if handler := self.notification_handlers.get(type(notify)): # type: ignore + async def _handle_notification( + self, + notify: types.ClientNotification, + session: ServerSession, + lifespan_context: LifespanResultT, + ) -> None: + if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") @@ -910,7 +643,3 @@ def streamable_http_app( middleware=middleware, lifespan=lambda app: session_manager.run(), ) - - -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.EmptyResult() diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..a26254632 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -2,7 +2,9 @@ from __future__ import annotations +import base64 import inspect +import json import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager @@ -29,7 +31,7 @@ from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation from mcp.server.elicitation import elicit_url as _elicit_url from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT, Server +from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager @@ -42,7 +44,30 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.shared.exceptions import MCPError +from mcp.types import ( + Annotations, + BlobResourceContents, + CallToolRequestParams, + CallToolResult, + CompleteRequestParams, + CompleteResult, + Completion, + ContentBlock, + GetPromptRequestParams, + GetPromptResult, + Icon, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + ToolAnnotations, +) from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -91,9 +116,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: MCPServer[LifespanResultT], lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[Server[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: +) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: Server[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -132,6 +157,9 @@ def __init__( auth=auth, ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", title=title, @@ -140,13 +168,17 @@ def __init__( website_url=website_url, icons=icons, version=version, + on_list_tools=self._handle_list_tools, + on_call_tool=self._handle_call_tool, + on_list_resources=self._handle_list_resources, + on_read_resource=self._handle_read_resource, + on_list_resource_templates=self._handle_list_resource_templates, + on_list_prompts=self._handle_list_prompts, + on_get_prompt=self._handle_get_prompt, # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover @@ -164,9 +196,6 @@ def __init__( self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._custom_starlette_routes: list[Route] = [] - # Set up MCP protocol handlers - self._setup_handlers() - # Configure logging configure_logging(self.settings.log_level) @@ -263,18 +292,80 @@ def run( case "streamable-http": # pragma: no cover anyio.run(lambda: self.run_streamable_http_async(**kwargs)) - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._lowlevel_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # MCPServer does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._lowlevel_server.call_tool(validate_input=False)(self.call_tool) - self._lowlevel_server.list_resources()(self.list_resources) - self._lowlevel_server.read_resource()(self.read_resource) - self._lowlevel_server.list_prompts()(self.list_prompts) - self._lowlevel_server.get_prompt()(self.get_prompt) - self._lowlevel_server.list_resource_templates()(self.list_resource_templates) + async def _handle_list_tools( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=await self.list_tools()) + + async def _handle_call_tool( + self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams + ) -> CallToolResult: + try: + result = await self.call_tool(params.name, params.arguments or {}) + except MCPError: + raise + except Exception as e: + return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True) + if isinstance(result, CallToolResult): + return result + if isinstance(result, tuple) and len(result) == 2: + unstructured_content, structured_content = result + return CallToolResult( + content=list(unstructured_content), # type: ignore[arg-type] + structured_content=structured_content, # type: ignore[arg-type] + ) + if isinstance(result, dict): + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(result, indent=2))], + structured_content=result, + ) + return CallToolResult(content=list(result)) + + async def _handle_list_resources( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=await self.list_resources()) + + async def _handle_read_resource( + self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams + ) -> ReadResourceResult: + results = await self.read_resource(params.uri) + contents: list[TextResourceContents | BlobResourceContents] = [] + for item in results: + if isinstance(item.content, bytes): + contents.append( + BlobResourceContents( + uri=params.uri, + blob=base64.b64encode(item.content).decode(), + mime_type=item.mime_type or "application/octet-stream", + _meta=item.meta, + ) + ) + else: + contents.append( + TextResourceContents( + uri=params.uri, + text=item.content, + mime_type=item.mime_type or "text/plain", + _meta=item.meta, + ) + ) + return ReadResourceResult(contents=contents) + + async def _handle_list_resource_templates( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates()) + + async def _handle_list_prompts( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=await self.list_prompts()) + + async def _handle_get_prompt( + self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams + ) -> GetPromptResult: + return await self.get_prompt(params.name, params.arguments) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -298,7 +389,7 @@ def get_context(self) -> Context[LifespanResultT, Request]: during a request; outside a request, most methods will error. """ try: - request_context = self._lowlevel_server.request_context + request_context = request_ctx.get() except LookupError: request_context = None return Context(request_context=request_context, mcp_server=self) @@ -486,7 +577,24 @@ async def handle_completion(ref, argument, context): return Completion(values=["option1", "option2"]) return None """ - return self._lowlevel_server.completion() + + def decorator(func: _CallableT) -> _CallableT: + async def handler( + ctx: ServerRequestContext[LifespanResultT], params: CompleteRequestParams + ) -> CompleteResult: + result = await func(params.ref, params.argument, params.context) + return CompleteResult( + completion=result if result is not None else Completion(values=[], total=None, has_more=None), + ) + + # TODO(maxisbey): remove private access — completion needs post-construction + # handler registration, find a better pattern for this + self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] + "completion/complete", handler + ) + return func + + return decorator def add_resource(self, resource: Resource) -> None: """Add a resource to the server. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a3..6925aa556 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,30 +6,22 @@ Common usage pattern: ``` - server = Server(name) - - @server.call_tool() - async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: + async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: # Check client capabilities before proceeding if ctx.session.check_client_capability( types.ClientCapabilities(experimental={"advanced_tools": dict()}) ): - # Perform advanced tool operations - result = await perform_advanced_tool_operation(arguments) + result = await perform_advanced_tool_operation(params.arguments) else: - # Fall back to basic tool operations - result = await perform_basic_tool_operation(arguments) - + result = await perform_basic_tool_operation(params.arguments) return result - @server.list_prompts() - async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: - # Access session for any necessary checks or operations + async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: if ctx.session.client_params: - # Customize prompts based on client initialization parameters - return generate_custom_prompts(ctx.session.client_params) - else: - return default_prompts + return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) + return ListPromptsResult(prompts=default_prompts) + + server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) ``` The ServerSession class is typically used internally by the Server class and should not diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index ddc6e5014..8eb29c4d4 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -60,7 +60,7 @@ class StreamableHTTPSessionManager: def __init__( self, - app: Server[Any, Any], + app: Server[Any], event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index 2facc2a49..bbcee2d02 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -13,8 +13,12 @@ @dataclass(kw_only=True) class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests.""" + """Common context for handling incoming requests. + + For request handlers, request_id is always populated. + For notification handlers, request_id is None. + """ - request_id: RequestId - meta: RequestParamsMeta | None session: SessionT + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index 38ca802da..bd1781cb5 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -72,9 +72,8 @@ async def cancel_task( - Task is already in a terminal state (completed, failed, cancelled) Example: - @server.experimental.cancel_task() - async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: - return await cancel_task(store, request.params.taskId) + async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: + return await cancel_task(store, params.task_id) """ task = await store.get_task(task_id) if task is None: diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d..1858eeac3 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,6 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Any from mcp.types import JSONRPCMessage, RequestId @@ -30,8 +31,10 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None - # Request-specific context (e.g., headers, auth info) - request_context: object | None = None + # Transport-specific request context (e.g. starlette Request for HTTP + # transports, None for stdio). Typed as Any because the server layer is + # transport-agnostic. + request_context: Any = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index d483ae54b..45300063a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -11,7 +11,7 @@ from mcp import types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( CallToolResult, @@ -41,33 +41,36 @@ @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - @server.list_resources() - async def handle_list_resources(): - return [Resource(uri="memory://test", name="Test Resource", description="A test resource")] + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")] + ) - @server.subscribe_resource() - async def handle_subscribe_resource(uri: str): - pass + async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + return EmptyResult() - @server.unsubscribe_resource() - async def handle_unsubscribe_resource(uri: str): - pass + async def handle_unsubscribe_resource( + ctx: ServerRequestContext, params: types.UnsubscribeRequestParams + ) -> EmptyResult: + return EmptyResult() - @server.set_logging_level() - async def handle_set_logging_level(level: str): - pass + async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + return EmptyResult() - @server.completion() - async def handle_completion( - ref: types.PromptReference | types.ResourceTemplateReference, - argument: types.CompletionArgument, - context: types.CompletionContext | None, - ) -> types.Completion | None: - return types.Completion(values=[]) + async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + return types.CompleteResult(completion=types.Completion(values=[])) - return server + return Server( + name="test_server", + on_list_resources=handle_list_resources, + on_subscribe_resource=handle_subscribe_resource, + on_unsubscribe_resource=handle_unsubscribe_resource, + on_set_logging_level=handle_set_logging_level, + on_completion=handle_completion, + ) @pytest.fixture @@ -202,19 +205,14 @@ async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() - server = Server(name="test_server") - - @server.progress_notification() - async def handle_progress_notification( - progress_token: str | int, - progress: float = 0.0, - total: float | None = None, - message: str | None = None, - ) -> None: + + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: nonlocal received_from_client - received_from_client = {"progress_token": progress_token, "progress": progress} + received_from_client = {"progress_token": params.progress_token, "progress": params.progress} event.set() + server = Server(name="test_server", on_progress=handle_progress) + async with Client(server) as client: await client.send_progress_notification(progress_token="token123", progress=50.0) await event.wait() diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 5cca8c194..cc2e14e46 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -8,7 +8,6 @@ import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager -from typing import Any import pytest from starlette.applications import Starlette @@ -17,7 +16,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool from tests.test_helpers import wait_for_server @@ -47,54 +46,56 @@ def run_unicode_server(port: int) -> None: # pragma: no cover import uvicorn # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") - - @server.list_tools() - async def list_tools() -> list[Tool]: - """List tools with Unicode descriptions.""" - return [ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, + }, + "required": ["text"], }, - "required": ["text"], - }, - ), - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) + ), ] - else: - raise ValueError(f"Unknown tool: {name}") - - @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: - """List prompts with Unicode names and descriptions.""" - return [ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "echo_unicode": + text = params.arguments.get("text", "") if params.arguments else "" + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=f"Echo: {text}", + ) + ] ) - ] + else: + raise ValueError(f"Unknown tool: {params.name}") + + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], + ) + ] + ) - @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: - """Get a prompt with Unicode content.""" - if name == "unicode_prompt": + async def handle_get_prompt( + ctx: ServerRequestContext, params: types.GetPromptRequestParams + ) -> types.GetPromptResult: + if params.name == "unicode_prompt": return types.GetPromptResult( messages=[ types.PromptMessage( @@ -106,7 +107,15 @@ async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPr ) ] ) - raise ValueError(f"Unknown prompt: {name}") + raise ValueError(f"Unknown prompt: {params.name}") + + server = Server( + name="unicode_test_server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) # Create the session manager session_manager = StreamableHTTPSessionManager( diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 4d7c53db2..f70fb9277 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -3,9 +3,9 @@ import pytest from mcp import Client, types -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import ListToolsRequest, ListToolsResult +from mcp.types import ListToolsResult from .conftest import StreamSpyCollection @@ -105,14 +105,16 @@ async def test_list_tools_with_strict_server_validation( async def test_list_tools_with_lowlevel_server(): """Test that list_tools works with a lowlevel Server using params.""" - server = Server("test-lowlevel") - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: # Echo back what cursor we received in the tool description - cursor = request.params.cursor if request.params else None + cursor = params.cursor if params else None return ListToolsResult(tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})]) + server = Server("test-lowlevel", on_list_tools=handle_list_tools) + async with Client(server) as client: result = await client.list_tools() assert result.tools[0].description == "cursor=None" diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 30ecb0ac3..47be3e208 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -2,31 +2,31 @@ import pytest -from mcp import Client +from mcp import Client, types from mcp.client._memory import InMemoryTransport -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import Resource +from mcp.types import ListResourcesResult, Resource @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - - # pragma: no cover - handler exists only to register a resource capability. - # Transport tests verify stream creation, not handler invocation. - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - return server + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: # pragma: no cover + return ListResourcesResult( + resources=[ + Resource( + uri="memory://test", + name="Test Resource", + description="A test resource", + ) + ] + ) + + return Server(name="test_server", on_list_resources=handle_list_resources) @pytest.fixture diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index f21abf4d0..0da4ce3bf 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -1,43 +1,39 @@ """Tests for the experimental client task methods (session.experimental).""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, Tool, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -48,44 +44,53 @@ class AppContext: task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - # Note: We bypass the normal lifespan mechanism - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] - store = InMemoryTaskStore() +async def _handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="test_tool", description="Test", input_schema={"type": "object"})]) - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) +async def _handle_call_tool_with_done_event( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" +) -> CallToolResult | CreateTaskResult: + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) - done_event = Event() - app.task_done_events[task.task_id] = done_event + done_event = Event() + app.task_done_events[task.task_id] = done_event - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() + async def do_work() -> None: + async with task_execution(task.task_id, app.store) as task_ctx: + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) + done_event.set() - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) - raise NotImplementedError + raise NotImplementedError + + +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + return app_lifespan + + +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -96,280 +101,141 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use session.experimental to get task status - task_status = await client_session.experimental.get_task(task_id) + # Wait for task to complete + await task_done_events[task_id].wait() - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Use session.experimental to get task status + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.task_id == task_id + assert task_status.status == "completed" -@pytest.mark.anyio async def test_session_experimental_get_task_result() -> None: """Test session.experimental.get_task_result() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete( - CallToolResult(content=[TextContent(type="text", text="Task result content")]) - ) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_task_result=handle_get_task_result) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use TaskClient to get task result - task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + # Wait for task to complete + await task_done_events[task_id].wait() - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" + # Use TaskClient to get task result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" -@pytest.mark.anyio async def test_session_experimental_list_tasks() -> None: """Test TaskClient.list_tasks() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: - app = server.request_context.lifespan_context - tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + cursor = params.cursor if params else None + tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) + + async with Client(server) as client: + # Create two tasks + for _ in range(2): + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create two tasks - for _ in range(2): - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await app_context.task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client_session.experimental.list_tasks() + CreateTaskResult, + ) + await task_done_events[create_result.task.task_id].wait() - assert len(list_result.tasks) == 2 + # Use TaskClient to list tasks + list_result = await client.session.experimental.list_tasks() - tg.cancel_scope.cancel() + assert len(list_result.tasks) == 2 -@pytest.mark.anyio async def test_session_experimental_cancel_task() -> None: """Test TaskClient.cancel_task() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool_no_work( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata @@ -377,14 +243,12 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextCon task = await app.store.create_task(task_metadata) # Don't start any work - task stays in "working" status return CreateTaskResult(task=task) - raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -395,14 +259,14 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" - await app.store.update_task(request.params.task_id, status="cancelled") - # CancelTaskResult extends Task, so we need to return the updated task info - updated_task = await app.store.get_task(request.params.task_id) + async def handle_cancel_task( + ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams + ) -> CancelTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" + await app.store.update_task(params.task_id, status="cancelled") + updated_task = await app.store.get_task(params.task_id) assert updated_task is not None return CancelTaskResult( task_id=updated_task.task_id, @@ -412,63 +276,35 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: ttl=updated_task.ttl, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool_no_work, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) + + async with Client(server) as client: + # Create a task (but don't complete it) + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task (but don't complete it) - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client_session.experimental.get_task(task_id) - assert status_before.status == "working" + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Cancel the task - await client_session.experimental.cancel_task(task_id) + # Verify task is working + status_before = await client.session.experimental.get_task(task_id) + assert status_before.status == "working" - # Verify task is cancelled - status_after = await client_session.experimental.get_task(task_id) - assert status_after.status == "cancelled" + # Cancel the task + await client.session.experimental.cancel_task(task_id) - tg.cancel_scope.cancel() + # Verify task is cancelled + status_after = await client.session.experimental.get_task(task_id) + assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 57122da7b..c91a87659 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -17,7 +17,7 @@ from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions from mcp.shared._context import RequestContext @@ -26,6 +26,7 @@ from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, CreateMessageRequestParams, CreateMessageResult, @@ -35,6 +36,8 @@ ErrorData, GetTaskPayloadResult, GetTaskResult, + ListToolsResult, + PaginatedRequestParams, SamplingMessage, TaskMetadata, TextContent, @@ -181,24 +184,21 @@ async def test_scenario1_normal_tool_normal_elicitation() -> None: Server calls session.elicit() directly, client responds immediately. """ - server = Server("test-scenario1") elicit_received = Event() tool_result: list[str] = [] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Normal elicitation - expects immediate response result = await ctx.session.elicit( message="Please confirm the action", @@ -209,6 +209,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -262,27 +264,24 @@ async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: Server calls session.experimental.elicit_as_task(), client creates a task for the elicitation and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2") elicit_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented elicitation - server polls client result = await ctx.session.experimental.elicit_as_task( message="Please confirm the action", @@ -294,6 +293,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -342,26 +342,22 @@ async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: Client calls tool as task. Inside the task, server uses task.elicit() which queues the request and delivers via tasks/result. """ - server = Server("test-scenario3") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -377,6 +373,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -452,29 +451,25 @@ async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> Non 5. Server gets the ElicitResult and completes the tool task 6. Client's tasks/result returns with the CallToolResult """ - server = Server("test-scenario4") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -491,6 +486,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -553,27 +550,24 @@ async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: Server calls session.experimental.create_message_as_task(), client creates a task for the sampling and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2-sampling") sampling_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented sampling - server polls client result = await ctx.session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], @@ -587,6 +581,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append(response_text) return CallToolResult(content=[TextContent(type="text", text=response_text)]) + server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams @@ -636,29 +631,25 @@ async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() which sends task-augmented sampling. Client creates its own task for the sampling, and server polls the client. """ - server = Server("test-scenario4-sampling") - server.experimental.enable_tasks() - sampling_received = Event() work_completed = Event() # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -677,6 +668,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py index d00ce40a4..dc1e090e1 100644 --- a/tests/experimental/tasks/test_spec_compliance.py +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -10,17 +10,17 @@ import pytest -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY from mcp.types import ( - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, CreateTaskResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, Task, ) @@ -44,13 +44,22 @@ def test_server_without_task_handlers_has_no_tasks_capability() -> None: assert caps.tasks is None +async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + +async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: + raise NotImplementedError + + +async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: + raise NotImplementedError + + def test_server_with_list_tasks_handler_declares_list_capability() -> None: """Server with list_tasks handler declares tasks.list capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError + server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) caps = _get_capabilities(server) assert caps.tasks is not None @@ -60,10 +69,7 @@ async def handle_list(req: ListTasksRequest) -> ListTasksResult: def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: """Server with cancel_task handler declares tasks.cancel capability.""" server: Server = Server("test") - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -75,10 +81,7 @@ def test_server_with_get_task_handler_declares_requests_tools_call_capability() (get_task is required for task-augmented tools/call support) """ server: Server = Server("test") - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -89,11 +92,7 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: def test_server_without_list_handler_has_no_list_capability() -> None: """Server without list_tasks handler has no tasks.list capability.""" server: Server = Server("test") - - # Register only get_task (not list_tasks) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -103,11 +102,7 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: def test_server_without_cancel_handler_has_no_cancel_capability() -> None: """Server without cancel_task handler has no tasks.cancel capability.""" server: Server = Server("test") - - # Register only get_task (not cancel_task) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -117,18 +112,11 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: def test_server_with_all_task_handlers_has_full_capability() -> None: """Server with all task handlers declares complete tasks capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks( + on_list_tasks=_noop_list_tasks, + on_cancel_task=_noop_cancel_task, + on_get_task=_noop_get_task, + ) caps = _get_capabilities(server) assert caps.tasks is not None diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index e738017f8..851e89979 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -3,9 +3,16 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer +from mcp.types import ( + BlobResourceContents, + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -58,7 +65,6 @@ def get_image_as_bytes() -> bytes: async def test_lowlevel_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" - server = Server("test") # Create a small test image as bytes image_bytes = b"fake_image_data" @@ -74,17 +80,24 @@ async def test_lowlevel_resource_mime_type(): ), ] - @server.list_resources() - async def handle_list_resources(): - return test_resources - - @server.read_resource() - async def handle_read_resource(uri: str): - if str(uri) == "test://image": - return [ReadResourceContents(content=base64_string, mime_type="image/png")] - elif str(uri) == "test://image_bytes": - return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] - raise Exception(f"Resource not found: {uri}") # pragma: no cover + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) + + resource_contents: dict[str, list[TextResourceContents | BlobResourceContents]] = { + "test://image": [TextResourceContents(uri="test://image", text=base64_string, mime_type="image/png")], + "test://image_bytes": [ + BlobResourceContents( + uri="test://image_bytes", blob=base64.b64encode(image_bytes).decode("utf-8"), mime_type="image/png" + ) + ], + } + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult(contents=resource_contents[str(params.uri)]) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) # Test that resources are listed with correct mime type async with Client(server) as client: diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index e6ff56877..c67708128 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -13,8 +13,14 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -26,24 +32,24 @@ async def test_relative_uri_roundtrip(): the server would fail to serialize resources with relative URIs, or the URI would be transformed during the roundtrip. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="user", uri="users/me"), - types.Resource(name="config", uri="./config"), - types.Resource(name="parent", uri="../parent/resource"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ - ReadResourceContents( - content=f"data for {uri}", - mime_type="text/plain", - ) - ] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="user", uri="users/me"), + types.Resource(name="config", uri="./config"), + types.Resource(name="parent", uri="../parent/resource"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"data for {params.uri}", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: # List should return the exact URIs we specified @@ -67,18 +73,23 @@ async def test_custom_scheme_uri_roundtrip(): Some MCP servers use custom schemes like "custom://resource". These should work end-to-end. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="custom", uri="custom://my-resource"), - types.Resource(name="file", uri="file:///path/to/file"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ReadResourceContents(content="data", mime_type="text/plain")] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="custom", uri="custom://my-resource"), + types.Resource(name="file", uri="file:///path/to/file"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="data", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: resources = await client.list_resources() diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 44b17d337..2bccedf8d 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -1,83 +1,52 @@ """Test for base64 encoding issue in MCP server. -This test demonstrates the issue in server.py where the server uses -urlsafe_b64encode but the BlobResourceContents validator expects standard -base64 encoding. - -The test should FAIL before fixing server.py to use b64encode instead of -urlsafe_b64encode. -After the fix, the test should PASS. +This test verifies that binary resource data is encoded with standard base64 +(not urlsafe_b64encode), so BlobResourceContents validation succeeds. """ import base64 -from typing import cast import pytest -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import Server -from mcp.types import ( - BlobResourceContents, - ReadResourceRequest, - ReadResourceRequestParams, - ReadResourceResult, - ServerResult, -) +from mcp import Client +from mcp.server.mcpserver import MCPServer +from mcp.types import BlobResourceContents +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_server_base64_encoding_issue(): - """Tests that server response can be validated by BlobResourceContents. - This test will: - 1. Set up a server that returns binary data - 2. Extract the base64-encoded blob from the server's response - 3. Verify the encoded data can be properly validated by BlobResourceContents +async def test_server_base64_encoding(): + """Tests that binary resource data round-trips correctly through base64 encoding. - BEFORE FIX: The test will fail because server uses urlsafe_b64encode - AFTER FIX: The test will pass because server uses standard b64encode + The test uses binary data that produces different results with urlsafe vs standard + base64, ensuring the server uses standard encoding. """ - server = Server("test") + mcp = MCPServer("test") # Create binary data that will definitely result in + and / characters # when encoded with standard base64 binary_data = bytes(list(range(255)) * 4) - # Register a resource handler that returns our test data - @server.read_resource() - async def read_resource(uri: str) -> list[ReadResourceContents]: - return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[ReadResourceRequest] - - # Create a request - request = ReadResourceRequest( - params=ReadResourceRequestParams(uri="test://resource"), - ) - - # Call the handler to get the response - result: ServerResult = await handler(request) - - # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result) - blob_content = read_result.contents[0] - - # First verify our test data actually produces different encodings + # Sanity check: our test data produces different encodings urlsafe_b64 = base64.urlsafe_b64encode(binary_data).decode() standard_b64 = base64.b64encode(binary_data).decode() - assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate" - " encoding difference" + assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate encoding difference" + + @mcp.resource("test://binary", mime_type="application/octet-stream") + def get_binary() -> bytes: + """Return binary test data.""" + return binary_data + + async with Client(mcp) as client: + result = await client.read_resource("test://binary") + assert len(result.contents) == 1 - # Now validate the server's output with BlobResourceContents.model_validate - # Before the fix: This should fail with "Invalid base64" because server - # uses urlsafe_b64encode - # After the fix: This should pass because server will use standard b64encode - model_dict = blob_content.model_dump() + blob_content = result.contents[0] + assert isinstance(blob_content, BlobResourceContents) - # Direct validation - this will fail before fix, pass after fix - blob_model = BlobResourceContents.model_validate(model_dict) + # Verify standard base64 was used (not urlsafe) + assert blob_content.blob == standard_b64 - # Verify we can decode the data back correctly - decoded = base64.b64decode(blob_model.blob) - assert decoded == binary_data + # Verify we can decode the data back correctly + decoded = base64.b64decode(blob_content.blob) + assert decoded == binary_data diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index cd27698e6..9a039fefc 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,8 +1,6 @@ """Test to reproduce issue #88: Random error thrown on response.""" -from collections.abc import Sequence from pathlib import Path -from typing import Any import anyio import pytest @@ -11,10 +9,10 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import ContentBlock, TextContent +from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent @pytest.mark.anyio @@ -32,36 +30,37 @@ async def test_notification_validation_error(tmp_path: Path): - Slow operations use minimal timeout (10ms) for quick test execution """ - server = Server(name="test") request_count = 0 slow_request_lock = anyio.Event() - @server.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow", - description="A slow tool", - input_schema={"type": "object"}, - ), - types.Tool( - name="fast", - description="A fast tool", - input_schema={"type": "object"}, - ), - ] - - @server.call_tool() - async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + types.Tool( + name="slow", + description="A slow tool", + input_schema={"type": "object"}, + ), + types.Tool( + name="fast", + description="A fast tool", + input_schema={"type": "object"}, + ), + ] + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal request_count request_count += 1 - if name == "slow": + if params.name == "slow": await slow_request_lock.wait() # it should timeout here - return [TextContent(type="text", text=f"slow {request_count}")] - elif name == "fast": - return [TextContent(type="text", text=f"fast {request_count}")] - return [TextContent(type="text", text=f"unknown {request_count}")] # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"slow {request_count}")]) + elif params.name == "fast": + return CallToolResult(content=[TextContent(type="text", text=f"fast {request_count}")]) + pytest.fail(f"Unknown tool: {params.name}") + + server = Server(name="test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 6bf4cddb3..2c3d303a9 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -1,20 +1,16 @@ -"""Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" - -import warnings +"""Basic tests for list_prompts, list_resources, and list_tools handlers without pagination.""" import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, Prompt, Resource, - ServerResult, Tool, ) @@ -22,60 +18,44 @@ @pytest.mark.anyio async def test_list_prompts_basic() -> None: """Test basic prompt listing without pagination.""" - server = Server("test") - test_prompts = [ Prompt(name="prompt1", description="First prompt"), Prompt(name="prompt2", description="Second prompt"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return test_prompts - - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=test_prompts) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == test_prompts + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == test_prompts @pytest.mark.anyio async def test_list_resources_basic() -> None: """Test basic resource listing without pagination.""" - server = Server("test") - test_resources = [ Resource(uri="file:///test1.txt", name="Test 1"), Resource(uri="file:///test2.txt", name="Test 2"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return test_resources + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == test_resources + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == test_resources @pytest.mark.anyio async def test_list_tools_basic() -> None: """Test basic tool listing without pagination.""" - server = Server("test") - test_tools = [ Tool( name="tool1", @@ -102,80 +82,53 @@ async def test_list_tools_basic() -> None: ), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=test_tools) - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return test_tools - - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == test_tools + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == test_tools @pytest.mark.anyio async def test_list_prompts_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == [] + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == [] @pytest.mark.anyio async def test_list_resources_empty() -> None: """Test listing with empty results.""" - server = Server("test") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=[]) - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] - - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == [] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == [] @pytest.mark.anyio async def test_list_tools_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == [] + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == [] diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index 081fb262a..a4627b316 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -1,111 +1,83 @@ import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, PaginatedRequestParams, - ServerResult, ) @pytest.mark.anyio async def test_list_prompts_pagination() -> None: - server = Server("test") test_cursor = "test-cursor-123" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListPromptsRequest | None = None - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - nonlocal received_request - received_request = request + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + nonlocal received_params + received_params = params return ListPromptsResult(prompts=[], next_cursor="next") - handler = server.request_handlers[ListPromptsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + # No cursor provided + await client.list_prompts() + assert received_params is not None + assert received_params.cursor is None - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_prompts(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_resources_pagination() -> None: - server = Server("test") test_cursor = "resource-cursor-456" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListResourcesRequest | None = None - - @server.list_resources() - async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: - nonlocal received_request - received_request = request + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + nonlocal received_params + received_params = params return ListResourcesResult(resources=[], next_cursor="next") - handler = server.request_handlers[ListResourcesRequest] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + # No cursor provided + await client.list_resources() + assert received_params is not None + assert received_params.cursor is None - # Test: No cursor provided -> handler receives request with None params - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListResourcesRequest( - method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) - ) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_resources(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_tools_pagination() -> None: - server = Server("test") test_cursor = "tools-cursor-789" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListToolsRequest | None = None - - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: - nonlocal received_request - received_request = request + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + nonlocal received_params + received_params = params return ListToolsResult(tools=[], next_cursor="next") - handler = server.request_handlers[ListToolsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + # No cursor provided + await client.list_tools() + assert received_params is not None + assert received_params.cursor is None + + # Cursor provided + await client.list_tools(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 6d1634f2e..297f3d6a5 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -1,12 +1,10 @@ """Test that cancelled requests don't cause double responses.""" -from typing import Any - import anyio import pytest -from mcp import Client, types -from mcp.server.lowlevel.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.types import ( CallToolRequest, @@ -14,6 +12,9 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, + ListToolsResult, + PaginatedRequestParams, + TextContent, Tool, ) @@ -22,34 +23,34 @@ async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" - server = Server("test-server") - # Track tool calls call_count = 0 ev_first_call = anyio.Event() first_request_id = None - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="Tool for testing", - input_schema={}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="Tool for testing", + input_schema={}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal call_count, first_request_id - if name == "test_tool": + if params.name == "test_tool": call_count += 1 if call_count == 1: - first_request_id = server.request_context.request_id + first_request_id = ctx.request_id ev_first_call.set() await anyio.sleep(5) # First call is slow - return [types.TextContent(type="text", text=f"Call number: {call_count}")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async with Client(server) as client: # First request (will be cancelled) @@ -86,6 +87,6 @@ async def first_request(): # Type narrowing for pyright content = result.content[0] assert content.type == "text" - assert isinstance(content, types.TextContent) + assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 5a8d67f09..8d11a228d 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -1,15 +1,13 @@ """Tests for completion handler with context functionality.""" -from typing import Any - import pytest from mcp import Client -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.types import ( + CompleteRequestParams, + CompleteResult, Completion, - CompletionArgument, - CompletionContext, PromptReference, ResourceTemplateReference, ) @@ -18,23 +16,15 @@ @pytest.mark.anyio async def test_completion_handler_receives_context(): """Test that the completion handler receives context correctly.""" - server = Server("test-server") - # Track what the handler receives - received_args: dict[str, Any] = {} + received_params: CompleteRequestParams | None = None - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - received_args["ref"] = ref - received_args["argument"] = argument - received_args["context"] = context + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + nonlocal received_params + received_params = params + return CompleteResult(completion=Completion(values=["test-completion"], total=1, has_more=False)) - # Return test completion - return Completion(values=["test-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test with context @@ -45,28 +35,23 @@ async def handle_completion( ) # Verify handler received the context - assert received_args["context"] is not None - assert received_args["context"].arguments == {"previous": "value"} + assert received_params is not None + assert received_params.context is not None + assert received_params.context.arguments == {"previous": "value"} assert result.completion.values == ["test-completion"] @pytest.mark.anyio async def test_completion_backward_compatibility(): """Test that completion works without context (backward compatibility).""" - server = Server("test-server") - context_was_none = False - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: nonlocal context_was_none - context_was_none = context is None + context_was_none = params.context is None + return CompleteResult(completion=Completion(values=["no-context-completion"], total=1, has_more=False)) - return Completion(values=["no-context-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test without context @@ -82,30 +67,36 @@ async def handle_completion( @pytest.mark.anyio async def test_dependent_completion_scenario(): """Test a real-world scenario with dependent completions.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: # Simulate database/table completion scenario - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "database": - # Complete database names - return Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) - elif argument.name == "table": - # Complete table names based on selected database - if context and context.arguments: - db = context.arguments.get("database") + if isinstance(params.ref, ResourceTemplateReference): + if params.ref.uri == "db://{database}/{table}": + if params.argument.name == "database": + return CompleteResult( + completion=Completion( + values=["users_db", "products_db", "analytics_db"], total=3, has_more=False + ) + ) + elif params.argument.name == "table": + if params.context and params.context.arguments: + db = params.context.arguments.get("database") if db == "users_db": - return Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) + return CompleteResult( + completion=Completion( + values=["users", "sessions", "permissions"], total=3, has_more=False + ) + ) elif db == "products_db": - return Completion(values=["products", "categories", "inventory"], total=3, has_more=False) + return CompleteResult( + completion=Completion( + values=["products", "categories", "inventory"], total=3, has_more=False + ) + ) - return Completion(values=[], total=0, has_more=False) # pragma: no cover + pytest.fail("Unexpected completion request") + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # First, complete database @@ -136,27 +127,22 @@ async def handle_completion( @pytest.mark.anyio async def test_completion_error_on_missing_context(): """Test that server can raise error when required context is missing.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "table": - # Check if database context is provided - if not context or not context.arguments or "database" not in context.arguments: - # Raise an error instead of returning error as completion + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + if isinstance(params.ref, ResourceTemplateReference): + if params.ref.uri == "db://{database}/{table}": + if params.argument.name == "table": + if not params.context or not params.context.arguments or "database" not in params.context.arguments: raise ValueError("Please select a database first to see available tables") - # Normal completion if context is provided - db = context.arguments.get("database") + db = params.context.arguments.get("database") if db == "test_db": - return Completion(values=["users", "orders", "products"], total=3, has_more=False) + return CompleteResult( + completion=Completion(values=["users", "orders", "products"], total=3, has_more=False) + ) + + pytest.fail("Unexpected completion request") - return Completion(values=[], total=0, has_more=False) # pragma: no cover + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Try to complete table without database context - should raise error diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a303664a5..0f8840d29 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -2,18 +2,20 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import pytest from pydantic import TypeAdapter +from mcp.server import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( + CallToolRequestParams, + CallToolResult, ClientCapabilities, Implementation, InitializeRequestParams, @@ -39,20 +41,20 @@ async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: finally: context["shutdown"] = True - server = Server[dict[str, bool]]("test", lifespan=test_lifespan) - - # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) - # Create a tool that accesses lifespan context - @server.call_tool() - async def check_lifespan(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def check_lifespan( + ctx: ServerRequestContext[dict[str, bool]], params: CallToolRequestParams + ) -> CallToolResult: assert isinstance(ctx.lifespan_context, dict) assert ctx.lifespan_context["started"] assert not ctx.lifespan_context["shutdown"] - return [TextContent(type="text", text="true")] + return CallToolResult(content=[TextContent(type="text", text="true")]) + + server = Server[dict[str, bool]]("test", lifespan=test_lifespan, on_call_tool=check_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 68543136e..705abdfe8 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -1,100 +1,44 @@ """Tests for tool annotations in low-level server.""" -import anyio import pytest -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams, Tool, ToolAnnotations @pytest.mark.anyio async def test_lowlevel_server_tool_annotations(): """Test that tool annotations work in low-level server.""" - server = Server("test") - # Create a tool with annotations - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="echo", - description="Echo a message back", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + description="Echo a message back", + input_schema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + "required": ["message"], }, - "required": ["message"], - }, - annotations=ToolAnnotations( - title="Echo Tool", - read_only_hint=True, - ), - ) - ] - - tools_result = None - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # List tools - tools_result = await client_session.list_tools() - - # Cancel the server task - tg.cancel_scope.cancel() - - # Verify results - assert tools_result is not None - assert len(tools_result.tools) == 1 - assert tools_result.tools[0].name == "echo" - assert tools_result.tools[0].annotations is not None - assert tools_result.tools[0].annotations.title == "Echo Tool" - assert tools_result.tools[0].annotations.read_only_hint is True + annotations=ToolAnnotations( + title="Echo Tool", + read_only_hint=True, + ), + ) + ] + ) + + server = Server("test", on_list_tools=handle_list_tools) + + async with Client(server) as client: + tools_result = await client.list_tools() + + assert len(tools_result.tools) == 1 + assert tools_result.tools[0].name == "echo" + assert tools_result.tools[0].annotations is not None + assert tools_result.tools[0].annotations.title == "Echo Tool" + assert tools_result.tools[0].annotations.read_only_hint is True diff --git a/tests/server/test_session.py b/tests/server/test_session.py index d353e46e4..a2786d865 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -5,7 +5,7 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -14,17 +14,10 @@ from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, - Completion, - CompletionArgument, - CompletionContext, CompletionsCapability, InitializedNotification, - Prompt, - PromptReference, PromptsCapability, - Resource, ResourcesCapability, - ResourceTemplateReference, ServerCapabilities, ) @@ -85,47 +78,50 @@ async def run_server(): @pytest.mark.anyio async def test_server_capabilities(): - server = Server("test") notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} - # Initially no capabilities + async def noop_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + raise NotImplementedError + + async def noop_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + raise NotImplementedError + + async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + raise NotImplementedError + + # No capabilities + server = Server("test") caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts is None assert caps.resources is None assert caps.completions is None - # Add a prompts handler - @server.list_prompts() - async def list_prompts() -> list[Prompt]: # pragma: no cover - return [] - + # With prompts handler + server = Server("test", on_list_prompts=noop_list_prompts) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources is None assert caps.completions is None - # Add a resources handler - @server.list_resources() - async def list_resources() -> list[Resource]: # pragma: no cover - return [] - + # With prompts + resources handlers + server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) assert caps.completions is None - # Add a complete handler - @server.completion() - async def complete( # pragma: no cover - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - return Completion( - values=["completion1", "completion2"], - ) - + # With prompts + resources + completion handlers + server = Server( + "test", + on_list_prompts=noop_list_prompts, + on_list_resources=noop_list_resources, + on_completion=noop_completion, + ) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 51572baa9..e8870d9a9 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -9,13 +9,12 @@ import pytest from starlette.types import Message -from mcp import Client, types +from mcp import Client from mcp.client.streamable_http import streamable_http_client -from mcp.server import streamable_http_manager -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext, streamable_http_manager from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.types import INVALID_REQUEST +from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @pytest.mark.anyio @@ -321,12 +320,11 @@ async def mock_receive(): @pytest.mark.anyio async def test_e2e_streamable_http_server_cleanup(): host = "testserver" - app = Server("test-server") - @app.list_tools() - async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: - return types.ListToolsResult(tools=[]) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + app = Server("test-server", on_list_tools=handle_list_tools) mcp_app = app.streamable_http_app(host=host) async with ( mcp_app.router.lifespan_context(mcp_app), diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py deleted file mode 100644 index 31238b9ff..000000000 --- a/tests/shared/test_memory.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from mcp import Client -from mcp.server import Server -from mcp.types import EmptyResult, Resource - - -@pytest.fixture -def mcp_server() -> Server: - server = Server(name="test_server") - - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - - return server - - -@pytest.mark.anyio -async def test_memory_server_and_client_connection(mcp_server: Server): - """Shows how a client and server can communicate over memory streams.""" - async with Client(mcp_server) as client: - response = await client.send_ping() - assert isinstance(response, EmptyResult) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ab117f1f0..6b87774c0 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -6,7 +6,7 @@ from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -35,9 +35,6 @@ async def run_server(): capabilities=server.get_capabilities(NotificationOptions(), {}), ), ) as server_session: - global serv_sesh - - serv_sesh = server_session async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) @@ -52,79 +49,73 @@ async def run_server(): server_progress_token = "server_token_123" client_progress_token = "client_token_456" - # Create a server with progress capability - server = Server(name="ProgressTestServer") - # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: server_progress_updates.append( { - "token": progress_token, - "progress": progress, - "total": total, - "message": message, + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, } ) # Register list tool handler - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: # Make sure we received a progress token - if name == "test_tool": - if arguments and "_meta" in arguments: - progressToken = arguments["_meta"]["progressToken"] - - if not progressToken: # pragma: no cover - raise ValueError("Empty progress token received") - - if progressToken != client_progress_token: # pragma: no cover - raise ValueError("Server sending back incorrect progressToken") - - # Send progress notifications - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.25, - total=1.0, - message="Server progress 25%", - ) + if params.name == "test_tool": + assert params.meta is not None + progress_token = params.meta.get("progress_token") + assert progress_token is not None + assert progress_token == client_progress_token + + # Send progress notifications using ctx.session + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.5, - total=1.0, - message="Server progress 50%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=1.0, - total=1.0, - message="Server progress 100%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) - else: # pragma: no cover - raise ValueError("Progress token not sent.") + return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) - return [types.TextContent(type="text", text="Tool executed successfully")] + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + # Create a server with progress capability + server = Server( + name="ProgressTestServer", + on_progress=handle_progress, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Client message handler to store progress notifications async def handle_client_message( @@ -164,7 +155,7 @@ async def handle_client_message( await client_session.list_tools() # Call test_tool with progress token - await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}}) + await client_session.call_tool("test_tool", meta={"progress_token": client_progress_token}) # Send progress notifications from client to server await client_session.send_progress_notification( @@ -217,22 +208,21 @@ async def test_progress_context_manager(): # Track progress updates server_progress_updates: list[dict[str, Any]] = [] - server = Server(name="ProgressContextTestServer") - progress_token = None # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: server_progress_updates.append( - {"token": progress_token, "progress": progress, "total": total, "message": message} + { + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, + } ) + server = Server(name="ProgressContextTestServer", on_progress=handle_progress) + # Run server session to receive progress updates async def run_server(): # Create a server session @@ -334,30 +324,37 @@ async def failing_progress_callback(progress: float, total: float | None, messag raise ValueError("Progress callback failed!") # Create a server with a tool that sends progress notifications - server = Server(name="TestProgressServer") - - @server.call_tool() - async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: - if name == "progress_tool": + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "progress_tool": + assert ctx.request_id is not None # Send a progress notification - await server.request_context.session.send_progress_notification( - progress_token=server.request_context.request_id, + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, progress=50.0, total=100.0, message="Halfway done", ) - return [types.TextContent(type="text", text="progress_result")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] + return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + input_schema={}, + ) + ] + ) + + server = Server( + name="TestProgressServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Test with mocked logging with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 182b4671d..42c969148 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,11 +1,9 @@ -from typing import Any - import anyio import pytest from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage @@ -17,7 +15,6 @@ JSONRPCError, JSONRPCRequest, JSONRPCResponse, - TextContent, ) @@ -42,29 +39,33 @@ async def test_request_cancellation(): request_id = None # Create a server with a slow tool - server = Server(name="TestSessionServer") - - # Register the tool handler - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: nonlocal request_id, ev_tool_called - if name == "slow_tool": - request_id = server.request_context.request_id + if params.name == "slow_tool": + request_id = ctx.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel - return [] # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - # Register the tool so it shows up in list_tools - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow_tool", - description="A slow tool that takes 10 seconds to complete", - input_schema={}, - ) - ] + return types.CallToolResult(content=[]) # pragma: no cover + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + input_schema={}, + ) + ] + ) + + server = Server( + name="TestSessionServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) async def make_request(client: Client): nonlocal ev_cancelled