From 5697d68e6514c2b157392a0c4ede68b11d7447ba Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 20 Mar 2026 16:01:24 +0000 Subject: [PATCH 1/5] chore: enforce type annotations on all functions via ruff ANN rules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable the `ANN` (flake8-annotations) rule set to require return type and argument annotations on all function definitions. Pyright strict already enforces argument types but infers return types silently; this closes the gap so every function has an explicit return annotation. Configuration choices: - `ANN401` (any-type) is ignored: `Any` is sometimes the right answer and pyright strict already catches genuine misuse. - `allow-star-arg-any = true`: `*args: Any, **kwargs: Any` is a common valid pattern for pass-through wrappers. - Per-file ignores for `test_func_metadata.py` and `test_server.py` where untyped parameters are intentional (testing schema inference on unannotated signatures). - README code snippets (via pytest-examples) exempt — short doc examples shouldn't need full annotations. Auto-fix added `-> None` or inferred return types to ~910 functions. The remaining ~100 were annotated manually — mostly `@asynccontextmanager` generators, pytest fixtures, and ASGI handlers that Ruff couldn't infer. --- examples/mcpserver/memory.py | 36 ++-- .../mcp_simple_auth/auth_server.py | 4 +- .../mcp_simple_auth/legacy_as_server.py | 2 +- .../mcp_simple_auth/simple_auth_provider.py | 4 +- .../mcp_simple_auth/token_verifier.py | 2 +- .../__main__.py | 2 +- .../snippets/clients/completion_client.py | 4 +- .../snippets/clients/display_utilities.py | 8 +- examples/snippets/clients/oauth_client.py | 6 +- .../snippets/clients/parsing_tool_results.py | 4 +- examples/snippets/clients/stdio_client.py | 4 +- examples/snippets/clients/streamable_basic.py | 2 +- examples/snippets/servers/__init__.py | 2 +- examples/snippets/servers/direct_execution.py | 2 +- examples/snippets/servers/lowlevel/basic.py | 2 +- .../lowlevel/direct_call_tool_result.py | 2 +- .../snippets/servers/lowlevel/lifespan.py | 2 +- .../servers/lowlevel/structured_output.py | 2 +- .../servers/streamable_http_basic_mounting.py | 3 +- .../servers/streamable_http_host_mounting.py | 3 +- .../streamable_http_multiple_servers.py | 3 +- .../servers/streamable_starlette_mount.py | 3 +- .../snippets/servers/structured_output.py | 4 +- pyproject.toml | 13 +- scripts/update_readme_snippets.py | 2 +- src/mcp/cli/cli.py | 6 +- src/mcp/client/__main__.py | 6 +- .../auth/extensions/client_credentials.py | 2 +- src/mcp/client/auth/oauth2.py | 2 +- src/mcp/client/sse.py | 10 +- src/mcp/client/stdio.py | 13 +- src/mcp/client/streamable_http.py | 2 +- src/mcp/client/websocket.py | 4 +- src/mcp/os/win32/utilities.py | 10 +- src/mcp/server/__main__.py | 4 +- src/mcp/server/auth/handlers/authorize.py | 2 +- src/mcp/server/auth/handlers/token.py | 5 +- .../server/auth/middleware/auth_context.py | 4 +- src/mcp/server/auth/middleware/bearer_auth.py | 8 +- src/mcp/server/auth/middleware/client_auth.py | 4 +- src/mcp/server/auth/provider.py | 4 +- src/mcp/server/auth/routes.py | 2 +- src/mcp/server/experimental/task_context.py | 2 +- .../experimental/task_result_handler.py | 2 +- src/mcp/server/lowlevel/server.py | 12 +- src/mcp/server/mcpserver/context.py | 5 +- src/mcp/server/mcpserver/prompts/base.py | 6 +- src/mcp/server/mcpserver/prompts/manager.py | 2 +- .../mcpserver/resources/resource_manager.py | 2 +- src/mcp/server/mcpserver/server.py | 8 +- .../server/mcpserver/tools/tool_manager.py | 2 +- src/mcp/server/mcpserver/utilities/types.py | 4 +- src/mcp/server/session.py | 2 +- src/mcp/server/sse.py | 11 +- src/mcp/server/stdio.py | 11 +- src/mcp/server/streamable_http.py | 8 +- src/mcp/server/streamable_http_manager.py | 6 +- src/mcp/server/transport_security.py | 2 +- src/mcp/server/websocket.py | 11 +- src/mcp/shared/auth.py | 4 +- src/mcp/shared/exceptions.py | 6 +- src/mcp/shared/experimental/tasks/context.py | 2 +- tests/cli/test_claude.py | 22 +-- tests/cli/test_utils.py | 16 +- .../extensions/test_client_credentials.py | 42 +++-- tests/client/conftest.py | 26 +-- tests/client/test_auth.py | 138 ++++++++------- tests/client/test_client.py | 40 ++--- tests/client/test_list_methods_cursor.py | 8 +- tests/client/test_list_roots_callback.py | 4 +- tests/client/test_logging_callback.py | 4 +- tests/client/test_output_schema_validation.py | 10 +- tests/client/test_resource_cleanup.py | 6 +- tests/client/test_sampling_callback.py | 6 +- tests/client/test_scope_bug_1630.py | 2 +- tests/client/test_session.py | 40 ++--- tests/client/test_session_group.py | 22 +-- tests/client/test_stdio.py | 22 +-- tests/client/transports/test_memory.py | 12 +- tests/conftest.py | 2 +- .../tasks/client/test_capabilities.py | 16 +- tests/experimental/tasks/client/test_tasks.py | 8 +- .../tasks/server/test_integration.py | 8 +- tests/issues/test_100_tool_listing.py | 4 +- .../test_1027_win_unreachable_cleanup.py | 4 +- tests/issues/test_129_resource_templates.py | 2 +- tests/issues/test_1338_icons_and_metadata.py | 6 +- ...est_1363_race_condition_streamable_http.py | 12 +- tests/issues/test_141_resource_templates.py | 4 +- tests/issues/test_152_resource_mime_type.py | 4 +- .../test_1574_resource_uri_validation.py | 8 +- .../issues/test_1754_mime_type_parameters.py | 8 +- tests/issues/test_176_progress_token.py | 2 +- tests/issues/test_188_concurrency.py | 12 +- tests/issues/test_192_request_id.py | 2 +- tests/issues/test_342_base64_encoding.py | 2 +- tests/issues/test_355_type_error.py | 6 +- tests/issues/test_552_windows_hang.py | 2 +- tests/issues/test_88_random_error.py | 6 +- tests/issues/test_973_url_decoding.py | 8 +- tests/issues/test_malformed_input.py | 4 +- .../auth/middleware/test_auth_context.py | 6 +- .../auth/middleware/test_bearer_auth.py | 38 ++-- tests/server/auth/test_error_handling.py | 16 +- tests/server/auth/test_protected_resource.py | 27 +-- tests/server/auth/test_provider.py | 16 +- tests/server/auth/test_routes.py | 18 +- tests/server/lowlevel/test_helper_types.py | 6 +- .../mcpserver/auth/test_auth_integration.py | 111 ++++++------ tests/server/mcpserver/prompts/test_base.py | 22 +-- .../server/mcpserver/prompts/test_manager.py | 16 +- .../resources/test_file_resources.py | 17 +- .../resources/test_function_resources.py | 22 +-- .../resources/test_resource_manager.py | 23 +-- .../resources/test_resource_template.py | 30 ++-- .../mcpserver/resources/test_resources.py | 26 +-- .../mcpserver/servers/test_file_server.py | 10 +- tests/server/mcpserver/test_elicitation.py | 47 ++--- tests/server/mcpserver/test_func_metadata.py | 82 ++++----- tests/server/mcpserver/test_integration.py | 12 +- .../mcpserver/test_parameter_descriptions.py | 2 +- tests/server/mcpserver/test_server.py | 166 +++++++++--------- tests/server/mcpserver/test_title.py | 10 +- tests/server/mcpserver/test_tool_manager.py | 108 ++++++------ .../server/mcpserver/test_url_elicitation.py | 44 ++--- .../test_url_elicitation_error_throw.py | 6 +- tests/server/mcpserver/tools/test_base.py | 2 +- tests/server/test_cancel_handling.py | 12 +- tests/server/test_completion_with_context.py | 8 +- tests/server/test_lifespan.py | 8 +- .../test_lowlevel_exception_handling.py | 8 +- .../server/test_lowlevel_tool_annotations.py | 2 +- tests/server/test_read_resource.py | 4 +- tests/server/test_session.py | 28 +-- tests/server/test_session_race_condition.py | 6 +- tests/server/test_sse_security.py | 28 +-- tests/server/test_stateless_mode.py | 16 +- tests/server/test_stdio.py | 4 +- tests/server/test_streamable_http_manager.py | 57 +++--- tests/server/test_streamable_http_security.py | 24 +-- tests/shared/test_auth.py | 6 +- tests/shared/test_auth_utils.py | 30 ++-- tests/shared/test_httpx_utils.py | 4 +- tests/shared/test_progress_notifications.py | 6 +- tests/shared/test_session.py | 40 ++--- tests/shared/test_sse.py | 4 +- tests/shared/test_streamable_http.py | 102 ++++++----- tests/test_examples.py | 12 +- tests/test_types.py | 28 +-- 149 files changed, 1151 insertions(+), 1036 deletions(-) diff --git a/examples/mcpserver/memory.py b/examples/mcpserver/memory.py index fd0bd9362..5a7f6ab4d 100644 --- a/examples/mcpserver/memory.py +++ b/examples/mcpserver/memory.py @@ -50,11 +50,17 @@ def cosine_similarity(a: list[float], b: list[float]) -> float: return np.dot(a_array, b_array) / (np.linalg.norm(a_array) * np.linalg.norm(b_array)) +@dataclass +class Deps: + openai: AsyncOpenAI + pool: asyncpg.Pool + + async def do_ai( user_prompt: str, system_prompt: str, result_type: type[T] | Annotated, - deps=None, + deps: Deps | None = None, ) -> T: agent = Agent( DEFAULT_LLM_MODEL, @@ -65,14 +71,8 @@ async def do_ai( return result.data -@dataclass -class Deps: - openai: AsyncOpenAI - pool: asyncpg.Pool - - async def get_db_pool() -> asyncpg.Pool: - async def init(conn): + async def init(conn: asyncpg.Connection) -> None: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") await register_vector(conn) @@ -90,11 +90,11 @@ class MemoryNode(BaseModel): embedding: list[float] @classmethod - async def from_content(cls, content: str, deps: Deps): + async def from_content(cls, content: str, deps: Deps) -> Self: embedding = await get_embedding(content, deps) return cls(content=content, embedding=embedding) - async def save(self, deps: Deps): + async def save(self, deps: Deps) -> None: async with deps.pool.acquire() as conn: if self.id is None: result = await conn.fetchrow( @@ -129,7 +129,7 @@ async def save(self, deps: Deps): self.id, ) - async def merge_with(self, other: Self, deps: Deps): + async def merge_with(self, other: Self, deps: Deps) -> None: self.content = await do_ai( f"{self.content}\n\n{other.content}", "Combine the following two texts into a single, coherent text.", @@ -145,7 +145,7 @@ async def merge_with(self, other: Self, deps: Deps): if other.id is not None: await delete_memory(other.id, deps) - def get_effective_importance(self): + def get_effective_importance(self) -> float: return self.importance * (1 + math.log(self.access_count + 1)) @@ -157,12 +157,12 @@ async def get_embedding(text: str, deps: Deps) -> list[float]: return embedding_response.data[0].embedding -async def delete_memory(memory_id: int, deps: Deps): +async def delete_memory(memory_id: int, deps: Deps) -> None: async with deps.pool.acquire() as conn: await conn.execute("DELETE FROM memories WHERE id = $1", memory_id) -async def add_memory(content: str, deps: Deps): +async def add_memory(content: str, deps: Deps) -> str: new_memory = await MemoryNode.from_content(content, deps) await new_memory.save(deps) @@ -204,7 +204,7 @@ async def find_similar_memories(embedding: list[float], deps: Deps) -> list[Memo return memories -async def update_importance(user_embedding: list[float], deps: Deps): +async def update_importance(user_embedding: list[float], deps: Deps) -> None: async with deps.pool.acquire() as conn: rows = await conn.fetch("SELECT id, importance, access_count, embedding FROM memories") for row in rows: @@ -228,7 +228,7 @@ async def update_importance(user_embedding: list[float], deps: Deps): ) -async def prune_memories(deps: Deps): +async def prune_memories(deps: Deps) -> None: async with deps.pool.acquire() as conn: rows = await conn.fetch( """ @@ -265,7 +265,7 @@ async def display_memory_tree(deps: Deps) -> str: @mcp.tool() async def remember( contents: list[str] = Field(description="List of observations or memories to store"), -): +) -> str: deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool()) try: return "\n".join(await asyncio.gather(*[add_memory(content, deps) for content in contents])) @@ -281,7 +281,7 @@ async def read_profile() -> str: return profile -async def initialize_database(): +async def initialize_database() -> None: pool = await asyncpg.create_pool("postgresql://postgres:postgres@localhost:54320/postgres") try: async with pool.acquire() as conn: diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 9d13fffe4..0125f3659 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -47,7 +47,7 @@ class SimpleAuthProvider(SimpleOAuthProvider): 2. Stores token state for introspection by Resource Servers """ - def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): + def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str) -> None: super().__init__(auth_settings, auth_callback_path, server_url) @@ -134,7 +134,7 @@ async def introspect_handler(request: Request) -> Response: return Starlette(routes=routes) -async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings): +async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings) -> None: """Run the Authorization Server.""" auth_server = create_authorization_server(server_settings, auth_settings) diff --git a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py index ab7773b5b..41aed08c0 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py @@ -39,7 +39,7 @@ class ServerSettings(BaseModel): class LegacySimpleOAuthProvider(SimpleOAuthProvider): """Simple OAuth provider for legacy MCP server.""" - def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): + def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str) -> None: super().__init__(auth_settings, auth_callback_path, server_url) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 3a3895cc5..cf4e3dc81 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -51,7 +51,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, Re 3. Maintaining token state for introspection """ - def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str): + def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str) -> None: self.settings = settings self.auth_callback_url = auth_callback_url self.server_url = server_url @@ -66,7 +66,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """Get OAuth client information.""" return self.clients.get(client_id) - async def register_client(self, client_info: OAuthClientInformationFull): + async def register_client(self, client_info: OAuthClientInformationFull) -> None: """Register a new OAuth client.""" if not client_info.client_id: raise ValueError("No client_id provided") diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 5228d034e..8f9407431 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -25,7 +25,7 @@ def __init__( introspection_endpoint: str, server_url: str, validate_resource: bool = False, - ): + ) -> None: self.introspection_endpoint = introspection_endpoint self.server_url = server_url self.validate_resource = validate_resource diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index 95fb90854..a11b420e4 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -75,7 +75,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ ) -async def run(): +async def run() -> None: """Run the low-level server using stdio transport.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/clients/completion_client.py b/examples/snippets/clients/completion_client.py index dc0c1b4f7..9ef10c5d4 100644 --- a/examples/snippets/clients/completion_client.py +++ b/examples/snippets/clients/completion_client.py @@ -17,7 +17,7 @@ ) -async def run(): +async def run() -> None: """Run the completion client example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -68,7 +68,7 @@ async def run(): print(f"Completions for 'style' argument: {result.completion.values}") -def main(): +def main() -> None: """Entry point for the completion client.""" asyncio.run(run()) diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index baa2765a8..1a5e966ed 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -17,7 +17,7 @@ ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientSession) -> None: """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -29,7 +29,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientSession) -> None: """Display available resources with human-readable names""" resources_response = await session.list_resources() @@ -43,7 +43,7 @@ async def display_resources(session: ClientSession): print(f"Resource Template: {display_name}") -async def run(): +async def run() -> None: """Run the display utilities example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -57,7 +57,7 @@ async def run(): await display_resources(session) -def main(): +def main() -> None: """Entry point for the display utilities client.""" asyncio.run(run()) diff --git a/examples/snippets/clients/oauth_client.py b/examples/snippets/clients/oauth_client.py index 3887c5c8c..0f6cd5568 100644 --- a/examples/snippets/clients/oauth_client.py +++ b/examples/snippets/clients/oauth_client.py @@ -21,7 +21,7 @@ class InMemoryTokenStorage(TokenStorage): """Demo In-memory token storage implementation.""" - def __init__(self): + def __init__(self) -> None: self.tokens: OAuthToken | None = None self.client_info: OAuthClientInformationFull | None = None @@ -52,7 +52,7 @@ async def handle_callback() -> tuple[str, str | None]: return params["code"][0], params.get("state", [None])[0] -async def main(): +async def main() -> None: """Run the OAuth client example.""" oauth_auth = OAuthClientProvider( server_url="http://localhost:8001", @@ -80,7 +80,7 @@ async def main(): print(f"Available resources: {[r.uri for r in resources.resources]}") -def run(): +def run() -> None: asyncio.run(main()) diff --git a/examples/snippets/clients/parsing_tool_results.py b/examples/snippets/clients/parsing_tool_results.py index b16640677..945fa9d01 100644 --- a/examples/snippets/clients/parsing_tool_results.py +++ b/examples/snippets/clients/parsing_tool_results.py @@ -6,7 +6,7 @@ from mcp.client.stdio import stdio_client -async def parse_tool_results(): +async def parse_tool_results() -> None: """Demonstrates how to parse different types of content in CallToolResult.""" server_params = StdioServerParameters(command="python", args=["path/to/mcp_server.py"]) @@ -52,7 +52,7 @@ async def parse_tool_results(): print(f"Error: {content.text}") -async def main(): +async def main() -> None: await parse_tool_results() diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index c1f85f42a..44c8bd45b 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -33,7 +33,7 @@ async def handle_sampling_message( ) -async def run(): +async def run() -> None: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection @@ -71,7 +71,7 @@ async def run(): print(f"Structured tool result: {result_structured}") -def main(): +def main() -> None: """Entry point for the client script.""" asyncio.run(run()) diff --git a/examples/snippets/clients/streamable_basic.py b/examples/snippets/clients/streamable_basic.py index 43bb6396c..d64a05bf3 100644 --- a/examples/snippets/clients/streamable_basic.py +++ b/examples/snippets/clients/streamable_basic.py @@ -8,7 +8,7 @@ from mcp.client.streamable_http import streamable_http_client -async def main(): +async def main() -> None: # Connect to a streamable HTTP server async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream): # Create a session using the client streams diff --git a/examples/snippets/servers/__init__.py b/examples/snippets/servers/__init__.py index f132f875f..94220c728 100644 --- a/examples/snippets/servers/__init__.py +++ b/examples/snippets/servers/__init__.py @@ -12,7 +12,7 @@ from typing import Literal, cast -def run_server(): +def run_server() -> None: """Run a server by name with optional transport. Usage: server [transport] diff --git a/examples/snippets/servers/direct_execution.py b/examples/snippets/servers/direct_execution.py index acf7151d3..3fed250ef 100644 --- a/examples/snippets/servers/direct_execution.py +++ b/examples/snippets/servers/direct_execution.py @@ -18,7 +18,7 @@ def hello(name: str = "World") -> str: return f"Hello, {name}!" -def main(): +def main() -> None: """Entry point for the direct execution server.""" mcp.run() diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index 81f40e994..1981d22b0 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -49,7 +49,7 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe ) -async def run(): +async def run() -> None: """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 7e8fc4dcb..ba44fac9e 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -48,7 +48,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ ) -async def run(): +async def run() -> None: """Run the server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index bcd96c893..d72c93c2f 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -85,7 +85,7 @@ async def handle_call_tool( ) -async def run(): +async def run() -> None: """Run the server with lifespan management.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index f93c8875f..71107b792 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -66,7 +66,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ ) -async def run(): +async def run() -> None: """Run the structured output server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( diff --git a/examples/snippets/servers/streamable_http_basic_mounting.py b/examples/snippets/servers/streamable_http_basic_mounting.py index 9a53034f1..a1fd72a29 100644 --- a/examples/snippets/servers/streamable_http_basic_mounting.py +++ b/examples/snippets/servers/streamable_http_basic_mounting.py @@ -5,6 +5,7 @@ """ import contextlib +from collections.abc import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Mount @@ -23,7 +24,7 @@ def hello() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with mcp.session_manager.run(): yield diff --git a/examples/snippets/servers/streamable_http_host_mounting.py b/examples/snippets/servers/streamable_http_host_mounting.py index 2a41f74a5..ea72a98a4 100644 --- a/examples/snippets/servers/streamable_http_host_mounting.py +++ b/examples/snippets/servers/streamable_http_host_mounting.py @@ -5,6 +5,7 @@ """ import contextlib +from collections.abc import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Host @@ -23,7 +24,7 @@ def domain_info() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with mcp.session_manager.run(): yield diff --git a/examples/snippets/servers/streamable_http_multiple_servers.py b/examples/snippets/servers/streamable_http_multiple_servers.py index 71217bdfe..e46924e8a 100644 --- a/examples/snippets/servers/streamable_http_multiple_servers.py +++ b/examples/snippets/servers/streamable_http_multiple_servers.py @@ -5,6 +5,7 @@ """ import contextlib +from collections.abc import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Mount @@ -30,7 +31,7 @@ def send_message(message: str) -> str: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(api_mcp.session_manager.run()) await stack.enter_async_context(chat_mcp.session_manager.run()) diff --git a/examples/snippets/servers/streamable_starlette_mount.py b/examples/snippets/servers/streamable_starlette_mount.py index eb6f1b809..95186ad7f 100644 --- a/examples/snippets/servers/streamable_starlette_mount.py +++ b/examples/snippets/servers/streamable_starlette_mount.py @@ -3,6 +3,7 @@ """ import contextlib +from collections.abc import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Mount @@ -31,7 +32,7 @@ def add_two(n: int) -> int: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(echo_mcp.session_manager.run()) await stack.enter_async_context(math_mcp.session_manager.run()) diff --git a/examples/snippets/servers/structured_output.py b/examples/snippets/servers/structured_output.py index bea7b22c1..69f4b203a 100644 --- a/examples/snippets/servers/structured_output.py +++ b/examples/snippets/servers/structured_output.py @@ -57,7 +57,7 @@ class UserProfile: age: int email: str | None = None - def __init__(self, name: str, age: int, email: str | None = None): + def __init__(self, name: str, age: int, email: str | None = None) -> None: self.name = name self.age = age self.email = email @@ -71,7 +71,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] + def __init__(self, setting1, setting2) -> None: # type: ignore[reportMissingParameterType] # noqa: ANN001 self.setting1 = setting1 self.setting2 = setting2 diff --git a/pyproject.toml b/pyproject.toml index 624ade170..b38312577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ extend-exclude = ["README.md", "README.v2.md"] [tool.ruff.lint] select = [ + "ANN", # flake8-annotations "C4", # flake8-comprehensions "C90", # mccabe "D212", # pydocstyle: multi-line docstring summary should start at the first line @@ -141,7 +142,13 @@ select = [ "UP", # pyupgrade "TID251", # https://docs.astral.sh/ruff/rules/banned-api/ ] -ignore = ["PERF203"] +ignore = [ + "ANN401", # `Any` is sometimes the right type; pyright strict handles real misuse + "PERF203", +] + +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead." @@ -152,7 +159,9 @@ max-complexity = 24 # Default is 10 [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] -"tests/server/mcpserver/test_func_metadata.py" = ["E501"] +# ANN001: these files intentionally define untyped parameters to test schema inference +"tests/server/mcpserver/test_func_metadata.py" = ["ANN001", "E501"] +"tests/server/mcpserver/test_server.py" = ["ANN001"] "tests/shared/test_progress_notifications.py" = ["PLW0603"] [tool.ruff.lint.pylint] diff --git a/scripts/update_readme_snippets.py b/scripts/update_readme_snippets.py index 8a534e5cb..707759277 100755 --- a/scripts/update_readme_snippets.py +++ b/scripts/update_readme_snippets.py @@ -138,7 +138,7 @@ def update_readme_snippets(readme_path: Path = Path("README.md"), check_mode: bo return True -def main(): +def main() -> None: """Main entry point.""" parser = argparse.ArgumentParser(description="Update README code snippets from source files") parser.add_argument( diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 62334a4a2..5419f912b 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -39,7 +39,7 @@ ) -def _get_npx_command(): +def _get_npx_command() -> str | None: """Get the correct npx command for the current platform.""" if sys.platform == "win32": # Try both npx.cmd and npx.exe on Windows @@ -116,7 +116,7 @@ def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: return file_path, server_object -def _import_server(file: Path, server_object: str | None = None): # pragma: no cover +def _import_server(file: Path, server_object: str | None = None) -> Any: # pragma: no cover """Import an MCP server from a file. Args: @@ -140,7 +140,7 @@ def _import_server(file: Path, server_object: str | None = None): # pragma: no module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - def _check_server_object(server_object: Any, object_name: str): + def _check_server_object(server_object: Any, object_name: str) -> bool: """Helper function to check that the server object is supported Args: diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index f3db17906..c2b98d64d 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -36,7 +36,7 @@ async def run_session( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], client_info: types.Implementation | None = None, -): +) -> None: async with ClientSession( read_stream, write_stream, @@ -48,7 +48,7 @@ async def run_session( logger.info("Initialized") -async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]): +async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) -> None: env_dict = dict(env) if urlparse(command_or_url).scheme in ("http", "https"): @@ -62,7 +62,7 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) await run_session(*streams) -def cli(): +def cli() -> None: parser = argparse.ArgumentParser() parser.add_argument("command_or_url", help="Command or URL to connect to") parser.add_argument("args", nargs="*", help="Additional arguments") diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index cb6dafb40..d7e9cae49 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -435,7 +435,7 @@ async def _perform_authorization(self) -> httpx.Request: # pragma: no cover else: return await super()._perform_authorization() - def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover + def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None: # pragma: no cover """Add JWT assertion for client authentication to token endpoint parameters.""" if not self.jwt_parameters: raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow") diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 25075dec3..78eb0c325 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -232,7 +232,7 @@ def __init__( timeout: float = 300.0, client_metadata_url: str | None = None, validate_resource_url: Callable[[str, str | None], Awaitable[None]] | None = None, - ): + ) -> None: """Initialize OAuth2 authentication. Args: diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7b66b5c1b..78d93fb1b 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from typing import Any from urllib.parse import parse_qs, urljoin, urlparse @@ -36,7 +36,9 @@ async def sse_client( httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, on_session_created: Callable[[str], None] | None = None, -): +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new @@ -68,7 +70,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED) -> None: try: async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") @@ -121,7 +123,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): finally: await read_stream_writer.aclose() - async def post_writer(endpoint_url: str): + async def post_writer(endpoint_url: str) -> None: try: async with write_stream_reader, write_stream: async for session_message in write_stream_reader: diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..29c41a700 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -1,6 +1,7 @@ import logging import os import sys +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path from typing import Literal, TextIO @@ -102,7 +103,11 @@ class StdioServerParameters(BaseModel): @asynccontextmanager -async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): +async def stdio_client( + server: StdioServerParameters, errlog: TextIO = sys.stderr +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ @@ -134,7 +139,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder await write_stream_reader.aclose() raise - async def stdout_reader(): + async def stdout_reader() -> None: assert process.stdout, "Opened process is missing stdout" try: @@ -161,7 +166,7 @@ async def stdout_reader(): except anyio.ClosedResourceError: # pragma: lax no cover await anyio.lowlevel.checkpoint() - async def stdin_writer(): + async def stdin_writer() -> None: assert process.stdin, "Opened process is missing stdin" try: @@ -232,7 +237,7 @@ async def _create_platform_compatible_process( env: dict[str, str] | None = None, errlog: TextIO = sys.stderr, cwd: Path | str | None = None, -): +) -> Process | FallbackProcess: """Creates a subprocess in a platform-compatible way. Unix: Creates process in a new session/process group for killpg support diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3afb94b03..4a31835cd 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -466,7 +466,7 @@ async def post_writer( read_stream_writer=read_stream_writer, ) - async def handle_request_async(): + async def handle_request_async() -> None: if is_resumption: await self._handle_resumption_request(ctx) else: diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index de473f36d..e5a7b58d8 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -43,7 +43,7 @@ async def websocket_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def ws_reader(): + async def ws_reader() -> None: """Reads text messages from the WebSocket, parses them as JSON-RPC messages, and sends them into read_stream_writer. """ @@ -57,7 +57,7 @@ async def ws_reader(): # If JSON parse or model validation fails, send the exception await read_stream_writer.send(exc) - async def ws_writer(): + async def ws_writer() -> None: """Reads JSON-RPC messages from write_stream_reader and sends them to the server. """ diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 6f68405f7..e99092182 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -71,7 +71,7 @@ class FallbackProcess: so that MCP clients expecting async streams can work properly. """ - def __init__(self, popen_obj: subprocess.Popen[bytes]): + def __init__(self, popen_obj: subprocess.Popen[bytes]) -> None: self.popen: subprocess.Popen[bytes] = popen_obj self.stdin_raw = popen_obj.stdin # type: ignore[assignment] self.stdout_raw = popen_obj.stdout # type: ignore[assignment] @@ -80,7 +80,7 @@ def __init__(self, popen_obj: subprocess.Popen[bytes]): self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None - async def __aenter__(self): + async def __aenter__(self) -> "FallbackProcess": """Support async context manager entry.""" return self @@ -106,11 +106,11 @@ async def __aexit__( if self.stderr: self.stderr.close() - async def wait(self): + async def wait(self) -> int: """Async wait for process completion.""" return await to_thread.run_sync(self.popen.wait) - def terminate(self): + def terminate(self) -> None: """Terminate the subprocess immediately.""" return self.popen.terminate() @@ -313,7 +313,7 @@ async def terminate_windows_process_tree(process: Process | FallbackProcess, tim "terminate_windows_process is deprecated and will be removed in a future version. " "Process termination is now handled internally by the stdio_client context manager." ) -async def terminate_windows_process(process: Process | FallbackProcess): +async def terminate_windows_process(process: Process | FallbackProcess) -> None: """Terminate a Windows process. Note: On Windows, terminating a process with process.terminate() doesn't diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index dbc50b8a7..5d66cfb63 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -17,7 +17,7 @@ logger = logging.getLogger("server") -async def receive_loop(session: ServerSession): +async def receive_loop(session: ServerSession) -> None: logger.info("Starting receive loop") async for message in session.incoming_messages: if isinstance(message, Exception): @@ -27,7 +27,7 @@ async def receive_loop(session: ServerSession): logger.info("Received message from client: %s", message) -async def main(): +async def main() -> None: version = importlib.metadata.version("mcp") async with stdio_server() as (read_stream, write_stream): async with ( diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index dec6713b1..e5763a028 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -80,7 +80,7 @@ async def error_response( error: AuthorizationErrorCode, error_description: str | None, attempt_load_client: bool = True, - ): + ) -> Response: # Error responses take two different formats: # 1. The request has a valid client ID & redirect_uri: we issue a redirect # back to the redirect_uri with the error response fields as query diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 534a478a9..486249f21 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -6,6 +6,7 @@ from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, TypeAdapter, ValidationError from starlette.requests import Request +from starlette.responses import Response from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse @@ -63,7 +64,7 @@ class TokenHandler: provider: OAuthAuthorizationServerProvider[Any, Any, Any] client_authenticator: ClientAuthenticator - def response(self, obj: TokenSuccessResponse | TokenErrorResponse): + def response(self, obj: TokenSuccessResponse | TokenErrorResponse) -> PydanticJSONResponse: status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 @@ -77,7 +78,7 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse): }, ) - async def handle(self, request: Request): + async def handle(self, request: Request) -> Response: try: client_info = await self.client_authenticator.authenticate_request(request) except AuthenticationError as e: diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546..686c2480f 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -29,10 +29,10 @@ class AuthContextMiddleware: being stored in the context. """ - def __init__(self, app: ASGIApp): + def __init__(self, app: ASGIApp) -> None: self.app = app - async def __call__(self, scope: Scope, receive: Receive, send: Send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: user = scope.get("user") if isinstance(user, AuthenticatedUser): # Set the authenticated user in the contextvar diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9..ef0d083c9 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -13,7 +13,7 @@ class AuthenticatedUser(SimpleUser): """User with authentication info.""" - def __init__(self, auth_info: AccessToken): + def __init__(self, auth_info: AccessToken) -> None: super().__init__(auth_info.client_id) self.access_token = auth_info self.scopes = auth_info.scopes @@ -22,10 +22,10 @@ def __init__(self, auth_info: AccessToken): class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" - def __init__(self, token_verifier: TokenVerifier): + def __init__(self, token_verifier: TokenVerifier) -> None: self.token_verifier = token_verifier - async def authenticate(self, conn: HTTPConnection): + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, AuthenticatedUser] | None: auth_header = next( (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), None, @@ -59,7 +59,7 @@ def __init__( app: Any, required_scopes: list[str], resource_metadata_url: AnyHttpUrl | None = None, - ): + ) -> None: """Initialize the middleware. Args: diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 2832f8352..2f8396e49 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -12,7 +12,7 @@ class AuthenticationError(Exception): - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -28,7 +28,7 @@ class ClientAuthenticator: logic is skipped. """ - def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): + def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]) -> None: """Initialize the authenticator. Args: diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a85..9ea811d68 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -285,7 +285,9 @@ class ProviderTokenVerifier(TokenVerifier): the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier. """ - def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"): + def __init__( + self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]" + ) -> None: self.provider = provider async def verify_token(self, token: str) -> AccessToken | None: diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index a72e81947..04905a1c0 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -21,7 +21,7 @@ from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata -def validate_issuer_url(url: AnyHttpUrl): +def validate_issuer_url(url: AnyHttpUrl) -> None: """Validate that the issuer URL meets OAuth 2.0 requirements. Args: diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 1fc45badf..74d2a3f53 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -80,7 +80,7 @@ def __init__( session: ServerSession, queue: TaskMessageQueue, handler: TaskResultHandler | None = None, - ): + ) -> None: """Create a ServerTaskContext. Args: diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index b2268bc1c..c158d4d3d 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -58,7 +58,7 @@ def __init__( self, store: TaskStore, queue: TaskMessageQueue, - ): + ) -> None: self._store = store self._queue = queue # Map from internal request ID to resolver for routing responses diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c28842272..7fa745fd3 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -75,7 +75,9 @@ async def main(): class NotificationOptions: - def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): + def __init__( + self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False + ) -> None: self.prompts_changed = prompts_changed self.resources_changed = resources_changed self.tools_changed = tools_changed @@ -181,7 +183,7 @@ def __init__( Awaitable[None], ] | None = None, - ): + ) -> None: self.name = name self.version = version self.title = title @@ -368,7 +370,7 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, - ): + ) -> None: async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) session = await stack.enter_async_context( @@ -411,7 +413,7 @@ async def _handle_message( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, - ): + ) -> None: with warnings.catch_warnings(record=True) as w: match message: case RequestResponder() as responder: @@ -436,7 +438,7 @@ async def _handle_request( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, - ): + ) -> None: logger.info("Processing request of type %s", type(req).__name__) if handler := self._request_handlers.get(req.method): diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c..e2689d6b2 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -14,6 +14,7 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.session import ServerSession if TYPE_CHECKING: from mcp.server.mcpserver.server import MCPServer @@ -64,7 +65,7 @@ def __init__( mcp_server: MCPServer | None = None, # TODO(Marcelo): We should drop this kwargs parameter. **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self._request_context = request_context self._mcp_server = mcp_server @@ -224,7 +225,7 @@ def request_id(self) -> str: return str(self.request_context.request_id) @property - def session(self): + def session(self) -> ServerSession: """Access to the underlying session for advanced usage.""" return self.request_context.session diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 0c319d53c..2f502250b 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -24,7 +24,7 @@ class Message(BaseModel): role: Literal["user", "assistant"] content: ContentBlock - def __init__(self, content: str | ContentBlock, **kwargs: Any): + def __init__(self, content: str | ContentBlock, **kwargs: Any) -> None: if isinstance(content, str): content = TextContent(type="text", text=content) super().__init__(content=content, **kwargs) @@ -35,7 +35,7 @@ class UserMessage(Message): role: Literal["user", "assistant"] = "user" - def __init__(self, content: str | ContentBlock, **kwargs: Any): + def __init__(self, content: str | ContentBlock, **kwargs: Any) -> None: super().__init__(content=content, **kwargs) @@ -44,7 +44,7 @@ class AssistantMessage(Message): role: Literal["user", "assistant"] = "assistant" - def __init__(self, content: str | ContentBlock, **kwargs: Any): + def __init__(self, content: str | ContentBlock, **kwargs: Any) -> None: super().__init__(content=content, **kwargs) diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 28a7a6e98..303a5bbb9 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -17,7 +17,7 @@ class PromptManager: """Manages MCPServer prompts.""" - def __init__(self, warn_on_duplicate_prompts: bool = True): + def __init__(self, warn_on_duplicate_prompts: bool = True) -> None: self._prompts: dict[str, Prompt] = {} self.warn_on_duplicate_prompts = warn_on_duplicate_prompts diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 6bf17376d..56b9e0e6d 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -22,7 +22,7 @@ class ResourceManager: """Manages MCPServer resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__(self, warn_on_duplicate_resources: bool = True) -> None: self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117..b80035ade 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -144,7 +144,7 @@ def __init__( warn_on_duplicate_prompts: bool = True, lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, - ): + ) -> None: self.settings = Settings( debug=debug, log_level=log_level, @@ -570,7 +570,7 @@ def decorator(fn: _CallableT) -> _CallableT: return decorator - def completion(self): + def completion(self) -> Callable[[_CallableT], _CallableT]: """Decorator to register a completion handler. The completion handler receives: @@ -799,7 +799,7 @@ def custom_route( methods: list[str], name: str | None = None, include_in_schema: bool = True, - ): + ) -> Callable[[Callable[[Request], Awaitable[Response]]], Callable[[Request], Awaitable[Response]]]: """Decorator to register a custom HTTP route on the MCP server. Allows adding arbitrary HTTP endpoints outside the standard MCP protocol, @@ -926,7 +926,7 @@ def sse_app( sse = SseServerTransport(message_path, security_settings=transport_security) - async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def handle_sse(scope: Scope, receive: Receive, send: Send) -> Response: # pragma: no cover # Add client ID from auth context into request context if available async with sse.connect_sse(scope, receive, send) as streams: diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 32ed54797..dd1ac519f 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -23,7 +23,7 @@ def __init__( warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None, - ): + ) -> None: self._tools: dict[str, Tool] = {} if tools is not None: for tool in tools: diff --git a/src/mcp/server/mcpserver/utilities/types.py b/src/mcp/server/mcpserver/utilities/types.py index f092b245a..9e05a663e 100644 --- a/src/mcp/server/mcpserver/utilities/types.py +++ b/src/mcp/server/mcpserver/utilities/types.py @@ -14,7 +14,7 @@ def __init__( path: str | Path | None = None, data: bytes | None = None, format: str | None = None, - ): + ) -> None: if path is None and data is None: # pragma: no cover raise ValueError("Either path or data must be provided") if path is not None and data is not None: # pragma: no cover @@ -62,7 +62,7 @@ def __init__( path: str | Path | None = None, data: bytes | None = None, format: str | None = None, - ): + ) -> None: if not bool(path) ^ bool(data): # pragma: no cover raise ValueError("Either path or data can be provided") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..79ad72466 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -161,7 +161,7 @@ async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): + async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]) -> None: match responder.request: case types.InitializeRequest(params=params): requested_version = params.protocol_version diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f7..ab53a802e 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -37,6 +37,7 @@ async def handle_sse(request): """ import logging +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any from urllib.parse import quote @@ -116,7 +117,11 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def connect_sse( + self, scope: Scope, receive: Receive, send: Send + ) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None + ]: # pragma: no cover if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -159,7 +164,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) - async def sse_writer(): + async def sse_writer() -> None: logger.debug("Starting SSE writer") async with sse_stream_writer, write_stream_reader: await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) @@ -176,7 +181,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: - async def response_wrapper(scope: Scope, receive: Receive, send: Send): + async def response_wrapper(scope: Scope, receive: Receive, send: Send) -> None: """The EventSourceResponse returning signals a client close / disconnect. In this case we close our side of the streams to signal the client that the connection has been closed. diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5ea6c4e77..f393fb9aa 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -18,6 +18,7 @@ async def run_server(): """ import sys +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from io import TextIOWrapper @@ -30,7 +31,11 @@ async def run_server(): @asynccontextmanager -async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): +async def stdio_server( + stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout. """ @@ -52,7 +57,7 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def stdin_reader(): + async def stdin_reader() -> None: try: async with read_stream_writer: async for line in stdin: @@ -67,7 +72,7 @@ async def stdin_reader(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async def stdout_writer(): + async def stdout_writer() -> None: try: async with write_stream_reader: async for session_message in write_stream_reader: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..db1274fa6 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -577,7 +577,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer() -> None: # pragma: lax no cover # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -696,7 +696,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - async def standalone_sse_writer(): + async def standalone_sse_writer() -> None: try: # Create a standalone message stream for server-initiated messages @@ -890,7 +890,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) # Create SSE stream for replay sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - async def replay_sender(): + async def replay_sender() -> None: try: async with sse_stream_writer: # Define an async callback for sending events @@ -979,7 +979,7 @@ async def connect( # Start a task group for message routing async with anyio.create_task_group() as tg: # Create a message router that distributes messages to request streams - async def message_router(): + async def message_router() -> None: try: async for session_message in write_stream_reader: # pragma: no branch # Determine which request stream(s) should receive this message diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..20fc5f74a 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -72,7 +72,7 @@ def __init__( security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, session_idle_timeout: float | None = None, - ): + ) -> None: if session_idle_timeout is not None and session_idle_timeout <= 0: raise ValueError("session_idle_timeout must be a positive number of seconds") if stateless and session_idle_timeout is not None: @@ -162,7 +162,7 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: ) # Start server in a new task - async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): + async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() @@ -289,7 +289,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE class StreamableHTTPASGIApp: """ASGI application for Streamable HTTP server transport.""" - def __init__(self, session_manager: StreamableHTTPSessionManager): + def __init__(self, session_manager: StreamableHTTPSessionManager) -> None: self.session_manager = session_manager async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..9acb41538 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -36,7 +36,7 @@ class TransportSecuritySettings(BaseModel): class TransportSecurityMiddleware: """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" - def __init__(self, settings: TransportSecuritySettings | None = None): + def __init__(self, settings: TransportSecuritySettings | None = None) -> None: # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 32b50560c..79623d7ba 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager import anyio @@ -11,7 +12,11 @@ @asynccontextmanager -async def websocket_server(scope: Scope, receive: Receive, send: Send): +async def websocket_server( + scope: Scope, receive: Receive, send: Send +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None +]: """WebSocket server transport for MCP. This is an ASGI application, suitable for use with a framework like Starlette and a server like Hypercorn. """ @@ -28,7 +33,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async def ws_reader(): + async def ws_reader() -> None: try: async with read_stream_writer: async for msg in websocket.iter_text(): @@ -43,7 +48,7 @@ async def ws_reader(): except anyio.ClosedResourceError: # pragma: no cover await websocket.close() - async def ws_writer(): + async def ws_writer() -> None: try: async with write_stream_reader: async for session_message in write_stream_reader: diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ca5b7b45a..c151ab104 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -23,12 +23,12 @@ def normalize_token_type(cls, v: str | None) -> str | None: class InvalidScopeError(Exception): - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message class InvalidRedirectUriError(Exception): - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..ba203220c 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -10,7 +10,7 @@ class MCPError(Exception): error: ErrorData - def __init__(self, code: int, message: str, data: Any = None): + def __init__(self, code: int, message: str, data: Any = None) -> None: super().__init__(code, message, data) if data is not None: self.error = ErrorData(code=code, message=message, data=data) @@ -49,7 +49,7 @@ class StatelessModeNotSupported(RuntimeError): for bidirectional communication. """ - def __init__(self, method: str): + def __init__(self, method: str) -> None: super().__init__( f"Cannot use {method} in stateless HTTP mode. " "Stateless mode does not support server-to-client requests. " @@ -76,7 +76,7 @@ class UrlElicitationRequiredError(MCPError): ``` """ - def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None): + def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None) -> None: """Initialize UrlElicitationRequiredError.""" if message is None: message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required" diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py index ed0d2b91b..124591a0e 100644 --- a/src/mcp/shared/experimental/tasks/context.py +++ b/src/mcp/shared/experimental/tasks/context.py @@ -31,7 +31,7 @@ async def worker_job(task_id: str): await ctx.complete(result) """ - def __init__(self, task: Task, store: TaskStore): + def __init__(self, task: Task, store: TaskStore) -> None: self._task = task self._store = store self._cancelled = False diff --git a/tests/cli/test_claude.py b/tests/cli/test_claude.py index 73d4f0eb5..c33489a74 100644 --- a/tests/cli/test_claude.py +++ b/tests/cli/test_claude.py @@ -24,7 +24,7 @@ def _read_server(config_dir: Path, name: str) -> dict[str, Any]: return config["mcpServers"][name] -def test_generates_uv_run_command(config_dir: Path): +def test_generates_uv_run_command(config_dir: Path) -> None: """Should write a uv run command that invokes mcp run on the resolved file spec.""" assert update_claude_config(file_spec="server.py:app", server_name="my_server") @@ -35,14 +35,14 @@ def test_generates_uv_run_command(config_dir: Path): } -def test_file_spec_without_object_suffix(config_dir: Path): +def test_file_spec_without_object_suffix(config_dir: Path) -> None: """File specs without :object should still resolve to an absolute path.""" assert update_claude_config(file_spec="server.py", server_name="s") assert _read_server(config_dir, "s")["args"][-1] == str(Path("server.py").resolve()) -def test_with_packages_sorted_and_deduplicated(config_dir: Path): +def test_with_packages_sorted_and_deduplicated(config_dir: Path) -> None: """Extra packages should appear as --with flags, sorted and deduplicated with mcp[cli].""" assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["zebra", "aardvark", "zebra"]) @@ -50,7 +50,7 @@ def test_with_packages_sorted_and_deduplicated(config_dir: Path): assert args[:8] == ["run", "--frozen", "--with", "aardvark", "--with", "mcp[cli]", "--with", "zebra"] -def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): +def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path) -> None: """with_editable should add --with-editable after the --with flags.""" editable = tmp_path / "project" assert update_claude_config(file_spec="s.py:app", server_name="s", with_editable=editable) @@ -59,14 +59,14 @@ def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): assert args[4:6] == ["--with-editable", str(editable)] -def test_env_vars_written(config_dir: Path): +def test_env_vars_written(config_dir: Path) -> None: """env_vars should be written under the server's env key.""" assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "val"}) assert _read_server(config_dir, "s")["env"] == {"KEY": "val"} -def test_existing_env_vars_merged_new_wins(config_dir: Path): +def test_existing_env_vars_merged_new_wins(config_dir: Path) -> None: """Re-installing should merge env vars, with new values overriding existing ones.""" (config_dir / "claude_desktop_config.json").write_text( json.dumps({"mcpServers": {"s": {"env": {"OLD": "keep", "KEY": "old"}}}}) @@ -77,7 +77,7 @@ def test_existing_env_vars_merged_new_wins(config_dir: Path): assert _read_server(config_dir, "s")["env"] == {"OLD": "keep", "KEY": "new"} -def test_existing_env_vars_preserved_without_new(config_dir: Path): +def test_existing_env_vars_preserved_without_new(config_dir: Path) -> None: """Re-installing without env_vars should keep the existing env block intact.""" (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"s": {"env": {"KEEP": "me"}}}})) @@ -86,7 +86,7 @@ def test_existing_env_vars_preserved_without_new(config_dir: Path): assert _read_server(config_dir, "s")["env"] == {"KEEP": "me"} -def test_other_servers_preserved(config_dir: Path): +def test_other_servers_preserved(config_dir: Path) -> None: """Installing a new server should not clobber existing mcpServers entries.""" (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"other": {"command": "x"}}})) @@ -97,7 +97,7 @@ def test_other_servers_preserved(config_dir: Path): assert config["mcpServers"]["other"] == {"command": "x"} -def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): +def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch) -> None: """Should raise RuntimeError when Claude Desktop config dir can't be found.""" monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: None) monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") @@ -107,7 +107,7 @@ def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("which_result, expected", [("/usr/local/bin/uv", "/usr/local/bin/uv"), (None, "uv")]) -def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str): +def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str) -> None: """Should return shutil.which's result, or fall back to bare 'uv' when not on PATH.""" def fake_which(cmd: str) -> str | None: @@ -126,7 +126,7 @@ def fake_which(cmd: str) -> str | None: ) def test_windows_drive_letter_not_split( config_dir: Path, monkeypatch: pytest.MonkeyPatch, file_spec: str, expected_last_arg: str -): +) -> None: """Drive-letter paths like 'C:\\server.py' must not be split on the drive colon. Before the fix, a bare 'C:\\path\\server.py' would hit rsplit(":", 1) and yield diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 44f4ab4d3..50b2646c5 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -15,7 +15,7 @@ ("foo.py:srv_obj", "srv_obj"), ], ) -def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected_obj: str | None): +def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected_obj: str | None) -> None: """Should accept valid file specs.""" file = tmp_path / spec.split(":")[0] file.write_text("x = 1") @@ -24,13 +24,13 @@ def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected assert obj == expected_obj -def test_parse_file_path_missing(tmp_path: Path): +def test_parse_file_path_missing(tmp_path: Path) -> None: """Should system exit if a file is missing.""" with pytest.raises(SystemExit): _parse_file_path(str(tmp_path / "missing.py")) -def test_parse_file_exit_on_dir(tmp_path: Path): +def test_parse_file_exit_on_dir(tmp_path: Path) -> None: """Should system exit if a directory is passed""" dir_path = tmp_path / "dir" dir_path.mkdir() @@ -38,13 +38,13 @@ def test_parse_file_exit_on_dir(tmp_path: Path): _parse_file_path(str(dir_path)) -def test_build_uv_command_minimal(): +def test_build_uv_command_minimal() -> None: """Should emit core command when no extras specified.""" cmd = _build_uv_command("foo.py") assert cmd == ["uv", "run", "--with", "mcp", "mcp", "run", "foo.py"] -def test_build_uv_command_adds_editable_and_packages(): +def test_build_uv_command_adds_editable_and_packages() -> None: """Should include --with-editable and every --with pkg in correct order.""" test_path = Path("/pkg") cmd = _build_uv_command( @@ -69,13 +69,13 @@ def test_build_uv_command_adds_editable_and_packages(): ] -def test_get_npx_unix_like(monkeypatch: pytest.MonkeyPatch): +def test_get_npx_unix_like(monkeypatch: pytest.MonkeyPatch) -> None: """Should return "npx" on unix-like systems.""" monkeypatch.setattr(sys, "platform", "linux") assert _get_npx_command() == "npx" -def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch): +def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch) -> None: """Should return one of the npx candidates on Windows.""" candidates = ["npx.cmd", "npx.exe", "npx"] @@ -90,7 +90,7 @@ def fake_run(cmd: list[str], **kw: Any) -> subprocess.CompletedProcess[bytes]: assert _get_npx_command() in candidates -def test_get_npx_returns_none_when_npx_missing(monkeypatch: pytest.MonkeyPatch): +def test_get_npx_returns_none_when_npx_missing(monkeypatch: pytest.MonkeyPatch) -> None: """Should give None if every candidate fails.""" monkeypatch.setattr(sys, "platform", "win32", raising=False) diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 09760f453..b75b08453 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -19,7 +19,7 @@ class MockTokenStorage: """Mock token storage for testing.""" - def __init__(self): + def __init__(self) -> None: self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -37,12 +37,12 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.fixture -def mock_storage(): +def mock_storage() -> MockTokenStorage: return MockTokenStorage() @pytest.fixture -def client_metadata(): +def client_metadata() -> OAuthClientMetadata: return OAuthClientMetadata( client_name="Test Client", client_uri=AnyHttpUrl("https://example.com"), @@ -52,7 +52,9 @@ def client_metadata(): @pytest.fixture -def rfc7523_oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): +def rfc7523_oauth_provider( + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage +) -> RFC7523OAuthClientProvider: async def redirect_handler(url: str) -> None: # pragma: no cover """Mock redirect handler.""" pass @@ -76,7 +78,9 @@ class TestOAuthFlowClientCredentials: """Test OAuth flow behavior for client credentials flows.""" @pytest.mark.anyio - async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): + async def test_token_exchange_request_jwt_predefined( + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider + ) -> None: """Test token exchange request building with a predefined JWT assertion.""" # Set up required context rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( @@ -115,7 +119,7 @@ async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provide ) @pytest.mark.anyio - async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): + async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider) -> None: """Test token exchange request building wiith a generated JWT assertion.""" # Set up required context rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( @@ -177,7 +181,7 @@ class TestClientCredentialsOAuthProvider: """Test ClientCredentialsOAuthProvider.""" @pytest.mark.anyio - async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage) -> None: """Test that _initialize sets client_info.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -196,7 +200,7 @@ async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): assert provider.context.client_info.token_endpoint_auth_method == "client_secret_basic" @pytest.mark.anyio - async def test_init_with_scopes(self, mock_storage: MockTokenStorage): + async def test_init_with_scopes(self, mock_storage: MockTokenStorage) -> None: """Test that constructor accepts scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -211,7 +215,7 @@ async def test_init_with_scopes(self, mock_storage: MockTokenStorage): assert provider.context.client_info.scope == "read write" @pytest.mark.anyio - async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): + async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage) -> None: """Test that constructor accepts client_secret_post auth method.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -226,7 +230,7 @@ async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post" @pytest.mark.anyio - async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage) -> None: """Test token exchange request building.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -253,7 +257,7 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert "resource=https://api.example.com/v1/mcp" in content @pytest.mark.anyio - async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage): + async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage) -> None: """Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1).""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -281,7 +285,7 @@ async def test_exchange_token_client_secret_post_includes_client_id(self, mock_s assert "Authorization" not in request.headers @pytest.mark.anyio - async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage): + async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage) -> None: """Test client_secret_post skips body credentials when client_id is None.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -319,7 +323,7 @@ async def test_exchange_token_client_secret_post_without_client_id(self, mock_st assert "Authorization" not in request.headers @pytest.mark.anyio - async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage) -> None: """Test token exchange without scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -346,7 +350,7 @@ class TestPrivateKeyJWTOAuthProvider: """Test PrivateKeyJWTOAuthProvider.""" @pytest.mark.anyio - async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage) -> None: """Test that _initialize sets client_info.""" async def mock_assertion_provider(audience: str) -> str: # pragma: no cover @@ -368,7 +372,7 @@ async def mock_assertion_provider(audience: str) -> str: # pragma: no cover assert provider.context.client_info.token_endpoint_auth_method == "private_key_jwt" @pytest.mark.anyio - async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage) -> None: """Test token exchange request building with assertion provider.""" async def mock_assertion_provider(audience: str) -> str: @@ -400,7 +404,7 @@ async def mock_assertion_provider(audience: str) -> str: assert "scope=read write" in content @pytest.mark.anyio - async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage) -> None: """Test token exchange without scopes.""" async def mock_assertion_provider(audience: str) -> str: @@ -431,7 +435,7 @@ class TestSignedJWTParameters: """Test SignedJWTParameters.""" @pytest.mark.anyio - async def test_create_assertion_provider(self): + async def test_create_assertion_provider(self) -> None: """Test that create_assertion_provider creates valid JWTs.""" params = SignedJWTParameters( issuer="test-issuer", @@ -458,7 +462,7 @@ async def test_create_assertion_provider(self): assert "jti" in claims @pytest.mark.anyio - async def test_create_assertion_provider_with_additional_claims(self): + async def test_create_assertion_provider_with_additional_claims(self) -> None: """Test that additional_claims are included in the JWT.""" params = SignedJWTParameters( issuer="test-issuer", @@ -484,7 +488,7 @@ class TestStaticAssertionProvider: """Test static_assertion_provider helper.""" @pytest.mark.anyio - async def test_returns_static_token(self): + async def test_returns_static_token(self) -> None: """Test that static_assertion_provider returns the same token regardless of audience.""" token = "my-static-jwt-token" provider = static_assertion_provider(token) diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 2e39f1363..a7f580b86 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -1,10 +1,10 @@ -from collections.abc import Callable, Generator +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import asynccontextmanager from typing import Any from unittest.mock import patch import pytest -from anyio.streams.memory import MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.shared.memory from mcp.shared.message import SessionMessage @@ -12,26 +12,26 @@ class SpyMemoryObjectSendStream: - def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): + def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]) -> None: self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] - async def send(self, message: SessionMessage): + async def send(self, message: SessionMessage) -> None: self.sent_messages.append(message) await self.original_stream.send(message) - async def aclose(self): + async def aclose(self) -> None: await self.original_stream.aclose() - async def __aenter__(self): + async def __aenter__(self) -> "SpyMemoryObjectSendStream": return self - async def __aexit__(self, *args: Any): + async def __aexit__(self, *args: Any) -> None: await self.aclose() class StreamSpyCollection: - def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream): + def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream) -> None: self.client = client_spy self.server = server_spy @@ -99,7 +99,7 @@ async def test_something(stream_spy): server_spy = None # Store references to our spy objects - def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream): + def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream) -> None: nonlocal client_spy, server_spy client_spy = c_spy server_spy = s_spy @@ -108,7 +108,13 @@ def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendSt original_create_streams = mcp.shared.memory.create_client_server_memory_streams @asynccontextmanager - async def patched_create_streams(): + async def patched_create_streams() -> AsyncGenerator[ + tuple[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], SpyMemoryObjectSendStream], + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], SpyMemoryObjectSendStream], + ], + None, + ]: async with original_create_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e36..964a5b333 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -40,7 +40,7 @@ class MockTokenStorage: """Mock token storage for testing.""" - def __init__(self): + def __init__(self) -> None: self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -58,12 +58,12 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.fixture -def mock_storage(): +def mock_storage() -> MockTokenStorage: return MockTokenStorage() @pytest.fixture -def client_metadata(): +def client_metadata() -> OAuthClientMetadata: return OAuthClientMetadata( client_name="Test Client", client_uri=AnyHttpUrl("https://example.com"), @@ -73,7 +73,7 @@ def client_metadata(): @pytest.fixture -def valid_tokens(): +def valid_tokens() -> OAuthToken: return OAuthToken( access_token="test_access_token", token_type="Bearer", @@ -84,7 +84,7 @@ def valid_tokens(): @pytest.fixture -def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): +def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage) -> OAuthClientProvider: async def redirect_handler(url: str) -> None: """Mock redirect handler.""" pass # pragma: no cover @@ -103,7 +103,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.fixture -def prm_metadata_response(): +def prm_metadata_response() -> httpx.Response: """PRM metadata response with scopes.""" return httpx.Response( 200, @@ -116,7 +116,7 @@ def prm_metadata_response(): @pytest.fixture -def prm_metadata_without_scopes_response(): +def prm_metadata_without_scopes_response() -> httpx.Response: """PRM metadata response without scopes.""" return httpx.Response( 200, @@ -129,7 +129,7 @@ def prm_metadata_without_scopes_response(): @pytest.fixture -def init_response_with_www_auth_scope(): +def init_response_with_www_auth_scope() -> httpx.Response: """Initial 401 response with WWW-Authenticate header containing scope.""" return httpx.Response( 401, @@ -139,7 +139,7 @@ def init_response_with_www_auth_scope(): @pytest.fixture -def init_response_without_www_auth_scope(): +def init_response_without_www_auth_scope() -> httpx.Response: """Initial 401 response without WWW-Authenticate scope.""" return httpx.Response( 401, @@ -151,7 +151,7 @@ def init_response_without_www_auth_scope(): class TestPKCEParameters: """Test PKCE parameter generation.""" - def test_pkce_generation(self): + def test_pkce_generation(self) -> None: """Test PKCE parameter generation creates valid values.""" pkce = PKCEParameters.generate() @@ -166,7 +166,7 @@ def test_pkce_generation(self): # Verify base64url encoding in challenge (no padding) assert "=" not in pkce.code_challenge - def test_pkce_uniqueness(self): + def test_pkce_uniqueness(self) -> None: """Test PKCE generates unique values each time.""" pkce1 = PKCEParameters.generate() pkce2 = PKCEParameters.generate() @@ -181,7 +181,7 @@ class TestOAuthContext: @pytest.mark.anyio async def test_oauth_provider_initialization( self, oauth_provider: OAuthClientProvider, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test OAuthClientProvider basic setup.""" assert oauth_provider.context.server_url == "https://api.example.com/v1/mcp" assert oauth_provider.context.client_metadata == client_metadata @@ -189,7 +189,7 @@ async def test_oauth_provider_initialization( assert oauth_provider.context.timeout == 300.0 assert oauth_provider.context is not None - def test_context_url_parsing(self, oauth_provider: OAuthClientProvider): + def test_context_url_parsing(self, oauth_provider: OAuthClientProvider) -> None: """Test get_authorization_base_url() extracts base URLs correctly.""" context = oauth_provider.context @@ -211,7 +211,7 @@ def test_context_url_parsing(self, oauth_provider: OAuthClientProvider): ) @pytest.mark.anyio - async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken) -> None: """Test is_token_valid() and can_refresh_token() logic.""" context = oauth_provider.context @@ -246,7 +246,7 @@ async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider context.client_info = None assert not context.can_refresh_token() - def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken) -> None: """Test clear_tokens() removes token data.""" context = oauth_provider.context context.current_tokens = valid_tokens @@ -266,7 +266,7 @@ class TestOAuthFlow: @pytest.mark.anyio async def test_build_protected_resource_discovery_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test protected resource metadata discovery URL building with fallback.""" async def redirect_handler(url: str) -> None: @@ -307,7 +307,7 @@ async def callback_handler() -> tuple[str, str | None]: assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @pytest.mark.anyio - def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): + def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider) -> None: """Test OAuth metadata discovery request building.""" request = create_oauth_metadata_request("https://example.com") @@ -321,7 +321,7 @@ class TestOAuthFallback: """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" @pytest.mark.anyio - async def test_oauth_discovery_legacy_fallback_when_no_prm(self): + async def test_oauth_discovery_legacy_fallback_when_no_prm(self) -> None: """Test that when PRM discovery fails, only root OAuth URL is tried (March 2025 spec).""" # When auth_server_url is None (PRM failed), we use server_url and only try root discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://mcp.linear.app/sse") @@ -332,7 +332,7 @@ async def test_oauth_discovery_legacy_fallback_when_no_prm(self): ] @pytest.mark.anyio - async def test_oauth_discovery_path_aware_when_auth_server_has_path(self): + async def test_oauth_discovery_path_aware_when_auth_server_has_path(self) -> None: """Test that when auth server URL has a path, only path-based URLs are tried.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com/tenant1", "https://api.example.com/mcp" @@ -346,7 +346,7 @@ async def test_oauth_discovery_path_aware_when_auth_server_has_path(self): ] @pytest.mark.anyio - async def test_oauth_discovery_root_when_auth_server_has_no_path(self): + async def test_oauth_discovery_root_when_auth_server_has_no_path(self) -> None: """Test that when auth server URL has no path, only root URLs are tried.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com", "https://api.example.com/mcp" @@ -359,7 +359,7 @@ async def test_oauth_discovery_root_when_auth_server_has_no_path(self): ] @pytest.mark.anyio - async def test_oauth_discovery_root_when_auth_server_has_only_slash(self): + async def test_oauth_discovery_root_when_auth_server_has_only_slash(self) -> None: """Test that when auth server URL has only trailing slash, treated as root.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com/", "https://api.example.com/mcp" @@ -372,7 +372,7 @@ async def test_oauth_discovery_root_when_auth_server_has_only_slash(self): ] @pytest.mark.anyio - async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): + async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider) -> None: """Test fallback URL construction order when auth server URL has a path.""" # Simulate PRM discovery returning an auth server URL with a path oauth_provider.context.auth_server_url = oauth_provider.context.server_url @@ -388,7 +388,7 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientP ] @pytest.mark.anyio - async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthClientProvider): + async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthClientProvider) -> None: """Test the conditions during which an AS metadata discovery fallback will be attempted.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -507,7 +507,7 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl pass # Expected - generator should complete @pytest.mark.anyio - async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider): + async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider) -> None: """Test successful metadata response handling.""" # Create minimal valid OAuth metadata content = b"""{ @@ -528,7 +528,7 @@ async def test_prioritize_www_auth_scope_over_prm( oauth_provider: OAuthClientProvider, prm_metadata_response: httpx.Response, init_response_with_www_auth_scope: httpx.Response, - ): + ) -> None: """Test that WWW-Authenticate scope is prioritized over PRM scopes.""" # First, process PRM metadata to set protected_resource_metadata with scopes await oauth_provider._handle_protected_resource_response(prm_metadata_response) @@ -548,7 +548,7 @@ async def test_prioritize_prm_scopes_when_no_www_auth_scope( oauth_provider: OAuthClientProvider, prm_metadata_response: httpx.Response, init_response_without_www_auth_scope: httpx.Response, - ): + ) -> None: """Test that PRM scopes are prioritized when WWW-Authenticate header has no scopes.""" # Process the PRM metadata to set protected_resource_metadata with scopes await oauth_provider._handle_protected_resource_response(prm_metadata_response) @@ -568,7 +568,7 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( oauth_provider: OAuthClientProvider, prm_metadata_without_scopes_response: httpx.Response, init_response_without_www_auth_scope: httpx.Response, - ): + ) -> None: """Test that scope is omitted when PRM has no scopes and WWW-Authenticate doesn't specify scope.""" # Process the PRM metadata without scopes await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) @@ -582,7 +582,7 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( assert scopes is None @pytest.mark.anyio - async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider): + async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider) -> None: """Test token exchange request building.""" # Set up required context oauth_provider.context.client_info = OAuthClientInformationFull( @@ -607,7 +607,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O assert "client_secret=test_secret" in content @pytest.mark.anyio - async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken) -> None: """Test refresh token request building.""" # Set up required context oauth_provider.context.current_tokens = valid_tokens @@ -632,7 +632,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert "client_secret=test_secret" in content @pytest.mark.anyio - async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider): + async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider) -> None: """Test token exchange with client_secret_basic authentication.""" # Set up OAuth metadata to support basic auth oauth_provider.context.oauth_metadata = OAuthMetadata( @@ -677,7 +677,9 @@ async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvid assert "client_id=test%40client" in content # client_id still in body @pytest.mark.anyio - async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + async def test_basic_auth_refresh_token( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ) -> None: """Test token refresh with client_secret_basic authentication.""" oauth_provider.context.current_tokens = valid_tokens @@ -712,7 +714,7 @@ async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvide assert "client_secret=" not in content @pytest.mark.anyio - async def test_none_auth_method(self, oauth_provider: OAuthClientProvider): + async def test_none_auth_method(self, oauth_provider: OAuthClientProvider) -> None: """Test 'none' authentication method (public client).""" oauth_provider.context.oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -744,7 +746,9 @@ class TestProtectedResourceMetadata: """Test protected resource handling.""" @pytest.mark.anyio - async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): + async def test_resource_param_included_with_recent_protocol_version( + self, oauth_provider: OAuthClientProvider + ) -> None: """Test resource parameter is included for protocol version >= 2025-06-18.""" # Set protocol version to 2025-06-18 oauth_provider.context.protocol_version = "2025-06-18" @@ -773,7 +777,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ assert "resource=" in refresh_content @pytest.mark.anyio - async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): + async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider) -> None: """Test resource parameter is excluded for protocol version < 2025-06-18.""" # Set protocol version to older version oauth_provider.context.protocol_version = "2025-03-26" @@ -799,7 +803,9 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro assert "resource=" not in refresh_content @pytest.mark.anyio - async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): + async def test_resource_param_included_with_protected_resource_metadata( + self, oauth_provider: OAuthClientProvider + ) -> None: """Test resource parameter is always included when protected resource metadata exists.""" # Set old protocol version but with protected resource metadata oauth_provider.context.protocol_version = "2025-03-26" @@ -953,22 +959,22 @@ class TestRegistrationResponse: """Test client registration response handling.""" @pytest.mark.anyio - async def test_handle_registration_response_reads_before_accessing_text(self): + async def test_handle_registration_response_reads_before_accessing_text(self) -> None: """Test that response.aread() is called before accessing response.text.""" # Track if aread() was called class MockResponse(httpx.Response): - def __init__(self): + def __init__(self) -> None: self.status_code = 400 self._aread_called = False self._text = "Registration failed with error" - async def aread(self): + async def aread(self) -> bytes: self._aread_called = True return b"test content" @property - def text(self): + def text(self) -> str: if not self._aread_called: raise RuntimeError("Response.text accessed before response.aread()") # pragma: no cover return self._text @@ -988,7 +994,7 @@ def text(self): class TestCreateClientRegistrationRequest: """Test client registration request creation.""" - def test_uses_registration_endpoint_from_metadata(self): + def test_uses_registration_endpoint_from_metadata(self) -> None: """Test that registration URL comes from metadata when available.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -1003,7 +1009,7 @@ def test_uses_registration_endpoint_from_metadata(self): assert str(request.url) == "https://auth.example.com/register" assert request.method == "POST" - def test_falls_back_to_default_register_endpoint_when_no_metadata(self): + def test_falls_back_to_default_register_endpoint_when_no_metadata(self) -> None: """Test that registration uses fallback URL when auth_server_metadata is None.""" client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) @@ -1012,7 +1018,7 @@ def test_falls_back_to_default_register_endpoint_when_no_metadata(self): assert str(request.url) == "https://auth.example.com/register" assert request.method == "POST" - def test_falls_back_when_metadata_has_no_registration_endpoint(self): + def test_falls_back_when_metadata_has_no_registration_endpoint(self) -> None: """Test fallback when metadata exists but lacks registration_endpoint.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -1034,7 +1040,7 @@ class TestAuthFlow: @pytest.mark.anyio async def test_auth_flow_with_valid_tokens( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken - ): + ) -> None: """Test auth flow when tokens are already valid.""" # Pre-store valid tokens await mock_storage.set_tokens(valid_tokens) @@ -1060,7 +1066,9 @@ async def test_auth_flow_with_valid_tokens( pass # Expected @pytest.mark.anyio - async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage): + async def test_auth_flow_with_no_tokens( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage + ) -> None: """Test auth flow when no tokens are available, triggering the full OAuth flow.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -1170,7 +1178,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide @pytest.mark.anyio async def test_auth_flow_no_unnecessary_retry_after_oauth( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken - ): + ) -> None: """Test that requests are not retried unnecessarily - the core bug that caused 2x performance degradation.""" # Pre-store valid tokens so no OAuth flow is needed await mock_storage.set_tokens(valid_tokens) @@ -1213,7 +1221,7 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( @pytest.mark.anyio async def test_token_exchange_accepts_201_status( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage - ): + ) -> None: """Test that token exchange accepts both 200 and 201 status codes.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -1326,7 +1334,7 @@ async def test_403_insufficient_scope_updates_scope_from_header( oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken, - ): + ) -> None: """Test that 403 response correctly updates scope from WWW-Authenticate header.""" # Pre-store valid tokens and client info client_info = OAuthClientInformationFull( @@ -1462,7 +1470,7 @@ def test_build_metadata( token_endpoint: str, registration_endpoint: str, revocation_endpoint: str, -): +) -> None: metadata = build_metadata( issuer_url=AnyHttpUrl(issuer_url), service_documentation_url=AnyHttpUrl(service_documentation_url), @@ -1493,7 +1501,7 @@ class TestLegacyServerFallback: @pytest.mark.anyio async def test_legacy_server_no_prm_falls_back_to_root_oauth_discovery( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test that when PRM discovery fails completely, we fall back to root OAuth discovery (March 2025 spec).""" async def redirect_handler(url: str) -> None: @@ -1592,7 +1600,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_legacy_server_with_different_prm_and_root_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test PRM fallback with different WWW-Authenticate and root URLs.""" async def redirect_handler(url: str) -> None: @@ -1697,7 +1705,7 @@ class TestSEP985Discovery: @pytest.mark.anyio async def test_path_based_fallback_when_no_www_authenticate( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" async def redirect_handler(url: str) -> None: @@ -1732,7 +1740,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_root_based_fallback_after_path_based_404( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test that client falls back to root-based URI when path-based returns 404.""" async def redirect_handler(url: str) -> None: @@ -1833,7 +1841,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_www_authenticate_takes_priority_over_well_known( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" async def redirect_handler(url: str) -> None: @@ -1927,7 +1935,7 @@ def test_extract_field_from_www_auth_valid_cases( www_auth_header: str, field_name: str, expected_value: str, - ): + ) -> None: """Test extraction of various fields from valid WWW-Authenticate headers.""" init_response = httpx.Response( @@ -1961,7 +1969,7 @@ def test_extract_field_from_www_auth_invalid_cases( www_auth_header: str | None, field_name: str, description: str, - ): + ) -> None: """Test extraction returns None for invalid cases.""" headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} @@ -1996,11 +2004,11 @@ class TestCIMD: ("http://[::1/foo/", False), ], ) - def test_is_valid_client_metadata_url(self, url: str | None, expected: bool): + def test_is_valid_client_metadata_url(self, url: str | None, expected: bool) -> None: """Test CIMD URL validation.""" assert is_valid_client_metadata_url(url) == expected - def test_should_use_client_metadata_url_when_server_supports(self): + def test_should_use_client_metadata_url_when_server_supports(self) -> None: """Test that CIMD is used when server supports it and URL is provided.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -2010,7 +2018,7 @@ def test_should_use_client_metadata_url_when_server_supports(self): ) assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True - def test_should_not_use_client_metadata_url_when_server_does_not_support(self): + def test_should_not_use_client_metadata_url_when_server_does_not_support(self) -> None: """Test that CIMD is not used when server doesn't support it.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -2020,7 +2028,7 @@ def test_should_not_use_client_metadata_url_when_server_does_not_support(self): ) assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False - def test_should_not_use_client_metadata_url_when_not_provided(self): + def test_should_not_use_client_metadata_url_when_not_provided(self) -> None: """Test that CIMD is not used when no URL is provided.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -2030,11 +2038,11 @@ def test_should_not_use_client_metadata_url_when_not_provided(self): ) assert should_use_client_metadata_url(oauth_metadata, None) is False - def test_should_not_use_client_metadata_url_when_no_metadata(self): + def test_should_not_use_client_metadata_url_when_no_metadata(self) -> None: """Test that CIMD is not used when OAuth metadata is None.""" assert should_use_client_metadata_url(None, "https://example.com/client") is False - def test_create_client_info_from_metadata_url(self): + def test_create_client_info_from_metadata_url(self) -> None: """Test creating client info from CIMD URL.""" client_info = create_client_info_from_metadata_url( "https://example.com/client", @@ -2047,7 +2055,7 @@ def test_create_client_info_from_metadata_url(self): def test_oauth_provider_with_valid_client_metadata_url( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test OAuthClientProvider initialization with valid client_metadata_url.""" async def redirect_handler(url: str) -> None: @@ -2068,7 +2076,7 @@ async def callback_handler() -> tuple[str, str | None]: def test_oauth_provider_with_invalid_client_metadata_url_raises_error( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test OAuthClientProvider raises error for invalid client_metadata_url.""" async def redirect_handler(url: str) -> None: @@ -2091,7 +2099,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_auth_flow_uses_cimd_when_server_supports( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test that auth flow uses CIMD URL as client_id when server supports it.""" async def redirect_handler(url: str) -> None: @@ -2182,7 +2190,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ): + ) -> None: """Test that auth flow falls back to DCR when server doesn't support CIMD.""" async def redirect_handler(url: str) -> None: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 18368e6bb..62c454fde 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -96,7 +96,7 @@ def greeting_prompt(name: str) -> str: return server -async def test_client_is_initialized(app: MCPServer): +async def test_client_is_initialized(app: MCPServer) -> None: """Test that the client is initialized after entering context.""" async with Client(app) as client: assert client.initialize_result.capabilities == snapshot( @@ -110,7 +110,7 @@ async def test_client_is_initialized(app: MCPServer): assert client.initialize_result.server_info.name == "test" -async def test_client_with_simple_server(simple_server: Server): +async def test_client_with_simple_server(simple_server: Server) -> None: """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: resources = await client.list_resources() @@ -121,13 +121,13 @@ async def test_client_with_simple_server(simple_server: Server): ) -async def test_client_send_ping(app: MCPServer): +async def test_client_send_ping(app: MCPServer) -> None: async with Client(app) as client: result = await client.send_ping() assert result == snapshot(EmptyResult()) -async def test_client_list_tools(app: MCPServer): +async def test_client_list_tools(app: MCPServer) -> None: async with Client(app) as client: result = await client.list_tools() assert result == snapshot( @@ -154,7 +154,7 @@ async def test_client_list_tools(app: MCPServer): ) -async def test_client_call_tool(app: MCPServer): +async def test_client_call_tool(app: MCPServer) -> None: async with Client(app) as client: result = await client.call_tool("greet", {"name": "World"}) assert result == snapshot( @@ -165,7 +165,7 @@ async def test_client_call_tool(app: MCPServer): ) -async def test_read_resource(app: MCPServer): +async def test_read_resource(app: MCPServer) -> None: """Test reading a resource.""" async with Client(app) as client: result = await client.read_resource("test://resource") @@ -176,7 +176,7 @@ async def test_read_resource(app: MCPServer): ) -async def test_read_resource_error_propagates(): +async def test_read_resource_error_propagates() -> None: """MCPError raised by a server handler propagates to the client with its code intact.""" async def handle_read_resource( @@ -191,7 +191,7 @@ async def handle_read_resource( assert exc_info.value.error.code == 404 -async def test_get_prompt(app: MCPServer): +async def test_get_prompt(app: MCPServer) -> None: """Test getting a prompt.""" async with Client(app) as client: result = await client.get_prompt("greeting_prompt", {"name": "Alice"}) @@ -203,21 +203,21 @@ async def test_get_prompt(app: MCPServer): ) -def test_client_session_property_before_enter(app: MCPServer): +def test_client_session_property_before_enter(app: MCPServer) -> None: """Test that accessing session before context manager raises RuntimeError.""" client = Client(app) with pytest.raises(RuntimeError, match="Client must be used within an async context manager"): client.session -async def test_client_reentry_raises_runtime_error(app: MCPServer): +async def test_client_reentry_raises_runtime_error(app: MCPServer) -> None: """Test that reentering a client raises RuntimeError.""" async with Client(app) as client: with pytest.raises(RuntimeError, match="Client is already entered"): await client.__aenter__() -async def test_client_send_progress_notification(): +async def test_client_send_progress_notification() -> None: """Test sending progress notification.""" received_from_client = None event = anyio.Event() @@ -235,26 +235,26 @@ async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotif assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0}) -async def test_client_subscribe_resource(simple_server: Server): +async def test_client_subscribe_resource(simple_server: Server) -> None: async with Client(simple_server) as client: result = await client.subscribe_resource("memory://test") assert result == snapshot(EmptyResult()) -async def test_client_unsubscribe_resource(simple_server: Server): +async def test_client_unsubscribe_resource(simple_server: Server) -> None: async with Client(simple_server) as client: result = await client.unsubscribe_resource("memory://test") assert result == snapshot(EmptyResult()) -async def test_client_set_logging_level(simple_server: Server): +async def test_client_set_logging_level(simple_server: Server) -> None: """Test setting logging level.""" async with Client(simple_server) as client: result = await client.set_logging_level("debug") assert result == snapshot(EmptyResult()) -async def test_client_list_resources_with_params(app: MCPServer): +async def test_client_list_resources_with_params(app: MCPServer) -> None: """Test listing resources with params parameter.""" async with Client(app) as client: result = await client.list_resources() @@ -272,14 +272,14 @@ async def test_client_list_resources_with_params(app: MCPServer): ) -async def test_client_list_resource_templates(app: MCPServer): +async def test_client_list_resource_templates(app: MCPServer) -> None: """Test listing resource templates with params parameter.""" async with Client(app) as client: result = await client.list_resource_templates() assert result == snapshot(ListResourceTemplatesResult(resource_templates=[])) -async def test_list_prompts(app: MCPServer): +async def test_list_prompts(app: MCPServer) -> None: """Test listing prompts with params parameter.""" async with Client(app) as client: result = await client.list_prompts() @@ -296,7 +296,7 @@ async def test_list_prompts(app: MCPServer): ) -async def test_complete_with_prompt_reference(simple_server: Server): +async def test_complete_with_prompt_reference(simple_server: Server) -> None: """Test getting completions for a prompt argument.""" async with Client(simple_server) as client: ref = types.PromptReference(type="ref/prompt", name="test_prompt") @@ -304,13 +304,13 @@ async def test_complete_with_prompt_reference(simple_server: Server): assert result == snapshot(types.CompleteResult(completion=types.Completion(values=[]))) -def test_client_with_url_initializes_streamable_http_transport(): +def test_client_with_url_initializes_streamable_http_transport() -> None: with patch("mcp.client.client.streamable_http_client") as mock: _ = Client("http://localhost:8000/mcp") mock.assert_called_once_with("http://localhost:8000/mcp") -async def test_client_uses_transport_directly(app: MCPServer): +async def test_client_uses_transport_directly(app: MCPServer) -> None: transport = InMemoryTransport(app) async with Client(transport) as client: result = await client.call_tool("greet", {"name": "Transport"}) diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index f70fb9277..5055c10c3 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -13,7 +13,7 @@ @pytest.fixture -async def full_featured_server(): +async def full_featured_server() -> MCPServer: """Create a server with tools, resources, prompts, and templates.""" server = MCPServer("test") @@ -57,7 +57,7 @@ async def test_list_methods_params_parameter( full_featured_server: MCPServer, method_name: str, request_method: str, -): +) -> None: """Test that the params parameter is accepted and correctly passed to the server. Covers: list_tools, list_resources, list_prompts, list_resource_templates @@ -95,7 +95,7 @@ async def test_list_methods_params_parameter( async def test_list_tools_with_strict_server_validation( full_featured_server: MCPServer, -): +) -> None: """Test pagination with a server that validates request format strictly.""" async with Client(full_featured_server) as client: result = await client.list_tools() @@ -103,7 +103,7 @@ async def test_list_tools_with_strict_server_validation( assert len(result.tools) > 0 -async def test_list_tools_with_lowlevel_server(): +async def test_list_tools_with_lowlevel_server() -> None: """Test that list_tools works with a lowlevel Server using params.""" async def handle_list_tools( diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index be4b9a97b..44cc3b943 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -9,7 +9,7 @@ @pytest.mark.anyio -async def test_list_roots_callback(): +async def test_list_roots_callback() -> None: server = MCPServer("test") callback_return = ListRootsResult( @@ -25,7 +25,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context, message: str): + async def test_list_roots(context: Context, message: str) -> bool: roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 1598fd55f..affb4469f 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -12,7 +12,7 @@ class LoggingCollector: - def __init__(self): + def __init__(self) -> None: self.log_messages: list[LoggingMessageNotificationParams] = [] async def __call__(self, params: LoggingMessageNotificationParams) -> None: @@ -20,7 +20,7 @@ async def __call__(self, params: LoggingMessageNotificationParams) -> None: @pytest.mark.anyio -async def test_logging_callback(): +async def test_logging_callback() -> None: server = MCPServer("test") logging_collector = LoggingCollector() diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index d78197b5c..2432293ef 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -34,7 +34,7 @@ async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_basemodel(): +async def test_tool_structured_output_client_side_validation_basemodel() -> None: """Test that client validates structured content against schema for BaseModel outputs""" output_schema = { "type": "object", @@ -62,7 +62,7 @@ async def test_tool_structured_output_client_side_validation_basemodel(): @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_primitive(): +async def test_tool_structured_output_client_side_validation_primitive() -> None: """Test that client validates structured content for primitive outputs""" output_schema = { "type": "object", @@ -90,7 +90,7 @@ async def test_tool_structured_output_client_side_validation_primitive(): @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_dict_typed(): +async def test_tool_structured_output_client_side_validation_dict_typed() -> None: """Test that client validates dict[str, T] structured content""" output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} @@ -113,7 +113,7 @@ async def test_tool_structured_output_client_side_validation_dict_typed(): @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_missing_required(): +async def test_tool_structured_output_client_side_validation_missing_required() -> None: """Test that client validates missing required fields""" output_schema = { "type": "object", @@ -141,7 +141,7 @@ async def test_tool_structured_output_client_side_validation_missing_required(): @pytest.mark.anyio -async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): +async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture) -> None: """Test that client logs warning when tool is not in list_tools but has output_schema""" async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index c7bf8fafa..16ba40cd6 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, NoReturn from unittest.mock import patch import anyio @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_send_request_stream_cleanup(): +async def test_send_request_stream_cleanup() -> None: """Test that send_request properly cleans up streams when an exception occurs. This test mocks out most of the session functionality to focus on stream cleanup. @@ -43,7 +43,7 @@ def _receive_notification_adapter(self) -> TypeAdapter[Any]: request = PingRequest() # Patch the _write_stream.send method to raise an exception - async def mock_send(*args: Any, **kwargs: Any): + async def mock_send(*args: Any, **kwargs: Any) -> NoReturn: raise RuntimeError("Simulated network error") # Record the response streams before the test diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 6efcac0a5..484d14e55 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -15,7 +15,7 @@ @pytest.mark.anyio -async def test_sampling_callback(): +async def test_sampling_callback() -> None: server = MCPServer("test") callback_return = CreateMessageResult( @@ -58,7 +58,7 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: @pytest.mark.anyio -async def test_create_message_backwards_compat_single_content(): +async def test_create_message_backwards_compat_single_content() -> None: """Test backwards compatibility: create_message without tools returns single content.""" server = MCPServer("test") @@ -100,7 +100,7 @@ async def test_tool(message: str, ctx: Context) -> bool: @pytest.mark.anyio -async def test_create_message_result_with_tools_type(): +async def test_create_message_result_with_tools_type() -> None: """Test that CreateMessageResultWithTools supports content_as_list.""" # Test the type itself, not the overload (overload requires client capability setup) result = CreateMessageResultWithTools( diff --git a/tests/client/test_scope_bug_1630.py b/tests/client/test_scope_bug_1630.py index fafa51007..f273688a3 100644 --- a/tests/client/test_scope_bug_1630.py +++ b/tests/client/test_scope_bug_1630.py @@ -35,7 +35,7 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.mark.anyio -async def test_401_uses_www_auth_scope_not_resource_metadata_url(): +async def test_401_uses_www_auth_scope_not_resource_metadata_url() -> None: """Regression test for #1630: Ensure scope is extracted from WWW-Authenticate header, not the resource_metadata URL. diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f0..e90b680d5 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -28,14 +28,14 @@ @pytest.mark.anyio -async def test_client_session_initialize(): +async def test_client_session_initialize() -> None: client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) initialized_notification = None result = None - async def mock_server(): + async def mock_server() -> None: nonlocal initialized_notification session_message = await client_to_server_receive.receive() @@ -111,14 +111,14 @@ async def message_handler( # pragma: no cover @pytest.mark.anyio -async def test_client_session_custom_client_info(): +async def test_client_session_custom_client_info() -> None: client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) custom_client_info = Implementation(name="test-client", version="1.2.3") received_client_info = None - async def mock_server(): + async def mock_server() -> None: nonlocal received_client_info session_message = await client_to_server_receive.receive() @@ -169,13 +169,13 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_session_default_client_info(): +async def test_client_session_default_client_info() -> None: client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_client_info = None - async def mock_server(): + async def mock_server() -> None: nonlocal received_client_info session_message = await client_to_server_receive.receive() @@ -222,13 +222,13 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_session_version_negotiation_success(): +async def test_client_session_version_negotiation_success() -> None: """Test successful version negotiation with supported version""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) result = None - async def mock_server(): + async def mock_server() -> None: session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -278,12 +278,12 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_session_version_negotiation_failure(): +async def test_client_session_version_negotiation_failure() -> None: """Test version negotiation failure with unsupported version""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - async def mock_server(): + async def mock_server() -> None: session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -326,14 +326,14 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_capabilities_default(): +async def test_client_capabilities_default() -> None: """Test that client capabilities are properly set with default callbacks""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -382,7 +382,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_capabilities_with_custom_callbacks(): +async def test_client_capabilities_with_custom_callbacks() -> None: """Test that client capabilities are properly set with custom callbacks""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -404,7 +404,7 @@ async def custom_list_roots_callback( # pragma: no cover ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -466,7 +466,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_capabilities_with_sampling_tools(): +async def test_client_capabilities_with_sampling_tools() -> None: """Test that sampling capabilities with tools are properly advertised""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -483,7 +483,7 @@ async def custom_sampling_callback( # pragma: no cover model="test-model", ) - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -540,7 +540,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_initialize_result(): +async def test_initialize_result() -> None: """Test that initialize_result is None before init and contains the full result after.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -554,7 +554,7 @@ async def test_initialize_result(): expected_server_info = Implementation(name="mock-server", version="0.1.0") expected_instructions = "Use the tools wisely." - async def mock_server(): + async def mock_server() -> None: session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -608,14 +608,14 @@ async def mock_server(): @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) -async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None): +async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None) -> None: """Test that client tool call requests can include metadata""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) mocked_tool = types.Tool(name="sample_tool", input_schema={}) - async def mock_server(): + async def mock_server() -> None: # Receive initialization request from client session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..f7f101a7d 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -17,7 +17,7 @@ @pytest.fixture -def mock_exit_stack(): +def mock_exit_stack() -> mock.MagicMock: """Fixture for a mocked AsyncExitStack.""" # Use unittest.mock.Mock directly if needed, or just a plain object # if only attribute access/existence is needed. @@ -25,7 +25,7 @@ def mock_exit_stack(): return mock.MagicMock(spec=contextlib.AsyncExitStack) -def test_client_session_group_init(): +def test_client_session_group_init() -> None: mcp_session_group = ClientSessionGroup() assert not mcp_session_group._tools assert not mcp_session_group._resources @@ -33,7 +33,7 @@ def test_client_session_group_init(): assert not mcp_session_group._tool_to_session -def test_client_session_group_component_properties(): +def test_client_session_group_component_properties() -> None: # --- Mock Dependencies --- mock_prompt = mock.Mock() mock_resource = mock.Mock() @@ -52,7 +52,7 @@ def test_client_session_group_component_properties(): @pytest.mark.anyio -async def test_client_session_group_call_tool(): +async def test_client_session_group_call_tool() -> None: # --- Mock Dependencies --- mock_session = mock.AsyncMock() @@ -87,7 +87,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov @pytest.mark.anyio -async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack): +async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack) -> None: """Test connecting to a server and aggregating components.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -126,7 +126,9 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli @pytest.mark.anyio -async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): +async def test_client_session_group_connect_to_server_with_name_hook( + mock_exit_stack: contextlib.AsyncExitStack, +) -> None: """Test connecting with a component name hook.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -157,7 +159,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str: @pytest.mark.anyio -async def test_client_session_group_disconnect_from_server(): +async def test_client_session_group_disconnect_from_server() -> None: """Test disconnecting from a server.""" # --- Test Setup --- group = ClientSessionGroup() @@ -224,7 +226,7 @@ async def test_client_session_group_disconnect_from_server(): @pytest.mark.anyio async def test_client_session_group_connect_to_server_duplicate_tool_raises_error( mock_exit_stack: contextlib.AsyncExitStack, -): +) -> None: """Test MCPError raised when connecting a server with a dup name.""" # --- Setup Pre-existing State --- group = ClientSessionGroup(exit_stack=mock_exit_stack) @@ -270,7 +272,7 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro @pytest.mark.anyio -async def test_client_session_group_disconnect_non_existent_server(): +async def test_client_session_group_disconnect_non_existent_server() -> None: """Test disconnecting a server that isn't connected.""" session = mock.Mock(spec=mcp.ClientSession) group = ClientSessionGroup() @@ -304,7 +306,7 @@ async def test_client_session_group_establish_session_parameterized( server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters, client_type_name: str, # Just for clarity or conditional logic if needed patch_target_for_client_func: str, -): +) -> None: with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: with mock.patch(patch_target_for_client_func) as mock_specific_client_func: mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM") diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 06e2cba4b..7965ba025 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -31,7 +31,7 @@ @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") -async def test_stdio_context_manager_exiting(): +async def test_stdio_context_manager_exiting() -> None: assert tee is not None async with stdio_client(StdioServerParameters(command=tee)) as (_, _): pass @@ -39,7 +39,7 @@ async def test_stdio_context_manager_exiting(): @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") -async def test_stdio_client(): +async def test_stdio_client() -> None: assert tee is not None server_parameters = StdioServerParameters(command=tee) @@ -71,7 +71,7 @@ async def test_stdio_client(): @pytest.mark.anyio -async def test_stdio_client_bad_path(): +async def test_stdio_client_bad_path() -> None: """Check that the connection doesn't hang if process errors.""" server_params = StdioServerParameters(command=sys.executable, args=["-c", "non-existent-file.py"]) async with stdio_client(server_params) as (read_stream, write_stream): @@ -86,7 +86,7 @@ async def test_stdio_client_bad_path(): @pytest.mark.anyio -async def test_stdio_client_nonexistent_command(): +async def test_stdio_client_nonexistent_command() -> None: """Test that stdio_client raises an error for non-existent commands.""" # Create a server with a non-existent command server_params = StdioServerParameters( @@ -104,7 +104,7 @@ async def test_stdio_client_nonexistent_command(): @pytest.mark.anyio -async def test_stdio_client_universal_cleanup(): +async def test_stdio_client_universal_cleanup() -> None: """Test that stdio_client completes cleanup within reasonable time even when connected to processes that exit slowly. """ @@ -156,7 +156,7 @@ async def test_stdio_client_universal_cleanup(): @pytest.mark.anyio @pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different") -async def test_stdio_client_sigint_only_process(): # pragma: lax no cover +async def test_stdio_client_sigint_only_process() -> None: # pragma: lax no cover """Test cleanup with a process that ignores SIGTERM but responds to SIGINT.""" # Create a Python script that ignores SIGTERM but handles SIGINT script_content = textwrap.dedent( @@ -353,7 +353,7 @@ class TestChildProcessCleanup: """ @pytest.mark.anyio - async def test_basic_child_process_cleanup(self): + async def test_basic_child_process_cleanup(self) -> None: """Parent spawns one child; terminating the tree kills both.""" async with AsyncExitStack() as stack: sock, port = await _open_liveness_listener() @@ -377,7 +377,7 @@ async def test_basic_child_process_cleanup(self): await _assert_stream_closed(stream) @pytest.mark.anyio - async def test_nested_process_tree(self): + async def test_nested_process_tree(self) -> None: """Parent → child → grandchild; terminating the tree kills all three.""" async with AsyncExitStack() as stack: sock, port = await _open_liveness_listener() @@ -413,7 +413,7 @@ async def test_nested_process_tree(self): await _assert_stream_closed(stream) @pytest.mark.anyio - async def test_early_parent_exit(self): + async def test_early_parent_exit(self) -> None: """Parent exits immediately on SIGTERM; process-group termination still catches the child (exercises the race where the parent dies mid-cleanup). """ @@ -447,7 +447,7 @@ async def test_early_parent_exit(self): @pytest.mark.anyio -async def test_stdio_client_graceful_stdin_exit(): +async def test_stdio_client_graceful_stdin_exit() -> None: """Test that a process exits gracefully when stdin is closed, without needing SIGTERM or SIGKILL. """ @@ -502,7 +502,7 @@ async def test_stdio_client_graceful_stdin_exit(): @pytest.mark.anyio -async def test_stdio_client_stdin_close_ignored(): +async def test_stdio_client_stdin_close_ignored() -> None: """Test that when a process ignores stdin closure, the shutdown sequence properly escalates to SIGTERM. """ diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index c8fc41fd5..406f69f40 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -50,7 +50,7 @@ def test_resource() -> str: # pragma: no cover pytestmark = pytest.mark.anyio -async def test_with_server(simple_server: Server): +async def test_with_server(simple_server: Server) -> None: """Test creating transport with a Server instance.""" transport = InMemoryTransport(simple_server) async with transport as (read_stream, write_stream): @@ -58,7 +58,7 @@ async def test_with_server(simple_server: Server): assert write_stream is not None -async def test_with_mcpserver(mcpserver_server: MCPServer): +async def test_with_mcpserver(mcpserver_server: MCPServer) -> None: """Test creating transport with an MCPServer instance.""" transport = InMemoryTransport(mcpserver_server) async with transport as (read_stream, write_stream): @@ -66,13 +66,13 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): assert write_stream is not None -async def test_server_is_running(mcpserver_server: MCPServer): +async def test_server_is_running(mcpserver_server: MCPServer) -> None: """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: assert client.initialize_result.capabilities.tools is not None -async def test_list_tools(mcpserver_server: MCPServer): +async def test_list_tools(mcpserver_server: MCPServer) -> None: """Test listing tools through the transport.""" async with Client(mcpserver_server) as client: tools_result = await client.list_tools() @@ -81,7 +81,7 @@ async def test_list_tools(mcpserver_server: MCPServer): assert "greet" in tool_names -async def test_call_tool(mcpserver_server: MCPServer): +async def test_call_tool(mcpserver_server: MCPServer) -> None: """Test calling a tool through the transport.""" async with Client(mcpserver_server) as client: result = await client.call_tool("greet", {"name": "World"}) @@ -90,7 +90,7 @@ async def test_call_tool(mcpserver_server: MCPServer): assert "Hello, World!" in str(result.content[0]) -async def test_raise_exceptions(mcpserver_server: MCPServer): +async def test_raise_exceptions(mcpserver_server: MCPServer) -> None: """Test that raise_exceptions parameter is passed through.""" transport = InMemoryTransport(mcpserver_server, raise_exceptions=True) async with transport as (read_stream, _write_stream): diff --git a/tests/conftest.py b/tests/conftest.py index af7e47993..5c53fe0ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,5 +2,5 @@ @pytest.fixture -def anyio_backend(): +def anyio_backend() -> str: return "asyncio" diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index 1ea2199e8..6bbc5699b 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -21,14 +21,14 @@ @pytest.mark.anyio -async def test_client_capabilities_without_tasks(): +async def test_client_capabilities_without_tasks() -> None: """Test that tasks capability is None when not provided.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -78,7 +78,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_capabilities_with_tasks(): +async def test_client_capabilities_with_tasks() -> None: """Test that tasks capability is properly set when handlers are provided.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -98,7 +98,7 @@ async def my_cancel_task_handler( ) -> types.CancelTaskResult | types.ErrorData: raise NotImplementedError - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -158,7 +158,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_capabilities_auto_built_from_handlers(): +async def test_client_capabilities_auto_built_from_handlers() -> None: """Test that tasks capability is automatically built from provided handlers.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -178,7 +178,7 @@ async def my_cancel_task_handler( ) -> types.CancelTaskResult | types.ErrorData: raise NotImplementedError - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -239,7 +239,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_client_capabilities_with_task_augmented_handlers(): +async def test_client_capabilities_with_task_augmented_handlers() -> None: """Test that requests capability is built when augmented handlers are provided.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -254,7 +254,7 @@ async def my_augmented_sampling_handler( ) -> types.CreateTaskResult | types.ErrorData: raise NotImplementedError - async def mock_server(): + async def mock_server() -> None: nonlocal received_capabilities session_message = await client_to_server_receive.receive() diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index 613c794eb..fc3361f99 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -1,7 +1,7 @@ """Tests for the experimental client task methods (session.experimental).""" -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass, field import anyio @@ -72,7 +72,9 @@ async def do_work() -> None: raise NotImplementedError -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): +def _make_lifespan( + store: InMemoryTaskStore, task_done_events: dict[str, Event] +) -> Callable[[Server[AppContext]], AbstractAsyncContextManager[AppContext]]: @asynccontextmanager async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: async with anyio.create_task_group() as tg: diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index b5b79033d..0f098152d 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -8,8 +8,8 @@ 5. Client retrieves result with tasks/result """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass, field import anyio @@ -49,7 +49,9 @@ class AppContext: task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): +def _make_lifespan( + store: InMemoryTaskStore, task_done_events: dict[str, Event] +) -> Callable[[Server[AppContext]], AbstractAsyncContextManager[AppContext]]: @asynccontextmanager async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: async with anyio.create_task_group() as tg: diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index e59fb632d..259970cce 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -5,7 +5,7 @@ pytestmark = pytest.mark.anyio -async def test_list_tools_returns_all_tools(): +async def test_list_tools_returns_all_tools() -> None: mcp = MCPServer("TestTools") # Create 100 tools with unique names @@ -13,7 +13,7 @@ async def test_list_tools_returns_all_tools(): for i in range(num_tools): @mcp.tool(name=f"tool_{i}") - def dummy_tool_func(): # pragma: no cover + def dummy_tool_func() -> int: # pragma: no cover f"""Tool number {i}""" return i diff --git a/tests/issues/test_1027_win_unreachable_cleanup.py b/tests/issues/test_1027_win_unreachable_cleanup.py index c59c5aeca..ac2ca8937 100644 --- a/tests/issues/test_1027_win_unreachable_cleanup.py +++ b/tests/issues/test_1027_win_unreachable_cleanup.py @@ -21,7 +21,7 @@ @pytest.mark.anyio -async def test_lifespan_cleanup_executed(): +async def test_lifespan_cleanup_executed() -> None: """Regression test ensuring MCP server cleanup code runs during shutdown. This test verifies that the fix for issue #1027 works correctly by: @@ -121,7 +121,7 @@ def echo(text: str) -> str: @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") -async def test_stdin_close_triggers_cleanup(): +async def test_stdin_close_triggers_cleanup() -> None: """Regression test verifying the stdin-based graceful shutdown mechanism. This test ensures the core fix for issue #1027 continues to work by: diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index bb4735121..43931290f 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -5,7 +5,7 @@ @pytest.mark.anyio -async def test_resource_templates(): +async def test_resource_templates() -> None: mcp = MCPServer("Demo") @mcp.resource("greeting://{name}") diff --git a/tests/issues/test_1338_icons_and_metadata.py b/tests/issues/test_1338_icons_and_metadata.py index a003f75b8..985d92865 100644 --- a/tests/issues/test_1338_icons_and_metadata.py +++ b/tests/issues/test_1338_icons_and_metadata.py @@ -8,7 +8,7 @@ pytestmark = pytest.mark.anyio -async def test_icons_and_website_url(): +async def test_icons_and_website_url() -> None: """Test that icons and websiteUrl are properly returned in API calls.""" # Create test icon @@ -92,7 +92,7 @@ def test_resource_template(city: str) -> str: # pragma: no cover assert template.icons[0].src == test_icon.src -async def test_multiple_icons(): +async def test_multiple_icons() -> None: """Test that multiple icons can be added to tools, resources, and prompts.""" # Create multiple test icons @@ -119,7 +119,7 @@ def multi_icon_tool() -> str: # pragma: no cover assert tool.icons[2].sizes == ["64x64"] -async def test_no_icons_or_website(): +async def test_no_icons_or_website() -> None: """Test that server works without icons or websiteUrl.""" mcp = MCPServer("BasicServer") diff --git a/tests/issues/test_1363_race_condition_streamable_http.py b/tests/issues/test_1363_race_condition_streamable_http.py index db2a82d07..a813c596d 100644 --- a/tests/issues/test_1363_race_condition_streamable_http.py +++ b/tests/issues/test_1363_race_condition_streamable_http.py @@ -33,7 +33,7 @@ class RaceConditionTestServer(Server): - def __init__(self): + def __init__(self) -> None: super().__init__(SERVER_NAME) @@ -64,7 +64,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: class ServerThread(threading.Thread): """Thread that runs the ASGI application lifespan in a separate event loop.""" - def __init__(self, app: Starlette): + def __init__(self, app: Starlette) -> None: super().__init__(daemon=True) self.app = app self._stop_event = threading.Event() @@ -73,7 +73,7 @@ def run(self) -> None: """Run the lifespan in a new event loop.""" # Create a new event loop for this thread - async def run_lifespan(): + async def run_lifespan() -> None: # Use the lifespan context (always present in our tests) lifespan_context = getattr(self.app.router, "lifespan_context", None) assert lifespan_context is not None # Tests always create apps with lifespan @@ -119,7 +119,7 @@ def check_logs_for_race_condition_errors(caplog: pytest.LogCaptureFixture, test_ @pytest.mark.anyio -async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFixture): +async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFixture) -> None: """Test the race condition with invalid Accept headers. This test reproduces the exact scenario described in issue #1363: @@ -193,7 +193,7 @@ async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFi @pytest.mark.anyio -async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixture): +async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixture) -> None: """Test the race condition with invalid Content-Type headers. This test reproduces the race condition scenario with Content-Type validation failure. @@ -233,7 +233,7 @@ async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixt @pytest.mark.anyio -async def test_race_condition_message_router_async_for(caplog: pytest.LogCaptureFixture): +async def test_race_condition_message_router_async_for(caplog: pytest.LogCaptureFixture) -> None: """Uses json_response=True to trigger the `if self.is_json_response_enabled` branch, which reproduces the ClosedResourceError when message_router is suspended in async for loop while transport cleanup closes streams concurrently. diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index f5c5081c3..04cff1271 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -10,7 +10,7 @@ @pytest.mark.anyio -async def test_resource_template_edge_cases(): +async def test_resource_template_edge_cases() -> None: """Test server-side resource template validation""" mcp = MCPServer("Demo") @@ -63,7 +63,7 @@ def get_user_profile_missing(user_id: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_resource_template_client_interaction(): +async def test_resource_template_client_interaction() -> None: """Test client-side resource template interaction""" mcp = MCPServer("Demo") diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index 851e89979..7ae03198b 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -17,7 +17,7 @@ pytestmark = pytest.mark.anyio -async def test_mcpserver_resource_mime_type(): +async def test_mcpserver_resource_mime_type() -> None: """Test that mime_type parameter is respected for resources.""" mcp = MCPServer("test") @@ -63,7 +63,7 @@ def get_image_as_bytes() -> bytes: assert bytes_result.contents[0].mime_type == "image/png", "Bytes content mime type not preserved" -async def test_lowlevel_resource_mime_type(): +async def test_lowlevel_resource_mime_type() -> None: """Test that mime_type parameter is respected for resources.""" # Create a small test image as bytes diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index c67708128..d992e2bd9 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -25,7 +25,7 @@ pytestmark = pytest.mark.anyio -async def test_relative_uri_roundtrip(): +async def test_relative_uri_roundtrip() -> None: """Relative URIs survive the full server-client JSON-RPC roundtrip. This is the critical regression test - if someone reintroduces AnyUrl, @@ -67,7 +67,7 @@ async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRe assert result.contents[0].uri == uri_str -async def test_custom_scheme_uri_roundtrip(): +async def test_custom_scheme_uri_roundtrip() -> None: """Custom scheme URIs work through the protocol. Some MCP servers use custom schemes like "custom://resource". @@ -103,7 +103,7 @@ async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRe assert len(result.contents) == 1 -def test_uri_json_roundtrip_preserves_value(): +def test_uri_json_roundtrip_preserves_value() -> None: """URI is preserved exactly through JSON serialization. This catches any Pydantic validation or normalization that would @@ -125,7 +125,7 @@ def test_uri_json_roundtrip_preserves_value(): assert restored.uri == uri_str, f"URI mutated: {uri_str} -> {restored.uri}" -def test_resource_contents_uri_json_roundtrip(): +def test_resource_contents_uri_json_roundtrip() -> None: """TextResourceContents URI is preserved through JSON serialization.""" test_uris = ["users/me", "./relative", "custom://resource"] diff --git a/tests/issues/test_1754_mime_type_parameters.py b/tests/issues/test_1754_mime_type_parameters.py index 7903fd560..a798fcab1 100644 --- a/tests/issues/test_1754_mime_type_parameters.py +++ b/tests/issues/test_1754_mime_type_parameters.py @@ -12,7 +12,7 @@ pytestmark = pytest.mark.anyio -async def test_mime_type_with_parameters(): +async def test_mime_type_with_parameters() -> None: """Test that MIME types with parameters are accepted (RFC 2045).""" mcp = MCPServer("test") @@ -26,7 +26,7 @@ def widget() -> str: assert resources[0].mime_type == "text/html;profile=mcp-app" -async def test_mime_type_with_parameters_and_space(): +async def test_mime_type_with_parameters_and_space() -> None: """Test MIME type with space after semicolon.""" mcp = MCPServer("test") @@ -39,7 +39,7 @@ def data() -> str: assert resources[0].mime_type == "application/json; charset=utf-8" -async def test_mime_type_with_multiple_parameters(): +async def test_mime_type_with_multiple_parameters() -> None: """Test MIME type with multiple parameters.""" mcp = MCPServer("test") @@ -52,7 +52,7 @@ def data() -> str: assert resources[0].mime_type == "text/plain; charset=utf-8; format=fixed" -async def test_mime_type_preserved_in_read_resource(): +async def test_mime_type_preserved_in_read_resource() -> None: """Test that MIME type with parameters is preserved when reading resource.""" mcp = MCPServer("test") diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 5d5f8b8fc..21cad62b6 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -9,7 +9,7 @@ pytestmark = pytest.mark.anyio -async def test_progress_token_zero_first_call(): +async def test_progress_token_zero_first_call() -> None: """Test that progress notifications work when progress_token is 0 on first call.""" # Create mock session with progress notification tracking diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 0e11f6148..61d113341 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -6,14 +6,14 @@ @pytest.mark.anyio -async def test_messages_are_executed_concurrently_tools(): +async def test_messages_are_executed_concurrently_tools() -> None: server = MCPServer("test") event = anyio.Event() tool_started = anyio.Event() call_order: list[str] = [] @server.tool("sleep") - async def sleep_tool(): + async def sleep_tool() -> str: call_order.append("waiting_for_event") tool_started.set() await event.wait() @@ -21,7 +21,7 @@ async def sleep_tool(): return "done" @server.tool("trigger") - async def trigger(): + async def trigger() -> str: # Wait for tool to start before setting the event await tool_started.wait() call_order.append("trigger_started") @@ -47,14 +47,14 @@ async def trigger(): @pytest.mark.anyio -async def test_messages_are_executed_concurrently_tools_and_resources(): +async def test_messages_are_executed_concurrently_tools_and_resources() -> None: server = MCPServer("test") event = anyio.Event() tool_started = anyio.Event() call_order: list[str] = [] @server.tool("sleep") - async def sleep_tool(): + async def sleep_tool() -> str: call_order.append("waiting_for_event") tool_started.set() await event.wait() @@ -62,7 +62,7 @@ async def sleep_tool(): return "done" @server.resource("slow://slow_resource") - async def slow_resource(): + async def slow_resource() -> str: # Wait for tool to start before setting the event await tool_started.wait() event.set() diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index de96dbe23..d3c7d9497 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -28,7 +28,7 @@ async def test_request_id_match() -> None: server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) # Server task to process the request - async def run_server(): + async def run_server() -> None: async with client_reader, server_writer: await server.run( client_reader, diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 2bccedf8d..8294c82eb 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -15,7 +15,7 @@ pytestmark = pytest.mark.anyio -async def test_server_base64_encoding(): +async def test_server_base64_encoding() -> None: """Tests that binary resource data round-trips correctly through base64 encoding. The test uses binary data that produces different results with urlsafe vs standard diff --git a/tests/issues/test_355_type_error.py b/tests/issues/test_355_type_error.py index 905cf7eee..29cb03d7a 100644 --- a/tests/issues/test_355_type_error.py +++ b/tests/issues/test_355_type_error.py @@ -7,13 +7,13 @@ class Database: # Replace with your actual DB type @classmethod - async def connect(cls): # pragma: no cover + async def connect(cls) -> "Database": # pragma: no cover return cls() - async def disconnect(self): # pragma: no cover + async def disconnect(self) -> None: # pragma: no cover pass - def query(self): # pragma: no cover + def query(self) -> str: # pragma: no cover return "Hello, World!" diff --git a/tests/issues/test_552_windows_hang.py b/tests/issues/test_552_windows_hang.py index 1adb5d80c..f254e2183 100644 --- a/tests/issues/test_552_windows_hang.py +++ b/tests/issues/test_552_windows_hang.py @@ -12,7 +12,7 @@ @pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific test") # pragma: no cover @pytest.mark.anyio -async def test_windows_stdio_client_with_session(): +async def test_windows_stdio_client_with_session() -> None: """Test the exact scenario from issue #552: Using ClientSession with stdio_client. This reproduces the original bug report where stdio_client hangs on Windows 11 diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 6b593d2a5..92d2ee3d6 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -16,7 +16,7 @@ @pytest.mark.anyio -async def test_notification_validation_error(tmp_path: Path): +async def test_notification_validation_error(tmp_path: Path) -> None: """Test that timeouts are handled gracefully and don't break the server. This test verifies that when a client request times out: @@ -67,7 +67,7 @@ async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ): + ) -> None: with anyio.CancelScope() as scope: task_status.started(scope) # type: ignore await server.run( @@ -81,7 +81,7 @@ async def client( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], scope: anyio.CancelScope, - ): + ) -> None: # No session-level timeout to avoid race conditions with fast operations async with ClientSession(read_stream, write_stream) as session: await session.initialize() diff --git a/tests/issues/test_973_url_decoding.py b/tests/issues/test_973_url_decoding.py index 01cf222b9..65214c61d 100644 --- a/tests/issues/test_973_url_decoding.py +++ b/tests/issues/test_973_url_decoding.py @@ -6,7 +6,7 @@ from mcp.server.mcpserver.resources import ResourceTemplate -def test_template_matches_decodes_space(): +def test_template_matches_decodes_space() -> None: """Test that %20 is decoded to space.""" def search(query: str) -> str: # pragma: no cover @@ -23,7 +23,7 @@ def search(query: str) -> str: # pragma: no cover assert params["query"] == "hello world" -def test_template_matches_decodes_accented_characters(): +def test_template_matches_decodes_accented_characters() -> None: """Test that %C3%A9 is decoded to e with accent.""" def search(query: str) -> str: # pragma: no cover @@ -40,7 +40,7 @@ def search(query: str) -> str: # pragma: no cover assert params["query"] == "café" -def test_template_matches_decodes_complex_phrase(): +def test_template_matches_decodes_complex_phrase() -> None: """Test complex French phrase from the original issue.""" def search(query: str) -> str: # pragma: no cover @@ -57,7 +57,7 @@ def search(query: str) -> str: # pragma: no cover assert params["query"] == "stick correcteur teinté anti-imperfections" -def test_template_matches_preserves_plus_sign(): +def test_template_matches_preserves_plus_sign() -> None: """Test that plus sign remains as plus (not converted to space). In URI encoding, %20 is space. Plus-as-space is only for diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index da586f309..3c40f8157 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_malformed_initialize_request_does_not_crash_server(): +async def test_malformed_initialize_request_does_not_crash_server() -> None: """Test that malformed initialize requests return proper error responses instead of crashing the server (HackerOne #3156202). """ @@ -91,7 +91,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): @pytest.mark.anyio -async def test_multiple_concurrent_malformed_requests(): +async def test_multiple_concurrent_malformed_requests() -> None: """Test that multiple concurrent malformed requests don't crash the server.""" # Create in-memory streams for testing read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100) diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index 66481bcf7..a844bc368 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -17,7 +17,7 @@ class MockApp: """Mock ASGI app for testing.""" - def __init__(self): + def __init__(self) -> None: self.called = False self.scope: Scope | None = None self.receive: Receive | None = None @@ -45,7 +45,7 @@ def valid_access_token() -> AccessToken: @pytest.mark.anyio -async def test_auth_context_middleware_with_authenticated_user(valid_access_token: AccessToken): +async def test_auth_context_middleware_with_authenticated_user(valid_access_token: AccessToken) -> None: """Test middleware with an authenticated user in scope.""" app = MockApp() middleware = AuthContextMiddleware(app) @@ -84,7 +84,7 @@ async def send(message: Message) -> None: # pragma: no cover @pytest.mark.anyio -async def test_auth_context_middleware_with_no_user(): +async def test_auth_context_middleware_with_no_user() -> None: """Test middleware with no user in scope.""" app = MockApp() middleware = AuthContextMiddleware(app) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index bd14e294c..1baf5818b 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -20,7 +20,7 @@ class MockOAuthProvider: the BearerAuthMiddleware components. """ - def __init__(self): + def __init__(self) -> None: self.tokens: dict[str, AccessToken] = {} # token -> AccessToken def add_token(self, token: str, access_token: AccessToken) -> None: @@ -49,7 +49,7 @@ def add_token_to_provider( class MockApp: """Mock ASGI app for testing.""" - def __init__(self): + def __init__(self) -> None: self.called = False self.scope: Scope | None = None self.receive: Receive | None = None @@ -106,14 +106,16 @@ def no_expiry_access_token() -> AccessToken: class TestBearerAuthBackend: """Tests for the BearerAuthBackend class.""" - async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): + async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]) -> None: """Test authentication with no Authorization header.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request({"type": "http", "headers": []}) result = await backend.authenticate(request) assert result is None - async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): + async def test_non_bearer_auth_header( + self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] + ) -> None: """Test authentication with non-Bearer Authorization header.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( @@ -125,7 +127,7 @@ async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizat result = await backend.authenticate(request) assert result is None - async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): + async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]) -> None: """Test authentication with invalid token.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( @@ -141,7 +143,7 @@ async def test_expired_token( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], expired_access_token: AccessToken, - ): + ) -> None: """Test authentication with expired token.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) @@ -158,7 +160,7 @@ async def test_valid_token( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ): + ) -> None: """Test authentication with valid token.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -182,7 +184,7 @@ async def test_token_without_expiry( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], no_expiry_access_token: AccessToken, - ): + ) -> None: """Test authentication with token that has no expiry.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) @@ -206,7 +208,7 @@ async def test_lowercase_bearer_prefix( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ): + ) -> None: """Test with lowercase 'bearer' prefix in Authorization header""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -226,7 +228,7 @@ async def test_mixed_case_bearer_prefix( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ): + ) -> None: """Test with mixed 'BeArEr' prefix in Authorization header""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -246,7 +248,7 @@ async def test_mixed_case_authorization_header( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ): + ) -> None: """Test authentication with mixed 'Authorization' header.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -267,7 +269,7 @@ async def test_mixed_case_authorization_header( class TestRequireAuthMiddleware: """Tests for the RequireAuthMiddleware class.""" - async def test_no_user(self): + async def test_no_user(self) -> None: """Test middleware with no user in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -291,7 +293,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_non_authenticated_user(self): + async def test_non_authenticated_user(self) -> None: """Test middleware with non-authenticated user in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -315,7 +317,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_missing_required_scope(self, valid_access_token: AccessToken): + async def test_missing_required_scope(self, valid_access_token: AccessToken) -> None: """Test middleware with user missing required scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["admin"]) @@ -344,7 +346,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_no_auth_credentials(self, valid_access_token: AccessToken): + async def test_no_auth_credentials(self, valid_access_token: AccessToken) -> None: """Test middleware with no auth credentials in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -372,7 +374,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_has_required_scopes(self, valid_access_token: AccessToken): + async def test_has_required_scopes(self, valid_access_token: AccessToken) -> None: """Test middleware with user having all required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -397,7 +399,7 @@ async def send(message: Message) -> None: # pragma: no cover assert app.receive == receive assert app.send == send - async def test_multiple_required_scopes(self, valid_access_token: AccessToken): + async def test_multiple_required_scopes(self, valid_access_token: AccessToken) -> None: """Test middleware with multiple required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"]) @@ -422,7 +424,7 @@ async def send(message: Message) -> None: # pragma: no cover assert app.receive == receive assert app.send == send - async def test_no_required_scopes(self, valid_access_token: AccessToken): + async def test_no_required_scopes(self, valid_access_token: AccessToken) -> None: """Test middleware with no required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=[]) diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index 7c5c43582..977106078 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -20,13 +20,13 @@ @pytest.fixture -def oauth_provider(): +def oauth_provider() -> MockOAuthProvider: """Return a MockOAuthProvider instance that can be configured to raise errors.""" return MockOAuthProvider() @pytest.fixture -def app(oauth_provider: MockOAuthProvider): +def app(oauth_provider: MockOAuthProvider) -> Starlette: # Enable client registration client_registration_options = ClientRegistrationOptions(enabled=True) revocation_options = RevocationOptions(enabled=True) @@ -44,14 +44,14 @@ def app(oauth_provider: MockOAuthProvider): @pytest.fixture -def client(app: Starlette): +def client(app: Starlette) -> httpx.AsyncClient: transport = ASGITransport(app=app) # Use base_url without a path since routes are directly on the app return httpx.AsyncClient(transport=transport, base_url="http://localhost") @pytest.fixture -def pkce_challenge(): +def pkce_challenge() -> dict[str, str]: """Create a PKCE challenge with code_verifier and code_challenge.""" # Generate a code verifier code_verifier = secrets.token_urlsafe(64)[:128] @@ -84,7 +84,7 @@ async def registered_client(client: httpx.AsyncClient) -> dict[str, Any]: @pytest.mark.anyio -async def test_registration_error_handling(client: httpx.AsyncClient, oauth_provider: MockOAuthProvider): +async def test_registration_error_handling(client: httpx.AsyncClient, oauth_provider: MockOAuthProvider) -> None: # Mock the register_client method to raise a registration error with unittest.mock.patch.object( oauth_provider, @@ -122,7 +122,7 @@ async def test_authorize_error_handling( oauth_provider: MockOAuthProvider, registered_client: dict[str, Any], pkce_challenge: dict[str, str], -): +) -> None: # Mock the authorize method to raise an authorize error with unittest.mock.patch.object( oauth_provider, @@ -163,7 +163,7 @@ async def test_token_error_handling_auth_code( oauth_provider: MockOAuthProvider, registered_client: dict[str, Any], pkce_challenge: dict[str, str], -): +) -> None: # Register the client and get an auth code client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] @@ -222,7 +222,7 @@ async def test_token_error_handling_refresh_token( oauth_provider: MockOAuthProvider, registered_client: dict[str, Any], pkce_challenge: dict[str, str], -): +) -> None: # Register the client and get tokens client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] diff --git a/tests/server/auth/test_protected_resource.py b/tests/server/auth/test_protected_resource.py index 413a80276..35fd3de2a 100644 --- a/tests/server/auth/test_protected_resource.py +++ b/tests/server/auth/test_protected_resource.py @@ -1,5 +1,6 @@ """Integration tests for MCP Oauth Protected Resource.""" +from collections.abc import AsyncGenerator from urllib.parse import urlparse import httpx @@ -12,7 +13,7 @@ @pytest.fixture -def test_app(): +def test_app() -> Starlette: """Fixture to create protected resource routes for testing.""" # Create the protected resource routes @@ -29,14 +30,14 @@ def test_app(): @pytest.fixture -async def test_client(test_app: Starlette): +async def test_client(test_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: """Fixture to create an HTTP client for the protected resource app.""" async with httpx.AsyncClient(transport=httpx.ASGITransport(app=test_app), base_url="https://mcptest.com") as client: yield client @pytest.mark.anyio -async def test_metadata_endpoint_with_path(test_client: httpx.AsyncClient): +async def test_metadata_endpoint_with_path(test_client: httpx.AsyncClient) -> None: """Test the OAuth 2.0 Protected Resource metadata endpoint for path-based resource.""" # For resource with path "/resource", metadata should be accessible at the path-aware location @@ -54,7 +55,7 @@ async def test_metadata_endpoint_with_path(test_client: httpx.AsyncClient): @pytest.mark.anyio -async def test_metadata_endpoint_root_path_returns_404(test_client: httpx.AsyncClient): +async def test_metadata_endpoint_root_path_returns_404(test_client: httpx.AsyncClient) -> None: """Test that root path returns 404 for path-based resource.""" # Root path should return 404 for path-based resources @@ -63,7 +64,7 @@ async def test_metadata_endpoint_root_path_returns_404(test_client: httpx.AsyncC @pytest.fixture -def root_resource_app(): +def root_resource_app() -> Starlette: """Fixture to create protected resource routes for root-level resource.""" # Create routes for a resource without path component @@ -79,7 +80,7 @@ def root_resource_app(): @pytest.fixture -async def root_resource_client(root_resource_app: Starlette): +async def root_resource_client(root_resource_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: """Fixture to create an HTTP client for the root resource app.""" async with httpx.AsyncClient( transport=httpx.ASGITransport(app=root_resource_app), base_url="https://mcptest.com" @@ -88,7 +89,7 @@ async def root_resource_client(root_resource_app: Starlette): @pytest.mark.anyio -async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncClient): +async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncClient) -> None: """Test metadata endpoint for root-level resource.""" # For root resource, metadata should be at standard location @@ -108,21 +109,21 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC # Tests for URL construction utility function -def test_metadata_url_construction_url_without_path(): +def test_metadata_url_construction_url_without_path() -> None: """Test URL construction for resource without path component.""" resource_url = AnyHttpUrl("https://example.com") result = build_resource_metadata_url(resource_url) assert str(result) == "https://example.com/.well-known/oauth-protected-resource" -def test_metadata_url_construction_url_with_path_component(): +def test_metadata_url_construction_url_with_path_component() -> None: """Test URL construction for resource with path component.""" resource_url = AnyHttpUrl("https://example.com/mcp") result = build_resource_metadata_url(resource_url) assert str(result) == "https://example.com/.well-known/oauth-protected-resource/mcp" -def test_metadata_url_construction_url_with_trailing_slash_only(): +def test_metadata_url_construction_url_with_trailing_slash_only() -> None: """Test URL construction for resource with trailing slash only.""" resource_url = AnyHttpUrl("https://example.com/") result = build_resource_metadata_url(resource_url) @@ -139,7 +140,7 @@ def test_metadata_url_construction_url_with_trailing_slash_only(): ("http://localhost:8001/mcp", "http://localhost:8001/.well-known/oauth-protected-resource/mcp"), ], ) -def test_metadata_url_construction_various_resource_configurations(resource_url: str, expected_url: str): +def test_metadata_url_construction_various_resource_configurations(resource_url: str, expected_url: str) -> None: """Test URL construction with various resource configurations.""" result = build_resource_metadata_url(AnyHttpUrl(resource_url)) assert str(result) == expected_url @@ -148,7 +149,7 @@ def test_metadata_url_construction_various_resource_configurations(resource_url: # Tests for consistency between URL generation and route registration -def test_route_consistency_route_path_matches_metadata_url(): +def test_route_consistency_route_path_matches_metadata_url() -> None: """Test that route path matches the generated metadata URL.""" resource_url = AnyHttpUrl("https://example.com/mcp") @@ -177,7 +178,7 @@ def test_route_consistency_route_path_matches_metadata_url(): ("https://example.com/mcp", "/.well-known/oauth-protected-resource/mcp"), ], ) -def test_route_consistency_consistent_paths_for_various_resources(resource_url: str, expected_path: str): +def test_route_consistency_consistent_paths_for_various_resources(resource_url: str, expected_path: str) -> None: """Test that URL generation and route creation are consistent.""" resource_url_obj = AnyHttpUrl(resource_url) diff --git a/tests/server/auth/test_provider.py b/tests/server/auth/test_provider.py index aaaeb413a..b71b6ff5b 100644 --- a/tests/server/auth/test_provider.py +++ b/tests/server/auth/test_provider.py @@ -3,7 +3,7 @@ from mcp.server.auth.provider import construct_redirect_uri -def test_construct_redirect_uri_no_existing_params(): +def test_construct_redirect_uri_no_existing_params() -> None: """Test construct_redirect_uri with no existing query parameters.""" base_uri = "http://localhost:8000/callback" result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") @@ -11,7 +11,7 @@ def test_construct_redirect_uri_no_existing_params(): assert "http://localhost:8000/callback?code=auth_code&state=test_state" == result -def test_construct_redirect_uri_with_existing_params(): +def test_construct_redirect_uri_with_existing_params() -> None: """Test construct_redirect_uri with existing query parameters (regression test for #1279).""" base_uri = "http://localhost:8000/callback?session_id=1234" result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") @@ -23,7 +23,7 @@ def test_construct_redirect_uri_with_existing_params(): assert result.startswith("http://localhost:8000/callback?") -def test_construct_redirect_uri_multiple_existing_params(): +def test_construct_redirect_uri_multiple_existing_params() -> None: """Test construct_redirect_uri with multiple existing query parameters.""" base_uri = "http://localhost:8000/callback?session_id=1234&user=test" result = construct_redirect_uri(base_uri, code="auth_code") @@ -33,7 +33,7 @@ def test_construct_redirect_uri_multiple_existing_params(): assert "code=auth_code" in result -def test_construct_redirect_uri_with_none_values(): +def test_construct_redirect_uri_with_none_values() -> None: """Test construct_redirect_uri filters out None values.""" base_uri = "http://localhost:8000/callback" result = construct_redirect_uri(base_uri, code="auth_code", state=None) @@ -42,7 +42,7 @@ def test_construct_redirect_uri_with_none_values(): assert "state" not in result -def test_construct_redirect_uri_empty_params(): +def test_construct_redirect_uri_empty_params() -> None: """Test construct_redirect_uri with no additional parameters.""" base_uri = "http://localhost:8000/callback?existing=param" result = construct_redirect_uri(base_uri) @@ -50,7 +50,7 @@ def test_construct_redirect_uri_empty_params(): assert result == "http://localhost:8000/callback?existing=param" -def test_construct_redirect_uri_duplicate_param_names(): +def test_construct_redirect_uri_duplicate_param_names() -> None: """Test construct_redirect_uri when adding param that already exists.""" base_uri = "http://localhost:8000/callback?code=existing" result = construct_redirect_uri(base_uri, code="new_code") @@ -60,7 +60,7 @@ def test_construct_redirect_uri_duplicate_param_names(): assert "code=new_code" in result -def test_construct_redirect_uri_multivalued_existing_params(): +def test_construct_redirect_uri_multivalued_existing_params() -> None: """Test construct_redirect_uri with existing multi-valued parameters.""" base_uri = "http://localhost:8000/callback?scope=read&scope=write" result = construct_redirect_uri(base_uri, code="auth_code") @@ -70,7 +70,7 @@ def test_construct_redirect_uri_multivalued_existing_params(): assert "code=auth_code" in result -def test_construct_redirect_uri_encoded_values(): +def test_construct_redirect_uri_encoded_values() -> None: """Test construct_redirect_uri handles URL encoding properly.""" base_uri = "http://localhost:8000/callback" result = construct_redirect_uri(base_uri, state="test state with spaces") diff --git a/tests/server/auth/test_routes.py b/tests/server/auth/test_routes.py index 3d13b5ba5..a910edc8b 100644 --- a/tests/server/auth/test_routes.py +++ b/tests/server/auth/test_routes.py @@ -4,44 +4,44 @@ from mcp.server.auth.routes import validate_issuer_url -def test_validate_issuer_url_https_allowed(): +def test_validate_issuer_url_https_allowed() -> None: validate_issuer_url(AnyHttpUrl("https://example.com/path")) -def test_validate_issuer_url_http_localhost_allowed(): +def test_validate_issuer_url_http_localhost_allowed() -> None: validate_issuer_url(AnyHttpUrl("http://localhost:8080/path")) -def test_validate_issuer_url_http_127_0_0_1_allowed(): +def test_validate_issuer_url_http_127_0_0_1_allowed() -> None: validate_issuer_url(AnyHttpUrl("http://127.0.0.1:8080/path")) -def test_validate_issuer_url_http_ipv6_loopback_allowed(): +def test_validate_issuer_url_http_ipv6_loopback_allowed() -> None: validate_issuer_url(AnyHttpUrl("http://[::1]:8080/path")) -def test_validate_issuer_url_http_non_loopback_rejected(): +def test_validate_issuer_url_http_non_loopback_rejected() -> None: with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): validate_issuer_url(AnyHttpUrl("http://evil.com/path")) -def test_validate_issuer_url_http_127_prefix_domain_rejected(): +def test_validate_issuer_url_http_127_prefix_domain_rejected() -> None: """A domain like 127.0.0.1.evil.com is not loopback.""" with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): validate_issuer_url(AnyHttpUrl("http://127.0.0.1.evil.com/path")) -def test_validate_issuer_url_http_127_prefix_subdomain_rejected(): +def test_validate_issuer_url_http_127_prefix_subdomain_rejected() -> None: """A domain like 127.0.0.1something.example.com is not loopback.""" with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): validate_issuer_url(AnyHttpUrl("http://127.0.0.1something.example.com/path")) -def test_validate_issuer_url_fragment_rejected(): +def test_validate_issuer_url_fragment_rejected() -> None: with pytest.raises(ValueError, match="fragment"): validate_issuer_url(AnyHttpUrl("https://example.com/path#frag")) -def test_validate_issuer_url_query_rejected(): +def test_validate_issuer_url_query_rejected() -> None: with pytest.raises(ValueError, match="query"): validate_issuer_url(AnyHttpUrl("https://example.com/path?q=1")) diff --git a/tests/server/lowlevel/test_helper_types.py b/tests/server/lowlevel/test_helper_types.py index e29273d3f..d1a8ca0a4 100644 --- a/tests/server/lowlevel/test_helper_types.py +++ b/tests/server/lowlevel/test_helper_types.py @@ -10,7 +10,7 @@ from mcp.server.lowlevel.helper_types import ReadResourceContents -def test_read_resource_contents_with_metadata(): +def test_read_resource_contents_with_metadata() -> None: """Test that ReadResourceContents accepts meta parameter. ReadResourceContents is an internal helper type used by the low-level MCP server. @@ -33,7 +33,7 @@ def test_read_resource_contents_with_metadata(): assert contents.meta["cached"] is True -def test_read_resource_contents_without_metadata(): +def test_read_resource_contents_without_metadata() -> None: """Test that ReadResourceContents meta defaults to None.""" # Ensures backward compatibility - meta defaults to None, _meta omitted from protocol (helper_types.py:11) contents = ReadResourceContents( @@ -44,7 +44,7 @@ def test_read_resource_contents_without_metadata(): assert contents.meta is None -def test_read_resource_contents_with_bytes(): +def test_read_resource_contents_with_bytes() -> None: """Test that ReadResourceContents works with bytes content and meta.""" # Verifies meta works with both str and bytes content (binary resources like images, PDFs) metadata = {"encoding": "utf-8"} diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 602f5cc75..66ce88b2d 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -5,6 +5,7 @@ import secrets import time import unittest.mock +from collections.abc import AsyncGenerator from typing import Any from urllib.parse import parse_qs, urlparse @@ -28,7 +29,7 @@ # Mock OAuth provider for testing class MockOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]): - def __init__(self): + def __init__(self) -> None: self.clients: dict[str, OAuthClientInformationFull] = {} self.auth_codes: dict[str, AuthorizationCode] = {} # code -> {client_id, code_challenge, redirect_uri} self.tokens: dict[str, AccessToken] = {} # token -> {client_id, scopes, expires_at} @@ -37,7 +38,7 @@ def __init__(self): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) - async def register_client(self, client_info: OAuthClientInformationFull): + async def register_client(self, client_info: OAuthClientInformationFull) -> None: assert client_info.client_id is not None self.clients[client_info.client_id] = client_info @@ -188,12 +189,12 @@ async def revoke_token(self, token: AccessToken | RefreshToken) -> None: @pytest.fixture -def mock_oauth_provider(): +def mock_oauth_provider() -> MockOAuthProvider: return MockOAuthProvider() @pytest.fixture -def auth_app(mock_oauth_provider: MockOAuthProvider): +def auth_app(mock_oauth_provider: MockOAuthProvider) -> Starlette: # Create auth router auth_routes = create_auth_routes( mock_oauth_provider, @@ -214,7 +215,7 @@ def auth_app(mock_oauth_provider: MockOAuthProvider): @pytest.fixture -async def test_client(auth_app: Starlette): +async def test_client(auth_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client: yield client @@ -249,7 +250,7 @@ async def registered_client( @pytest.fixture -def pkce_challenge(): +def pkce_challenge() -> dict[str, str]: """Create a PKCE challenge with code_verifier and code_challenge.""" code_verifier = "some_random_verifier_string" code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=") @@ -263,7 +264,7 @@ async def auth_code( registered_client: dict[str, Any], pkce_challenge: dict[str, str], request: pytest.FixtureRequest, -): +) -> dict[str, str | None]: """Get an authorization code. Parameters can be customized via indirect parameterization: @@ -305,7 +306,7 @@ async def auth_code( class TestAuthEndpoints: @pytest.mark.anyio - async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): + async def test_metadata_endpoint(self, test_client: httpx.AsyncClient) -> None: """Test the OAuth 2.0 metadata endpoint.""" response = await test_client.get("/.well-known/oauth-authorization-server") @@ -327,7 +328,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert metadata["service_documentation"] == "https://docs.example.com/" @pytest.mark.anyio - async def test_token_validation_error(self, test_client: httpx.AsyncClient): + async def test_token_validation_error(self, test_client: httpx.AsyncClient) -> None: """Test token endpoint error - validation error.""" # Missing required fields response = await test_client.post( @@ -350,7 +351,7 @@ async def test_token_invalid_client_secret_returns_invalid_client( registered_client: dict[str, Any], pkce_challenge: dict[str, str], mock_oauth_provider: MockOAuthProvider, - ): + ) -> None: """Test token endpoint returns 'invalid_client' for wrong client_secret per RFC 6749. RFC 6749 Section 5.2 defines: @@ -397,7 +398,7 @@ async def test_token_invalid_auth_code( test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str], - ): + ) -> None: """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -425,7 +426,7 @@ async def test_token_expired_auth_code( auth_code: dict[str, str], pkce_challenge: dict[str, str], mock_oauth_provider: MockOAuthProvider, - ): + ) -> None: """Test token endpoint error - authorization code has expired.""" # Get the current time for our time mocking current_time = time.time() @@ -479,7 +480,7 @@ async def test_token_redirect_uri_mismatch( registered_client: dict[str, Any], auth_code: dict[str, str], pkce_challenge: dict[str, str], - ): + ) -> None: """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -502,7 +503,7 @@ async def test_token_redirect_uri_mismatch( @pytest.mark.anyio async def test_token_code_verifier_mismatch( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], auth_code: dict[str, str] - ): + ) -> None: """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -523,7 +524,9 @@ async def test_token_code_verifier_mismatch( assert "incorrect code_verifier" in error_response["error_description"] @pytest.mark.anyio - async def test_token_invalid_refresh_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): + async def test_token_invalid_refresh_token( + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] + ) -> None: """Test token endpoint error - refresh token does not exist.""" # Try to use a non-existent refresh token response = await test_client.post( @@ -547,7 +550,7 @@ async def test_token_expired_refresh_token( registered_client: dict[str, Any], auth_code: dict[str, str], pkce_challenge: dict[str, str], - ): + ) -> None: """Test token endpoint error - refresh token has expired.""" # Step 1: First, let's create a token and refresh token at the current time current_time = time.time() @@ -595,7 +598,7 @@ async def test_token_invalid_scope( registered_client: dict[str, Any], auth_code: dict[str, str], pkce_challenge: dict[str, str], - ): + ) -> None: """Test token endpoint error - invalid scope in refresh token request.""" # Exchange authorization code for tokens token_response = await test_client.post( @@ -631,7 +634,9 @@ async def test_token_invalid_scope( assert "cannot request scope" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): + async def test_client_registration( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ) -> None: """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -657,7 +662,7 @@ async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oa # ) is not None @pytest.mark.anyio - async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient): + async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient) -> None: """Test client registration with missing required fields.""" # Missing redirect_uris which is a required field client_metadata = { @@ -676,7 +681,7 @@ async def test_client_registration_missing_required_fields(self, test_client: ht assert error_data["error_description"] == "redirect_uris: Field required" @pytest.mark.anyio - async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient): + async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient) -> None: """Test client registration with invalid URIs.""" # Invalid redirect_uri format client_metadata = { @@ -697,7 +702,7 @@ async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncCli ) @pytest.mark.anyio - async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient): + async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient) -> None: """Test client registration with empty redirect_uris array.""" redirect_uris: list[str] = [] client_metadata = { @@ -718,7 +723,7 @@ async def test_client_registration_empty_redirect_uris(self, test_client: httpx. ) @pytest.mark.anyio - async def test_authorize_form_post(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): + async def test_authorize_form_post(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]) -> None: """Test the authorization endpoint using POST with form-encoded data.""" # Register a client client_metadata = { @@ -762,7 +767,7 @@ async def test_authorization_get( test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str], - ): + ) -> None: """Test the full authorization flow.""" # 1. Register a client client_metadata = { @@ -867,7 +872,9 @@ async def test_authorization_get( assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None @pytest.mark.anyio - async def test_revoke_invalid_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): + async def test_revoke_invalid_token( + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] + ) -> None: """Test revoking an invalid token.""" response = await test_client.post( "/revoke", @@ -881,7 +888,9 @@ async def test_revoke_invalid_token(self, test_client: httpx.AsyncClient, regist assert response.status_code == 200 @pytest.mark.anyio - async def test_revoke_with_malformed_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): + async def test_revoke_with_malformed_token( + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] + ) -> None: response = await test_client.post( "/revoke", data={ @@ -897,7 +906,7 @@ async def test_revoke_with_malformed_token(self, test_client: httpx.AsyncClient, assert "token_type_hint" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient): + async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient) -> None: """Test client registration with scopes that are not allowed.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -916,7 +925,7 @@ async def test_client_registration_disallowed_scopes(self, test_client: httpx.As @pytest.mark.anyio async def test_client_registration_default_scopes( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ): + ) -> None: client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -938,7 +947,7 @@ async def test_client_registration_default_scopes( assert registered_client.scope == "read write" @pytest.mark.anyio - async def test_client_registration_with_authorization_code_only(self, test_client: httpx.AsyncClient): + async def test_client_registration_with_authorization_code_only(self, test_client: httpx.AsyncClient) -> None: """Test that registration succeeds with only authorization_code (refresh_token is optional per RFC 7591).""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -953,7 +962,7 @@ async def test_client_registration_with_authorization_code_only(self, test_clien assert client_info["grant_types"] == ["authorization_code"] @pytest.mark.anyio - async def test_client_registration_missing_authorization_code(self, test_client: httpx.AsyncClient): + async def test_client_registration_missing_authorization_code(self, test_client: httpx.AsyncClient) -> None: """Test that registration fails when authorization_code grant type is missing.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -969,7 +978,7 @@ async def test_client_registration_missing_authorization_code(self, test_client: assert error_data["error_description"] == "grant_types must include 'authorization_code'" @pytest.mark.anyio - async def test_client_registration_with_additional_grant_type(self, test_client: httpx.AsyncClient): + async def test_client_registration_with_additional_grant_type(self, test_client: httpx.AsyncClient) -> None: client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -988,7 +997,7 @@ async def test_client_registration_with_additional_grant_type(self, test_client: @pytest.mark.anyio async def test_client_registration_with_additional_response_types( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ): + ) -> None: """Test that registration accepts additional response_types values alongside 'code'.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1006,7 +1015,7 @@ async def test_client_registration_with_additional_response_types( assert "code" in client.response_types @pytest.mark.anyio - async def test_client_registration_response_types_without_code(self, test_client: httpx.AsyncClient): + async def test_client_registration_response_types_without_code(self, test_client: httpx.AsyncClient) -> None: """Test that registration rejects response_types that don't include 'code'.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1025,7 +1034,7 @@ async def test_client_registration_response_types_without_code(self, test_client @pytest.mark.anyio async def test_client_registration_default_response_types( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ): + ) -> None: """Test that registration uses default response_types of ['code'] when not specified.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1044,7 +1053,7 @@ async def test_client_registration_default_response_types( @pytest.mark.anyio async def test_client_secret_basic_authentication( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that client_secret_basic authentication works correctly.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1090,7 +1099,7 @@ async def test_client_secret_basic_authentication( @pytest.mark.anyio async def test_wrong_auth_method_without_valid_credentials_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that using the wrong authentication method fails when credentials are missing.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1142,7 +1151,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails( @pytest.mark.anyio async def test_basic_auth_without_header_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that omitting Basic auth when client_secret_basic is registered fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1187,7 +1196,7 @@ async def test_basic_auth_without_header_fails( @pytest.mark.anyio async def test_basic_auth_invalid_base64_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that invalid base64 in Basic auth header fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1232,7 +1241,7 @@ async def test_basic_auth_invalid_base64_fails( @pytest.mark.anyio async def test_basic_auth_no_colon_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that Basic auth without colon separator fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1278,7 +1287,7 @@ async def test_basic_auth_no_colon_fails( @pytest.mark.anyio async def test_basic_auth_client_id_mismatch_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that client_id mismatch between body and Basic auth fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1324,7 +1333,7 @@ async def test_basic_auth_client_id_mismatch_fails( @pytest.mark.anyio async def test_none_auth_method_public_client( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ): + ) -> None: """Test that 'none' authentication method works for public clients.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1371,7 +1380,9 @@ class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" @pytest.mark.anyio - async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): + async def test_authorize_missing_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] + ) -> None: """Test authorization endpoint with missing client_id. According to the OAuth2.0 spec, if client_id is missing, the server should @@ -1395,7 +1406,9 @@ async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, assert "client_id" in response.text.lower() @pytest.mark.anyio - async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): + async def test_authorize_invalid_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] + ) -> None: """Test authorization endpoint with invalid client_id. According to the OAuth2.0 spec, if client_id is invalid, the server should @@ -1421,7 +1434,7 @@ async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, @pytest.mark.anyio async def test_authorize_missing_redirect_uri( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ): + ) -> None: """Test authorization endpoint with missing redirect_uri. If client has only one registered redirect_uri, it can be omitted. @@ -1447,7 +1460,7 @@ async def test_authorize_missing_redirect_uri( @pytest.mark.anyio async def test_authorize_invalid_redirect_uri( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ): + ) -> None: """Test authorization endpoint with invalid redirect_uri. According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, @@ -1487,7 +1500,7 @@ async def test_authorize_invalid_redirect_uri( ) async def test_authorize_missing_redirect_uri_multiple_registered( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ): + ) -> None: """Test endpoint with missing redirect_uri with multiple registered URIs. If client has multiple registered redirect_uris, redirect_uri must be provided. @@ -1513,7 +1526,7 @@ async def test_authorize_missing_redirect_uri_multiple_registered( @pytest.mark.anyio async def test_authorize_unsupported_response_type( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ): + ) -> None: """Test authorization endpoint with unsupported response_type. According to the OAuth2.0 spec, for other errors like unsupported_response_type, @@ -1547,7 +1560,7 @@ async def test_authorize_unsupported_response_type( @pytest.mark.anyio async def test_authorize_missing_response_type( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ): + ) -> None: """Test authorization endpoint with missing response_type. Missing required parameter should result in invalid_request error. @@ -1580,7 +1593,7 @@ async def test_authorize_missing_response_type( @pytest.mark.anyio async def test_authorize_missing_pkce_challenge( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] - ): + ) -> None: """Test authorization endpoint with missing PKCE code_challenge. Missing PKCE parameters should result in invalid_request error. @@ -1611,7 +1624,7 @@ async def test_authorize_missing_pkce_challenge( @pytest.mark.anyio async def test_authorize_invalid_scope( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ): + ) -> None: """Test authorization endpoint with invalid scope. Invalid scope should redirect with invalid_scope error. diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index fe18e91bd..19fc5130c 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -9,7 +9,7 @@ class TestRenderPrompt: @pytest.mark.anyio - async def test_basic_fn(self): + async def test_basic_fn(self) -> None: def fn() -> str: return "Hello, world!" @@ -19,7 +19,7 @@ def fn() -> str: ] @pytest.mark.anyio - async def test_async_fn(self): + async def test_async_fn(self) -> None: async def fn() -> str: return "Hello, world!" @@ -29,7 +29,7 @@ async def fn() -> str: ] @pytest.mark.anyio - async def test_fn_with_args(self): + async def test_fn_with_args(self) -> None: async def fn(name: str, age: int = 30) -> str: return f"Hello, {name}! You're {age} years old." @@ -39,7 +39,7 @@ async def fn(name: str, age: int = 30) -> str: ] @pytest.mark.anyio - async def test_fn_with_invalid_kwargs(self): + async def test_fn_with_invalid_kwargs(self) -> None: async def fn(name: str, age: int = 30) -> str: # pragma: no cover return f"Hello, {name}! You're {age} years old." @@ -48,7 +48,7 @@ async def fn(name: str, age: int = 30) -> str: # pragma: no cover await prompt.render({"age": 40}, Context()) @pytest.mark.anyio - async def test_fn_returns_message(self): + async def test_fn_returns_message(self) -> None: async def fn() -> UserMessage: return UserMessage(content="Hello, world!") @@ -58,7 +58,7 @@ async def fn() -> UserMessage: ] @pytest.mark.anyio - async def test_fn_returns_assistant_message(self): + async def test_fn_returns_assistant_message(self) -> None: async def fn() -> AssistantMessage: return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) @@ -68,7 +68,7 @@ async def fn() -> AssistantMessage: ] @pytest.mark.anyio - async def test_fn_returns_multiple_messages(self): + async def test_fn_returns_multiple_messages(self) -> None: expected: list[Message] = [ UserMessage("Hello, world!"), AssistantMessage("How can I help you today?"), @@ -82,7 +82,7 @@ async def fn() -> list[Message]: assert await prompt.render(None, Context()) == expected @pytest.mark.anyio - async def test_fn_returns_list_of_strings(self): + async def test_fn_returns_list_of_strings(self) -> None: expected = [ "Hello, world!", "I'm looking for a restaurant in the center of town.", @@ -95,7 +95,7 @@ async def fn() -> list[str]: assert await prompt.render(None, Context()) == [UserMessage(t) for t in expected] @pytest.mark.anyio - async def test_fn_returns_resource_content(self): + async def test_fn_returns_resource_content(self) -> None: """Test returning a message with resource content.""" async def fn() -> UserMessage: @@ -125,7 +125,7 @@ async def fn() -> UserMessage: ] @pytest.mark.anyio - async def test_fn_returns_mixed_content(self): + async def test_fn_returns_mixed_content(self) -> None: """Test returning messages with mixed content types.""" async def fn() -> list[Message]: @@ -161,7 +161,7 @@ async def fn() -> list[Message]: ] @pytest.mark.anyio - async def test_fn_returns_dict_with_resource(self): + async def test_fn_returns_dict_with_resource(self) -> None: """Test returning a dict with resource content.""" async def fn() -> dict[str, Any]: diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 99a03db56..9a41931ab 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -7,7 +7,7 @@ class TestPromptManager: - def test_add_prompt(self): + def test_add_prompt(self) -> None: """Test adding a prompt to the manager.""" def fn() -> str: # pragma: no cover @@ -19,7 +19,7 @@ def fn() -> str: # pragma: no cover assert added == prompt assert manager.get_prompt("fn") == prompt - def test_add_duplicate_prompt(self, caplog: pytest.LogCaptureFixture): + def test_add_duplicate_prompt(self, caplog: pytest.LogCaptureFixture) -> None: """Test adding the same prompt twice.""" def fn() -> str: # pragma: no cover @@ -32,7 +32,7 @@ def fn() -> str: # pragma: no cover assert first == second assert "Prompt already exists" in caplog.text - def test_disable_warn_on_duplicate_prompts(self, caplog: pytest.LogCaptureFixture): + def test_disable_warn_on_duplicate_prompts(self, caplog: pytest.LogCaptureFixture) -> None: """Test disabling warning on duplicate prompts.""" def fn() -> str: # pragma: no cover @@ -45,7 +45,7 @@ def fn() -> str: # pragma: no cover assert first == second assert "Prompt already exists" not in caplog.text - def test_list_prompts(self): + def test_list_prompts(self) -> None: """Test listing all prompts.""" def fn1() -> str: # pragma: no cover @@ -64,7 +64,7 @@ def fn2() -> str: # pragma: no cover assert prompts == [prompt1, prompt2] @pytest.mark.anyio - async def test_render_prompt(self): + async def test_render_prompt(self) -> None: """Test rendering a prompt.""" def fn() -> str: @@ -77,7 +77,7 @@ def fn() -> str: assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio - async def test_render_prompt_with_args(self): + async def test_render_prompt_with_args(self) -> None: """Test rendering a prompt with arguments.""" def fn(name: str) -> str: @@ -90,14 +90,14 @@ def fn(name: str) -> str: assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio - async def test_render_unknown_prompt(self): + async def test_render_unknown_prompt(self) -> None: """Test rendering a non-existent prompt.""" manager = PromptManager() with pytest.raises(ValueError, match="Unknown prompt: unknown"): await manager.render_prompt("unknown", None, Context()) @pytest.mark.anyio - async def test_render_prompt_with_missing_args(self): + async def test_render_prompt_with_missing_args(self) -> None: """Test rendering a prompt with missing required arguments.""" def fn(name: str) -> str: # pragma: no cover diff --git a/tests/server/mcpserver/resources/test_file_resources.py b/tests/server/mcpserver/resources/test_file_resources.py index 94885113a..26ce4e475 100644 --- a/tests/server/mcpserver/resources/test_file_resources.py +++ b/tests/server/mcpserver/resources/test_file_resources.py @@ -1,4 +1,5 @@ import os +from collections.abc import Generator from pathlib import Path from tempfile import NamedTemporaryFile @@ -8,7 +9,7 @@ @pytest.fixture -def temp_file(): +def temp_file() -> Generator[Path, None, None]: """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. @@ -27,7 +28,7 @@ def temp_file(): class TestFileResource: """Test FileResource functionality.""" - def test_file_resource_creation(self, temp_file: Path): + def test_file_resource_creation(self, temp_file: Path) -> None: """Test creating a FileResource.""" resource = FileResource( uri=temp_file.as_uri(), @@ -42,7 +43,7 @@ def test_file_resource_creation(self, temp_file: Path): assert resource.path == temp_file assert resource.is_binary is False # default - def test_file_resource_str_path_conversion(self, temp_file: Path): + def test_file_resource_str_path_conversion(self, temp_file: Path) -> None: """Test FileResource handles string paths.""" resource = FileResource( uri=f"file://{temp_file}", @@ -53,7 +54,7 @@ def test_file_resource_str_path_conversion(self, temp_file: Path): assert resource.path.is_absolute() @pytest.mark.anyio - async def test_read_text_file(self, temp_file: Path): + async def test_read_text_file(self, temp_file: Path) -> None: """Test reading a text file.""" resource = FileResource( uri=f"file://{temp_file}", @@ -65,7 +66,7 @@ async def test_read_text_file(self, temp_file: Path): assert resource.mime_type == "text/plain" @pytest.mark.anyio - async def test_read_binary_file(self, temp_file: Path): + async def test_read_binary_file(self, temp_file: Path) -> None: """Test reading a file as binary.""" resource = FileResource( uri=f"file://{temp_file}", @@ -77,7 +78,7 @@ async def test_read_binary_file(self, temp_file: Path): assert isinstance(content, bytes) assert content == b"test content" - def test_relative_path_error(self): + def test_relative_path_error(self) -> None: """Test error on relative path.""" with pytest.raises(ValueError, match="Path must be absolute"): FileResource( @@ -87,7 +88,7 @@ def test_relative_path_error(self): ) @pytest.mark.anyio - async def test_missing_file_error(self, temp_file: Path): + async def test_missing_file_error(self, temp_file: Path) -> None: """Test error when file doesn't exist.""" # Create path to non-existent file missing = temp_file.parent / "missing.txt" @@ -101,7 +102,7 @@ async def test_missing_file_error(self, temp_file: Path): @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") @pytest.mark.anyio - async def test_permission_error(self, temp_file: Path): # pragma: lax no cover + async def test_permission_error(self, temp_file: Path) -> None: # pragma: lax no cover """Test reading a file without permissions.""" temp_file.chmod(0o000) # Remove all permissions try: diff --git a/tests/server/mcpserver/resources/test_function_resources.py b/tests/server/mcpserver/resources/test_function_resources.py index 5f5c216ed..8cc498530 100644 --- a/tests/server/mcpserver/resources/test_function_resources.py +++ b/tests/server/mcpserver/resources/test_function_resources.py @@ -7,7 +7,7 @@ class TestFunctionResource: """Test FunctionResource functionality.""" - def test_function_resource_creation(self): + def test_function_resource_creation(self) -> None: """Test creating a FunctionResource.""" def my_func() -> str: # pragma: no cover @@ -26,7 +26,7 @@ def my_func() -> str: # pragma: no cover assert resource.fn == my_func @pytest.mark.anyio - async def test_read_text(self): + async def test_read_text(self) -> None: """Test reading text from a FunctionResource.""" def get_data() -> str: @@ -42,7 +42,7 @@ def get_data() -> str: assert resource.mime_type == "text/plain" @pytest.mark.anyio - async def test_read_binary(self): + async def test_read_binary(self) -> None: """Test reading binary data from a FunctionResource.""" def get_data() -> bytes: @@ -57,7 +57,7 @@ def get_data() -> bytes: assert content == b"Hello, world!" @pytest.mark.anyio - async def test_json_conversion(self): + async def test_json_conversion(self) -> None: """Test automatic JSON conversion of non-string results.""" def get_data() -> dict[str, str]: @@ -73,7 +73,7 @@ def get_data() -> dict[str, str]: assert '"key": "value"' in content @pytest.mark.anyio - async def test_error_handling(self): + async def test_error_handling(self) -> None: """Test error handling in FunctionResource.""" def failing_func() -> str: @@ -88,7 +88,7 @@ def failing_func() -> str: await resource.read() @pytest.mark.anyio - async def test_basemodel_conversion(self): + async def test_basemodel_conversion(self) -> None: """Test handling of BaseModel types.""" class MyModel(BaseModel): @@ -103,7 +103,7 @@ class MyModel(BaseModel): assert content == '{\n "name": "test"\n}' @pytest.mark.anyio - async def test_custom_type_conversion(self): + async def test_custom_type_conversion(self) -> None: """Test handling of custom types.""" class CustomData: @@ -122,7 +122,7 @@ def get_data() -> CustomData: assert isinstance(content, str) @pytest.mark.anyio - async def test_async_read_text(self): + async def test_async_read_text(self) -> None: """Test reading text from async FunctionResource.""" async def get_data() -> str: @@ -138,7 +138,7 @@ async def get_data() -> str: assert resource.mime_type == "text/plain" @pytest.mark.anyio - async def test_from_function(self): + async def test_from_function(self) -> None: """Test creating a FunctionResource from a function.""" async def get_data() -> str: # pragma: no cover @@ -158,7 +158,7 @@ async def get_data() -> str: # pragma: no cover class TestFunctionResourceMetadata: - def test_from_function_with_metadata(self): + def test_from_function_with_metadata(self) -> None: # from_function() accepts meta dict and stores it on the resource for static resources def get_data() -> str: # pragma: no cover @@ -178,7 +178,7 @@ def get_data() -> str: # pragma: no cover assert "data" in resource.meta["tags"] assert "readonly" in resource.meta["tags"] - def test_from_function_without_metadata(self): + def test_from_function_without_metadata(self) -> None: # meta parameter is optional and defaults to None for backward compatibility def get_data() -> str: # pragma: no cover diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index 724b57997..763a004ad 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -1,3 +1,4 @@ +from collections.abc import Generator from pathlib import Path from tempfile import NamedTemporaryFile @@ -9,7 +10,7 @@ @pytest.fixture -def temp_file(): +def temp_file() -> Generator[Path, None, None]: """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. @@ -28,7 +29,7 @@ def temp_file(): class TestResourceManager: """Test ResourceManager functionality.""" - def test_add_resource(self, temp_file: Path): + def test_add_resource(self, temp_file: Path) -> None: """Test adding a resource.""" manager = ResourceManager() resource = FileResource( @@ -40,7 +41,7 @@ def test_add_resource(self, temp_file: Path): assert added == resource assert manager.list_resources() == [resource] - def test_add_duplicate_resource(self, temp_file: Path): + def test_add_duplicate_resource(self, temp_file: Path) -> None: """Test adding the same resource twice.""" manager = ResourceManager() resource = FileResource( @@ -53,7 +54,7 @@ def test_add_duplicate_resource(self, temp_file: Path): assert first == second assert manager.list_resources() == [resource] - def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): + def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture) -> None: """Test warning on duplicate resources.""" manager = ResourceManager() resource = FileResource( @@ -65,7 +66,7 @@ def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCa manager.add_resource(resource) assert "Resource already exists" in caplog.text - def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): + def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture) -> None: """Test disabling warning on duplicate resources.""" manager = ResourceManager(warn_on_duplicate_resources=False) resource = FileResource( @@ -78,7 +79,7 @@ def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pyte assert "Resource already exists" not in caplog.text @pytest.mark.anyio - async def test_get_resource(self, temp_file: Path): + async def test_get_resource(self, temp_file: Path) -> None: """Test getting a resource by URI.""" manager = ResourceManager() resource = FileResource( @@ -91,7 +92,7 @@ async def test_get_resource(self, temp_file: Path): assert retrieved == resource @pytest.mark.anyio - async def test_get_resource_from_template(self): + async def test_get_resource_from_template(self) -> None: """Test getting a resource through a template.""" manager = ResourceManager() @@ -111,13 +112,13 @@ def greet(name: str) -> str: assert content == "Hello, world!" @pytest.mark.anyio - async def test_get_unknown_resource(self): + async def test_get_unknown_resource(self) -> None: """Test getting a non-existent resource.""" manager = ResourceManager() with pytest.raises(ValueError, match="Unknown resource"): await manager.get_resource(AnyUrl("unknown://test"), Context()) - def test_list_resources(self, temp_file: Path): + def test_list_resources(self, temp_file: Path) -> None: """Test listing all resources.""" manager = ResourceManager() resource1 = FileResource( @@ -140,7 +141,7 @@ def test_list_resources(self, temp_file: Path): class TestResourceManagerMetadata: """Test ResourceManager Metadata""" - def test_add_template_with_metadata(self): + def test_add_template_with_metadata(self) -> None: """Test that ResourceManager.add_template() accepts and passes meta parameter.""" manager = ResourceManager() @@ -161,7 +162,7 @@ def get_item(id: str) -> str: # pragma: no cover assert template.meta["source"] == "database" assert template.meta["cached"] is True - def test_add_template_without_metadata(self): + def test_add_template_without_metadata(self) -> None: """Test that ResourceManager.add_template() works without meta parameter.""" manager = ResourceManager() diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 640cfe803..818f13841 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -12,7 +12,7 @@ class TestResourceTemplate: """Test ResourceTemplate functionality.""" - def test_template_creation(self): + def test_template_creation(self) -> None: """Test creating a template from a function.""" def my_func(key: str, value: int) -> dict[str, Any]: @@ -28,7 +28,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: assert template.mime_type == "text/plain" # default assert template.fn(key="test", value=42) == my_func(key="test", value=42) - def test_template_matches(self): + def test_template_matches(self) -> None: """Test matching URIs against a template.""" def my_func(key: str, value: int) -> dict[str, Any]: # pragma: no cover @@ -49,7 +49,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: # pragma: no cover assert template.matches("other://foo/123") is None @pytest.mark.anyio - async def test_create_resource(self): + async def test_create_resource(self) -> None: """Test creating a resource from a template.""" def my_func(key: str, value: int) -> dict[str, Any]: @@ -74,7 +74,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: assert data == {"key": "foo", "value": 123} @pytest.mark.anyio - async def test_template_error(self): + async def test_template_error(self) -> None: """Test error handling in template resource creation.""" def failing_func(x: str) -> str: @@ -90,7 +90,7 @@ def failing_func(x: str) -> str: await template.create_resource("fail://test", {"x": "test"}, Context()) @pytest.mark.anyio - async def test_async_text_resource(self): + async def test_async_text_resource(self) -> None: """Test creating a text resource from async function.""" async def greet(name: str) -> str: @@ -113,7 +113,7 @@ async def greet(name: str) -> str: assert content == "Hello, world!" @pytest.mark.anyio - async def test_async_binary_resource(self): + async def test_async_binary_resource(self) -> None: """Test creating a binary resource from async function.""" async def get_bytes(value: str) -> bytes: @@ -136,7 +136,7 @@ async def get_bytes(value: str) -> bytes: assert content == b"test" @pytest.mark.anyio - async def test_basemodel_conversion(self): + async def test_basemodel_conversion(self) -> None: """Test handling of BaseModel types.""" class MyModel(BaseModel): @@ -165,11 +165,11 @@ def get_data(key: str, value: int) -> MyModel: assert data == {"key": "foo", "value": 123} @pytest.mark.anyio - async def test_custom_type_conversion(self): + async def test_custom_type_conversion(self) -> None: """Test handling of custom types.""" class CustomData: - def __init__(self, value: str): + def __init__(self, value: str) -> None: self.value = value def __str__(self) -> str: @@ -198,7 +198,7 @@ def get_data(value: str) -> CustomData: class TestResourceTemplateAnnotations: """Test annotations on resource templates.""" - def test_template_with_annotations(self): + def test_template_with_annotations(self) -> None: """Test creating a template with annotations.""" def get_user_data(user_id: str) -> str: # pragma: no cover @@ -213,7 +213,7 @@ def get_user_data(user_id: str) -> str: # pragma: no cover assert template.annotations is not None assert template.annotations.priority == 0.9 - def test_template_without_annotations(self): + def test_template_without_annotations(self) -> None: """Test that annotations are optional for templates.""" def get_user_data(user_id: str) -> str: # pragma: no cover @@ -224,7 +224,7 @@ def get_user_data(user_id: str) -> str: # pragma: no cover assert template.annotations is None @pytest.mark.anyio - async def test_template_annotations_in_mcpserver(self): + async def test_template_annotations_in_mcpserver(self) -> None: """Test template annotations via an MCPServer decorator.""" mcp = MCPServer() @@ -241,7 +241,7 @@ def get_dynamic(id: str) -> str: # pragma: no cover assert templates[0].annotations.priority == 0.7 @pytest.mark.anyio - async def test_template_created_resources_inherit_annotations(self): + async def test_template_created_resources_inherit_annotations(self) -> None: """Test that resources created from templates inherit annotations.""" def get_item(item_id: str) -> str: @@ -268,7 +268,7 @@ def get_item(item_id: str) -> str: class TestResourceTemplateMetadata: """Test ResourceTemplate meta handling.""" - def test_template_from_function_with_metadata(self): + def test_template_from_function_with_metadata(self) -> None: """Test that ResourceTemplate.from_function() accepts and stores meta parameter.""" def get_user(user_id: str) -> str: # pragma: no cover @@ -288,7 +288,7 @@ def get_user(user_id: str) -> str: # pragma: no cover assert template.meta["rate_limit"] == 100 @pytest.mark.anyio - async def test_template_created_resources_inherit_metadata(self): + async def test_template_created_resources_inherit_metadata(self) -> None: """Test that resources created from templates inherit meta from template.""" def get_item(item_id: str) -> str: diff --git a/tests/server/mcpserver/resources/test_resources.py b/tests/server/mcpserver/resources/test_resources.py index 5d36beda8..cc428a7af 100644 --- a/tests/server/mcpserver/resources/test_resources.py +++ b/tests/server/mcpserver/resources/test_resources.py @@ -8,7 +8,7 @@ class TestResourceValidation: """Test base Resource validation.""" - def test_resource_uri_accepts_any_string(self): + def test_resource_uri_accepts_any_string(self) -> None: """Test that URI field accepts any string per MCP spec.""" def dummy_func() -> str: # pragma: no cover @@ -38,7 +38,7 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.uri == "custom://resource" - def test_resource_name_from_uri(self): + def test_resource_name_from_uri(self) -> None: """Test name is extracted from URI if not provided.""" def dummy_func() -> str: # pragma: no cover @@ -50,7 +50,7 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.name == "resource://my-resource" - def test_resource_name_validation(self): + def test_resource_name_validation(self) -> None: """Test name validation.""" def dummy_func() -> str: # pragma: no cover @@ -70,7 +70,7 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.name == "explicit-name" - def test_resource_mime_type(self): + def test_resource_mime_type(self) -> None: """Test mime type handling.""" def dummy_func() -> str: # pragma: no cover @@ -100,7 +100,7 @@ def dummy_func() -> str: # pragma: no cover assert resource.mime_type == 'text/plain; charset="utf-8"' @pytest.mark.anyio - async def test_resource_read_abstract(self): + async def test_resource_read_abstract(self) -> None: """Test that Resource.read() is abstract.""" class ConcreteResource(Resource): @@ -113,7 +113,7 @@ class ConcreteResource(Resource): class TestResourceAnnotations: """Test annotations on resources.""" - def test_resource_with_annotations(self): + def test_resource_with_annotations(self) -> None: """Test creating a resource with annotations.""" def get_data() -> str: # pragma: no cover @@ -127,7 +127,7 @@ def get_data() -> str: # pragma: no cover assert resource.annotations.audience == ["user"] assert resource.annotations.priority == 0.8 - def test_resource_without_annotations(self): + def test_resource_without_annotations(self) -> None: """Test that annotations are optional.""" def get_data() -> str: # pragma: no cover @@ -138,7 +138,7 @@ def get_data() -> str: # pragma: no cover assert resource.annotations is None @pytest.mark.anyio - async def test_resource_annotations_in_mcpserver(self): + async def test_resource_annotations_in_mcpserver(self) -> None: """Test resource annotations via MCPServer decorator.""" mcp = MCPServer() @@ -155,7 +155,7 @@ def get_annotated() -> str: # pragma: no cover assert resources[0].annotations.priority == 0.5 @pytest.mark.anyio - async def test_resource_annotations_with_both_audiences(self): + async def test_resource_annotations_with_both_audiences(self) -> None: """Test resource with both user and assistant audience.""" mcp = MCPServer() @@ -173,7 +173,7 @@ def get_both() -> str: # pragma: no cover class TestAnnotationsValidation: """Test validation of annotation values.""" - def test_priority_validation(self): + def test_priority_validation(self) -> None: """Test that priority is validated to be between 0.0 and 1.0.""" # Valid priorities @@ -188,7 +188,7 @@ def test_priority_validation(self): with pytest.raises(Exception): Annotations(priority=1.1) - def test_audience_validation(self): + def test_audience_validation(self) -> None: """Test that audience only accepts valid roles.""" # Valid audiences @@ -205,7 +205,7 @@ def test_audience_validation(self): class TestResourceMetadata: """Test metadata field on base Resource class.""" - def test_resource_with_metadata(self): + def test_resource_with_metadata(self) -> None: """Test that Resource base class accepts meta parameter.""" def dummy_func() -> str: # pragma: no cover @@ -225,7 +225,7 @@ def dummy_func() -> str: # pragma: no cover assert resource.meta["version"] == "1.0" assert resource.meta["category"] == "test" - def test_resource_without_metadata(self): + def test_resource_without_metadata(self) -> None: """Test that meta field defaults to None.""" def dummy_func() -> str: # pragma: no cover diff --git a/tests/server/mcpserver/servers/test_file_server.py b/tests/server/mcpserver/servers/test_file_server.py index 9c3fe265c..3ee02a28f 100644 --- a/tests/server/mcpserver/servers/test_file_server.py +++ b/tests/server/mcpserver/servers/test_file_server.py @@ -74,7 +74,7 @@ def delete_file(path: str) -> bool: @pytest.mark.anyio -async def test_list_resources(mcp: MCPServer): +async def test_list_resources(mcp: MCPServer) -> None: resources = await mcp.list_resources() assert len(resources) == 4 @@ -87,7 +87,7 @@ async def test_list_resources(mcp: MCPServer): @pytest.mark.anyio -async def test_read_resource_dir(mcp: MCPServer): +async def test_read_resource_dir(mcp: MCPServer) -> None: res_iter = await mcp.read_resource("dir://test_dir") res_list = list(res_iter) assert len(res_list) == 1 @@ -104,7 +104,7 @@ async def test_read_resource_dir(mcp: MCPServer): @pytest.mark.anyio -async def test_read_resource_file(mcp: MCPServer): +async def test_read_resource_file(mcp: MCPServer) -> None: res_iter = await mcp.read_resource("file://test_dir/example.py") res_list = list(res_iter) assert len(res_list) == 1 @@ -113,13 +113,13 @@ async def test_read_resource_file(mcp: MCPServer): @pytest.mark.anyio -async def test_delete_file(mcp: MCPServer, test_dir: Path): +async def test_delete_file(mcp: MCPServer, test_dir: Path) -> None: await mcp.call_tool("delete_file", arguments={"path": str(test_dir / "example.py")}) assert not (test_dir / "example.py").exists() @pytest.mark.anyio -async def test_delete_file_and_check_resources(mcp: MCPServer, test_dir: Path): +async def test_delete_file_and_check_resources(mcp: MCPServer, test_dir: Path) -> None: await mcp.call_tool("delete_file", arguments={"path": str(test_dir / "example.py")}) res_iter = await mcp.read_resource("file://test_dir/example.py") res_list = list(res_iter) diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 679fb848f..5fd7cbc77 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -1,5 +1,6 @@ """Test the elicitation feature using stdio transport.""" +from collections.abc import Callable, Coroutine from typing import Any import pytest @@ -9,7 +10,7 @@ from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext -from mcp.types import ElicitRequestParams, ElicitResult, TextContent +from mcp.types import CallToolResult, ElicitRequestParams, ElicitResult, TextContent # Shared schema for basic tests @@ -17,7 +18,7 @@ class AnswerSchema(BaseModel): answer: str = Field(description="The user's answer to the question") -def create_ask_user_tool(mcp: MCPServer): +def create_ask_user_tool(mcp: MCPServer) -> Callable[[str, Context], Coroutine[Any, Any, str]]: """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") @@ -41,7 +42,7 @@ async def call_tool_and_assert( args: dict[str, Any], expected_text: str | None = None, text_contains: list[str] | None = None, -): +) -> CallToolResult: """Helper to create session, call tool, and assert result.""" async with Client(mcp, elicitation_callback=elicitation_callback) as client: result = await client.call_tool(tool_name, args) @@ -58,13 +59,13 @@ async def call_tool_and_assert( @pytest.mark.anyio -async def test_stdio_elicitation(): +async def test_stdio_elicitation() -> None: """Test the elicitation feature using stdio transport.""" mcp = MCPServer(name="StdioElicitationServer") create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: # pragma: no cover @@ -76,12 +77,12 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_stdio_elicitation_decline(): +async def test_stdio_elicitation_decline() -> None: """Test elicitation with user declining.""" mcp = MCPServer(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") await call_tool_and_assert( @@ -90,11 +91,13 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_elicitation_schema_validation(): +async def test_elicitation_schema_validation() -> None: """Test that elicitation schemas must only contain primitive types.""" mcp = MCPServer(name="ValidationTestServer") - def create_validation_tool(name: str, schema_class: type[BaseModel]): + def create_validation_tool( + name: str, schema_class: type[BaseModel] + ) -> Callable[[Context], Coroutine[Any, Any, str]]: @mcp.tool(name=name, description=f"Tool testing {name}") async def tool(ctx: Context) -> str: try: @@ -121,7 +124,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( context: RequestContext[ClientSession], params: ElicitRequestParams - ): # pragma: no cover + ) -> ElicitResult: # pragma: no cover return ElicitResult(action="accept", content={}) async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -135,7 +138,7 @@ async def elicitation_callback( @pytest.mark.anyio -async def test_elicitation_with_optional_fields(): +async def test_elicitation_with_optional_fields() -> None: """Test that Optional fields work correctly in elicitation schemas.""" mcp = MCPServer(name="OptionalFieldServer") @@ -176,7 +179,7 @@ async def optional_tool(ctx: Context) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -196,7 +199,7 @@ async def invalid_optional_tool(ctx: Context) -> str: async def elicitation_callback( context: RequestContext[ClientSession], params: ElicitRequestParams - ): # pragma: no cover + ) -> ElicitResult: # pragma: no cover return ElicitResult(action="accept", content={}) await call_tool_and_assert( @@ -219,7 +222,7 @@ async def valid_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" return f"User {result.action}" # pragma: no cover - async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: if "Please provide tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -239,7 +242,9 @@ async def optional_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {tags_str}" return f"User {result.action}" # pragma: no cover - async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def optional_multiselect_callback( + context: RequestContext[ClientSession], params: ElicitRequestParams + ) -> ElicitResult: if "Please provide optional tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -250,7 +255,7 @@ async def optional_multiselect_callback(context: RequestContext[ClientSession], @pytest.mark.anyio -async def test_elicitation_with_default_values(): +async def test_elicitation_with_default_values() -> None: """Test that default values work correctly in elicitation schemas and are included in JSON.""" mcp = MCPServer(name="DefaultValuesServer") @@ -273,7 +278,9 @@ async def defaults_tool(ctx: Context) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback_schema_verify( + context: RequestContext[ClientSession], params: ElicitRequestParams + ) -> ElicitResult: # Verify the schema includes defaults assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requested_schema @@ -295,7 +302,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession], params: ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) @@ -306,7 +313,7 @@ async def callback_override(context: RequestContext[ClientSession], params: Elic @pytest.mark.anyio -async def test_elicitation_with_enum_titles(): +async def test_elicitation_with_enum_titles() -> None: """Test elicitation with enum schemas using oneOf/anyOf for titles.""" mcp = MCPServer(name="ColorPreferencesApp") @@ -371,7 +378,7 @@ async def select_color_legacy(ctx: Context) -> str: return f"User: {result.data.user_name}, Color: {result.data.color}" return f"User {result.action}" # pragma: no cover - async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: if "colors" in params.message and "legacy" not in params.message: return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) elif "color" in params.message: diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index c57d1ee9f..b0e57ce92 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -92,7 +92,7 @@ def complex_arguments_fn( @pytest.mark.anyio -async def test_complex_function_runtime_arg_validation_non_json(): +async def test_complex_function_runtime_arg_validation_non_json() -> None: """Test that basic non-JSON arguments are validated correctly""" meta = func_metadata(complex_arguments_fn) @@ -129,7 +129,7 @@ async def test_complex_function_runtime_arg_validation_non_json(): @pytest.mark.anyio -async def test_complex_function_runtime_arg_validation_with_json(): +async def test_complex_function_runtime_arg_validation_with_json() -> None: """Test that JSON string arguments are parsed and validated correctly""" meta = func_metadata(complex_arguments_fn) @@ -155,14 +155,14 @@ async def test_complex_function_runtime_arg_validation_with_json(): assert result == "ok!" -def test_str_vs_list_str(): +def test_str_vs_list_str() -> None: """Test handling of string vs list[str] type annotations. This is tricky as '"hello"' can be parsed as a JSON string or a Python string. We want to make sure it's kept as a python string. """ - def func_with_str_types(str_or_list: str | list[str]): # pragma: no cover + def func_with_str_types(str_or_list: str | list[str]) -> str | list[str]: # pragma: no cover return str_or_list meta = func_metadata(func_with_str_types) @@ -182,10 +182,12 @@ def func_with_str_types(str_or_list: str | list[str]): # pragma: no cover assert result["str_or_list"] == ["hello", "world"] -def test_skip_names(): +def test_skip_names() -> None: """Test that skipped parameters are not included in the model""" - def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool): # pragma: no cover + def func_with_many_params( + keep_this: int, skip_this: str, also_keep: float, also_skip: bool + ) -> tuple[int, str, float, bool]: # pragma: no cover return keep_this, skip_this, also_keep, also_skip # Skip some parameters @@ -203,7 +205,7 @@ def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also assert model.also_keep == 2.5 # type: ignore -def test_structured_output_dict_str_types(): +def test_structured_output_dict_str_types() -> None: """Test that dict[str, T] types are handled without wrapping.""" # Test dict[str, Any] @@ -246,7 +248,7 @@ def func_dict_int_key() -> dict[int, str]: # pragma: no cover @pytest.mark.anyio -async def test_lambda_function(): +async def test_lambda_function() -> None: """Test lambda function schema and validation""" fn: Callable[[str, int], str] = lambda x, y=5: x # noqa: E731 meta = func_metadata(lambda x, y=5: x) @@ -262,7 +264,7 @@ async def test_lambda_function(): "type": "object", } - async def check_call(args): + async def check_call(args: dict[str, Any]) -> Any: return await meta.call_fn_with_arg_validation( fn, fn_is_async=False, @@ -280,7 +282,7 @@ async def check_call(args): await check_call({"y": "world"}) -def test_complex_function_json_schema(): +def test_complex_function_json_schema() -> None: """Test JSON schema generation for complex function arguments. Note: Different versions of pydantic output slightly different @@ -447,12 +449,12 @@ def test_complex_function_json_schema(): } -def test_str_vs_int(): +def test_str_vs_int() -> None: """Test that string values are kept as strings even when they contain numbers, while numbers are parsed correctly. """ - def func_with_str_and_int(a: str, b: int): # pragma: no cover + def func_with_str_and_int(a: str, b: int) -> str: # pragma: no cover return a meta = func_metadata(func_with_str_and_int) @@ -461,7 +463,7 @@ def func_with_str_and_int(a: str, b: int): # pragma: no cover assert result["b"] == 123 -def test_str_annotation_preserves_json_string(): +def test_str_annotation_preserves_json_string() -> None: """Regression test for PR #1113: Ensure that when a parameter is annotated as str, valid JSON strings are NOT parsed into Python objects. @@ -511,7 +513,7 @@ def process_json_config(config: str, enabled: bool = True) -> str: # pragma: no @pytest.mark.anyio -async def test_str_annotation_runtime_validation(): +async def test_str_annotation_runtime_validation() -> None: """Regression test for PR #1113: Test runtime validation with string parameters containing valid JSON to ensure they are passed as strings, not parsed objects. """ @@ -554,10 +556,10 @@ def handle_json_payload(payload: str, strict_mode: bool = False) -> str: # Tests for structured output functionality -def test_structured_output_requires_return_annotation(): +def test_structured_output_requires_return_annotation() -> None: """Test that structured_output=True requires a return annotation""" - def func_no_annotation(): # pragma: no cover + def func_no_annotation(): # noqa: ANN202 # pragma: no cover return "hello" def func_none_annotation() -> None: # pragma: no cover @@ -577,7 +579,7 @@ def func_none_annotation() -> None: # pragma: no cover } -def test_structured_output_basemodel(): +def test_structured_output_basemodel() -> None: """Test structured output with BaseModel return types""" class PersonModel(BaseModel): @@ -601,7 +603,7 @@ def func_returning_person() -> PersonModel: # pragma: no cover } -def test_structured_output_primitives(): +def test_structured_output_primitives() -> None: """Test structured output with primitive return types""" def func_str() -> str: # pragma: no cover @@ -665,7 +667,7 @@ def func_bytes() -> bytes: # pragma: no cover } -def test_structured_output_generic_types(): +def test_structured_output_generic_types() -> None: """Test structured output with generic types (list, dict, Union, etc.)""" def func_list_str() -> list[str]: # pragma: no cover @@ -716,7 +718,7 @@ def func_optional() -> str | None: # pragma: no cover } -def test_structured_output_dataclass(): +def test_structured_output_dataclass() -> None: """Test structured output with dataclass return types""" @dataclass @@ -747,7 +749,7 @@ def func_returning_dataclass() -> PersonDataClass: # pragma: no cover } -def test_structured_output_typeddict(): +def test_structured_output_typeddict() -> None: """Test structured output with TypedDict return types""" class PersonTypedDictOptional(TypedDict, total=False): @@ -789,7 +791,7 @@ def func_returning_typeddict_required() -> PersonTypedDictRequired: # pragma: n } -def test_structured_output_ordinary_class(): +def test_structured_output_ordinary_class() -> None: """Test structured output with ordinary annotated classes""" class PersonClass: @@ -797,7 +799,7 @@ class PersonClass: age: int email: str | None - def __init__(self, name: str, age: int, email: str | None = None): # pragma: no cover + def __init__(self, name: str, age: int, email: str | None = None) -> None: # pragma: no cover self.name = name self.age = age self.email = email @@ -818,10 +820,10 @@ def func_returning_class() -> PersonClass: # pragma: no cover } -def test_unstructured_output_unannotated_class(): +def test_unstructured_output_unannotated_class() -> None: # Test with class that has no annotations class UnannotatedClass: - def __init__(self, x, y): # pragma: no cover + def __init__(self, x, y) -> None: # pragma: no cover self.x = x self.y = y @@ -832,7 +834,7 @@ def func_returning_unannotated() -> UnannotatedClass: # pragma: no cover assert meta.output_schema is None -def test_tool_call_result_is_unstructured_and_not_converted(): +def test_tool_call_result_is_unstructured_and_not_converted() -> None: def func_returning_call_tool_result() -> CallToolResult: return CallToolResult(content=[]) @@ -842,7 +844,7 @@ def func_returning_call_tool_result() -> CallToolResult: assert isinstance(meta.convert_result(func_returning_call_tool_result()), CallToolResult) -def test_tool_call_result_annotated_is_structured_and_converted(): +def test_tool_call_result_annotated_is_structured_and_converted() -> None: class PersonClass(BaseModel): name: str @@ -862,7 +864,7 @@ def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, Per assert isinstance(meta.convert_result(func_returning_annotated_tool_call_result()), CallToolResult) -def test_tool_call_result_annotated_is_structured_and_invalid(): +def test_tool_call_result_annotated_is_structured_and_invalid() -> None: class PersonClass(BaseModel): name: str @@ -875,7 +877,7 @@ def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, Per meta.convert_result(func_returning_annotated_tool_call_result()) -def test_tool_call_result_in_optional_is_rejected(): +def test_tool_call_result_in_optional_is_rejected() -> None: """Test that Optional[CallToolResult] raises InvalidSignature""" def func_optional_call_tool_result() -> CallToolResult | None: # pragma: no cover @@ -888,7 +890,7 @@ def func_optional_call_tool_result() -> CallToolResult | None: # pragma: no cov assert "CallToolResult" in str(exc_info.value) -def test_tool_call_result_in_union_is_rejected(): +def test_tool_call_result_in_union_is_rejected() -> None: """Test that Union[str, CallToolResult] raises InvalidSignature""" def func_union_call_tool_result() -> str | CallToolResult: # pragma: no cover @@ -901,7 +903,7 @@ def func_union_call_tool_result() -> str | CallToolResult: # pragma: no cover assert "CallToolResult" in str(exc_info.value) -def test_tool_call_result_in_pipe_union_is_rejected(): +def test_tool_call_result_in_pipe_union_is_rejected() -> None: """Test that str | CallToolResult raises InvalidSignature""" def func_pipe_union_call_tool_result() -> str | CallToolResult: # pragma: no cover @@ -914,7 +916,7 @@ def func_pipe_union_call_tool_result() -> str | CallToolResult: # pragma: no co assert "CallToolResult" in str(exc_info.value) -def test_structured_output_with_field_descriptions(): +def test_structured_output_with_field_descriptions() -> None: """Test that Field descriptions are preserved in structured output""" class ModelWithDescriptions(BaseModel): @@ -936,7 +938,7 @@ def func_with_descriptions() -> ModelWithDescriptions: # pragma: no cover } -def test_structured_output_nested_models(): +def test_structured_output_nested_models() -> None: """Test structured output with nested models""" class Address(BaseModel): @@ -975,7 +977,7 @@ def func_nested() -> PersonWithAddress: # pragma: no cover } -def test_structured_output_unserializable_type_error(): +def test_structured_output_unserializable_type_error() -> None: """Test error when structured_output=True is used with unserializable types""" # Test with a class that has non-serializable default values @@ -1016,7 +1018,7 @@ def func_returning_namedtuple() -> Point: # pragma: no cover assert "Point" in str(exc_info.value) -def test_structured_output_aliases(): +def test_structured_output_aliases() -> None: """Test that field aliases are consistent between schema and output""" class ModelWithAliases(BaseModel): @@ -1061,7 +1063,7 @@ def func_with_aliases() -> ModelWithAliases: # pragma: no cover assert structured_content_defaults["second"] is None -def test_basemodel_reserved_names(): +def test_basemodel_reserved_names() -> None: """Test that functions with parameters named after BaseModel methods work correctly""" def func_with_reserved_names( # pragma: no cover @@ -1089,7 +1091,7 @@ def func_with_reserved_names( # pragma: no cover @pytest.mark.anyio -async def test_basemodel_reserved_names_validation(): +async def test_basemodel_reserved_names_validation() -> None: """Test that validation and calling works with reserved parameter names""" def func_with_reserved_names( @@ -1147,7 +1149,7 @@ def func_with_reserved_names( assert dumped["normal_param"] == "test" -def test_basemodel_reserved_names_with_json_preparsing(): +def test_basemodel_reserved_names_with_json_preparsing() -> None: """Test that pre_parse_json works correctly with reserved parameter names""" def func_with_reserved_json( # pragma: no cover @@ -1173,7 +1175,7 @@ def func_with_reserved_json( # pragma: no cover assert result["normal"] == "plain string" -def test_disallowed_type_qualifier(): +def test_disallowed_type_qualifier() -> None: def func_disallowed_qualifier() -> Final[int]: # type: ignore pass # pragma: no cover @@ -1182,7 +1184,7 @@ def func_disallowed_qualifier() -> Final[int]: # type: ignore assert "return annotation contains an invalid type qualifier" in str(exc_info.value) -def test_preserves_pydantic_metadata(): +def test_preserves_pydantic_metadata() -> None: def func_with_metadata() -> Annotated[int, Field(gt=1)]: ... # pragma: no branch meta = func_metadata(func_with_metadata) diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index f71c0574c..c6ca50685 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -59,7 +59,7 @@ class NotificationCollector: """Collects notifications from the server for testing.""" - def __init__(self): + def __init__(self) -> None: self.progress_notifications: list[ProgressNotificationParams] = [] self.log_messages: list[LoggingMessageNotificationParams] = [] self.resource_notifications: list[NotificationParams | None] = [] @@ -94,7 +94,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): +async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: @@ -184,7 +184,9 @@ async def test_tool_progress() -> None: """Test tool progress reporting.""" collector = NotificationCollector() - async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: await collector.handle_generic_notification(message) if isinstance(message, Exception): # pragma: no cover raise message @@ -263,7 +265,9 @@ async def test_notifications() -> None: """Test notifications and logging functionality.""" collector = NotificationCollector() - async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: await collector.handle_generic_notification(message) if isinstance(message, Exception): # pragma: no cover raise message diff --git a/tests/server/mcpserver/test_parameter_descriptions.py b/tests/server/mcpserver/test_parameter_descriptions.py index ec9f22c25..a47b29e08 100644 --- a/tests/server/mcpserver/test_parameter_descriptions.py +++ b/tests/server/mcpserver/test_parameter_descriptions.py @@ -7,7 +7,7 @@ @pytest.mark.anyio -async def test_parameter_descriptions(): +async def test_parameter_descriptions() -> None: mcp = MCPServer("Test Server") @mcp.tool() diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3ef06d038..205a63334 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,6 +1,6 @@ import base64 from pathlib import Path -from typing import Any +from typing import Any, NoReturn from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -46,7 +46,7 @@ class TestServer: - async def test_create_server(self): + async def test_create_server(self) -> None: mcp = MCPServer( title="MCPServer Server", description="Server description", @@ -65,7 +65,7 @@ async def test_create_server(self): assert len(mcp.icons) == 1 assert mcp.icons[0].src == "https://example.com/icon.png" - async def test_sse_app_returns_starlette_app(self): + async def test_sse_app_returns_starlette_app(self) -> None: """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test") # Use host="0.0.0.0" to avoid auto DNS protection @@ -82,7 +82,7 @@ async def test_sse_app_returns_starlette_app(self): assert sse_routes[0].path == "/sse" assert mount_routes[0].path == "/messages" - async def test_non_ascii_description(self): + async def test_non_ascii_description(self) -> None: """Test that MCPServer handles non-ASCII characters in descriptions correctly""" mcp = MCPServer() @@ -105,7 +105,7 @@ def hello_world(name: str = "世界") -> str: assert isinstance(content, TextContent) assert "¡Hola, 世界! 👋" == content.text - async def test_add_tool_decorator(self): + async def test_add_tool_decorator(self) -> None: mcp = MCPServer() @mcp.tool() @@ -114,7 +114,7 @@ def sum(x: int, y: int) -> int: # pragma: no cover assert len(mcp._tool_manager.list_tools()) == 1 - async def test_add_tool_decorator_incorrect_usage(self): + async def test_add_tool_decorator_incorrect_usage(self) -> None: mcp = MCPServer() with pytest.raises(TypeError, match="The @tool decorator was used incorrectly"): @@ -123,7 +123,7 @@ async def test_add_tool_decorator_incorrect_usage(self): def sum(x: int, y: int) -> int: # pragma: no cover return x + y - async def test_add_resource_decorator(self): + async def test_add_resource_decorator(self) -> None: mcp = MCPServer() @mcp.resource("r://{x}") @@ -132,7 +132,7 @@ def get_data(x: str) -> str: # pragma: no cover assert len(mcp._resource_manager._templates) == 1 - async def test_add_resource_decorator_incorrect_usage(self): + async def test_add_resource_decorator_incorrect_usage(self) -> None: mcp = MCPServer() with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"): @@ -149,7 +149,7 @@ class TestDnsRebindingProtection: based on the host parameter passed to those methods. """ - def test_auto_enabled_for_127_0_0_1_sse(self): + def test_auto_enabled_for_127_0_0_1_sse(self) -> None: """DNS rebinding protection should auto-enable for host=127.0.0.1 in SSE app.""" mcp = MCPServer() # Call sse_app with host=127.0.0.1 to trigger auto-config @@ -158,31 +158,31 @@ def test_auto_enabled_for_127_0_0_1_sse(self): app = mcp.sse_app(host="127.0.0.1") assert app is not None - def test_auto_enabled_for_127_0_0_1_streamable_http(self): + def test_auto_enabled_for_127_0_0_1_streamable_http(self) -> None: """DNS rebinding protection should auto-enable for host=127.0.0.1 in StreamableHTTP app.""" mcp = MCPServer() app = mcp.streamable_http_app(host="127.0.0.1") assert app is not None - def test_auto_enabled_for_localhost_sse(self): + def test_auto_enabled_for_localhost_sse(self) -> None: """DNS rebinding protection should auto-enable for host=localhost in SSE app.""" mcp = MCPServer() app = mcp.sse_app(host="localhost") assert app is not None - def test_auto_enabled_for_ipv6_localhost_sse(self): + def test_auto_enabled_for_ipv6_localhost_sse(self) -> None: """DNS rebinding protection should auto-enable for host=::1 (IPv6 localhost) in SSE app.""" mcp = MCPServer() app = mcp.sse_app(host="::1") assert app is not None - def test_not_auto_enabled_for_other_hosts_sse(self): + def test_not_auto_enabled_for_other_hosts_sse(self) -> None: """DNS rebinding protection should NOT auto-enable for other hosts in SSE app.""" mcp = MCPServer() app = mcp.sse_app(host="0.0.0.0") assert app is not None - def test_explicit_settings_not_overridden_sse(self): + def test_explicit_settings_not_overridden_sse(self) -> None: """Explicit transport_security settings should not be overridden in SSE app.""" custom_settings = TransportSecuritySettings( enable_dns_rebinding_protection=False, @@ -192,7 +192,7 @@ def test_explicit_settings_not_overridden_sse(self): app = mcp.sse_app(host="127.0.0.1", transport_security=custom_settings) assert app is not None - def test_explicit_settings_not_overridden_streamable_http(self): + def test_explicit_settings_not_overridden_streamable_http(self) -> None: """Explicit transport_security settings should not be overridden in StreamableHTTP app.""" custom_settings = TransportSecuritySettings( enable_dns_rebinding_protection=False, @@ -228,20 +228,20 @@ def mixed_content_tool_fn() -> list[ContentBlock]: class TestServerTools: - async def test_add_tool(self): + async def test_add_tool(self) -> None: mcp = MCPServer() mcp.add_tool(tool_fn) mcp.add_tool(tool_fn) assert len(mcp._tool_manager.list_tools()) == 1 - async def test_list_tools(self): + async def test_list_tools(self) -> None: mcp = MCPServer() mcp.add_tool(tool_fn) async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 1 - async def test_call_tool(self): + async def test_call_tool(self) -> None: mcp = MCPServer() mcp.add_tool(tool_fn) async with Client(mcp) as client: @@ -249,7 +249,7 @@ async def test_call_tool(self): assert not hasattr(result, "error") assert len(result.content) > 0 - async def test_tool_exception_handling(self): + async def test_tool_exception_handling(self) -> None: mcp = MCPServer() mcp.add_tool(error_tool_fn) async with Client(mcp) as client: @@ -260,7 +260,7 @@ async def test_tool_exception_handling(self): assert "Test error" in content.text assert result.is_error is True - async def test_tool_error_handling(self): + async def test_tool_error_handling(self) -> None: mcp = MCPServer() mcp.add_tool(error_tool_fn) async with Client(mcp) as client: @@ -271,7 +271,7 @@ async def test_tool_error_handling(self): assert "Test error" in content.text assert result.is_error is True - async def test_tool_error_details(self): + async def test_tool_error_details(self) -> None: """Test that exception details are properly formatted in the response""" mcp = MCPServer() mcp.add_tool(error_tool_fn) @@ -283,7 +283,7 @@ async def test_tool_error_details(self): assert "Test error" in content.text assert result.is_error is True - async def test_tool_return_value_conversion(self): + async def test_tool_return_value_conversion(self) -> None: mcp = MCPServer() mcp.add_tool(tool_fn) async with Client(mcp) as client: @@ -296,7 +296,7 @@ async def test_tool_return_value_conversion(self): assert result.structured_content is not None assert result.structured_content == {"result": 3} - async def test_tool_image_helper(self, tmp_path: Path): + async def test_tool_image_helper(self, tmp_path: Path) -> None: # Create a test image image_path = tmp_path / "test.png" image_path.write_bytes(b"fake png data") @@ -316,7 +316,7 @@ async def test_tool_image_helper(self, tmp_path: Path): # Check structured content - Image return type should NOT have structured output assert result.structured_content is None - async def test_tool_audio_helper(self, tmp_path: Path): + async def test_tool_audio_helper(self, tmp_path: Path) -> None: # Create a test audio audio_path = tmp_path / "test.wav" audio_path.write_bytes(b"fake wav data") @@ -348,7 +348,7 @@ async def test_tool_audio_helper(self, tmp_path: Path): ("test.unknown", "application/octet-stream"), # Unknown extension fallback ], ) - async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str): + async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str) -> None: """Test that Audio helper correctly detects MIME types from file suffixes""" mcp = MCPServer() mcp.add_tool(audio_tool_fn) @@ -368,7 +368,7 @@ async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, decoded = base64.b64decode(content.data) assert decoded == b"fake audio data" - async def test_tool_mixed_content(self): + async def test_tool_mixed_content(self) -> None: mcp = MCPServer() mcp.add_tool(mixed_content_tool_fn) async with Client(mcp) as client: @@ -398,7 +398,7 @@ async def test_tool_mixed_content(self): for key, value in expected.items(): assert structured_result[i][key] == value - async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path): + async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path) -> None: """Test that lists containing Image objects and other types are handled correctly""" # Create a test image @@ -450,7 +450,7 @@ def mixed_list_fn() -> list: # type: ignore # Check structured content - untyped list with Image objects should NOT have structured output assert result.structured_content is None - async def test_tool_structured_output_basemodel(self): + async def test_tool_structured_output_basemodel(self) -> None: """Test tool with structured output returning BaseModel""" class UserOutput(BaseModel): @@ -484,7 +484,7 @@ def get_user(user_id: int) -> UserOutput: assert isinstance(result.content[0], TextContent) assert '"name": "John Doe"' in result.content[0].text - async def test_tool_structured_output_primitive(self): + async def test_tool_structured_output_primitive(self) -> None: """Test tool with structured output returning primitive type""" def calculate_sum(a: int, b: int) -> int: @@ -510,7 +510,7 @@ def calculate_sum(a: int, b: int) -> int: assert result.structured_content is not None assert result.structured_content == {"result": 12} - async def test_tool_structured_output_list(self): + async def test_tool_structured_output_list(self) -> None: """Test tool with structured output returning list""" def get_numbers() -> list[int]: @@ -526,7 +526,7 @@ def get_numbers() -> list[int]: assert result.structured_content is not None assert result.structured_content == {"result": [1, 2, 3, 4, 5]} - async def test_tool_structured_output_server_side_validation_error(self): + async def test_tool_structured_output_server_side_validation_error(self) -> None: """Test that server-side validation errors are handled properly""" def get_numbers() -> list[int]: @@ -542,7 +542,7 @@ def get_numbers() -> list[int]: assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) - async def test_tool_structured_output_dict_str_any(self): + async def test_tool_structured_output_dict_str_any(self) -> None: """Test tool with dict[str, Any] structured output""" def get_metadata() -> dict[str, Any]: @@ -583,7 +583,7 @@ def get_metadata() -> dict[str, Any]: } assert result.structured_content == expected - async def test_tool_structured_output_dict_str_typed(self): + async def test_tool_structured_output_dict_str_typed(self) -> None: """Test tool with dict[str, T] structured output for specific T""" def get_settings() -> dict[str, str]: @@ -606,7 +606,7 @@ def get_settings() -> dict[str, str]: assert result.is_error is False assert result.structured_content == {"theme": "dark", "language": "en", "timezone": "UTC"} - async def test_remove_tool(self): + async def test_remove_tool(self) -> None: """Test removing a tool from the server.""" mcp = MCPServer() mcp.add_tool(tool_fn) @@ -620,14 +620,14 @@ async def test_remove_tool(self): # Verify tool is removed assert len(mcp._tool_manager.list_tools()) == 0 - async def test_remove_nonexistent_tool(self): + async def test_remove_nonexistent_tool(self) -> None: """Test that removing a non-existent tool raises ToolError.""" mcp = MCPServer() with pytest.raises(ToolError, match="Unknown tool: nonexistent"): mcp.remove_tool("nonexistent") - async def test_remove_tool_and_list(self): + async def test_remove_tool_and_list(self) -> None: """Test that a removed tool doesn't appear in list_tools.""" mcp = MCPServer() mcp.add_tool(tool_fn) @@ -650,7 +650,7 @@ async def test_remove_tool_and_list(self): assert len(tools.tools) == 1 assert tools.tools[0].name == "error_tool_fn" - async def test_remove_tool_and_call(self): + async def test_remove_tool_and_call(self) -> None: """Test that calling a removed tool fails appropriately.""" mcp = MCPServer() mcp.add_tool(tool_fn) @@ -676,10 +676,10 @@ async def test_remove_tool_and_call(self): class TestServerResources: - async def test_text_resource(self): + async def test_text_resource(self) -> None: mcp = MCPServer() - def get_text(): + def get_text() -> str: return "Hello, world!" resource = FunctionResource(uri="resource://test", name="test", fn=get_text) @@ -691,7 +691,7 @@ def get_text(): assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Hello, world!" - async def test_read_unknown_resource(self): + async def test_read_unknown_resource(self) -> None: """Test that reading an unknown resource raises MCPError.""" mcp = MCPServer() @@ -699,22 +699,22 @@ async def test_read_unknown_resource(self): with pytest.raises(MCPError, match="Unknown resource: unknown://missing"): await client.read_resource("unknown://missing") - async def test_read_resource_error(self): + async def test_read_resource_error(self) -> None: """Test that resource read errors are properly wrapped in MCPError.""" mcp = MCPServer() @mcp.resource("resource://failing") - def failing_resource(): + def failing_resource() -> NoReturn: raise ValueError("Resource read failed") async with Client(mcp) as client: with pytest.raises(MCPError, match="Error reading resource resource://failing"): await client.read_resource("resource://failing") - async def test_binary_resource(self): + async def test_binary_resource(self) -> None: mcp = MCPServer() - def get_binary(): + def get_binary() -> bytes: return b"Binary data" resource = FunctionResource( @@ -731,7 +731,7 @@ def get_binary(): assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary data").decode() - async def test_file_resource_text(self, tmp_path: Path): + async def test_file_resource_text(self, tmp_path: Path) -> None: mcp = MCPServer() # Create a text file @@ -747,7 +747,7 @@ async def test_file_resource_text(self, tmp_path: Path): assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Hello from file!" - async def test_file_resource_binary(self, tmp_path: Path): + async def test_file_resource_binary(self, tmp_path: Path) -> None: mcp = MCPServer() # Create a binary file @@ -768,7 +768,7 @@ async def test_file_resource_binary(self, tmp_path: Path): assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() - async def test_function_resource(self): + async def test_function_resource(self) -> None: mcp = MCPServer() @mcp.resource("function://test", name="test_get_data") @@ -787,7 +787,7 @@ def get_data() -> str: # pragma: no cover class TestServerResourceTemplates: - async def test_resource_with_params(self): + async def test_resource_with_params(self) -> None: """Test that a resource with function parameters raises an error if the URI parameters don't match""" mcp = MCPServer() @@ -798,7 +798,7 @@ async def test_resource_with_params(self): def get_data_fn(param: str) -> str: # pragma: no cover return f"Data: {param}" - async def test_resource_with_uri_params(self): + async def test_resource_with_uri_params(self) -> None: """Test that a resource with URI parameters is automatically a template""" mcp = MCPServer() @@ -808,7 +808,7 @@ async def test_resource_with_uri_params(self): def get_data() -> str: # pragma: no cover return "Data" - async def test_resource_with_untyped_params(self): + async def test_resource_with_untyped_params(self) -> None: """Test that a resource with untyped parameters raises an error""" mcp = MCPServer() @@ -816,7 +816,7 @@ async def test_resource_with_untyped_params(self): def get_data(param) -> str: # type: ignore # pragma: no cover return "Data" - async def test_resource_matching_params(self): + async def test_resource_matching_params(self) -> None: """Test that a resource with matching URI and function parameters works""" mcp = MCPServer() @@ -830,7 +830,7 @@ def get_data(name: str) -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for test" - async def test_resource_mismatched_params(self): + async def test_resource_mismatched_params(self) -> None: """Test that mismatched parameters raise an error""" mcp = MCPServer() @@ -840,7 +840,7 @@ async def test_resource_mismatched_params(self): def get_data(user: str) -> str: # pragma: no cover return f"Data for {user}" - async def test_resource_multiple_params(self): + async def test_resource_multiple_params(self) -> None: """Test that multiple parameters work correctly""" mcp = MCPServer() @@ -854,7 +854,7 @@ def get_data(org: str, repo: str) -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for cursor/myrepo" - async def test_resource_multiple_mismatched_params(self): + async def test_resource_multiple_mismatched_params(self) -> None: """Test that mismatched parameters raise an error""" mcp = MCPServer() @@ -877,7 +877,7 @@ def get_static_data() -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Static data" - async def test_template_to_resource_conversion(self): + async def test_template_to_resource_conversion(self) -> None: """Test that templates are properly converted to resources when accessed""" mcp = MCPServer() @@ -895,7 +895,7 @@ def get_data(name: str) -> str: result = await resource.read() assert result == "Data for test" - async def test_resource_template_includes_mime_type(self): + async def test_resource_template_includes_mime_type(self) -> None: """Test that list resource templates includes the correct mimeType.""" mcp = MCPServer() @@ -928,7 +928,7 @@ class TestServerResourceMetadata: Note: read_resource does NOT pass meta to protocol response (lowlevel/server.py only extracts content/mime_type). """ - async def test_resource_decorator_with_metadata(self): + async def test_resource_decorator_with_metadata(self) -> None: """Test that @resource decorator accepts and passes meta parameter.""" # Tests static resource flow: decorator -> FunctionResource -> list_resources (server.py:544,635,361) mcp = MCPServer() @@ -949,7 +949,7 @@ def get_config() -> str: ... # pragma: no branch ] ) - async def test_resource_template_decorator_with_metadata(self): + async def test_resource_template_decorator_with_metadata(self) -> None: """Test that @resource decorator passes meta to templates.""" # Tests template resource flow: decorator -> add_template() -> list_resource_templates (server.py:544,622,377) mcp = MCPServer() @@ -970,7 +970,7 @@ def get_weather(city: str) -> str: ... # pragma: no branch ] ) - async def test_read_resource_returns_meta(self): + async def test_read_resource_returns_meta(self) -> None: """Test that read_resource includes meta in response.""" # Tests end-to-end: Resource.meta -> ReadResourceContents.meta -> protocol _meta (lowlevel/server.py:341,371) mcp = MCPServer() @@ -998,7 +998,7 @@ def get_data() -> str: class TestContextInjection: """Test context injection in tools, resources, and prompts.""" - async def test_context_detection(self): + async def test_context_detection(self) -> None: """Test that context parameters are properly detected.""" mcp = MCPServer() @@ -1008,7 +1008,7 @@ def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover tool = mcp._tool_manager.add_tool(tool_with_context) assert tool.context_kwarg == "ctx" - async def test_context_injection(self): + async def test_context_injection(self) -> None: """Test that context is properly injected into tool calls.""" mcp = MCPServer() @@ -1025,7 +1025,7 @@ def tool_with_context(x: int, ctx: Context) -> str: assert "Request" in content.text assert "42" in content.text - async def test_async_context(self): + async def test_async_context(self) -> None: """Test that context works in async functions.""" mcp = MCPServer() @@ -1042,7 +1042,7 @@ async def async_tool(x: int, ctx: Context) -> str: assert "Async request" in content.text assert "42" in content.text - async def test_context_logging(self): + async def test_context_logging(self) -> None: """Test that context logging methods work.""" mcp = MCPServer() @@ -1069,7 +1069,7 @@ async def logging_tool(msg: str, ctx: Context) -> str: mock_log.assert_any_call(level="warning", data="Warning message", logger=None, related_request_id="1") mock_log.assert_any_call(level="error", data="Error message", logger=None, related_request_id="1") - async def test_optional_context(self): + async def test_optional_context(self) -> None: """Test that context is optional.""" mcp = MCPServer() @@ -1084,7 +1084,7 @@ def no_context(x: int) -> int: assert isinstance(content, TextContent) assert content.text == "42" - async def test_context_resource_access(self): + async def test_context_resource_access(self) -> None: """Test that context can access resources.""" mcp = MCPServer() @@ -1107,7 +1107,7 @@ async def tool_with_resource(ctx: Context) -> str: assert isinstance(content, TextContent) assert "Read resource: resource data" in content.text - async def test_resource_with_context(self): + async def test_resource_with_context(self) -> None: """Test that resources can receive context parameter.""" mcp = MCPServer() @@ -1133,7 +1133,7 @@ def resource_with_context(name: str, ctx: Context) -> str: # Should have either request_id or indication that context was injected assert "Resource test - context injected" == content.text - async def test_resource_without_context(self): + async def test_resource_without_context(self) -> None: """Test that resources without context work normally.""" mcp = MCPServer() @@ -1160,7 +1160,7 @@ def resource_no_context(name: str) -> str: ) ) - async def test_resource_context_custom_name(self): + async def test_resource_context_custom_name(self) -> None: """Test resource context with custom parameter name.""" mcp = MCPServer() @@ -1188,7 +1188,7 @@ def resource_custom_ctx(id: str, my_ctx: Context) -> str: ) ) - async def test_prompt_with_context(self): + async def test_prompt_with_context(self) -> None: """Test that prompts can receive context parameter.""" mcp = MCPServer() @@ -1208,7 +1208,7 @@ def prompt_with_context(text: str, ctx: Context) -> str: assert isinstance(content, TextContent) assert "Prompt 'test' - context injected" in content.text - async def test_prompt_without_context(self): + async def test_prompt_without_context(self) -> None: """Test that prompts without context work normally.""" mcp = MCPServer() @@ -1230,7 +1230,7 @@ def prompt_no_context(text: str) -> str: class TestServerPrompts: """Test prompt functionality in MCPServer server.""" - async def test_get_prompt_direct_call_without_context(self): + async def test_get_prompt_direct_call_without_context(self) -> None: """Test calling mcp.get_prompt() directly without passing context.""" mcp = MCPServer() @@ -1243,7 +1243,7 @@ def fn() -> str: assert isinstance(content, TextContent) assert content.text == "Hello, world!" - async def test_prompt_decorator(self): + async def test_prompt_decorator(self) -> None: """Test that the prompt decorator registers prompts correctly.""" mcp = MCPServer() @@ -1259,7 +1259,7 @@ def fn() -> str: assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" - async def test_prompt_decorator_with_name(self): + async def test_prompt_decorator_with_name(self) -> None: """Test prompt decorator with custom name.""" mcp = MCPServer() @@ -1274,7 +1274,7 @@ def fn() -> str: assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" - async def test_prompt_decorator_with_description(self): + async def test_prompt_decorator_with_description(self) -> None: """Test prompt decorator with custom description.""" mcp = MCPServer() @@ -1289,7 +1289,7 @@ def fn() -> str: assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" - def test_prompt_decorator_error(self): + def test_prompt_decorator_error(self) -> None: """Test error when decorator is used incorrectly.""" mcp = MCPServer() with pytest.raises(TypeError, match="decorator was used incorrectly"): @@ -1297,7 +1297,7 @@ def test_prompt_decorator_error(self): @mcp.prompt # type: ignore def fn() -> str: ... # pragma: no branch - async def test_list_prompts(self): + async def test_list_prompts(self) -> None: """Test listing prompts through MCP protocol.""" mcp = MCPServer() @@ -1321,7 +1321,7 @@ def fn(name: str, optional: str = "default") -> str: ... # pragma: no branch ) ) - async def test_get_prompt(self): + async def test_get_prompt(self) -> None: """Test getting a prompt through MCP protocol.""" mcp = MCPServer() @@ -1338,7 +1338,7 @@ def fn(name: str) -> str: ) ) - async def test_get_prompt_with_description(self): + async def test_get_prompt_with_description(self) -> None: """Test getting a prompt through MCP protocol.""" mcp = MCPServer() @@ -1350,7 +1350,7 @@ def fn(name: str) -> str: result = await client.get_prompt("fn", {"name": "World"}) assert result.description == "Test prompt description" - async def test_get_prompt_with_docstring_description(self): + async def test_get_prompt_with_docstring_description(self) -> None: """Test prompt uses docstring as description when not explicitly provided.""" mcp = MCPServer() @@ -1368,7 +1368,7 @@ def fn(name: str) -> str: ) ) - async def test_get_prompt_with_resource(self): + async def test_get_prompt_with_resource(self) -> None: """Test getting a prompt that returns resource content.""" mcp = MCPServer() @@ -1399,7 +1399,7 @@ def fn() -> Message: ) ) - async def test_get_unknown_prompt(self): + async def test_get_unknown_prompt(self) -> None: """Test error when getting unknown prompt.""" mcp = MCPServer() @@ -1407,7 +1407,7 @@ async def test_get_unknown_prompt(self): with pytest.raises(MCPError, match="Unknown prompt"): await client.get_prompt("unknown") - async def test_get_prompt_missing_args(self): + async def test_get_prompt_missing_args(self) -> None: """Test error when required arguments are missing.""" mcp = MCPServer() @@ -1452,7 +1452,7 @@ def test_streamable_http_no_redirect() -> None: assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" -async def test_report_progress_passes_related_request_id(): +async def test_report_progress_passes_related_request_id() -> None: """Test that report_progress passes the request_id as related_request_id. Without related_request_id, the streamable HTTP transport cannot route diff --git a/tests/server/mcpserver/test_title.py b/tests/server/mcpserver/test_title.py index 662464757..70218fff1 100644 --- a/tests/server/mcpserver/test_title.py +++ b/tests/server/mcpserver/test_title.py @@ -10,7 +10,7 @@ @pytest.mark.anyio -async def test_server_name_title_description_version(): +async def test_server_name_title_description_version() -> None: """Test that server title and description are set and retrievable correctly.""" mcp = MCPServer( name="TestServer", @@ -34,7 +34,7 @@ async def test_server_name_title_description_version(): @pytest.mark.anyio -async def test_tool_title_precedence(): +async def test_tool_title_precedence() -> None: """Test that tool title precedence works correctly: title > annotations.title > name.""" # Create server with various tool configurations mcp = MCPServer(name="TitleTestServer") @@ -88,7 +88,7 @@ def tool_with_both(message: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_prompt_title(): +async def test_prompt_title() -> None: """Test that prompt titles work correctly.""" mcp = MCPServer(name="PromptTitleServer") @@ -121,7 +121,7 @@ def titled_prompt(topic: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_resource_title(): +async def test_resource_title() -> None: """Test that resource titles work correctly.""" mcp = MCPServer(name="ResourceTitleServer") @@ -194,7 +194,7 @@ def titled_dynamic_resource(id: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_get_display_name_utility(): +async def test_get_display_name_utility() -> None: """Test the get_display_name utility function.""" # Test tool precedence: title > annotations.title > name diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index e4dfd4ff9..781386c2e 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -15,7 +15,7 @@ class TestAddTools: - def test_basic_function(self): + def test_basic_function(self) -> None: """Test registering and running a basic function.""" def sum(a: int, b: int) -> int: # pragma: no cover @@ -33,7 +33,7 @@ def sum(a: int, b: int) -> int: # pragma: no cover assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" - def test_init_with_tools(self, caplog: pytest.LogCaptureFixture): + def test_init_with_tools(self, caplog: pytest.LogCaptureFixture) -> None: def sum(a: int, b: int) -> int: # pragma: no cover return a + b @@ -64,7 +64,7 @@ class AddArguments(ArgModelBase): assert "Tool already exists: sum" in caplog.text @pytest.mark.anyio - async def test_async_function(self): + async def test_async_function(self) -> None: """Test registering and running an async function.""" async def fetch_data(url: str) -> str: # pragma: no cover @@ -81,7 +81,7 @@ async def fetch_data(url: str) -> str: # pragma: no cover assert tool.is_async is True assert tool.parameters["properties"]["url"]["type"] == "string" - def test_pydantic_model_function(self): + def test_pydantic_model_function(self) -> None: """Test registering a function that takes a Pydantic model.""" class UserInput(BaseModel): @@ -104,11 +104,11 @@ def create_user(user: UserInput, flag: bool) -> dict[str, Any]: # pragma: no co assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] assert "flag" in tool.parameters["properties"] - def test_add_callable_object(self): + def test_add_callable_object(self) -> None: """Test registering a callable object.""" class MyTool: - def __init__(self): + def __init__(self) -> None: self.__name__ = "MyTool" def __call__(self, x: int) -> int: # pragma: no cover @@ -121,11 +121,11 @@ def __call__(self, x: int) -> int: # pragma: no cover assert tool.parameters["properties"]["x"]["type"] == "integer" @pytest.mark.anyio - async def test_add_async_callable_object(self): + async def test_add_async_callable_object(self) -> None: """Test registering an async callable object.""" class MyAsyncTool: - def __init__(self): + def __init__(self) -> None: self.__name__ = "MyAsyncTool" async def __call__(self, x: int) -> int: # pragma: no cover @@ -137,22 +137,22 @@ async def __call__(self, x: int) -> int: # pragma: no cover assert tool.is_async is True assert tool.parameters["properties"]["x"]["type"] == "integer" - def test_add_invalid_tool(self): + def test_add_invalid_tool(self) -> None: manager = ToolManager() with pytest.raises(AttributeError): manager.add_tool(1) # type: ignore - def test_add_lambda(self): + def test_add_lambda(self) -> None: manager = ToolManager() tool = manager.add_tool(lambda x: x, name="my_tool") # type: ignore[reportUnknownLambdaType] assert tool.name == "my_tool" - def test_add_lambda_with_no_name(self): + def test_add_lambda_with_no_name(self) -> None: manager = ToolManager() with pytest.raises(ValueError, match="You must provide a name for lambda functions"): manager.add_tool(lambda x: x) # type: ignore[reportUnknownLambdaType] - def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): + def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture) -> None: """Test warning on duplicate tools.""" def f(x: int) -> int: # pragma: no cover @@ -164,7 +164,7 @@ def f(x: int) -> int: # pragma: no cover manager.add_tool(f) assert "Tool already exists: f" in caplog.text - def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): + def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture) -> None: """Test disabling warning on duplicate tools.""" def f(x: int) -> int: # pragma: no cover @@ -180,7 +180,7 @@ def f(x: int) -> int: # pragma: no cover class TestCallTools: @pytest.mark.anyio - async def test_call_tool(self): + async def test_call_tool(self) -> None: def sum(a: int, b: int) -> int: """Add two numbers.""" return a + b @@ -191,7 +191,7 @@ def sum(a: int, b: int) -> int: assert result == 3 @pytest.mark.anyio - async def test_call_async_tool(self): + async def test_call_async_tool(self) -> None: async def double(n: int) -> int: """Double a number.""" return n * 2 @@ -202,9 +202,9 @@ async def double(n: int) -> int: assert result == 10 @pytest.mark.anyio - async def test_call_object_tool(self): + async def test_call_object_tool(self) -> None: class MyTool: - def __init__(self): + def __init__(self) -> None: self.__name__ = "MyTool" def __call__(self, x: int) -> int: @@ -216,9 +216,9 @@ def __call__(self, x: int) -> int: assert result == 10 @pytest.mark.anyio - async def test_call_async_object_tool(self): + async def test_call_async_object_tool(self) -> None: class MyAsyncTool: - def __init__(self): + def __init__(self) -> None: self.__name__ = "MyAsyncTool" async def __call__(self, x: int) -> int: @@ -230,7 +230,7 @@ async def __call__(self, x: int) -> int: assert result == 10 @pytest.mark.anyio - async def test_call_tool_with_default_args(self): + async def test_call_tool_with_default_args(self) -> None: def sum(a: int, b: int = 1) -> int: """Add two numbers.""" return a + b @@ -241,7 +241,7 @@ def sum(a: int, b: int = 1) -> int: assert result == 2 @pytest.mark.anyio - async def test_call_tool_with_missing_args(self): + async def test_call_tool_with_missing_args(self) -> None: def sum(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b @@ -252,13 +252,13 @@ def sum(a: int, b: int) -> int: # pragma: no cover await manager.call_tool("sum", {"a": 1}, Context()) @pytest.mark.anyio - async def test_call_unknown_tool(self): + async def test_call_unknown_tool(self) -> None: manager = ToolManager() with pytest.raises(ToolError): await manager.call_tool("unknown", {"a": 1}, Context()) @pytest.mark.anyio - async def test_call_tool_with_list_int_input(self): + async def test_call_tool_with_list_int_input(self) -> None: def sum_vals(vals: list[int]) -> int: return sum(vals) @@ -271,7 +271,7 @@ def sum_vals(vals: list[int]) -> int: assert result == 6 @pytest.mark.anyio - async def test_call_tool_with_list_str_or_str_input(self): + async def test_call_tool_with_list_str_or_str_input(self) -> None: def concat_strs(vals: list[str] | str) -> str: return vals if isinstance(vals, str) else "".join(vals) @@ -288,7 +288,7 @@ def concat_strs(vals: list[str] | str) -> str: assert result == '"a"' @pytest.mark.anyio - async def test_call_tool_with_complex_model(self): + async def test_call_tool_with_complex_model(self) -> None: class MyShrimpTank(BaseModel): class Shrimp(BaseModel): name: str @@ -317,7 +317,7 @@ def name_shrimp(tank: MyShrimpTank) -> list[str]: class TestToolSchema: @pytest.mark.anyio - async def test_context_arg_excluded_from_schema(self): + async def test_context_arg_excluded_from_schema(self) -> None: def something(a: int, ctx: Context) -> int: # pragma: no cover return a @@ -331,7 +331,7 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover class TestContextHandling: """Test context handling in the tool manager.""" - def test_context_parameter_detection(self): + def test_context_parameter_detection(self) -> None: """Test that context parameters are properly detected in Tool.from_function().""" @@ -355,7 +355,7 @@ def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, Reques assert tool.context_kwarg == "ctx" @pytest.mark.anyio - async def test_context_injection(self): + async def test_context_injection(self) -> None: """Test that context is properly injected during tool execution.""" def tool_with_context(x: int, ctx: Context) -> str: @@ -369,7 +369,7 @@ def tool_with_context(x: int, ctx: Context) -> str: assert result == "42" @pytest.mark.anyio - async def test_context_injection_async(self): + async def test_context_injection_async(self) -> None: """Test that context is properly injected in async tools.""" async def async_tool(x: int, ctx: Context) -> str: @@ -383,7 +383,7 @@ async def async_tool(x: int, ctx: Context) -> str: assert result == "42" @pytest.mark.anyio - async def test_context_error_handling(self): + async def test_context_error_handling(self) -> None: """Test error handling when context injection fails.""" def tool_with_context(x: int, ctx: Context) -> str: @@ -397,7 +397,7 @@ def tool_with_context(x: int, ctx: Context) -> str: class TestToolAnnotations: - def test_tool_annotations(self): + def test_tool_annotations(self) -> None: """Test that tool annotations are correctly added to tools.""" def read_data(path: str) -> str: # pragma: no cover @@ -419,7 +419,7 @@ def read_data(path: str) -> str: # pragma: no cover assert tool.annotations.open_world_hint is False @pytest.mark.anyio - async def test_tool_annotations_in_mcpserver(self): + async def test_tool_annotations_in_mcpserver(self) -> None: """Test that tool annotations are included in MCPTool conversion.""" app = MCPServer() @@ -440,7 +440,7 @@ class TestStructuredOutput: """Test structured output functionality in tools.""" @pytest.mark.anyio - async def test_tool_with_basemodel_output(self): + async def test_tool_with_basemodel_output(self) -> None: """Test tool with BaseModel return type.""" class UserOutput(BaseModel): @@ -458,7 +458,7 @@ def get_user(user_id: int) -> UserOutput: assert len(result) == 2 and result[1] == {"name": "John", "age": 30} @pytest.mark.anyio - async def test_tool_with_primitive_output(self): + async def test_tool_with_primitive_output(self) -> None: """Test tool with primitive return type.""" def double_number(n: int) -> int: @@ -473,7 +473,7 @@ def double_number(n: int) -> int: assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10} @pytest.mark.anyio - async def test_tool_with_typeddict_output(self): + async def test_tool_with_typeddict_output(self) -> None: """Test tool with TypedDict return type.""" class UserDict(TypedDict): @@ -492,7 +492,7 @@ def get_user_dict(user_id: int) -> UserDict: assert result == expected_output @pytest.mark.anyio - async def test_tool_with_dataclass_output(self): + async def test_tool_with_dataclass_output(self) -> None: """Test tool with dataclass return type.""" @dataclass @@ -513,7 +513,7 @@ def get_person() -> Person: assert len(result) == 2 and result[1] == expected_output @pytest.mark.anyio - async def test_tool_with_list_output(self): + async def test_tool_with_list_output(self) -> None: """Test tool with list return type.""" expected_list = [1, 2, 3, 4, 5] @@ -531,7 +531,7 @@ def get_numbers() -> list[int]: assert isinstance(result[0][0], TextContent) and result[1] == expected_output @pytest.mark.anyio - async def test_tool_without_structured_output(self): + async def test_tool_without_structured_output(self) -> None: """Test that tools work normally when structured_output=False.""" def get_dict() -> dict[str, Any]: @@ -544,7 +544,7 @@ def get_dict() -> dict[str, Any]: assert isinstance(result, dict) assert result == {"key": "value"} - def test_tool_output_schema_property(self): + def test_tool_output_schema_property(self) -> None: """Test that Tool.output_schema property works correctly.""" class UserOutput(BaseModel): @@ -567,7 +567,7 @@ def get_user() -> UserOutput: # pragma: no cover assert tool.output_schema == expected_schema @pytest.mark.anyio - async def test_tool_with_dict_str_any_output(self): + async def test_tool_with_dict_str_any_output(self) -> None: """Test tool with dict[str, Any] return type.""" def get_config() -> dict[str, Any]: @@ -592,7 +592,7 @@ def get_config() -> dict[str, Any]: assert result == expected @pytest.mark.anyio - async def test_tool_with_dict_str_typed_output(self): + async def test_tool_with_dict_str_typed_output(self) -> None: """Test tool with dict[str, T] return type for specific T.""" def get_scores() -> dict[str, int]: @@ -620,7 +620,7 @@ def get_scores() -> dict[str, int]: class TestToolMetadata: """Test tool metadata functionality.""" - def test_add_tool_with_metadata(self): + def test_add_tool_with_metadata(self) -> None: """Test adding a tool with metadata via ToolManager.""" def process_data(input_data: str) -> str: # pragma: no cover @@ -637,7 +637,7 @@ def process_data(input_data: str) -> str: # pragma: no cover assert tool.meta["ui"]["type"] == "form" assert tool.meta["version"] == "1.0" - def test_add_tool_without_metadata(self): + def test_add_tool_without_metadata(self) -> None: """Test that tools without metadata have None as meta value.""" def simple_tool(x: int) -> int: # pragma: no cover @@ -650,7 +650,7 @@ def simple_tool(x: int) -> int: # pragma: no cover assert tool.meta is None @pytest.mark.anyio - async def test_metadata_in_mcpserver_decorator(self): + async def test_metadata_in_mcpserver_decorator(self) -> None: """Test that metadata is correctly added via MCPServer.tool decorator.""" app = MCPServer() @@ -671,7 +671,7 @@ def upload_file(filename: str) -> str: # pragma: no cover assert tool.meta["priority"] == "high" @pytest.mark.anyio - async def test_metadata_in_list_tools(self): + async def test_metadata_in_list_tools(self) -> None: """Test that metadata is included in MCPTool when listing tools.""" app = MCPServer() @@ -692,7 +692,7 @@ def analyze_text(text: str) -> dict[str, Any]: # pragma: no cover assert tools[0].meta == metadata @pytest.mark.anyio - async def test_multiple_tools_with_different_metadata(self): + async def test_multiple_tools_with_different_metadata(self) -> None: """Test multiple tools with different metadata values.""" app = MCPServer() @@ -725,7 +725,7 @@ def tool3(z: bool) -> bool: # pragma: no cover assert tools_by_name["tool2"].meta == metadata2 assert tools_by_name["tool3"].meta is None - def test_metadata_with_complex_structure(self): + def test_metadata_with_complex_structure(self) -> None: """Test metadata with complex nested structures.""" def complex_tool(data: str) -> str: # pragma: no cover @@ -754,7 +754,7 @@ def complex_tool(data: str) -> str: # pragma: no cover assert "read" in tool.meta["permissions"] assert "data-processing" in tool.meta["tags"] - def test_metadata_empty_dict(self): + def test_metadata_empty_dict(self) -> None: """Test that empty dict metadata is preserved.""" def tool_with_empty_meta(x: int) -> int: # pragma: no cover @@ -768,7 +768,7 @@ def tool_with_empty_meta(x: int) -> int: # pragma: no cover assert tool.meta == {} @pytest.mark.anyio - async def test_metadata_with_annotations(self): + async def test_metadata_with_annotations(self) -> None: """Test that metadata and annotations can coexist.""" app = MCPServer() @@ -792,7 +792,7 @@ def combined_tool(data: str) -> str: # pragma: no cover class TestRemoveTools: """Test tool removal functionality in the tool manager.""" - def test_remove_existing_tool(self): + def test_remove_existing_tool(self) -> None: """Test removing an existing tool.""" def add(a: int, b: int) -> int: # pragma: no cover @@ -813,14 +813,14 @@ def add(a: int, b: int) -> int: # pragma: no cover assert manager.get_tool("add") is None assert len(manager.list_tools()) == 0 - def test_remove_nonexistent_tool(self): + def test_remove_nonexistent_tool(self) -> None: """Test removing a non-existent tool raises ToolError.""" manager = ToolManager() with pytest.raises(ToolError, match="Unknown tool: nonexistent"): manager.remove_tool("nonexistent") - def test_remove_tool_from_multiple_tools(self): + def test_remove_tool_from_multiple_tools(self) -> None: """Test removing one tool when multiple tools exist.""" def add(a: int, b: int) -> int: # pragma: no cover @@ -856,7 +856,7 @@ def divide(a: int, b: int) -> float: # pragma: no cover assert manager.get_tool("divide") is not None @pytest.mark.anyio - async def test_call_removed_tool_raises_error(self): + async def test_call_removed_tool_raises_error(self) -> None: """Test that calling a removed tool raises ToolError.""" def greet(name: str) -> str: @@ -877,7 +877,7 @@ def greet(name: str) -> str: with pytest.raises(ToolError, match="Unknown tool: greet"): await manager.call_tool("greet", {"name": "World"}, Context()) - def test_remove_tool_case_sensitive(self): + def test_remove_tool_case_sensitive(self) -> None: """Test that tool removal is case-sensitive.""" def test_func() -> str: # pragma: no cover diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index af90dc208..f2a02d580 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -13,7 +13,7 @@ @pytest.mark.anyio -async def test_url_elicitation_accept(): +async def test_url_elicitation_accept() -> None: """Test URL mode elicitation with user acceptance.""" mcp = MCPServer(name="URLElicitationServer") @@ -28,7 +28,7 @@ async def request_api_key(ctx: Context) -> str: return f"User {result.action}" # Create elicitation callback that accepts URL mode - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: assert params.mode == "url" assert params.url == "https://example.com/api_key_setup" assert params.elicitation_id == "test-elicitation-001" @@ -43,7 +43,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_decline(): +async def test_url_elicitation_decline() -> None: """Test URL mode elicitation with user declining.""" mcp = MCPServer(name="URLElicitationDeclineServer") @@ -57,7 +57,7 @@ async def oauth_flow(ctx: Context) -> str: # Test only checks decline path return f"User {result.action} authorization" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: assert params.mode == "url" return ElicitResult(action="decline") @@ -69,7 +69,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_cancel(): +async def test_url_elicitation_cancel() -> None: """Test URL mode elicitation with user cancelling.""" mcp = MCPServer(name="URLElicitationCancelServer") @@ -83,7 +83,7 @@ async def payment_flow(ctx: Context) -> str: # Test only checks cancel path return f"User {result.action} payment" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: assert params.mode == "url" return ElicitResult(action="cancel") @@ -95,7 +95,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_helper_function(): +async def test_url_elicitation_helper_function() -> None: """Test the elicit_url helper function.""" mcp = MCPServer(name="URLElicitationHelperServer") @@ -110,7 +110,7 @@ async def setup_credentials(ctx: Context) -> str: # Test only checks accept path - return the type name return type(result).__name__ - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="accept") async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -121,7 +121,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_no_content_in_response(): +async def test_url_no_content_in_response() -> None: """Test that URL mode elicitation responses don't include content field.""" mcp = MCPServer(name="URLContentCheckServer") @@ -137,7 +137,7 @@ async def check_url_response(ctx: Context) -> str: assert result.content is None return f"Action: {result.action}, Content: {result.content}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: # Verify that this is URL mode assert params.mode == "url" assert isinstance(params, types.ElicitRequestURLParams) @@ -155,7 +155,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_form_mode_still_works(): +async def test_form_mode_still_works() -> None: """Ensure form mode elicitation still works after SEP 1036.""" mcp = MCPServer(name="FormModeBackwardCompatServer") @@ -170,7 +170,7 @@ async def ask_name(ctx: Context) -> str: assert result.data is not None return f"Hello, {result.data.name}!" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: # Verify form mode parameters assert params.mode == "form" assert isinstance(params, types.ElicitRequestFormParams) @@ -186,7 +186,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_elicit_complete_notification(): +async def test_elicit_complete_notification() -> None: """Test that elicitation completion notifications can be sent and received.""" mcp = MCPServer(name="ElicitCompleteServer") @@ -206,7 +206,7 @@ async def trigger_elicitation(ctx: Context) -> str: return "Elicitation completed" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="accept") # pragma: no cover async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -223,7 +223,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_required_error_code(): +async def test_url_elicitation_required_error_code() -> None: """Test that the URL_ELICITATION_REQUIRED error code is correct.""" # Verify the error code matches the specification (SEP 1036) assert types.URL_ELICITATION_REQUIRED == -32042, ( @@ -232,7 +232,7 @@ async def test_url_elicitation_required_error_code(): @pytest.mark.anyio -async def test_elicit_url_typed_results(): +async def test_elicit_url_typed_results() -> None: """Test that elicit_url returns properly typed result objects.""" mcp = MCPServer(name="TypedResultsServer") @@ -263,7 +263,7 @@ async def test_cancel(ctx: Context) -> str: return "Not cancelled" # pragma: no cover # Test declined result - async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") async with Client(mcp, elicitation_callback=decline_callback) as client: @@ -273,7 +273,7 @@ async def decline_callback(context: RequestContext[ClientSession], params: Elici assert result.content[0].text == "Declined" # Test cancelled result - async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="cancel") async with Client(mcp, elicitation_callback=cancel_callback) as client: @@ -284,7 +284,7 @@ async def cancel_callback(context: RequestContext[ClientSession], params: Elicit @pytest.mark.anyio -async def test_deprecated_elicit_method(): +async def test_deprecated_elicit_method() -> None: """Test the deprecated elicit() method for backward compatibility.""" mcp = MCPServer(name="DeprecatedElicitServer") @@ -303,7 +303,7 @@ async def use_deprecated_elicit(ctx: Context) -> str: return f"Email: {result.content.get('email', 'none')}" return "No email provided" # pragma: no cover - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: # Verify this is form mode assert params.mode == "form" assert params.requested_schema is not None @@ -317,7 +317,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_ctx_elicit_url_convenience_method(): +async def test_ctx_elicit_url_convenience_method() -> None: """Test the ctx.elicit_url() convenience method (vs ctx.session.elicit_url()).""" mcp = MCPServer(name="CtxElicitUrlServer") @@ -331,7 +331,7 @@ async def direct_elicit_url(ctx: Context) -> str: ) return f"Result: {result.action}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: assert params.mode == "url" assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") diff --git a/tests/server/mcpserver/test_url_elicitation_error_throw.py b/tests/server/mcpserver/test_url_elicitation_error_throw.py index 1f45fd60f..173849eec 100644 --- a/tests/server/mcpserver/test_url_elicitation_error_throw.py +++ b/tests/server/mcpserver/test_url_elicitation_error_throw.py @@ -9,7 +9,7 @@ @pytest.mark.anyio -async def test_url_elicitation_error_thrown_from_tool(): +async def test_url_elicitation_error_thrown_from_tool() -> None: """Test that UrlElicitationRequiredError raised from a tool is received as MCPError by client.""" mcp = MCPServer(name="UrlElicitationErrorServer") @@ -50,7 +50,7 @@ async def connect_service(service_name: str, ctx: Context) -> str: @pytest.mark.anyio -async def test_url_elicitation_error_from_error(): +async def test_url_elicitation_error_from_error() -> None: """Test that client can reconstruct UrlElicitationRequiredError from MCPError.""" mcp = MCPServer(name="UrlElicitationErrorServer") @@ -91,7 +91,7 @@ async def multi_auth(ctx: Context) -> str: @pytest.mark.anyio -async def test_normal_exceptions_still_return_error_result(): +async def test_normal_exceptions_still_return_error_result() -> None: """Test that normal exceptions still return CallToolResult with is_error=True.""" mcp = MCPServer(name="NormalErrorServer") diff --git a/tests/server/mcpserver/tools/test_base.py b/tests/server/mcpserver/tools/test_base.py index 22d5f973e..dce688554 100644 --- a/tests/server/mcpserver/tools/test_base.py +++ b/tests/server/mcpserver/tools/test_base.py @@ -2,7 +2,7 @@ from mcp.server.mcpserver.tools.base import Tool -def test_context_detected_in_union_annotation(): +def test_context_detected_in_union_annotation() -> None: def my_tool(x: int, ctx: Context | None) -> str: raise NotImplementedError diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index cff5a37c1..33cfc56b4 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -27,7 +27,7 @@ @pytest.mark.anyio -async def test_server_remains_functional_after_cancel(): +async def test_server_remains_functional_after_cancel() -> None: """Verify server can handle new requests after a cancellation.""" # Track tool calls @@ -61,7 +61,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar async with Client(server) as client: # First request (will be cancelled) - async def first_request(): + async def first_request() -> None: try: await client.session.send_request( CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), @@ -100,7 +100,7 @@ async def first_request(): @pytest.mark.anyio -async def test_server_cancels_in_flight_handlers_on_transport_close(): +async def test_server_cancels_in_flight_handlers_on_transport_close() -> None: """When the transport closes mid-request, server.run() must cancel in-flight handlers rather than join on them. @@ -129,7 +129,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) - async def run_server(): + async def run_server() -> None: await server.run(server_read, server_write, server.create_initialization_options()) server_run_returned.set() @@ -173,7 +173,7 @@ async def run_server(): @pytest.mark.anyio -async def test_server_handles_transport_close_with_pending_server_to_client_requests(): +async def test_server_handles_transport_close_with_pending_server_to_client_requests() -> None: """When the transport closes while handlers are blocked on server→client requests (sampling, roots, elicitation), server.run() must still exit cleanly. @@ -203,7 +203,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) - async def run_server(): + async def run_server() -> None: await server.run(server_read, server_write, server.create_initialization_options()) server_run_returned.set() diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index a01d0d4d7..3f61b6dff 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -14,7 +14,7 @@ @pytest.mark.anyio -async def test_completion_handler_receives_context(): +async def test_completion_handler_receives_context() -> None: """Test that the completion handler receives context correctly.""" # Track what the handler receives received_params: CompleteRequestParams | None = None @@ -42,7 +42,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa @pytest.mark.anyio -async def test_completion_backward_compatibility(): +async def test_completion_backward_compatibility() -> None: """Test that completion works without context (backward compatibility).""" context_was_none = False @@ -65,7 +65,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa @pytest.mark.anyio -async def test_dependent_completion_scenario(): +async def test_dependent_completion_scenario() -> None: """Test a real-world scenario with dependent completions.""" async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: @@ -120,7 +120,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa @pytest.mark.anyio -async def test_completion_error_on_missing_context(): +async def test_completion_error_on_missing_context() -> None: """Test that server can raise error when required context is missing.""" async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 0d8790504..9539d0eea 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -27,7 +27,7 @@ @pytest.mark.anyio -async def test_lowlevel_server_lifespan(): +async def test_lowlevel_server_lifespan() -> None: """Test that lifespan works in low-level server.""" @asynccontextmanager @@ -58,7 +58,7 @@ async def check_lifespan( # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: - async def run_server(): + async def run_server() -> None: await server.run( receive_stream1, send_stream2, @@ -121,7 +121,7 @@ async def run_server(): @pytest.mark.anyio -async def test_mcpserver_server_lifespan(): +async def test_mcpserver_server_lifespan() -> None: """Test that lifespan works in MCPServer server.""" @asynccontextmanager @@ -152,7 +152,7 @@ def check_lifespan(ctx: Context) -> bool: # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: - async def run_server(): + async def run_server() -> None: await server._lowlevel_server.run( receive_stream1, send_stream2, diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 46925916d..6834d3ddc 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_true(): +async def test_exception_handling_with_raise_exceptions_true() -> None: """Transport exceptions are re-raised when raise_exceptions=True.""" server = Server("test-server") session = Mock(spec=ServerSession) @@ -23,7 +23,7 @@ async def test_exception_handling_with_raise_exceptions_true(): @pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_false(): +async def test_exception_handling_with_raise_exceptions_false() -> None: """Transport exceptions are logged locally but not sent to the client. The transport that reported the error is likely broken; writing back @@ -40,7 +40,7 @@ async def test_exception_handling_with_raise_exceptions_false(): @pytest.mark.anyio -async def test_normal_message_handling_not_affected(): +async def test_normal_message_handling_not_affected() -> None: """Test that normal messages still work correctly""" server = Server("test-server") session = Mock(spec=ServerSession) @@ -62,7 +62,7 @@ async def test_normal_message_handling_not_affected(): @pytest.mark.anyio -async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes(): +async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes() -> None: """Regression test for #1967 / #2064. Exercises the real Server.run() path with real memory streams, reproducing diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 705abdfe8..807cc7502 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -8,7 +8,7 @@ @pytest.mark.anyio -async def test_lowlevel_server_tool_annotations(): +async def test_lowlevel_server_tool_annotations() -> None: """Test that tool annotations work in low-level server.""" async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 102a58d03..887b8527a 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.anyio -async def test_read_resource_text(): +async def test_read_resource_text() -> None: async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: return ReadResourceResult( contents=[TextResourceContents(uri=str(params.uri), text="Hello World", mime_type="text/plain")] @@ -32,7 +32,7 @@ async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRe assert content.mime_type == "text/plain" -async def test_read_resource_binary(): +async def test_read_resource_binary() -> None: binary_data = b"Hello World" async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a2786d865..5cdd5b1ce 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -23,7 +23,7 @@ @pytest.mark.anyio -async def test_server_session_initialize(): +async def test_server_session_initialize() -> None: server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -36,7 +36,7 @@ async def message_handler( # pragma: no cover received_initialized = False - async def run_server(): + async def run_server() -> None: nonlocal received_initialized async with ServerSession( @@ -77,7 +77,7 @@ async def run_server(): @pytest.mark.anyio -async def test_server_capabilities(): +async def test_server_capabilities() -> None: notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} @@ -129,7 +129,7 @@ async def noop_completion(ctx: ServerRequestContext, params: types.CompleteReque @pytest.mark.anyio -async def test_server_session_initialize_with_older_protocol_version(): +async def test_server_session_initialize_with_older_protocol_version() -> None: """Test that server accepts and responds with older protocol (2024-11-05).""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -137,7 +137,7 @@ async def test_server_session_initialize_with_older_protocol_version(): received_initialized = False received_protocol_version = None - async def run_server(): + async def run_server() -> None: nonlocal received_initialized async with ServerSession( @@ -159,7 +159,7 @@ async def run_server(): received_initialized = True return - async def mock_client(): + async def mock_client() -> None: nonlocal received_protocol_version # Send initialization request with older protocol version (2024-11-05) @@ -208,7 +208,7 @@ async def mock_client(): @pytest.mark.anyio -async def test_ping_request_before_initialization(): +async def test_ping_request_before_initialization() -> None: """Test that ping requests are allowed before initialization is complete.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -216,7 +216,7 @@ async def test_ping_request_before_initialization(): ping_response_received = False ping_response_id = None - async def run_server(): + async def run_server() -> None: async with ServerSession( client_to_server_receive, server_to_client_send, @@ -239,7 +239,7 @@ async def run_server(): await message.respond(types.EmptyResult()) return - async def mock_client(): + async def mock_client() -> None: nonlocal ping_response_received, ping_response_id # Send ping request before any initialization @@ -267,7 +267,7 @@ async def mock_client(): @pytest.mark.anyio -async def test_create_message_tool_result_validation(): +async def test_create_message_tool_result_validation() -> None: """Test tool_use/tool_result validation in create_message.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -392,7 +392,7 @@ async def test_create_message_tool_result_validation(): @pytest.mark.anyio -async def test_create_message_without_tools_capability(): +async def test_create_message_without_tools_capability() -> None: """Test that create_message raises MCPError when tools are provided without capability.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -442,7 +442,7 @@ async def test_create_message_without_tools_capability(): @pytest.mark.anyio -async def test_other_requests_blocked_before_initialization(): +async def test_other_requests_blocked_before_initialization() -> None: """Test that non-ping requests are still blocked before initialization.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -450,7 +450,7 @@ async def test_other_requests_blocked_before_initialization(): error_response_received = False error_code = None - async def run_server(): + async def run_server() -> None: async with ServerSession( client_to_server_receive, server_to_client_send, @@ -464,7 +464,7 @@ async def run_server(): # No need to process incoming_messages since the error is handled automatically await anyio.sleep(0.1) # Give time for the request to be processed - async def mock_client(): + async def mock_client() -> None: nonlocal error_response_received, error_code # Try to send a non-ping request before initialization diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py index 81041152b..0dcaf3097 100644 --- a/tests/server/test_session_race_condition.py +++ b/tests/server/test_session_race_condition.py @@ -18,7 +18,7 @@ @pytest.mark.anyio -async def test_request_immediately_after_initialize_response(): +async def test_request_immediately_after_initialize_response() -> None: """Test that requests are accepted immediately after initialize response. This reproduces the race condition in stateful HTTP mode where: @@ -37,7 +37,7 @@ async def test_request_immediately_after_initialize_response(): tools_list_success = False error_received = None - async def run_server(): + async def run_server() -> None: nonlocal tools_list_success async with ServerSession( @@ -79,7 +79,7 @@ async def run_server(): # Done - exit gracefully return - async def mock_client(): + async def mock_client() -> None: nonlocal error_received # Step 1: Send InitializeRequest diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a2..a5e2c78db 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -35,19 +35,21 @@ def server_url(server_port: int) -> str: # pragma: no cover class SecurityTestServer(Server): # pragma: no cover - def __init__(self): + def __init__(self) -> None: super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover +def run_server_with_settings( + port: int, security_settings: TransportSecuritySettings | None = None +) -> None: # pragma: no cover """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> Response: try: async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: if streams: @@ -66,7 +68,9 @@ async def handle_sse(request: Request): uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): +def start_server_process( + port: int, security_settings: TransportSecuritySettings | None = None +) -> multiprocessing.Process: """Start server in a separate process.""" process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() @@ -76,7 +80,7 @@ def start_server_process(port: int, security_settings: TransportSecuritySettings @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): +async def test_sse_security_default_settings(server_port: int) -> None: """Test SSE with default security settings (protection disabled).""" process = start_server_process(server_port) @@ -92,7 +96,7 @@ async def test_sse_security_default_settings(server_port: int): @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(server_port: int) -> None: """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) @@ -113,7 +117,7 @@ async def test_sse_security_invalid_host_header(server_port: int): @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(server_port: int) -> None: """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( @@ -136,7 +140,7 @@ async def test_sse_security_invalid_origin_header(server_port: int): @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(server_port: int) -> None: """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( @@ -169,7 +173,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int): @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(server_port: int) -> None: """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) process = start_server_process(server_port, settings) @@ -190,7 +194,7 @@ async def test_sse_security_disabled(server_port: int): @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(server_port: int) -> None: """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, @@ -223,7 +227,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(server_port: int) -> None: """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, @@ -257,7 +261,7 @@ async def test_sse_security_wildcard_ports(server_port: int): @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(server_port: int) -> None: """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 3bfc6e674..3378a0cab 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -49,14 +49,14 @@ async def stateless_session() -> AsyncGenerator[ServerSession, None]: @pytest.mark.anyio -async def test_list_roots_fails_in_stateless_mode(stateless_session: ServerSession): +async def test_list_roots_fails_in_stateless_mode(stateless_session: ServerSession) -> None: """Test that list_roots raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="list_roots"): await stateless_session.list_roots() @pytest.mark.anyio -async def test_create_message_fails_in_stateless_mode(stateless_session: ServerSession): +async def test_create_message_fails_in_stateless_mode(stateless_session: ServerSession) -> None: """Test that create_message raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="sampling"): await stateless_session.create_message( @@ -71,7 +71,7 @@ async def test_create_message_fails_in_stateless_mode(stateless_session: ServerS @pytest.mark.anyio -async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSession): +async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSession) -> None: """Test that elicit_form raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="elicitation"): await stateless_session.elicit_form( @@ -81,7 +81,7 @@ async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSess @pytest.mark.anyio -async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSession): +async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSession) -> None: """Test that elicit_url raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="elicitation"): await stateless_session.elicit_url( @@ -92,7 +92,7 @@ async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSessi @pytest.mark.anyio -async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: ServerSession): +async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: ServerSession) -> None: """Test that the deprecated elicit method also fails in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="elicitation"): await stateless_session.elicit( @@ -102,7 +102,7 @@ async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: Serv @pytest.mark.anyio -async def test_stateless_error_message_is_actionable(stateless_session: ServerSession): +async def test_stateless_error_message_is_actionable(stateless_session: ServerSession) -> None: """Test that the error message provides actionable guidance.""" with pytest.raises(StatelessModeNotSupported) as exc_info: await stateless_session.list_roots() @@ -117,7 +117,7 @@ async def test_stateless_error_message_is_actionable(stateless_session: ServerSe @pytest.mark.anyio -async def test_exception_has_method_attribute(stateless_session: ServerSession): +async def test_exception_has_method_attribute(stateless_session: ServerSession) -> None: """Test that the exception has a method attribute for programmatic access.""" with pytest.raises(StatelessModeNotSupported) as exc_info: await stateless_session.list_roots() @@ -155,7 +155,7 @@ async def stateful_session() -> AsyncGenerator[ServerSession, None]: @pytest.mark.anyio async def test_stateful_mode_does_not_raise_stateless_error( stateful_session: ServerSession, monkeypatch: pytest.MonkeyPatch -): +) -> None: """Test that StatelessModeNotSupported is not raised in stateful mode. We mock send_request to avoid blocking on I/O while still verifying diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a99356..fbaeaed31 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_stdio_server(): +async def test_stdio_server() -> None: stdin = io.StringIO() stdout = io.StringIO() @@ -64,7 +64,7 @@ async def test_stdio_server(): @pytest.mark.anyio -async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): +async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None: """Non-UTF-8 bytes on stdin must not crash the server. Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a..281c6b22f 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -2,6 +2,7 @@ import json import logging +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, patch @@ -19,7 +20,7 @@ @pytest.mark.anyio -async def test_run_can_only_be_called_once(): +async def test_run_can_only_be_called_once() -> None: """Test that run() can only be called once per instance.""" app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) @@ -37,14 +38,14 @@ async def test_run_can_only_be_called_once(): @pytest.mark.anyio -async def test_run_prevents_concurrent_calls(): +async def test_run_prevents_concurrent_calls() -> None: """Test that concurrent calls to run() are prevented.""" app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) errors: list[Exception] = [] - async def try_run(): + async def try_run() -> None: try: async with manager.run(): # Simulate some work @@ -63,7 +64,7 @@ async def try_run(): @pytest.mark.anyio -async def test_handle_request_without_run_raises_error(): +async def test_handle_request_without_run_raises_error() -> None: """Test that handle_request raises error if run() hasn't been called.""" app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) @@ -71,10 +72,10 @@ async def test_handle_request_without_run_raises_error(): # Mock ASGI parameters scope = {"type": "http", "method": "POST", "path": "/test"} - async def receive(): # pragma: no cover + async def receive() -> Message: # pragma: no cover return {"type": "http.request", "body": b""} - async def send(message: Message): # pragma: no cover + async def send(message: Message) -> None: # pragma: no cover pass # Should raise error because run() hasn't been called @@ -90,7 +91,7 @@ class TestException(Exception): @pytest.fixture -async def running_manager(): +async def running_manager() -> AsyncGenerator[tuple[StreamableHTTPSessionManager, Server], None]: app = Server("test-cleanup-server") # It's important that the app instance used by the manager is the one we can patch manager = StreamableHTTPSessionManager(app=app) @@ -100,7 +101,9 @@ async def running_manager(): @pytest.mark.anyio -async def test_stateful_session_cleanup_on_graceful_exit(running_manager: tuple[StreamableHTTPSessionManager, Server]): +async def test_stateful_session_cleanup_on_graceful_exit( + running_manager: tuple[StreamableHTTPSessionManager, Server], +) -> None: manager, app = running_manager mock_mcp_run = AsyncMock(return_value=None) @@ -109,7 +112,7 @@ async def test_stateful_session_cleanup_on_graceful_exit(running_manager: tuple[ sent_messages: list[Message] = [] - async def mock_send(message: Message): + async def mock_send(message: Message) -> None: sent_messages.append(message) scope = { @@ -119,7 +122,7 @@ async def mock_send(message: Message): "headers": [(b"content-type", b"application/json")], } - async def mock_receive(): # pragma: no cover + async def mock_receive() -> Message: # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation @@ -155,7 +158,9 @@ async def mock_receive(): # pragma: no cover @pytest.mark.anyio -async def test_stateful_session_cleanup_on_exception(running_manager: tuple[StreamableHTTPSessionManager, Server]): +async def test_stateful_session_cleanup_on_exception( + running_manager: tuple[StreamableHTTPSessionManager, Server], +) -> None: manager, app = running_manager mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) @@ -163,7 +168,7 @@ async def test_stateful_session_cleanup_on_exception(running_manager: tuple[Stre sent_messages: list[Message] = [] - async def mock_send(message: Message): + async def mock_send(message: Message) -> None: sent_messages.append(message) # If an exception occurs, the transport might try to send an error response # For this test, we mostly care that the session is established enough @@ -178,7 +183,7 @@ async def mock_send(message: Message): "headers": [(b"content-type", b"application/json")], } - async def mock_receive(): # pragma: no cover + async def mock_receive() -> Message: # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation @@ -208,7 +213,7 @@ async def mock_receive(): # pragma: no cover @pytest.mark.anyio -async def test_stateless_requests_memory_cleanup(): +async def test_stateless_requests_memory_cleanup() -> None: """Test that stateless requests actually clean up resources using real transports.""" app = Server("test-stateless-real-cleanup") manager = StreamableHTTPSessionManager(app=app, stateless=True) @@ -233,7 +238,7 @@ def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: # Send a simple request sent_messages: list[Message] = [] - async def mock_send(message: Message): + async def mock_send(message: Message) -> None: sent_messages.append(message) scope = { @@ -247,7 +252,7 @@ async def mock_send(message: Message): } # Empty body to trigger early return - async def mock_receive(): + async def mock_receive() -> Message: return { "type": "http.request", "body": b"", @@ -270,7 +275,7 @@ async def mock_receive(): @pytest.mark.anyio -async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): +async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture) -> None: """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" app = Server("test-unknown-session") manager = StreamableHTTPSessionManager(app=app) @@ -279,7 +284,7 @@ async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): sent_messages: list[Message] = [] response_body = b"" - async def mock_send(message: Message): + async def mock_send(message: Message) -> None: nonlocal response_body sent_messages.append(message) if message["type"] == "http.response.body": @@ -297,7 +302,7 @@ async def mock_send(message: Message): ], } - async def mock_receive(): + async def mock_receive() -> Message: return {"type": "http.request", "body": b"{}", "more_body": False} # pragma: no cover with caplog.at_level(logging.INFO): @@ -321,7 +326,7 @@ async def mock_receive(): @pytest.mark.anyio -async def test_e2e_streamable_http_server_cleanup(): +async def test_e2e_streamable_http_server_cleanup() -> None: host = "testserver" async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: @@ -339,7 +344,7 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP @pytest.mark.anyio -async def test_idle_session_is_reaped(): +async def test_idle_session_is_reaped() -> None: """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) @@ -347,7 +352,7 @@ async def test_idle_session_is_reaped(): async with manager.run(): sent_messages: list[Message] = [] - async def mock_send(message: Message): + async def mock_send(message: Message) -> None: sent_messages.append(message) scope = { @@ -357,7 +362,7 @@ async def mock_send(message: Message): "headers": [(b"content-type", b"application/json")], } - async def mock_receive(): # pragma: no cover + async def mock_receive() -> Message: # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} await manager.handle_request(scope, mock_receive, mock_send) @@ -380,7 +385,7 @@ async def mock_receive(): # pragma: no cover # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] - async def capture_send(message: Message): + async def capture_send(message: Message) -> None: response_messages.append(message) scope_with_session = { @@ -403,13 +408,13 @@ async def capture_send(message: Message): assert response_start["status"] == 404 -def test_session_idle_timeout_rejects_non_positive(): +def test_session_idle_timeout_rejects_non_positive() -> None: with pytest.raises(ValueError, match="positive number"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=-1) with pytest.raises(ValueError, match="positive number"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=0) -def test_session_idle_timeout_rejects_stateless(): +def test_session_idle_timeout_rejects_stateless() -> None: with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353..f5dcff821 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -34,14 +34,16 @@ def server_url(server_port: int) -> str: # pragma: no cover class SecurityTestServer(Server): # pragma: no cover - def __init__(self): + def __init__(self) -> None: super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover +def run_server_with_settings( + port: int, security_settings: TransportSecuritySettings | None = None +) -> None: # pragma: no cover """Run the StreamableHTTP server with specified security settings.""" app = SecurityTestServer() @@ -71,7 +73,9 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): +def start_server_process( + port: int, security_settings: TransportSecuritySettings | None = None +) -> multiprocessing.Process: """Start server in a separate process.""" process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() @@ -81,7 +85,7 @@ def start_server_process(port: int, security_settings: TransportSecuritySettings @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): +async def test_streamable_http_security_default_settings(server_port: int) -> None: """Test StreamableHTTP with default security settings (protection enabled).""" process = start_server_process(server_port) @@ -106,7 +110,7 @@ async def test_streamable_http_security_default_settings(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): +async def test_streamable_http_security_invalid_host_header(server_port: int) -> None: """Test StreamableHTTP with invalid Host header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) process = start_server_process(server_port, security_settings) @@ -134,7 +138,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): +async def test_streamable_http_security_invalid_origin_header(server_port: int) -> None: """Test StreamableHTTP with invalid Origin header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) process = start_server_process(server_port, security_settings) @@ -162,7 +166,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): +async def test_streamable_http_security_invalid_content_type(server_port: int) -> None: """Test StreamableHTTP POST with invalid Content-Type header.""" process = start_server_process(server_port) @@ -195,7 +199,7 @@ async def test_streamable_http_security_invalid_content_type(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): +async def test_streamable_http_security_disabled(server_port: int) -> None: """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) process = start_server_process(server_port, settings) @@ -223,7 +227,7 @@ async def test_streamable_http_security_disabled(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): +async def test_streamable_http_security_custom_allowed_hosts(server_port: int) -> None: """Test StreamableHTTP with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, @@ -254,7 +258,7 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): +async def test_streamable_http_security_get_request(server_port: int) -> None: """Test StreamableHTTP GET request with security.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) process = start_server_process(server_port, security_settings) diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index cd3c35332..a1138a6ee 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -3,7 +3,7 @@ from mcp.shared.auth import OAuthMetadata -def test_oauth(): +def test_oauth() -> None: """Should not throw when parsing OAuth metadata.""" OAuthMetadata.model_validate( { @@ -17,7 +17,7 @@ def test_oauth(): ) -def test_oidc(): +def test_oidc() -> None: """Should not throw when parsing OIDC metadata.""" OAuthMetadata.model_validate( { @@ -37,7 +37,7 @@ def test_oidc(): ) -def test_oauth_with_jarm(): +def test_oauth_with_jarm() -> None: """Should not throw when parsing OAuth metadata that includes JARM response modes.""" OAuthMetadata.model_validate( { diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 5ae0e22b0..ee6c4347f 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -7,13 +7,13 @@ # Tests for resource_url_from_server_url function -def test_resource_url_from_server_url_removes_fragment(): +def test_resource_url_from_server_url_removes_fragment() -> None: """Fragment should be removed per RFC 8707.""" assert resource_url_from_server_url("https://example.com/path#fragment") == "https://example.com/path" assert resource_url_from_server_url("https://example.com/#fragment") == "https://example.com/" -def test_resource_url_from_server_url_preserves_path(): +def test_resource_url_from_server_url_preserves_path() -> None: """Path should be preserved.""" assert ( resource_url_from_server_url("https://example.com/path/to/resource") == "https://example.com/path/to/resource" @@ -22,25 +22,25 @@ def test_resource_url_from_server_url_preserves_path(): assert resource_url_from_server_url("https://example.com") == "https://example.com" -def test_resource_url_from_server_url_preserves_query(): +def test_resource_url_from_server_url_preserves_query() -> None: """Query parameters should be preserved.""" assert resource_url_from_server_url("https://example.com/path?foo=bar") == "https://example.com/path?foo=bar" assert resource_url_from_server_url("https://example.com/?key=value") == "https://example.com/?key=value" -def test_resource_url_from_server_url_preserves_port(): +def test_resource_url_from_server_url_preserves_port() -> None: """Non-default ports should be preserved.""" assert resource_url_from_server_url("https://example.com:8443/path") == "https://example.com:8443/path" assert resource_url_from_server_url("http://example.com:8080/") == "http://example.com:8080/" -def test_resource_url_from_server_url_lowercase_scheme_and_host(): +def test_resource_url_from_server_url_lowercase_scheme_and_host() -> None: """Scheme and host should be lowercase for canonical form.""" assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" assert resource_url_from_server_url("Http://Example.Com:8080/") == "http://example.com:8080/" -def test_resource_url_from_server_url_handles_pydantic_urls(): +def test_resource_url_from_server_url_handles_pydantic_urls() -> None: """Should handle Pydantic URL types.""" url = HttpUrl("https://example.com/path") assert resource_url_from_server_url(url) == "https://example.com/path" @@ -49,32 +49,32 @@ def test_resource_url_from_server_url_handles_pydantic_urls(): # Tests for check_resource_allowed function -def test_check_resource_allowed_identical_urls(): +def test_check_resource_allowed_identical_urls() -> None: """Identical URLs should match.""" assert check_resource_allowed("https://example.com/path", "https://example.com/path") is True assert check_resource_allowed("https://example.com/", "https://example.com/") is True assert check_resource_allowed("https://example.com", "https://example.com") is True -def test_check_resource_allowed_different_schemes(): +def test_check_resource_allowed_different_schemes() -> None: """Different schemes should not match.""" assert check_resource_allowed("https://example.com/path", "http://example.com/path") is False assert check_resource_allowed("http://example.com/", "https://example.com/") is False -def test_check_resource_allowed_different_domains(): +def test_check_resource_allowed_different_domains() -> None: """Different domains should not match.""" assert check_resource_allowed("https://example.com/path", "https://example.org/path") is False assert check_resource_allowed("https://sub.example.com/", "https://example.com/") is False -def test_check_resource_allowed_different_ports(): +def test_check_resource_allowed_different_ports() -> None: """Different ports should not match.""" assert check_resource_allowed("https://example.com:8443/path", "https://example.com/path") is False assert check_resource_allowed("https://example.com:8080/", "https://example.com:8443/") is False -def test_check_resource_allowed_hierarchical_matching(): +def test_check_resource_allowed_hierarchical_matching() -> None: """Child paths should match parent paths.""" # Parent resource allows child resources assert check_resource_allowed("https://example.com/api/v1/users", "https://example.com/api") is True @@ -89,7 +89,7 @@ def test_check_resource_allowed_hierarchical_matching(): assert check_resource_allowed("https://example.com/", "https://example.com/api") is False -def test_check_resource_allowed_path_boundary_matching(): +def test_check_resource_allowed_path_boundary_matching() -> None: """Path matching should respect boundaries.""" # Should not match partial path segments assert check_resource_allowed("https://example.com/apiextra", "https://example.com/api") is False @@ -100,7 +100,7 @@ def test_check_resource_allowed_path_boundary_matching(): assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True -def test_check_resource_allowed_trailing_slash_handling(): +def test_check_resource_allowed_trailing_slash_handling() -> None: """Trailing slashes should be handled correctly.""" # With and without trailing slashes assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True @@ -109,14 +109,14 @@ def test_check_resource_allowed_trailing_slash_handling(): assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True -def test_check_resource_allowed_case_insensitive_origin(): +def test_check_resource_allowed_case_insensitive_origin() -> None: """Origin comparison should be case-insensitive.""" assert check_resource_allowed("https://EXAMPLE.COM/path", "https://example.com/path") is True assert check_resource_allowed("HTTPS://example.com/path", "https://example.com/path") is True assert check_resource_allowed("https://Example.Com:8080/api", "https://example.com:8080/api") is True -def test_check_resource_allowed_empty_paths(): +def test_check_resource_allowed_empty_paths() -> None: """Empty paths should be handled correctly.""" assert check_resource_allowed("https://example.com", "https://example.com") is True assert check_resource_allowed("https://example.com/", "https://example.com") is True diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py index dcc6fd003..493f5f100 100644 --- a/tests/shared/test_httpx_utils.py +++ b/tests/shared/test_httpx_utils.py @@ -5,7 +5,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client -def test_default_settings(): +def test_default_settings() -> None: """Test that default settings are applied correctly.""" client = create_mcp_http_client() @@ -13,7 +13,7 @@ def test_default_settings(): assert client.timeout.connect == 30.0 -def test_custom_parameters(): +def test_custom_parameters() -> None: """Test custom headers and timeout are set correctly.""" headers = {"Authorization": "Bearer token"} timeout = httpx.Timeout(60.0) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index aad9e5d43..8ad4d8c0d 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -15,14 +15,14 @@ @pytest.mark.anyio -async def test_bidirectional_progress_notifications(): +async def test_bidirectional_progress_notifications() -> None: """Test that both client and server can send progress notifications.""" # Create memory streams for client/server server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Run a server session so we can send progress updates in tool - async def run_server(): + async def run_server() -> None: # Create a server session async with ServerSession( client_to_server_receive, @@ -197,7 +197,7 @@ async def handle_client_message( @pytest.mark.anyio -async def test_progress_callback_exception_logging(): +async def test_progress_callback_exception_logging() -> None: """Test that exceptions in progress callbacks are logged and \ don't crash the session.""" # Track logged warnings diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5..285efefcf 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -24,7 +24,7 @@ @pytest.mark.anyio -async def test_in_flight_requests_cleared_after_completion(): +async def test_in_flight_requests_cleared_after_completion() -> None: """Verify that _in_flight is empty after all requests complete.""" server = Server(name="test server") async with Client(server) as client: @@ -37,7 +37,7 @@ async def test_in_flight_requests_cleared_after_completion(): @pytest.mark.anyio -async def test_request_cancellation(): +async def test_request_cancellation() -> None: """Test that requests can be cancelled while in-flight.""" ev_tool_called = anyio.Event() ev_cancelled = anyio.Event() @@ -64,7 +64,7 @@ async def handle_list_tools( on_list_tools=handle_list_tools, ) - async def make_request(client: Client): + async def make_request(client: Client) -> None: nonlocal ev_cancelled try: await client.session.send_request( @@ -99,7 +99,7 @@ async def make_request(client: Client): @pytest.mark.anyio -async def test_response_id_type_mismatch_string_to_int(): +async def test_response_id_type_mismatch_string_to_int() -> None: """Test that responses with string IDs are correctly matched to requests sent with integer IDs. @@ -113,7 +113,7 @@ async def test_response_id_type_mismatch_string_to_int(): client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server(): + async def mock_server() -> None: """Receive a request and respond with a string ID instead of integer.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -131,7 +131,7 @@ async def mock_server(): ) await server_write.send(SessionMessage(message=response)) - async def make_request(client_session: ClientSession): + async def make_request(client_session: ClientSession) -> None: nonlocal result_holder # Send a ping request (uses integer ID internally) result = await client_session.send_ping() @@ -153,7 +153,7 @@ async def make_request(client_session: ClientSession): @pytest.mark.anyio -async def test_error_response_id_type_mismatch_string_to_int(): +async def test_error_response_id_type_mismatch_string_to_int() -> None: """Test that error responses with string IDs are correctly matched to requests sent with integer IDs. @@ -167,7 +167,7 @@ async def test_error_response_id_type_mismatch_string_to_int(): client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server(): + async def mock_server() -> None: """Receive a request and respond with an error using a string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -184,7 +184,7 @@ async def mock_server(): ) await server_write.send(SessionMessage(message=error_response)) - async def make_request(client_session: ClientSession): + async def make_request(client_session: ClientSession) -> None: nonlocal error_holder try: await client_session.send_ping() @@ -208,7 +208,7 @@ async def make_request(client_session: ClientSession): @pytest.mark.anyio -async def test_response_id_non_numeric_string_no_match(): +async def test_response_id_non_numeric_string_no_match() -> None: """Test that responses with non-numeric string IDs don't incorrectly match integer request IDs. @@ -221,7 +221,7 @@ async def test_response_id_non_numeric_string_no_match(): client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server(): + async def mock_server() -> None: """Receive a request and respond with a non-numeric string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -234,7 +234,7 @@ async def mock_server(): ) await server_write.send(SessionMessage(message=response)) - async def make_request(client_session: ClientSession): + async def make_request(client_session: ClientSession) -> None: try: # Use a short timeout since we expect this to fail await client_session.send_request( @@ -259,7 +259,7 @@ async def make_request(client_session: ClientSession): @pytest.mark.anyio -async def test_connection_closed(): +async def test_connection_closed() -> None: """Test that pending requests are cancelled when the connection is closed remotely.""" ev_closed = anyio.Event() @@ -269,7 +269,7 @@ async def test_connection_closed(): client_read, client_write = client_streams server_read, server_write = server_streams - async def make_request(client_session: ClientSession): + async def make_request(client_session: ClientSession) -> None: """Send a request in a separate task""" nonlocal ev_response try: @@ -281,7 +281,7 @@ async def make_request(client_session: ClientSession): assert "Connection closed" in str(e) ev_response.set() - async def mock_server(): + async def mock_server() -> None: """Wait for a request, then close the connection""" nonlocal ev_closed # Wait for a request @@ -305,7 +305,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_null_id_error_surfaced_via_message_handler(): +async def test_null_id_error_surfaced_via_message_handler() -> None: """Test that a JSONRPCError with id=None is surfaced to the message handler. Per JSON-RPC 2.0, error responses use id=null when the request id could not @@ -328,7 +328,7 @@ async def capture_errors( client_read, client_write = client_streams _server_read, server_write = server_streams - async def mock_server(): + async def mock_server() -> None: """Send a null-id error (simulating a parse error).""" error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) await server_write.send(SessionMessage(message=error_response)) @@ -351,7 +351,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_null_id_error_does_not_affect_pending_request(): +async def test_null_id_error_does_not_affect_pending_request() -> None: """Test that a null-id error doesn't interfere with an in-flight request. When a null-id error arrives while a request is pending, the error should @@ -376,7 +376,7 @@ async def capture_errors( client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server(): + async def mock_server() -> None: """Read a request, inject a null-id error, then respond normally.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -389,7 +389,7 @@ async def mock_server(): # Then, respond normally to the pending request await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) - async def make_request(client_session: ClientSession): + async def make_request(client_session: ClientSession) -> None: result = await client_session.send_ping() result_holder.append(result) ev_response_received.set() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707..4dec30549 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -497,7 +497,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["headers"].get("x-custom-value") == f"value-{i}" -def test_sse_message_id_coercion(): +def test_sse_message_id_coercion() -> None: """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -531,7 +531,7 @@ def test_sse_message_id_coercion(): ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..9d0b6adff 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,10 +10,10 @@ import socket import time import traceback -from collections.abc import AsyncIterator, Generator +from collections.abc import AsyncGenerator, AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any +from typing import Any, NoReturn from unittest.mock import MagicMock from urllib.parse import urlparse @@ -97,7 +97,7 @@ def extract_protocol_version_from_sse(response: requests.Response) -> str: class SimpleEventStore(EventStore): """Simple in-memory event store for testing.""" - def __init__(self): + def __init__(self) -> None: self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 @@ -570,7 +570,7 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): +def test_accept_header_validation(basic_server: None, basic_server_url: str) -> None: """Test that Accept header is properly validated.""" # Test without Accept header (suppress requests library default Accept: */*) session = requests.Session() @@ -595,7 +595,7 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str) -> None: """Test that wildcard Accept headers are accepted per RFC 7231.""" response = requests.post( f"{basic_server_url}/mcp", @@ -616,7 +616,7 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str) -> None: """Test that incompatible Accept headers are rejected for SSE mode.""" response = requests.post( f"{basic_server_url}/mcp", @@ -630,7 +630,7 @@ def test_accept_header_incompatible(basic_server: None, basic_server_url: str, a assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server: None, basic_server_url: str): +def test_content_type_validation(basic_server: None, basic_server_url: str) -> None: """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -646,7 +646,7 @@ def test_content_type_validation(basic_server: None, basic_server_url: str): assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): +def test_json_validation(basic_server: None, basic_server_url: str) -> None: """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -661,7 +661,7 @@ def test_json_validation(basic_server: None, basic_server_url: str): assert "Parse error" in response.text -def test_json_parsing(basic_server: None, basic_server_url: str): +def test_json_parsing(basic_server: None, basic_server_url: str) -> None: """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -676,7 +676,7 @@ def test_json_parsing(basic_server: None, basic_server_url: str): assert "Validation error" in response.text -def test_method_not_allowed(basic_server: None, basic_server_url: str): +def test_method_not_allowed(basic_server: None, basic_server_url: str) -> None: """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -691,7 +691,7 @@ def test_method_not_allowed(basic_server: None, basic_server_url: str): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server: None, basic_server_url: str): +def test_session_validation(basic_server: None, basic_server_url: str) -> None: """Test session ID validation.""" # session_id not used directly in this test @@ -708,7 +708,7 @@ def test_session_validation(basic_server: None, basic_server_url: str): assert "Missing session ID" in response.text -def test_session_id_pattern(): +def test_session_id_pattern() -> None: """Test that SESSION_ID_PATTERN correctly validates session IDs.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ @@ -743,7 +743,7 @@ def test_session_id_pattern(): assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation(): +def test_streamable_http_transport_init_validation() -> None: """Test that StreamableHTTPServerTransport validates session ID on init.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") @@ -766,7 +766,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): +def test_session_termination(basic_server: None, basic_server_url: str) -> None: """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -806,7 +806,7 @@ def test_session_termination(basic_server: None, basic_server_url: str): assert "Session has been terminated" in response.text -def test_response(basic_server: None, basic_server_url: str): +def test_response(basic_server: None, basic_server_url: str) -> None: """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -841,7 +841,7 @@ def test_response(basic_server: None, basic_server_url: str): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server: None, json_server_url: str): +def test_json_response(json_response_server: None, json_server_url: str) -> None: """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -856,7 +856,7 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): +def test_json_response_accept_json_only(json_response_server: None, json_server_url: str) -> None: """Test that json_response servers only require application/json in Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -871,7 +871,7 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ assert response.headers.get("Content-Type") == "application/json" -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str) -> None: """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" # Suppress requests library default Accept: */* header @@ -888,7 +888,7 @@ def test_json_response_missing_accept_header(json_response_server: None, json_se assert "Not Acceptable" in response.text -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str) -> None: """Test that json_response servers reject requests with incorrect Accept header.""" mcp_url = f"{json_server_url}/mcp" # Test with only text/event-stream (wrong for JSON server) @@ -912,7 +912,9 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): +def test_json_response_wildcard_accept_header( + json_response_server: None, json_server_url: str, accept_header: str +) -> None: """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -927,7 +929,7 @@ def test_json_response_wildcard_accept_header(json_response_server: None, json_s assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server: None, basic_server_url: str): +def test_get_sse_stream(basic_server: None, basic_server_url: str) -> None: """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -987,7 +989,7 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): assert second_get.status_code == 409 -def test_get_validation(basic_server: None, basic_server_url: str): +def test_get_validation(basic_server: None, basic_server_url: str) -> None: """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -1044,14 +1046,16 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover +async def http_client( + basic_server: None, basic_server_url: str +) -> AsyncGenerator[httpx.AsyncClient, None]: # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_server: None, basic_server_url: str) -> AsyncGenerator[ClientSession, None]: """Create initialized StreamableHTTP client session.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1060,7 +1064,7 @@ async def initialized_client_session(basic_server: None, basic_server_url: str): @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str) -> None: """Test basic client connection with initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1071,7 +1075,7 @@ async def test_streamable_http_client_basic_connection(basic_server: None, basic @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: """Test client resource read functionality.""" response = await initialized_client_session.read_resource(uri="foobar://test-resource") assert len(response.contents) == 1 @@ -1081,7 +1085,7 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() @@ -1096,7 +1100,7 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: """Test error handling in client.""" with pytest.raises(MCPError) as exc_info: await initialized_client_session.read_resource(uri="unknown://test-error") @@ -1105,7 +1109,7 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str) -> None: """Test that session ID persists across requests.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1126,7 +1130,7 @@ async def test_streamable_http_client_session_persistence(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): +async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str) -> None: """Test client with JSON response mode.""" async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1147,7 +1151,7 @@ async def test_streamable_http_client_json_response(json_response_server: None, @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str) -> None: """Test GET stream functionality for server-initiated messages.""" notifications_received: list[types.ServerNotification] = [] @@ -1198,7 +1202,7 @@ async def capture_session_id(response: httpx.Response) -> None: @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str) -> None: """Test client session termination functionality.""" # Use httpx client with event hooks to capture session ID httpx_client, captured_ids = create_session_id_capturing_client() @@ -1235,7 +1239,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): +) -> None: """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1294,7 +1298,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): +async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]) -> None: """Test client session resumption using sync primitives for reliable coordination.""" _, server_url = event_server @@ -1345,7 +1349,7 @@ async def on_resumption_token_update(token: str) -> None: # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: # pragma: no branch - async def run_tool(): + async def run_tool() -> None: metadata = ClientMessageMetadata( on_resumption_token_update=on_resumption_token_update, ) @@ -1412,7 +1416,7 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): +async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str) -> None: """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1517,7 +1521,7 @@ async def _handle_context_call_tool( # pragma: no cover # Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover +def run_context_aware_server(port: int) -> None: # pragma: no cover """Run the context-aware test server.""" server = Server( "ContextAwareServer", @@ -1639,7 +1643,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init( + context_aware_server: None, basic_server_url: str +) -> None: """Test that client includes mcp-protocol-version header after initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1659,7 +1665,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_ assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): +def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str) -> None: """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1717,7 +1723,7 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): +def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str) -> None: """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1747,11 +1753,11 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server: None, @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): +async def test_client_crash_handled(basic_server: None, basic_server_url: str) -> None: """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init - async def bad_client(): + async def bad_client() -> NoReturn: """Client that triggers ClosedResourceError""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1776,7 +1782,7 @@ async def bad_client(): @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): +async def test_handle_sse_event_skips_empty_data() -> None: """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") @@ -1802,7 +1808,7 @@ async def test_handle_sse_event_skips_empty_data(): @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version(): +async def test_priming_event_not_sent_for_old_protocol_version() -> None: """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( @@ -1831,7 +1837,7 @@ async def test_priming_event_not_sent_for_old_protocol_version(): @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store(): +async def test_priming_event_not_sent_without_event_store() -> None: """Test that _maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1851,7 +1857,7 @@ async def test_priming_event_not_sent_without_event_store(): @pytest.mark.anyio -async def test_priming_event_includes_retry_interval(): +async def test_priming_event_includes_retry_interval() -> None: """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( @@ -1880,7 +1886,7 @@ async def test_priming_event_includes_retry_interval(): @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( @@ -2119,7 +2125,7 @@ async def message_handler( @pytest.mark.anyio async def test_streamable_http_multiple_reconnections( event_server: tuple[SimpleEventStore, str], -): +) -> None: """Verify multiple close_sse_stream() calls each trigger a client reconnect. Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure diff --git a/tests/test_examples.py b/tests/test_examples.py index 3af82f04c..0a7fb2ba8 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -16,7 +16,7 @@ @pytest.mark.anyio -async def test_simple_echo(): +async def test_simple_echo() -> None: """Test the simple echo server""" from examples.mcpserver.simple_echo import mcp @@ -28,7 +28,7 @@ async def test_simple_echo(): @pytest.mark.anyio -async def test_complex_inputs(): +async def test_complex_inputs() -> None: """Test the complex inputs server""" from examples.mcpserver.complex_inputs import mcp @@ -48,7 +48,7 @@ async def test_complex_inputs(): @pytest.mark.anyio -async def test_direct_call_tool_result_return(): +async def test_direct_call_tool_result_return() -> None: """Test the CallToolResult echo server""" from examples.mcpserver.direct_call_tool_result_return import mcp @@ -64,7 +64,7 @@ async def test_direct_call_tool_result_return(): @pytest.mark.anyio -async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): +async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test the desktop server""" # Build a real Desktop directory under tmp_path rather than patching # Path.iterdir — a class-level patch breaks jsonschema_specifications' @@ -95,8 +95,8 @@ async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): # TODO(v2): Change back to README.md when v2 is released @pytest.mark.parametrize("example", find_examples("README.v2.md"), ids=str) -def test_docs_examples(example: CodeExample, eval_example: EvalExample): - ruff_ignore: list[str] = ["F841", "I001", "F821"] # F821: undefined names (snippets lack imports) +def test_docs_examples(example: CodeExample, eval_example: EvalExample) -> None: + ruff_ignore: list[str] = ["F841", "I001", "F821", "ANN"] # F821: undefined names (snippets lack imports) # Use project's actual line length of 120 eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=120) diff --git a/tests/test_types.py b/tests/test_types.py index f424efdbf..28df823a1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -26,7 +26,7 @@ @pytest.mark.anyio -async def test_jsonrpc_request(): +async def test_jsonrpc_request() -> None: json_data = { "jsonrpc": "2.0", "id": 1, @@ -50,7 +50,7 @@ async def test_jsonrpc_request(): @pytest.mark.anyio -async def test_method_initialization(): +async def test_method_initialization() -> None: """Test that the method is automatically set on object creation. Testing just for InitializeRequest to keep the test simple, but should be set for other types as well. """ @@ -71,7 +71,7 @@ async def test_method_initialization(): @pytest.mark.anyio -async def test_tool_use_content(): +async def test_tool_use_content() -> None: """Test ToolUseContent type for SEP-1577.""" tool_use_data = { "type": "tool_use", @@ -93,7 +93,7 @@ async def test_tool_use_content(): @pytest.mark.anyio -async def test_tool_result_content(): +async def test_tool_result_content() -> None: """Test ToolResultContent type for SEP-1577.""" tool_result_data = { "type": "tool_result", @@ -115,7 +115,7 @@ async def test_tool_result_content(): @pytest.mark.anyio -async def test_tool_choice(): +async def test_tool_choice() -> None: """Test ToolChoice type for SEP-1577.""" # Test with mode tool_choice_data = {"mode": "required"} @@ -135,7 +135,7 @@ async def test_tool_choice(): @pytest.mark.anyio -async def test_sampling_message_with_user_role(): +async def test_sampling_message_with_user_role() -> None: """Test SamplingMessage with user role for SEP-1577.""" # Test with single content user_msg_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} @@ -158,7 +158,7 @@ async def test_sampling_message_with_user_role(): @pytest.mark.anyio -async def test_sampling_message_with_assistant_role(): +async def test_sampling_message_with_assistant_role() -> None: """Test SamplingMessage with assistant role for SEP-1577.""" # Test with tool use content assistant_msg_data = { @@ -188,7 +188,7 @@ async def test_sampling_message_with_assistant_role(): @pytest.mark.anyio -async def test_sampling_message_backward_compatibility(): +async def test_sampling_message_backward_compatibility() -> None: """Test that SamplingMessage maintains backward compatibility.""" # Old-style message (single content, no tools) old_style_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} @@ -215,7 +215,7 @@ async def test_sampling_message_backward_compatibility(): @pytest.mark.anyio -async def test_create_message_request_params_with_tools(): +async def test_create_message_request_params_with_tools() -> None: """Test CreateMessageRequestParams with tools for SEP-1577.""" tool = Tool( name="get_weather", @@ -238,7 +238,7 @@ async def test_create_message_request_params_with_tools(): @pytest.mark.anyio -async def test_create_message_result_with_tool_use(): +async def test_create_message_result_with_tool_use() -> None: """Test CreateMessageResultWithTools with tool use content for SEP-1577.""" result_data = { "role": "assistant", @@ -261,7 +261,7 @@ async def test_create_message_result_with_tool_use(): @pytest.mark.anyio -async def test_create_message_result_basic(): +async def test_create_message_result_basic() -> None: """Test CreateMessageResult with basic text content (backwards compatible).""" result_data = { "role": "assistant", @@ -280,7 +280,7 @@ async def test_create_message_result_basic(): @pytest.mark.anyio -async def test_client_capabilities_with_sampling_tools(): +async def test_client_capabilities_with_sampling_tools() -> None: """Test ClientCapabilities with nested sampling capabilities for SEP-1577.""" # New structured format capabilities_data: dict[str, Any] = { @@ -299,7 +299,7 @@ async def test_client_capabilities_with_sampling_tools(): assert full_caps.sampling.tools is not None -def test_tool_preserves_json_schema_2020_12_fields(): +def test_tool_preserves_json_schema_2020_12_fields() -> None: """Verify that JSON Schema 2020-12 keywords are preserved in Tool.inputSchema. SEP-1613 establishes JSON Schema 2020-12 as the default dialect for MCP. @@ -336,7 +336,7 @@ def test_tool_preserves_json_schema_2020_12_fields(): assert serialized["inputSchema"]["additionalProperties"] is False -def test_list_tools_result_preserves_json_schema_2020_12_fields(): +def test_list_tools_result_preserves_json_schema_2020_12_fields() -> None: """Verify JSON Schema 2020-12 fields survive ListToolsResult deserialization.""" raw_response = { "tools": [ From 8b7399c060aed9461a424d5882f5ec731b26564d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:15:42 +0000 Subject: [PATCH 2/5] chore: use AsyncIterator for lifespan annotations, sync README snippets - Switch AsyncGenerator[None, None] to AsyncIterator[None] in 4 example snippets to match codebase convention for @asynccontextmanager lifespans - Regenerate README.v2.md to sync snippet changes from eb3d0d0 --- README.v2.md | 50 ++++++++++--------- .../servers/streamable_http_basic_mounting.py | 4 +- .../servers/streamable_http_host_mounting.py | 4 +- .../streamable_http_multiple_servers.py | 4 +- .../servers/streamable_starlette_mount.py | 4 +- 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/README.v2.md b/README.v2.md index 55d867586..c9c1f5b97 100644 --- a/README.v2.md +++ b/README.v2.md @@ -518,7 +518,7 @@ class UserProfile: age: int email: str | None = None - def __init__(self, name: str, age: int, email: str | None = None): + def __init__(self, name: str, age: int, email: str | None = None) -> None: self.name = name self.age = age self.email = email @@ -532,7 +532,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] + def __init__(self, setting1, setting2) -> None: # type: ignore[reportMissingParameterType] # noqa: ANN001 self.setting1 = setting1 self.setting2 = setting2 @@ -744,7 +744,7 @@ server_params = StdioServerParameters( ) -async def run(): +async def run() -> None: """Run the completion client example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -795,7 +795,7 @@ async def run(): print(f"Completions for 'style' argument: {result.completion.values}") -def main(): +def main() -> None: """Entry point for the completion client.""" asyncio.run(run()) @@ -1210,7 +1210,7 @@ def hello(name: str = "World") -> str: return f"Hello, {name}!" -def main(): +def main() -> None: """Entry point for the direct execution server.""" mcp.run() @@ -1280,6 +1280,7 @@ uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -1308,7 +1309,7 @@ def add_two(n: int) -> int: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(echo_mcp.session_manager.run()) await stack.enter_async_context(math_mcp.session_manager.run()) @@ -1392,6 +1393,7 @@ Run from the repository root: """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -1410,7 +1412,7 @@ def hello() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield @@ -1439,6 +1441,7 @@ Run from the repository root: """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Host @@ -1457,7 +1460,7 @@ def domain_info() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield @@ -1486,6 +1489,7 @@ Run from the repository root: """ import contextlib +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -1511,7 +1515,7 @@ def send_message(message: str) -> str: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette): +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(api_mcp.session_manager.run()) await stack.enter_async_context(chat_mcp.session_manager.run()) @@ -1718,7 +1722,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the server with lifespan management.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -1796,7 +1800,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -1889,7 +1893,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the structured output server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -1964,7 +1968,7 @@ server = Server( ) -async def run(): +async def run() -> None: """Run the server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( @@ -2127,7 +2131,7 @@ async def handle_sampling_message( ) -async def run(): +async def run() -> None: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection @@ -2165,7 +2169,7 @@ async def run(): print(f"Structured tool result: {result_structured}") -def main(): +def main() -> None: """Entry point for the client script.""" asyncio.run(run()) @@ -2191,7 +2195,7 @@ from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client -async def main(): +async def main() -> None: # Connect to a streamable HTTP server async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream): # Create a session using the client streams @@ -2235,7 +2239,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientSession) -> None: """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2247,7 +2251,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientSession) -> None: """Display available resources with human-readable names""" resources_response = await session.list_resources() @@ -2261,7 +2265,7 @@ async def display_resources(session: ClientSession): print(f"Resource Template: {display_name}") -async def run(): +async def run() -> None: """Run the display utilities example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: @@ -2275,7 +2279,7 @@ async def run(): await display_resources(session) -def main(): +def main() -> None: """Entry point for the display utilities client.""" asyncio.run(run()) @@ -2323,7 +2327,7 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAu class InMemoryTokenStorage(TokenStorage): """Demo In-memory token storage implementation.""" - def __init__(self): + def __init__(self) -> None: self.tokens: OAuthToken | None = None self.client_info: OAuthClientInformationFull | None = None @@ -2354,7 +2358,7 @@ async def handle_callback() -> tuple[str, str | None]: return params["code"][0], params.get("state", [None])[0] -async def main(): +async def main() -> None: """Run the OAuth client example.""" oauth_auth = OAuthClientProvider( server_url="http://localhost:8001", @@ -2382,7 +2386,7 @@ async def main(): print(f"Available resources: {[r.uri for r in resources.resources]}") -def run(): +def run() -> None: asyncio.run(main()) diff --git a/examples/snippets/servers/streamable_http_basic_mounting.py b/examples/snippets/servers/streamable_http_basic_mounting.py index a1fd72a29..e1687ce99 100644 --- a/examples/snippets/servers/streamable_http_basic_mounting.py +++ b/examples/snippets/servers/streamable_http_basic_mounting.py @@ -5,7 +5,7 @@ """ import contextlib -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -24,7 +24,7 @@ def hello() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield diff --git a/examples/snippets/servers/streamable_http_host_mounting.py b/examples/snippets/servers/streamable_http_host_mounting.py index ea72a98a4..73cbdd54d 100644 --- a/examples/snippets/servers/streamable_http_host_mounting.py +++ b/examples/snippets/servers/streamable_http_host_mounting.py @@ -5,7 +5,7 @@ """ import contextlib -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Host @@ -24,7 +24,7 @@ def domain_info() -> str: # Create a lifespan context manager to run the session manager @contextlib.asynccontextmanager -async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with mcp.session_manager.run(): yield diff --git a/examples/snippets/servers/streamable_http_multiple_servers.py b/examples/snippets/servers/streamable_http_multiple_servers.py index e46924e8a..b95e34d22 100644 --- a/examples/snippets/servers/streamable_http_multiple_servers.py +++ b/examples/snippets/servers/streamable_http_multiple_servers.py @@ -5,7 +5,7 @@ """ import contextlib -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -31,7 +31,7 @@ def send_message(message: str) -> str: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(api_mcp.session_manager.run()) await stack.enter_async_context(chat_mcp.session_manager.run()) diff --git a/examples/snippets/servers/streamable_starlette_mount.py b/examples/snippets/servers/streamable_starlette_mount.py index 95186ad7f..ae97982ff 100644 --- a/examples/snippets/servers/streamable_starlette_mount.py +++ b/examples/snippets/servers/streamable_starlette_mount.py @@ -3,7 +3,7 @@ """ import contextlib -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from starlette.applications import Starlette from starlette.routing import Mount @@ -32,7 +32,7 @@ def add_two(n: int) -> int: # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager -async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: +async def lifespan(app: Starlette) -> AsyncIterator[None]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(echo_mcp.session_manager.run()) await stack.enter_async_context(math_mcp.session_manager.run()) From cfde91dacd6dbd7c2a3574e3abcf29962207dc37 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:48:16 +0000 Subject: [PATCH 3/5] chore: exempt tests from ANN, enable mypy-init-return Align with ecosystem norms per review feedback: - Add mypy-init-return = true to skip -> None on __init__ when args are typed - Exempt tests/** from ANN rules (near-universal practice) - Revert test file annotations and __init__-only changes Shrinks diff from 149 to ~47 files, keeping the valuable src/ return type annotations that close the pyright gap. --- README.v2.md | 4 +- .../mcp_simple_auth/auth_server.py | 2 +- .../mcp_simple_auth/legacy_as_server.py | 2 +- .../mcp_simple_auth/simple_auth_provider.py | 2 +- .../mcp_simple_auth/token_verifier.py | 2 +- .../snippets/servers/structured_output.py | 4 +- pyproject.toml | 6 +- src/mcp/client/auth/oauth2.py | 2 +- src/mcp/os/win32/utilities.py | 2 +- .../server/auth/middleware/auth_context.py | 2 +- src/mcp/server/auth/middleware/bearer_auth.py | 6 +- src/mcp/server/auth/middleware/client_auth.py | 4 +- src/mcp/server/auth/provider.py | 4 +- src/mcp/server/experimental/task_context.py | 2 +- .../experimental/task_result_handler.py | 2 +- src/mcp/server/lowlevel/server.py | 6 +- src/mcp/server/mcpserver/context.py | 2 +- src/mcp/server/mcpserver/prompts/base.py | 6 +- src/mcp/server/mcpserver/prompts/manager.py | 2 +- .../mcpserver/resources/resource_manager.py | 2 +- src/mcp/server/mcpserver/server.py | 2 +- .../server/mcpserver/tools/tool_manager.py | 2 +- src/mcp/server/mcpserver/utilities/types.py | 4 +- src/mcp/server/streamable_http_manager.py | 4 +- src/mcp/server/transport_security.py | 2 +- src/mcp/shared/auth.py | 4 +- src/mcp/shared/exceptions.py | 6 +- src/mcp/shared/experimental/tasks/context.py | 2 +- tests/cli/test_claude.py | 22 +-- tests/cli/test_utils.py | 16 +- .../extensions/test_client_credentials.py | 42 ++--- tests/client/conftest.py | 26 ++- tests/client/test_auth.py | 138 +++++++-------- tests/client/test_client.py | 40 ++--- tests/client/test_list_methods_cursor.py | 8 +- tests/client/test_list_roots_callback.py | 4 +- tests/client/test_logging_callback.py | 4 +- tests/client/test_output_schema_validation.py | 10 +- tests/client/test_resource_cleanup.py | 6 +- tests/client/test_sampling_callback.py | 6 +- tests/client/test_scope_bug_1630.py | 2 +- tests/client/test_session.py | 40 ++--- tests/client/test_session_group.py | 22 ++- tests/client/test_stdio.py | 22 +-- tests/client/transports/test_memory.py | 12 +- tests/conftest.py | 2 +- .../tasks/client/test_capabilities.py | 16 +- tests/experimental/tasks/client/test_tasks.py | 8 +- .../tasks/server/test_integration.py | 8 +- tests/issues/test_100_tool_listing.py | 4 +- .../test_1027_win_unreachable_cleanup.py | 4 +- tests/issues/test_129_resource_templates.py | 2 +- tests/issues/test_1338_icons_and_metadata.py | 6 +- ...est_1363_race_condition_streamable_http.py | 12 +- tests/issues/test_141_resource_templates.py | 4 +- tests/issues/test_152_resource_mime_type.py | 4 +- .../test_1574_resource_uri_validation.py | 8 +- .../issues/test_1754_mime_type_parameters.py | 8 +- tests/issues/test_176_progress_token.py | 2 +- tests/issues/test_188_concurrency.py | 12 +- tests/issues/test_192_request_id.py | 2 +- tests/issues/test_342_base64_encoding.py | 2 +- tests/issues/test_355_type_error.py | 6 +- tests/issues/test_552_windows_hang.py | 2 +- tests/issues/test_88_random_error.py | 6 +- tests/issues/test_973_url_decoding.py | 8 +- tests/issues/test_malformed_input.py | 4 +- .../auth/middleware/test_auth_context.py | 6 +- .../auth/middleware/test_bearer_auth.py | 38 ++-- tests/server/auth/test_error_handling.py | 16 +- tests/server/auth/test_protected_resource.py | 27 ++- tests/server/auth/test_provider.py | 16 +- tests/server/auth/test_routes.py | 18 +- tests/server/lowlevel/test_helper_types.py | 6 +- .../mcpserver/auth/test_auth_integration.py | 111 ++++++------ tests/server/mcpserver/prompts/test_base.py | 22 +-- .../server/mcpserver/prompts/test_manager.py | 16 +- .../resources/test_file_resources.py | 17 +- .../resources/test_function_resources.py | 22 +-- .../resources/test_resource_manager.py | 23 ++- .../resources/test_resource_template.py | 30 ++-- .../mcpserver/resources/test_resources.py | 26 +-- .../mcpserver/servers/test_file_server.py | 10 +- tests/server/mcpserver/test_elicitation.py | 47 +++-- tests/server/mcpserver/test_func_metadata.py | 82 +++++---- tests/server/mcpserver/test_integration.py | 12 +- .../mcpserver/test_parameter_descriptions.py | 2 +- tests/server/mcpserver/test_server.py | 166 +++++++++--------- tests/server/mcpserver/test_title.py | 10 +- tests/server/mcpserver/test_tool_manager.py | 108 ++++++------ .../server/mcpserver/test_url_elicitation.py | 44 ++--- .../test_url_elicitation_error_throw.py | 6 +- tests/server/mcpserver/tools/test_base.py | 2 +- tests/server/test_cancel_handling.py | 12 +- tests/server/test_completion_with_context.py | 8 +- tests/server/test_lifespan.py | 8 +- .../test_lowlevel_exception_handling.py | 8 +- .../server/test_lowlevel_tool_annotations.py | 2 +- tests/server/test_read_resource.py | 4 +- tests/server/test_session.py | 28 +-- tests/server/test_session_race_condition.py | 6 +- tests/server/test_sse_security.py | 28 ++- tests/server/test_stateless_mode.py | 16 +- tests/server/test_stdio.py | 4 +- tests/server/test_streamable_http_manager.py | 57 +++--- tests/server/test_streamable_http_security.py | 24 ++- tests/shared/test_auth.py | 6 +- tests/shared/test_auth_utils.py | 30 ++-- tests/shared/test_httpx_utils.py | 4 +- tests/shared/test_progress_notifications.py | 6 +- tests/shared/test_session.py | 40 ++--- tests/shared/test_sse.py | 4 +- tests/shared/test_streamable_http.py | 102 +++++------ tests/test_examples.py | 12 +- tests/test_types.py | 28 +-- 115 files changed, 941 insertions(+), 1019 deletions(-) diff --git a/README.v2.md b/README.v2.md index c9c1f5b97..02c133b0d 100644 --- a/README.v2.md +++ b/README.v2.md @@ -518,7 +518,7 @@ class UserProfile: age: int email: str | None = None - def __init__(self, name: str, age: int, email: str | None = None) -> None: + def __init__(self, name: str, age: int, email: str | None = None): self.name = name self.age = age self.email = email @@ -532,7 +532,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2) -> None: # type: ignore[reportMissingParameterType] # noqa: ANN001 + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001, ANN204 self.setting1 = setting1 self.setting2 = setting2 diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 0125f3659..996cbfa44 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -47,7 +47,7 @@ class SimpleAuthProvider(SimpleOAuthProvider): 2. Stores token state for introspection by Resource Servers """ - def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str) -> None: + def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): super().__init__(auth_settings, auth_callback_path, server_url) diff --git a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py index 41aed08c0..ab7773b5b 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py @@ -39,7 +39,7 @@ class ServerSettings(BaseModel): class LegacySimpleOAuthProvider(SimpleOAuthProvider): """Simple OAuth provider for legacy MCP server.""" - def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str) -> None: + def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): super().__init__(auth_settings, auth_callback_path, server_url) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index cf4e3dc81..9bd6e3c4e 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -51,7 +51,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, Re 3. Maintaining token state for introspection """ - def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str) -> None: + def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str): self.settings = settings self.auth_callback_url = auth_callback_url self.server_url = server_url diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 8f9407431..5228d034e 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -25,7 +25,7 @@ def __init__( introspection_endpoint: str, server_url: str, validate_resource: bool = False, - ) -> None: + ): self.introspection_endpoint = introspection_endpoint self.server_url = server_url self.validate_resource = validate_resource diff --git a/examples/snippets/servers/structured_output.py b/examples/snippets/servers/structured_output.py index 69f4b203a..d7a2a4b51 100644 --- a/examples/snippets/servers/structured_output.py +++ b/examples/snippets/servers/structured_output.py @@ -57,7 +57,7 @@ class UserProfile: age: int email: str | None = None - def __init__(self, name: str, age: int, email: str | None = None) -> None: + def __init__(self, name: str, age: int, email: str | None = None): self.name = name self.age = age self.email = email @@ -71,7 +71,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2) -> None: # type: ignore[reportMissingParameterType] # noqa: ANN001 + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001, ANN204 self.setting1 = setting1 self.setting2 = setting2 diff --git a/pyproject.toml b/pyproject.toml index b38312577..b58aa3c13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ ignore = [ [tool.ruff.lint.flake8-annotations] allow-star-arg-any = true +mypy-init-return = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead." @@ -159,9 +160,8 @@ max-complexity = 24 # Default is 10 [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] -# ANN001: these files intentionally define untyped parameters to test schema inference -"tests/server/mcpserver/test_func_metadata.py" = ["ANN001", "E501"] -"tests/server/mcpserver/test_server.py" = ["ANN001"] +"tests/**" = ["ANN"] +"tests/server/mcpserver/test_func_metadata.py" = ["E501"] "tests/shared/test_progress_notifications.py" = ["PLW0603"] [tool.ruff.lint.pylint] diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 78eb0c325..25075dec3 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -232,7 +232,7 @@ def __init__( timeout: float = 300.0, client_metadata_url: str | None = None, validate_resource_url: Callable[[str, str | None], Awaitable[None]] | None = None, - ) -> None: + ): """Initialize OAuth2 authentication. Args: diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index e99092182..22dc600c0 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -71,7 +71,7 @@ class FallbackProcess: so that MCP clients expecting async streams can work properly. """ - def __init__(self, popen_obj: subprocess.Popen[bytes]) -> None: + def __init__(self, popen_obj: subprocess.Popen[bytes]): self.popen: subprocess.Popen[bytes] = popen_obj self.stdin_raw = popen_obj.stdin # type: ignore[assignment] self.stdout_raw = popen_obj.stdout # type: ignore[assignment] diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 686c2480f..682b3e47a 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -29,7 +29,7 @@ class AuthContextMiddleware: being stored in the context. """ - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index ef0d083c9..bc8b5263e 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -13,7 +13,7 @@ class AuthenticatedUser(SimpleUser): """User with authentication info.""" - def __init__(self, auth_info: AccessToken) -> None: + def __init__(self, auth_info: AccessToken): super().__init__(auth_info.client_id) self.access_token = auth_info self.scopes = auth_info.scopes @@ -22,7 +22,7 @@ def __init__(self, auth_info: AccessToken) -> None: class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" - def __init__(self, token_verifier: TokenVerifier) -> None: + def __init__(self, token_verifier: TokenVerifier): self.token_verifier = token_verifier async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, AuthenticatedUser] | None: @@ -59,7 +59,7 @@ def __init__( app: Any, required_scopes: list[str], resource_metadata_url: AnyHttpUrl | None = None, - ) -> None: + ): """Initialize the middleware. Args: diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 2f8396e49..2832f8352 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -12,7 +12,7 @@ class AuthenticationError(Exception): - def __init__(self, message: str) -> None: + def __init__(self, message: str): self.message = message @@ -28,7 +28,7 @@ class ClientAuthenticator: logic is skipped. """ - def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]) -> None: + def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Initialize the authenticator. Args: diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 9ea811d68..957082a85 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -285,9 +285,7 @@ class ProviderTokenVerifier(TokenVerifier): the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier. """ - def __init__( - self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]" - ) -> None: + def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"): self.provider = provider async def verify_token(self, token: str) -> AccessToken | None: diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 74d2a3f53..1fc45badf 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -80,7 +80,7 @@ def __init__( session: ServerSession, queue: TaskMessageQueue, handler: TaskResultHandler | None = None, - ) -> None: + ): """Create a ServerTaskContext. Args: diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index c158d4d3d..b2268bc1c 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -58,7 +58,7 @@ def __init__( self, store: TaskStore, queue: TaskMessageQueue, - ) -> None: + ): self._store = store self._queue = queue # Map from internal request ID to resolver for routing responses diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7fa745fd3..408f2536b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -75,9 +75,7 @@ async def main(): class NotificationOptions: - def __init__( - self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False - ) -> None: + def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): self.prompts_changed = prompts_changed self.resources_changed = resources_changed self.tools_changed = tools_changed @@ -183,7 +181,7 @@ def __init__( Awaitable[None], ] | None = None, - ) -> None: + ): self.name = name self.version = version self.title = title diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index e2689d6b2..3ad8391ed 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -65,7 +65,7 @@ def __init__( mcp_server: MCPServer | None = None, # TODO(Marcelo): We should drop this kwargs parameter. **kwargs: Any, - ) -> None: + ): super().__init__(**kwargs) self._request_context = request_context self._mcp_server = mcp_server diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 2f502250b..0c319d53c 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -24,7 +24,7 @@ class Message(BaseModel): role: Literal["user", "assistant"] content: ContentBlock - def __init__(self, content: str | ContentBlock, **kwargs: Any) -> None: + def __init__(self, content: str | ContentBlock, **kwargs: Any): if isinstance(content, str): content = TextContent(type="text", text=content) super().__init__(content=content, **kwargs) @@ -35,7 +35,7 @@ class UserMessage(Message): role: Literal["user", "assistant"] = "user" - def __init__(self, content: str | ContentBlock, **kwargs: Any) -> None: + def __init__(self, content: str | ContentBlock, **kwargs: Any): super().__init__(content=content, **kwargs) @@ -44,7 +44,7 @@ class AssistantMessage(Message): role: Literal["user", "assistant"] = "assistant" - def __init__(self, content: str | ContentBlock, **kwargs: Any) -> None: + def __init__(self, content: str | ContentBlock, **kwargs: Any): super().__init__(content=content, **kwargs) diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 303a5bbb9..28a7a6e98 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -17,7 +17,7 @@ class PromptManager: """Manages MCPServer prompts.""" - def __init__(self, warn_on_duplicate_prompts: bool = True) -> None: + def __init__(self, warn_on_duplicate_prompts: bool = True): self._prompts: dict[str, Prompt] = {} self.warn_on_duplicate_prompts = warn_on_duplicate_prompts diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 56b9e0e6d..6bf17376d 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -22,7 +22,7 @@ class ResourceManager: """Manages MCPServer resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True) -> None: + def __init__(self, warn_on_duplicate_resources: bool = True): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b80035ade..afa3653ee 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -144,7 +144,7 @@ def __init__( warn_on_duplicate_prompts: bool = True, lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, - ) -> None: + ): self.settings = Settings( debug=debug, log_level=log_level, diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index dd1ac519f..32ed54797 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -23,7 +23,7 @@ def __init__( warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None, - ) -> None: + ): self._tools: dict[str, Tool] = {} if tools is not None: for tool in tools: diff --git a/src/mcp/server/mcpserver/utilities/types.py b/src/mcp/server/mcpserver/utilities/types.py index 9e05a663e..f092b245a 100644 --- a/src/mcp/server/mcpserver/utilities/types.py +++ b/src/mcp/server/mcpserver/utilities/types.py @@ -14,7 +14,7 @@ def __init__( path: str | Path | None = None, data: bytes | None = None, format: str | None = None, - ) -> None: + ): if path is None and data is None: # pragma: no cover raise ValueError("Either path or data must be provided") if path is not None and data is not None: # pragma: no cover @@ -62,7 +62,7 @@ def __init__( path: str | Path | None = None, data: bytes | None = None, format: str | None = None, - ) -> None: + ): if not bool(path) ^ bool(data): # pragma: no cover raise ValueError("Either path or data can be provided") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 20fc5f74a..8e60863ed 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -72,7 +72,7 @@ def __init__( security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, session_idle_timeout: float | None = None, - ) -> None: + ): if session_idle_timeout is not None and session_idle_timeout <= 0: raise ValueError("session_idle_timeout must be a positive number of seconds") if stateless and session_idle_timeout is not None: @@ -289,7 +289,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE class StreamableHTTPASGIApp: """ASGI application for Streamable HTTP server transport.""" - def __init__(self, session_manager: StreamableHTTPSessionManager) -> None: + def __init__(self, session_manager: StreamableHTTPSessionManager): self.session_manager = session_manager async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 9acb41538..1ed9842c0 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -36,7 +36,7 @@ class TransportSecuritySettings(BaseModel): class TransportSecurityMiddleware: """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" - def __init__(self, settings: TransportSecuritySettings | None = None) -> None: + def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index c151ab104..ca5b7b45a 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -23,12 +23,12 @@ def normalize_token_type(cls, v: str | None) -> str | None: class InvalidScopeError(Exception): - def __init__(self, message: str) -> None: + def __init__(self, message: str): self.message = message class InvalidRedirectUriError(Exception): - def __init__(self, message: str) -> None: + def __init__(self, message: str): self.message = message diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index ba203220c..f153ea319 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -10,7 +10,7 @@ class MCPError(Exception): error: ErrorData - def __init__(self, code: int, message: str, data: Any = None) -> None: + def __init__(self, code: int, message: str, data: Any = None): super().__init__(code, message, data) if data is not None: self.error = ErrorData(code=code, message=message, data=data) @@ -49,7 +49,7 @@ class StatelessModeNotSupported(RuntimeError): for bidirectional communication. """ - def __init__(self, method: str) -> None: + def __init__(self, method: str): super().__init__( f"Cannot use {method} in stateless HTTP mode. " "Stateless mode does not support server-to-client requests. " @@ -76,7 +76,7 @@ class UrlElicitationRequiredError(MCPError): ``` """ - def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None) -> None: + def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None): """Initialize UrlElicitationRequiredError.""" if message is None: message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required" diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py index 124591a0e..ed0d2b91b 100644 --- a/src/mcp/shared/experimental/tasks/context.py +++ b/src/mcp/shared/experimental/tasks/context.py @@ -31,7 +31,7 @@ async def worker_job(task_id: str): await ctx.complete(result) """ - def __init__(self, task: Task, store: TaskStore) -> None: + def __init__(self, task: Task, store: TaskStore): self._task = task self._store = store self._cancelled = False diff --git a/tests/cli/test_claude.py b/tests/cli/test_claude.py index c33489a74..73d4f0eb5 100644 --- a/tests/cli/test_claude.py +++ b/tests/cli/test_claude.py @@ -24,7 +24,7 @@ def _read_server(config_dir: Path, name: str) -> dict[str, Any]: return config["mcpServers"][name] -def test_generates_uv_run_command(config_dir: Path) -> None: +def test_generates_uv_run_command(config_dir: Path): """Should write a uv run command that invokes mcp run on the resolved file spec.""" assert update_claude_config(file_spec="server.py:app", server_name="my_server") @@ -35,14 +35,14 @@ def test_generates_uv_run_command(config_dir: Path) -> None: } -def test_file_spec_without_object_suffix(config_dir: Path) -> None: +def test_file_spec_without_object_suffix(config_dir: Path): """File specs without :object should still resolve to an absolute path.""" assert update_claude_config(file_spec="server.py", server_name="s") assert _read_server(config_dir, "s")["args"][-1] == str(Path("server.py").resolve()) -def test_with_packages_sorted_and_deduplicated(config_dir: Path) -> None: +def test_with_packages_sorted_and_deduplicated(config_dir: Path): """Extra packages should appear as --with flags, sorted and deduplicated with mcp[cli].""" assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["zebra", "aardvark", "zebra"]) @@ -50,7 +50,7 @@ def test_with_packages_sorted_and_deduplicated(config_dir: Path) -> None: assert args[:8] == ["run", "--frozen", "--with", "aardvark", "--with", "mcp[cli]", "--with", "zebra"] -def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path) -> None: +def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): """with_editable should add --with-editable after the --with flags.""" editable = tmp_path / "project" assert update_claude_config(file_spec="s.py:app", server_name="s", with_editable=editable) @@ -59,14 +59,14 @@ def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path) -> None: assert args[4:6] == ["--with-editable", str(editable)] -def test_env_vars_written(config_dir: Path) -> None: +def test_env_vars_written(config_dir: Path): """env_vars should be written under the server's env key.""" assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "val"}) assert _read_server(config_dir, "s")["env"] == {"KEY": "val"} -def test_existing_env_vars_merged_new_wins(config_dir: Path) -> None: +def test_existing_env_vars_merged_new_wins(config_dir: Path): """Re-installing should merge env vars, with new values overriding existing ones.""" (config_dir / "claude_desktop_config.json").write_text( json.dumps({"mcpServers": {"s": {"env": {"OLD": "keep", "KEY": "old"}}}}) @@ -77,7 +77,7 @@ def test_existing_env_vars_merged_new_wins(config_dir: Path) -> None: assert _read_server(config_dir, "s")["env"] == {"OLD": "keep", "KEY": "new"} -def test_existing_env_vars_preserved_without_new(config_dir: Path) -> None: +def test_existing_env_vars_preserved_without_new(config_dir: Path): """Re-installing without env_vars should keep the existing env block intact.""" (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"s": {"env": {"KEEP": "me"}}}})) @@ -86,7 +86,7 @@ def test_existing_env_vars_preserved_without_new(config_dir: Path) -> None: assert _read_server(config_dir, "s")["env"] == {"KEEP": "me"} -def test_other_servers_preserved(config_dir: Path) -> None: +def test_other_servers_preserved(config_dir: Path): """Installing a new server should not clobber existing mcpServers entries.""" (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"other": {"command": "x"}}})) @@ -97,7 +97,7 @@ def test_other_servers_preserved(config_dir: Path) -> None: assert config["mcpServers"]["other"] == {"command": "x"} -def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch) -> None: +def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): """Should raise RuntimeError when Claude Desktop config dir can't be found.""" monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: None) monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") @@ -107,7 +107,7 @@ def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch) -> None @pytest.mark.parametrize("which_result, expected", [("/usr/local/bin/uv", "/usr/local/bin/uv"), (None, "uv")]) -def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str) -> None: +def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str): """Should return shutil.which's result, or fall back to bare 'uv' when not on PATH.""" def fake_which(cmd: str) -> str | None: @@ -126,7 +126,7 @@ def fake_which(cmd: str) -> str | None: ) def test_windows_drive_letter_not_split( config_dir: Path, monkeypatch: pytest.MonkeyPatch, file_spec: str, expected_last_arg: str -) -> None: +): """Drive-letter paths like 'C:\\server.py' must not be split on the drive colon. Before the fix, a bare 'C:\\path\\server.py' would hit rsplit(":", 1) and yield diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 50b2646c5..44f4ab4d3 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -15,7 +15,7 @@ ("foo.py:srv_obj", "srv_obj"), ], ) -def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected_obj: str | None) -> None: +def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected_obj: str | None): """Should accept valid file specs.""" file = tmp_path / spec.split(":")[0] file.write_text("x = 1") @@ -24,13 +24,13 @@ def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected assert obj == expected_obj -def test_parse_file_path_missing(tmp_path: Path) -> None: +def test_parse_file_path_missing(tmp_path: Path): """Should system exit if a file is missing.""" with pytest.raises(SystemExit): _parse_file_path(str(tmp_path / "missing.py")) -def test_parse_file_exit_on_dir(tmp_path: Path) -> None: +def test_parse_file_exit_on_dir(tmp_path: Path): """Should system exit if a directory is passed""" dir_path = tmp_path / "dir" dir_path.mkdir() @@ -38,13 +38,13 @@ def test_parse_file_exit_on_dir(tmp_path: Path) -> None: _parse_file_path(str(dir_path)) -def test_build_uv_command_minimal() -> None: +def test_build_uv_command_minimal(): """Should emit core command when no extras specified.""" cmd = _build_uv_command("foo.py") assert cmd == ["uv", "run", "--with", "mcp", "mcp", "run", "foo.py"] -def test_build_uv_command_adds_editable_and_packages() -> None: +def test_build_uv_command_adds_editable_and_packages(): """Should include --with-editable and every --with pkg in correct order.""" test_path = Path("/pkg") cmd = _build_uv_command( @@ -69,13 +69,13 @@ def test_build_uv_command_adds_editable_and_packages() -> None: ] -def test_get_npx_unix_like(monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_npx_unix_like(monkeypatch: pytest.MonkeyPatch): """Should return "npx" on unix-like systems.""" monkeypatch.setattr(sys, "platform", "linux") assert _get_npx_command() == "npx" -def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch): """Should return one of the npx candidates on Windows.""" candidates = ["npx.cmd", "npx.exe", "npx"] @@ -90,7 +90,7 @@ def fake_run(cmd: list[str], **kw: Any) -> subprocess.CompletedProcess[bytes]: assert _get_npx_command() in candidates -def test_get_npx_returns_none_when_npx_missing(monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_npx_returns_none_when_npx_missing(monkeypatch: pytest.MonkeyPatch): """Should give None if every candidate fails.""" monkeypatch.setattr(sys, "platform", "win32", raising=False) diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index b75b08453..09760f453 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -19,7 +19,7 @@ class MockTokenStorage: """Mock token storage for testing.""" - def __init__(self) -> None: + def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -37,12 +37,12 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.fixture -def mock_storage() -> MockTokenStorage: +def mock_storage(): return MockTokenStorage() @pytest.fixture -def client_metadata() -> OAuthClientMetadata: +def client_metadata(): return OAuthClientMetadata( client_name="Test Client", client_uri=AnyHttpUrl("https://example.com"), @@ -52,9 +52,7 @@ def client_metadata() -> OAuthClientMetadata: @pytest.fixture -def rfc7523_oauth_provider( - client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage -) -> RFC7523OAuthClientProvider: +def rfc7523_oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: # pragma: no cover """Mock redirect handler.""" pass @@ -78,9 +76,7 @@ class TestOAuthFlowClientCredentials: """Test OAuth flow behavior for client credentials flows.""" @pytest.mark.anyio - async def test_token_exchange_request_jwt_predefined( - self, rfc7523_oauth_provider: RFC7523OAuthClientProvider - ) -> None: + async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): """Test token exchange request building with a predefined JWT assertion.""" # Set up required context rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( @@ -119,7 +115,7 @@ async def test_token_exchange_request_jwt_predefined( ) @pytest.mark.anyio - async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider) -> None: + async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): """Test token exchange request building wiith a generated JWT assertion.""" # Set up required context rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( @@ -181,7 +177,7 @@ class TestClientCredentialsOAuthProvider: """Test ClientCredentialsOAuthProvider.""" @pytest.mark.anyio - async def test_init_sets_client_info(self, mock_storage: MockTokenStorage) -> None: + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): """Test that _initialize sets client_info.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -200,7 +196,7 @@ async def test_init_sets_client_info(self, mock_storage: MockTokenStorage) -> No assert provider.context.client_info.token_endpoint_auth_method == "client_secret_basic" @pytest.mark.anyio - async def test_init_with_scopes(self, mock_storage: MockTokenStorage) -> None: + async def test_init_with_scopes(self, mock_storage: MockTokenStorage): """Test that constructor accepts scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -215,7 +211,7 @@ async def test_init_with_scopes(self, mock_storage: MockTokenStorage) -> None: assert provider.context.client_info.scope == "read write" @pytest.mark.anyio - async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage) -> None: + async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): """Test that constructor accepts client_secret_post auth method.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -230,7 +226,7 @@ async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post" @pytest.mark.anyio - async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage) -> None: + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): """Test token exchange request building.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -257,7 +253,7 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert "resource=https://api.example.com/v1/mcp" in content @pytest.mark.anyio - async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage) -> None: + async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage): """Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1).""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -285,7 +281,7 @@ async def test_exchange_token_client_secret_post_includes_client_id(self, mock_s assert "Authorization" not in request.headers @pytest.mark.anyio - async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage) -> None: + async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage): """Test client_secret_post skips body credentials when client_id is None.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -323,7 +319,7 @@ async def test_exchange_token_client_secret_post_without_client_id(self, mock_st assert "Authorization" not in request.headers @pytest.mark.anyio - async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage) -> None: + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): """Test token exchange without scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", @@ -350,7 +346,7 @@ class TestPrivateKeyJWTOAuthProvider: """Test PrivateKeyJWTOAuthProvider.""" @pytest.mark.anyio - async def test_init_sets_client_info(self, mock_storage: MockTokenStorage) -> None: + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): """Test that _initialize sets client_info.""" async def mock_assertion_provider(audience: str) -> str: # pragma: no cover @@ -372,7 +368,7 @@ async def mock_assertion_provider(audience: str) -> str: # pragma: no cover assert provider.context.client_info.token_endpoint_auth_method == "private_key_jwt" @pytest.mark.anyio - async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage) -> None: + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): """Test token exchange request building with assertion provider.""" async def mock_assertion_provider(audience: str) -> str: @@ -404,7 +400,7 @@ async def mock_assertion_provider(audience: str) -> str: assert "scope=read write" in content @pytest.mark.anyio - async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage) -> None: + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): """Test token exchange without scopes.""" async def mock_assertion_provider(audience: str) -> str: @@ -435,7 +431,7 @@ class TestSignedJWTParameters: """Test SignedJWTParameters.""" @pytest.mark.anyio - async def test_create_assertion_provider(self) -> None: + async def test_create_assertion_provider(self): """Test that create_assertion_provider creates valid JWTs.""" params = SignedJWTParameters( issuer="test-issuer", @@ -462,7 +458,7 @@ async def test_create_assertion_provider(self) -> None: assert "jti" in claims @pytest.mark.anyio - async def test_create_assertion_provider_with_additional_claims(self) -> None: + async def test_create_assertion_provider_with_additional_claims(self): """Test that additional_claims are included in the JWT.""" params = SignedJWTParameters( issuer="test-issuer", @@ -488,7 +484,7 @@ class TestStaticAssertionProvider: """Test static_assertion_provider helper.""" @pytest.mark.anyio - async def test_returns_static_token(self) -> None: + async def test_returns_static_token(self): """Test that static_assertion_provider returns the same token regardless of audience.""" token = "my-static-jwt-token" provider = static_assertion_provider(token) diff --git a/tests/client/conftest.py b/tests/client/conftest.py index a7f580b86..2e39f1363 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -1,10 +1,10 @@ -from collections.abc import AsyncGenerator, Callable, Generator +from collections.abc import Callable, Generator from contextlib import asynccontextmanager from typing import Any from unittest.mock import patch import pytest -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectSendStream import mcp.shared.memory from mcp.shared.message import SessionMessage @@ -12,26 +12,26 @@ class SpyMemoryObjectSendStream: - def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]) -> None: + def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] - async def send(self, message: SessionMessage) -> None: + async def send(self, message: SessionMessage): self.sent_messages.append(message) await self.original_stream.send(message) - async def aclose(self) -> None: + async def aclose(self): await self.original_stream.aclose() - async def __aenter__(self) -> "SpyMemoryObjectSendStream": + async def __aenter__(self): return self - async def __aexit__(self, *args: Any) -> None: + async def __aexit__(self, *args: Any): await self.aclose() class StreamSpyCollection: - def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream) -> None: + def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream): self.client = client_spy self.server = server_spy @@ -99,7 +99,7 @@ async def test_something(stream_spy): server_spy = None # Store references to our spy objects - def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream) -> None: + def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream): nonlocal client_spy, server_spy client_spy = c_spy server_spy = s_spy @@ -108,13 +108,7 @@ def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendSt original_create_streams = mcp.shared.memory.create_client_server_memory_streams @asynccontextmanager - async def patched_create_streams() -> AsyncGenerator[ - tuple[ - tuple[MemoryObjectReceiveStream[SessionMessage | Exception], SpyMemoryObjectSendStream], - tuple[MemoryObjectReceiveStream[SessionMessage | Exception], SpyMemoryObjectSendStream], - ], - None, - ]: + async def patched_create_streams(): async with original_create_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 964a5b333..5aa985e36 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -40,7 +40,7 @@ class MockTokenStorage: """Mock token storage for testing.""" - def __init__(self) -> None: + def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -58,12 +58,12 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.fixture -def mock_storage() -> MockTokenStorage: +def mock_storage(): return MockTokenStorage() @pytest.fixture -def client_metadata() -> OAuthClientMetadata: +def client_metadata(): return OAuthClientMetadata( client_name="Test Client", client_uri=AnyHttpUrl("https://example.com"), @@ -73,7 +73,7 @@ def client_metadata() -> OAuthClientMetadata: @pytest.fixture -def valid_tokens() -> OAuthToken: +def valid_tokens(): return OAuthToken( access_token="test_access_token", token_type="Bearer", @@ -84,7 +84,7 @@ def valid_tokens() -> OAuthToken: @pytest.fixture -def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage) -> OAuthClientProvider: +def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: """Mock redirect handler.""" pass # pragma: no cover @@ -103,7 +103,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.fixture -def prm_metadata_response() -> httpx.Response: +def prm_metadata_response(): """PRM metadata response with scopes.""" return httpx.Response( 200, @@ -116,7 +116,7 @@ def prm_metadata_response() -> httpx.Response: @pytest.fixture -def prm_metadata_without_scopes_response() -> httpx.Response: +def prm_metadata_without_scopes_response(): """PRM metadata response without scopes.""" return httpx.Response( 200, @@ -129,7 +129,7 @@ def prm_metadata_without_scopes_response() -> httpx.Response: @pytest.fixture -def init_response_with_www_auth_scope() -> httpx.Response: +def init_response_with_www_auth_scope(): """Initial 401 response with WWW-Authenticate header containing scope.""" return httpx.Response( 401, @@ -139,7 +139,7 @@ def init_response_with_www_auth_scope() -> httpx.Response: @pytest.fixture -def init_response_without_www_auth_scope() -> httpx.Response: +def init_response_without_www_auth_scope(): """Initial 401 response without WWW-Authenticate scope.""" return httpx.Response( 401, @@ -151,7 +151,7 @@ def init_response_without_www_auth_scope() -> httpx.Response: class TestPKCEParameters: """Test PKCE parameter generation.""" - def test_pkce_generation(self) -> None: + def test_pkce_generation(self): """Test PKCE parameter generation creates valid values.""" pkce = PKCEParameters.generate() @@ -166,7 +166,7 @@ def test_pkce_generation(self) -> None: # Verify base64url encoding in challenge (no padding) assert "=" not in pkce.code_challenge - def test_pkce_uniqueness(self) -> None: + def test_pkce_uniqueness(self): """Test PKCE generates unique values each time.""" pkce1 = PKCEParameters.generate() pkce2 = PKCEParameters.generate() @@ -181,7 +181,7 @@ class TestOAuthContext: @pytest.mark.anyio async def test_oauth_provider_initialization( self, oauth_provider: OAuthClientProvider, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test OAuthClientProvider basic setup.""" assert oauth_provider.context.server_url == "https://api.example.com/v1/mcp" assert oauth_provider.context.client_metadata == client_metadata @@ -189,7 +189,7 @@ async def test_oauth_provider_initialization( assert oauth_provider.context.timeout == 300.0 assert oauth_provider.context is not None - def test_context_url_parsing(self, oauth_provider: OAuthClientProvider) -> None: + def test_context_url_parsing(self, oauth_provider: OAuthClientProvider): """Test get_authorization_base_url() extracts base URLs correctly.""" context = oauth_provider.context @@ -211,7 +211,7 @@ def test_context_url_parsing(self, oauth_provider: OAuthClientProvider) -> None: ) @pytest.mark.anyio - async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken) -> None: + async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test is_token_valid() and can_refresh_token() logic.""" context = oauth_provider.context @@ -246,7 +246,7 @@ async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider context.client_info = None assert not context.can_refresh_token() - def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken) -> None: + def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test clear_tokens() removes token data.""" context = oauth_provider.context context.current_tokens = valid_tokens @@ -266,7 +266,7 @@ class TestOAuthFlow: @pytest.mark.anyio async def test_build_protected_resource_discovery_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test protected resource metadata discovery URL building with fallback.""" async def redirect_handler(url: str) -> None: @@ -307,7 +307,7 @@ async def callback_handler() -> tuple[str, str | None]: assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @pytest.mark.anyio - def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider) -> None: + def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): """Test OAuth metadata discovery request building.""" request = create_oauth_metadata_request("https://example.com") @@ -321,7 +321,7 @@ class TestOAuthFallback: """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" @pytest.mark.anyio - async def test_oauth_discovery_legacy_fallback_when_no_prm(self) -> None: + async def test_oauth_discovery_legacy_fallback_when_no_prm(self): """Test that when PRM discovery fails, only root OAuth URL is tried (March 2025 spec).""" # When auth_server_url is None (PRM failed), we use server_url and only try root discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://mcp.linear.app/sse") @@ -332,7 +332,7 @@ async def test_oauth_discovery_legacy_fallback_when_no_prm(self) -> None: ] @pytest.mark.anyio - async def test_oauth_discovery_path_aware_when_auth_server_has_path(self) -> None: + async def test_oauth_discovery_path_aware_when_auth_server_has_path(self): """Test that when auth server URL has a path, only path-based URLs are tried.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com/tenant1", "https://api.example.com/mcp" @@ -346,7 +346,7 @@ async def test_oauth_discovery_path_aware_when_auth_server_has_path(self) -> Non ] @pytest.mark.anyio - async def test_oauth_discovery_root_when_auth_server_has_no_path(self) -> None: + async def test_oauth_discovery_root_when_auth_server_has_no_path(self): """Test that when auth server URL has no path, only root URLs are tried.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com", "https://api.example.com/mcp" @@ -359,7 +359,7 @@ async def test_oauth_discovery_root_when_auth_server_has_no_path(self) -> None: ] @pytest.mark.anyio - async def test_oauth_discovery_root_when_auth_server_has_only_slash(self) -> None: + async def test_oauth_discovery_root_when_auth_server_has_only_slash(self): """Test that when auth server URL has only trailing slash, treated as root.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com/", "https://api.example.com/mcp" @@ -372,7 +372,7 @@ async def test_oauth_discovery_root_when_auth_server_has_only_slash(self) -> Non ] @pytest.mark.anyio - async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider) -> None: + async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): """Test fallback URL construction order when auth server URL has a path.""" # Simulate PRM discovery returning an auth server URL with a path oauth_provider.context.auth_server_url = oauth_provider.context.server_url @@ -388,7 +388,7 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientP ] @pytest.mark.anyio - async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthClientProvider) -> None: + async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthClientProvider): """Test the conditions during which an AS metadata discovery fallback will be attempted.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -507,7 +507,7 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl pass # Expected - generator should complete @pytest.mark.anyio - async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider) -> None: + async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider): """Test successful metadata response handling.""" # Create minimal valid OAuth metadata content = b"""{ @@ -528,7 +528,7 @@ async def test_prioritize_www_auth_scope_over_prm( oauth_provider: OAuthClientProvider, prm_metadata_response: httpx.Response, init_response_with_www_auth_scope: httpx.Response, - ) -> None: + ): """Test that WWW-Authenticate scope is prioritized over PRM scopes.""" # First, process PRM metadata to set protected_resource_metadata with scopes await oauth_provider._handle_protected_resource_response(prm_metadata_response) @@ -548,7 +548,7 @@ async def test_prioritize_prm_scopes_when_no_www_auth_scope( oauth_provider: OAuthClientProvider, prm_metadata_response: httpx.Response, init_response_without_www_auth_scope: httpx.Response, - ) -> None: + ): """Test that PRM scopes are prioritized when WWW-Authenticate header has no scopes.""" # Process the PRM metadata to set protected_resource_metadata with scopes await oauth_provider._handle_protected_resource_response(prm_metadata_response) @@ -568,7 +568,7 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( oauth_provider: OAuthClientProvider, prm_metadata_without_scopes_response: httpx.Response, init_response_without_www_auth_scope: httpx.Response, - ) -> None: + ): """Test that scope is omitted when PRM has no scopes and WWW-Authenticate doesn't specify scope.""" # Process the PRM metadata without scopes await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) @@ -582,7 +582,7 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( assert scopes is None @pytest.mark.anyio - async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider) -> None: + async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider): """Test token exchange request building.""" # Set up required context oauth_provider.context.client_info = OAuthClientInformationFull( @@ -607,7 +607,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O assert "client_secret=test_secret" in content @pytest.mark.anyio - async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken) -> None: + async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test refresh token request building.""" # Set up required context oauth_provider.context.current_tokens = valid_tokens @@ -632,7 +632,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert "client_secret=test_secret" in content @pytest.mark.anyio - async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider) -> None: + async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider): """Test token exchange with client_secret_basic authentication.""" # Set up OAuth metadata to support basic auth oauth_provider.context.oauth_metadata = OAuthMetadata( @@ -677,9 +677,7 @@ async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvid assert "client_id=test%40client" in content # client_id still in body @pytest.mark.anyio - async def test_basic_auth_refresh_token( - self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken - ) -> None: + async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test token refresh with client_secret_basic authentication.""" oauth_provider.context.current_tokens = valid_tokens @@ -714,7 +712,7 @@ async def test_basic_auth_refresh_token( assert "client_secret=" not in content @pytest.mark.anyio - async def test_none_auth_method(self, oauth_provider: OAuthClientProvider) -> None: + async def test_none_auth_method(self, oauth_provider: OAuthClientProvider): """Test 'none' authentication method (public client).""" oauth_provider.context.oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -746,9 +744,7 @@ class TestProtectedResourceMetadata: """Test protected resource handling.""" @pytest.mark.anyio - async def test_resource_param_included_with_recent_protocol_version( - self, oauth_provider: OAuthClientProvider - ) -> None: + async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): """Test resource parameter is included for protocol version >= 2025-06-18.""" # Set protocol version to 2025-06-18 oauth_provider.context.protocol_version = "2025-06-18" @@ -777,7 +773,7 @@ async def test_resource_param_included_with_recent_protocol_version( assert "resource=" in refresh_content @pytest.mark.anyio - async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider) -> None: + async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): """Test resource parameter is excluded for protocol version < 2025-06-18.""" # Set protocol version to older version oauth_provider.context.protocol_version = "2025-03-26" @@ -803,9 +799,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro assert "resource=" not in refresh_content @pytest.mark.anyio - async def test_resource_param_included_with_protected_resource_metadata( - self, oauth_provider: OAuthClientProvider - ) -> None: + async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): """Test resource parameter is always included when protected resource metadata exists.""" # Set old protocol version but with protected resource metadata oauth_provider.context.protocol_version = "2025-03-26" @@ -959,22 +953,22 @@ class TestRegistrationResponse: """Test client registration response handling.""" @pytest.mark.anyio - async def test_handle_registration_response_reads_before_accessing_text(self) -> None: + async def test_handle_registration_response_reads_before_accessing_text(self): """Test that response.aread() is called before accessing response.text.""" # Track if aread() was called class MockResponse(httpx.Response): - def __init__(self) -> None: + def __init__(self): self.status_code = 400 self._aread_called = False self._text = "Registration failed with error" - async def aread(self) -> bytes: + async def aread(self): self._aread_called = True return b"test content" @property - def text(self) -> str: + def text(self): if not self._aread_called: raise RuntimeError("Response.text accessed before response.aread()") # pragma: no cover return self._text @@ -994,7 +988,7 @@ def text(self) -> str: class TestCreateClientRegistrationRequest: """Test client registration request creation.""" - def test_uses_registration_endpoint_from_metadata(self) -> None: + def test_uses_registration_endpoint_from_metadata(self): """Test that registration URL comes from metadata when available.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -1009,7 +1003,7 @@ def test_uses_registration_endpoint_from_metadata(self) -> None: assert str(request.url) == "https://auth.example.com/register" assert request.method == "POST" - def test_falls_back_to_default_register_endpoint_when_no_metadata(self) -> None: + def test_falls_back_to_default_register_endpoint_when_no_metadata(self): """Test that registration uses fallback URL when auth_server_metadata is None.""" client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) @@ -1018,7 +1012,7 @@ def test_falls_back_to_default_register_endpoint_when_no_metadata(self) -> None: assert str(request.url) == "https://auth.example.com/register" assert request.method == "POST" - def test_falls_back_when_metadata_has_no_registration_endpoint(self) -> None: + def test_falls_back_when_metadata_has_no_registration_endpoint(self): """Test fallback when metadata exists but lacks registration_endpoint.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -1040,7 +1034,7 @@ class TestAuthFlow: @pytest.mark.anyio async def test_auth_flow_with_valid_tokens( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken - ) -> None: + ): """Test auth flow when tokens are already valid.""" # Pre-store valid tokens await mock_storage.set_tokens(valid_tokens) @@ -1066,9 +1060,7 @@ async def test_auth_flow_with_valid_tokens( pass # Expected @pytest.mark.anyio - async def test_auth_flow_with_no_tokens( - self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage - ) -> None: + async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage): """Test auth flow when no tokens are available, triggering the full OAuth flow.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -1178,7 +1170,7 @@ async def test_auth_flow_with_no_tokens( @pytest.mark.anyio async def test_auth_flow_no_unnecessary_retry_after_oauth( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken - ) -> None: + ): """Test that requests are not retried unnecessarily - the core bug that caused 2x performance degradation.""" # Pre-store valid tokens so no OAuth flow is needed await mock_storage.set_tokens(valid_tokens) @@ -1221,7 +1213,7 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( @pytest.mark.anyio async def test_token_exchange_accepts_201_status( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage - ) -> None: + ): """Test that token exchange accepts both 200 and 201 status codes.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -1334,7 +1326,7 @@ async def test_403_insufficient_scope_updates_scope_from_header( oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken, - ) -> None: + ): """Test that 403 response correctly updates scope from WWW-Authenticate header.""" # Pre-store valid tokens and client info client_info = OAuthClientInformationFull( @@ -1470,7 +1462,7 @@ def test_build_metadata( token_endpoint: str, registration_endpoint: str, revocation_endpoint: str, -) -> None: +): metadata = build_metadata( issuer_url=AnyHttpUrl(issuer_url), service_documentation_url=AnyHttpUrl(service_documentation_url), @@ -1501,7 +1493,7 @@ class TestLegacyServerFallback: @pytest.mark.anyio async def test_legacy_server_no_prm_falls_back_to_root_oauth_discovery( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test that when PRM discovery fails completely, we fall back to root OAuth discovery (March 2025 spec).""" async def redirect_handler(url: str) -> None: @@ -1600,7 +1592,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_legacy_server_with_different_prm_and_root_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test PRM fallback with different WWW-Authenticate and root URLs.""" async def redirect_handler(url: str) -> None: @@ -1705,7 +1697,7 @@ class TestSEP985Discovery: @pytest.mark.anyio async def test_path_based_fallback_when_no_www_authenticate( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" async def redirect_handler(url: str) -> None: @@ -1740,7 +1732,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_root_based_fallback_after_path_based_404( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test that client falls back to root-based URI when path-based returns 404.""" async def redirect_handler(url: str) -> None: @@ -1841,7 +1833,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_www_authenticate_takes_priority_over_well_known( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" async def redirect_handler(url: str) -> None: @@ -1935,7 +1927,7 @@ def test_extract_field_from_www_auth_valid_cases( www_auth_header: str, field_name: str, expected_value: str, - ) -> None: + ): """Test extraction of various fields from valid WWW-Authenticate headers.""" init_response = httpx.Response( @@ -1969,7 +1961,7 @@ def test_extract_field_from_www_auth_invalid_cases( www_auth_header: str | None, field_name: str, description: str, - ) -> None: + ): """Test extraction returns None for invalid cases.""" headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} @@ -2004,11 +1996,11 @@ class TestCIMD: ("http://[::1/foo/", False), ], ) - def test_is_valid_client_metadata_url(self, url: str | None, expected: bool) -> None: + def test_is_valid_client_metadata_url(self, url: str | None, expected: bool): """Test CIMD URL validation.""" assert is_valid_client_metadata_url(url) == expected - def test_should_use_client_metadata_url_when_server_supports(self) -> None: + def test_should_use_client_metadata_url_when_server_supports(self): """Test that CIMD is used when server supports it and URL is provided.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -2018,7 +2010,7 @@ def test_should_use_client_metadata_url_when_server_supports(self) -> None: ) assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True - def test_should_not_use_client_metadata_url_when_server_does_not_support(self) -> None: + def test_should_not_use_client_metadata_url_when_server_does_not_support(self): """Test that CIMD is not used when server doesn't support it.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -2028,7 +2020,7 @@ def test_should_not_use_client_metadata_url_when_server_does_not_support(self) - ) assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False - def test_should_not_use_client_metadata_url_when_not_provided(self) -> None: + def test_should_not_use_client_metadata_url_when_not_provided(self): """Test that CIMD is not used when no URL is provided.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), @@ -2038,11 +2030,11 @@ def test_should_not_use_client_metadata_url_when_not_provided(self) -> None: ) assert should_use_client_metadata_url(oauth_metadata, None) is False - def test_should_not_use_client_metadata_url_when_no_metadata(self) -> None: + def test_should_not_use_client_metadata_url_when_no_metadata(self): """Test that CIMD is not used when OAuth metadata is None.""" assert should_use_client_metadata_url(None, "https://example.com/client") is False - def test_create_client_info_from_metadata_url(self) -> None: + def test_create_client_info_from_metadata_url(self): """Test creating client info from CIMD URL.""" client_info = create_client_info_from_metadata_url( "https://example.com/client", @@ -2055,7 +2047,7 @@ def test_create_client_info_from_metadata_url(self) -> None: def test_oauth_provider_with_valid_client_metadata_url( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test OAuthClientProvider initialization with valid client_metadata_url.""" async def redirect_handler(url: str) -> None: @@ -2076,7 +2068,7 @@ async def callback_handler() -> tuple[str, str | None]: def test_oauth_provider_with_invalid_client_metadata_url_raises_error( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test OAuthClientProvider raises error for invalid client_metadata_url.""" async def redirect_handler(url: str) -> None: @@ -2099,7 +2091,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_auth_flow_uses_cimd_when_server_supports( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test that auth flow uses CIMD URL as client_id when server supports it.""" async def redirect_handler(url: str) -> None: @@ -2190,7 +2182,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage - ) -> None: + ): """Test that auth flow falls back to DCR when server doesn't support CIMD.""" async def redirect_handler(url: str) -> None: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 62c454fde..18368e6bb 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -96,7 +96,7 @@ def greeting_prompt(name: str) -> str: return server -async def test_client_is_initialized(app: MCPServer) -> None: +async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: assert client.initialize_result.capabilities == snapshot( @@ -110,7 +110,7 @@ async def test_client_is_initialized(app: MCPServer) -> None: assert client.initialize_result.server_info.name == "test" -async def test_client_with_simple_server(simple_server: Server) -> None: +async def test_client_with_simple_server(simple_server: Server): """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: resources = await client.list_resources() @@ -121,13 +121,13 @@ async def test_client_with_simple_server(simple_server: Server) -> None: ) -async def test_client_send_ping(app: MCPServer) -> None: +async def test_client_send_ping(app: MCPServer): async with Client(app) as client: result = await client.send_ping() assert result == snapshot(EmptyResult()) -async def test_client_list_tools(app: MCPServer) -> None: +async def test_client_list_tools(app: MCPServer): async with Client(app) as client: result = await client.list_tools() assert result == snapshot( @@ -154,7 +154,7 @@ async def test_client_list_tools(app: MCPServer) -> None: ) -async def test_client_call_tool(app: MCPServer) -> None: +async def test_client_call_tool(app: MCPServer): async with Client(app) as client: result = await client.call_tool("greet", {"name": "World"}) assert result == snapshot( @@ -165,7 +165,7 @@ async def test_client_call_tool(app: MCPServer) -> None: ) -async def test_read_resource(app: MCPServer) -> None: +async def test_read_resource(app: MCPServer): """Test reading a resource.""" async with Client(app) as client: result = await client.read_resource("test://resource") @@ -176,7 +176,7 @@ async def test_read_resource(app: MCPServer) -> None: ) -async def test_read_resource_error_propagates() -> None: +async def test_read_resource_error_propagates(): """MCPError raised by a server handler propagates to the client with its code intact.""" async def handle_read_resource( @@ -191,7 +191,7 @@ async def handle_read_resource( assert exc_info.value.error.code == 404 -async def test_get_prompt(app: MCPServer) -> None: +async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: result = await client.get_prompt("greeting_prompt", {"name": "Alice"}) @@ -203,21 +203,21 @@ async def test_get_prompt(app: MCPServer) -> None: ) -def test_client_session_property_before_enter(app: MCPServer) -> None: +def test_client_session_property_before_enter(app: MCPServer): """Test that accessing session before context manager raises RuntimeError.""" client = Client(app) with pytest.raises(RuntimeError, match="Client must be used within an async context manager"): client.session -async def test_client_reentry_raises_runtime_error(app: MCPServer) -> None: +async def test_client_reentry_raises_runtime_error(app: MCPServer): """Test that reentering a client raises RuntimeError.""" async with Client(app) as client: with pytest.raises(RuntimeError, match="Client is already entered"): await client.__aenter__() -async def test_client_send_progress_notification() -> None: +async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() @@ -235,26 +235,26 @@ async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotif assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0}) -async def test_client_subscribe_resource(simple_server: Server) -> None: +async def test_client_subscribe_resource(simple_server: Server): async with Client(simple_server) as client: result = await client.subscribe_resource("memory://test") assert result == snapshot(EmptyResult()) -async def test_client_unsubscribe_resource(simple_server: Server) -> None: +async def test_client_unsubscribe_resource(simple_server: Server): async with Client(simple_server) as client: result = await client.unsubscribe_resource("memory://test") assert result == snapshot(EmptyResult()) -async def test_client_set_logging_level(simple_server: Server) -> None: +async def test_client_set_logging_level(simple_server: Server): """Test setting logging level.""" async with Client(simple_server) as client: result = await client.set_logging_level("debug") assert result == snapshot(EmptyResult()) -async def test_client_list_resources_with_params(app: MCPServer) -> None: +async def test_client_list_resources_with_params(app: MCPServer): """Test listing resources with params parameter.""" async with Client(app) as client: result = await client.list_resources() @@ -272,14 +272,14 @@ async def test_client_list_resources_with_params(app: MCPServer) -> None: ) -async def test_client_list_resource_templates(app: MCPServer) -> None: +async def test_client_list_resource_templates(app: MCPServer): """Test listing resource templates with params parameter.""" async with Client(app) as client: result = await client.list_resource_templates() assert result == snapshot(ListResourceTemplatesResult(resource_templates=[])) -async def test_list_prompts(app: MCPServer) -> None: +async def test_list_prompts(app: MCPServer): """Test listing prompts with params parameter.""" async with Client(app) as client: result = await client.list_prompts() @@ -296,7 +296,7 @@ async def test_list_prompts(app: MCPServer) -> None: ) -async def test_complete_with_prompt_reference(simple_server: Server) -> None: +async def test_complete_with_prompt_reference(simple_server: Server): """Test getting completions for a prompt argument.""" async with Client(simple_server) as client: ref = types.PromptReference(type="ref/prompt", name="test_prompt") @@ -304,13 +304,13 @@ async def test_complete_with_prompt_reference(simple_server: Server) -> None: assert result == snapshot(types.CompleteResult(completion=types.Completion(values=[]))) -def test_client_with_url_initializes_streamable_http_transport() -> None: +def test_client_with_url_initializes_streamable_http_transport(): with patch("mcp.client.client.streamable_http_client") as mock: _ = Client("http://localhost:8000/mcp") mock.assert_called_once_with("http://localhost:8000/mcp") -async def test_client_uses_transport_directly(app: MCPServer) -> None: +async def test_client_uses_transport_directly(app: MCPServer): transport = InMemoryTransport(app) async with Client(transport) as client: result = await client.call_tool("greet", {"name": "Transport"}) diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 5055c10c3..f70fb9277 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -13,7 +13,7 @@ @pytest.fixture -async def full_featured_server() -> MCPServer: +async def full_featured_server(): """Create a server with tools, resources, prompts, and templates.""" server = MCPServer("test") @@ -57,7 +57,7 @@ async def test_list_methods_params_parameter( full_featured_server: MCPServer, method_name: str, request_method: str, -) -> None: +): """Test that the params parameter is accepted and correctly passed to the server. Covers: list_tools, list_resources, list_prompts, list_resource_templates @@ -95,7 +95,7 @@ async def test_list_methods_params_parameter( async def test_list_tools_with_strict_server_validation( full_featured_server: MCPServer, -) -> None: +): """Test pagination with a server that validates request format strictly.""" async with Client(full_featured_server) as client: result = await client.list_tools() @@ -103,7 +103,7 @@ async def test_list_tools_with_strict_server_validation( assert len(result.tools) > 0 -async def test_list_tools_with_lowlevel_server() -> None: +async def test_list_tools_with_lowlevel_server(): """Test that list_tools works with a lowlevel Server using params.""" async def handle_list_tools( diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 44cc3b943..be4b9a97b 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -9,7 +9,7 @@ @pytest.mark.anyio -async def test_list_roots_callback() -> None: +async def test_list_roots_callback(): server = MCPServer("test") callback_return = ListRootsResult( @@ -25,7 +25,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context, message: str) -> bool: + async def test_list_roots(context: Context, message: str): roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index affb4469f..1598fd55f 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -12,7 +12,7 @@ class LoggingCollector: - def __init__(self) -> None: + def __init__(self): self.log_messages: list[LoggingMessageNotificationParams] = [] async def __call__(self, params: LoggingMessageNotificationParams) -> None: @@ -20,7 +20,7 @@ async def __call__(self, params: LoggingMessageNotificationParams) -> None: @pytest.mark.anyio -async def test_logging_callback() -> None: +async def test_logging_callback(): server = MCPServer("test") logging_collector = LoggingCollector() diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 2432293ef..d78197b5c 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -34,7 +34,7 @@ async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_basemodel() -> None: +async def test_tool_structured_output_client_side_validation_basemodel(): """Test that client validates structured content against schema for BaseModel outputs""" output_schema = { "type": "object", @@ -62,7 +62,7 @@ async def test_tool_structured_output_client_side_validation_basemodel() -> None @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_primitive() -> None: +async def test_tool_structured_output_client_side_validation_primitive(): """Test that client validates structured content for primitive outputs""" output_schema = { "type": "object", @@ -90,7 +90,7 @@ async def test_tool_structured_output_client_side_validation_primitive() -> None @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_dict_typed() -> None: +async def test_tool_structured_output_client_side_validation_dict_typed(): """Test that client validates dict[str, T] structured content""" output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} @@ -113,7 +113,7 @@ async def test_tool_structured_output_client_side_validation_dict_typed() -> Non @pytest.mark.anyio -async def test_tool_structured_output_client_side_validation_missing_required() -> None: +async def test_tool_structured_output_client_side_validation_missing_required(): """Test that client validates missing required fields""" output_schema = { "type": "object", @@ -141,7 +141,7 @@ async def test_tool_structured_output_client_side_validation_missing_required() @pytest.mark.anyio -async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture) -> None: +async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): """Test that client logs warning when tool is not in list_tools but has output_schema""" async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index 16ba40cd6..c7bf8fafa 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -1,4 +1,4 @@ -from typing import Any, NoReturn +from typing import Any from unittest.mock import patch import anyio @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_send_request_stream_cleanup() -> None: +async def test_send_request_stream_cleanup(): """Test that send_request properly cleans up streams when an exception occurs. This test mocks out most of the session functionality to focus on stream cleanup. @@ -43,7 +43,7 @@ def _receive_notification_adapter(self) -> TypeAdapter[Any]: request = PingRequest() # Patch the _write_stream.send method to raise an exception - async def mock_send(*args: Any, **kwargs: Any) -> NoReturn: + async def mock_send(*args: Any, **kwargs: Any): raise RuntimeError("Simulated network error") # Record the response streams before the test diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 484d14e55..6efcac0a5 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -15,7 +15,7 @@ @pytest.mark.anyio -async def test_sampling_callback() -> None: +async def test_sampling_callback(): server = MCPServer("test") callback_return = CreateMessageResult( @@ -58,7 +58,7 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: @pytest.mark.anyio -async def test_create_message_backwards_compat_single_content() -> None: +async def test_create_message_backwards_compat_single_content(): """Test backwards compatibility: create_message without tools returns single content.""" server = MCPServer("test") @@ -100,7 +100,7 @@ async def test_tool(message: str, ctx: Context) -> bool: @pytest.mark.anyio -async def test_create_message_result_with_tools_type() -> None: +async def test_create_message_result_with_tools_type(): """Test that CreateMessageResultWithTools supports content_as_list.""" # Test the type itself, not the overload (overload requires client capability setup) result = CreateMessageResultWithTools( diff --git a/tests/client/test_scope_bug_1630.py b/tests/client/test_scope_bug_1630.py index f273688a3..fafa51007 100644 --- a/tests/client/test_scope_bug_1630.py +++ b/tests/client/test_scope_bug_1630.py @@ -35,7 +35,7 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.mark.anyio -async def test_401_uses_www_auth_scope_not_resource_metadata_url() -> None: +async def test_401_uses_www_auth_scope_not_resource_metadata_url(): """Regression test for #1630: Ensure scope is extracted from WWW-Authenticate header, not the resource_metadata URL. diff --git a/tests/client/test_session.py b/tests/client/test_session.py index e90b680d5..f25c964f0 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -28,14 +28,14 @@ @pytest.mark.anyio -async def test_client_session_initialize() -> None: +async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) initialized_notification = None result = None - async def mock_server() -> None: + async def mock_server(): nonlocal initialized_notification session_message = await client_to_server_receive.receive() @@ -111,14 +111,14 @@ async def message_handler( # pragma: no cover @pytest.mark.anyio -async def test_client_session_custom_client_info() -> None: +async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) custom_client_info = Implementation(name="test-client", version="1.2.3") received_client_info = None - async def mock_server() -> None: + async def mock_server(): nonlocal received_client_info session_message = await client_to_server_receive.receive() @@ -169,13 +169,13 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_session_default_client_info() -> None: +async def test_client_session_default_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_client_info = None - async def mock_server() -> None: + async def mock_server(): nonlocal received_client_info session_message = await client_to_server_receive.receive() @@ -222,13 +222,13 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_session_version_negotiation_success() -> None: +async def test_client_session_version_negotiation_success(): """Test successful version negotiation with supported version""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) result = None - async def mock_server() -> None: + async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -278,12 +278,12 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_session_version_negotiation_failure() -> None: +async def test_client_session_version_negotiation_failure(): """Test version negotiation failure with unsupported version""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - async def mock_server() -> None: + async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -326,14 +326,14 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_capabilities_default() -> None: +async def test_client_capabilities_default(): """Test that client capabilities are properly set with default callbacks""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -382,7 +382,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_capabilities_with_custom_callbacks() -> None: +async def test_client_capabilities_with_custom_callbacks(): """Test that client capabilities are properly set with custom callbacks""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -404,7 +404,7 @@ async def custom_list_roots_callback( # pragma: no cover ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -466,7 +466,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_capabilities_with_sampling_tools() -> None: +async def test_client_capabilities_with_sampling_tools(): """Test that sampling capabilities with tools are properly advertised""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -483,7 +483,7 @@ async def custom_sampling_callback( # pragma: no cover model="test-model", ) - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -540,7 +540,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_initialize_result() -> None: +async def test_initialize_result(): """Test that initialize_result is None before init and contains the full result after.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -554,7 +554,7 @@ async def test_initialize_result() -> None: expected_server_info = Implementation(name="mock-server", version="0.1.0") expected_instructions = "Use the tools wisely." - async def mock_server() -> None: + async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -608,14 +608,14 @@ async def mock_server() -> None: @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) -async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None) -> None: +async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None): """Test that client tool call requests can include metadata""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) mocked_tool = types.Tool(name="sample_tool", input_schema={}) - async def mock_server() -> None: + async def mock_server(): # Receive initialization request from client session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index f7f101a7d..6a58b39f3 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -17,7 +17,7 @@ @pytest.fixture -def mock_exit_stack() -> mock.MagicMock: +def mock_exit_stack(): """Fixture for a mocked AsyncExitStack.""" # Use unittest.mock.Mock directly if needed, or just a plain object # if only attribute access/existence is needed. @@ -25,7 +25,7 @@ def mock_exit_stack() -> mock.MagicMock: return mock.MagicMock(spec=contextlib.AsyncExitStack) -def test_client_session_group_init() -> None: +def test_client_session_group_init(): mcp_session_group = ClientSessionGroup() assert not mcp_session_group._tools assert not mcp_session_group._resources @@ -33,7 +33,7 @@ def test_client_session_group_init() -> None: assert not mcp_session_group._tool_to_session -def test_client_session_group_component_properties() -> None: +def test_client_session_group_component_properties(): # --- Mock Dependencies --- mock_prompt = mock.Mock() mock_resource = mock.Mock() @@ -52,7 +52,7 @@ def test_client_session_group_component_properties() -> None: @pytest.mark.anyio -async def test_client_session_group_call_tool() -> None: +async def test_client_session_group_call_tool(): # --- Mock Dependencies --- mock_session = mock.AsyncMock() @@ -87,7 +87,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov @pytest.mark.anyio -async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack) -> None: +async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting to a server and aggregating components.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -126,9 +126,7 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli @pytest.mark.anyio -async def test_client_session_group_connect_to_server_with_name_hook( - mock_exit_stack: contextlib.AsyncExitStack, -) -> None: +async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -159,7 +157,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str: @pytest.mark.anyio -async def test_client_session_group_disconnect_from_server() -> None: +async def test_client_session_group_disconnect_from_server(): """Test disconnecting from a server.""" # --- Test Setup --- group = ClientSessionGroup() @@ -226,7 +224,7 @@ async def test_client_session_group_disconnect_from_server() -> None: @pytest.mark.anyio async def test_client_session_group_connect_to_server_duplicate_tool_raises_error( mock_exit_stack: contextlib.AsyncExitStack, -) -> None: +): """Test MCPError raised when connecting a server with a dup name.""" # --- Setup Pre-existing State --- group = ClientSessionGroup(exit_stack=mock_exit_stack) @@ -272,7 +270,7 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro @pytest.mark.anyio -async def test_client_session_group_disconnect_non_existent_server() -> None: +async def test_client_session_group_disconnect_non_existent_server(): """Test disconnecting a server that isn't connected.""" session = mock.Mock(spec=mcp.ClientSession) group = ClientSessionGroup() @@ -306,7 +304,7 @@ async def test_client_session_group_establish_session_parameterized( server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters, client_type_name: str, # Just for clarity or conditional logic if needed patch_target_for_client_func: str, -) -> None: +): with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: with mock.patch(patch_target_for_client_func) as mock_specific_client_func: mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM") diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 7965ba025..06e2cba4b 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -31,7 +31,7 @@ @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") -async def test_stdio_context_manager_exiting() -> None: +async def test_stdio_context_manager_exiting(): assert tee is not None async with stdio_client(StdioServerParameters(command=tee)) as (_, _): pass @@ -39,7 +39,7 @@ async def test_stdio_context_manager_exiting() -> None: @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") -async def test_stdio_client() -> None: +async def test_stdio_client(): assert tee is not None server_parameters = StdioServerParameters(command=tee) @@ -71,7 +71,7 @@ async def test_stdio_client() -> None: @pytest.mark.anyio -async def test_stdio_client_bad_path() -> None: +async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" server_params = StdioServerParameters(command=sys.executable, args=["-c", "non-existent-file.py"]) async with stdio_client(server_params) as (read_stream, write_stream): @@ -86,7 +86,7 @@ async def test_stdio_client_bad_path() -> None: @pytest.mark.anyio -async def test_stdio_client_nonexistent_command() -> None: +async def test_stdio_client_nonexistent_command(): """Test that stdio_client raises an error for non-existent commands.""" # Create a server with a non-existent command server_params = StdioServerParameters( @@ -104,7 +104,7 @@ async def test_stdio_client_nonexistent_command() -> None: @pytest.mark.anyio -async def test_stdio_client_universal_cleanup() -> None: +async def test_stdio_client_universal_cleanup(): """Test that stdio_client completes cleanup within reasonable time even when connected to processes that exit slowly. """ @@ -156,7 +156,7 @@ async def test_stdio_client_universal_cleanup() -> None: @pytest.mark.anyio @pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different") -async def test_stdio_client_sigint_only_process() -> None: # pragma: lax no cover +async def test_stdio_client_sigint_only_process(): # pragma: lax no cover """Test cleanup with a process that ignores SIGTERM but responds to SIGINT.""" # Create a Python script that ignores SIGTERM but handles SIGINT script_content = textwrap.dedent( @@ -353,7 +353,7 @@ class TestChildProcessCleanup: """ @pytest.mark.anyio - async def test_basic_child_process_cleanup(self) -> None: + async def test_basic_child_process_cleanup(self): """Parent spawns one child; terminating the tree kills both.""" async with AsyncExitStack() as stack: sock, port = await _open_liveness_listener() @@ -377,7 +377,7 @@ async def test_basic_child_process_cleanup(self) -> None: await _assert_stream_closed(stream) @pytest.mark.anyio - async def test_nested_process_tree(self) -> None: + async def test_nested_process_tree(self): """Parent → child → grandchild; terminating the tree kills all three.""" async with AsyncExitStack() as stack: sock, port = await _open_liveness_listener() @@ -413,7 +413,7 @@ async def test_nested_process_tree(self) -> None: await _assert_stream_closed(stream) @pytest.mark.anyio - async def test_early_parent_exit(self) -> None: + async def test_early_parent_exit(self): """Parent exits immediately on SIGTERM; process-group termination still catches the child (exercises the race where the parent dies mid-cleanup). """ @@ -447,7 +447,7 @@ async def test_early_parent_exit(self) -> None: @pytest.mark.anyio -async def test_stdio_client_graceful_stdin_exit() -> None: +async def test_stdio_client_graceful_stdin_exit(): """Test that a process exits gracefully when stdin is closed, without needing SIGTERM or SIGKILL. """ @@ -502,7 +502,7 @@ async def test_stdio_client_graceful_stdin_exit() -> None: @pytest.mark.anyio -async def test_stdio_client_stdin_close_ignored() -> None: +async def test_stdio_client_stdin_close_ignored(): """Test that when a process ignores stdin closure, the shutdown sequence properly escalates to SIGTERM. """ diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 406f69f40..c8fc41fd5 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -50,7 +50,7 @@ def test_resource() -> str: # pragma: no cover pytestmark = pytest.mark.anyio -async def test_with_server(simple_server: Server) -> None: +async def test_with_server(simple_server: Server): """Test creating transport with a Server instance.""" transport = InMemoryTransport(simple_server) async with transport as (read_stream, write_stream): @@ -58,7 +58,7 @@ async def test_with_server(simple_server: Server) -> None: assert write_stream is not None -async def test_with_mcpserver(mcpserver_server: MCPServer) -> None: +async def test_with_mcpserver(mcpserver_server: MCPServer): """Test creating transport with an MCPServer instance.""" transport = InMemoryTransport(mcpserver_server) async with transport as (read_stream, write_stream): @@ -66,13 +66,13 @@ async def test_with_mcpserver(mcpserver_server: MCPServer) -> None: assert write_stream is not None -async def test_server_is_running(mcpserver_server: MCPServer) -> None: +async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: assert client.initialize_result.capabilities.tools is not None -async def test_list_tools(mcpserver_server: MCPServer) -> None: +async def test_list_tools(mcpserver_server: MCPServer): """Test listing tools through the transport.""" async with Client(mcpserver_server) as client: tools_result = await client.list_tools() @@ -81,7 +81,7 @@ async def test_list_tools(mcpserver_server: MCPServer) -> None: assert "greet" in tool_names -async def test_call_tool(mcpserver_server: MCPServer) -> None: +async def test_call_tool(mcpserver_server: MCPServer): """Test calling a tool through the transport.""" async with Client(mcpserver_server) as client: result = await client.call_tool("greet", {"name": "World"}) @@ -90,7 +90,7 @@ async def test_call_tool(mcpserver_server: MCPServer) -> None: assert "Hello, World!" in str(result.content[0]) -async def test_raise_exceptions(mcpserver_server: MCPServer) -> None: +async def test_raise_exceptions(mcpserver_server: MCPServer): """Test that raise_exceptions parameter is passed through.""" transport = InMemoryTransport(mcpserver_server, raise_exceptions=True) async with transport as (read_stream, _write_stream): diff --git a/tests/conftest.py b/tests/conftest.py index 5c53fe0ac..af7e47993 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,5 +2,5 @@ @pytest.fixture -def anyio_backend() -> str: +def anyio_backend(): return "asyncio" diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index 6bbc5699b..1ea2199e8 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -21,14 +21,14 @@ @pytest.mark.anyio -async def test_client_capabilities_without_tasks() -> None: +async def test_client_capabilities_without_tasks(): """Test that tasks capability is None when not provided.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -78,7 +78,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_capabilities_with_tasks() -> None: +async def test_client_capabilities_with_tasks(): """Test that tasks capability is properly set when handlers are provided.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -98,7 +98,7 @@ async def my_cancel_task_handler( ) -> types.CancelTaskResult | types.ErrorData: raise NotImplementedError - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -158,7 +158,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_capabilities_auto_built_from_handlers() -> None: +async def test_client_capabilities_auto_built_from_handlers(): """Test that tasks capability is automatically built from provided handlers.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -178,7 +178,7 @@ async def my_cancel_task_handler( ) -> types.CancelTaskResult | types.ErrorData: raise NotImplementedError - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() @@ -239,7 +239,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_client_capabilities_with_task_augmented_handlers() -> None: +async def test_client_capabilities_with_task_augmented_handlers(): """Test that requests capability is built when augmented handlers are provided.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -254,7 +254,7 @@ async def my_augmented_sampling_handler( ) -> types.CreateTaskResult | types.ErrorData: raise NotImplementedError - async def mock_server() -> None: + async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index fc3361f99..613c794eb 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -1,7 +1,7 @@ """Tests for the experimental client task methods (session.experimental).""" -from collections.abc import AsyncIterator, Callable -from contextlib import AbstractAsyncContextManager, asynccontextmanager +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field import anyio @@ -72,9 +72,7 @@ async def do_work() -> None: raise NotImplementedError -def _make_lifespan( - store: InMemoryTaskStore, task_done_events: dict[str, Event] -) -> Callable[[Server[AppContext]], AbstractAsyncContextManager[AppContext]]: +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): @asynccontextmanager async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: async with anyio.create_task_group() as tg: diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index 0f098152d..b5b79033d 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -8,8 +8,8 @@ 5. Client retrieves result with tasks/result """ -from collections.abc import AsyncIterator, Callable -from contextlib import AbstractAsyncContextManager, asynccontextmanager +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field import anyio @@ -49,9 +49,7 @@ class AppContext: task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -def _make_lifespan( - store: InMemoryTaskStore, task_done_events: dict[str, Event] -) -> Callable[[Server[AppContext]], AbstractAsyncContextManager[AppContext]]: +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): @asynccontextmanager async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: async with anyio.create_task_group() as tg: diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index 259970cce..e59fb632d 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -5,7 +5,7 @@ pytestmark = pytest.mark.anyio -async def test_list_tools_returns_all_tools() -> None: +async def test_list_tools_returns_all_tools(): mcp = MCPServer("TestTools") # Create 100 tools with unique names @@ -13,7 +13,7 @@ async def test_list_tools_returns_all_tools() -> None: for i in range(num_tools): @mcp.tool(name=f"tool_{i}") - def dummy_tool_func() -> int: # pragma: no cover + def dummy_tool_func(): # pragma: no cover f"""Tool number {i}""" return i diff --git a/tests/issues/test_1027_win_unreachable_cleanup.py b/tests/issues/test_1027_win_unreachable_cleanup.py index ac2ca8937..c59c5aeca 100644 --- a/tests/issues/test_1027_win_unreachable_cleanup.py +++ b/tests/issues/test_1027_win_unreachable_cleanup.py @@ -21,7 +21,7 @@ @pytest.mark.anyio -async def test_lifespan_cleanup_executed() -> None: +async def test_lifespan_cleanup_executed(): """Regression test ensuring MCP server cleanup code runs during shutdown. This test verifies that the fix for issue #1027 works correctly by: @@ -121,7 +121,7 @@ def echo(text: str) -> str: @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") -async def test_stdin_close_triggers_cleanup() -> None: +async def test_stdin_close_triggers_cleanup(): """Regression test verifying the stdin-based graceful shutdown mechanism. This test ensures the core fix for issue #1027 continues to work by: diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 43931290f..bb4735121 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -5,7 +5,7 @@ @pytest.mark.anyio -async def test_resource_templates() -> None: +async def test_resource_templates(): mcp = MCPServer("Demo") @mcp.resource("greeting://{name}") diff --git a/tests/issues/test_1338_icons_and_metadata.py b/tests/issues/test_1338_icons_and_metadata.py index 985d92865..a003f75b8 100644 --- a/tests/issues/test_1338_icons_and_metadata.py +++ b/tests/issues/test_1338_icons_and_metadata.py @@ -8,7 +8,7 @@ pytestmark = pytest.mark.anyio -async def test_icons_and_website_url() -> None: +async def test_icons_and_website_url(): """Test that icons and websiteUrl are properly returned in API calls.""" # Create test icon @@ -92,7 +92,7 @@ def test_resource_template(city: str) -> str: # pragma: no cover assert template.icons[0].src == test_icon.src -async def test_multiple_icons() -> None: +async def test_multiple_icons(): """Test that multiple icons can be added to tools, resources, and prompts.""" # Create multiple test icons @@ -119,7 +119,7 @@ def multi_icon_tool() -> str: # pragma: no cover assert tool.icons[2].sizes == ["64x64"] -async def test_no_icons_or_website() -> None: +async def test_no_icons_or_website(): """Test that server works without icons or websiteUrl.""" mcp = MCPServer("BasicServer") diff --git a/tests/issues/test_1363_race_condition_streamable_http.py b/tests/issues/test_1363_race_condition_streamable_http.py index a813c596d..db2a82d07 100644 --- a/tests/issues/test_1363_race_condition_streamable_http.py +++ b/tests/issues/test_1363_race_condition_streamable_http.py @@ -33,7 +33,7 @@ class RaceConditionTestServer(Server): - def __init__(self) -> None: + def __init__(self): super().__init__(SERVER_NAME) @@ -64,7 +64,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: class ServerThread(threading.Thread): """Thread that runs the ASGI application lifespan in a separate event loop.""" - def __init__(self, app: Starlette) -> None: + def __init__(self, app: Starlette): super().__init__(daemon=True) self.app = app self._stop_event = threading.Event() @@ -73,7 +73,7 @@ def run(self) -> None: """Run the lifespan in a new event loop.""" # Create a new event loop for this thread - async def run_lifespan() -> None: + async def run_lifespan(): # Use the lifespan context (always present in our tests) lifespan_context = getattr(self.app.router, "lifespan_context", None) assert lifespan_context is not None # Tests always create apps with lifespan @@ -119,7 +119,7 @@ def check_logs_for_race_condition_errors(caplog: pytest.LogCaptureFixture, test_ @pytest.mark.anyio -async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFixture) -> None: +async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFixture): """Test the race condition with invalid Accept headers. This test reproduces the exact scenario described in issue #1363: @@ -193,7 +193,7 @@ async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFi @pytest.mark.anyio -async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixture) -> None: +async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixture): """Test the race condition with invalid Content-Type headers. This test reproduces the race condition scenario with Content-Type validation failure. @@ -233,7 +233,7 @@ async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixt @pytest.mark.anyio -async def test_race_condition_message_router_async_for(caplog: pytest.LogCaptureFixture) -> None: +async def test_race_condition_message_router_async_for(caplog: pytest.LogCaptureFixture): """Uses json_response=True to trigger the `if self.is_json_response_enabled` branch, which reproduces the ClosedResourceError when message_router is suspended in async for loop while transport cleanup closes streams concurrently. diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index 04cff1271..f5c5081c3 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -10,7 +10,7 @@ @pytest.mark.anyio -async def test_resource_template_edge_cases() -> None: +async def test_resource_template_edge_cases(): """Test server-side resource template validation""" mcp = MCPServer("Demo") @@ -63,7 +63,7 @@ def get_user_profile_missing(user_id: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_resource_template_client_interaction() -> None: +async def test_resource_template_client_interaction(): """Test client-side resource template interaction""" mcp = MCPServer("Demo") diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index 7ae03198b..851e89979 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -17,7 +17,7 @@ pytestmark = pytest.mark.anyio -async def test_mcpserver_resource_mime_type() -> None: +async def test_mcpserver_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" mcp = MCPServer("test") @@ -63,7 +63,7 @@ def get_image_as_bytes() -> bytes: assert bytes_result.contents[0].mime_type == "image/png", "Bytes content mime type not preserved" -async def test_lowlevel_resource_mime_type() -> None: +async def test_lowlevel_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" # Create a small test image as bytes diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index d992e2bd9..c67708128 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -25,7 +25,7 @@ pytestmark = pytest.mark.anyio -async def test_relative_uri_roundtrip() -> None: +async def test_relative_uri_roundtrip(): """Relative URIs survive the full server-client JSON-RPC roundtrip. This is the critical regression test - if someone reintroduces AnyUrl, @@ -67,7 +67,7 @@ async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRe assert result.contents[0].uri == uri_str -async def test_custom_scheme_uri_roundtrip() -> None: +async def test_custom_scheme_uri_roundtrip(): """Custom scheme URIs work through the protocol. Some MCP servers use custom schemes like "custom://resource". @@ -103,7 +103,7 @@ async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRe assert len(result.contents) == 1 -def test_uri_json_roundtrip_preserves_value() -> None: +def test_uri_json_roundtrip_preserves_value(): """URI is preserved exactly through JSON serialization. This catches any Pydantic validation or normalization that would @@ -125,7 +125,7 @@ def test_uri_json_roundtrip_preserves_value() -> None: assert restored.uri == uri_str, f"URI mutated: {uri_str} -> {restored.uri}" -def test_resource_contents_uri_json_roundtrip() -> None: +def test_resource_contents_uri_json_roundtrip(): """TextResourceContents URI is preserved through JSON serialization.""" test_uris = ["users/me", "./relative", "custom://resource"] diff --git a/tests/issues/test_1754_mime_type_parameters.py b/tests/issues/test_1754_mime_type_parameters.py index a798fcab1..7903fd560 100644 --- a/tests/issues/test_1754_mime_type_parameters.py +++ b/tests/issues/test_1754_mime_type_parameters.py @@ -12,7 +12,7 @@ pytestmark = pytest.mark.anyio -async def test_mime_type_with_parameters() -> None: +async def test_mime_type_with_parameters(): """Test that MIME types with parameters are accepted (RFC 2045).""" mcp = MCPServer("test") @@ -26,7 +26,7 @@ def widget() -> str: assert resources[0].mime_type == "text/html;profile=mcp-app" -async def test_mime_type_with_parameters_and_space() -> None: +async def test_mime_type_with_parameters_and_space(): """Test MIME type with space after semicolon.""" mcp = MCPServer("test") @@ -39,7 +39,7 @@ def data() -> str: assert resources[0].mime_type == "application/json; charset=utf-8" -async def test_mime_type_with_multiple_parameters() -> None: +async def test_mime_type_with_multiple_parameters(): """Test MIME type with multiple parameters.""" mcp = MCPServer("test") @@ -52,7 +52,7 @@ def data() -> str: assert resources[0].mime_type == "text/plain; charset=utf-8; format=fixed" -async def test_mime_type_preserved_in_read_resource() -> None: +async def test_mime_type_preserved_in_read_resource(): """Test that MIME type with parameters is preserved when reading resource.""" mcp = MCPServer("test") diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 21cad62b6..5d5f8b8fc 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -9,7 +9,7 @@ pytestmark = pytest.mark.anyio -async def test_progress_token_zero_first_call() -> None: +async def test_progress_token_zero_first_call(): """Test that progress notifications work when progress_token is 0 on first call.""" # Create mock session with progress notification tracking diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 61d113341..0e11f6148 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -6,14 +6,14 @@ @pytest.mark.anyio -async def test_messages_are_executed_concurrently_tools() -> None: +async def test_messages_are_executed_concurrently_tools(): server = MCPServer("test") event = anyio.Event() tool_started = anyio.Event() call_order: list[str] = [] @server.tool("sleep") - async def sleep_tool() -> str: + async def sleep_tool(): call_order.append("waiting_for_event") tool_started.set() await event.wait() @@ -21,7 +21,7 @@ async def sleep_tool() -> str: return "done" @server.tool("trigger") - async def trigger() -> str: + async def trigger(): # Wait for tool to start before setting the event await tool_started.wait() call_order.append("trigger_started") @@ -47,14 +47,14 @@ async def trigger() -> str: @pytest.mark.anyio -async def test_messages_are_executed_concurrently_tools_and_resources() -> None: +async def test_messages_are_executed_concurrently_tools_and_resources(): server = MCPServer("test") event = anyio.Event() tool_started = anyio.Event() call_order: list[str] = [] @server.tool("sleep") - async def sleep_tool() -> str: + async def sleep_tool(): call_order.append("waiting_for_event") tool_started.set() await event.wait() @@ -62,7 +62,7 @@ async def sleep_tool() -> str: return "done" @server.resource("slow://slow_resource") - async def slow_resource() -> str: + async def slow_resource(): # Wait for tool to start before setting the event await tool_started.wait() event.set() diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index d3c7d9497..de96dbe23 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -28,7 +28,7 @@ async def test_request_id_match() -> None: server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) # Server task to process the request - async def run_server() -> None: + async def run_server(): async with client_reader, server_writer: await server.run( client_reader, diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 8294c82eb..2bccedf8d 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -15,7 +15,7 @@ pytestmark = pytest.mark.anyio -async def test_server_base64_encoding() -> None: +async def test_server_base64_encoding(): """Tests that binary resource data round-trips correctly through base64 encoding. The test uses binary data that produces different results with urlsafe vs standard diff --git a/tests/issues/test_355_type_error.py b/tests/issues/test_355_type_error.py index 29cb03d7a..905cf7eee 100644 --- a/tests/issues/test_355_type_error.py +++ b/tests/issues/test_355_type_error.py @@ -7,13 +7,13 @@ class Database: # Replace with your actual DB type @classmethod - async def connect(cls) -> "Database": # pragma: no cover + async def connect(cls): # pragma: no cover return cls() - async def disconnect(self) -> None: # pragma: no cover + async def disconnect(self): # pragma: no cover pass - def query(self) -> str: # pragma: no cover + def query(self): # pragma: no cover return "Hello, World!" diff --git a/tests/issues/test_552_windows_hang.py b/tests/issues/test_552_windows_hang.py index f254e2183..1adb5d80c 100644 --- a/tests/issues/test_552_windows_hang.py +++ b/tests/issues/test_552_windows_hang.py @@ -12,7 +12,7 @@ @pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific test") # pragma: no cover @pytest.mark.anyio -async def test_windows_stdio_client_with_session() -> None: +async def test_windows_stdio_client_with_session(): """Test the exact scenario from issue #552: Using ClientSession with stdio_client. This reproduces the original bug report where stdio_client hangs on Windows 11 diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 92d2ee3d6..6b593d2a5 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -16,7 +16,7 @@ @pytest.mark.anyio -async def test_notification_validation_error(tmp_path: Path) -> None: +async def test_notification_validation_error(tmp_path: Path): """Test that timeouts are handled gracefully and don't break the server. This test verifies that when a client request times out: @@ -67,7 +67,7 @@ async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ) -> None: + ): with anyio.CancelScope() as scope: task_status.started(scope) # type: ignore await server.run( @@ -81,7 +81,7 @@ async def client( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], scope: anyio.CancelScope, - ) -> None: + ): # No session-level timeout to avoid race conditions with fast operations async with ClientSession(read_stream, write_stream) as session: await session.initialize() diff --git a/tests/issues/test_973_url_decoding.py b/tests/issues/test_973_url_decoding.py index 65214c61d..01cf222b9 100644 --- a/tests/issues/test_973_url_decoding.py +++ b/tests/issues/test_973_url_decoding.py @@ -6,7 +6,7 @@ from mcp.server.mcpserver.resources import ResourceTemplate -def test_template_matches_decodes_space() -> None: +def test_template_matches_decodes_space(): """Test that %20 is decoded to space.""" def search(query: str) -> str: # pragma: no cover @@ -23,7 +23,7 @@ def search(query: str) -> str: # pragma: no cover assert params["query"] == "hello world" -def test_template_matches_decodes_accented_characters() -> None: +def test_template_matches_decodes_accented_characters(): """Test that %C3%A9 is decoded to e with accent.""" def search(query: str) -> str: # pragma: no cover @@ -40,7 +40,7 @@ def search(query: str) -> str: # pragma: no cover assert params["query"] == "café" -def test_template_matches_decodes_complex_phrase() -> None: +def test_template_matches_decodes_complex_phrase(): """Test complex French phrase from the original issue.""" def search(query: str) -> str: # pragma: no cover @@ -57,7 +57,7 @@ def search(query: str) -> str: # pragma: no cover assert params["query"] == "stick correcteur teinté anti-imperfections" -def test_template_matches_preserves_plus_sign() -> None: +def test_template_matches_preserves_plus_sign(): """Test that plus sign remains as plus (not converted to space). In URI encoding, %20 is space. Plus-as-space is only for diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 3c40f8157..da586f309 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_malformed_initialize_request_does_not_crash_server() -> None: +async def test_malformed_initialize_request_does_not_crash_server(): """Test that malformed initialize requests return proper error responses instead of crashing the server (HackerOne #3156202). """ @@ -91,7 +91,7 @@ async def test_malformed_initialize_request_does_not_crash_server() -> None: @pytest.mark.anyio -async def test_multiple_concurrent_malformed_requests() -> None: +async def test_multiple_concurrent_malformed_requests(): """Test that multiple concurrent malformed requests don't crash the server.""" # Create in-memory streams for testing read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100) diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index a844bc368..66481bcf7 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -17,7 +17,7 @@ class MockApp: """Mock ASGI app for testing.""" - def __init__(self) -> None: + def __init__(self): self.called = False self.scope: Scope | None = None self.receive: Receive | None = None @@ -45,7 +45,7 @@ def valid_access_token() -> AccessToken: @pytest.mark.anyio -async def test_auth_context_middleware_with_authenticated_user(valid_access_token: AccessToken) -> None: +async def test_auth_context_middleware_with_authenticated_user(valid_access_token: AccessToken): """Test middleware with an authenticated user in scope.""" app = MockApp() middleware = AuthContextMiddleware(app) @@ -84,7 +84,7 @@ async def send(message: Message) -> None: # pragma: no cover @pytest.mark.anyio -async def test_auth_context_middleware_with_no_user() -> None: +async def test_auth_context_middleware_with_no_user(): """Test middleware with no user in scope.""" app = MockApp() middleware = AuthContextMiddleware(app) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 1baf5818b..bd14e294c 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -20,7 +20,7 @@ class MockOAuthProvider: the BearerAuthMiddleware components. """ - def __init__(self) -> None: + def __init__(self): self.tokens: dict[str, AccessToken] = {} # token -> AccessToken def add_token(self, token: str, access_token: AccessToken) -> None: @@ -49,7 +49,7 @@ def add_token_to_provider( class MockApp: """Mock ASGI app for testing.""" - def __init__(self) -> None: + def __init__(self): self.called = False self.scope: Scope | None = None self.receive: Receive | None = None @@ -106,16 +106,14 @@ def no_expiry_access_token() -> AccessToken: class TestBearerAuthBackend: """Tests for the BearerAuthBackend class.""" - async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]) -> None: + async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with no Authorization header.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request({"type": "http", "headers": []}) result = await backend.authenticate(request) assert result is None - async def test_non_bearer_auth_header( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ) -> None: + async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with non-Bearer Authorization header.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( @@ -127,7 +125,7 @@ async def test_non_bearer_auth_header( result = await backend.authenticate(request) assert result is None - async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]) -> None: + async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with invalid token.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( @@ -143,7 +141,7 @@ async def test_expired_token( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], expired_access_token: AccessToken, - ) -> None: + ): """Test authentication with expired token.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) @@ -160,7 +158,7 @@ async def test_valid_token( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ) -> None: + ): """Test authentication with valid token.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -184,7 +182,7 @@ async def test_token_without_expiry( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], no_expiry_access_token: AccessToken, - ) -> None: + ): """Test authentication with token that has no expiry.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) @@ -208,7 +206,7 @@ async def test_lowercase_bearer_prefix( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ) -> None: + ): """Test with lowercase 'bearer' prefix in Authorization header""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -228,7 +226,7 @@ async def test_mixed_case_bearer_prefix( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ) -> None: + ): """Test with mixed 'BeArEr' prefix in Authorization header""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -248,7 +246,7 @@ async def test_mixed_case_authorization_header( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], valid_access_token: AccessToken, - ) -> None: + ): """Test authentication with mixed 'Authorization' header.""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) @@ -269,7 +267,7 @@ async def test_mixed_case_authorization_header( class TestRequireAuthMiddleware: """Tests for the RequireAuthMiddleware class.""" - async def test_no_user(self) -> None: + async def test_no_user(self): """Test middleware with no user in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -293,7 +291,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_non_authenticated_user(self) -> None: + async def test_non_authenticated_user(self): """Test middleware with non-authenticated user in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -317,7 +315,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_missing_required_scope(self, valid_access_token: AccessToken) -> None: + async def test_missing_required_scope(self, valid_access_token: AccessToken): """Test middleware with user missing required scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["admin"]) @@ -346,7 +344,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_no_auth_credentials(self, valid_access_token: AccessToken) -> None: + async def test_no_auth_credentials(self, valid_access_token: AccessToken): """Test middleware with no auth credentials in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -374,7 +372,7 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_has_required_scopes(self, valid_access_token: AccessToken) -> None: + async def test_has_required_scopes(self, valid_access_token: AccessToken): """Test middleware with user having all required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) @@ -399,7 +397,7 @@ async def send(message: Message) -> None: # pragma: no cover assert app.receive == receive assert app.send == send - async def test_multiple_required_scopes(self, valid_access_token: AccessToken) -> None: + async def test_multiple_required_scopes(self, valid_access_token: AccessToken): """Test middleware with multiple required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"]) @@ -424,7 +422,7 @@ async def send(message: Message) -> None: # pragma: no cover assert app.receive == receive assert app.send == send - async def test_no_required_scopes(self, valid_access_token: AccessToken) -> None: + async def test_no_required_scopes(self, valid_access_token: AccessToken): """Test middleware with no required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=[]) diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index 977106078..7c5c43582 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -20,13 +20,13 @@ @pytest.fixture -def oauth_provider() -> MockOAuthProvider: +def oauth_provider(): """Return a MockOAuthProvider instance that can be configured to raise errors.""" return MockOAuthProvider() @pytest.fixture -def app(oauth_provider: MockOAuthProvider) -> Starlette: +def app(oauth_provider: MockOAuthProvider): # Enable client registration client_registration_options = ClientRegistrationOptions(enabled=True) revocation_options = RevocationOptions(enabled=True) @@ -44,14 +44,14 @@ def app(oauth_provider: MockOAuthProvider) -> Starlette: @pytest.fixture -def client(app: Starlette) -> httpx.AsyncClient: +def client(app: Starlette): transport = ASGITransport(app=app) # Use base_url without a path since routes are directly on the app return httpx.AsyncClient(transport=transport, base_url="http://localhost") @pytest.fixture -def pkce_challenge() -> dict[str, str]: +def pkce_challenge(): """Create a PKCE challenge with code_verifier and code_challenge.""" # Generate a code verifier code_verifier = secrets.token_urlsafe(64)[:128] @@ -84,7 +84,7 @@ async def registered_client(client: httpx.AsyncClient) -> dict[str, Any]: @pytest.mark.anyio -async def test_registration_error_handling(client: httpx.AsyncClient, oauth_provider: MockOAuthProvider) -> None: +async def test_registration_error_handling(client: httpx.AsyncClient, oauth_provider: MockOAuthProvider): # Mock the register_client method to raise a registration error with unittest.mock.patch.object( oauth_provider, @@ -122,7 +122,7 @@ async def test_authorize_error_handling( oauth_provider: MockOAuthProvider, registered_client: dict[str, Any], pkce_challenge: dict[str, str], -) -> None: +): # Mock the authorize method to raise an authorize error with unittest.mock.patch.object( oauth_provider, @@ -163,7 +163,7 @@ async def test_token_error_handling_auth_code( oauth_provider: MockOAuthProvider, registered_client: dict[str, Any], pkce_challenge: dict[str, str], -) -> None: +): # Register the client and get an auth code client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] @@ -222,7 +222,7 @@ async def test_token_error_handling_refresh_token( oauth_provider: MockOAuthProvider, registered_client: dict[str, Any], pkce_challenge: dict[str, str], -) -> None: +): # Register the client and get tokens client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] diff --git a/tests/server/auth/test_protected_resource.py b/tests/server/auth/test_protected_resource.py index 35fd3de2a..413a80276 100644 --- a/tests/server/auth/test_protected_resource.py +++ b/tests/server/auth/test_protected_resource.py @@ -1,6 +1,5 @@ """Integration tests for MCP Oauth Protected Resource.""" -from collections.abc import AsyncGenerator from urllib.parse import urlparse import httpx @@ -13,7 +12,7 @@ @pytest.fixture -def test_app() -> Starlette: +def test_app(): """Fixture to create protected resource routes for testing.""" # Create the protected resource routes @@ -30,14 +29,14 @@ def test_app() -> Starlette: @pytest.fixture -async def test_client(test_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: +async def test_client(test_app: Starlette): """Fixture to create an HTTP client for the protected resource app.""" async with httpx.AsyncClient(transport=httpx.ASGITransport(app=test_app), base_url="https://mcptest.com") as client: yield client @pytest.mark.anyio -async def test_metadata_endpoint_with_path(test_client: httpx.AsyncClient) -> None: +async def test_metadata_endpoint_with_path(test_client: httpx.AsyncClient): """Test the OAuth 2.0 Protected Resource metadata endpoint for path-based resource.""" # For resource with path "/resource", metadata should be accessible at the path-aware location @@ -55,7 +54,7 @@ async def test_metadata_endpoint_with_path(test_client: httpx.AsyncClient) -> No @pytest.mark.anyio -async def test_metadata_endpoint_root_path_returns_404(test_client: httpx.AsyncClient) -> None: +async def test_metadata_endpoint_root_path_returns_404(test_client: httpx.AsyncClient): """Test that root path returns 404 for path-based resource.""" # Root path should return 404 for path-based resources @@ -64,7 +63,7 @@ async def test_metadata_endpoint_root_path_returns_404(test_client: httpx.AsyncC @pytest.fixture -def root_resource_app() -> Starlette: +def root_resource_app(): """Fixture to create protected resource routes for root-level resource.""" # Create routes for a resource without path component @@ -80,7 +79,7 @@ def root_resource_app() -> Starlette: @pytest.fixture -async def root_resource_client(root_resource_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: +async def root_resource_client(root_resource_app: Starlette): """Fixture to create an HTTP client for the root resource app.""" async with httpx.AsyncClient( transport=httpx.ASGITransport(app=root_resource_app), base_url="https://mcptest.com" @@ -89,7 +88,7 @@ async def root_resource_client(root_resource_app: Starlette) -> AsyncGenerator[h @pytest.mark.anyio -async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncClient) -> None: +async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncClient): """Test metadata endpoint for root-level resource.""" # For root resource, metadata should be at standard location @@ -109,21 +108,21 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC # Tests for URL construction utility function -def test_metadata_url_construction_url_without_path() -> None: +def test_metadata_url_construction_url_without_path(): """Test URL construction for resource without path component.""" resource_url = AnyHttpUrl("https://example.com") result = build_resource_metadata_url(resource_url) assert str(result) == "https://example.com/.well-known/oauth-protected-resource" -def test_metadata_url_construction_url_with_path_component() -> None: +def test_metadata_url_construction_url_with_path_component(): """Test URL construction for resource with path component.""" resource_url = AnyHttpUrl("https://example.com/mcp") result = build_resource_metadata_url(resource_url) assert str(result) == "https://example.com/.well-known/oauth-protected-resource/mcp" -def test_metadata_url_construction_url_with_trailing_slash_only() -> None: +def test_metadata_url_construction_url_with_trailing_slash_only(): """Test URL construction for resource with trailing slash only.""" resource_url = AnyHttpUrl("https://example.com/") result = build_resource_metadata_url(resource_url) @@ -140,7 +139,7 @@ def test_metadata_url_construction_url_with_trailing_slash_only() -> None: ("http://localhost:8001/mcp", "http://localhost:8001/.well-known/oauth-protected-resource/mcp"), ], ) -def test_metadata_url_construction_various_resource_configurations(resource_url: str, expected_url: str) -> None: +def test_metadata_url_construction_various_resource_configurations(resource_url: str, expected_url: str): """Test URL construction with various resource configurations.""" result = build_resource_metadata_url(AnyHttpUrl(resource_url)) assert str(result) == expected_url @@ -149,7 +148,7 @@ def test_metadata_url_construction_various_resource_configurations(resource_url: # Tests for consistency between URL generation and route registration -def test_route_consistency_route_path_matches_metadata_url() -> None: +def test_route_consistency_route_path_matches_metadata_url(): """Test that route path matches the generated metadata URL.""" resource_url = AnyHttpUrl("https://example.com/mcp") @@ -178,7 +177,7 @@ def test_route_consistency_route_path_matches_metadata_url() -> None: ("https://example.com/mcp", "/.well-known/oauth-protected-resource/mcp"), ], ) -def test_route_consistency_consistent_paths_for_various_resources(resource_url: str, expected_path: str) -> None: +def test_route_consistency_consistent_paths_for_various_resources(resource_url: str, expected_path: str): """Test that URL generation and route creation are consistent.""" resource_url_obj = AnyHttpUrl(resource_url) diff --git a/tests/server/auth/test_provider.py b/tests/server/auth/test_provider.py index b71b6ff5b..aaaeb413a 100644 --- a/tests/server/auth/test_provider.py +++ b/tests/server/auth/test_provider.py @@ -3,7 +3,7 @@ from mcp.server.auth.provider import construct_redirect_uri -def test_construct_redirect_uri_no_existing_params() -> None: +def test_construct_redirect_uri_no_existing_params(): """Test construct_redirect_uri with no existing query parameters.""" base_uri = "http://localhost:8000/callback" result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") @@ -11,7 +11,7 @@ def test_construct_redirect_uri_no_existing_params() -> None: assert "http://localhost:8000/callback?code=auth_code&state=test_state" == result -def test_construct_redirect_uri_with_existing_params() -> None: +def test_construct_redirect_uri_with_existing_params(): """Test construct_redirect_uri with existing query parameters (regression test for #1279).""" base_uri = "http://localhost:8000/callback?session_id=1234" result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") @@ -23,7 +23,7 @@ def test_construct_redirect_uri_with_existing_params() -> None: assert result.startswith("http://localhost:8000/callback?") -def test_construct_redirect_uri_multiple_existing_params() -> None: +def test_construct_redirect_uri_multiple_existing_params(): """Test construct_redirect_uri with multiple existing query parameters.""" base_uri = "http://localhost:8000/callback?session_id=1234&user=test" result = construct_redirect_uri(base_uri, code="auth_code") @@ -33,7 +33,7 @@ def test_construct_redirect_uri_multiple_existing_params() -> None: assert "code=auth_code" in result -def test_construct_redirect_uri_with_none_values() -> None: +def test_construct_redirect_uri_with_none_values(): """Test construct_redirect_uri filters out None values.""" base_uri = "http://localhost:8000/callback" result = construct_redirect_uri(base_uri, code="auth_code", state=None) @@ -42,7 +42,7 @@ def test_construct_redirect_uri_with_none_values() -> None: assert "state" not in result -def test_construct_redirect_uri_empty_params() -> None: +def test_construct_redirect_uri_empty_params(): """Test construct_redirect_uri with no additional parameters.""" base_uri = "http://localhost:8000/callback?existing=param" result = construct_redirect_uri(base_uri) @@ -50,7 +50,7 @@ def test_construct_redirect_uri_empty_params() -> None: assert result == "http://localhost:8000/callback?existing=param" -def test_construct_redirect_uri_duplicate_param_names() -> None: +def test_construct_redirect_uri_duplicate_param_names(): """Test construct_redirect_uri when adding param that already exists.""" base_uri = "http://localhost:8000/callback?code=existing" result = construct_redirect_uri(base_uri, code="new_code") @@ -60,7 +60,7 @@ def test_construct_redirect_uri_duplicate_param_names() -> None: assert "code=new_code" in result -def test_construct_redirect_uri_multivalued_existing_params() -> None: +def test_construct_redirect_uri_multivalued_existing_params(): """Test construct_redirect_uri with existing multi-valued parameters.""" base_uri = "http://localhost:8000/callback?scope=read&scope=write" result = construct_redirect_uri(base_uri, code="auth_code") @@ -70,7 +70,7 @@ def test_construct_redirect_uri_multivalued_existing_params() -> None: assert "code=auth_code" in result -def test_construct_redirect_uri_encoded_values() -> None: +def test_construct_redirect_uri_encoded_values(): """Test construct_redirect_uri handles URL encoding properly.""" base_uri = "http://localhost:8000/callback" result = construct_redirect_uri(base_uri, state="test state with spaces") diff --git a/tests/server/auth/test_routes.py b/tests/server/auth/test_routes.py index a910edc8b..3d13b5ba5 100644 --- a/tests/server/auth/test_routes.py +++ b/tests/server/auth/test_routes.py @@ -4,44 +4,44 @@ from mcp.server.auth.routes import validate_issuer_url -def test_validate_issuer_url_https_allowed() -> None: +def test_validate_issuer_url_https_allowed(): validate_issuer_url(AnyHttpUrl("https://example.com/path")) -def test_validate_issuer_url_http_localhost_allowed() -> None: +def test_validate_issuer_url_http_localhost_allowed(): validate_issuer_url(AnyHttpUrl("http://localhost:8080/path")) -def test_validate_issuer_url_http_127_0_0_1_allowed() -> None: +def test_validate_issuer_url_http_127_0_0_1_allowed(): validate_issuer_url(AnyHttpUrl("http://127.0.0.1:8080/path")) -def test_validate_issuer_url_http_ipv6_loopback_allowed() -> None: +def test_validate_issuer_url_http_ipv6_loopback_allowed(): validate_issuer_url(AnyHttpUrl("http://[::1]:8080/path")) -def test_validate_issuer_url_http_non_loopback_rejected() -> None: +def test_validate_issuer_url_http_non_loopback_rejected(): with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): validate_issuer_url(AnyHttpUrl("http://evil.com/path")) -def test_validate_issuer_url_http_127_prefix_domain_rejected() -> None: +def test_validate_issuer_url_http_127_prefix_domain_rejected(): """A domain like 127.0.0.1.evil.com is not loopback.""" with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): validate_issuer_url(AnyHttpUrl("http://127.0.0.1.evil.com/path")) -def test_validate_issuer_url_http_127_prefix_subdomain_rejected() -> None: +def test_validate_issuer_url_http_127_prefix_subdomain_rejected(): """A domain like 127.0.0.1something.example.com is not loopback.""" with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): validate_issuer_url(AnyHttpUrl("http://127.0.0.1something.example.com/path")) -def test_validate_issuer_url_fragment_rejected() -> None: +def test_validate_issuer_url_fragment_rejected(): with pytest.raises(ValueError, match="fragment"): validate_issuer_url(AnyHttpUrl("https://example.com/path#frag")) -def test_validate_issuer_url_query_rejected() -> None: +def test_validate_issuer_url_query_rejected(): with pytest.raises(ValueError, match="query"): validate_issuer_url(AnyHttpUrl("https://example.com/path?q=1")) diff --git a/tests/server/lowlevel/test_helper_types.py b/tests/server/lowlevel/test_helper_types.py index d1a8ca0a4..e29273d3f 100644 --- a/tests/server/lowlevel/test_helper_types.py +++ b/tests/server/lowlevel/test_helper_types.py @@ -10,7 +10,7 @@ from mcp.server.lowlevel.helper_types import ReadResourceContents -def test_read_resource_contents_with_metadata() -> None: +def test_read_resource_contents_with_metadata(): """Test that ReadResourceContents accepts meta parameter. ReadResourceContents is an internal helper type used by the low-level MCP server. @@ -33,7 +33,7 @@ def test_read_resource_contents_with_metadata() -> None: assert contents.meta["cached"] is True -def test_read_resource_contents_without_metadata() -> None: +def test_read_resource_contents_without_metadata(): """Test that ReadResourceContents meta defaults to None.""" # Ensures backward compatibility - meta defaults to None, _meta omitted from protocol (helper_types.py:11) contents = ReadResourceContents( @@ -44,7 +44,7 @@ def test_read_resource_contents_without_metadata() -> None: assert contents.meta is None -def test_read_resource_contents_with_bytes() -> None: +def test_read_resource_contents_with_bytes(): """Test that ReadResourceContents works with bytes content and meta.""" # Verifies meta works with both str and bytes content (binary resources like images, PDFs) metadata = {"encoding": "utf-8"} diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 66ce88b2d..602f5cc75 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -5,7 +5,6 @@ import secrets import time import unittest.mock -from collections.abc import AsyncGenerator from typing import Any from urllib.parse import parse_qs, urlparse @@ -29,7 +28,7 @@ # Mock OAuth provider for testing class MockOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]): - def __init__(self) -> None: + def __init__(self): self.clients: dict[str, OAuthClientInformationFull] = {} self.auth_codes: dict[str, AuthorizationCode] = {} # code -> {client_id, code_challenge, redirect_uri} self.tokens: dict[str, AccessToken] = {} # token -> {client_id, scopes, expires_at} @@ -38,7 +37,7 @@ def __init__(self) -> None: async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) - async def register_client(self, client_info: OAuthClientInformationFull) -> None: + async def register_client(self, client_info: OAuthClientInformationFull): assert client_info.client_id is not None self.clients[client_info.client_id] = client_info @@ -189,12 +188,12 @@ async def revoke_token(self, token: AccessToken | RefreshToken) -> None: @pytest.fixture -def mock_oauth_provider() -> MockOAuthProvider: +def mock_oauth_provider(): return MockOAuthProvider() @pytest.fixture -def auth_app(mock_oauth_provider: MockOAuthProvider) -> Starlette: +def auth_app(mock_oauth_provider: MockOAuthProvider): # Create auth router auth_routes = create_auth_routes( mock_oauth_provider, @@ -215,7 +214,7 @@ def auth_app(mock_oauth_provider: MockOAuthProvider) -> Starlette: @pytest.fixture -async def test_client(auth_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: +async def test_client(auth_app: Starlette): async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client: yield client @@ -250,7 +249,7 @@ async def registered_client( @pytest.fixture -def pkce_challenge() -> dict[str, str]: +def pkce_challenge(): """Create a PKCE challenge with code_verifier and code_challenge.""" code_verifier = "some_random_verifier_string" code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=") @@ -264,7 +263,7 @@ async def auth_code( registered_client: dict[str, Any], pkce_challenge: dict[str, str], request: pytest.FixtureRequest, -) -> dict[str, str | None]: +): """Get an authorization code. Parameters can be customized via indirect parameterization: @@ -306,7 +305,7 @@ async def auth_code( class TestAuthEndpoints: @pytest.mark.anyio - async def test_metadata_endpoint(self, test_client: httpx.AsyncClient) -> None: + async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): """Test the OAuth 2.0 metadata endpoint.""" response = await test_client.get("/.well-known/oauth-authorization-server") @@ -328,7 +327,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient) -> None: assert metadata["service_documentation"] == "https://docs.example.com/" @pytest.mark.anyio - async def test_token_validation_error(self, test_client: httpx.AsyncClient) -> None: + async def test_token_validation_error(self, test_client: httpx.AsyncClient): """Test token endpoint error - validation error.""" # Missing required fields response = await test_client.post( @@ -351,7 +350,7 @@ async def test_token_invalid_client_secret_returns_invalid_client( registered_client: dict[str, Any], pkce_challenge: dict[str, str], mock_oauth_provider: MockOAuthProvider, - ) -> None: + ): """Test token endpoint returns 'invalid_client' for wrong client_secret per RFC 6749. RFC 6749 Section 5.2 defines: @@ -398,7 +397,7 @@ async def test_token_invalid_auth_code( test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str], - ) -> None: + ): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -426,7 +425,7 @@ async def test_token_expired_auth_code( auth_code: dict[str, str], pkce_challenge: dict[str, str], mock_oauth_provider: MockOAuthProvider, - ) -> None: + ): """Test token endpoint error - authorization code has expired.""" # Get the current time for our time mocking current_time = time.time() @@ -480,7 +479,7 @@ async def test_token_redirect_uri_mismatch( registered_client: dict[str, Any], auth_code: dict[str, str], pkce_challenge: dict[str, str], - ) -> None: + ): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -503,7 +502,7 @@ async def test_token_redirect_uri_mismatch( @pytest.mark.anyio async def test_token_code_verifier_mismatch( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], auth_code: dict[str, str] - ) -> None: + ): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -524,9 +523,7 @@ async def test_token_code_verifier_mismatch( assert "incorrect code_verifier" in error_response["error_description"] @pytest.mark.anyio - async def test_token_invalid_refresh_token( - self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] - ) -> None: + async def test_token_invalid_refresh_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): """Test token endpoint error - refresh token does not exist.""" # Try to use a non-existent refresh token response = await test_client.post( @@ -550,7 +547,7 @@ async def test_token_expired_refresh_token( registered_client: dict[str, Any], auth_code: dict[str, str], pkce_challenge: dict[str, str], - ) -> None: + ): """Test token endpoint error - refresh token has expired.""" # Step 1: First, let's create a token and refresh token at the current time current_time = time.time() @@ -598,7 +595,7 @@ async def test_token_invalid_scope( registered_client: dict[str, Any], auth_code: dict[str, str], pkce_challenge: dict[str, str], - ) -> None: + ): """Test token endpoint error - invalid scope in refresh token request.""" # Exchange authorization code for tokens token_response = await test_client.post( @@ -634,9 +631,7 @@ async def test_token_invalid_scope( assert "cannot request scope" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ) -> None: + async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -662,7 +657,7 @@ async def test_client_registration( # ) is not None @pytest.mark.anyio - async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient): """Test client registration with missing required fields.""" # Missing redirect_uris which is a required field client_metadata = { @@ -681,7 +676,7 @@ async def test_client_registration_missing_required_fields(self, test_client: ht assert error_data["error_description"] == "redirect_uris: Field required" @pytest.mark.anyio - async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient): """Test client registration with invalid URIs.""" # Invalid redirect_uri format client_metadata = { @@ -702,7 +697,7 @@ async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncCli ) @pytest.mark.anyio - async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient): """Test client registration with empty redirect_uris array.""" redirect_uris: list[str] = [] client_metadata = { @@ -723,7 +718,7 @@ async def test_client_registration_empty_redirect_uris(self, test_client: httpx. ) @pytest.mark.anyio - async def test_authorize_form_post(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]) -> None: + async def test_authorize_form_post(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): """Test the authorization endpoint using POST with form-encoded data.""" # Register a client client_metadata = { @@ -767,7 +762,7 @@ async def test_authorization_get( test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str], - ) -> None: + ): """Test the full authorization flow.""" # 1. Register a client client_metadata = { @@ -872,9 +867,7 @@ async def test_authorization_get( assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None @pytest.mark.anyio - async def test_revoke_invalid_token( - self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] - ) -> None: + async def test_revoke_invalid_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): """Test revoking an invalid token.""" response = await test_client.post( "/revoke", @@ -888,9 +881,7 @@ async def test_revoke_invalid_token( assert response.status_code == 200 @pytest.mark.anyio - async def test_revoke_with_malformed_token( - self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] - ) -> None: + async def test_revoke_with_malformed_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): response = await test_client.post( "/revoke", data={ @@ -906,7 +897,7 @@ async def test_revoke_with_malformed_token( assert "token_type_hint" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient): """Test client registration with scopes that are not allowed.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -925,7 +916,7 @@ async def test_client_registration_disallowed_scopes(self, test_client: httpx.As @pytest.mark.anyio async def test_client_registration_default_scopes( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ) -> None: + ): client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -947,7 +938,7 @@ async def test_client_registration_default_scopes( assert registered_client.scope == "read write" @pytest.mark.anyio - async def test_client_registration_with_authorization_code_only(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_with_authorization_code_only(self, test_client: httpx.AsyncClient): """Test that registration succeeds with only authorization_code (refresh_token is optional per RFC 7591).""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -962,7 +953,7 @@ async def test_client_registration_with_authorization_code_only(self, test_clien assert client_info["grant_types"] == ["authorization_code"] @pytest.mark.anyio - async def test_client_registration_missing_authorization_code(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_missing_authorization_code(self, test_client: httpx.AsyncClient): """Test that registration fails when authorization_code grant type is missing.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -978,7 +969,7 @@ async def test_client_registration_missing_authorization_code(self, test_client: assert error_data["error_description"] == "grant_types must include 'authorization_code'" @pytest.mark.anyio - async def test_client_registration_with_additional_grant_type(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_with_additional_grant_type(self, test_client: httpx.AsyncClient): client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -997,7 +988,7 @@ async def test_client_registration_with_additional_grant_type(self, test_client: @pytest.mark.anyio async def test_client_registration_with_additional_response_types( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ) -> None: + ): """Test that registration accepts additional response_types values alongside 'code'.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1015,7 +1006,7 @@ async def test_client_registration_with_additional_response_types( assert "code" in client.response_types @pytest.mark.anyio - async def test_client_registration_response_types_without_code(self, test_client: httpx.AsyncClient) -> None: + async def test_client_registration_response_types_without_code(self, test_client: httpx.AsyncClient): """Test that registration rejects response_types that don't include 'code'.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1034,7 +1025,7 @@ async def test_client_registration_response_types_without_code(self, test_client @pytest.mark.anyio async def test_client_registration_default_response_types( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ) -> None: + ): """Test that registration uses default response_types of ['code'] when not specified.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1053,7 +1044,7 @@ async def test_client_registration_default_response_types( @pytest.mark.anyio async def test_client_secret_basic_authentication( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that client_secret_basic authentication works correctly.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1099,7 +1090,7 @@ async def test_client_secret_basic_authentication( @pytest.mark.anyio async def test_wrong_auth_method_without_valid_credentials_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that using the wrong authentication method fails when credentials are missing.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1151,7 +1142,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails( @pytest.mark.anyio async def test_basic_auth_without_header_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that omitting Basic auth when client_secret_basic is registered fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1196,7 +1187,7 @@ async def test_basic_auth_without_header_fails( @pytest.mark.anyio async def test_basic_auth_invalid_base64_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that invalid base64 in Basic auth header fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1241,7 +1232,7 @@ async def test_basic_auth_invalid_base64_fails( @pytest.mark.anyio async def test_basic_auth_no_colon_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that Basic auth without colon separator fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1287,7 +1278,7 @@ async def test_basic_auth_no_colon_fails( @pytest.mark.anyio async def test_basic_auth_client_id_mismatch_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that client_id mismatch between body and Basic auth fails.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1333,7 +1324,7 @@ async def test_basic_auth_client_id_mismatch_fails( @pytest.mark.anyio async def test_none_auth_method_public_client( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] - ) -> None: + ): """Test that 'none' authentication method works for public clients.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -1380,9 +1371,7 @@ class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" @pytest.mark.anyio - async def test_authorize_missing_client_id( - self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] - ) -> None: + async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): """Test authorization endpoint with missing client_id. According to the OAuth2.0 spec, if client_id is missing, the server should @@ -1406,9 +1395,7 @@ async def test_authorize_missing_client_id( assert "client_id" in response.text.lower() @pytest.mark.anyio - async def test_authorize_invalid_client_id( - self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] - ) -> None: + async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): """Test authorization endpoint with invalid client_id. According to the OAuth2.0 spec, if client_id is invalid, the server should @@ -1434,7 +1421,7 @@ async def test_authorize_invalid_client_id( @pytest.mark.anyio async def test_authorize_missing_redirect_uri( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ) -> None: + ): """Test authorization endpoint with missing redirect_uri. If client has only one registered redirect_uri, it can be omitted. @@ -1460,7 +1447,7 @@ async def test_authorize_missing_redirect_uri( @pytest.mark.anyio async def test_authorize_invalid_redirect_uri( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ) -> None: + ): """Test authorization endpoint with invalid redirect_uri. According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, @@ -1500,7 +1487,7 @@ async def test_authorize_invalid_redirect_uri( ) async def test_authorize_missing_redirect_uri_multiple_registered( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ) -> None: + ): """Test endpoint with missing redirect_uri with multiple registered URIs. If client has multiple registered redirect_uris, redirect_uri must be provided. @@ -1526,7 +1513,7 @@ async def test_authorize_missing_redirect_uri_multiple_registered( @pytest.mark.anyio async def test_authorize_unsupported_response_type( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ) -> None: + ): """Test authorization endpoint with unsupported response_type. According to the OAuth2.0 spec, for other errors like unsupported_response_type, @@ -1560,7 +1547,7 @@ async def test_authorize_unsupported_response_type( @pytest.mark.anyio async def test_authorize_missing_response_type( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ) -> None: + ): """Test authorization endpoint with missing response_type. Missing required parameter should result in invalid_request error. @@ -1593,7 +1580,7 @@ async def test_authorize_missing_response_type( @pytest.mark.anyio async def test_authorize_missing_pkce_challenge( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] - ) -> None: + ): """Test authorization endpoint with missing PKCE code_challenge. Missing PKCE parameters should result in invalid_request error. @@ -1624,7 +1611,7 @@ async def test_authorize_missing_pkce_challenge( @pytest.mark.anyio async def test_authorize_invalid_scope( self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] - ) -> None: + ): """Test authorization endpoint with invalid scope. Invalid scope should redirect with invalid_scope error. diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index 19fc5130c..fe18e91bd 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -9,7 +9,7 @@ class TestRenderPrompt: @pytest.mark.anyio - async def test_basic_fn(self) -> None: + async def test_basic_fn(self): def fn() -> str: return "Hello, world!" @@ -19,7 +19,7 @@ def fn() -> str: ] @pytest.mark.anyio - async def test_async_fn(self) -> None: + async def test_async_fn(self): async def fn() -> str: return "Hello, world!" @@ -29,7 +29,7 @@ async def fn() -> str: ] @pytest.mark.anyio - async def test_fn_with_args(self) -> None: + async def test_fn_with_args(self): async def fn(name: str, age: int = 30) -> str: return f"Hello, {name}! You're {age} years old." @@ -39,7 +39,7 @@ async def fn(name: str, age: int = 30) -> str: ] @pytest.mark.anyio - async def test_fn_with_invalid_kwargs(self) -> None: + async def test_fn_with_invalid_kwargs(self): async def fn(name: str, age: int = 30) -> str: # pragma: no cover return f"Hello, {name}! You're {age} years old." @@ -48,7 +48,7 @@ async def fn(name: str, age: int = 30) -> str: # pragma: no cover await prompt.render({"age": 40}, Context()) @pytest.mark.anyio - async def test_fn_returns_message(self) -> None: + async def test_fn_returns_message(self): async def fn() -> UserMessage: return UserMessage(content="Hello, world!") @@ -58,7 +58,7 @@ async def fn() -> UserMessage: ] @pytest.mark.anyio - async def test_fn_returns_assistant_message(self) -> None: + async def test_fn_returns_assistant_message(self): async def fn() -> AssistantMessage: return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) @@ -68,7 +68,7 @@ async def fn() -> AssistantMessage: ] @pytest.mark.anyio - async def test_fn_returns_multiple_messages(self) -> None: + async def test_fn_returns_multiple_messages(self): expected: list[Message] = [ UserMessage("Hello, world!"), AssistantMessage("How can I help you today?"), @@ -82,7 +82,7 @@ async def fn() -> list[Message]: assert await prompt.render(None, Context()) == expected @pytest.mark.anyio - async def test_fn_returns_list_of_strings(self) -> None: + async def test_fn_returns_list_of_strings(self): expected = [ "Hello, world!", "I'm looking for a restaurant in the center of town.", @@ -95,7 +95,7 @@ async def fn() -> list[str]: assert await prompt.render(None, Context()) == [UserMessage(t) for t in expected] @pytest.mark.anyio - async def test_fn_returns_resource_content(self) -> None: + async def test_fn_returns_resource_content(self): """Test returning a message with resource content.""" async def fn() -> UserMessage: @@ -125,7 +125,7 @@ async def fn() -> UserMessage: ] @pytest.mark.anyio - async def test_fn_returns_mixed_content(self) -> None: + async def test_fn_returns_mixed_content(self): """Test returning messages with mixed content types.""" async def fn() -> list[Message]: @@ -161,7 +161,7 @@ async def fn() -> list[Message]: ] @pytest.mark.anyio - async def test_fn_returns_dict_with_resource(self) -> None: + async def test_fn_returns_dict_with_resource(self): """Test returning a dict with resource content.""" async def fn() -> dict[str, Any]: diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 9a41931ab..99a03db56 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -7,7 +7,7 @@ class TestPromptManager: - def test_add_prompt(self) -> None: + def test_add_prompt(self): """Test adding a prompt to the manager.""" def fn() -> str: # pragma: no cover @@ -19,7 +19,7 @@ def fn() -> str: # pragma: no cover assert added == prompt assert manager.get_prompt("fn") == prompt - def test_add_duplicate_prompt(self, caplog: pytest.LogCaptureFixture) -> None: + def test_add_duplicate_prompt(self, caplog: pytest.LogCaptureFixture): """Test adding the same prompt twice.""" def fn() -> str: # pragma: no cover @@ -32,7 +32,7 @@ def fn() -> str: # pragma: no cover assert first == second assert "Prompt already exists" in caplog.text - def test_disable_warn_on_duplicate_prompts(self, caplog: pytest.LogCaptureFixture) -> None: + def test_disable_warn_on_duplicate_prompts(self, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate prompts.""" def fn() -> str: # pragma: no cover @@ -45,7 +45,7 @@ def fn() -> str: # pragma: no cover assert first == second assert "Prompt already exists" not in caplog.text - def test_list_prompts(self) -> None: + def test_list_prompts(self): """Test listing all prompts.""" def fn1() -> str: # pragma: no cover @@ -64,7 +64,7 @@ def fn2() -> str: # pragma: no cover assert prompts == [prompt1, prompt2] @pytest.mark.anyio - async def test_render_prompt(self) -> None: + async def test_render_prompt(self): """Test rendering a prompt.""" def fn() -> str: @@ -77,7 +77,7 @@ def fn() -> str: assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio - async def test_render_prompt_with_args(self) -> None: + async def test_render_prompt_with_args(self): """Test rendering a prompt with arguments.""" def fn(name: str) -> str: @@ -90,14 +90,14 @@ def fn(name: str) -> str: assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio - async def test_render_unknown_prompt(self) -> None: + async def test_render_unknown_prompt(self): """Test rendering a non-existent prompt.""" manager = PromptManager() with pytest.raises(ValueError, match="Unknown prompt: unknown"): await manager.render_prompt("unknown", None, Context()) @pytest.mark.anyio - async def test_render_prompt_with_missing_args(self) -> None: + async def test_render_prompt_with_missing_args(self): """Test rendering a prompt with missing required arguments.""" def fn(name: str) -> str: # pragma: no cover diff --git a/tests/server/mcpserver/resources/test_file_resources.py b/tests/server/mcpserver/resources/test_file_resources.py index 26ce4e475..94885113a 100644 --- a/tests/server/mcpserver/resources/test_file_resources.py +++ b/tests/server/mcpserver/resources/test_file_resources.py @@ -1,5 +1,4 @@ import os -from collections.abc import Generator from pathlib import Path from tempfile import NamedTemporaryFile @@ -9,7 +8,7 @@ @pytest.fixture -def temp_file() -> Generator[Path, None, None]: +def temp_file(): """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. @@ -28,7 +27,7 @@ def temp_file() -> Generator[Path, None, None]: class TestFileResource: """Test FileResource functionality.""" - def test_file_resource_creation(self, temp_file: Path) -> None: + def test_file_resource_creation(self, temp_file: Path): """Test creating a FileResource.""" resource = FileResource( uri=temp_file.as_uri(), @@ -43,7 +42,7 @@ def test_file_resource_creation(self, temp_file: Path) -> None: assert resource.path == temp_file assert resource.is_binary is False # default - def test_file_resource_str_path_conversion(self, temp_file: Path) -> None: + def test_file_resource_str_path_conversion(self, temp_file: Path): """Test FileResource handles string paths.""" resource = FileResource( uri=f"file://{temp_file}", @@ -54,7 +53,7 @@ def test_file_resource_str_path_conversion(self, temp_file: Path) -> None: assert resource.path.is_absolute() @pytest.mark.anyio - async def test_read_text_file(self, temp_file: Path) -> None: + async def test_read_text_file(self, temp_file: Path): """Test reading a text file.""" resource = FileResource( uri=f"file://{temp_file}", @@ -66,7 +65,7 @@ async def test_read_text_file(self, temp_file: Path) -> None: assert resource.mime_type == "text/plain" @pytest.mark.anyio - async def test_read_binary_file(self, temp_file: Path) -> None: + async def test_read_binary_file(self, temp_file: Path): """Test reading a file as binary.""" resource = FileResource( uri=f"file://{temp_file}", @@ -78,7 +77,7 @@ async def test_read_binary_file(self, temp_file: Path) -> None: assert isinstance(content, bytes) assert content == b"test content" - def test_relative_path_error(self) -> None: + def test_relative_path_error(self): """Test error on relative path.""" with pytest.raises(ValueError, match="Path must be absolute"): FileResource( @@ -88,7 +87,7 @@ def test_relative_path_error(self) -> None: ) @pytest.mark.anyio - async def test_missing_file_error(self, temp_file: Path) -> None: + async def test_missing_file_error(self, temp_file: Path): """Test error when file doesn't exist.""" # Create path to non-existent file missing = temp_file.parent / "missing.txt" @@ -102,7 +101,7 @@ async def test_missing_file_error(self, temp_file: Path) -> None: @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") @pytest.mark.anyio - async def test_permission_error(self, temp_file: Path) -> None: # pragma: lax no cover + async def test_permission_error(self, temp_file: Path): # pragma: lax no cover """Test reading a file without permissions.""" temp_file.chmod(0o000) # Remove all permissions try: diff --git a/tests/server/mcpserver/resources/test_function_resources.py b/tests/server/mcpserver/resources/test_function_resources.py index 8cc498530..5f5c216ed 100644 --- a/tests/server/mcpserver/resources/test_function_resources.py +++ b/tests/server/mcpserver/resources/test_function_resources.py @@ -7,7 +7,7 @@ class TestFunctionResource: """Test FunctionResource functionality.""" - def test_function_resource_creation(self) -> None: + def test_function_resource_creation(self): """Test creating a FunctionResource.""" def my_func() -> str: # pragma: no cover @@ -26,7 +26,7 @@ def my_func() -> str: # pragma: no cover assert resource.fn == my_func @pytest.mark.anyio - async def test_read_text(self) -> None: + async def test_read_text(self): """Test reading text from a FunctionResource.""" def get_data() -> str: @@ -42,7 +42,7 @@ def get_data() -> str: assert resource.mime_type == "text/plain" @pytest.mark.anyio - async def test_read_binary(self) -> None: + async def test_read_binary(self): """Test reading binary data from a FunctionResource.""" def get_data() -> bytes: @@ -57,7 +57,7 @@ def get_data() -> bytes: assert content == b"Hello, world!" @pytest.mark.anyio - async def test_json_conversion(self) -> None: + async def test_json_conversion(self): """Test automatic JSON conversion of non-string results.""" def get_data() -> dict[str, str]: @@ -73,7 +73,7 @@ def get_data() -> dict[str, str]: assert '"key": "value"' in content @pytest.mark.anyio - async def test_error_handling(self) -> None: + async def test_error_handling(self): """Test error handling in FunctionResource.""" def failing_func() -> str: @@ -88,7 +88,7 @@ def failing_func() -> str: await resource.read() @pytest.mark.anyio - async def test_basemodel_conversion(self) -> None: + async def test_basemodel_conversion(self): """Test handling of BaseModel types.""" class MyModel(BaseModel): @@ -103,7 +103,7 @@ class MyModel(BaseModel): assert content == '{\n "name": "test"\n}' @pytest.mark.anyio - async def test_custom_type_conversion(self) -> None: + async def test_custom_type_conversion(self): """Test handling of custom types.""" class CustomData: @@ -122,7 +122,7 @@ def get_data() -> CustomData: assert isinstance(content, str) @pytest.mark.anyio - async def test_async_read_text(self) -> None: + async def test_async_read_text(self): """Test reading text from async FunctionResource.""" async def get_data() -> str: @@ -138,7 +138,7 @@ async def get_data() -> str: assert resource.mime_type == "text/plain" @pytest.mark.anyio - async def test_from_function(self) -> None: + async def test_from_function(self): """Test creating a FunctionResource from a function.""" async def get_data() -> str: # pragma: no cover @@ -158,7 +158,7 @@ async def get_data() -> str: # pragma: no cover class TestFunctionResourceMetadata: - def test_from_function_with_metadata(self) -> None: + def test_from_function_with_metadata(self): # from_function() accepts meta dict and stores it on the resource for static resources def get_data() -> str: # pragma: no cover @@ -178,7 +178,7 @@ def get_data() -> str: # pragma: no cover assert "data" in resource.meta["tags"] assert "readonly" in resource.meta["tags"] - def test_from_function_without_metadata(self) -> None: + def test_from_function_without_metadata(self): # meta parameter is optional and defaults to None for backward compatibility def get_data() -> str: # pragma: no cover diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index 763a004ad..724b57997 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -1,4 +1,3 @@ -from collections.abc import Generator from pathlib import Path from tempfile import NamedTemporaryFile @@ -10,7 +9,7 @@ @pytest.fixture -def temp_file() -> Generator[Path, None, None]: +def temp_file(): """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. @@ -29,7 +28,7 @@ def temp_file() -> Generator[Path, None, None]: class TestResourceManager: """Test ResourceManager functionality.""" - def test_add_resource(self, temp_file: Path) -> None: + def test_add_resource(self, temp_file: Path): """Test adding a resource.""" manager = ResourceManager() resource = FileResource( @@ -41,7 +40,7 @@ def test_add_resource(self, temp_file: Path) -> None: assert added == resource assert manager.list_resources() == [resource] - def test_add_duplicate_resource(self, temp_file: Path) -> None: + def test_add_duplicate_resource(self, temp_file: Path): """Test adding the same resource twice.""" manager = ResourceManager() resource = FileResource( @@ -54,7 +53,7 @@ def test_add_duplicate_resource(self, temp_file: Path) -> None: assert first == second assert manager.list_resources() == [resource] - def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture) -> None: + def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): """Test warning on duplicate resources.""" manager = ResourceManager() resource = FileResource( @@ -66,7 +65,7 @@ def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCa manager.add_resource(resource) assert "Resource already exists" in caplog.text - def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture) -> None: + def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate resources.""" manager = ResourceManager(warn_on_duplicate_resources=False) resource = FileResource( @@ -79,7 +78,7 @@ def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pyte assert "Resource already exists" not in caplog.text @pytest.mark.anyio - async def test_get_resource(self, temp_file: Path) -> None: + async def test_get_resource(self, temp_file: Path): """Test getting a resource by URI.""" manager = ResourceManager() resource = FileResource( @@ -92,7 +91,7 @@ async def test_get_resource(self, temp_file: Path) -> None: assert retrieved == resource @pytest.mark.anyio - async def test_get_resource_from_template(self) -> None: + async def test_get_resource_from_template(self): """Test getting a resource through a template.""" manager = ResourceManager() @@ -112,13 +111,13 @@ def greet(name: str) -> str: assert content == "Hello, world!" @pytest.mark.anyio - async def test_get_unknown_resource(self) -> None: + async def test_get_unknown_resource(self): """Test getting a non-existent resource.""" manager = ResourceManager() with pytest.raises(ValueError, match="Unknown resource"): await manager.get_resource(AnyUrl("unknown://test"), Context()) - def test_list_resources(self, temp_file: Path) -> None: + def test_list_resources(self, temp_file: Path): """Test listing all resources.""" manager = ResourceManager() resource1 = FileResource( @@ -141,7 +140,7 @@ def test_list_resources(self, temp_file: Path) -> None: class TestResourceManagerMetadata: """Test ResourceManager Metadata""" - def test_add_template_with_metadata(self) -> None: + def test_add_template_with_metadata(self): """Test that ResourceManager.add_template() accepts and passes meta parameter.""" manager = ResourceManager() @@ -162,7 +161,7 @@ def get_item(id: str) -> str: # pragma: no cover assert template.meta["source"] == "database" assert template.meta["cached"] is True - def test_add_template_without_metadata(self) -> None: + def test_add_template_without_metadata(self): """Test that ResourceManager.add_template() works without meta parameter.""" manager = ResourceManager() diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 818f13841..640cfe803 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -12,7 +12,7 @@ class TestResourceTemplate: """Test ResourceTemplate functionality.""" - def test_template_creation(self) -> None: + def test_template_creation(self): """Test creating a template from a function.""" def my_func(key: str, value: int) -> dict[str, Any]: @@ -28,7 +28,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: assert template.mime_type == "text/plain" # default assert template.fn(key="test", value=42) == my_func(key="test", value=42) - def test_template_matches(self) -> None: + def test_template_matches(self): """Test matching URIs against a template.""" def my_func(key: str, value: int) -> dict[str, Any]: # pragma: no cover @@ -49,7 +49,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: # pragma: no cover assert template.matches("other://foo/123") is None @pytest.mark.anyio - async def test_create_resource(self) -> None: + async def test_create_resource(self): """Test creating a resource from a template.""" def my_func(key: str, value: int) -> dict[str, Any]: @@ -74,7 +74,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: assert data == {"key": "foo", "value": 123} @pytest.mark.anyio - async def test_template_error(self) -> None: + async def test_template_error(self): """Test error handling in template resource creation.""" def failing_func(x: str) -> str: @@ -90,7 +90,7 @@ def failing_func(x: str) -> str: await template.create_resource("fail://test", {"x": "test"}, Context()) @pytest.mark.anyio - async def test_async_text_resource(self) -> None: + async def test_async_text_resource(self): """Test creating a text resource from async function.""" async def greet(name: str) -> str: @@ -113,7 +113,7 @@ async def greet(name: str) -> str: assert content == "Hello, world!" @pytest.mark.anyio - async def test_async_binary_resource(self) -> None: + async def test_async_binary_resource(self): """Test creating a binary resource from async function.""" async def get_bytes(value: str) -> bytes: @@ -136,7 +136,7 @@ async def get_bytes(value: str) -> bytes: assert content == b"test" @pytest.mark.anyio - async def test_basemodel_conversion(self) -> None: + async def test_basemodel_conversion(self): """Test handling of BaseModel types.""" class MyModel(BaseModel): @@ -165,11 +165,11 @@ def get_data(key: str, value: int) -> MyModel: assert data == {"key": "foo", "value": 123} @pytest.mark.anyio - async def test_custom_type_conversion(self) -> None: + async def test_custom_type_conversion(self): """Test handling of custom types.""" class CustomData: - def __init__(self, value: str) -> None: + def __init__(self, value: str): self.value = value def __str__(self) -> str: @@ -198,7 +198,7 @@ def get_data(value: str) -> CustomData: class TestResourceTemplateAnnotations: """Test annotations on resource templates.""" - def test_template_with_annotations(self) -> None: + def test_template_with_annotations(self): """Test creating a template with annotations.""" def get_user_data(user_id: str) -> str: # pragma: no cover @@ -213,7 +213,7 @@ def get_user_data(user_id: str) -> str: # pragma: no cover assert template.annotations is not None assert template.annotations.priority == 0.9 - def test_template_without_annotations(self) -> None: + def test_template_without_annotations(self): """Test that annotations are optional for templates.""" def get_user_data(user_id: str) -> str: # pragma: no cover @@ -224,7 +224,7 @@ def get_user_data(user_id: str) -> str: # pragma: no cover assert template.annotations is None @pytest.mark.anyio - async def test_template_annotations_in_mcpserver(self) -> None: + async def test_template_annotations_in_mcpserver(self): """Test template annotations via an MCPServer decorator.""" mcp = MCPServer() @@ -241,7 +241,7 @@ def get_dynamic(id: str) -> str: # pragma: no cover assert templates[0].annotations.priority == 0.7 @pytest.mark.anyio - async def test_template_created_resources_inherit_annotations(self) -> None: + async def test_template_created_resources_inherit_annotations(self): """Test that resources created from templates inherit annotations.""" def get_item(item_id: str) -> str: @@ -268,7 +268,7 @@ def get_item(item_id: str) -> str: class TestResourceTemplateMetadata: """Test ResourceTemplate meta handling.""" - def test_template_from_function_with_metadata(self) -> None: + def test_template_from_function_with_metadata(self): """Test that ResourceTemplate.from_function() accepts and stores meta parameter.""" def get_user(user_id: str) -> str: # pragma: no cover @@ -288,7 +288,7 @@ def get_user(user_id: str) -> str: # pragma: no cover assert template.meta["rate_limit"] == 100 @pytest.mark.anyio - async def test_template_created_resources_inherit_metadata(self) -> None: + async def test_template_created_resources_inherit_metadata(self): """Test that resources created from templates inherit meta from template.""" def get_item(item_id: str) -> str: diff --git a/tests/server/mcpserver/resources/test_resources.py b/tests/server/mcpserver/resources/test_resources.py index cc428a7af..5d36beda8 100644 --- a/tests/server/mcpserver/resources/test_resources.py +++ b/tests/server/mcpserver/resources/test_resources.py @@ -8,7 +8,7 @@ class TestResourceValidation: """Test base Resource validation.""" - def test_resource_uri_accepts_any_string(self) -> None: + def test_resource_uri_accepts_any_string(self): """Test that URI field accepts any string per MCP spec.""" def dummy_func() -> str: # pragma: no cover @@ -38,7 +38,7 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.uri == "custom://resource" - def test_resource_name_from_uri(self) -> None: + def test_resource_name_from_uri(self): """Test name is extracted from URI if not provided.""" def dummy_func() -> str: # pragma: no cover @@ -50,7 +50,7 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.name == "resource://my-resource" - def test_resource_name_validation(self) -> None: + def test_resource_name_validation(self): """Test name validation.""" def dummy_func() -> str: # pragma: no cover @@ -70,7 +70,7 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.name == "explicit-name" - def test_resource_mime_type(self) -> None: + def test_resource_mime_type(self): """Test mime type handling.""" def dummy_func() -> str: # pragma: no cover @@ -100,7 +100,7 @@ def dummy_func() -> str: # pragma: no cover assert resource.mime_type == 'text/plain; charset="utf-8"' @pytest.mark.anyio - async def test_resource_read_abstract(self) -> None: + async def test_resource_read_abstract(self): """Test that Resource.read() is abstract.""" class ConcreteResource(Resource): @@ -113,7 +113,7 @@ class ConcreteResource(Resource): class TestResourceAnnotations: """Test annotations on resources.""" - def test_resource_with_annotations(self) -> None: + def test_resource_with_annotations(self): """Test creating a resource with annotations.""" def get_data() -> str: # pragma: no cover @@ -127,7 +127,7 @@ def get_data() -> str: # pragma: no cover assert resource.annotations.audience == ["user"] assert resource.annotations.priority == 0.8 - def test_resource_without_annotations(self) -> None: + def test_resource_without_annotations(self): """Test that annotations are optional.""" def get_data() -> str: # pragma: no cover @@ -138,7 +138,7 @@ def get_data() -> str: # pragma: no cover assert resource.annotations is None @pytest.mark.anyio - async def test_resource_annotations_in_mcpserver(self) -> None: + async def test_resource_annotations_in_mcpserver(self): """Test resource annotations via MCPServer decorator.""" mcp = MCPServer() @@ -155,7 +155,7 @@ def get_annotated() -> str: # pragma: no cover assert resources[0].annotations.priority == 0.5 @pytest.mark.anyio - async def test_resource_annotations_with_both_audiences(self) -> None: + async def test_resource_annotations_with_both_audiences(self): """Test resource with both user and assistant audience.""" mcp = MCPServer() @@ -173,7 +173,7 @@ def get_both() -> str: # pragma: no cover class TestAnnotationsValidation: """Test validation of annotation values.""" - def test_priority_validation(self) -> None: + def test_priority_validation(self): """Test that priority is validated to be between 0.0 and 1.0.""" # Valid priorities @@ -188,7 +188,7 @@ def test_priority_validation(self) -> None: with pytest.raises(Exception): Annotations(priority=1.1) - def test_audience_validation(self) -> None: + def test_audience_validation(self): """Test that audience only accepts valid roles.""" # Valid audiences @@ -205,7 +205,7 @@ def test_audience_validation(self) -> None: class TestResourceMetadata: """Test metadata field on base Resource class.""" - def test_resource_with_metadata(self) -> None: + def test_resource_with_metadata(self): """Test that Resource base class accepts meta parameter.""" def dummy_func() -> str: # pragma: no cover @@ -225,7 +225,7 @@ def dummy_func() -> str: # pragma: no cover assert resource.meta["version"] == "1.0" assert resource.meta["category"] == "test" - def test_resource_without_metadata(self) -> None: + def test_resource_without_metadata(self): """Test that meta field defaults to None.""" def dummy_func() -> str: # pragma: no cover diff --git a/tests/server/mcpserver/servers/test_file_server.py b/tests/server/mcpserver/servers/test_file_server.py index 3ee02a28f..9c3fe265c 100644 --- a/tests/server/mcpserver/servers/test_file_server.py +++ b/tests/server/mcpserver/servers/test_file_server.py @@ -74,7 +74,7 @@ def delete_file(path: str) -> bool: @pytest.mark.anyio -async def test_list_resources(mcp: MCPServer) -> None: +async def test_list_resources(mcp: MCPServer): resources = await mcp.list_resources() assert len(resources) == 4 @@ -87,7 +87,7 @@ async def test_list_resources(mcp: MCPServer) -> None: @pytest.mark.anyio -async def test_read_resource_dir(mcp: MCPServer) -> None: +async def test_read_resource_dir(mcp: MCPServer): res_iter = await mcp.read_resource("dir://test_dir") res_list = list(res_iter) assert len(res_list) == 1 @@ -104,7 +104,7 @@ async def test_read_resource_dir(mcp: MCPServer) -> None: @pytest.mark.anyio -async def test_read_resource_file(mcp: MCPServer) -> None: +async def test_read_resource_file(mcp: MCPServer): res_iter = await mcp.read_resource("file://test_dir/example.py") res_list = list(res_iter) assert len(res_list) == 1 @@ -113,13 +113,13 @@ async def test_read_resource_file(mcp: MCPServer) -> None: @pytest.mark.anyio -async def test_delete_file(mcp: MCPServer, test_dir: Path) -> None: +async def test_delete_file(mcp: MCPServer, test_dir: Path): await mcp.call_tool("delete_file", arguments={"path": str(test_dir / "example.py")}) assert not (test_dir / "example.py").exists() @pytest.mark.anyio -async def test_delete_file_and_check_resources(mcp: MCPServer, test_dir: Path) -> None: +async def test_delete_file_and_check_resources(mcp: MCPServer, test_dir: Path): await mcp.call_tool("delete_file", arguments={"path": str(test_dir / "example.py")}) res_iter = await mcp.read_resource("file://test_dir/example.py") res_list = list(res_iter) diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 5fd7cbc77..679fb848f 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -1,6 +1,5 @@ """Test the elicitation feature using stdio transport.""" -from collections.abc import Callable, Coroutine from typing import Any import pytest @@ -10,7 +9,7 @@ from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext -from mcp.types import CallToolResult, ElicitRequestParams, ElicitResult, TextContent +from mcp.types import ElicitRequestParams, ElicitResult, TextContent # Shared schema for basic tests @@ -18,7 +17,7 @@ class AnswerSchema(BaseModel): answer: str = Field(description="The user's answer to the question") -def create_ask_user_tool(mcp: MCPServer) -> Callable[[str, Context], Coroutine[Any, Any, str]]: +def create_ask_user_tool(mcp: MCPServer): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") @@ -42,7 +41,7 @@ async def call_tool_and_assert( args: dict[str, Any], expected_text: str | None = None, text_contains: list[str] | None = None, -) -> CallToolResult: +): """Helper to create session, call tool, and assert result.""" async with Client(mcp, elicitation_callback=elicitation_callback) as client: result = await client.call_tool(tool_name, args) @@ -59,13 +58,13 @@ async def call_tool_and_assert( @pytest.mark.anyio -async def test_stdio_elicitation() -> None: +async def test_stdio_elicitation(): """Test the elicitation feature using stdio transport.""" mcp = MCPServer(name="StdioElicitationServer") create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: # pragma: no cover @@ -77,12 +76,12 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_stdio_elicitation_decline() -> None: +async def test_stdio_elicitation_decline(): """Test elicitation with user declining.""" mcp = MCPServer(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -91,13 +90,11 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_elicitation_schema_validation() -> None: +async def test_elicitation_schema_validation(): """Test that elicitation schemas must only contain primitive types.""" mcp = MCPServer(name="ValidationTestServer") - def create_validation_tool( - name: str, schema_class: type[BaseModel] - ) -> Callable[[Context], Coroutine[Any, Any, str]]: + def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") async def tool(ctx: Context) -> str: try: @@ -124,7 +121,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( context: RequestContext[ClientSession], params: ElicitRequestParams - ) -> ElicitResult: # pragma: no cover + ): # pragma: no cover return ElicitResult(action="accept", content={}) async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -138,7 +135,7 @@ async def elicitation_callback( @pytest.mark.anyio -async def test_elicitation_with_optional_fields() -> None: +async def test_elicitation_with_optional_fields(): """Test that Optional fields work correctly in elicitation schemas.""" mcp = MCPServer(name="OptionalFieldServer") @@ -179,7 +176,7 @@ async def optional_tool(ctx: Context) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -199,7 +196,7 @@ async def invalid_optional_tool(ctx: Context) -> str: async def elicitation_callback( context: RequestContext[ClientSession], params: ElicitRequestParams - ) -> ElicitResult: # pragma: no cover + ): # pragma: no cover return ElicitResult(action="accept", content={}) await call_tool_and_assert( @@ -222,7 +219,7 @@ async def valid_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" return f"User {result.action}" # pragma: no cover - async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if "Please provide tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -242,9 +239,7 @@ async def optional_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {tags_str}" return f"User {result.action}" # pragma: no cover - async def optional_multiselect_callback( - context: RequestContext[ClientSession], params: ElicitRequestParams - ) -> ElicitResult: + async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if "Please provide optional tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -255,7 +250,7 @@ async def optional_multiselect_callback( @pytest.mark.anyio -async def test_elicitation_with_default_values() -> None: +async def test_elicitation_with_default_values(): """Test that default values work correctly in elicitation schemas and are included in JSON.""" mcp = MCPServer(name="DefaultValuesServer") @@ -278,9 +273,7 @@ async def defaults_tool(ctx: Context) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify( - context: RequestContext[ClientSession], params: ElicitRequestParams - ) -> ElicitResult: + async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify the schema includes defaults assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requested_schema @@ -302,7 +295,7 @@ async def callback_schema_verify( ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) @@ -313,7 +306,7 @@ async def callback_override(context: RequestContext[ClientSession], params: Elic @pytest.mark.anyio -async def test_elicitation_with_enum_titles() -> None: +async def test_elicitation_with_enum_titles(): """Test elicitation with enum schemas using oneOf/anyOf for titles.""" mcp = MCPServer(name="ColorPreferencesApp") @@ -378,7 +371,7 @@ async def select_color_legacy(ctx: Context) -> str: return f"User: {result.data.user_name}, Color: {result.data.color}" return f"User {result.action}" # pragma: no cover - async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if "colors" in params.message and "legacy" not in params.message: return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) elif "color" in params.message: diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index b0e57ce92..c57d1ee9f 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -92,7 +92,7 @@ def complex_arguments_fn( @pytest.mark.anyio -async def test_complex_function_runtime_arg_validation_non_json() -> None: +async def test_complex_function_runtime_arg_validation_non_json(): """Test that basic non-JSON arguments are validated correctly""" meta = func_metadata(complex_arguments_fn) @@ -129,7 +129,7 @@ async def test_complex_function_runtime_arg_validation_non_json() -> None: @pytest.mark.anyio -async def test_complex_function_runtime_arg_validation_with_json() -> None: +async def test_complex_function_runtime_arg_validation_with_json(): """Test that JSON string arguments are parsed and validated correctly""" meta = func_metadata(complex_arguments_fn) @@ -155,14 +155,14 @@ async def test_complex_function_runtime_arg_validation_with_json() -> None: assert result == "ok!" -def test_str_vs_list_str() -> None: +def test_str_vs_list_str(): """Test handling of string vs list[str] type annotations. This is tricky as '"hello"' can be parsed as a JSON string or a Python string. We want to make sure it's kept as a python string. """ - def func_with_str_types(str_or_list: str | list[str]) -> str | list[str]: # pragma: no cover + def func_with_str_types(str_or_list: str | list[str]): # pragma: no cover return str_or_list meta = func_metadata(func_with_str_types) @@ -182,12 +182,10 @@ def func_with_str_types(str_or_list: str | list[str]) -> str | list[str]: # pra assert result["str_or_list"] == ["hello", "world"] -def test_skip_names() -> None: +def test_skip_names(): """Test that skipped parameters are not included in the model""" - def func_with_many_params( - keep_this: int, skip_this: str, also_keep: float, also_skip: bool - ) -> tuple[int, str, float, bool]: # pragma: no cover + def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool): # pragma: no cover return keep_this, skip_this, also_keep, also_skip # Skip some parameters @@ -205,7 +203,7 @@ def func_with_many_params( assert model.also_keep == 2.5 # type: ignore -def test_structured_output_dict_str_types() -> None: +def test_structured_output_dict_str_types(): """Test that dict[str, T] types are handled without wrapping.""" # Test dict[str, Any] @@ -248,7 +246,7 @@ def func_dict_int_key() -> dict[int, str]: # pragma: no cover @pytest.mark.anyio -async def test_lambda_function() -> None: +async def test_lambda_function(): """Test lambda function schema and validation""" fn: Callable[[str, int], str] = lambda x, y=5: x # noqa: E731 meta = func_metadata(lambda x, y=5: x) @@ -264,7 +262,7 @@ async def test_lambda_function() -> None: "type": "object", } - async def check_call(args: dict[str, Any]) -> Any: + async def check_call(args): return await meta.call_fn_with_arg_validation( fn, fn_is_async=False, @@ -282,7 +280,7 @@ async def check_call(args: dict[str, Any]) -> Any: await check_call({"y": "world"}) -def test_complex_function_json_schema() -> None: +def test_complex_function_json_schema(): """Test JSON schema generation for complex function arguments. Note: Different versions of pydantic output slightly different @@ -449,12 +447,12 @@ def test_complex_function_json_schema() -> None: } -def test_str_vs_int() -> None: +def test_str_vs_int(): """Test that string values are kept as strings even when they contain numbers, while numbers are parsed correctly. """ - def func_with_str_and_int(a: str, b: int) -> str: # pragma: no cover + def func_with_str_and_int(a: str, b: int): # pragma: no cover return a meta = func_metadata(func_with_str_and_int) @@ -463,7 +461,7 @@ def func_with_str_and_int(a: str, b: int) -> str: # pragma: no cover assert result["b"] == 123 -def test_str_annotation_preserves_json_string() -> None: +def test_str_annotation_preserves_json_string(): """Regression test for PR #1113: Ensure that when a parameter is annotated as str, valid JSON strings are NOT parsed into Python objects. @@ -513,7 +511,7 @@ def process_json_config(config: str, enabled: bool = True) -> str: # pragma: no @pytest.mark.anyio -async def test_str_annotation_runtime_validation() -> None: +async def test_str_annotation_runtime_validation(): """Regression test for PR #1113: Test runtime validation with string parameters containing valid JSON to ensure they are passed as strings, not parsed objects. """ @@ -556,10 +554,10 @@ def handle_json_payload(payload: str, strict_mode: bool = False) -> str: # Tests for structured output functionality -def test_structured_output_requires_return_annotation() -> None: +def test_structured_output_requires_return_annotation(): """Test that structured_output=True requires a return annotation""" - def func_no_annotation(): # noqa: ANN202 # pragma: no cover + def func_no_annotation(): # pragma: no cover return "hello" def func_none_annotation() -> None: # pragma: no cover @@ -579,7 +577,7 @@ def func_none_annotation() -> None: # pragma: no cover } -def test_structured_output_basemodel() -> None: +def test_structured_output_basemodel(): """Test structured output with BaseModel return types""" class PersonModel(BaseModel): @@ -603,7 +601,7 @@ def func_returning_person() -> PersonModel: # pragma: no cover } -def test_structured_output_primitives() -> None: +def test_structured_output_primitives(): """Test structured output with primitive return types""" def func_str() -> str: # pragma: no cover @@ -667,7 +665,7 @@ def func_bytes() -> bytes: # pragma: no cover } -def test_structured_output_generic_types() -> None: +def test_structured_output_generic_types(): """Test structured output with generic types (list, dict, Union, etc.)""" def func_list_str() -> list[str]: # pragma: no cover @@ -718,7 +716,7 @@ def func_optional() -> str | None: # pragma: no cover } -def test_structured_output_dataclass() -> None: +def test_structured_output_dataclass(): """Test structured output with dataclass return types""" @dataclass @@ -749,7 +747,7 @@ def func_returning_dataclass() -> PersonDataClass: # pragma: no cover } -def test_structured_output_typeddict() -> None: +def test_structured_output_typeddict(): """Test structured output with TypedDict return types""" class PersonTypedDictOptional(TypedDict, total=False): @@ -791,7 +789,7 @@ def func_returning_typeddict_required() -> PersonTypedDictRequired: # pragma: n } -def test_structured_output_ordinary_class() -> None: +def test_structured_output_ordinary_class(): """Test structured output with ordinary annotated classes""" class PersonClass: @@ -799,7 +797,7 @@ class PersonClass: age: int email: str | None - def __init__(self, name: str, age: int, email: str | None = None) -> None: # pragma: no cover + def __init__(self, name: str, age: int, email: str | None = None): # pragma: no cover self.name = name self.age = age self.email = email @@ -820,10 +818,10 @@ def func_returning_class() -> PersonClass: # pragma: no cover } -def test_unstructured_output_unannotated_class() -> None: +def test_unstructured_output_unannotated_class(): # Test with class that has no annotations class UnannotatedClass: - def __init__(self, x, y) -> None: # pragma: no cover + def __init__(self, x, y): # pragma: no cover self.x = x self.y = y @@ -834,7 +832,7 @@ def func_returning_unannotated() -> UnannotatedClass: # pragma: no cover assert meta.output_schema is None -def test_tool_call_result_is_unstructured_and_not_converted() -> None: +def test_tool_call_result_is_unstructured_and_not_converted(): def func_returning_call_tool_result() -> CallToolResult: return CallToolResult(content=[]) @@ -844,7 +842,7 @@ def func_returning_call_tool_result() -> CallToolResult: assert isinstance(meta.convert_result(func_returning_call_tool_result()), CallToolResult) -def test_tool_call_result_annotated_is_structured_and_converted() -> None: +def test_tool_call_result_annotated_is_structured_and_converted(): class PersonClass(BaseModel): name: str @@ -864,7 +862,7 @@ def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, Per assert isinstance(meta.convert_result(func_returning_annotated_tool_call_result()), CallToolResult) -def test_tool_call_result_annotated_is_structured_and_invalid() -> None: +def test_tool_call_result_annotated_is_structured_and_invalid(): class PersonClass(BaseModel): name: str @@ -877,7 +875,7 @@ def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, Per meta.convert_result(func_returning_annotated_tool_call_result()) -def test_tool_call_result_in_optional_is_rejected() -> None: +def test_tool_call_result_in_optional_is_rejected(): """Test that Optional[CallToolResult] raises InvalidSignature""" def func_optional_call_tool_result() -> CallToolResult | None: # pragma: no cover @@ -890,7 +888,7 @@ def func_optional_call_tool_result() -> CallToolResult | None: # pragma: no cov assert "CallToolResult" in str(exc_info.value) -def test_tool_call_result_in_union_is_rejected() -> None: +def test_tool_call_result_in_union_is_rejected(): """Test that Union[str, CallToolResult] raises InvalidSignature""" def func_union_call_tool_result() -> str | CallToolResult: # pragma: no cover @@ -903,7 +901,7 @@ def func_union_call_tool_result() -> str | CallToolResult: # pragma: no cover assert "CallToolResult" in str(exc_info.value) -def test_tool_call_result_in_pipe_union_is_rejected() -> None: +def test_tool_call_result_in_pipe_union_is_rejected(): """Test that str | CallToolResult raises InvalidSignature""" def func_pipe_union_call_tool_result() -> str | CallToolResult: # pragma: no cover @@ -916,7 +914,7 @@ def func_pipe_union_call_tool_result() -> str | CallToolResult: # pragma: no co assert "CallToolResult" in str(exc_info.value) -def test_structured_output_with_field_descriptions() -> None: +def test_structured_output_with_field_descriptions(): """Test that Field descriptions are preserved in structured output""" class ModelWithDescriptions(BaseModel): @@ -938,7 +936,7 @@ def func_with_descriptions() -> ModelWithDescriptions: # pragma: no cover } -def test_structured_output_nested_models() -> None: +def test_structured_output_nested_models(): """Test structured output with nested models""" class Address(BaseModel): @@ -977,7 +975,7 @@ def func_nested() -> PersonWithAddress: # pragma: no cover } -def test_structured_output_unserializable_type_error() -> None: +def test_structured_output_unserializable_type_error(): """Test error when structured_output=True is used with unserializable types""" # Test with a class that has non-serializable default values @@ -1018,7 +1016,7 @@ def func_returning_namedtuple() -> Point: # pragma: no cover assert "Point" in str(exc_info.value) -def test_structured_output_aliases() -> None: +def test_structured_output_aliases(): """Test that field aliases are consistent between schema and output""" class ModelWithAliases(BaseModel): @@ -1063,7 +1061,7 @@ def func_with_aliases() -> ModelWithAliases: # pragma: no cover assert structured_content_defaults["second"] is None -def test_basemodel_reserved_names() -> None: +def test_basemodel_reserved_names(): """Test that functions with parameters named after BaseModel methods work correctly""" def func_with_reserved_names( # pragma: no cover @@ -1091,7 +1089,7 @@ def func_with_reserved_names( # pragma: no cover @pytest.mark.anyio -async def test_basemodel_reserved_names_validation() -> None: +async def test_basemodel_reserved_names_validation(): """Test that validation and calling works with reserved parameter names""" def func_with_reserved_names( @@ -1149,7 +1147,7 @@ def func_with_reserved_names( assert dumped["normal_param"] == "test" -def test_basemodel_reserved_names_with_json_preparsing() -> None: +def test_basemodel_reserved_names_with_json_preparsing(): """Test that pre_parse_json works correctly with reserved parameter names""" def func_with_reserved_json( # pragma: no cover @@ -1175,7 +1173,7 @@ def func_with_reserved_json( # pragma: no cover assert result["normal"] == "plain string" -def test_disallowed_type_qualifier() -> None: +def test_disallowed_type_qualifier(): def func_disallowed_qualifier() -> Final[int]: # type: ignore pass # pragma: no cover @@ -1184,7 +1182,7 @@ def func_disallowed_qualifier() -> Final[int]: # type: ignore assert "return annotation contains an invalid type qualifier" in str(exc_info.value) -def test_preserves_pydantic_metadata() -> None: +def test_preserves_pydantic_metadata(): def func_with_metadata() -> Annotated[int, Field(gt=1)]: ... # pragma: no branch meta = func_metadata(func_with_metadata) diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index c6ca50685..f71c0574c 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -59,7 +59,7 @@ class NotificationCollector: """Collects notifications from the server for testing.""" - def __init__(self) -> None: + def __init__(self): self.progress_notifications: list[ProgressNotificationParams] = [] self.log_messages: list[LoggingMessageNotificationParams] = [] self.resource_notifications: list[NotificationParams | None] = [] @@ -94,7 +94,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: +async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: @@ -184,9 +184,7 @@ async def test_tool_progress() -> None: """Test tool progress reporting.""" collector = NotificationCollector() - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: + async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) if isinstance(message, Exception): # pragma: no cover raise message @@ -265,9 +263,7 @@ async def test_notifications() -> None: """Test notifications and logging functionality.""" collector = NotificationCollector() - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: + async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) if isinstance(message, Exception): # pragma: no cover raise message diff --git a/tests/server/mcpserver/test_parameter_descriptions.py b/tests/server/mcpserver/test_parameter_descriptions.py index a47b29e08..ec9f22c25 100644 --- a/tests/server/mcpserver/test_parameter_descriptions.py +++ b/tests/server/mcpserver/test_parameter_descriptions.py @@ -7,7 +7,7 @@ @pytest.mark.anyio -async def test_parameter_descriptions() -> None: +async def test_parameter_descriptions(): mcp = MCPServer("Test Server") @mcp.tool() diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 205a63334..3ef06d038 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,6 +1,6 @@ import base64 from pathlib import Path -from typing import Any, NoReturn +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -46,7 +46,7 @@ class TestServer: - async def test_create_server(self) -> None: + async def test_create_server(self): mcp = MCPServer( title="MCPServer Server", description="Server description", @@ -65,7 +65,7 @@ async def test_create_server(self) -> None: assert len(mcp.icons) == 1 assert mcp.icons[0].src == "https://example.com/icon.png" - async def test_sse_app_returns_starlette_app(self) -> None: + async def test_sse_app_returns_starlette_app(self): """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test") # Use host="0.0.0.0" to avoid auto DNS protection @@ -82,7 +82,7 @@ async def test_sse_app_returns_starlette_app(self) -> None: assert sse_routes[0].path == "/sse" assert mount_routes[0].path == "/messages" - async def test_non_ascii_description(self) -> None: + async def test_non_ascii_description(self): """Test that MCPServer handles non-ASCII characters in descriptions correctly""" mcp = MCPServer() @@ -105,7 +105,7 @@ def hello_world(name: str = "世界") -> str: assert isinstance(content, TextContent) assert "¡Hola, 世界! 👋" == content.text - async def test_add_tool_decorator(self) -> None: + async def test_add_tool_decorator(self): mcp = MCPServer() @mcp.tool() @@ -114,7 +114,7 @@ def sum(x: int, y: int) -> int: # pragma: no cover assert len(mcp._tool_manager.list_tools()) == 1 - async def test_add_tool_decorator_incorrect_usage(self) -> None: + async def test_add_tool_decorator_incorrect_usage(self): mcp = MCPServer() with pytest.raises(TypeError, match="The @tool decorator was used incorrectly"): @@ -123,7 +123,7 @@ async def test_add_tool_decorator_incorrect_usage(self) -> None: def sum(x: int, y: int) -> int: # pragma: no cover return x + y - async def test_add_resource_decorator(self) -> None: + async def test_add_resource_decorator(self): mcp = MCPServer() @mcp.resource("r://{x}") @@ -132,7 +132,7 @@ def get_data(x: str) -> str: # pragma: no cover assert len(mcp._resource_manager._templates) == 1 - async def test_add_resource_decorator_incorrect_usage(self) -> None: + async def test_add_resource_decorator_incorrect_usage(self): mcp = MCPServer() with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"): @@ -149,7 +149,7 @@ class TestDnsRebindingProtection: based on the host parameter passed to those methods. """ - def test_auto_enabled_for_127_0_0_1_sse(self) -> None: + def test_auto_enabled_for_127_0_0_1_sse(self): """DNS rebinding protection should auto-enable for host=127.0.0.1 in SSE app.""" mcp = MCPServer() # Call sse_app with host=127.0.0.1 to trigger auto-config @@ -158,31 +158,31 @@ def test_auto_enabled_for_127_0_0_1_sse(self) -> None: app = mcp.sse_app(host="127.0.0.1") assert app is not None - def test_auto_enabled_for_127_0_0_1_streamable_http(self) -> None: + def test_auto_enabled_for_127_0_0_1_streamable_http(self): """DNS rebinding protection should auto-enable for host=127.0.0.1 in StreamableHTTP app.""" mcp = MCPServer() app = mcp.streamable_http_app(host="127.0.0.1") assert app is not None - def test_auto_enabled_for_localhost_sse(self) -> None: + def test_auto_enabled_for_localhost_sse(self): """DNS rebinding protection should auto-enable for host=localhost in SSE app.""" mcp = MCPServer() app = mcp.sse_app(host="localhost") assert app is not None - def test_auto_enabled_for_ipv6_localhost_sse(self) -> None: + def test_auto_enabled_for_ipv6_localhost_sse(self): """DNS rebinding protection should auto-enable for host=::1 (IPv6 localhost) in SSE app.""" mcp = MCPServer() app = mcp.sse_app(host="::1") assert app is not None - def test_not_auto_enabled_for_other_hosts_sse(self) -> None: + def test_not_auto_enabled_for_other_hosts_sse(self): """DNS rebinding protection should NOT auto-enable for other hosts in SSE app.""" mcp = MCPServer() app = mcp.sse_app(host="0.0.0.0") assert app is not None - def test_explicit_settings_not_overridden_sse(self) -> None: + def test_explicit_settings_not_overridden_sse(self): """Explicit transport_security settings should not be overridden in SSE app.""" custom_settings = TransportSecuritySettings( enable_dns_rebinding_protection=False, @@ -192,7 +192,7 @@ def test_explicit_settings_not_overridden_sse(self) -> None: app = mcp.sse_app(host="127.0.0.1", transport_security=custom_settings) assert app is not None - def test_explicit_settings_not_overridden_streamable_http(self) -> None: + def test_explicit_settings_not_overridden_streamable_http(self): """Explicit transport_security settings should not be overridden in StreamableHTTP app.""" custom_settings = TransportSecuritySettings( enable_dns_rebinding_protection=False, @@ -228,20 +228,20 @@ def mixed_content_tool_fn() -> list[ContentBlock]: class TestServerTools: - async def test_add_tool(self) -> None: + async def test_add_tool(self): mcp = MCPServer() mcp.add_tool(tool_fn) mcp.add_tool(tool_fn) assert len(mcp._tool_manager.list_tools()) == 1 - async def test_list_tools(self) -> None: + async def test_list_tools(self): mcp = MCPServer() mcp.add_tool(tool_fn) async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 1 - async def test_call_tool(self) -> None: + async def test_call_tool(self): mcp = MCPServer() mcp.add_tool(tool_fn) async with Client(mcp) as client: @@ -249,7 +249,7 @@ async def test_call_tool(self) -> None: assert not hasattr(result, "error") assert len(result.content) > 0 - async def test_tool_exception_handling(self) -> None: + async def test_tool_exception_handling(self): mcp = MCPServer() mcp.add_tool(error_tool_fn) async with Client(mcp) as client: @@ -260,7 +260,7 @@ async def test_tool_exception_handling(self) -> None: assert "Test error" in content.text assert result.is_error is True - async def test_tool_error_handling(self) -> None: + async def test_tool_error_handling(self): mcp = MCPServer() mcp.add_tool(error_tool_fn) async with Client(mcp) as client: @@ -271,7 +271,7 @@ async def test_tool_error_handling(self) -> None: assert "Test error" in content.text assert result.is_error is True - async def test_tool_error_details(self) -> None: + async def test_tool_error_details(self): """Test that exception details are properly formatted in the response""" mcp = MCPServer() mcp.add_tool(error_tool_fn) @@ -283,7 +283,7 @@ async def test_tool_error_details(self) -> None: assert "Test error" in content.text assert result.is_error is True - async def test_tool_return_value_conversion(self) -> None: + async def test_tool_return_value_conversion(self): mcp = MCPServer() mcp.add_tool(tool_fn) async with Client(mcp) as client: @@ -296,7 +296,7 @@ async def test_tool_return_value_conversion(self) -> None: assert result.structured_content is not None assert result.structured_content == {"result": 3} - async def test_tool_image_helper(self, tmp_path: Path) -> None: + async def test_tool_image_helper(self, tmp_path: Path): # Create a test image image_path = tmp_path / "test.png" image_path.write_bytes(b"fake png data") @@ -316,7 +316,7 @@ async def test_tool_image_helper(self, tmp_path: Path) -> None: # Check structured content - Image return type should NOT have structured output assert result.structured_content is None - async def test_tool_audio_helper(self, tmp_path: Path) -> None: + async def test_tool_audio_helper(self, tmp_path: Path): # Create a test audio audio_path = tmp_path / "test.wav" audio_path.write_bytes(b"fake wav data") @@ -348,7 +348,7 @@ async def test_tool_audio_helper(self, tmp_path: Path) -> None: ("test.unknown", "application/octet-stream"), # Unknown extension fallback ], ) - async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str) -> None: + async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str): """Test that Audio helper correctly detects MIME types from file suffixes""" mcp = MCPServer() mcp.add_tool(audio_tool_fn) @@ -368,7 +368,7 @@ async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, decoded = base64.b64decode(content.data) assert decoded == b"fake audio data" - async def test_tool_mixed_content(self) -> None: + async def test_tool_mixed_content(self): mcp = MCPServer() mcp.add_tool(mixed_content_tool_fn) async with Client(mcp) as client: @@ -398,7 +398,7 @@ async def test_tool_mixed_content(self) -> None: for key, value in expected.items(): assert structured_result[i][key] == value - async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path) -> None: + async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path): """Test that lists containing Image objects and other types are handled correctly""" # Create a test image @@ -450,7 +450,7 @@ def mixed_list_fn() -> list: # type: ignore # Check structured content - untyped list with Image objects should NOT have structured output assert result.structured_content is None - async def test_tool_structured_output_basemodel(self) -> None: + async def test_tool_structured_output_basemodel(self): """Test tool with structured output returning BaseModel""" class UserOutput(BaseModel): @@ -484,7 +484,7 @@ def get_user(user_id: int) -> UserOutput: assert isinstance(result.content[0], TextContent) assert '"name": "John Doe"' in result.content[0].text - async def test_tool_structured_output_primitive(self) -> None: + async def test_tool_structured_output_primitive(self): """Test tool with structured output returning primitive type""" def calculate_sum(a: int, b: int) -> int: @@ -510,7 +510,7 @@ def calculate_sum(a: int, b: int) -> int: assert result.structured_content is not None assert result.structured_content == {"result": 12} - async def test_tool_structured_output_list(self) -> None: + async def test_tool_structured_output_list(self): """Test tool with structured output returning list""" def get_numbers() -> list[int]: @@ -526,7 +526,7 @@ def get_numbers() -> list[int]: assert result.structured_content is not None assert result.structured_content == {"result": [1, 2, 3, 4, 5]} - async def test_tool_structured_output_server_side_validation_error(self) -> None: + async def test_tool_structured_output_server_side_validation_error(self): """Test that server-side validation errors are handled properly""" def get_numbers() -> list[int]: @@ -542,7 +542,7 @@ def get_numbers() -> list[int]: assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) - async def test_tool_structured_output_dict_str_any(self) -> None: + async def test_tool_structured_output_dict_str_any(self): """Test tool with dict[str, Any] structured output""" def get_metadata() -> dict[str, Any]: @@ -583,7 +583,7 @@ def get_metadata() -> dict[str, Any]: } assert result.structured_content == expected - async def test_tool_structured_output_dict_str_typed(self) -> None: + async def test_tool_structured_output_dict_str_typed(self): """Test tool with dict[str, T] structured output for specific T""" def get_settings() -> dict[str, str]: @@ -606,7 +606,7 @@ def get_settings() -> dict[str, str]: assert result.is_error is False assert result.structured_content == {"theme": "dark", "language": "en", "timezone": "UTC"} - async def test_remove_tool(self) -> None: + async def test_remove_tool(self): """Test removing a tool from the server.""" mcp = MCPServer() mcp.add_tool(tool_fn) @@ -620,14 +620,14 @@ async def test_remove_tool(self) -> None: # Verify tool is removed assert len(mcp._tool_manager.list_tools()) == 0 - async def test_remove_nonexistent_tool(self) -> None: + async def test_remove_nonexistent_tool(self): """Test that removing a non-existent tool raises ToolError.""" mcp = MCPServer() with pytest.raises(ToolError, match="Unknown tool: nonexistent"): mcp.remove_tool("nonexistent") - async def test_remove_tool_and_list(self) -> None: + async def test_remove_tool_and_list(self): """Test that a removed tool doesn't appear in list_tools.""" mcp = MCPServer() mcp.add_tool(tool_fn) @@ -650,7 +650,7 @@ async def test_remove_tool_and_list(self) -> None: assert len(tools.tools) == 1 assert tools.tools[0].name == "error_tool_fn" - async def test_remove_tool_and_call(self) -> None: + async def test_remove_tool_and_call(self): """Test that calling a removed tool fails appropriately.""" mcp = MCPServer() mcp.add_tool(tool_fn) @@ -676,10 +676,10 @@ async def test_remove_tool_and_call(self) -> None: class TestServerResources: - async def test_text_resource(self) -> None: + async def test_text_resource(self): mcp = MCPServer() - def get_text() -> str: + def get_text(): return "Hello, world!" resource = FunctionResource(uri="resource://test", name="test", fn=get_text) @@ -691,7 +691,7 @@ def get_text() -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Hello, world!" - async def test_read_unknown_resource(self) -> None: + async def test_read_unknown_resource(self): """Test that reading an unknown resource raises MCPError.""" mcp = MCPServer() @@ -699,22 +699,22 @@ async def test_read_unknown_resource(self) -> None: with pytest.raises(MCPError, match="Unknown resource: unknown://missing"): await client.read_resource("unknown://missing") - async def test_read_resource_error(self) -> None: + async def test_read_resource_error(self): """Test that resource read errors are properly wrapped in MCPError.""" mcp = MCPServer() @mcp.resource("resource://failing") - def failing_resource() -> NoReturn: + def failing_resource(): raise ValueError("Resource read failed") async with Client(mcp) as client: with pytest.raises(MCPError, match="Error reading resource resource://failing"): await client.read_resource("resource://failing") - async def test_binary_resource(self) -> None: + async def test_binary_resource(self): mcp = MCPServer() - def get_binary() -> bytes: + def get_binary(): return b"Binary data" resource = FunctionResource( @@ -731,7 +731,7 @@ def get_binary() -> bytes: assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary data").decode() - async def test_file_resource_text(self, tmp_path: Path) -> None: + async def test_file_resource_text(self, tmp_path: Path): mcp = MCPServer() # Create a text file @@ -747,7 +747,7 @@ async def test_file_resource_text(self, tmp_path: Path) -> None: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Hello from file!" - async def test_file_resource_binary(self, tmp_path: Path) -> None: + async def test_file_resource_binary(self, tmp_path: Path): mcp = MCPServer() # Create a binary file @@ -768,7 +768,7 @@ async def test_file_resource_binary(self, tmp_path: Path) -> None: assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() - async def test_function_resource(self) -> None: + async def test_function_resource(self): mcp = MCPServer() @mcp.resource("function://test", name="test_get_data") @@ -787,7 +787,7 @@ def get_data() -> str: # pragma: no cover class TestServerResourceTemplates: - async def test_resource_with_params(self) -> None: + async def test_resource_with_params(self): """Test that a resource with function parameters raises an error if the URI parameters don't match""" mcp = MCPServer() @@ -798,7 +798,7 @@ async def test_resource_with_params(self) -> None: def get_data_fn(param: str) -> str: # pragma: no cover return f"Data: {param}" - async def test_resource_with_uri_params(self) -> None: + async def test_resource_with_uri_params(self): """Test that a resource with URI parameters is automatically a template""" mcp = MCPServer() @@ -808,7 +808,7 @@ async def test_resource_with_uri_params(self) -> None: def get_data() -> str: # pragma: no cover return "Data" - async def test_resource_with_untyped_params(self) -> None: + async def test_resource_with_untyped_params(self): """Test that a resource with untyped parameters raises an error""" mcp = MCPServer() @@ -816,7 +816,7 @@ async def test_resource_with_untyped_params(self) -> None: def get_data(param) -> str: # type: ignore # pragma: no cover return "Data" - async def test_resource_matching_params(self) -> None: + async def test_resource_matching_params(self): """Test that a resource with matching URI and function parameters works""" mcp = MCPServer() @@ -830,7 +830,7 @@ def get_data(name: str) -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for test" - async def test_resource_mismatched_params(self) -> None: + async def test_resource_mismatched_params(self): """Test that mismatched parameters raise an error""" mcp = MCPServer() @@ -840,7 +840,7 @@ async def test_resource_mismatched_params(self) -> None: def get_data(user: str) -> str: # pragma: no cover return f"Data for {user}" - async def test_resource_multiple_params(self) -> None: + async def test_resource_multiple_params(self): """Test that multiple parameters work correctly""" mcp = MCPServer() @@ -854,7 +854,7 @@ def get_data(org: str, repo: str) -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for cursor/myrepo" - async def test_resource_multiple_mismatched_params(self) -> None: + async def test_resource_multiple_mismatched_params(self): """Test that mismatched parameters raise an error""" mcp = MCPServer() @@ -877,7 +877,7 @@ def get_static_data() -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Static data" - async def test_template_to_resource_conversion(self) -> None: + async def test_template_to_resource_conversion(self): """Test that templates are properly converted to resources when accessed""" mcp = MCPServer() @@ -895,7 +895,7 @@ def get_data(name: str) -> str: result = await resource.read() assert result == "Data for test" - async def test_resource_template_includes_mime_type(self) -> None: + async def test_resource_template_includes_mime_type(self): """Test that list resource templates includes the correct mimeType.""" mcp = MCPServer() @@ -928,7 +928,7 @@ class TestServerResourceMetadata: Note: read_resource does NOT pass meta to protocol response (lowlevel/server.py only extracts content/mime_type). """ - async def test_resource_decorator_with_metadata(self) -> None: + async def test_resource_decorator_with_metadata(self): """Test that @resource decorator accepts and passes meta parameter.""" # Tests static resource flow: decorator -> FunctionResource -> list_resources (server.py:544,635,361) mcp = MCPServer() @@ -949,7 +949,7 @@ def get_config() -> str: ... # pragma: no branch ] ) - async def test_resource_template_decorator_with_metadata(self) -> None: + async def test_resource_template_decorator_with_metadata(self): """Test that @resource decorator passes meta to templates.""" # Tests template resource flow: decorator -> add_template() -> list_resource_templates (server.py:544,622,377) mcp = MCPServer() @@ -970,7 +970,7 @@ def get_weather(city: str) -> str: ... # pragma: no branch ] ) - async def test_read_resource_returns_meta(self) -> None: + async def test_read_resource_returns_meta(self): """Test that read_resource includes meta in response.""" # Tests end-to-end: Resource.meta -> ReadResourceContents.meta -> protocol _meta (lowlevel/server.py:341,371) mcp = MCPServer() @@ -998,7 +998,7 @@ def get_data() -> str: class TestContextInjection: """Test context injection in tools, resources, and prompts.""" - async def test_context_detection(self) -> None: + async def test_context_detection(self): """Test that context parameters are properly detected.""" mcp = MCPServer() @@ -1008,7 +1008,7 @@ def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover tool = mcp._tool_manager.add_tool(tool_with_context) assert tool.context_kwarg == "ctx" - async def test_context_injection(self) -> None: + async def test_context_injection(self): """Test that context is properly injected into tool calls.""" mcp = MCPServer() @@ -1025,7 +1025,7 @@ def tool_with_context(x: int, ctx: Context) -> str: assert "Request" in content.text assert "42" in content.text - async def test_async_context(self) -> None: + async def test_async_context(self): """Test that context works in async functions.""" mcp = MCPServer() @@ -1042,7 +1042,7 @@ async def async_tool(x: int, ctx: Context) -> str: assert "Async request" in content.text assert "42" in content.text - async def test_context_logging(self) -> None: + async def test_context_logging(self): """Test that context logging methods work.""" mcp = MCPServer() @@ -1069,7 +1069,7 @@ async def logging_tool(msg: str, ctx: Context) -> str: mock_log.assert_any_call(level="warning", data="Warning message", logger=None, related_request_id="1") mock_log.assert_any_call(level="error", data="Error message", logger=None, related_request_id="1") - async def test_optional_context(self) -> None: + async def test_optional_context(self): """Test that context is optional.""" mcp = MCPServer() @@ -1084,7 +1084,7 @@ def no_context(x: int) -> int: assert isinstance(content, TextContent) assert content.text == "42" - async def test_context_resource_access(self) -> None: + async def test_context_resource_access(self): """Test that context can access resources.""" mcp = MCPServer() @@ -1107,7 +1107,7 @@ async def tool_with_resource(ctx: Context) -> str: assert isinstance(content, TextContent) assert "Read resource: resource data" in content.text - async def test_resource_with_context(self) -> None: + async def test_resource_with_context(self): """Test that resources can receive context parameter.""" mcp = MCPServer() @@ -1133,7 +1133,7 @@ def resource_with_context(name: str, ctx: Context) -> str: # Should have either request_id or indication that context was injected assert "Resource test - context injected" == content.text - async def test_resource_without_context(self) -> None: + async def test_resource_without_context(self): """Test that resources without context work normally.""" mcp = MCPServer() @@ -1160,7 +1160,7 @@ def resource_no_context(name: str) -> str: ) ) - async def test_resource_context_custom_name(self) -> None: + async def test_resource_context_custom_name(self): """Test resource context with custom parameter name.""" mcp = MCPServer() @@ -1188,7 +1188,7 @@ def resource_custom_ctx(id: str, my_ctx: Context) -> str: ) ) - async def test_prompt_with_context(self) -> None: + async def test_prompt_with_context(self): """Test that prompts can receive context parameter.""" mcp = MCPServer() @@ -1208,7 +1208,7 @@ def prompt_with_context(text: str, ctx: Context) -> str: assert isinstance(content, TextContent) assert "Prompt 'test' - context injected" in content.text - async def test_prompt_without_context(self) -> None: + async def test_prompt_without_context(self): """Test that prompts without context work normally.""" mcp = MCPServer() @@ -1230,7 +1230,7 @@ def prompt_no_context(text: str) -> str: class TestServerPrompts: """Test prompt functionality in MCPServer server.""" - async def test_get_prompt_direct_call_without_context(self) -> None: + async def test_get_prompt_direct_call_without_context(self): """Test calling mcp.get_prompt() directly without passing context.""" mcp = MCPServer() @@ -1243,7 +1243,7 @@ def fn() -> str: assert isinstance(content, TextContent) assert content.text == "Hello, world!" - async def test_prompt_decorator(self) -> None: + async def test_prompt_decorator(self): """Test that the prompt decorator registers prompts correctly.""" mcp = MCPServer() @@ -1259,7 +1259,7 @@ def fn() -> str: assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" - async def test_prompt_decorator_with_name(self) -> None: + async def test_prompt_decorator_with_name(self): """Test prompt decorator with custom name.""" mcp = MCPServer() @@ -1274,7 +1274,7 @@ def fn() -> str: assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" - async def test_prompt_decorator_with_description(self) -> None: + async def test_prompt_decorator_with_description(self): """Test prompt decorator with custom description.""" mcp = MCPServer() @@ -1289,7 +1289,7 @@ def fn() -> str: assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" - def test_prompt_decorator_error(self) -> None: + def test_prompt_decorator_error(self): """Test error when decorator is used incorrectly.""" mcp = MCPServer() with pytest.raises(TypeError, match="decorator was used incorrectly"): @@ -1297,7 +1297,7 @@ def test_prompt_decorator_error(self) -> None: @mcp.prompt # type: ignore def fn() -> str: ... # pragma: no branch - async def test_list_prompts(self) -> None: + async def test_list_prompts(self): """Test listing prompts through MCP protocol.""" mcp = MCPServer() @@ -1321,7 +1321,7 @@ def fn(name: str, optional: str = "default") -> str: ... # pragma: no branch ) ) - async def test_get_prompt(self) -> None: + async def test_get_prompt(self): """Test getting a prompt through MCP protocol.""" mcp = MCPServer() @@ -1338,7 +1338,7 @@ def fn(name: str) -> str: ) ) - async def test_get_prompt_with_description(self) -> None: + async def test_get_prompt_with_description(self): """Test getting a prompt through MCP protocol.""" mcp = MCPServer() @@ -1350,7 +1350,7 @@ def fn(name: str) -> str: result = await client.get_prompt("fn", {"name": "World"}) assert result.description == "Test prompt description" - async def test_get_prompt_with_docstring_description(self) -> None: + async def test_get_prompt_with_docstring_description(self): """Test prompt uses docstring as description when not explicitly provided.""" mcp = MCPServer() @@ -1368,7 +1368,7 @@ def fn(name: str) -> str: ) ) - async def test_get_prompt_with_resource(self) -> None: + async def test_get_prompt_with_resource(self): """Test getting a prompt that returns resource content.""" mcp = MCPServer() @@ -1399,7 +1399,7 @@ def fn() -> Message: ) ) - async def test_get_unknown_prompt(self) -> None: + async def test_get_unknown_prompt(self): """Test error when getting unknown prompt.""" mcp = MCPServer() @@ -1407,7 +1407,7 @@ async def test_get_unknown_prompt(self) -> None: with pytest.raises(MCPError, match="Unknown prompt"): await client.get_prompt("unknown") - async def test_get_prompt_missing_args(self) -> None: + async def test_get_prompt_missing_args(self): """Test error when required arguments are missing.""" mcp = MCPServer() @@ -1452,7 +1452,7 @@ def test_streamable_http_no_redirect() -> None: assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" -async def test_report_progress_passes_related_request_id() -> None: +async def test_report_progress_passes_related_request_id(): """Test that report_progress passes the request_id as related_request_id. Without related_request_id, the streamable HTTP transport cannot route diff --git a/tests/server/mcpserver/test_title.py b/tests/server/mcpserver/test_title.py index 70218fff1..662464757 100644 --- a/tests/server/mcpserver/test_title.py +++ b/tests/server/mcpserver/test_title.py @@ -10,7 +10,7 @@ @pytest.mark.anyio -async def test_server_name_title_description_version() -> None: +async def test_server_name_title_description_version(): """Test that server title and description are set and retrievable correctly.""" mcp = MCPServer( name="TestServer", @@ -34,7 +34,7 @@ async def test_server_name_title_description_version() -> None: @pytest.mark.anyio -async def test_tool_title_precedence() -> None: +async def test_tool_title_precedence(): """Test that tool title precedence works correctly: title > annotations.title > name.""" # Create server with various tool configurations mcp = MCPServer(name="TitleTestServer") @@ -88,7 +88,7 @@ def tool_with_both(message: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_prompt_title() -> None: +async def test_prompt_title(): """Test that prompt titles work correctly.""" mcp = MCPServer(name="PromptTitleServer") @@ -121,7 +121,7 @@ def titled_prompt(topic: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_resource_title() -> None: +async def test_resource_title(): """Test that resource titles work correctly.""" mcp = MCPServer(name="ResourceTitleServer") @@ -194,7 +194,7 @@ def titled_dynamic_resource(id: str) -> str: # pragma: no cover @pytest.mark.anyio -async def test_get_display_name_utility() -> None: +async def test_get_display_name_utility(): """Test the get_display_name utility function.""" # Test tool precedence: title > annotations.title > name diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index 781386c2e..e4dfd4ff9 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -15,7 +15,7 @@ class TestAddTools: - def test_basic_function(self) -> None: + def test_basic_function(self): """Test registering and running a basic function.""" def sum(a: int, b: int) -> int: # pragma: no cover @@ -33,7 +33,7 @@ def sum(a: int, b: int) -> int: # pragma: no cover assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" - def test_init_with_tools(self, caplog: pytest.LogCaptureFixture) -> None: + def test_init_with_tools(self, caplog: pytest.LogCaptureFixture): def sum(a: int, b: int) -> int: # pragma: no cover return a + b @@ -64,7 +64,7 @@ class AddArguments(ArgModelBase): assert "Tool already exists: sum" in caplog.text @pytest.mark.anyio - async def test_async_function(self) -> None: + async def test_async_function(self): """Test registering and running an async function.""" async def fetch_data(url: str) -> str: # pragma: no cover @@ -81,7 +81,7 @@ async def fetch_data(url: str) -> str: # pragma: no cover assert tool.is_async is True assert tool.parameters["properties"]["url"]["type"] == "string" - def test_pydantic_model_function(self) -> None: + def test_pydantic_model_function(self): """Test registering a function that takes a Pydantic model.""" class UserInput(BaseModel): @@ -104,11 +104,11 @@ def create_user(user: UserInput, flag: bool) -> dict[str, Any]: # pragma: no co assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] assert "flag" in tool.parameters["properties"] - def test_add_callable_object(self) -> None: + def test_add_callable_object(self): """Test registering a callable object.""" class MyTool: - def __init__(self) -> None: + def __init__(self): self.__name__ = "MyTool" def __call__(self, x: int) -> int: # pragma: no cover @@ -121,11 +121,11 @@ def __call__(self, x: int) -> int: # pragma: no cover assert tool.parameters["properties"]["x"]["type"] == "integer" @pytest.mark.anyio - async def test_add_async_callable_object(self) -> None: + async def test_add_async_callable_object(self): """Test registering an async callable object.""" class MyAsyncTool: - def __init__(self) -> None: + def __init__(self): self.__name__ = "MyAsyncTool" async def __call__(self, x: int) -> int: # pragma: no cover @@ -137,22 +137,22 @@ async def __call__(self, x: int) -> int: # pragma: no cover assert tool.is_async is True assert tool.parameters["properties"]["x"]["type"] == "integer" - def test_add_invalid_tool(self) -> None: + def test_add_invalid_tool(self): manager = ToolManager() with pytest.raises(AttributeError): manager.add_tool(1) # type: ignore - def test_add_lambda(self) -> None: + def test_add_lambda(self): manager = ToolManager() tool = manager.add_tool(lambda x: x, name="my_tool") # type: ignore[reportUnknownLambdaType] assert tool.name == "my_tool" - def test_add_lambda_with_no_name(self) -> None: + def test_add_lambda_with_no_name(self): manager = ToolManager() with pytest.raises(ValueError, match="You must provide a name for lambda functions"): manager.add_tool(lambda x: x) # type: ignore[reportUnknownLambdaType] - def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture) -> None: + def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test warning on duplicate tools.""" def f(x: int) -> int: # pragma: no cover @@ -164,7 +164,7 @@ def f(x: int) -> int: # pragma: no cover manager.add_tool(f) assert "Tool already exists: f" in caplog.text - def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture) -> None: + def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate tools.""" def f(x: int) -> int: # pragma: no cover @@ -180,7 +180,7 @@ def f(x: int) -> int: # pragma: no cover class TestCallTools: @pytest.mark.anyio - async def test_call_tool(self) -> None: + async def test_call_tool(self): def sum(a: int, b: int) -> int: """Add two numbers.""" return a + b @@ -191,7 +191,7 @@ def sum(a: int, b: int) -> int: assert result == 3 @pytest.mark.anyio - async def test_call_async_tool(self) -> None: + async def test_call_async_tool(self): async def double(n: int) -> int: """Double a number.""" return n * 2 @@ -202,9 +202,9 @@ async def double(n: int) -> int: assert result == 10 @pytest.mark.anyio - async def test_call_object_tool(self) -> None: + async def test_call_object_tool(self): class MyTool: - def __init__(self) -> None: + def __init__(self): self.__name__ = "MyTool" def __call__(self, x: int) -> int: @@ -216,9 +216,9 @@ def __call__(self, x: int) -> int: assert result == 10 @pytest.mark.anyio - async def test_call_async_object_tool(self) -> None: + async def test_call_async_object_tool(self): class MyAsyncTool: - def __init__(self) -> None: + def __init__(self): self.__name__ = "MyAsyncTool" async def __call__(self, x: int) -> int: @@ -230,7 +230,7 @@ async def __call__(self, x: int) -> int: assert result == 10 @pytest.mark.anyio - async def test_call_tool_with_default_args(self) -> None: + async def test_call_tool_with_default_args(self): def sum(a: int, b: int = 1) -> int: """Add two numbers.""" return a + b @@ -241,7 +241,7 @@ def sum(a: int, b: int = 1) -> int: assert result == 2 @pytest.mark.anyio - async def test_call_tool_with_missing_args(self) -> None: + async def test_call_tool_with_missing_args(self): def sum(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b @@ -252,13 +252,13 @@ def sum(a: int, b: int) -> int: # pragma: no cover await manager.call_tool("sum", {"a": 1}, Context()) @pytest.mark.anyio - async def test_call_unknown_tool(self) -> None: + async def test_call_unknown_tool(self): manager = ToolManager() with pytest.raises(ToolError): await manager.call_tool("unknown", {"a": 1}, Context()) @pytest.mark.anyio - async def test_call_tool_with_list_int_input(self) -> None: + async def test_call_tool_with_list_int_input(self): def sum_vals(vals: list[int]) -> int: return sum(vals) @@ -271,7 +271,7 @@ def sum_vals(vals: list[int]) -> int: assert result == 6 @pytest.mark.anyio - async def test_call_tool_with_list_str_or_str_input(self) -> None: + async def test_call_tool_with_list_str_or_str_input(self): def concat_strs(vals: list[str] | str) -> str: return vals if isinstance(vals, str) else "".join(vals) @@ -288,7 +288,7 @@ def concat_strs(vals: list[str] | str) -> str: assert result == '"a"' @pytest.mark.anyio - async def test_call_tool_with_complex_model(self) -> None: + async def test_call_tool_with_complex_model(self): class MyShrimpTank(BaseModel): class Shrimp(BaseModel): name: str @@ -317,7 +317,7 @@ def name_shrimp(tank: MyShrimpTank) -> list[str]: class TestToolSchema: @pytest.mark.anyio - async def test_context_arg_excluded_from_schema(self) -> None: + async def test_context_arg_excluded_from_schema(self): def something(a: int, ctx: Context) -> int: # pragma: no cover return a @@ -331,7 +331,7 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover class TestContextHandling: """Test context handling in the tool manager.""" - def test_context_parameter_detection(self) -> None: + def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" @@ -355,7 +355,7 @@ def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, Reques assert tool.context_kwarg == "ctx" @pytest.mark.anyio - async def test_context_injection(self) -> None: + async def test_context_injection(self): """Test that context is properly injected during tool execution.""" def tool_with_context(x: int, ctx: Context) -> str: @@ -369,7 +369,7 @@ def tool_with_context(x: int, ctx: Context) -> str: assert result == "42" @pytest.mark.anyio - async def test_context_injection_async(self) -> None: + async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" async def async_tool(x: int, ctx: Context) -> str: @@ -383,7 +383,7 @@ async def async_tool(x: int, ctx: Context) -> str: assert result == "42" @pytest.mark.anyio - async def test_context_error_handling(self) -> None: + async def test_context_error_handling(self): """Test error handling when context injection fails.""" def tool_with_context(x: int, ctx: Context) -> str: @@ -397,7 +397,7 @@ def tool_with_context(x: int, ctx: Context) -> str: class TestToolAnnotations: - def test_tool_annotations(self) -> None: + def test_tool_annotations(self): """Test that tool annotations are correctly added to tools.""" def read_data(path: str) -> str: # pragma: no cover @@ -419,7 +419,7 @@ def read_data(path: str) -> str: # pragma: no cover assert tool.annotations.open_world_hint is False @pytest.mark.anyio - async def test_tool_annotations_in_mcpserver(self) -> None: + async def test_tool_annotations_in_mcpserver(self): """Test that tool annotations are included in MCPTool conversion.""" app = MCPServer() @@ -440,7 +440,7 @@ class TestStructuredOutput: """Test structured output functionality in tools.""" @pytest.mark.anyio - async def test_tool_with_basemodel_output(self) -> None: + async def test_tool_with_basemodel_output(self): """Test tool with BaseModel return type.""" class UserOutput(BaseModel): @@ -458,7 +458,7 @@ def get_user(user_id: int) -> UserOutput: assert len(result) == 2 and result[1] == {"name": "John", "age": 30} @pytest.mark.anyio - async def test_tool_with_primitive_output(self) -> None: + async def test_tool_with_primitive_output(self): """Test tool with primitive return type.""" def double_number(n: int) -> int: @@ -473,7 +473,7 @@ def double_number(n: int) -> int: assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10} @pytest.mark.anyio - async def test_tool_with_typeddict_output(self) -> None: + async def test_tool_with_typeddict_output(self): """Test tool with TypedDict return type.""" class UserDict(TypedDict): @@ -492,7 +492,7 @@ def get_user_dict(user_id: int) -> UserDict: assert result == expected_output @pytest.mark.anyio - async def test_tool_with_dataclass_output(self) -> None: + async def test_tool_with_dataclass_output(self): """Test tool with dataclass return type.""" @dataclass @@ -513,7 +513,7 @@ def get_person() -> Person: assert len(result) == 2 and result[1] == expected_output @pytest.mark.anyio - async def test_tool_with_list_output(self) -> None: + async def test_tool_with_list_output(self): """Test tool with list return type.""" expected_list = [1, 2, 3, 4, 5] @@ -531,7 +531,7 @@ def get_numbers() -> list[int]: assert isinstance(result[0][0], TextContent) and result[1] == expected_output @pytest.mark.anyio - async def test_tool_without_structured_output(self) -> None: + async def test_tool_without_structured_output(self): """Test that tools work normally when structured_output=False.""" def get_dict() -> dict[str, Any]: @@ -544,7 +544,7 @@ def get_dict() -> dict[str, Any]: assert isinstance(result, dict) assert result == {"key": "value"} - def test_tool_output_schema_property(self) -> None: + def test_tool_output_schema_property(self): """Test that Tool.output_schema property works correctly.""" class UserOutput(BaseModel): @@ -567,7 +567,7 @@ def get_user() -> UserOutput: # pragma: no cover assert tool.output_schema == expected_schema @pytest.mark.anyio - async def test_tool_with_dict_str_any_output(self) -> None: + async def test_tool_with_dict_str_any_output(self): """Test tool with dict[str, Any] return type.""" def get_config() -> dict[str, Any]: @@ -592,7 +592,7 @@ def get_config() -> dict[str, Any]: assert result == expected @pytest.mark.anyio - async def test_tool_with_dict_str_typed_output(self) -> None: + async def test_tool_with_dict_str_typed_output(self): """Test tool with dict[str, T] return type for specific T.""" def get_scores() -> dict[str, int]: @@ -620,7 +620,7 @@ def get_scores() -> dict[str, int]: class TestToolMetadata: """Test tool metadata functionality.""" - def test_add_tool_with_metadata(self) -> None: + def test_add_tool_with_metadata(self): """Test adding a tool with metadata via ToolManager.""" def process_data(input_data: str) -> str: # pragma: no cover @@ -637,7 +637,7 @@ def process_data(input_data: str) -> str: # pragma: no cover assert tool.meta["ui"]["type"] == "form" assert tool.meta["version"] == "1.0" - def test_add_tool_without_metadata(self) -> None: + def test_add_tool_without_metadata(self): """Test that tools without metadata have None as meta value.""" def simple_tool(x: int) -> int: # pragma: no cover @@ -650,7 +650,7 @@ def simple_tool(x: int) -> int: # pragma: no cover assert tool.meta is None @pytest.mark.anyio - async def test_metadata_in_mcpserver_decorator(self) -> None: + async def test_metadata_in_mcpserver_decorator(self): """Test that metadata is correctly added via MCPServer.tool decorator.""" app = MCPServer() @@ -671,7 +671,7 @@ def upload_file(filename: str) -> str: # pragma: no cover assert tool.meta["priority"] == "high" @pytest.mark.anyio - async def test_metadata_in_list_tools(self) -> None: + async def test_metadata_in_list_tools(self): """Test that metadata is included in MCPTool when listing tools.""" app = MCPServer() @@ -692,7 +692,7 @@ def analyze_text(text: str) -> dict[str, Any]: # pragma: no cover assert tools[0].meta == metadata @pytest.mark.anyio - async def test_multiple_tools_with_different_metadata(self) -> None: + async def test_multiple_tools_with_different_metadata(self): """Test multiple tools with different metadata values.""" app = MCPServer() @@ -725,7 +725,7 @@ def tool3(z: bool) -> bool: # pragma: no cover assert tools_by_name["tool2"].meta == metadata2 assert tools_by_name["tool3"].meta is None - def test_metadata_with_complex_structure(self) -> None: + def test_metadata_with_complex_structure(self): """Test metadata with complex nested structures.""" def complex_tool(data: str) -> str: # pragma: no cover @@ -754,7 +754,7 @@ def complex_tool(data: str) -> str: # pragma: no cover assert "read" in tool.meta["permissions"] assert "data-processing" in tool.meta["tags"] - def test_metadata_empty_dict(self) -> None: + def test_metadata_empty_dict(self): """Test that empty dict metadata is preserved.""" def tool_with_empty_meta(x: int) -> int: # pragma: no cover @@ -768,7 +768,7 @@ def tool_with_empty_meta(x: int) -> int: # pragma: no cover assert tool.meta == {} @pytest.mark.anyio - async def test_metadata_with_annotations(self) -> None: + async def test_metadata_with_annotations(self): """Test that metadata and annotations can coexist.""" app = MCPServer() @@ -792,7 +792,7 @@ def combined_tool(data: str) -> str: # pragma: no cover class TestRemoveTools: """Test tool removal functionality in the tool manager.""" - def test_remove_existing_tool(self) -> None: + def test_remove_existing_tool(self): """Test removing an existing tool.""" def add(a: int, b: int) -> int: # pragma: no cover @@ -813,14 +813,14 @@ def add(a: int, b: int) -> int: # pragma: no cover assert manager.get_tool("add") is None assert len(manager.list_tools()) == 0 - def test_remove_nonexistent_tool(self) -> None: + def test_remove_nonexistent_tool(self): """Test removing a non-existent tool raises ToolError.""" manager = ToolManager() with pytest.raises(ToolError, match="Unknown tool: nonexistent"): manager.remove_tool("nonexistent") - def test_remove_tool_from_multiple_tools(self) -> None: + def test_remove_tool_from_multiple_tools(self): """Test removing one tool when multiple tools exist.""" def add(a: int, b: int) -> int: # pragma: no cover @@ -856,7 +856,7 @@ def divide(a: int, b: int) -> float: # pragma: no cover assert manager.get_tool("divide") is not None @pytest.mark.anyio - async def test_call_removed_tool_raises_error(self) -> None: + async def test_call_removed_tool_raises_error(self): """Test that calling a removed tool raises ToolError.""" def greet(name: str) -> str: @@ -877,7 +877,7 @@ def greet(name: str) -> str: with pytest.raises(ToolError, match="Unknown tool: greet"): await manager.call_tool("greet", {"name": "World"}, Context()) - def test_remove_tool_case_sensitive(self) -> None: + def test_remove_tool_case_sensitive(self): """Test that tool removal is case-sensitive.""" def test_func() -> str: # pragma: no cover diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index f2a02d580..af90dc208 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -13,7 +13,7 @@ @pytest.mark.anyio -async def test_url_elicitation_accept() -> None: +async def test_url_elicitation_accept(): """Test URL mode elicitation with user acceptance.""" mcp = MCPServer(name="URLElicitationServer") @@ -28,7 +28,7 @@ async def request_api_key(ctx: Context) -> str: return f"User {result.action}" # Create elicitation callback that accepts URL mode - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" assert params.url == "https://example.com/api_key_setup" assert params.elicitation_id == "test-elicitation-001" @@ -43,7 +43,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_decline() -> None: +async def test_url_elicitation_decline(): """Test URL mode elicitation with user declining.""" mcp = MCPServer(name="URLElicitationDeclineServer") @@ -57,7 +57,7 @@ async def oauth_flow(ctx: Context) -> str: # Test only checks decline path return f"User {result.action} authorization" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="decline") @@ -69,7 +69,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_cancel() -> None: +async def test_url_elicitation_cancel(): """Test URL mode elicitation with user cancelling.""" mcp = MCPServer(name="URLElicitationCancelServer") @@ -83,7 +83,7 @@ async def payment_flow(ctx: Context) -> str: # Test only checks cancel path return f"User {result.action} payment" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="cancel") @@ -95,7 +95,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_helper_function() -> None: +async def test_url_elicitation_helper_function(): """Test the elicit_url helper function.""" mcp = MCPServer(name="URLElicitationHelperServer") @@ -110,7 +110,7 @@ async def setup_credentials(ctx: Context) -> str: # Test only checks accept path - return the type name return type(result).__name__ - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="accept") async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -121,7 +121,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_no_content_in_response() -> None: +async def test_url_no_content_in_response(): """Test that URL mode elicitation responses don't include content field.""" mcp = MCPServer(name="URLContentCheckServer") @@ -137,7 +137,7 @@ async def check_url_response(ctx: Context) -> str: assert result.content is None return f"Action: {result.action}, Content: {result.content}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify that this is URL mode assert params.mode == "url" assert isinstance(params, types.ElicitRequestURLParams) @@ -155,7 +155,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_form_mode_still_works() -> None: +async def test_form_mode_still_works(): """Ensure form mode elicitation still works after SEP 1036.""" mcp = MCPServer(name="FormModeBackwardCompatServer") @@ -170,7 +170,7 @@ async def ask_name(ctx: Context) -> str: assert result.data is not None return f"Hello, {result.data.name}!" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify form mode parameters assert params.mode == "form" assert isinstance(params, types.ElicitRequestFormParams) @@ -186,7 +186,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_elicit_complete_notification() -> None: +async def test_elicit_complete_notification(): """Test that elicitation completion notifications can be sent and received.""" mcp = MCPServer(name="ElicitCompleteServer") @@ -206,7 +206,7 @@ async def trigger_elicitation(ctx: Context) -> str: return "Elicitation completed" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="accept") # pragma: no cover async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -223,7 +223,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_url_elicitation_required_error_code() -> None: +async def test_url_elicitation_required_error_code(): """Test that the URL_ELICITATION_REQUIRED error code is correct.""" # Verify the error code matches the specification (SEP 1036) assert types.URL_ELICITATION_REQUIRED == -32042, ( @@ -232,7 +232,7 @@ async def test_url_elicitation_required_error_code() -> None: @pytest.mark.anyio -async def test_elicit_url_typed_results() -> None: +async def test_elicit_url_typed_results(): """Test that elicit_url returns properly typed result objects.""" mcp = MCPServer(name="TypedResultsServer") @@ -263,7 +263,7 @@ async def test_cancel(ctx: Context) -> str: return "Not cancelled" # pragma: no cover # Test declined result - async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="decline") async with Client(mcp, elicitation_callback=decline_callback) as client: @@ -273,7 +273,7 @@ async def decline_callback(context: RequestContext[ClientSession], params: Elici assert result.content[0].text == "Declined" # Test cancelled result - async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="cancel") async with Client(mcp, elicitation_callback=cancel_callback) as client: @@ -284,7 +284,7 @@ async def cancel_callback(context: RequestContext[ClientSession], params: Elicit @pytest.mark.anyio -async def test_deprecated_elicit_method() -> None: +async def test_deprecated_elicit_method(): """Test the deprecated elicit() method for backward compatibility.""" mcp = MCPServer(name="DeprecatedElicitServer") @@ -303,7 +303,7 @@ async def use_deprecated_elicit(ctx: Context) -> str: return f"Email: {result.content.get('email', 'none')}" return "No email provided" # pragma: no cover - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify this is form mode assert params.mode == "form" assert params.requested_schema is not None @@ -317,7 +317,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_ctx_elicit_url_convenience_method() -> None: +async def test_ctx_elicit_url_convenience_method(): """Test the ctx.elicit_url() convenience method (vs ctx.session.elicit_url()).""" mcp = MCPServer(name="CtxElicitUrlServer") @@ -331,7 +331,7 @@ async def direct_elicit_url(ctx: Context) -> str: ) return f"Result: {result.action}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams) -> ElicitResult: + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") diff --git a/tests/server/mcpserver/test_url_elicitation_error_throw.py b/tests/server/mcpserver/test_url_elicitation_error_throw.py index 173849eec..1f45fd60f 100644 --- a/tests/server/mcpserver/test_url_elicitation_error_throw.py +++ b/tests/server/mcpserver/test_url_elicitation_error_throw.py @@ -9,7 +9,7 @@ @pytest.mark.anyio -async def test_url_elicitation_error_thrown_from_tool() -> None: +async def test_url_elicitation_error_thrown_from_tool(): """Test that UrlElicitationRequiredError raised from a tool is received as MCPError by client.""" mcp = MCPServer(name="UrlElicitationErrorServer") @@ -50,7 +50,7 @@ async def connect_service(service_name: str, ctx: Context) -> str: @pytest.mark.anyio -async def test_url_elicitation_error_from_error() -> None: +async def test_url_elicitation_error_from_error(): """Test that client can reconstruct UrlElicitationRequiredError from MCPError.""" mcp = MCPServer(name="UrlElicitationErrorServer") @@ -91,7 +91,7 @@ async def multi_auth(ctx: Context) -> str: @pytest.mark.anyio -async def test_normal_exceptions_still_return_error_result() -> None: +async def test_normal_exceptions_still_return_error_result(): """Test that normal exceptions still return CallToolResult with is_error=True.""" mcp = MCPServer(name="NormalErrorServer") diff --git a/tests/server/mcpserver/tools/test_base.py b/tests/server/mcpserver/tools/test_base.py index dce688554..22d5f973e 100644 --- a/tests/server/mcpserver/tools/test_base.py +++ b/tests/server/mcpserver/tools/test_base.py @@ -2,7 +2,7 @@ from mcp.server.mcpserver.tools.base import Tool -def test_context_detected_in_union_annotation() -> None: +def test_context_detected_in_union_annotation(): def my_tool(x: int, ctx: Context | None) -> str: raise NotImplementedError diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 33cfc56b4..cff5a37c1 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -27,7 +27,7 @@ @pytest.mark.anyio -async def test_server_remains_functional_after_cancel() -> None: +async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" # Track tool calls @@ -61,7 +61,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar async with Client(server) as client: # First request (will be cancelled) - async def first_request() -> None: + async def first_request(): try: await client.session.send_request( CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), @@ -100,7 +100,7 @@ async def first_request() -> None: @pytest.mark.anyio -async def test_server_cancels_in_flight_handlers_on_transport_close() -> None: +async def test_server_cancels_in_flight_handlers_on_transport_close(): """When the transport closes mid-request, server.run() must cancel in-flight handlers rather than join on them. @@ -129,7 +129,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) - async def run_server() -> None: + async def run_server(): await server.run(server_read, server_write, server.create_initialization_options()) server_run_returned.set() @@ -173,7 +173,7 @@ async def run_server() -> None: @pytest.mark.anyio -async def test_server_handles_transport_close_with_pending_server_to_client_requests() -> None: +async def test_server_handles_transport_close_with_pending_server_to_client_requests(): """When the transport closes while handlers are blocked on server→client requests (sampling, roots, elicitation), server.run() must still exit cleanly. @@ -203,7 +203,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) - async def run_server() -> None: + async def run_server(): await server.run(server_read, server_write, server.create_initialization_options()) server_run_returned.set() diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 3f61b6dff..a01d0d4d7 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -14,7 +14,7 @@ @pytest.mark.anyio -async def test_completion_handler_receives_context() -> None: +async def test_completion_handler_receives_context(): """Test that the completion handler receives context correctly.""" # Track what the handler receives received_params: CompleteRequestParams | None = None @@ -42,7 +42,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa @pytest.mark.anyio -async def test_completion_backward_compatibility() -> None: +async def test_completion_backward_compatibility(): """Test that completion works without context (backward compatibility).""" context_was_none = False @@ -65,7 +65,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa @pytest.mark.anyio -async def test_dependent_completion_scenario() -> None: +async def test_dependent_completion_scenario(): """Test a real-world scenario with dependent completions.""" async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: @@ -120,7 +120,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa @pytest.mark.anyio -async def test_completion_error_on_missing_context() -> None: +async def test_completion_error_on_missing_context(): """Test that server can raise error when required context is missing.""" async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 9539d0eea..0d8790504 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -27,7 +27,7 @@ @pytest.mark.anyio -async def test_lowlevel_server_lifespan() -> None: +async def test_lowlevel_server_lifespan(): """Test that lifespan works in low-level server.""" @asynccontextmanager @@ -58,7 +58,7 @@ async def check_lifespan( # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: - async def run_server() -> None: + async def run_server(): await server.run( receive_stream1, send_stream2, @@ -121,7 +121,7 @@ async def run_server() -> None: @pytest.mark.anyio -async def test_mcpserver_server_lifespan() -> None: +async def test_mcpserver_server_lifespan(): """Test that lifespan works in MCPServer server.""" @asynccontextmanager @@ -152,7 +152,7 @@ def check_lifespan(ctx: Context) -> bool: # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: - async def run_server() -> None: + async def run_server(): await server._lowlevel_server.run( receive_stream1, send_stream2, diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 6834d3ddc..46925916d 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_true() -> None: +async def test_exception_handling_with_raise_exceptions_true(): """Transport exceptions are re-raised when raise_exceptions=True.""" server = Server("test-server") session = Mock(spec=ServerSession) @@ -23,7 +23,7 @@ async def test_exception_handling_with_raise_exceptions_true() -> None: @pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_false() -> None: +async def test_exception_handling_with_raise_exceptions_false(): """Transport exceptions are logged locally but not sent to the client. The transport that reported the error is likely broken; writing back @@ -40,7 +40,7 @@ async def test_exception_handling_with_raise_exceptions_false() -> None: @pytest.mark.anyio -async def test_normal_message_handling_not_affected() -> None: +async def test_normal_message_handling_not_affected(): """Test that normal messages still work correctly""" server = Server("test-server") session = Mock(spec=ServerSession) @@ -62,7 +62,7 @@ async def test_normal_message_handling_not_affected() -> None: @pytest.mark.anyio -async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes() -> None: +async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes(): """Regression test for #1967 / #2064. Exercises the real Server.run() path with real memory streams, reproducing diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 807cc7502..705abdfe8 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -8,7 +8,7 @@ @pytest.mark.anyio -async def test_lowlevel_server_tool_annotations() -> None: +async def test_lowlevel_server_tool_annotations(): """Test that tool annotations work in low-level server.""" async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 887b8527a..102a58d03 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.anyio -async def test_read_resource_text() -> None: +async def test_read_resource_text(): async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: return ReadResourceResult( contents=[TextResourceContents(uri=str(params.uri), text="Hello World", mime_type="text/plain")] @@ -32,7 +32,7 @@ async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRe assert content.mime_type == "text/plain" -async def test_read_resource_binary() -> None: +async def test_read_resource_binary(): binary_data = b"Hello World" async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 5cdd5b1ce..a2786d865 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -23,7 +23,7 @@ @pytest.mark.anyio -async def test_server_session_initialize() -> None: +async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -36,7 +36,7 @@ async def message_handler( # pragma: no cover received_initialized = False - async def run_server() -> None: + async def run_server(): nonlocal received_initialized async with ServerSession( @@ -77,7 +77,7 @@ async def run_server() -> None: @pytest.mark.anyio -async def test_server_capabilities() -> None: +async def test_server_capabilities(): notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} @@ -129,7 +129,7 @@ async def noop_completion(ctx: ServerRequestContext, params: types.CompleteReque @pytest.mark.anyio -async def test_server_session_initialize_with_older_protocol_version() -> None: +async def test_server_session_initialize_with_older_protocol_version(): """Test that server accepts and responds with older protocol (2024-11-05).""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -137,7 +137,7 @@ async def test_server_session_initialize_with_older_protocol_version() -> None: received_initialized = False received_protocol_version = None - async def run_server() -> None: + async def run_server(): nonlocal received_initialized async with ServerSession( @@ -159,7 +159,7 @@ async def run_server() -> None: received_initialized = True return - async def mock_client() -> None: + async def mock_client(): nonlocal received_protocol_version # Send initialization request with older protocol version (2024-11-05) @@ -208,7 +208,7 @@ async def mock_client() -> None: @pytest.mark.anyio -async def test_ping_request_before_initialization() -> None: +async def test_ping_request_before_initialization(): """Test that ping requests are allowed before initialization is complete.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -216,7 +216,7 @@ async def test_ping_request_before_initialization() -> None: ping_response_received = False ping_response_id = None - async def run_server() -> None: + async def run_server(): async with ServerSession( client_to_server_receive, server_to_client_send, @@ -239,7 +239,7 @@ async def run_server() -> None: await message.respond(types.EmptyResult()) return - async def mock_client() -> None: + async def mock_client(): nonlocal ping_response_received, ping_response_id # Send ping request before any initialization @@ -267,7 +267,7 @@ async def mock_client() -> None: @pytest.mark.anyio -async def test_create_message_tool_result_validation() -> None: +async def test_create_message_tool_result_validation(): """Test tool_use/tool_result validation in create_message.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -392,7 +392,7 @@ async def test_create_message_tool_result_validation() -> None: @pytest.mark.anyio -async def test_create_message_without_tools_capability() -> None: +async def test_create_message_without_tools_capability(): """Test that create_message raises MCPError when tools are provided without capability.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -442,7 +442,7 @@ async def test_create_message_without_tools_capability() -> None: @pytest.mark.anyio -async def test_other_requests_blocked_before_initialization() -> None: +async def test_other_requests_blocked_before_initialization(): """Test that non-ping requests are still blocked before initialization.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -450,7 +450,7 @@ async def test_other_requests_blocked_before_initialization() -> None: error_response_received = False error_code = None - async def run_server() -> None: + async def run_server(): async with ServerSession( client_to_server_receive, server_to_client_send, @@ -464,7 +464,7 @@ async def run_server() -> None: # No need to process incoming_messages since the error is handled automatically await anyio.sleep(0.1) # Give time for the request to be processed - async def mock_client() -> None: + async def mock_client(): nonlocal error_response_received, error_code # Try to send a non-ping request before initialization diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py index 0dcaf3097..81041152b 100644 --- a/tests/server/test_session_race_condition.py +++ b/tests/server/test_session_race_condition.py @@ -18,7 +18,7 @@ @pytest.mark.anyio -async def test_request_immediately_after_initialize_response() -> None: +async def test_request_immediately_after_initialize_response(): """Test that requests are accepted immediately after initialize response. This reproduces the race condition in stateful HTTP mode where: @@ -37,7 +37,7 @@ async def test_request_immediately_after_initialize_response() -> None: tools_list_success = False error_received = None - async def run_server() -> None: + async def run_server(): nonlocal tools_list_success async with ServerSession( @@ -79,7 +79,7 @@ async def run_server() -> None: # Done - exit gracefully return - async def mock_client() -> None: + async def mock_client(): nonlocal error_received # Step 1: Send InitializeRequest diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index a5e2c78db..010eaf6a2 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -35,21 +35,19 @@ def server_url(server_port: int) -> str: # pragma: no cover class SecurityTestServer(Server): # pragma: no cover - def __init__(self) -> None: + def __init__(self): super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings( - port: int, security_settings: TransportSecuritySettings | None = None -) -> None: # pragma: no cover +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request) -> Response: + async def handle_sse(request: Request): try: async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: if streams: @@ -68,9 +66,7 @@ async def handle_sse(request: Request) -> Response: uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") -def start_server_process( - port: int, security_settings: TransportSecuritySettings | None = None -) -> multiprocessing.Process: +def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() @@ -80,7 +76,7 @@ def start_server_process( @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int) -> None: +async def test_sse_security_default_settings(server_port: int): """Test SSE with default security settings (protection disabled).""" process = start_server_process(server_port) @@ -96,7 +92,7 @@ async def test_sse_security_default_settings(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int) -> None: +async def test_sse_security_invalid_host_header(server_port: int): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) @@ -117,7 +113,7 @@ async def test_sse_security_invalid_host_header(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int) -> None: +async def test_sse_security_invalid_origin_header(server_port: int): """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( @@ -140,7 +136,7 @@ async def test_sse_security_invalid_origin_header(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int) -> None: +async def test_sse_security_post_invalid_content_type(server_port: int): """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( @@ -173,7 +169,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int) -> None: +async def test_sse_security_disabled(server_port: int): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) process = start_server_process(server_port, settings) @@ -194,7 +190,7 @@ async def test_sse_security_disabled(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int) -> None: +async def test_sse_security_custom_allowed_hosts(server_port: int): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, @@ -227,7 +223,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int) -> None: +async def test_sse_security_wildcard_ports(server_port: int): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, @@ -261,7 +257,7 @@ async def test_sse_security_wildcard_ports(server_port: int) -> None: @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int) -> None: +async def test_sse_security_post_valid_content_type(server_port: int): """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 3378a0cab..3bfc6e674 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -49,14 +49,14 @@ async def stateless_session() -> AsyncGenerator[ServerSession, None]: @pytest.mark.anyio -async def test_list_roots_fails_in_stateless_mode(stateless_session: ServerSession) -> None: +async def test_list_roots_fails_in_stateless_mode(stateless_session: ServerSession): """Test that list_roots raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="list_roots"): await stateless_session.list_roots() @pytest.mark.anyio -async def test_create_message_fails_in_stateless_mode(stateless_session: ServerSession) -> None: +async def test_create_message_fails_in_stateless_mode(stateless_session: ServerSession): """Test that create_message raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="sampling"): await stateless_session.create_message( @@ -71,7 +71,7 @@ async def test_create_message_fails_in_stateless_mode(stateless_session: ServerS @pytest.mark.anyio -async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSession) -> None: +async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSession): """Test that elicit_form raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="elicitation"): await stateless_session.elicit_form( @@ -81,7 +81,7 @@ async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSess @pytest.mark.anyio -async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSession) -> None: +async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSession): """Test that elicit_url raises StatelessModeNotSupported in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="elicitation"): await stateless_session.elicit_url( @@ -92,7 +92,7 @@ async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSessi @pytest.mark.anyio -async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: ServerSession) -> None: +async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: ServerSession): """Test that the deprecated elicit method also fails in stateless mode.""" with pytest.raises(StatelessModeNotSupported, match="elicitation"): await stateless_session.elicit( @@ -102,7 +102,7 @@ async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: Serv @pytest.mark.anyio -async def test_stateless_error_message_is_actionable(stateless_session: ServerSession) -> None: +async def test_stateless_error_message_is_actionable(stateless_session: ServerSession): """Test that the error message provides actionable guidance.""" with pytest.raises(StatelessModeNotSupported) as exc_info: await stateless_session.list_roots() @@ -117,7 +117,7 @@ async def test_stateless_error_message_is_actionable(stateless_session: ServerSe @pytest.mark.anyio -async def test_exception_has_method_attribute(stateless_session: ServerSession) -> None: +async def test_exception_has_method_attribute(stateless_session: ServerSession): """Test that the exception has a method attribute for programmatic access.""" with pytest.raises(StatelessModeNotSupported) as exc_info: await stateless_session.list_roots() @@ -155,7 +155,7 @@ async def stateful_session() -> AsyncGenerator[ServerSession, None]: @pytest.mark.anyio async def test_stateful_mode_does_not_raise_stateless_error( stateful_session: ServerSession, monkeypatch: pytest.MonkeyPatch -) -> None: +): """Test that StatelessModeNotSupported is not raised in stateful mode. We mock send_request to avoid blocking on I/O while still verifying diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index fbaeaed31..677a99356 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -11,7 +11,7 @@ @pytest.mark.anyio -async def test_stdio_server() -> None: +async def test_stdio_server(): stdin = io.StringIO() stdout = io.StringIO() @@ -64,7 +64,7 @@ async def test_stdio_server() -> None: @pytest.mark.anyio -async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): """Non-UTF-8 bytes on stdin must not crash the server. Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 281c6b22f..47cfbf14a 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -2,7 +2,6 @@ import json import logging -from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, patch @@ -20,7 +19,7 @@ @pytest.mark.anyio -async def test_run_can_only_be_called_once() -> None: +async def test_run_can_only_be_called_once(): """Test that run() can only be called once per instance.""" app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) @@ -38,14 +37,14 @@ async def test_run_can_only_be_called_once() -> None: @pytest.mark.anyio -async def test_run_prevents_concurrent_calls() -> None: +async def test_run_prevents_concurrent_calls(): """Test that concurrent calls to run() are prevented.""" app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) errors: list[Exception] = [] - async def try_run() -> None: + async def try_run(): try: async with manager.run(): # Simulate some work @@ -64,7 +63,7 @@ async def try_run() -> None: @pytest.mark.anyio -async def test_handle_request_without_run_raises_error() -> None: +async def test_handle_request_without_run_raises_error(): """Test that handle_request raises error if run() hasn't been called.""" app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) @@ -72,10 +71,10 @@ async def test_handle_request_without_run_raises_error() -> None: # Mock ASGI parameters scope = {"type": "http", "method": "POST", "path": "/test"} - async def receive() -> Message: # pragma: no cover + async def receive(): # pragma: no cover return {"type": "http.request", "body": b""} - async def send(message: Message) -> None: # pragma: no cover + async def send(message: Message): # pragma: no cover pass # Should raise error because run() hasn't been called @@ -91,7 +90,7 @@ class TestException(Exception): @pytest.fixture -async def running_manager() -> AsyncGenerator[tuple[StreamableHTTPSessionManager, Server], None]: +async def running_manager(): app = Server("test-cleanup-server") # It's important that the app instance used by the manager is the one we can patch manager = StreamableHTTPSessionManager(app=app) @@ -101,9 +100,7 @@ async def running_manager() -> AsyncGenerator[tuple[StreamableHTTPSessionManager @pytest.mark.anyio -async def test_stateful_session_cleanup_on_graceful_exit( - running_manager: tuple[StreamableHTTPSessionManager, Server], -) -> None: +async def test_stateful_session_cleanup_on_graceful_exit(running_manager: tuple[StreamableHTTPSessionManager, Server]): manager, app = running_manager mock_mcp_run = AsyncMock(return_value=None) @@ -112,7 +109,7 @@ async def test_stateful_session_cleanup_on_graceful_exit( sent_messages: list[Message] = [] - async def mock_send(message: Message) -> None: + async def mock_send(message: Message): sent_messages.append(message) scope = { @@ -122,7 +119,7 @@ async def mock_send(message: Message) -> None: "headers": [(b"content-type", b"application/json")], } - async def mock_receive() -> Message: # pragma: no cover + async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation @@ -158,9 +155,7 @@ async def mock_receive() -> Message: # pragma: no cover @pytest.mark.anyio -async def test_stateful_session_cleanup_on_exception( - running_manager: tuple[StreamableHTTPSessionManager, Server], -) -> None: +async def test_stateful_session_cleanup_on_exception(running_manager: tuple[StreamableHTTPSessionManager, Server]): manager, app = running_manager mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) @@ -168,7 +163,7 @@ async def test_stateful_session_cleanup_on_exception( sent_messages: list[Message] = [] - async def mock_send(message: Message) -> None: + async def mock_send(message: Message): sent_messages.append(message) # If an exception occurs, the transport might try to send an error response # For this test, we mostly care that the session is established enough @@ -183,7 +178,7 @@ async def mock_send(message: Message) -> None: "headers": [(b"content-type", b"application/json")], } - async def mock_receive() -> Message: # pragma: no cover + async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation @@ -213,7 +208,7 @@ async def mock_receive() -> Message: # pragma: no cover @pytest.mark.anyio -async def test_stateless_requests_memory_cleanup() -> None: +async def test_stateless_requests_memory_cleanup(): """Test that stateless requests actually clean up resources using real transports.""" app = Server("test-stateless-real-cleanup") manager = StreamableHTTPSessionManager(app=app, stateless=True) @@ -238,7 +233,7 @@ def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: # Send a simple request sent_messages: list[Message] = [] - async def mock_send(message: Message) -> None: + async def mock_send(message: Message): sent_messages.append(message) scope = { @@ -252,7 +247,7 @@ async def mock_send(message: Message) -> None: } # Empty body to trigger early return - async def mock_receive() -> Message: + async def mock_receive(): return { "type": "http.request", "body": b"", @@ -275,7 +270,7 @@ async def mock_receive() -> Message: @pytest.mark.anyio -async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture) -> None: +async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" app = Server("test-unknown-session") manager = StreamableHTTPSessionManager(app=app) @@ -284,7 +279,7 @@ async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture) sent_messages: list[Message] = [] response_body = b"" - async def mock_send(message: Message) -> None: + async def mock_send(message: Message): nonlocal response_body sent_messages.append(message) if message["type"] == "http.response.body": @@ -302,7 +297,7 @@ async def mock_send(message: Message) -> None: ], } - async def mock_receive() -> Message: + async def mock_receive(): return {"type": "http.request", "body": b"{}", "more_body": False} # pragma: no cover with caplog.at_level(logging.INFO): @@ -326,7 +321,7 @@ async def mock_receive() -> Message: @pytest.mark.anyio -async def test_e2e_streamable_http_server_cleanup() -> None: +async def test_e2e_streamable_http_server_cleanup(): host = "testserver" async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: @@ -344,7 +339,7 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP @pytest.mark.anyio -async def test_idle_session_is_reaped() -> None: +async def test_idle_session_is_reaped(): """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) @@ -352,7 +347,7 @@ async def test_idle_session_is_reaped() -> None: async with manager.run(): sent_messages: list[Message] = [] - async def mock_send(message: Message) -> None: + async def mock_send(message: Message): sent_messages.append(message) scope = { @@ -362,7 +357,7 @@ async def mock_send(message: Message) -> None: "headers": [(b"content-type", b"application/json")], } - async def mock_receive() -> Message: # pragma: no cover + async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} await manager.handle_request(scope, mock_receive, mock_send) @@ -385,7 +380,7 @@ async def mock_receive() -> Message: # pragma: no cover # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] - async def capture_send(message: Message) -> None: + async def capture_send(message: Message): response_messages.append(message) scope_with_session = { @@ -408,13 +403,13 @@ async def capture_send(message: Message) -> None: assert response_start["status"] == 404 -def test_session_idle_timeout_rejects_non_positive() -> None: +def test_session_idle_timeout_rejects_non_positive(): with pytest.raises(ValueError, match="positive number"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=-1) with pytest.raises(ValueError, match="positive number"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=0) -def test_session_idle_timeout_rejects_stateless() -> None: +def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index f5dcff821..897555353 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -34,16 +34,14 @@ def server_url(server_port: int) -> str: # pragma: no cover class SecurityTestServer(Server): # pragma: no cover - def __init__(self) -> None: + def __init__(self): super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings( - port: int, security_settings: TransportSecuritySettings | None = None -) -> None: # pragma: no cover +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover """Run the StreamableHTTP server with specified security settings.""" app = SecurityTestServer() @@ -73,9 +71,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") -def start_server_process( - port: int, security_settings: TransportSecuritySettings | None = None -) -> multiprocessing.Process: +def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() @@ -85,7 +81,7 @@ def start_server_process( @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int) -> None: +async def test_streamable_http_security_default_settings(server_port: int): """Test StreamableHTTP with default security settings (protection enabled).""" process = start_server_process(server_port) @@ -110,7 +106,7 @@ async def test_streamable_http_security_default_settings(server_port: int) -> No @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int) -> None: +async def test_streamable_http_security_invalid_host_header(server_port: int): """Test StreamableHTTP with invalid Host header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) process = start_server_process(server_port, security_settings) @@ -138,7 +134,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int) -> @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int) -> None: +async def test_streamable_http_security_invalid_origin_header(server_port: int): """Test StreamableHTTP with invalid Origin header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) process = start_server_process(server_port, security_settings) @@ -166,7 +162,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int) @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int) -> None: +async def test_streamable_http_security_invalid_content_type(server_port: int): """Test StreamableHTTP POST with invalid Content-Type header.""" process = start_server_process(server_port) @@ -199,7 +195,7 @@ async def test_streamable_http_security_invalid_content_type(server_port: int) - @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int) -> None: +async def test_streamable_http_security_disabled(server_port: int): """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) process = start_server_process(server_port, settings) @@ -227,7 +223,7 @@ async def test_streamable_http_security_disabled(server_port: int) -> None: @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int) -> None: +async def test_streamable_http_security_custom_allowed_hosts(server_port: int): """Test StreamableHTTP with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, @@ -258,7 +254,7 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int) - @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int) -> None: +async def test_streamable_http_security_get_request(server_port: int): """Test StreamableHTTP GET request with security.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) process = start_server_process(server_port, security_settings) diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index a1138a6ee..cd3c35332 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -3,7 +3,7 @@ from mcp.shared.auth import OAuthMetadata -def test_oauth() -> None: +def test_oauth(): """Should not throw when parsing OAuth metadata.""" OAuthMetadata.model_validate( { @@ -17,7 +17,7 @@ def test_oauth() -> None: ) -def test_oidc() -> None: +def test_oidc(): """Should not throw when parsing OIDC metadata.""" OAuthMetadata.model_validate( { @@ -37,7 +37,7 @@ def test_oidc() -> None: ) -def test_oauth_with_jarm() -> None: +def test_oauth_with_jarm(): """Should not throw when parsing OAuth metadata that includes JARM response modes.""" OAuthMetadata.model_validate( { diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index ee6c4347f..5ae0e22b0 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -7,13 +7,13 @@ # Tests for resource_url_from_server_url function -def test_resource_url_from_server_url_removes_fragment() -> None: +def test_resource_url_from_server_url_removes_fragment(): """Fragment should be removed per RFC 8707.""" assert resource_url_from_server_url("https://example.com/path#fragment") == "https://example.com/path" assert resource_url_from_server_url("https://example.com/#fragment") == "https://example.com/" -def test_resource_url_from_server_url_preserves_path() -> None: +def test_resource_url_from_server_url_preserves_path(): """Path should be preserved.""" assert ( resource_url_from_server_url("https://example.com/path/to/resource") == "https://example.com/path/to/resource" @@ -22,25 +22,25 @@ def test_resource_url_from_server_url_preserves_path() -> None: assert resource_url_from_server_url("https://example.com") == "https://example.com" -def test_resource_url_from_server_url_preserves_query() -> None: +def test_resource_url_from_server_url_preserves_query(): """Query parameters should be preserved.""" assert resource_url_from_server_url("https://example.com/path?foo=bar") == "https://example.com/path?foo=bar" assert resource_url_from_server_url("https://example.com/?key=value") == "https://example.com/?key=value" -def test_resource_url_from_server_url_preserves_port() -> None: +def test_resource_url_from_server_url_preserves_port(): """Non-default ports should be preserved.""" assert resource_url_from_server_url("https://example.com:8443/path") == "https://example.com:8443/path" assert resource_url_from_server_url("http://example.com:8080/") == "http://example.com:8080/" -def test_resource_url_from_server_url_lowercase_scheme_and_host() -> None: +def test_resource_url_from_server_url_lowercase_scheme_and_host(): """Scheme and host should be lowercase for canonical form.""" assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" assert resource_url_from_server_url("Http://Example.Com:8080/") == "http://example.com:8080/" -def test_resource_url_from_server_url_handles_pydantic_urls() -> None: +def test_resource_url_from_server_url_handles_pydantic_urls(): """Should handle Pydantic URL types.""" url = HttpUrl("https://example.com/path") assert resource_url_from_server_url(url) == "https://example.com/path" @@ -49,32 +49,32 @@ def test_resource_url_from_server_url_handles_pydantic_urls() -> None: # Tests for check_resource_allowed function -def test_check_resource_allowed_identical_urls() -> None: +def test_check_resource_allowed_identical_urls(): """Identical URLs should match.""" assert check_resource_allowed("https://example.com/path", "https://example.com/path") is True assert check_resource_allowed("https://example.com/", "https://example.com/") is True assert check_resource_allowed("https://example.com", "https://example.com") is True -def test_check_resource_allowed_different_schemes() -> None: +def test_check_resource_allowed_different_schemes(): """Different schemes should not match.""" assert check_resource_allowed("https://example.com/path", "http://example.com/path") is False assert check_resource_allowed("http://example.com/", "https://example.com/") is False -def test_check_resource_allowed_different_domains() -> None: +def test_check_resource_allowed_different_domains(): """Different domains should not match.""" assert check_resource_allowed("https://example.com/path", "https://example.org/path") is False assert check_resource_allowed("https://sub.example.com/", "https://example.com/") is False -def test_check_resource_allowed_different_ports() -> None: +def test_check_resource_allowed_different_ports(): """Different ports should not match.""" assert check_resource_allowed("https://example.com:8443/path", "https://example.com/path") is False assert check_resource_allowed("https://example.com:8080/", "https://example.com:8443/") is False -def test_check_resource_allowed_hierarchical_matching() -> None: +def test_check_resource_allowed_hierarchical_matching(): """Child paths should match parent paths.""" # Parent resource allows child resources assert check_resource_allowed("https://example.com/api/v1/users", "https://example.com/api") is True @@ -89,7 +89,7 @@ def test_check_resource_allowed_hierarchical_matching() -> None: assert check_resource_allowed("https://example.com/", "https://example.com/api") is False -def test_check_resource_allowed_path_boundary_matching() -> None: +def test_check_resource_allowed_path_boundary_matching(): """Path matching should respect boundaries.""" # Should not match partial path segments assert check_resource_allowed("https://example.com/apiextra", "https://example.com/api") is False @@ -100,7 +100,7 @@ def test_check_resource_allowed_path_boundary_matching() -> None: assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True -def test_check_resource_allowed_trailing_slash_handling() -> None: +def test_check_resource_allowed_trailing_slash_handling(): """Trailing slashes should be handled correctly.""" # With and without trailing slashes assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True @@ -109,14 +109,14 @@ def test_check_resource_allowed_trailing_slash_handling() -> None: assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True -def test_check_resource_allowed_case_insensitive_origin() -> None: +def test_check_resource_allowed_case_insensitive_origin(): """Origin comparison should be case-insensitive.""" assert check_resource_allowed("https://EXAMPLE.COM/path", "https://example.com/path") is True assert check_resource_allowed("HTTPS://example.com/path", "https://example.com/path") is True assert check_resource_allowed("https://Example.Com:8080/api", "https://example.com:8080/api") is True -def test_check_resource_allowed_empty_paths() -> None: +def test_check_resource_allowed_empty_paths(): """Empty paths should be handled correctly.""" assert check_resource_allowed("https://example.com", "https://example.com") is True assert check_resource_allowed("https://example.com/", "https://example.com") is True diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py index 493f5f100..dcc6fd003 100644 --- a/tests/shared/test_httpx_utils.py +++ b/tests/shared/test_httpx_utils.py @@ -5,7 +5,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client -def test_default_settings() -> None: +def test_default_settings(): """Test that default settings are applied correctly.""" client = create_mcp_http_client() @@ -13,7 +13,7 @@ def test_default_settings() -> None: assert client.timeout.connect == 30.0 -def test_custom_parameters() -> None: +def test_custom_parameters(): """Test custom headers and timeout are set correctly.""" headers = {"Authorization": "Bearer token"} timeout = httpx.Timeout(60.0) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 8ad4d8c0d..aad9e5d43 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -15,14 +15,14 @@ @pytest.mark.anyio -async def test_bidirectional_progress_notifications() -> None: +async def test_bidirectional_progress_notifications(): """Test that both client and server can send progress notifications.""" # Create memory streams for client/server server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Run a server session so we can send progress updates in tool - async def run_server() -> None: + async def run_server(): # Create a server session async with ServerSession( client_to_server_receive, @@ -197,7 +197,7 @@ async def handle_client_message( @pytest.mark.anyio -async def test_progress_callback_exception_logging() -> None: +async def test_progress_callback_exception_logging(): """Test that exceptions in progress callbacks are logged and \ don't crash the session.""" # Track logged warnings diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 285efefcf..d7c6cc3b5 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -24,7 +24,7 @@ @pytest.mark.anyio -async def test_in_flight_requests_cleared_after_completion() -> None: +async def test_in_flight_requests_cleared_after_completion(): """Verify that _in_flight is empty after all requests complete.""" server = Server(name="test server") async with Client(server) as client: @@ -37,7 +37,7 @@ async def test_in_flight_requests_cleared_after_completion() -> None: @pytest.mark.anyio -async def test_request_cancellation() -> None: +async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" ev_tool_called = anyio.Event() ev_cancelled = anyio.Event() @@ -64,7 +64,7 @@ async def handle_list_tools( on_list_tools=handle_list_tools, ) - async def make_request(client: Client) -> None: + async def make_request(client: Client): nonlocal ev_cancelled try: await client.session.send_request( @@ -99,7 +99,7 @@ async def make_request(client: Client) -> None: @pytest.mark.anyio -async def test_response_id_type_mismatch_string_to_int() -> None: +async def test_response_id_type_mismatch_string_to_int(): """Test that responses with string IDs are correctly matched to requests sent with integer IDs. @@ -113,7 +113,7 @@ async def test_response_id_type_mismatch_string_to_int() -> None: client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server() -> None: + async def mock_server(): """Receive a request and respond with a string ID instead of integer.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -131,7 +131,7 @@ async def mock_server() -> None: ) await server_write.send(SessionMessage(message=response)) - async def make_request(client_session: ClientSession) -> None: + async def make_request(client_session: ClientSession): nonlocal result_holder # Send a ping request (uses integer ID internally) result = await client_session.send_ping() @@ -153,7 +153,7 @@ async def make_request(client_session: ClientSession) -> None: @pytest.mark.anyio -async def test_error_response_id_type_mismatch_string_to_int() -> None: +async def test_error_response_id_type_mismatch_string_to_int(): """Test that error responses with string IDs are correctly matched to requests sent with integer IDs. @@ -167,7 +167,7 @@ async def test_error_response_id_type_mismatch_string_to_int() -> None: client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server() -> None: + async def mock_server(): """Receive a request and respond with an error using a string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -184,7 +184,7 @@ async def mock_server() -> None: ) await server_write.send(SessionMessage(message=error_response)) - async def make_request(client_session: ClientSession) -> None: + async def make_request(client_session: ClientSession): nonlocal error_holder try: await client_session.send_ping() @@ -208,7 +208,7 @@ async def make_request(client_session: ClientSession) -> None: @pytest.mark.anyio -async def test_response_id_non_numeric_string_no_match() -> None: +async def test_response_id_non_numeric_string_no_match(): """Test that responses with non-numeric string IDs don't incorrectly match integer request IDs. @@ -221,7 +221,7 @@ async def test_response_id_non_numeric_string_no_match() -> None: client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server() -> None: + async def mock_server(): """Receive a request and respond with a non-numeric string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -234,7 +234,7 @@ async def mock_server() -> None: ) await server_write.send(SessionMessage(message=response)) - async def make_request(client_session: ClientSession) -> None: + async def make_request(client_session: ClientSession): try: # Use a short timeout since we expect this to fail await client_session.send_request( @@ -259,7 +259,7 @@ async def make_request(client_session: ClientSession) -> None: @pytest.mark.anyio -async def test_connection_closed() -> None: +async def test_connection_closed(): """Test that pending requests are cancelled when the connection is closed remotely.""" ev_closed = anyio.Event() @@ -269,7 +269,7 @@ async def test_connection_closed() -> None: client_read, client_write = client_streams server_read, server_write = server_streams - async def make_request(client_session: ClientSession) -> None: + async def make_request(client_session: ClientSession): """Send a request in a separate task""" nonlocal ev_response try: @@ -281,7 +281,7 @@ async def make_request(client_session: ClientSession) -> None: assert "Connection closed" in str(e) ev_response.set() - async def mock_server() -> None: + async def mock_server(): """Wait for a request, then close the connection""" nonlocal ev_closed # Wait for a request @@ -305,7 +305,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_null_id_error_surfaced_via_message_handler() -> None: +async def test_null_id_error_surfaced_via_message_handler(): """Test that a JSONRPCError with id=None is surfaced to the message handler. Per JSON-RPC 2.0, error responses use id=null when the request id could not @@ -328,7 +328,7 @@ async def capture_errors( client_read, client_write = client_streams _server_read, server_write = server_streams - async def mock_server() -> None: + async def mock_server(): """Send a null-id error (simulating a parse error).""" error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) await server_write.send(SessionMessage(message=error_response)) @@ -351,7 +351,7 @@ async def mock_server() -> None: @pytest.mark.anyio -async def test_null_id_error_does_not_affect_pending_request() -> None: +async def test_null_id_error_does_not_affect_pending_request(): """Test that a null-id error doesn't interfere with an in-flight request. When a null-id error arrives while a request is pending, the error should @@ -376,7 +376,7 @@ async def capture_errors( client_read, client_write = client_streams server_read, server_write = server_streams - async def mock_server() -> None: + async def mock_server(): """Read a request, inject a null-id error, then respond normally.""" message = await server_read.receive() assert isinstance(message, SessionMessage) @@ -389,7 +389,7 @@ async def mock_server() -> None: # Then, respond normally to the pending request await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) - async def make_request(client_session: ClientSession) -> None: + async def make_request(client_session: ClientSession): result = await client_session.send_ping() result_holder.append(result) ev_response_received.set() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4dec30549..5629a5707 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -497,7 +497,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["headers"].get("x-custom-value") == f"value-{i}" -def test_sse_message_id_coercion() -> None: +def test_sse_message_id_coercion(): """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -531,7 +531,7 @@ def test_sse_message_id_coercion() -> None: ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 9d0b6adff..f8ca30441 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,10 +10,10 @@ import socket import time import traceback -from collections.abc import AsyncGenerator, AsyncIterator, Generator +from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any, NoReturn +from typing import Any from unittest.mock import MagicMock from urllib.parse import urlparse @@ -97,7 +97,7 @@ def extract_protocol_version_from_sse(response: requests.Response) -> str: class SimpleEventStore(EventStore): """Simple in-memory event store for testing.""" - def __init__(self) -> None: + def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 @@ -570,7 +570,7 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str) -> None: +def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" # Test without Accept header (suppress requests library default Accept: */*) session = requests.Session() @@ -595,7 +595,7 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str) -> "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str) -> None: +def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): """Test that wildcard Accept headers are accepted per RFC 7231.""" response = requests.post( f"{basic_server_url}/mcp", @@ -616,7 +616,7 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str) -> None: +def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): """Test that incompatible Accept headers are rejected for SSE mode.""" response = requests.post( f"{basic_server_url}/mcp", @@ -630,7 +630,7 @@ def test_accept_header_incompatible(basic_server: None, basic_server_url: str, a assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server: None, basic_server_url: str) -> None: +def test_content_type_validation(basic_server: None, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -646,7 +646,7 @@ def test_content_type_validation(basic_server: None, basic_server_url: str) -> N assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server: None, basic_server_url: str) -> None: +def test_json_validation(basic_server: None, basic_server_url: str): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -661,7 +661,7 @@ def test_json_validation(basic_server: None, basic_server_url: str) -> None: assert "Parse error" in response.text -def test_json_parsing(basic_server: None, basic_server_url: str) -> None: +def test_json_parsing(basic_server: None, basic_server_url: str): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -676,7 +676,7 @@ def test_json_parsing(basic_server: None, basic_server_url: str) -> None: assert "Validation error" in response.text -def test_method_not_allowed(basic_server: None, basic_server_url: str) -> None: +def test_method_not_allowed(basic_server: None, basic_server_url: str): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -691,7 +691,7 @@ def test_method_not_allowed(basic_server: None, basic_server_url: str) -> None: assert "Method Not Allowed" in response.text -def test_session_validation(basic_server: None, basic_server_url: str) -> None: +def test_session_validation(basic_server: None, basic_server_url: str): """Test session ID validation.""" # session_id not used directly in this test @@ -708,7 +708,7 @@ def test_session_validation(basic_server: None, basic_server_url: str) -> None: assert "Missing session ID" in response.text -def test_session_id_pattern() -> None: +def test_session_id_pattern(): """Test that SESSION_ID_PATTERN correctly validates session IDs.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ @@ -743,7 +743,7 @@ def test_session_id_pattern() -> None: assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation() -> None: +def test_streamable_http_transport_init_validation(): """Test that StreamableHTTPServerTransport validates session ID on init.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") @@ -766,7 +766,7 @@ def test_streamable_http_transport_init_validation() -> None: StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str) -> None: +def test_session_termination(basic_server: None, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -806,7 +806,7 @@ def test_session_termination(basic_server: None, basic_server_url: str) -> None: assert "Session has been terminated" in response.text -def test_response(basic_server: None, basic_server_url: str) -> None: +def test_response(basic_server: None, basic_server_url: str): """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -841,7 +841,7 @@ def test_response(basic_server: None, basic_server_url: str) -> None: assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server: None, json_server_url: str) -> None: +def test_json_response(json_response_server: None, json_server_url: str): """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -856,7 +856,7 @@ def test_json_response(json_response_server: None, json_server_url: str) -> None assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str) -> None: +def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): """Test that json_response servers only require application/json in Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -871,7 +871,7 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ assert response.headers.get("Content-Type") == "application/json" -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str) -> None: +def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" # Suppress requests library default Accept: */* header @@ -888,7 +888,7 @@ def test_json_response_missing_accept_header(json_response_server: None, json_se assert "Not Acceptable" in response.text -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str) -> None: +def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): """Test that json_response servers reject requests with incorrect Accept header.""" mcp_url = f"{json_server_url}/mcp" # Test with only text/event-stream (wrong for JSON server) @@ -912,9 +912,7 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header( - json_response_server: None, json_server_url: str, accept_header: str -) -> None: +def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -929,7 +927,7 @@ def test_json_response_wildcard_accept_header( assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server: None, basic_server_url: str) -> None: +def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -989,7 +987,7 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str) -> None: assert second_get.status_code == 409 -def test_get_validation(basic_server: None, basic_server_url: str) -> None: +def test_get_validation(basic_server: None, basic_server_url: str): """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -1046,16 +1044,14 @@ def test_get_validation(basic_server: None, basic_server_url: str) -> None: # Client-specific fixtures @pytest.fixture -async def http_client( - basic_server: None, basic_server_url: str -) -> AsyncGenerator[httpx.AsyncClient, None]: # pragma: no cover +async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_client_session(basic_server: None, basic_server_url: str): """Create initialized StreamableHTTP client session.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1064,7 +1060,7 @@ async def initialized_client_session(basic_server: None, basic_server_url: str) @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str) -> None: +async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): """Test basic client connection with initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1075,7 +1071,7 @@ async def test_streamable_http_client_basic_connection(basic_server: None, basic @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): """Test client resource read functionality.""" response = await initialized_client_session.read_resource(uri="foobar://test-resource") assert len(response.contents) == 1 @@ -1085,7 +1081,7 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() @@ -1100,7 +1096,7 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): """Test error handling in client.""" with pytest.raises(MCPError) as exc_info: await initialized_client_session.read_resource(uri="unknown://test-error") @@ -1109,7 +1105,7 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str) -> None: +async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): """Test that session ID persists across requests.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1130,7 +1126,7 @@ async def test_streamable_http_client_session_persistence(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str) -> None: +async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): """Test client with JSON response mode.""" async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1151,7 +1147,7 @@ async def test_streamable_http_client_json_response(json_response_server: None, @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str) -> None: +async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" notifications_received: list[types.ServerNotification] = [] @@ -1202,7 +1198,7 @@ async def capture_session_id(response: httpx.Response) -> None: @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str) -> None: +async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): """Test client session termination functionality.""" # Use httpx client with event hooks to capture session ID httpx_client, captured_ids = create_session_id_capturing_client() @@ -1239,7 +1235,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -) -> None: +): """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1298,7 +1294,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]) -> None: +async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): """Test client session resumption using sync primitives for reliable coordination.""" _, server_url = event_server @@ -1349,7 +1345,7 @@ async def on_resumption_token_update(token: str) -> None: # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: # pragma: no branch - async def run_tool() -> None: + async def run_tool(): metadata = ClientMessageMetadata( on_resumption_token_update=on_resumption_token_update, ) @@ -1416,7 +1412,7 @@ async def run_tool() -> None: @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1521,7 +1517,7 @@ async def _handle_context_call_tool( # pragma: no cover # Server runner for context-aware testing -def run_context_aware_server(port: int) -> None: # pragma: no cover +def run_context_aware_server(port: int): # pragma: no cover """Run the context-aware test server.""" server = Server( "ContextAwareServer", @@ -1643,9 +1639,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): """Test that client includes mcp-protocol-version header after initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1665,7 +1659,7 @@ async def test_client_includes_protocol_version_header_after_init( assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str) -> None: +def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1723,7 +1717,7 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str) -> None: +def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1753,11 +1747,11 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server: None, @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str) -> None: +async def test_client_crash_handled(basic_server: None, basic_server_url: str): """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init - async def bad_client() -> NoReturn: + async def bad_client(): """Client that triggers ClosedResourceError""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1782,7 +1776,7 @@ async def bad_client() -> NoReturn: @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data() -> None: +async def test_handle_sse_event_skips_empty_data(): """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") @@ -1808,7 +1802,7 @@ async def test_handle_sse_event_skips_empty_data() -> None: @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version() -> None: +async def test_priming_event_not_sent_for_old_protocol_version(): """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( @@ -1837,7 +1831,7 @@ async def test_priming_event_not_sent_for_old_protocol_version() -> None: @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store() -> None: +async def test_priming_event_not_sent_without_event_store(): """Test that _maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1857,7 +1851,7 @@ async def test_priming_event_not_sent_without_event_store() -> None: @pytest.mark.anyio -async def test_priming_event_includes_retry_interval() -> None: +async def test_priming_event_includes_retry_interval(): """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( @@ -1886,7 +1880,7 @@ async def test_priming_event_includes_retry_interval() -> None: @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( @@ -2125,7 +2119,7 @@ async def message_handler( @pytest.mark.anyio async def test_streamable_http_multiple_reconnections( event_server: tuple[SimpleEventStore, str], -) -> None: +): """Verify multiple close_sse_stream() calls each trigger a client reconnect. Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure diff --git a/tests/test_examples.py b/tests/test_examples.py index 0a7fb2ba8..3af82f04c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -16,7 +16,7 @@ @pytest.mark.anyio -async def test_simple_echo() -> None: +async def test_simple_echo(): """Test the simple echo server""" from examples.mcpserver.simple_echo import mcp @@ -28,7 +28,7 @@ async def test_simple_echo() -> None: @pytest.mark.anyio -async def test_complex_inputs() -> None: +async def test_complex_inputs(): """Test the complex inputs server""" from examples.mcpserver.complex_inputs import mcp @@ -48,7 +48,7 @@ async def test_complex_inputs() -> None: @pytest.mark.anyio -async def test_direct_call_tool_result_return() -> None: +async def test_direct_call_tool_result_return(): """Test the CallToolResult echo server""" from examples.mcpserver.direct_call_tool_result_return import mcp @@ -64,7 +64,7 @@ async def test_direct_call_tool_result_return() -> None: @pytest.mark.anyio -async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: +async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """Test the desktop server""" # Build a real Desktop directory under tmp_path rather than patching # Path.iterdir — a class-level patch breaks jsonschema_specifications' @@ -95,8 +95,8 @@ async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: # TODO(v2): Change back to README.md when v2 is released @pytest.mark.parametrize("example", find_examples("README.v2.md"), ids=str) -def test_docs_examples(example: CodeExample, eval_example: EvalExample) -> None: - ruff_ignore: list[str] = ["F841", "I001", "F821", "ANN"] # F821: undefined names (snippets lack imports) +def test_docs_examples(example: CodeExample, eval_example: EvalExample): + ruff_ignore: list[str] = ["F841", "I001", "F821"] # F821: undefined names (snippets lack imports) # Use project's actual line length of 120 eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=120) diff --git a/tests/test_types.py b/tests/test_types.py index 28df823a1..f424efdbf 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -26,7 +26,7 @@ @pytest.mark.anyio -async def test_jsonrpc_request() -> None: +async def test_jsonrpc_request(): json_data = { "jsonrpc": "2.0", "id": 1, @@ -50,7 +50,7 @@ async def test_jsonrpc_request() -> None: @pytest.mark.anyio -async def test_method_initialization() -> None: +async def test_method_initialization(): """Test that the method is automatically set on object creation. Testing just for InitializeRequest to keep the test simple, but should be set for other types as well. """ @@ -71,7 +71,7 @@ async def test_method_initialization() -> None: @pytest.mark.anyio -async def test_tool_use_content() -> None: +async def test_tool_use_content(): """Test ToolUseContent type for SEP-1577.""" tool_use_data = { "type": "tool_use", @@ -93,7 +93,7 @@ async def test_tool_use_content() -> None: @pytest.mark.anyio -async def test_tool_result_content() -> None: +async def test_tool_result_content(): """Test ToolResultContent type for SEP-1577.""" tool_result_data = { "type": "tool_result", @@ -115,7 +115,7 @@ async def test_tool_result_content() -> None: @pytest.mark.anyio -async def test_tool_choice() -> None: +async def test_tool_choice(): """Test ToolChoice type for SEP-1577.""" # Test with mode tool_choice_data = {"mode": "required"} @@ -135,7 +135,7 @@ async def test_tool_choice() -> None: @pytest.mark.anyio -async def test_sampling_message_with_user_role() -> None: +async def test_sampling_message_with_user_role(): """Test SamplingMessage with user role for SEP-1577.""" # Test with single content user_msg_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} @@ -158,7 +158,7 @@ async def test_sampling_message_with_user_role() -> None: @pytest.mark.anyio -async def test_sampling_message_with_assistant_role() -> None: +async def test_sampling_message_with_assistant_role(): """Test SamplingMessage with assistant role for SEP-1577.""" # Test with tool use content assistant_msg_data = { @@ -188,7 +188,7 @@ async def test_sampling_message_with_assistant_role() -> None: @pytest.mark.anyio -async def test_sampling_message_backward_compatibility() -> None: +async def test_sampling_message_backward_compatibility(): """Test that SamplingMessage maintains backward compatibility.""" # Old-style message (single content, no tools) old_style_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} @@ -215,7 +215,7 @@ async def test_sampling_message_backward_compatibility() -> None: @pytest.mark.anyio -async def test_create_message_request_params_with_tools() -> None: +async def test_create_message_request_params_with_tools(): """Test CreateMessageRequestParams with tools for SEP-1577.""" tool = Tool( name="get_weather", @@ -238,7 +238,7 @@ async def test_create_message_request_params_with_tools() -> None: @pytest.mark.anyio -async def test_create_message_result_with_tool_use() -> None: +async def test_create_message_result_with_tool_use(): """Test CreateMessageResultWithTools with tool use content for SEP-1577.""" result_data = { "role": "assistant", @@ -261,7 +261,7 @@ async def test_create_message_result_with_tool_use() -> None: @pytest.mark.anyio -async def test_create_message_result_basic() -> None: +async def test_create_message_result_basic(): """Test CreateMessageResult with basic text content (backwards compatible).""" result_data = { "role": "assistant", @@ -280,7 +280,7 @@ async def test_create_message_result_basic() -> None: @pytest.mark.anyio -async def test_client_capabilities_with_sampling_tools() -> None: +async def test_client_capabilities_with_sampling_tools(): """Test ClientCapabilities with nested sampling capabilities for SEP-1577.""" # New structured format capabilities_data: dict[str, Any] = { @@ -299,7 +299,7 @@ async def test_client_capabilities_with_sampling_tools() -> None: assert full_caps.sampling.tools is not None -def test_tool_preserves_json_schema_2020_12_fields() -> None: +def test_tool_preserves_json_schema_2020_12_fields(): """Verify that JSON Schema 2020-12 keywords are preserved in Tool.inputSchema. SEP-1613 establishes JSON Schema 2020-12 as the default dialect for MCP. @@ -336,7 +336,7 @@ def test_tool_preserves_json_schema_2020_12_fields() -> None: assert serialized["inputSchema"]["additionalProperties"] is False -def test_list_tools_result_preserves_json_schema_2020_12_fields() -> None: +def test_list_tools_result_preserves_json_schema_2020_12_fields(): """Verify JSON Schema 2020-12 fields survive ListToolsResult deserialization.""" raw_response = { "tools": [ From 91a983fc4cb4b4bf5a82290bb8ee8c257422e668 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:12:43 +0000 Subject: [PATCH 4/5] chore: ignore ANN204 to allow __init__ without return type Replaces mypy-init-return (which only exempts __init__ when args are typed) with a blanket ANN204 ignore. Special methods have well-known return types and pyright infers them correctly. --- README.v2.md | 4 ++-- examples/snippets/clients/oauth_client.py | 2 +- examples/snippets/servers/structured_output.py | 2 +- pyproject.toml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.v2.md b/README.v2.md index 02c133b0d..8049d1020 100644 --- a/README.v2.md +++ b/README.v2.md @@ -532,7 +532,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001, ANN204 + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001 self.setting1 = setting1 self.setting2 = setting2 @@ -2327,7 +2327,7 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAu class InMemoryTokenStorage(TokenStorage): """Demo In-memory token storage implementation.""" - def __init__(self) -> None: + def __init__(self): self.tokens: OAuthToken | None = None self.client_info: OAuthClientInformationFull | None = None diff --git a/examples/snippets/clients/oauth_client.py b/examples/snippets/clients/oauth_client.py index 0f6cd5568..115a6b6c8 100644 --- a/examples/snippets/clients/oauth_client.py +++ b/examples/snippets/clients/oauth_client.py @@ -21,7 +21,7 @@ class InMemoryTokenStorage(TokenStorage): """Demo In-memory token storage implementation.""" - def __init__(self) -> None: + def __init__(self): self.tokens: OAuthToken | None = None self.client_info: OAuthClientInformationFull | None = None diff --git a/examples/snippets/servers/structured_output.py b/examples/snippets/servers/structured_output.py index d7a2a4b51..422d9001c 100644 --- a/examples/snippets/servers/structured_output.py +++ b/examples/snippets/servers/structured_output.py @@ -71,7 +71,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001, ANN204 + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] # noqa: ANN001 self.setting1 = setting1 self.setting2 = setting2 diff --git a/pyproject.toml b/pyproject.toml index b58aa3c13..decfdbd31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,13 +143,13 @@ select = [ "TID251", # https://docs.astral.sh/ruff/rules/banned-api/ ] ignore = [ + "ANN204", # special methods (__init__, __enter__, etc.) have well-known return types "ANN401", # `Any` is sometimes the right type; pyright strict handles real misuse "PERF203", ] [tool.ruff.lint.flake8-annotations] allow-star-arg-any = true -mypy-init-return = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead." From a9cf2faec458d7871e3abf3371f54360dbe01862 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 25 Mar 2026 11:17:51 +0000 Subject: [PATCH 5/5] test: ignore ANN in README code example linting README doc snippets are illustrative and shouldn't require full type annotations. Lost this when reverting tests/** in d6b3ae4. --- tests/test_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 3af82f04c..3b9cdb23a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -96,7 +96,7 @@ async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): # TODO(v2): Change back to README.md when v2 is released @pytest.mark.parametrize("example", find_examples("README.v2.md"), ids=str) def test_docs_examples(example: CodeExample, eval_example: EvalExample): - ruff_ignore: list[str] = ["F841", "I001", "F821"] # F821: undefined names (snippets lack imports) + ruff_ignore: list[str] = ["F841", "I001", "F821", "ANN"] # F821: undefined names (snippets lack imports) # Use project's actual line length of 120 eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=120)