diff --git a/src/mcp/server/fastmcp/resources/__init__.py b/src/mcp/server/fastmcp/resources/__init__.py index b5805fb348..fe80b67e12 100644 --- a/src/mcp/server/fastmcp/resources/__init__.py +++ b/src/mcp/server/fastmcp/resources/__init__.py @@ -1,4 +1,5 @@ from .base import Resource +from .async_resource import AsyncResource, AsyncStatus from .resource_manager import ResourceManager from .templates import ResourceTemplate from .types import ( @@ -12,6 +13,8 @@ __all__ = [ "Resource", + "AsyncResource", + "AsyncStatus", "TextResource", "BinaryResource", "FunctionResource", diff --git a/src/mcp/server/fastmcp/resources/async_resource.py b/src/mcp/server/fastmcp/resources/async_resource.py new file mode 100644 index 0000000000..94dfa89e4a --- /dev/null +++ b/src/mcp/server/fastmcp/resources/async_resource.py @@ -0,0 +1,176 @@ +"""Asynchronous resource implementation for long-running operations.""" + +import asyncio +import enum +# import time +from typing import Any, Optional + +import pydantic +from pydantic import Field + +from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class AsyncStatus(str, enum.Enum): + """Status of an asynchronous operation.""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELED = "canceled" + + +class AsyncResource(Resource): + """A resource representing an asynchronous operation. + + This resource type is used to track long-running operations that are executed + asynchronously. It provides methods for updating progress, completing with a result, + failing with an error, and canceling the operation. + """ + + status: AsyncStatus = Field( + default=AsyncStatus.PENDING, + description="Current status of the asynchronous operation" + ) + # progress: float = Field( + # default=0, + # description="Current progress value (0-100 or raw count)" + # ) + error: Optional[str] = Field( + default=None, + description="Error message if the operation failed" + ) + # created_at: float = Field( + # default_factory=time.time, + # description="Timestamp when the resource was created" + # ) + # started_at: Optional[float] = Field( + # default=None, + # description="Timestamp when the operation started running" + # ) + # completed_at: Optional[float] = Field( + # default=None, + # description="Timestamp when the operation completed, failed, or was canceled" + # ) + + # Fields not included in serialization + _task: Optional[asyncio.Task[Any]] = pydantic.PrivateAttr(default=None) + # _mcp_server = pydantic.PrivateAttr(default=None) + + # def set_mcp_server(self, server: Any) -> None: + # """Set the MCP server reference. + + # Args: + # server: The MCP server instance + # """ + # self._mcp_server = server + + async def read(self) -> str: + """Read the current state of the resource as JSON. + + Returns the current status and progress information. + """ + # Convert the resource to a dictionary, excluding private fields + data = self.model_dump(exclude={"_task"}) + + # Return status info as JSON + import json + return json.dumps(data, indent=2) + + async def start(self, task: asyncio.Task[Any]) -> None: + """Mark the resource as running and store the task. + + Args: + task: The asyncio task that is executing the operation + """ + self._task = task + self.status = AsyncStatus.RUNNING + # self.started_at = time.time() + # await self._notify_changed() + + logger.debug( + "Started async operation", + extra={ + "uri": self.uri, + } + ) + + # async def update_progress(self, progress: float) -> None: + # """Update the progress information. + + # Args: + # progress: Current progress value + # total: Total expected progress value, if known + # """ + # self.progress = progress + # # await self._notify_changed() + + # logger.debug( + # "Updated async operation progress", + # extra={ + # "uri": self.uri, + # "progress": self.progress, + # } + # ) + + async def complete(self) -> None: + """Mark the resource as completed. + """ + self.status = AsyncStatus.COMPLETED + # self.completed_at = time.time() + + # await self._notify_changed() + + logger.info( + "Completed async operation", + extra={ + "uri": self.uri, + # "duration": self.completed_at - (self.started_at or self.created_at), + } + ) + + async def fail(self, error: str) -> None: + """Mark the resource as failed and store the error. + + Args: + error: Error message describing why the operation failed + """ + self.status = AsyncStatus.FAILED + self.error = error + # self.completed_at = time.time() + # await self._notify_changed() + + logger.error( + "Failed async operation", + extra={ + "uri": self.uri, + "error": error, + # "duration": self.completed_at - (self.started_at or self.created_at), + } + ) + + async def cancel(self) -> None: + """Cancel the operation if it's still running.""" + if self.status in (AsyncStatus.PENDING, AsyncStatus.RUNNING) and self._task: + self._task.cancel() + self.status = AsyncStatus.CANCELED + # self.completed_at = time.time() + # await self._notify_changed() + + logger.info( + "Canceled async operation", + extra={ + "uri": self.uri, + # "duration": self.completed_at - (self.started_at or self.created_at), + } + ) + + # async def _notify_changed(self) -> None: + # """Notify subscribers that the resource has changed.""" + # if self._mcp_server: + # # This will be implemented in the MCP server to notify clients + # # of resource changes via the notification protocol + # self._mcp_server.notify_resource_changed(self.uri) \ No newline at end of file diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index d27e6ac126..e6ffb853cb 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -1,11 +1,13 @@ """Resource manager functionality.""" +import uuid from collections.abc import Callable from typing import Any from pydantic import AnyUrl from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.resources.async_resource import AsyncResource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger @@ -19,6 +21,7 @@ def __init__(self, warn_on_duplicate_resources: bool = True): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + # self._mcp_server = None def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. @@ -93,3 +96,41 @@ def list_templates(self) -> list[ResourceTemplate]: """List all registered templates.""" logger.debug("Listing templates", extra={"count": len(self._templates)}) return list(self._templates.values()) + + # def set_mcp_server(self, server: Any) -> None: + # """Set the MCP server reference. + + # This allows resources to notify the server when they change. + + # Args: + # server: The MCP server instance + # """ + # self._mcp_server = server + + def create_async_resource( + self, + name: str | None = None, + description: str | None = None, + ) -> AsyncResource: + """Create a new async resource. + + Args: + name: Optional name for the resource + description: Optional description of the resource + + Returns: + A new AsyncResource instance + """ + resource_uri = f"resource://tasks/{uuid.uuid4()}" + resource = AsyncResource( + uri=AnyUrl(resource_uri), + name=name, + description=description, + ) + + # # Set the MCP server reference if available + # if self._mcp_server: + # resource.set_mcp_server(self._mcp_server) + + self.add_resource(resource) + return resource diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c31f29d4c3..4628bf62f0 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -161,6 +161,11 @@ def __init__( self._prompt_manager = PromptManager( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) + + # Connect the resource manager and tool manager + self._tool_manager.set_resource_manager(self._resource_manager) + self._resource_manager.set_mcp_server(self._mcp_server) + if (self.settings.auth is not None) != (auth_server_provider is not None): # TODO: after we support separate authorization servers (see # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) @@ -321,6 +326,7 @@ def add_tool( name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, + async_supported: bool = False, ) -> None: """Add a tool to the server. @@ -332,9 +338,11 @@ def add_tool( name: Optional name for the tool (defaults to function name) description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + async_supported: Whether this tool supports asynchronous execution """ self._tool_manager.add_tool( - fn, name=name, description=description, annotations=annotations + fn, name=name, description=description, annotations=annotations, + async_supported=async_supported ) def tool( @@ -342,6 +350,7 @@ def tool( name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, + async_supported: bool = False, ) -> Callable[[AnyFunction], AnyFunction]: """Decorator to register a tool. @@ -353,6 +362,7 @@ def tool( name: Optional name for the tool (defaults to function name) description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + async_supported: Whether this tool supports asynchronous execution Example: @server.tool() @@ -368,6 +378,16 @@ def tool_with_context(x: int, ctx: Context) -> str: async def async_tool(x: int, context: Context) -> str: await context.report_progress(50, 100) return str(x) + + @server.tool(async_supported=True) + async def long_running_tool(x: int, context: Context) -> str: + # This tool will be executed asynchronously + # The client will receive a resource URI immediately + # and can track progress through that resource + for i in range(100): + await asyncio.sleep(0.1) + await context.report_progress(i, 100) + return f"Processed {x}" """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -378,7 +398,8 @@ async def async_tool(x: int, context: Context) -> str: def decorator(fn: AnyFunction) -> AnyFunction: self.add_tool( - fn, name=name, description=description, annotations=annotations + fn, name=name, description=description, annotations=annotations, + async_supported=async_supported ) return fn @@ -917,14 +938,38 @@ def my_tool(x: int, ctx: Context) -> str: client_id = ctx.client_id return str(x) + + @server.tool(async_supported=True) + async def long_running_tool(x: int, ctx: Context) -> str: + # For async tools, the context.resource will be set to an AsyncResource + # that can be used to update progress and status + + total_steps = 100 + for i in range(total_steps): + # Do some work + await asyncio.sleep(0.1) + + # Update progress through the AsyncResource + if ctx.resource: + await ctx.resource.update_progress(i, total_steps) + + # You can also use the standard progress reporting + await ctx.report_progress(i, total_steps) + + return f"Processed {x}" ``` The context parameter name can be anything as long as it's annotated with Context. The context is optional - tools that don't need it can omit the parameter. + + For asynchronous tools (marked with async_supported=True), the context will have + a resource attribute set to an AsyncResource instance that can be used to update + progress and status information. """ _request_context: RequestContext[ServerSessionT, LifespanContextT] | None _fastmcp: FastMCP | None + resource: Any = None # Can hold a reference to an AsyncResource for async operations def __init__( self, diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 21eb1841d9..9d825de128 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -34,6 +34,9 @@ class Tool(BaseModel): annotations: ToolAnnotations | None = Field( None, description="Optional annotations for the tool" ) + async_supported: bool = Field( + False, description="Whether this tool supports asynchronous execution" + ) @classmethod def from_function( @@ -43,6 +46,7 @@ def from_function( description: str | None = None, context_kwarg: str | None = None, annotations: ToolAnnotations | None = None, + async_supported: bool = False, ) -> Tool: """Create a Tool from a function.""" from mcp.server.fastmcp.server import Context @@ -79,6 +83,7 @@ def from_function( is_async=is_async, context_kwarg=context_kwarg, annotations=annotations, + async_supported=async_supported, ) async def run( diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index cfdaeb350f..186d04ab3c 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import asyncio from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -11,6 +12,7 @@ if TYPE_CHECKING: from mcp.server.fastmcp.server import Context + from mcp.server.fastmcp.resources.resource_manager import ResourceManager from mcp.server.session import ServerSessionT logger = get_logger(__name__) @@ -22,6 +24,7 @@ class ToolManager: def __init__(self, warn_on_duplicate_tools: bool = True): self._tools: dict[str, Tool] = {} self.warn_on_duplicate_tools = warn_on_duplicate_tools + self._resource_manager = None def get_tool(self, name: str) -> Tool | None: """Get tool by name.""" @@ -30,6 +33,14 @@ def get_tool(self, name: str) -> Tool | None: def list_tools(self) -> list[Tool]: """List all registered tools.""" return list(self._tools.values()) + + def set_resource_manager(self, resource_manager: ResourceManager) -> None: + """Set the resource manager reference. + + Args: + resource_manager: The ResourceManager instance + """ + self._resource_manager = resource_manager def add_tool( self, @@ -37,10 +48,12 @@ def add_tool( name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, + async_supported: bool = False, ) -> Tool: """Add a tool to the server.""" tool = Tool.from_function( - fn, name=name, description=description, annotations=annotations + fn, name=name, description=description, annotations=annotations, + async_supported=async_supported ) existing = self._tools.get(tool.name) if existing: @@ -60,5 +73,46 @@ async def call_tool( tool = self.get_tool(name) if not tool: raise ToolError(f"Unknown tool: {name}") - - return await tool.run(arguments, context=context) + + # Check if the tool supports async execution + if tool.async_supported and self._resource_manager: + # Create an async resource + resource = self._resource_manager.create_async_resource( + name=tool.name, + description=tool.description, + ) + + # Set the resource in the context if provided + if context: + context.resource = resource + + # Create a task to run the tool + async def run_tool_async(): + try: + # Run the tool + result = await tool.run(arguments, context=context) + + # Mark the resource as completed + await resource.complete() + + return result + except Exception as e: + # Mark the resource as failed + await resource.fail(str(e)) + raise + + # Create and start the task + task = asyncio.create_task(run_tool_async()) + + # Start the resource + await resource.start(task) + + # Return the resource URI + return { + "type": "resource", + "uri": str(resource.uri), + "status": resource.status.value, + } + else: + # Run the tool synchronously + return await tool.run(arguments, context=context) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4b97b33dad..bc7e017c46 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -358,6 +358,7 @@ def subscribe_resource(self): def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for SubscribeRequest") + print('Registering handler for SubscribeRequest') async def handler(req: types.SubscribeRequest): await func(req.params.uri) return types.ServerResult(types.EmptyResult()) @@ -542,7 +543,8 @@ async def _handle_request( lifespan_context: LifespanResultT, raise_exceptions: bool, ): - logger.info(f"Processing request of type {type(req).__name__}") + logger.info(f"Process request of type {type(req).__name__}") + logger.info(f"Request handlers {self.request_handlers}") if type(req) in self.request_handlers: handler = self.request_handlers[type(req)] logger.debug(f"Dispatching request of type {type(req).__name__}")