diff --git a/README.v2.md b/README.v2.md index 55d867586..8049d1020 100644 --- a/README.v2.md +++ b/README.v2.md @@ -532,7 +532,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001 self.setting1 = setting1 self.setting2 = setting2 @@ -744,7 +744,7 @@ server_params = StdioServerParameters( ) -async def run(): +async def run() -> None: """Run the completion client example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -795,7 +795,7 @@ async def run(): print(f"Completions for 'style' argument: {result.completion.values}") -def main(): +def main() -> None: """Entry point for the completion client.""" asyncio.run(run()) @@ -1210,7 +1210,7 @@ def hello(name: str = "World") -> str: return f"Hello, {name}!" -def main(): +def main() -> None: """Entry point for the direct execution server.""" mcp.run() @@ -1280,6 +1280,7 @@ uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -1308,7 +1309,7 @@ def add_two(n: int) -> int: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(echo_mcp.session_manager.run()) await stack.enter_async_context(math_mcp.session_manager.run()) @@ -1392,6 +1393,7 @@ Run from the repository root: """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -1410,7 +1412,7 @@ def hello() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield @@ -1439,6 +1441,7 @@ Run from the repository root: """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Host @@ -1457,7 +1460,7 @@ def domain_info() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield @@ -1486,6 +1489,7 @@ Run from the repository root: """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -1511,7 +1515,7 @@ def send_message(message: str) -> str: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(api_mcp.session_manager.run()) await stack.enter_async_context(chat_mcp.session_manager.run()) @@ -1718,7 +1722,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the server with lifespan management.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -1796,7 +1800,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -1889,7 +1893,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the structured output server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -1964,7 +1968,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -2127,7 +2131,7 @@ async def handle_sampling_message( ) -async def run(): +async def run() -> None: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection @@ -2165,7 +2169,7 @@ async def run(): print(f"Structured tool result: {result_structured}") -def main(): +def main() -> None: """Entry point for the client script.""" asyncio.run(run()) @@ -2191,7 +2195,7 @@ from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client -async def main(): +async def main() -> None: # Connect to a streamable HTTP server async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream): # Create a session using the client streams @@ -2235,7 +2239,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientSession) -> None: """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2247,7 +2251,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientSession) -> None: """Display available resources with human-readable names""" resources_response = await session.list_resources() @@ -2261,7 +2265,7 @@ async def display_resources(session: ClientSession): print(f"Resource Template: {display_name}") -async def run(): +async def run() -> None: """Run the display utilities example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -2275,7 +2279,7 @@ async def run(): await display_resources(session) -def main(): +def main() -> None: """Entry point for the display utilities client.""" asyncio.run(run()) @@ -2354,7 +2358,7 @@ async def handle_callback() -> tuple[str, str | None]: return params["code"][0], params.get("state", [None])[0] -async def main(): +async def main() -> None: """Run the OAuth client example.""" oauth_auth = OAuthClientProvider( server_url="http://localhost:8001", @@ -2382,7 +2386,7 @@ async def main(): print(f"Available resources: {[r.uri for r in resources.resources]}") -def run(): +def run() -> None: asyncio.run(main()) diff --git a/examples/mcpserver/memory.py b/examples/mcpserver/memory.py index fd0bd9362..5a7f6ab4d 100644 --- a/examples/mcpserver/memory.py +++ b/examples/mcpserver/memory.py @@ -50,11 +50,17 @@ def cosine_similarity(a: list[float], b: list[float]) -> float: return np.dot(a_array, b_array) / (np.linalg.norm(a_array) * np.linalg.norm(b_array)) +@dataclass +class Deps: + openai: AsyncOpenAI + pool: asyncpg.Pool + + async def do_ai( user_prompt: str, system_prompt: str, result_type: type[T] | Annotated, - deps=None, + deps: Deps | None = None, ) -> T: agent = Agent( DEFAULT_LLM_MODEL, @@ -65,14 +71,8 @@ async def do_ai( return result.data -@dataclass -class Deps: - openai: AsyncOpenAI - pool: asyncpg.Pool - - async def get_db_pool() -> asyncpg.Pool: - async def init(conn): + async def init(conn: asyncpg.Connection) -> None: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") await register_vector(conn) @@ -90,11 +90,11 @@ class MemoryNode(BaseModel): embedding: list[float] @classmethod - async def from_content(cls, content: str, deps: Deps): + async def from_content(cls, content: str, deps: Deps) -> Self: embedding = await get_embedding(content, deps) return cls(content=content, embedding=embedding) - async def save(self, deps: Deps): + async def save(self, deps: Deps) -> None: async with deps.pool.acquire() as conn: if self.id is None: result = await conn.fetchrow( @@ -129,7 +129,7 @@ async def save(self, deps: Deps): self.id, ) - async def merge_with(self, other: Self, deps: Deps): + async def merge_with(self, other: Self, deps: Deps) -> None: self.content = await do_ai( f"{self.content}\n\n{other.content}", "Combine the following two texts into a single, coherent text.", @@ -145,7 +145,7 @@ async def merge_with(self, other: Self, deps: Deps): if other.id is not None: await delete_memory(other.id, deps) - def get_effective_importance(self): + def get_effective_importance(self) -> float: return self.importance * (1 + math.log(self.access_count + 1)) @@ -157,12 +157,12 @@ async def get_embedding(text: str, deps: Deps) -> list[float]: return embedding_response.data[0].embedding -async def delete_memory(memory_id: int, deps: Deps): +async def delete_memory(memory_id: int, deps: Deps) -> None: async with deps.pool.acquire() as conn: await conn.execute("DELETE FROM memories WHERE id = $1", memory_id) -async def add_memory(content: str, deps: Deps): +async def add_memory(content: str, deps: Deps) -> str: new_memory = await MemoryNode.from_content(content, deps) await new_memory.save(deps) @@ -204,7 +204,7 @@ async def find_similar_memories(embedding: list[float], deps: Deps) -> list[Memo return memories -async def update_importance(user_embedding: list[float], deps: Deps): +async def update_importance(user_embedding: list[float], deps: Deps) -> None: async with deps.pool.acquire() as conn: rows = await conn.fetch("SELECT id, importance, access_count, embedding FROM memories") for row in rows: @@ -228,7 +228,7 @@ async def update_importance(user_embedding: list[float], deps: Deps): ) -async def prune_memories(deps: Deps): +async def prune_memories(deps: Deps) -> None: async with deps.pool.acquire() as conn: rows = await conn.fetch( """ @@ -265,7 +265,7 @@ async def display_memory_tree(deps: Deps) -> str: @mcp.tool() async def remember( contents: list[str] = Field(description="List of observations or memories to store"), -): +) -> str: deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool()) try: return "\n".join(await asyncio.gather(*[add_memory(content, deps) for content in contents])) @@ -281,7 +281,7 @@ async def read_profile() -> str: return profile -async def initialize_database(): +async def initialize_database() -> None: pool = await asyncpg.create_pool("postgresql://postgres:postgres@localhost:54320/postgres") try: async with pool.acquire() as conn: diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 9d13fffe4..996cbfa44 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -134,7 +134,7 @@ async def introspect_handler(request: Request) -> Response: return Starlette(routes=routes) -async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings): +async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings) -> None: """Run the Authorization Server.""" auth_server = create_authorization_server(server_settings, auth_settings) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 3a3895cc5..9bd6e3c4e 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -66,7 +66,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """Get OAuth client information.""" return self.clients.get(client_id) - async def register_client(self, client_info: OAuthClientInformationFull): + async def register_client(self, client_info: OAuthClientInformationFull) -> None: """Register a new OAuth client.""" if not client_info.client_id: raise ValueError("No client_id provided") diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index 95fb90854..a11b420e4 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -75,7 +75,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ ) -async def run(): +async def run() -> None: """Run the low-level server using stdio transport.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/clients/completion_client.py b/examples/snippets/clients/completion_client.py index dc0c1b4f7..9ef10c5d4 100644 --- a/examples/snippets/clients/completion_client.py +++ b/examples/snippets/clients/completion_client.py @@ -17,7 +17,7 @@ ) -async def run(): +async def run() -> None: """Run the completion client example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -68,7 +68,7 @@ async def run(): print(f"Completions for 'style' argument: {result.completion.values}") -def main(): +def main() -> None: """Entry point for the completion client.""" asyncio.run(run()) diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index baa2765a8..1a5e966ed 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -17,7 +17,7 @@ ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientSession) -> None: """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -29,7 +29,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientSession) -> None: """Display available resources with human-readable names""" resources_response = await session.list_resources() @@ -43,7 +43,7 @@ async def display_resources(session: ClientSession): print(f"Resource Template: {display_name}") -async def run(): +async def run() -> None: """Run the display utilities example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -57,7 +57,7 @@ async def run(): await display_resources(session) -def main(): +def main() -> None: """Entry point for the display utilities client.""" asyncio.run(run()) diff --git a/examples/snippets/clients/oauth_client.py b/examples/snippets/clients/oauth_client.py index 3887c5c8c..115a6b6c8 100644 --- a/examples/snippets/clients/oauth_client.py +++ b/examples/snippets/clients/oauth_client.py @@ -52,7 +52,7 @@ async def handle_callback() -> tuple[str, str | None]: return params["code"][0], params.get("state", [None])[0] -async def main(): +async def main() -> None: """Run the OAuth client example.""" oauth_auth = OAuthClientProvider( server_url="http://localhost:8001", @@ -80,7 +80,7 @@ async def main(): print(f"Available resources: {[r.uri for r in resources.resources]}") -def run(): +def run() -> None: asyncio.run(main()) diff --git a/examples/snippets/clients/parsing_tool_results.py b/examples/snippets/clients/parsing_tool_results.py index b16640677..945fa9d01 100644 --- a/examples/snippets/clients/parsing_tool_results.py +++ b/examples/snippets/clients/parsing_tool_results.py @@ -6,7 +6,7 @@ from mcp.client.stdio import stdio_client -async def parse_tool_results(): +async def parse_tool_results() -> None: """Demonstrates how to parse different types of content in CallToolResult.""" server_params = StdioServerParameters(command="python", args=["path/to/mcp_server.py"]) @@ -52,7 +52,7 @@ async def parse_tool_results(): print(f"Error: {content.text}") -async def main(): +async def main() -> None: await parse_tool_results() diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index c1f85f42a..44c8bd45b 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -33,7 +33,7 @@ async def handle_sampling_message( ) -async def run(): +async def run() -> None: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection @@ -71,7 +71,7 @@ async def run(): print(f"Structured tool result: {result_structured}") -def main(): +def main() -> None: """Entry point for the client script.""" asyncio.run(run()) diff --git a/examples/snippets/clients/streamable_basic.py b/examples/snippets/clients/streamable_basic.py index 43bb6396c..d64a05bf3 100644 --- a/examples/snippets/clients/streamable_basic.py +++ b/examples/snippets/clients/streamable_basic.py @@ -8,7 +8,7 @@ from mcp.client.streamable_http import streamable_http_client -async def main(): +async def main() -> None: # Connect to a streamable HTTP server async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream): # Create a session using the client streams diff --git a/examples/snippets/servers/__init__.py b/examples/snippets/servers/__init__.py index f132f875f..94220c728 100644 --- a/examples/snippets/servers/__init__.py +++ b/examples/snippets/servers/__init__.py @@ -12,7 +12,7 @@ from typing import Literal, cast -def run_server(): +def run_server() -> None: """Run a server by name with optional transport. Usage: server [transport] diff --git a/examples/snippets/servers/direct_execution.py b/examples/snippets/servers/direct_execution.py index acf7151d3..3fed250ef 100644 --- a/examples/snippets/servers/direct_execution.py +++ b/examples/snippets/servers/direct_execution.py @@ -18,7 +18,7 @@ def hello(name: str = "World") -> str: return f"Hello, {name}!" -def main(): +def main() -> None: """Entry point for the direct execution server.""" mcp.run() diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index 81f40e994..1981d22b0 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -49,7 +49,7 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe ) -async def run(): +async def run() -> None: """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 7e8fc4dcb..ba44fac9e 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -48,7 +48,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ ) -async def run(): +async def run() -> None: """Run the server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index bcd96c893..d72c93c2f 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -85,7 +85,7 @@ async def handle_call_tool( ) -async def run(): +async def run() -> None: """Run the server with lifespan management.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index f93c8875f..71107b792 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -66,7 +66,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ ) -async def run(): +async def run() -> None: """Run the structured output server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/streamable_http_basic_mounting.py b/examples/snippets/servers/streamable_http_basic_mounting.py index 9a53034f1..e1687ce99 100644 --- a/examples/snippets/servers/streamable_http_basic_mounting.py +++ b/examples/snippets/servers/streamable_http_basic_mounting.py @@ -5,6 +5,7 @@ """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -23,7 +24,7 @@ def hello() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield diff --git a/examples/snippets/servers/streamable_http_host_mounting.py b/examples/snippets/servers/streamable_http_host_mounting.py index 2a41f74a5..73cbdd54d 100644 --- a/examples/snippets/servers/streamable_http_host_mounting.py +++ b/examples/snippets/servers/streamable_http_host_mounting.py @@ -5,6 +5,7 @@ """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Host @@ -23,7 +24,7 @@ def domain_info() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield diff --git a/examples/snippets/servers/streamable_http_multiple_servers.py b/examples/snippets/servers/streamable_http_multiple_servers.py index 71217bdfe..b95e34d22 100644 --- a/examples/snippets/servers/streamable_http_multiple_servers.py +++ b/examples/snippets/servers/streamable_http_multiple_servers.py @@ -5,6 +5,7 @@ """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -30,7 +31,7 @@ def send_message(message: str) -> str: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(api_mcp.session_manager.run()) await stack.enter_async_context(chat_mcp.session_manager.run()) diff --git a/examples/snippets/servers/streamable_starlette_mount.py b/examples/snippets/servers/streamable_starlette_mount.py index eb6f1b809..ae97982ff 100644 --- a/examples/snippets/servers/streamable_starlette_mount.py +++ b/examples/snippets/servers/streamable_starlette_mount.py @@ -3,6 +3,7 @@ """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -31,7 +32,7 @@ def add_two(n: int) -> int: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(echo_mcp.session_manager.run()) await stack.enter_async_context(math_mcp.session_manager.run()) diff --git a/examples/snippets/servers/structured_output.py b/examples/snippets/servers/structured_output.py index bea7b22c1..422d9001c 100644 --- a/examples/snippets/servers/structured_output.py +++ b/examples/snippets/servers/structured_output.py @@ -71,7 +71,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001 self.setting1 = setting1 self.setting2 = setting2 diff --git a/pyproject.toml b/pyproject.toml index 624ade170..decfdbd31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ extend-exclude = ["README.md", "README.v2.md"] [tool.ruff.lint] select = [ + "ANN", # flake8-annotations "C4", # flake8-comprehensions "C90", # mccabe "D212", # pydocstyle: multi-line docstring summary should start at the first line @@ -141,7 +142,14 @@ select = [ "UP", # pyupgrade "TID251", # https://docs.astral.sh/ruff/rules/banned-api/ ] -ignore = ["PERF203"] +ignore = [ + "ANN204", # special methods (__init__, __enter__, etc.) have well-known return types + "ANN401", # `Any` is sometimes the right type; pyright strict handles real misuse + "PERF203", +] + +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead." @@ -152,6 +160,7 @@ max-complexity = 24 # Default is 10 [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] +"tests/**" = ["ANN"] "tests/server/mcpserver/test_func_metadata.py" = ["E501"] "tests/shared/test_progress_notifications.py" = ["PLW0603"] diff --git a/scripts/update_readme_snippets.py b/scripts/update_readme_snippets.py index 8a534e5cb..707759277 100755 --- a/scripts/update_readme_snippets.py +++ b/scripts/update_readme_snippets.py @@ -138,7 +138,7 @@ def update_readme_snippets(readme_path: Path = Path("README.md"), check_mode: bo return True -def main(): +def main() -> None: """Main entry point.""" parser = argparse.ArgumentParser(description="Update README code snippets from source files") parser.add_argument( diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 62334a4a2..5419f912b 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -39,7 +39,7 @@ ) -def _get_npx_command(): +def _get_npx_command() -> str | None: """Get the correct npx command for the current platform.""" if sys.platform == "win32": # Try both npx.cmd and npx.exe on Windows @@ -116,7 +116,7 @@ def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: return file_path, server_object -def _import_server(file: Path, server_object: str | None = None): # pragma: no cover +def _import_server(file: Path, server_object: str | None = None) -> Any: # pragma: no cover """Import an MCP server from a file. Args: @@ -140,7 +140,7 @@ def _import_server(file: Path, server_object: str | None = None): # pragma: no module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - def _check_server_object(server_object: Any, object_name: str): + def _check_server_object(server_object: Any, object_name: str) -> bool: """Helper function to check that the server object is supported Args: diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index f3db17906..c2b98d64d 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -36,7 +36,7 @@ async def run_session( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], client_info: types.Implementation | None = None, -): +) -> None: async with ClientSession( read_stream, write_stream, @@ -48,7 +48,7 @@ async def run_session( logger.info("Initialized") -async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]): +async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) -> None: env_dict = dict(env) if urlparse(command_or_url).scheme in ("http", "https"): @@ -62,7 +62,7 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) await run_session(*streams) -def cli(): +def cli() -> None: parser = argparse.ArgumentParser() parser.add_argument("command_or_url", help="Command or URL to connect to") parser.add_argument("args", nargs="*", help="Additional arguments") diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index cb6dafb40..d7e9cae49 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -435,7 +435,7 @@ async def _perform_authorization(self) -> httpx.Request: # pragma: no cover else: return await super()._perform_authorization() - def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover + def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None: # pragma: no cover """Add JWT assertion for client authentication to token endpoint parameters.""" if not self.jwt_parameters: raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow") diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7b66b5c1b..78d93fb1b 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from typing import Any from urllib.parse import parse_qs, urljoin, urlparse @@ -36,7 +36,9 @@ async def sse_client( httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, on_session_created: Callable[[str], None] | None = None, -): +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new @@ -68,7 +70,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED) -> None: try: async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") @@ -121,7 +123,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): finally: await read_stream_writer.aclose() - async def post_writer(endpoint_url: str): + async def post_writer(endpoint_url: str) -> None: try: async with write_stream_reader, write_stream: async for session_message in write_stream_reader: diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..29c41a700 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -1,6 +1,7 @@ import logging import os import sys +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path from typing import Literal, TextIO @@ -102,7 +103,11 @@ class StdioServerParameters(BaseModel): @asynccontextmanager -async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): +async def stdio_client( + server: StdioServerParameters, errlog: TextIO = sys.stderr +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ @@ -134,7 +139,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder await write_stream_reader.aclose() raise - async def stdout_reader(): + async def stdout_reader() -> None: assert process.stdout, "Opened process is missing stdout" try: @@ -161,7 +166,7 @@ async def stdout_reader(): except anyio.ClosedResourceError: # pragma: lax no cover await anyio.lowlevel.checkpoint() - async def stdin_writer(): + async def stdin_writer() -> None: assert process.stdin, "Opened process is missing stdin" try: @@ -232,7 +237,7 @@ async def _create_platform_compatible_process( env: dict[str, str] | None = None, errlog: TextIO = sys.stderr, cwd: Path | str | None = None, -): +) -> Process | FallbackProcess: """Creates a subprocess in a platform-compatible way. Unix: Creates process in a new session/process group for killpg support diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3afb94b03..4a31835cd 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -466,7 +466,7 @@ async def post_writer( read_stream_writer=read_stream_writer, ) - async def handle_request_async(): + async def handle_request_async() -> None: if is_resumption: await self._handle_resumption_request(ctx) else: diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index de473f36d..e5a7b58d8 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -43,7 +43,7 @@ async def websocket_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def ws_reader(): + async def ws_reader() -> None: """Reads text messages from the WebSocket, parses them as JSON-RPC messages, and sends them into read_stream_writer. """ @@ -57,7 +57,7 @@ async def ws_reader(): # If JSON parse or model validation fails, send the exception await read_stream_writer.send(exc) - async def ws_writer(): + async def ws_writer() -> None: """Reads JSON-RPC messages from write_stream_reader and sends them to the server. """ diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 6f68405f7..22dc600c0 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -80,7 +80,7 @@ def __init__(self, popen_obj: subprocess.Popen[bytes]): self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None - async def __aenter__(self): + async def __aenter__(self) -> "FallbackProcess": """Support async context manager entry.""" return self @@ -106,11 +106,11 @@ async def __aexit__( if self.stderr: self.stderr.close() - async def wait(self): + async def wait(self) -> int: """Async wait for process completion.""" return await to_thread.run_sync(self.popen.wait) - def terminate(self): + def terminate(self) -> None: """Terminate the subprocess immediately.""" return self.popen.terminate() @@ -313,7 +313,7 @@ async def terminate_windows_process_tree(process: Process | FallbackProcess, tim "terminate_windows_process is deprecated and will be removed in a future version. " "Process termination is now handled internally by the stdio_client context manager." ) -async def terminate_windows_process(process: Process | FallbackProcess): +async def terminate_windows_process(process: Process | FallbackProcess) -> None: """Terminate a Windows process. Note: On Windows, terminating a process with process.terminate() doesn't diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index dbc50b8a7..5d66cfb63 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -17,7 +17,7 @@ logger = logging.getLogger("server") -async def receive_loop(session: ServerSession): +async def receive_loop(session: ServerSession) -> None: logger.info("Starting receive loop") async for message in session.incoming_messages: if isinstance(message, Exception): @@ -27,7 +27,7 @@ async def receive_loop(session: ServerSession): logger.info("Received message from client: %s", message) -async def main(): +async def main() -> None: version = importlib.metadata.version("mcp") async with stdio_server() as (read_stream, write_stream): async with ( diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index dec6713b1..e5763a028 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -80,7 +80,7 @@ async def error_response( error: AuthorizationErrorCode, error_description: str | None, attempt_load_client: bool = True, - ): + ) -> Response: # Error responses take two different formats: # 1. The request has a valid client ID & redirect_uri: we issue a redirect # back to the redirect_uri with the error response fields as query diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 534a478a9..486249f21 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -6,6 +6,7 @@ from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, TypeAdapter, ValidationError from starlette.requests import Request +from starlette.responses import Response from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse @@ -63,7 +64,7 @@ class TokenHandler: provider: OAuthAuthorizationServerProvider[Any, Any, Any] client_authenticator: ClientAuthenticator - def response(self, obj: TokenSuccessResponse | TokenErrorResponse): + def response(self, obj: TokenSuccessResponse | TokenErrorResponse) -> PydanticJSONResponse: status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 @@ -77,7 +78,7 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse): }, ) - async def handle(self, request: Request): + async def handle(self, request: Request) -> Response: try: client_info = await self.client_authenticator.authenticate_request(request) except AuthenticationError as e: diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546..682b3e47a 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -32,7 +32,7 @@ class AuthContextMiddleware: def __init__(self, app: ASGIApp): self.app = app - async def __call__(self, scope: Scope, receive: Receive, send: Send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: user = scope.get("user") if isinstance(user, AuthenticatedUser): # Set the authenticated user in the contextvar diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9..bc8b5263e 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -25,7 +25,7 @@ class BearerAuthBackend(AuthenticationBackend): def __init__(self, token_verifier: TokenVerifier): self.token_verifier = token_verifier - async def authenticate(self, conn: HTTPConnection): + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, AuthenticatedUser] | None: auth_header = next( (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), None, diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index a72e81947..04905a1c0 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -21,7 +21,7 @@ from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata -def validate_issuer_url(url: AnyHttpUrl): +def validate_issuer_url(url: AnyHttpUrl) -> None: """Validate that the issuer URL meets OAuth 2.0 requirements. Args: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c28842272..408f2536b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -368,7 +368,7 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, - ): + ) -> None: async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) session = await stack.enter_async_context( @@ -411,7 +411,7 @@ async def _handle_message( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, - ): + ) -> None: with warnings.catch_warnings(record=True) as w: match message: case RequestResponder() as responder: @@ -436,7 +436,7 @@ async def _handle_request( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, - ): + ) -> None: logger.info("Processing request of type %s", type(req).__name__) if handler := self._request_handlers.get(req.method): diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c..3ad8391ed 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -14,6 +14,7 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.session import ServerSession if TYPE_CHECKING: from mcp.server.mcpserver.server import MCPServer @@ -224,7 +225,7 @@ def request_id(self) -> str: return str(self.request_context.request_id) @property - def session(self): + def session(self) -> ServerSession: """Access to the underlying session for advanced usage.""" return self.request_context.session diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117..afa3653ee 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -570,7 +570,7 @@ def decorator(fn: _CallableT) -> _CallableT: return decorator - def completion(self): + def completion(self) -> Callable[[_CallableT], _CallableT]: """Decorator to register a completion handler. The completion handler receives: @@ -799,7 +799,7 @@ def custom_route( methods: list[str], name: str | None = None, include_in_schema: bool = True, - ): + ) -> Callable[[Callable[[Request], Awaitable[Response]]], Callable[[Request], Awaitable[Response]]]: """Decorator to register a custom HTTP route on the MCP server. Allows adding arbitrary HTTP endpoints outside the standard MCP protocol, @@ -926,7 +926,7 @@ def sse_app( sse = SseServerTransport(message_path, security_settings=transport_security) - async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def handle_sse(scope: Scope, receive: Receive, send: Send) -> Response: # pragma: no cover # Add client ID from auth context into request context if available async with sse.connect_sse(scope, receive, send) as streams: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..79ad72466 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -161,7 +161,7 @@ async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): + async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]) -> None: match responder.request: case types.InitializeRequest(params=params): requested_version = params.protocol_version diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f7..ab53a802e 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -37,6 +37,7 @@ async def handle_sse(request): """ import logging +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any from urllib.parse import quote @@ -116,7 +117,11 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def connect_sse( + self, scope: Scope, receive: Receive, send: Send + ) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None + ]: # pragma: no cover if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -159,7 +164,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) - async def sse_writer(): + async def sse_writer() -> None: logger.debug("Starting SSE writer") async with sse_stream_writer, write_stream_reader: await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) @@ -176,7 +181,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: - async def response_wrapper(scope: Scope, receive: Receive, send: Send): + async def response_wrapper(scope: Scope, receive: Receive, send: Send) -> None: """The EventSourceResponse returning signals a client close / disconnect. In this case we close our side of the streams to signal the client that the connection has been closed. diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5ea6c4e77..f393fb9aa 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -18,6 +18,7 @@ async def run_server(): """ import sys +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from io import TextIOWrapper @@ -30,7 +31,11 @@ async def run_server(): @asynccontextmanager -async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): +async def stdio_server( + stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout. """ @@ -52,7 +57,7 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def stdin_reader(): + async def stdin_reader() -> None: try: async with read_stream_writer: async for line in stdin: @@ -67,7 +72,7 @@ async def stdin_reader(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async def stdout_writer(): + async def stdout_writer() -> None: try: async with write_stream_reader: async for session_message in write_stream_reader: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..db1274fa6 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -577,7 +577,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer() -> None: # pragma: lax no cover # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -696,7 +696,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - async def standalone_sse_writer(): + async def standalone_sse_writer() -> None: try: # Create a standalone message stream for server-initiated messages @@ -890,7 +890,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) # Create SSE stream for replay sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - async def replay_sender(): + async def replay_sender() -> None: try: async with sse_stream_writer: # Define an async callback for sending events @@ -979,7 +979,7 @@ async def connect( # Start a task group for message routing async with anyio.create_task_group() as tg: # Create a message router that distributes messages to request streams - async def message_router(): + async def message_router() -> None: try: async for session_message in write_stream_reader: # pragma: no branch # Determine which request stream(s) should receive this message diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..8e60863ed 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -162,7 +162,7 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: ) # Start server in a new task - async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): + async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 32b50560c..79623d7ba 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager import anyio @@ -11,7 +12,11 @@ @asynccontextmanager -async def websocket_server(scope: Scope, receive: Receive, send: Send): +async def websocket_server( + scope: Scope, receive: Receive, send: Send +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """WebSocket server transport for MCP. This is an ASGI application, suitable for use with a framework like Starlette and a server like Hypercorn. """ @@ -28,7 +33,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def ws_reader(): + async def ws_reader() -> None: try: async with read_stream_writer: async for msg in websocket.iter_text(): @@ -43,7 +48,7 @@ async def ws_reader(): except anyio.ClosedResourceError: # pragma: no cover await websocket.close() - async def ws_writer(): + async def ws_writer() -> None: try: async with write_stream_reader: async for session_message in write_stream_reader: diff --git a/tests/test_examples.py b/tests/test_examples.py index 3af82f04c..3b9cdb23a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -96,7 +96,7 @@ async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): # TODO(v2): Change back to README.md when v2 is released @pytest.mark.parametrize("example", find_examples("README.v2.md"), ids=str) def test_docs_examples(example: CodeExample, eval_example: EvalExample): - ruff_ignore: list[str] = ["F841", "I001", "F821"] # F821: undefined names (snippets lack imports) + ruff_ignore: list[str] = ["F841", "I001", "F821", "ANN"] # F821: undefined names (snippets lack imports) # Use project's actual line length of 120 eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=120)