Skip to content

Commit a67a98a

Browse files
committed
- Remove dead is_token_valid() guard from _perform_authorization; parent only calls this when new tokens are needed, so it runs unconditionally now
- Replace JWTBearerGrantRequestData TypedDict + double-cast with plain dict[str, str] matching prepare_token_auth signature - Stop pre-adding client_secret to POST body in JWT bearer grant; prepare_token_auth already handles placement per auth method - Add override_audience_with_issuer constructor param (default True) so callers in federated/multi-tenant setups can opt out of audience override - Stop mutating injected TokenExchangeParameters; use internal _subject_token field for refresh_with_new_id_token - Call validate_token_exchange_params in __init__ for fail-fast validation - Raise ValueError for idp_client_secret without idp_client_id - Use str.removesuffix("/mcp") in conformance client instead of str.replace
1 parent ab07f34 commit a67a98a

File tree

3 files changed

+194
-165
lines changed

3 files changed

+194
-165
lines changed

.github/actions/conformance/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,8 @@ async def run_cross_app_access_complete_flow(server_url: str) -> None:
352352
if not idp_issuer:
353353
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_issuer'")
354354

355-
# Extract base URL and construct auth issuer and resource ID
356-
# The conformance test sets up auth server at a known location
357-
base_url = server_url.replace("/mcp", "")
355+
# Extract base URL by stripping trailing /mcp path (Python 3.9+)
356+
base_url = server_url.removesuffix("/mcp")
358357
auth_issuer = context.get("auth_issuer", base_url)
359358
resource_id = context.get("resource_id", server_url)
360359

src/mcp/client/auth/extensions/enterprise_managed_auth.py

Lines changed: 66 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import time
99
from json import JSONDecodeError
10-
from typing import cast
1110

1211
import httpx
1312
import jwt
@@ -38,19 +37,6 @@ class TokenExchangeRequestData(TypedDict):
3837
client_secret: NotRequired[str]
3938

4039

41-
class JWTBearerGrantRequestData(TypedDict):
42-
"""Type definition for RFC 7523 JWT Bearer Grant request data.
43-
44-
Required fields are those mandated by RFC 7523.
45-
Optional fields (NotRequired) are for client authentication.
46-
"""
47-
48-
grant_type: Required[str]
49-
assertion: Required[str]
50-
client_id: NotRequired[str]
51-
client_secret: NotRequired[str]
52-
53-
5440
class TokenExchangeParameters(BaseModel):
5541
"""Parameters for RFC 8693 Token Exchange request."""
5642

@@ -187,15 +173,16 @@ class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
187173
- RFC 7523: JWT Bearer Grant (ID-JAG → Access Token)
188174
189175
Concurrency & Thread Safety:
190-
- SAFE: Concurrent requests within a single asyncio event loop. Token operations
191-
are protected by the parent class's `OAuthContext.lock`.
192-
- UNSAFE: Sharing a provider instance across multiple OS threads. Each thread
193-
must instantiate its own provider and event loop.
194-
- Note: Ensure any shared `TokenStorage` implementation is async-safe.
176+
- SAFE: Concurrent requests within a single asyncio event loop. Token
177+
operations (including ``_id_jag`` / ``_id_jag_expiry``) are protected
178+
by the parent class's ``OAuthContext.lock`` via ``async_auth_flow``.
179+
- UNSAFE: Sharing a provider instance across multiple OS threads. Each
180+
thread must instantiate its own provider and event loop.
181+
- Note: Ensure any shared ``TokenStorage`` implementation is async-safe.
195182
"""
196183

197-
# Default ID-JAG expiry when IdP doesn't provide expires_in
198-
# 15 minutes is a conservative default for enterprise environments
184+
# Default ID-JAG expiry when IdP doesn't provide expires_in.
185+
# 15 minutes is a conservative default for enterprise environments.
199186
DEFAULT_ID_JAG_EXPIRY_SECONDS = 900
200187

201188
def __init__(
@@ -209,6 +196,7 @@ def __init__(
209196
idp_client_id: str | None = None,
210197
idp_client_secret: str | None = None,
211198
default_id_jag_expiry: int = DEFAULT_ID_JAG_EXPIRY_SECONDS,
199+
override_audience_with_issuer: bool = True,
212200
) -> None:
213201
"""Initialize Enterprise Auth OAuth Client.
214202
@@ -217,13 +205,21 @@ def __init__(
217205
client_metadata: OAuth client metadata
218206
storage: Token storage implementation
219207
idp_token_endpoint: Enterprise IdP token endpoint URL
220-
token_exchange_params: Token exchange parameters
208+
token_exchange_params: Token exchange parameters (not mutated)
221209
timeout: Request timeout in seconds
222210
idp_client_id: Optional client ID registered with the IdP for token exchange
223-
idp_client_secret: Optional client secret registered with the IdP for token exchange
211+
idp_client_secret: Optional client secret registered with the IdP.
212+
Must be accompanied by ``idp_client_id``; providing a secret
213+
without an ID raises ``ValueError``.
224214
default_id_jag_expiry: Fallback ID-JAG expiry in seconds if the IdP
225-
omits `expires_in` (default: 900s/15m). Adjust to balance token
226-
freshness against IdP request load.
215+
omits ``expires_in`` (default: 900 s / 15 min).
216+
override_audience_with_issuer: If True (default), replaces the IdP
217+
audience with the discovered OAuth issuer URL. Set to False for
218+
federated identity setups where the audience must differ.
219+
220+
Raises:
221+
ValueError: If ``idp_client_secret`` is provided without ``idp_client_id``.
222+
OAuthFlowError: If ``token_exchange_params`` fail validation.
227223
"""
228224
super().__init__(
229225
server_url=server_url,
@@ -232,30 +228,31 @@ def __init__(
232228
timeout=timeout,
233229
)
234230
self.idp_token_endpoint = idp_token_endpoint
231+
# Keep original params immutable; track mutable subject_token separately
235232
self.token_exchange_params = token_exchange_params
233+
self._subject_token = token_exchange_params.subject_token
236234
self.idp_client_id = idp_client_id
237235
self.idp_client_secret = idp_client_secret
238236
self.default_id_jag_expiry = default_id_jag_expiry
237+
self.override_audience_with_issuer = override_audience_with_issuer
239238
self._id_jag: str | None = None
240239
self._id_jag_expiry: float | None = None
241240

242-
# Validate client authentication configuration
241+
# Fail-fast: secret without ID is almost certainly a misconfiguration
243242
if idp_client_secret is not None and idp_client_id is None:
244-
logger.warning(
245-
"idp_client_secret provided without idp_client_id. "
246-
"The secret will be sent to the IdP but may be ignored. "
247-
"Consider providing both idp_client_id and idp_client_secret together."
243+
raise ValueError(
244+
"idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
248245
)
249246

247+
# Validate token exchange params at construction time
248+
validate_token_exchange_params(token_exchange_params)
249+
250250
async def exchange_token_for_id_jag(
251251
self,
252252
client: httpx.AsyncClient,
253253
) -> str:
254254
"""Exchange ID Token for ID-JAG using RFC 8693 Token Exchange.
255255
256-
Note: Overrides the configured `audience` with the discovered OAuth
257-
issuer URL (if available) to satisfy MCP server `aud` claim requirements.
258-
259256
Args:
260257
client: HTTP client for making requests
261258
@@ -268,22 +265,23 @@ async def exchange_token_for_id_jag(
268265
logger.debug("Starting token exchange for ID-JAG")
269266

270267
audience = self.token_exchange_params.audience
271-
if self.context.oauth_metadata and self.context.oauth_metadata.issuer:
272-
discovered_issuer = str(self.context.oauth_metadata.issuer)
273-
if audience != discovered_issuer:
274-
logger.warning(
275-
f"Overriding audience '{audience}' with discovered issuer '{discovered_issuer}'. "
276-
f"To prevent this, set token_exchange_params.audience to the issuer URL."
277-
)
278-
audience = discovered_issuer
268+
if self.override_audience_with_issuer:
269+
if self.context.oauth_metadata and self.context.oauth_metadata.issuer:
270+
discovered_issuer = str(self.context.oauth_metadata.issuer)
271+
if audience != discovered_issuer:
272+
logger.warning(
273+
f"Overriding audience '{audience}' with discovered issuer "
274+
f"'{discovered_issuer}'. To prevent this, pass "
275+
f"override_audience_with_issuer=False."
276+
)
277+
audience = discovered_issuer
279278

280-
# Build token exchange request
281279
token_data: TokenExchangeRequestData = {
282280
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
283281
"requested_token_type": self.token_exchange_params.requested_token_type,
284282
"audience": audience,
285283
"resource": self.token_exchange_params.resource,
286-
"subject_token": self.token_exchange_params.subject_token,
284+
"subject_token": self._subject_token,
287285
"subject_token_type": self.token_exchange_params.subject_token_type,
288286
}
289287

@@ -309,7 +307,6 @@ async def exchange_token_for_id_jag(
309307
if response.headers.get("content-type", "").startswith("application/json"):
310308
error_data = response.json()
311309
except JSONDecodeError:
312-
# Response is not valid JSON, use default error handling
313310
pass
314311

315312
error: str = error_data.get("error", "unknown_error")
@@ -318,10 +315,8 @@ async def exchange_token_for_id_jag(
318315
)
319316
raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}")
320317

321-
# Parse response
322318
token_response = IDJAGTokenExchangeResponse.model_validate_json(response.content)
323319

324-
# Validate response
325320
if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag":
326321
raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}")
327322

@@ -331,16 +326,11 @@ async def exchange_token_for_id_jag(
331326
logger.debug("Successfully obtained ID-JAG")
332327
self._id_jag = token_response.id_jag
333328

334-
# Track ID-JAG expiry to avoid using stale cached tokens
335329
if token_response.expires_in:
336330
self._id_jag_expiry = time.time() + token_response.expires_in
337331
else:
338-
# If no expires_in, use configured default expiry
339332
self._id_jag_expiry = time.time() + self.default_id_jag_expiry
340-
logger.debug(
341-
f"IdP did not provide expires_in, using default expiry of "
342-
f"{self.default_id_jag_expiry} seconds for ID-JAG"
343-
)
333+
logger.debug(f"IdP omitted expires_in; using default of {self.default_id_jag_expiry}s for ID-JAG")
344334

345335
return token_response.id_jag
346336

@@ -365,74 +355,54 @@ async def exchange_id_jag_for_access_token(
365355
Raises:
366356
OAuthFlowError: If OAuth metadata not discovered
367357
"""
368-
logger.info("Building JWT bearer grant request for ID-JAG")
358+
logger.debug("Building JWT bearer grant request for ID-JAG")
369359

370-
# Discover token endpoint from MCP server if not already done
371360
if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint:
372361
raise OAuthFlowError("MCP server token endpoint not discovered")
373362

374363
token_endpoint = str(self.context.oauth_metadata.token_endpoint)
375364

376-
# Build JWT bearer grant request
377-
token_data: JWTBearerGrantRequestData = {
365+
# Build as a plain dict — avoids the double-cast through TypedDict
366+
token_data: dict[str, str] = {
378367
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
379368
"assertion": id_jag,
380369
}
381370

382-
# Add client authentication
371+
# Add client_id to body. prepare_token_auth handles client_secret
372+
# placement (Basic header vs. POST body) based on auth method.
383373
if self.context.client_info:
384-
# Default to client_secret_basic if not specified (per OAuth 2.0 spec)
385374
if self.context.client_info.token_endpoint_auth_method is None:
386375
self.context.client_info.token_endpoint_auth_method = "client_secret_basic"
387376

388377
if self.context.client_info.client_id is not None:
389378
token_data["client_id"] = self.context.client_info.client_id
390-
if self.context.client_info.client_secret is not None:
391-
token_data["client_secret"] = self.context.client_info.client_secret
392379

393-
# Apply client authentication method (handles client_secret_basic vs client_secret_post)
394380
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
395-
# Cast to dict[str, str] for prepare_token_auth compatibility
396-
# Double-cast to bypass TypedDict strictness for prepare_token_auth
397-
data_dict = cast(dict[str, str], cast(object, token_data))
398-
data_dict, headers = self.context.prepare_token_auth(data_dict, headers)
381+
token_data, headers = self.context.prepare_token_auth(token_data, headers)
399382

400-
return httpx.Request("POST", token_endpoint, data=data_dict, headers=headers)
383+
return httpx.Request("POST", token_endpoint, data=token_data, headers=headers)
401384

402385
async def _perform_authorization(self) -> httpx.Request:
403386
"""Perform enterprise authorization flow.
404387
405-
Overrides parent method to use token exchange + JWT bearer grant
406-
instead of standard authorization code flow.
388+
Called by the parent's ``async_auth_flow`` when a new access token is needed.
389+
Unconditionally performs full token exchange as the parent already handles
390+
token validity checks.
407391
408-
This method:
409-
1. Exchanges IDP ID token for ID-JAG at the IDP server (direct HTTP call)
410-
2. Returns an httpx.Request for JWT bearer grant (ID-JAG → Access token)
392+
Flow:
393+
1. Exchange IdP subject token for ID-JAG (RFC 8693, direct HTTP call)
394+
2. Return an ``httpx.Request`` for the JWT bearer grant (RFC 7523)
395+
that the parent will execute and pass to ``_handle_token_response``
411396
412397
Returns:
413398
httpx.Request for the JWT bearer grant to the MCP authorization server
414399
"""
415-
# Check if we already have valid tokens
416-
if self.context.is_token_valid():
417-
# Reuse unexpired cached ID-JAG to prevent auth failures
418-
if self._id_jag and self._id_jag_expiry:
419-
if time.time() < self._id_jag_expiry:
420-
logger.debug("Reusing cached ID-JAG for JWT bearer grant")
421-
return await self.exchange_id_jag_for_access_token(self._id_jag)
422-
else:
423-
logger.debug("Cached ID-JAG expired, will obtain a new one")
424-
# Fall through to full flow if ID-JAG is expired or missing (e.g., loaded from storage)
425-
426-
# Step 1: Exchange IDP ID token for ID-JAG (RFC 8693)
427-
# This is an external call to the IDP, so we make it directly
400+
# Step 1: Exchange IDP subject token for ID-JAG (RFC 8693)
428401
async with httpx.AsyncClient(timeout=self.context.timeout) as client:
429402
id_jag = await self.exchange_token_for_id_jag(client)
430-
# Cache the ID-JAG for potential reuse
431403
self._id_jag = id_jag
432404

433405
# Step 2: Build JWT bearer grant request (RFC 7523)
434-
# This request will be yielded by the parent's async_auth_flow
435-
# and the response will be handled by _handle_token_response
436406
jwt_bearer_request = await self.exchange_id_jag_for_access_token(id_jag)
437407

438408
logger.debug("Returning JWT bearer grant request to async_auth_flow")
@@ -441,42 +411,39 @@ async def _perform_authorization(self) -> httpx.Request:
441411
async def refresh_with_new_id_token(self, new_id_token: str) -> None:
442412
"""Refresh MCP server access tokens using a fresh ID token from the IdP.
443413
444-
Updates the subject token and clears cached tokens (including ID-JAG),
445-
triggering re-authentication on the next API request.
414+
Updates the subject token and clears cached state so that the next API
415+
request triggers a full re-authentication.
446416
447417
Note: OAuth metadata is not re-discovered. If the MCP server's OAuth
448-
configuration has changed, you must create a new provider instance.
418+
configuration has changed, create a new provider instance instead.
449419
450420
Args:
451421
new_id_token: Fresh ID token obtained from your enterprise IdP.
452422
"""
453423
logger.info("Refreshing tokens with new ID token from IdP")
454-
self.token_exchange_params.subject_token = new_id_token
424+
# Update the mutable subject token (does NOT mutate the original params object)
425+
self._subject_token = new_id_token
455426

456-
# Clear caches to force ID-JAG re-exchange and re-authentication
427+
# Clear caches to force full re-exchange on next request
457428
self._id_jag = None
458429
self._id_jag_expiry = None
459430
self.context.clear_tokens()
460-
logger.debug("Token refresh prepared - will re-authenticate on next request")
431+
logger.debug("Token refresh prepared will re-authenticate on next request")
461432

462433

463434
def decode_id_jag(id_jag: str) -> IDJAGClaims:
464435
"""Decode an ID-JAG token without verification.
465436
437+
Relies on the receiving server to validate the JWT signature.
438+
466439
Args:
467440
id_jag: The ID-JAG token string
468441
469442
Returns:
470443
Decoded ID-JAG claims
471-
472-
Note:
473-
This function does not verify the JWT, instead relying on the receiving server to validate it.
474444
"""
475-
# Decode without verification for inspection
476445
claims = jwt.decode(id_jag, options={"verify_signature": False})
477446
header = jwt.get_unverified_header(id_jag)
478-
479-
# Add typ from header to claims
480447
claims["typ"] = header.get("typ", "")
481448

482449
return IDJAGClaims.model_validate(claims)

0 commit comments

Comments
 (0)