From c85501ac65bb424d30b079be25ba619ce85d6d02 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:46:08 +0000 Subject: [PATCH 1/2] feat(auth): add BearerAuth for minimal bearer-token authentication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds BearerAuth, a lightweight httpx.Auth implementation with a two-method contract (token() + optional on_unauthorized()). This covers the many deployments that don't fit the OAuth authorization-code flow: gateway/proxy patterns, service accounts with pre-provisioned tokens, enterprise SSO where tokens come from a separate pipeline. For simple cases, it's a one-liner: auth = BearerAuth("my-api-key") async with Client(url, auth=auth) as client: ... For token rotation, pass a callable (sync or async): auth = BearerAuth(lambda: os.environ.get("MCP_TOKEN")) For custom 401 handling, pass or override on_unauthorized(). The handler receives the 401 response (body pre-read, WWW-Authenticate available), refreshes credentials, and the request retries once. Retry state is naturally per-operation via httpx's generator-per-request pattern — no shared counter to reset or leak. OAuthClientProvider is unchanged. Both are httpx.Auth subclasses and plug into the same auth parameter — no adapter or type guard needed. Also adds: - auth= convenience parameter on streamable_http_client() and Client (mutually exclusive with http_client=, raises ValueError if both given) - UnauthorizedError exception for unrecoverable 401s - sync_auth_flow override that raises a clear error instead of silently no-oping - docs/authorization.md with bearer-token and OAuth sections - examples/snippets/clients/bearer_auth_client.py - 21 tests covering generator-driven unit tests and httpx wire-level integration --- docs/authorization.md | 149 ++++++- .../snippets/clients/bearer_auth_client.py | 45 ++ examples/snippets/pyproject.toml | 1 + src/mcp/client/auth/__init__.py | 21 +- src/mcp/client/auth/bearer.py | 175 ++++++++ src/mcp/client/auth/exceptions.py | 9 + src/mcp/client/client.py | 10 +- src/mcp/client/streamable_http.py | 19 +- src/mcp/shared/_httpx_utils.py | 2 +- tests/client/auth/test_bearer.py | 392 ++++++++++++++++++ tests/client/test_client.py | 14 +- 11 files changed, 827 insertions(+), 10 deletions(-) create mode 100644 examples/snippets/clients/bearer_auth_client.py create mode 100644 src/mcp/client/auth/bearer.py create mode 100644 tests/client/auth/test_bearer.py diff --git a/docs/authorization.md b/docs/authorization.md index 4b6208bdf..48d8d6ebd 100644 --- a/docs/authorization.md +++ b/docs/authorization.md @@ -1,5 +1,150 @@ # Authorization -!!! warning "Under Construction" +MCP HTTP transports authenticate via `httpx.Auth`. The SDK provides two +implementations that plug into the same `auth` parameter: - This page is currently being written. Check back soon for complete documentation. +- **`BearerAuth`** — a minimal two-method provider for API keys, gateway-managed + tokens, service accounts, or any scenario where the token comes from an + external pipeline. +- **`OAuthClientProvider`** — full OAuth 2.1 authorization-code flow with PKCE, + Protected Resource Metadata discovery (RFC 9728), dynamic client registration, + and automatic token refresh. + +Both are `httpx.Auth` subclasses. Pass either to `Client(url, auth=...)`, +`streamable_http_client(url, auth=...)`, or directly to +`httpx.AsyncClient(auth=...)`. + +## Bearer tokens + +For a static token (API key, pre-provisioned credential): + +```python +from mcp.client import Client +from mcp.client.auth import BearerAuth + +async with Client("https://api.example.com/mcp", auth=BearerAuth("my-api-key")) as client: + tools = await client.list_tools() +``` + +For a dynamic token (environment variable, cache, external service), pass a +callable — sync or async: + +```python +import os +from mcp.client.auth import BearerAuth + +auth = BearerAuth(lambda: os.environ.get("MCP_TOKEN")) +``` + +`token()` is called before every request, so the callable can return a freshly +rotated value each time. Keep it fast — return a cached value and refresh in the +background rather than blocking on network calls. + +### Handling 401 + +By default, `BearerAuth` raises `UnauthorizedError` immediately on 401. To +refresh credentials and retry once, pass an `on_unauthorized` handler: + +```python +from mcp.client.auth import BearerAuth, UnauthorizedContext + +token_cache = TokenCache() + +async def refresh(ctx: UnauthorizedContext) -> None: + # ctx.response.headers["WWW-Authenticate"] has scope/resource_metadata hints + await token_cache.invalidate() + +auth = BearerAuth(token_cache.get, on_unauthorized=refresh) +``` + +After `on_unauthorized` returns, `token()` is called again and the request is +retried once. If the retry also gets 401, `UnauthorizedError` is raised. Retry +state is scoped per-request — a failed retry on one request does not block +retries on subsequent requests. + +To abort without retrying (for example, when interactive user action is +required), raise from the handler: + +```python +async def signal_host(ctx: UnauthorizedContext) -> None: + ui.show_reauth_prompt() + raise UnauthorizedError("User action required before retry") +``` + +### Subclassing + +For more complex providers, subclass `BearerAuth` and override `token()` and +`on_unauthorized()`: + +```python +from mcp.client.auth import BearerAuth, UnauthorizedContext + +class MyAuth(BearerAuth): + async def token(self) -> str | None: + return await self._store.get_access_token() + + async def on_unauthorized(self, context: UnauthorizedContext) -> None: + await self._store.refresh() +``` + +## OAuth 2.1 + +For the full OAuth authorization-code flow with PKCE — including Protected +Resource Metadata discovery, authorization server metadata discovery, dynamic +client registration, and automatic token refresh — use `OAuthClientProvider`: + +```python +import httpx +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.auth import OAuthClientMetadata + +auth = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=OAuthClientMetadata( + client_name="My MCP Client", + redirect_uris=["http://localhost:3000/callback"], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ), + storage=my_token_storage, + redirect_handler=open_browser, + callback_handler=wait_for_callback, +) + +async with streamable_http_client("https://api.example.com/mcp", auth=auth) as (read, write): + ... +``` + +See `examples/snippets/clients/oauth_client.py` for a complete working example. + +### Non-interactive grants + +For machine-to-machine authentication without a browser redirect, use the +extensions in `mcp.client.auth.extensions`: + +- `ClientCredentialsOAuthProvider` — `client_credentials` grant with client ID + and secret +- `PrivateKeyJWTOAuthProvider` — `client_credentials` with `private_key_jwt` + client authentication (RFC 7523) + +## Custom `httpx.Auth` + +Any `httpx.Auth` implementation works. To combine authentication with custom +HTTP settings (headers, timeouts, proxies), configure an `httpx.AsyncClient` +directly: + +```python +import httpx +from mcp.client.streamable_http import streamable_http_client + +http_client = httpx.AsyncClient( + auth=my_auth, + headers={"X-Custom": "value"}, + timeout=httpx.Timeout(60.0), +) + +async with http_client: + async with streamable_http_client(url, http_client=http_client) as (read, write): + ... +``` diff --git a/examples/snippets/clients/bearer_auth_client.py b/examples/snippets/clients/bearer_auth_client.py new file mode 100644 index 000000000..8f4242524 --- /dev/null +++ b/examples/snippets/clients/bearer_auth_client.py @@ -0,0 +1,45 @@ +"""Minimal bearer-token authentication example. + +Demonstrates the simplest possible MCP client authentication: a bearer token +from an environment variable. `BearerAuth` is an `httpx.Auth` implementation +that calls `token()` before every request and optionally `on_unauthorized()` +on 401 before retrying once. + +For full OAuth flows (authorization code, PKCE, dynamic client registration), +see `oauth_client.py` and use `OAuthClientProvider` instead — both plug into +the same `auth` parameter. + +Run against any MCP server that accepts bearer tokens: + + MCP_TOKEN=your-token MCP_SERVER_URL=http://localhost:8001/mcp uv run bearer-auth-client +""" + +import asyncio +import os + +from mcp.client import Client +from mcp.client.auth import BearerAuth + + +async def main() -> None: + server_url = os.environ.get("MCP_SERVER_URL", "http://localhost:8001/mcp") + token = os.environ.get("MCP_TOKEN") + + if not token: + raise SystemExit("Set MCP_TOKEN to your bearer token") + + # token() is called before every request. With no on_unauthorized handler, + # a 401 raises UnauthorizedError immediately — no retry. + auth = BearerAuth(token) + + async with Client(server_url, auth=auth) as client: + tools = await client.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + +def run() -> None: + asyncio.run(main()) + + +if __name__ == "__main__": + run() diff --git a/examples/snippets/pyproject.toml b/examples/snippets/pyproject.toml index 4e68846a0..10b38e5ee 100644 --- a/examples/snippets/pyproject.toml +++ b/examples/snippets/pyproject.toml @@ -21,4 +21,5 @@ completion-client = "clients.completion_client:main" direct-execution-server = "servers.direct_execution:main" display-utilities-client = "clients.display_utilities:main" oauth-client = "clients.oauth_client:run" +bearer-auth-client = "clients.bearer_auth_client:run" elicitation-client = "clients.url_elicitation_client:run" diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py index ab3179ecb..834f3c539 100644 --- a/src/mcp/client/auth/__init__.py +++ b/src/mcp/client/auth/__init__.py @@ -1,9 +1,19 @@ -"""OAuth2 Authentication implementation for HTTPX. +"""Client-side authentication for MCP HTTP transports. -Implements authorization code flow with PKCE and automatic token refresh. +Two `httpx.Auth` implementations are provided: + +- `BearerAuth` — minimal two-method provider (`token()` + optional + `on_unauthorized()`) for API keys, gateway-managed tokens, service accounts, + or any scenario where the token comes from an external pipeline. +- `OAuthClientProvider` — full OAuth 2.1 authorization-code flow with PKCE, + Protected Resource Metadata discovery (RFC 9728), dynamic client registration, + and automatic token refresh. + +Both are `httpx.Auth` subclasses and plug into the same `auth` parameter. """ -from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError +from mcp.client.auth.bearer import BearerAuth, TokenSource, UnauthorizedContext, UnauthorizedHandler +from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError, UnauthorizedError from mcp.client.auth.oauth2 import ( OAuthClientProvider, PKCEParameters, @@ -11,10 +21,15 @@ ) __all__ = [ + "BearerAuth", "OAuthClientProvider", "OAuthFlowError", "OAuthRegistrationError", "OAuthTokenError", "PKCEParameters", + "TokenSource", "TokenStorage", + "UnauthorizedContext", + "UnauthorizedError", + "UnauthorizedHandler", ] diff --git a/src/mcp/client/auth/bearer.py b/src/mcp/client/auth/bearer.py new file mode 100644 index 000000000..7a077c276 --- /dev/null +++ b/src/mcp/client/auth/bearer.py @@ -0,0 +1,175 @@ +"""Minimal bearer-token authentication for MCP HTTP transports. + +Provides `BearerAuth`, a lightweight `httpx.Auth` implementation with a two-method +contract (`token()` and `on_unauthorized()`). Use this when you have a token from +an external source — API keys, gateway-managed tokens, service accounts, enterprise +SSO pipelines — and don't need the full OAuth authorization-code flow. + +For OAuth flows (authorization code with PKCE, dynamic client registration, token +refresh), use `OAuthClientProvider` instead. Both are `httpx.Auth` subclasses and +plug into the same `auth` parameter. +""" + +from __future__ import annotations + +import inspect +from collections.abc import AsyncGenerator, Awaitable, Callable, Generator +from dataclasses import dataclass + +import httpx + +from mcp.client.auth.exceptions import UnauthorizedError + +TokenSource = str | Callable[[], str | None] | Callable[[], Awaitable[str | None]] +"""A bearer-token source: a static string, or a sync/async callable returning one.""" + +UnauthorizedHandler = Callable[["UnauthorizedContext"], Awaitable[None]] +"""Async handler invoked when the server responds with 401.""" + + +@dataclass +class UnauthorizedContext: + """Context passed to `on_unauthorized` when the server responds with 401. + + Handlers can inspect `response.headers["WWW-Authenticate"]` for resource metadata + URLs and scope hints per RFC 6750 §3 and RFC 9728, then refresh credentials before + the single retry. + """ + + response: httpx.Response + """The 401 response. Body has been read — `response.text` / `response.json()` are safe.""" + + request: httpx.Request + """The request that was rejected. `request.url` is the MCP server URL.""" + + +class BearerAuth(httpx.Auth): + """Minimal bearer-token authentication for MCP HTTP transports. + + Implements `httpx.Auth` with a two-method contract: + + - `token()` — called before every request to obtain the current bearer token. + - `on_unauthorized()` — called when the server responds with 401, giving the + provider a chance to refresh credentials before the transport retries once. + + For static tokens (API keys, pre-provisioned credentials):: + + auth = BearerAuth("my-api-key") + + For dynamic tokens (read from environment, cache, or external service):: + + auth = BearerAuth(lambda: os.environ.get("MCP_TOKEN")) + auth = BearerAuth(get_token_async) # async callable + + For custom 401 handling (token refresh, re-authentication signal):: + + async def refresh(ctx: UnauthorizedContext) -> None: + await my_token_cache.invalidate() + + auth = BearerAuth(get_token, on_unauthorized=refresh) + + Subclass and override `token()` / `on_unauthorized()` for more complex providers. + + For full OAuth 2.1 flows (authorization code with PKCE, discovery, registration), + use `OAuthClientProvider` — both are `httpx.Auth` subclasses and accepted by the + same `auth` parameter on transports. + """ + + def __init__( + self, + token: TokenSource | None = None, + on_unauthorized: UnauthorizedHandler | None = None, + ) -> None: + """Initialize bearer-token authentication. + + Args: + token: The bearer token source. A static string, a sync callable + returning `str | None`, or an async callable returning `str | None`. + Called before every request. If `None`, subclasses must override + `token()`. + on_unauthorized: Optional async handler called when the server responds + with 401. After the handler returns, `token()` is called again and + the request retried once. If not provided, 401 raises + `UnauthorizedError` immediately. If the retry also gets 401, + `UnauthorizedError` is raised. + """ + self._token = token + self._on_unauthorized = on_unauthorized + + async def token(self) -> str | None: + """Return the current bearer token, or `None` if unavailable. + + Called before every request. The default implementation resolves the + `token` argument passed to `__init__` (string, sync callable, or async + callable). Override for custom retrieval logic. + + Implementations should be fast — return a cached value and refresh in the + background rather than blocking on network calls here. + """ + src = self._token + if src is None or isinstance(src, str): + return src + result = src() + if inspect.isawaitable(result): + return await result + return result + + async def on_unauthorized(self, context: UnauthorizedContext) -> None: + """Handle a 401 response. Called once before the single retry. + + The default implementation delegates to the `on_unauthorized` callable + passed to `__init__`, or raises `UnauthorizedError` if none was provided. + Override to implement custom refresh logic. + + Implementations should refresh tokens, clear caches, or signal the host + application — whatever is needed so the next `token()` call returns a + valid token. Raise an exception to abort without retrying (e.g., when + interactive user action is required before a retry could succeed). + """ + if self._on_unauthorized is None: + www_auth = context.response.headers.get("WWW-Authenticate", "") + hint = f" (WWW-Authenticate: {www_auth})" if www_auth else "" + raise UnauthorizedError( + f"Server at {context.request.url} returned 401 Unauthorized{hint}; " + "no on_unauthorized handler configured" + ) + await self._on_unauthorized(context) + + def sync_auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + """Not supported — `BearerAuth` is async-only. + + Raises: + RuntimeError: Always. Use `httpx.AsyncClient`, not `httpx.Client`. + """ + raise RuntimeError( + "BearerAuth is async-only because token() and on_unauthorized() are " + "coroutines; use httpx.AsyncClient, not httpx.Client" + ) + yield request # pragma: no cover — unreachable; makes this a generator for type compat + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + """httpx auth-flow integration. + + Each request gets a fresh generator instance, so retry state is naturally + scoped per-operation — there is no shared retry counter to reset or leak + across concurrent requests. + """ + await self._apply_token(request) + response = yield request + + if response.status_code == 401: + await response.aread() + await self.on_unauthorized(UnauthorizedContext(response=response, request=request)) + + await self._apply_token(request) + response = yield request + + if response.status_code == 401: + raise UnauthorizedError(f"Server at {request.url} returned 401 Unauthorized after re-authentication") + + async def _apply_token(self, request: httpx.Request) -> None: + token = await self.token() + if token: + request.headers["Authorization"] = f"Bearer {token}" + else: + request.headers.pop("Authorization", None) diff --git a/src/mcp/client/auth/exceptions.py b/src/mcp/client/auth/exceptions.py index 5ce8777b8..c281212f8 100644 --- a/src/mcp/client/auth/exceptions.py +++ b/src/mcp/client/auth/exceptions.py @@ -1,3 +1,12 @@ +class UnauthorizedError(Exception): + """Raised when the server responds with 401 and the auth provider cannot recover. + + Raised by `BearerAuth` when no `on_unauthorized` handler is configured, or when + the single retry after `on_unauthorized` also receives 401. Callers can catch + this to trigger an interactive re-authentication flow or surface a login prompt. + """ + + class OAuthFlowError(Exception): """Base exception for OAuth flow errors.""" diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360f..51b53c2d3 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -6,6 +6,8 @@ from dataclasses import KW_ONLY, dataclass, field from typing import Any +import httpx + from mcp.client._memory import InMemoryTransport from mcp.client._transport import Transport from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT @@ -95,6 +97,12 @@ async def main(): elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" + auth: httpx.Auth | None = None + """Optional HTTP authentication provider (e.g., `BearerAuth` or `OAuthClientProvider`). + + Only used when `server` is a URL string. Ignored for in-memory and custom transports. + """ + _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _transport: Transport = field(init=False) @@ -103,7 +111,7 @@ def __post_init__(self) -> None: if isinstance(self.server, Server | MCPServer): self._transport = InMemoryTransport(self.server, raise_exceptions=self.raise_exceptions) elif isinstance(self.server, str): - self._transport = streamable_http_client(self.server) + self._transport = streamable_http_client(self.server, auth=self.auth) else: self._transport = self.server diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3afb94b03..15a0e3a44 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -511,6 +511,7 @@ async def streamable_http_client( url: str, *, http_client: httpx.AsyncClient | None = None, + auth: httpx.Auth | None = None, terminate_on_close: bool = True, ) -> AsyncGenerator[TransportStreams, None]: """Client transport for StreamableHTTP. @@ -519,7 +520,12 @@ async def streamable_http_client( url: The MCP server endpoint URL. http_client: Optional pre-configured httpx.AsyncClient. If None, a default client with recommended MCP timeouts will be created. To configure headers, - authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. + timeouts, or other HTTP settings, create an httpx.AsyncClient and pass it here. + Mutually exclusive with `auth`. + auth: Optional httpx.Auth provider (e.g., `BearerAuth` or `OAuthClientProvider`). + Shortcut for creating a default http_client with this auth configured. + Mutually exclusive with `http_client` — to combine auth with custom HTTP + settings, pass `http_client=httpx.AsyncClient(auth=..., ...)` instead. terminate_on_close: If True, send a DELETE request to terminate the session when the context exits. Yields: @@ -527,16 +533,25 @@ async def streamable_http_client( - read_stream: Stream for reading messages from the server - write_stream: Stream for sending messages to the server + Raises: + ValueError: If both `http_client` and `auth` are provided. + Example: See examples/snippets/clients/ for usage patterns. """ + if http_client is not None and auth is not None: + raise ValueError( + "Pass either `http_client` or `auth`, not both. " + "To combine auth with custom HTTP settings, set auth on the httpx.AsyncClient." + ) + # Determine if we need to create and manage the client client_provided = http_client is not None client = http_client if client is None: # Create default client with recommended MCP timeouts - client = create_mcp_http_client() + client = create_mcp_http_client(auth=auth) transport = StreamableHTTPTransport(url) diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 251469eaa..1a609e5dd 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -91,7 +91,7 @@ def create_mcp_http_client( kwargs["headers"] = headers # Handle authentication - if auth is not None: # pragma: no cover + if auth is not None: kwargs["auth"] = auth return httpx.AsyncClient(**kwargs) diff --git a/tests/client/auth/test_bearer.py b/tests/client/auth/test_bearer.py new file mode 100644 index 000000000..995943ee6 --- /dev/null +++ b/tests/client/auth/test_bearer.py @@ -0,0 +1,392 @@ +"""Tests for BearerAuth — the minimal two-method bearer-token provider.""" + +from __future__ import annotations + +import httpx +import pytest + +from mcp.client.auth import BearerAuth, UnauthorizedContext, UnauthorizedError +from mcp.client.streamable_http import streamable_http_client +from mcp.shared._httpx_utils import create_mcp_http_client + +pytestmark = pytest.mark.anyio + + +def make_request(url: str = "https://api.example.com/mcp") -> httpx.Request: + return httpx.Request("POST", url) + + +def make_response(status: int, *, request: httpx.Request, www_auth: str | None = None) -> httpx.Response: + headers = {"WWW-Authenticate": www_auth} if www_auth else {} + return httpx.Response(status, headers=headers, request=request) + + +# --- token() resolution ------------------------------------------------------ + + +async def test_static_string_token_sets_authorization_header(): + auth = BearerAuth("my-api-key") + request = make_request() + + flow = auth.async_auth_flow(request) + sent = await flow.__anext__() + + assert sent.headers["Authorization"] == "Bearer my-api-key" + + with pytest.raises(StopAsyncIteration): + await flow.asend(make_response(200, request=request)) + + +async def test_sync_callable_token_resolved_per_request(): + calls = 0 + + def get_token() -> str: + nonlocal calls + calls += 1 + return f"token-{calls}" + + auth = BearerAuth(get_token) + + for expected in ("token-1", "token-2"): + request = make_request() + flow = auth.async_auth_flow(request) + sent = await flow.__anext__() + assert sent.headers["Authorization"] == f"Bearer {expected}" + with pytest.raises(StopAsyncIteration): + await flow.asend(make_response(200, request=request)) + + assert calls == 2 + + +async def test_async_callable_token_awaited(): + async def get_token() -> str: + return "async-token" + + auth = BearerAuth(get_token) + request = make_request() + + flow = auth.async_auth_flow(request) + sent = await flow.__anext__() + + assert sent.headers["Authorization"] == "Bearer async-token" + + +async def test_none_token_omits_authorization_header(): + auth = BearerAuth(lambda: None) + request = make_request() + + flow = auth.async_auth_flow(request) + sent = await flow.__anext__() + + assert "Authorization" not in sent.headers + + +async def test_no_token_source_omits_authorization_header(): + auth = BearerAuth() + request = make_request() + + flow = auth.async_auth_flow(request) + sent = await flow.__anext__() + + assert "Authorization" not in sent.headers + + +# --- 401 handling: no on_unauthorized handler -------------------------------- + + +async def test_401_without_handler_raises_unauthorized_error(): + auth = BearerAuth("rejected-token") + request = make_request() + + flow = auth.async_auth_flow(request) + await flow.__anext__() + + with pytest.raises(UnauthorizedError, match="401 Unauthorized"): + await flow.asend(make_response(401, request=request)) + + +async def test_401_without_handler_includes_www_authenticate_in_error(): + auth = BearerAuth("rejected-token") + request = make_request() + + flow = auth.async_auth_flow(request) + await flow.__anext__() + + www_auth = 'Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource"' + with pytest.raises(UnauthorizedError, match="WWW-Authenticate"): + await flow.asend(make_response(401, request=request, www_auth=www_auth)) + + +# --- 401 handling: with on_unauthorized handler ------------------------------ + + +async def test_401_with_handler_retries_once_with_fresh_token(): + current = "old-token" + token_calls = 0 + handler_calls = 0 + + def get_token() -> str: + nonlocal token_calls + token_calls += 1 + return current + + async def refresh(ctx: UnauthorizedContext) -> None: + nonlocal current, handler_calls + handler_calls += 1 + assert ctx.response.status_code == 401 + assert ctx.request.url == "https://api.example.com/mcp" + current = "new-token" + + auth = BearerAuth(get_token, on_unauthorized=refresh) + request = make_request() + + flow = auth.async_auth_flow(request) + + first = await flow.__anext__() + assert first.headers["Authorization"] == "Bearer old-token" + + retry = await flow.asend(make_response(401, request=request)) + assert retry.headers["Authorization"] == "Bearer new-token" + + with pytest.raises(StopAsyncIteration): + await flow.asend(make_response(200, request=request)) + + assert token_calls == 2 + assert handler_calls == 1 + + +async def test_401_on_retry_raises_unauthorized_error(): + async def noop(ctx: UnauthorizedContext) -> None: + pass + + auth = BearerAuth("still-bad", on_unauthorized=noop) + request = make_request() + + flow = auth.async_auth_flow(request) + await flow.__anext__() + await flow.asend(make_response(401, request=request)) + + with pytest.raises(UnauthorizedError, match="after re-authentication"): + await flow.asend(make_response(401, request=request)) + + +async def test_handler_exception_propagates_without_retry(): + token_calls = 0 + + def get_token() -> str: + nonlocal token_calls + token_calls += 1 + return "token" + + async def signal_and_abort(ctx: UnauthorizedContext) -> None: + raise RuntimeError("user action required") + + auth = BearerAuth(get_token, on_unauthorized=signal_and_abort) + request = make_request() + + flow = auth.async_auth_flow(request) + await flow.__anext__() + + with pytest.raises(RuntimeError, match="user action required"): + await flow.asend(make_response(401, request=request)) + + assert token_calls == 1 # no retry attempted + + +async def test_retry_state_is_per_operation_not_shared(): + """Each request gets a fresh generator, so a failed retry on one request + doesn't prevent retry on the next. This is the httpx.Auth generator pattern's + natural per-operation isolation — no instance state to reset or leak.""" + attempts: list[str] = [] + + async def track(ctx: UnauthorizedContext) -> None: + attempts.append("refresh") + + auth = BearerAuth("token", on_unauthorized=track) + + # First request: 401 → retry → 401 → UnauthorizedError + request1 = make_request() + flow1 = auth.async_auth_flow(request1) + await flow1.__anext__() + await flow1.asend(make_response(401, request=request1)) + with pytest.raises(UnauthorizedError): + await flow1.asend(make_response(401, request=request1)) + + # Second request: fresh generator, retry allowed again + request2 = make_request() + flow2 = auth.async_auth_flow(request2) + await flow2.__anext__() + retry = await flow2.asend(make_response(401, request=request2)) + assert retry.headers["Authorization"] == "Bearer token" + with pytest.raises(StopAsyncIteration): + await flow2.asend(make_response(200, request=request2)) + + assert attempts == ["refresh", "refresh"] + + +async def test_retry_clears_stale_header_when_token_becomes_none(): + """If token() returns None on retry, the stale Authorization header from the + first attempt must be cleared — not silently re-sent.""" + tokens = iter(["first", None]) + + async def refresh(ctx: UnauthorizedContext) -> None: + pass + + auth = BearerAuth(lambda: next(tokens), on_unauthorized=refresh) + request = make_request() + + flow = auth.async_auth_flow(request) + first = await flow.__anext__() + assert first.headers["Authorization"] == "Bearer first" + + retry = await flow.asend(make_response(401, request=request)) + assert "Authorization" not in retry.headers + + +async def test_handler_can_read_response_body(): + """Response body is read before on_unauthorized, so handlers can inspect it + even when the transport uses streaming (httpx stream() mode).""" + captured: list[str] = [] + + async def inspect_body(ctx: UnauthorizedContext) -> None: + captured.append(ctx.response.text) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(401, json={"error": "invalid_token"}) + + auth = BearerAuth("bad", on_unauthorized=inspect_body) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=auth) as client: + with pytest.raises(UnauthorizedError): + async with client.stream("POST", "https://api.example.com/mcp"): + pass # pragma: no cover — auth flow raises before stream body opens + + assert captured == ['{"error":"invalid_token"}'] + + +async def test_handler_receives_www_authenticate_header(): + captured: list[str] = [] + + async def inspect(ctx: UnauthorizedContext) -> None: + captured.append(ctx.response.headers.get("WWW-Authenticate", "")) + + auth = BearerAuth("token", on_unauthorized=inspect) + request = make_request() + + flow = auth.async_auth_flow(request) + await flow.__anext__() + + www_auth = 'Bearer scope="read write", resource_metadata="https://example.com/prm"' + await flow.asend(make_response(401, request=request, www_auth=www_auth)) + + assert captured == [www_auth] + + +# --- subclassing ------------------------------------------------------------- + + +async def test_subclass_override_token_and_on_unauthorized(): + class RefreshingAuth(BearerAuth): + def __init__(self) -> None: + super().__init__() + self.current = "initial" + self.refreshed = False + + async def token(self) -> str | None: + return self.current + + async def on_unauthorized(self, context: UnauthorizedContext) -> None: + self.current = "refreshed" + self.refreshed = True + + auth = RefreshingAuth() + request = make_request() + + flow = auth.async_auth_flow(request) + first = await flow.__anext__() + assert first.headers["Authorization"] == "Bearer initial" + + retry = await flow.asend(make_response(401, request=request)) + assert retry.headers["Authorization"] == "Bearer refreshed" + assert auth.refreshed is True + + +# --- httpx integration (wire-level) ------------------------------------------ + + +async def test_e2e_with_mock_transport_sets_header(): + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"ok": True}) + + auth = BearerAuth("wire-token") + async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=auth) as client: + response = await client.post("https://api.example.com/mcp", json={}) + + assert response.status_code == 200 + assert captured[0].headers["Authorization"] == "Bearer wire-token" + + +async def test_e2e_with_mock_transport_retries_on_401(): + seen_tokens: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + token = request.headers.get("Authorization") + seen_tokens.append(token) + if token == "Bearer old": + return httpx.Response(401, headers={"WWW-Authenticate": "Bearer"}) + return httpx.Response(200, json={"ok": True}) + + current = "old" + + async def refresh(ctx: UnauthorizedContext) -> None: + nonlocal current + current = "new" + + auth = BearerAuth(lambda: current, on_unauthorized=refresh) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=auth) as client: + response = await client.post("https://api.example.com/mcp", json={}) + + assert response.status_code == 200 + assert seen_tokens == ["Bearer old", "Bearer new"] + + +async def test_e2e_unauthorized_error_propagates(): + def always_401(request: httpx.Request) -> httpx.Response: + return httpx.Response(401) + + auth = BearerAuth("rejected") + async with httpx.AsyncClient(transport=httpx.MockTransport(always_401), auth=auth) as client: + with pytest.raises(UnauthorizedError): + await client.post("https://api.example.com/mcp", json={}) + + +# --- sync client guard ------------------------------------------------------- + + +def test_sync_client_raises_clear_error(): + auth = BearerAuth("token") + with pytest.raises(RuntimeError, match="async-only"): + with httpx.Client(auth=auth) as client: + client.get("https://api.example.com/mcp") + + +# --- streamable_http_client integration -------------------------------------- + + +async def test_streamable_http_client_rejects_both_auth_and_http_client(): + auth = BearerAuth("token") + http_client = httpx.AsyncClient() + + with pytest.raises(ValueError, match="either `http_client` or `auth`"): + async with streamable_http_client("https://example.com/mcp", auth=auth, http_client=http_client): + pass # pragma: no cover + + await http_client.aclose() + + +async def test_create_mcp_http_client_passes_auth(): + auth = BearerAuth("factory-token") + async with create_mcp_http_client(auth=auth) as client: + assert client.auth is auth diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 18368e6bb..559ab9f76 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -5,6 +5,7 @@ from unittest.mock import patch import anyio +import httpx import pytest from inline_snapshot import snapshot @@ -307,7 +308,18 @@ async def test_complete_with_prompt_reference(simple_server: Server): def test_client_with_url_initializes_streamable_http_transport(): with patch("mcp.client.client.streamable_http_client") as mock: _ = Client("http://localhost:8000/mcp") - mock.assert_called_once_with("http://localhost:8000/mcp") + mock.assert_called_once_with("http://localhost:8000/mcp", auth=None) + + +def test_client_with_url_passes_auth_to_transport(): + class FakeAuth(httpx.Auth): + def auth_flow(self, request: httpx.Request): + yield request # pragma: no cover + + auth = FakeAuth() + with patch("mcp.client.client.streamable_http_client") as mock: + _ = Client("http://localhost:8000/mcp", auth=auth) + mock.assert_called_once_with("http://localhost:8000/mcp", auth=auth) async def test_client_uses_transport_directly(app: MCPServer): From 07c95314351a65caa95405abc47d59346b76120d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:07:09 +0000 Subject: [PATCH 2/2] test: parse JSON instead of asserting exact whitespace in response body test Older httpx versions (lowest-direct CI matrix) serialize JSON with a space after the colon; newer versions use compact separators. Parse the captured body instead of asserting the exact string. --- tests/client/auth/test_bearer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/client/auth/test_bearer.py b/tests/client/auth/test_bearer.py index 995943ee6..e8d180e8b 100644 --- a/tests/client/auth/test_bearer.py +++ b/tests/client/auth/test_bearer.py @@ -2,6 +2,8 @@ from __future__ import annotations +import json + import httpx import pytest @@ -260,7 +262,7 @@ def handler(request: httpx.Request) -> httpx.Response: async with client.stream("POST", "https://api.example.com/mcp"): pass # pragma: no cover — auth flow raises before stream body opens - assert captured == ['{"error":"invalid_token"}'] + assert [json.loads(body) for body in captured] == [{"error": "invalid_token"}] async def test_handler_receives_www_authenticate_header():