From a4ffc4bb11eb97dd20f19b2015a87f72a6b9a373 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 04:05:47 +0530 Subject: [PATCH 1/4] Fix stateful StreamableHTTP auth context rebinding --- src/mcp/server/lowlevel/server.py | 73 +++++++----- .../test_auth_context_streamable_http.py | 104 ++++++++++++++++++ 2 files changed, 151 insertions(+), 26 deletions(-) create mode 100644 tests/server/auth/middleware/test_auth_context_streamable_http.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c28842272..0253dfd12 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -38,8 +38,8 @@ async def main(): import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from importlib.metadata import version as importlib_version from typing import Any, Generic @@ -52,8 +52,8 @@ async def main(): from typing_extensions import TypeVar from mcp import types -from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings @@ -74,6 +74,23 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) +@contextmanager +def _bind_request_auth_context(request_context: Any) -> Iterator[None]: + """Rebind auth context from the current transport request while handling a message.""" + authenticated_user = None + scope = getattr(request_context, "scope", None) + if isinstance(scope, dict): + scope_user = scope.get("user") + if isinstance(scope_user, AuthenticatedUser): + authenticated_user = scope_user + + token = auth_context_var.set(authenticated_user) + try: + yield + finally: + auth_context_var.reset(token) + + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): self.prompts_changed = prompts_changed @@ -452,28 +469,32 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: - task_metadata = getattr(req.params, "task", None) - ctx = ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) - response = await handler(ctx, req.params) + # Stateful HTTP sessions process later requests on tasks that were + # created during session setup, so ContextVar snapshots can lag + # behind the current request unless we rebind them here. + with _bind_request_auth_context(request_data): + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + # Get task metadata from request params if present + task_metadata = None + if hasattr(req, "params") and req.params is not None: + task_metadata = getattr(req.params, "task", None) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, + ) + response = await handler(ctx, req.params) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py new file mode 100644 index 000000000..afc7a5d12 --- /dev/null +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -0,0 +1,104 @@ +"""Regression tests for auth context in StreamableHTTP servers.""" + +import time +from collections.abc import Generator + +import httpx +import pytest +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Mount + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) +from tests.test_helpers import run_uvicorn_in_thread + + +class _EchoTokenVerifier: + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access = get_access_token() + text = access.token if access else "" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + +async def _handle_list_tools( + ctx: ServerRequestContext, + params: PaginatedRequestParams | None, +) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})]) + + +class _MutableBearerAuth(httpx.Auth): + def __init__(self, token: str) -> None: + self.token = token + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + +@pytest.fixture +def stateful_auth_server() -> Generator[str, None, None]: + server = Server( + "auth-test-server", + on_call_tool=_handle_whoami, + on_list_tools=_handle_list_tools, + ) + session_manager = StreamableHTTPSessionManager(app=server, stateless=False) + app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + middleware=[ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), + Middleware(AuthContextMiddleware), + ], + lifespan=lambda app: session_manager.run(), + ) + + with run_uvicorn_in_thread(app, host="127.0.0.1", log_level="error") as base_url: + yield f"{base_url}/mcp" + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_request_in_stateful_session(stateful_auth_server: str) -> None: + auth = _MutableBearerAuth("token-A") + async with httpx.AsyncClient( + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client: + async with streamable_http_client(stateful_auth_server, http_client=http_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + first_response = await session.call_tool("whoami", {}) + assert len(first_response.content) == 1 + assert isinstance(first_response.content[0], TextContent) + assert first_response.content[0].text == "token-A" + + auth.token = "token-B" + + second_response = await session.call_tool("whoami", {}) + assert len(second_response.content) == 1 + assert isinstance(second_response.content[0], TextContent) + assert second_response.content[0].text == "token-B" From 08ae6dfae2035de535bac415bc62dfe5facec932 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 04:58:58 +0530 Subject: [PATCH 2/4] fix: accept mapping-backed auth scopes --- src/mcp/server/lowlevel/server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 0253dfd12..61e6afd47 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -38,10 +38,10 @@ async def main(): import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Mapping from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic +from typing import Any, Generic, cast import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -79,8 +79,8 @@ def _bind_request_auth_context(request_context: Any) -> Iterator[None]: """Rebind auth context from the current transport request while handling a message.""" authenticated_user = None scope = getattr(request_context, "scope", None) - if isinstance(scope, dict): - scope_user = scope.get("user") + if isinstance(scope, Mapping): + scope_user = cast(Mapping[str, object], scope).get("user") if isinstance(scope_user, AuthenticatedUser): authenticated_user = scope_user From a040142b4f87c52666a15de47c5ae527ba0f2aa7 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 05:31:50 +0530 Subject: [PATCH 3/4] test: isolate streamable auth regression server lifecycle --- .../test_auth_context_streamable_http.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py index afc7a5d12..4001b5e0c 100644 --- a/tests/server/auth/middleware/test_auth_context_streamable_http.py +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -1,10 +1,13 @@ """Regression tests for auth context in StreamableHTTP servers.""" +import multiprocessing +import socket import time from collections.abc import Generator import httpx import pytest +import uvicorn from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -25,7 +28,7 @@ TextContent, Tool, ) -from tests.test_helpers import run_uvicorn_in_thread +from tests.test_helpers import wait_for_server class _EchoTokenVerifier: @@ -55,15 +58,14 @@ def auth_flow(self, request: httpx.Request): yield request -@pytest.fixture -def stateful_auth_server() -> Generator[str, None, None]: +def _create_stateful_auth_app() -> Starlette: server = Server( "auth-test-server", on_call_tool=_handle_whoami, on_list_tools=_handle_list_tools, ) session_manager = StreamableHTTPSessionManager(app=server, stateless=False) - app = Starlette( + return Starlette( routes=[Mount("/mcp", app=session_manager.handle_request)], middleware=[ Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), @@ -72,8 +74,43 @@ def stateful_auth_server() -> Generator[str, None, None]: lifespan=lambda app: session_manager.run(), ) - with run_uvicorn_in_thread(app, host="127.0.0.1", log_level="error") as base_url: - yield f"{base_url}/mcp" + +def run_stateful_auth_server(port: int) -> None: # pragma: no cover + config = uvicorn.Config( + app=_create_stateful_auth_app(), + host="127.0.0.1", + port=port, + log_level="error", + access_log=False, + ) + uvicorn.Server(config).run() + + +@pytest.fixture +def stateful_auth_server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def stateful_auth_server(stateful_auth_server_port: int) -> Generator[str, None, None]: + proc = multiprocessing.Process( + target=run_stateful_auth_server, + kwargs={"port": stateful_auth_server_port}, + daemon=True, + ) + proc.start() + wait_for_server(stateful_auth_server_port) + + try: + yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp" + finally: + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): # pragma: no cover + proc.kill() + proc.join(timeout=1) @pytest.mark.anyio From 08c6c0dfbfe120ffc1ab4f6d8c8e6994c3932f54 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 05:40:45 +0530 Subject: [PATCH 4/4] fix: rebind auth context for notifications --- src/mcp/client/session.py | 3 +- src/mcp/server/lowlevel/server.py | 48 ++++++++---- src/mcp/server/session.py | 17 ++++- src/mcp/shared/session.py | 15 +++- .../test_auth_context_streamable_http.py | 73 +++++++++++++++++-- 5 files changed, 126 insertions(+), 30 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..608ec284d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -11,7 +11,7 @@ from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage +from mcp.shared.message import MessageMetadata, SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types._types import RequestParamsMeta @@ -461,6 +461,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques async def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + message_metadata: MessageMetadata = None, ) -> None: """Handle incoming messages by forwarding to the message handler.""" await self._message_handler(req) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 61e6afd47..183f345f0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -66,8 +66,8 @@ async def main(): from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.session import NotificationWithMetadata, RequestResponder logger = logging.getLogger(__name__) @@ -424,7 +424,9 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message: RequestResponder[types.ClientRequest, types.ServerResult] + | NotificationWithMetadata[types.ClientNotification] + | Exception, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, @@ -436,6 +438,13 @@ async def _handle_message( await self._handle_request( message, responder.request, session, lifespan_context, raise_exceptions ) + case NotificationWithMetadata() as notification: + await self._handle_notification( + notification.notification, + session, + lifespan_context, + notification.message_metadata, + ) case Exception(): logger.error(f"Received exception from stream: {message}") if raise_exceptions: @@ -532,24 +541,31 @@ async def _handle_notification( notify: types.ClientNotification, session: ServerSession, lifespan_context: LifespanResultT, + message_metadata: MessageMetadata = None, ) -> None: if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - ctx = ServerRequestContext( - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - ) - await handler(ctx, notify.params) + request_data = None + if isinstance(message_metadata, ServerMessageMetadata): + request_data = message_metadata.request_context + + with _bind_request_auth_context(request_data): + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..0ffda24f6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -43,9 +43,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + NotificationWithMetadata, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -60,7 +61,9 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception + RequestResponder[types.ClientRequest, types.ServerResult] + | NotificationWithMetadata[types.ClientNotification] + | Exception ) @@ -683,7 +686,15 @@ async def send_message(self, message: SessionMessage) -> None: """ await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: + async def _handle_incoming( + self, + req: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message_metadata: MessageMetadata = None, + ) -> None: + if isinstance(req, types.ClientNotification): + await self._incoming_message_stream_writer.send(NotificationWithMetadata(req, message_metadata)) + return + await self._incoming_message_stream_writer.send(req) @property diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f..b024f6221 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import dataclass from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -53,6 +54,14 @@ async def __call__( ) -> None: ... # pragma: no branch +@dataclass +class NotificationWithMetadata(Generic[ReceiveNotificationT]): + """A validated notification paired with its transport metadata.""" + + notification: ReceiveNotificationT + message_metadata: MessageMetadata = None + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -396,7 +405,7 @@ async def _receive_loop(self) -> None: except Exception: logging.exception("Progress callback raised an exception") await self._received_notification(notification) - await self._handle_incoming(notification) + await self._handle_incoming(notification, message.metadata) except Exception: # For other validation errors, log and continue logging.warning( # pragma: no cover @@ -515,6 +524,8 @@ async def send_progress_notification( """Sends a progress notification for a request that is currently being processed.""" async def _handle_incoming( - self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception + self, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, + message_metadata: MessageMetadata = None, ) -> None: """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py index 4001b5e0c..28a91da13 100644 --- a/tests/server/auth/middleware/test_auth_context_streamable_http.py +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -1,10 +1,15 @@ """Regression tests for auth context in StreamableHTTP servers.""" +from __future__ import annotations + import multiprocessing +import queue import socket import time from collections.abc import Generator +from multiprocessing.queues import Queue +import anyio import httpx import pytest import uvicorn @@ -58,11 +63,19 @@ def auth_flow(self, request: httpx.Request): yield request -def _create_stateful_auth_app() -> Starlette: +def _create_stateful_auth_app(progress_tokens: Queue[str] | None = None) -> Starlette: + async def _handle_progress(ctx: ServerRequestContext, params: object) -> None: + if progress_tokens is None: + return + + access = get_access_token() + progress_tokens.put(access.token if access else "") + server = Server( "auth-test-server", on_call_tool=_handle_whoami, on_list_tools=_handle_list_tools, + on_progress=_handle_progress, ) session_manager = StreamableHTTPSessionManager(app=server, stateless=False) return Starlette( @@ -75,9 +88,12 @@ def _create_stateful_auth_app() -> Starlette: ) -def run_stateful_auth_server(port: int) -> None: # pragma: no cover +def run_stateful_auth_server( + port: int, + progress_tokens: Queue[str] | None = None, +) -> None: # pragma: no cover config = uvicorn.Config( - app=_create_stateful_auth_app(), + app=_create_stateful_auth_app(progress_tokens), host="127.0.0.1", port=port, log_level="error", @@ -94,34 +110,45 @@ def stateful_auth_server_port() -> int: @pytest.fixture -def stateful_auth_server(stateful_auth_server_port: int) -> Generator[str, None, None]: +def stateful_auth_server( + stateful_auth_server_port: int, +) -> Generator[tuple[str, Queue[str]], None, None]: + progress_tokens: Queue[str] = multiprocessing.Queue() proc = multiprocessing.Process( target=run_stateful_auth_server, - kwargs={"port": stateful_auth_server_port}, + kwargs={ + "port": stateful_auth_server_port, + "progress_tokens": progress_tokens, + }, daemon=True, ) proc.start() wait_for_server(stateful_auth_server_port) try: - yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp" + yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp", progress_tokens finally: proc.terminate() proc.join(timeout=2) if proc.is_alive(): # pragma: no cover proc.kill() proc.join(timeout=1) + progress_tokens.close() + progress_tokens.join_thread() @pytest.mark.anyio -async def test_get_access_token_reflects_current_request_in_stateful_session(stateful_auth_server: str) -> None: +async def test_get_access_token_reflects_current_request_in_stateful_session( + stateful_auth_server: tuple[str, Queue[str]], +) -> None: + server_url, _ = stateful_auth_server auth = _MutableBearerAuth("token-A") async with httpx.AsyncClient( auth=auth, timeout=httpx.Timeout(30, read=30), follow_redirects=True, ) as http_client: - async with streamable_http_client(stateful_auth_server, http_client=http_client) as ( + async with streamable_http_client(server_url, http_client=http_client) as ( read_stream, write_stream, ): @@ -139,3 +166,33 @@ async def test_get_access_token_reflects_current_request_in_stateful_session(sta assert len(second_response.content) == 1 assert isinstance(second_response.content[0], TextContent) assert second_response.content[0].text == "token-B" + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_notification_in_stateful_session( + stateful_auth_server: tuple[str, Queue[str]], +) -> None: + server_url, progress_tokens = stateful_auth_server + auth = _MutableBearerAuth("token-A") + async with httpx.AsyncClient( + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client: + async with streamable_http_client(server_url, http_client=http_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + auth.token = "token-B" + await session.send_progress_notification(progress_token="progress-1", progress=1) + + with anyio.fail_after(5): + while True: + try: + assert progress_tokens.get_nowait() == "token-B" + break + except queue.Empty: + await anyio.sleep(0.01)