From 2a07260b24b2b47387909d006ac24f07a497751e Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Wed, 14 May 2025 14:31:08 +0200 Subject: [PATCH] Adding authorization at http layer on streamable_http via httpx auth param --- src/mcp/client/streamable_http.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3324dab5a7..e738cf55cb 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -83,6 +83,7 @@ def __init__( headers: dict[str, Any] | None = None, timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + auth= None, ) -> None: """Initialize the StreamableHTTP transport. @@ -93,6 +94,7 @@ def __init__( sse_read_timeout: Timeout for SSE read operations. """ self.url = url + self.auth= auth, self.headers = headers or {} self.timeout = timeout self.sse_read_timeout = sse_read_timeout @@ -245,11 +247,16 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: message = ctx.session_message.message is_initialization = self._is_initialization_request(message) + kwargs: dict[str, Any]={ + "json": message.model_dump(by_alias=True, mode="json", exclude_none=True), + "headers": headers, + } + if self.auth is not None: + kwargs["auth"] = self.auth async with ctx.client.stream( "POST", self.url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), - headers=headers, + **kwargs ) as response: if response.status_code == 202: logger.debug("Received 202 Accepted") @@ -406,7 +413,12 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: try: headers = self._update_headers_with_session(self.request_headers) - response = await client.delete(self.url, headers=headers) + kwargs: dict[str, Any] = { + "headers": headers + } + if self.auth is not None: + kwargs["auth"] = self.auth + response = await client.delete(self.url, **kwargs) if response.status_code == 405: logger.debug("Server does not allow session termination") @@ -427,6 +439,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + auth=None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -448,6 +461,8 @@ async def streamablehttp_client( - get_session_id_callback: Function to retrieve the current session ID """ transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + if auth is not None: + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout,auth) read_stream_writer, read_stream = anyio.create_memory_object_stream[ SessionMessage | Exception