77import logging
88import time
99from json import JSONDecodeError
10- from typing import cast
1110
1211import httpx
1312import jwt
@@ -38,19 +37,6 @@ class TokenExchangeRequestData(TypedDict):
3837 client_secret : NotRequired [str ]
3938
4039
41- class JWTBearerGrantRequestData (TypedDict ):
42- """Type definition for RFC 7523 JWT Bearer Grant request data.
43-
44- Required fields are those mandated by RFC 7523.
45- Optional fields (NotRequired) are for client authentication.
46- """
47-
48- grant_type : Required [str ]
49- assertion : Required [str ]
50- client_id : NotRequired [str ]
51- client_secret : NotRequired [str ]
52-
53-
5440class TokenExchangeParameters (BaseModel ):
5541 """Parameters for RFC 8693 Token Exchange request."""
5642
@@ -187,15 +173,16 @@ class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
187173 - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token)
188174
189175 Concurrency & Thread Safety:
190- - SAFE: Concurrent requests within a single asyncio event loop. Token operations
191- are protected by the parent class's `OAuthContext.lock`.
192- - UNSAFE: Sharing a provider instance across multiple OS threads. Each thread
193- must instantiate its own provider and event loop.
194- - Note: Ensure any shared `TokenStorage` implementation is async-safe.
176+ - SAFE: Concurrent requests within a single asyncio event loop. Token
177+ operations (including ``_id_jag`` / ``_id_jag_expiry``) are protected
178+ by the parent class's ``OAuthContext.lock`` via ``async_auth_flow``.
179+ - UNSAFE: Sharing a provider instance across multiple OS threads. Each
180+ thread must instantiate its own provider and event loop.
181+ - Note: Ensure any shared ``TokenStorage`` implementation is async-safe.
195182 """
196183
197- # Default ID-JAG expiry when IdP doesn't provide expires_in
198- # 15 minutes is a conservative default for enterprise environments
184+ # Default ID-JAG expiry when IdP doesn't provide expires_in.
185+ # 15 minutes is a conservative default for enterprise environments.
199186 DEFAULT_ID_JAG_EXPIRY_SECONDS = 900
200187
201188 def __init__ (
@@ -209,6 +196,7 @@ def __init__(
209196 idp_client_id : str | None = None ,
210197 idp_client_secret : str | None = None ,
211198 default_id_jag_expiry : int = DEFAULT_ID_JAG_EXPIRY_SECONDS ,
199+ override_audience_with_issuer : bool = True ,
212200 ) -> None :
213201 """Initialize Enterprise Auth OAuth Client.
214202
@@ -217,13 +205,21 @@ def __init__(
217205 client_metadata: OAuth client metadata
218206 storage: Token storage implementation
219207 idp_token_endpoint: Enterprise IdP token endpoint URL
220- token_exchange_params: Token exchange parameters
208+ token_exchange_params: Token exchange parameters (not mutated)
221209 timeout: Request timeout in seconds
222210 idp_client_id: Optional client ID registered with the IdP for token exchange
223- idp_client_secret: Optional client secret registered with the IdP for token exchange
211+ idp_client_secret: Optional client secret registered with the IdP.
212+ Must be accompanied by ``idp_client_id``; providing a secret
213+ without an ID raises ``ValueError``.
224214 default_id_jag_expiry: Fallback ID-JAG expiry in seconds if the IdP
225- omits `expires_in` (default: 900s/15m). Adjust to balance token
226- freshness against IdP request load.
215+ omits ``expires_in`` (default: 900 s / 15 min).
216+ override_audience_with_issuer: If True (default), replaces the IdP
217+ audience with the discovered OAuth issuer URL. Set to False for
218+ federated identity setups where the audience must differ.
219+
220+ Raises:
221+ ValueError: If ``idp_client_secret`` is provided without ``idp_client_id``.
222+ OAuthFlowError: If ``token_exchange_params`` fail validation.
227223 """
228224 super ().__init__ (
229225 server_url = server_url ,
@@ -232,30 +228,31 @@ def __init__(
232228 timeout = timeout ,
233229 )
234230 self .idp_token_endpoint = idp_token_endpoint
231+ # Keep original params immutable; track mutable subject_token separately
235232 self .token_exchange_params = token_exchange_params
233+ self ._subject_token = token_exchange_params .subject_token
236234 self .idp_client_id = idp_client_id
237235 self .idp_client_secret = idp_client_secret
238236 self .default_id_jag_expiry = default_id_jag_expiry
237+ self .override_audience_with_issuer = override_audience_with_issuer
239238 self ._id_jag : str | None = None
240239 self ._id_jag_expiry : float | None = None
241240
242- # Validate client authentication configuration
241+ # Fail-fast: secret without ID is almost certainly a misconfiguration
243242 if idp_client_secret is not None and idp_client_id is None :
244- logger .warning (
245- "idp_client_secret provided without idp_client_id. "
246- "The secret will be sent to the IdP but may be ignored. "
247- "Consider providing both idp_client_id and idp_client_secret together."
243+ raise ValueError (
244+ "idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
248245 )
249246
247+ # Validate token exchange params at construction time
248+ validate_token_exchange_params (token_exchange_params )
249+
250250 async def exchange_token_for_id_jag (
251251 self ,
252252 client : httpx .AsyncClient ,
253253 ) -> str :
254254 """Exchange ID Token for ID-JAG using RFC 8693 Token Exchange.
255255
256- Note: Overrides the configured `audience` with the discovered OAuth
257- issuer URL (if available) to satisfy MCP server `aud` claim requirements.
258-
259256 Args:
260257 client: HTTP client for making requests
261258
@@ -268,22 +265,23 @@ async def exchange_token_for_id_jag(
268265 logger .debug ("Starting token exchange for ID-JAG" )
269266
270267 audience = self .token_exchange_params .audience
271- if self .context .oauth_metadata and self .context .oauth_metadata .issuer :
272- discovered_issuer = str (self .context .oauth_metadata .issuer )
273- if audience != discovered_issuer :
274- logger .warning (
275- f"Overriding audience '{ audience } ' with discovered issuer '{ discovered_issuer } '. "
276- f"To prevent this, set token_exchange_params.audience to the issuer URL."
277- )
278- audience = discovered_issuer
268+ if self .override_audience_with_issuer :
269+ if self .context .oauth_metadata and self .context .oauth_metadata .issuer :
270+ discovered_issuer = str (self .context .oauth_metadata .issuer )
271+ if audience != discovered_issuer :
272+ logger .warning (
273+ f"Overriding audience '{ audience } ' with discovered issuer "
274+ f"'{ discovered_issuer } '. To prevent this, pass "
275+ f"override_audience_with_issuer=False."
276+ )
277+ audience = discovered_issuer
279278
280- # Build token exchange request
281279 token_data : TokenExchangeRequestData = {
282280 "grant_type" : "urn:ietf:params:oauth:grant-type:token-exchange" ,
283281 "requested_token_type" : self .token_exchange_params .requested_token_type ,
284282 "audience" : audience ,
285283 "resource" : self .token_exchange_params .resource ,
286- "subject_token" : self .token_exchange_params . subject_token ,
284+ "subject_token" : self ._subject_token ,
287285 "subject_token_type" : self .token_exchange_params .subject_token_type ,
288286 }
289287
@@ -309,7 +307,6 @@ async def exchange_token_for_id_jag(
309307 if response .headers .get ("content-type" , "" ).startswith ("application/json" ):
310308 error_data = response .json ()
311309 except JSONDecodeError :
312- # Response is not valid JSON, use default error handling
313310 pass
314311
315312 error : str = error_data .get ("error" , "unknown_error" )
@@ -318,10 +315,8 @@ async def exchange_token_for_id_jag(
318315 )
319316 raise OAuthTokenError (f"Token exchange failed: { error } - { error_description } " )
320317
321- # Parse response
322318 token_response = IDJAGTokenExchangeResponse .model_validate_json (response .content )
323319
324- # Validate response
325320 if token_response .issued_token_type != "urn:ietf:params:oauth:token-type:id-jag" :
326321 raise OAuthTokenError (f"Unexpected token type: { token_response .issued_token_type } " )
327322
@@ -331,16 +326,11 @@ async def exchange_token_for_id_jag(
331326 logger .debug ("Successfully obtained ID-JAG" )
332327 self ._id_jag = token_response .id_jag
333328
334- # Track ID-JAG expiry to avoid using stale cached tokens
335329 if token_response .expires_in :
336330 self ._id_jag_expiry = time .time () + token_response .expires_in
337331 else :
338- # If no expires_in, use configured default expiry
339332 self ._id_jag_expiry = time .time () + self .default_id_jag_expiry
340- logger .debug (
341- f"IdP did not provide expires_in, using default expiry of "
342- f"{ self .default_id_jag_expiry } seconds for ID-JAG"
343- )
333+ logger .debug (f"IdP omitted expires_in; using default of { self .default_id_jag_expiry } s for ID-JAG" )
344334
345335 return token_response .id_jag
346336
@@ -365,74 +355,54 @@ async def exchange_id_jag_for_access_token(
365355 Raises:
366356 OAuthFlowError: If OAuth metadata not discovered
367357 """
368- logger .info ("Building JWT bearer grant request for ID-JAG" )
358+ logger .debug ("Building JWT bearer grant request for ID-JAG" )
369359
370- # Discover token endpoint from MCP server if not already done
371360 if not self .context .oauth_metadata or not self .context .oauth_metadata .token_endpoint :
372361 raise OAuthFlowError ("MCP server token endpoint not discovered" )
373362
374363 token_endpoint = str (self .context .oauth_metadata .token_endpoint )
375364
376- # Build JWT bearer grant request
377- token_data : JWTBearerGrantRequestData = {
365+ # Build as a plain dict — avoids the double-cast through TypedDict
366+ token_data : dict [ str , str ] = {
378367 "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
379368 "assertion" : id_jag ,
380369 }
381370
382- # Add client authentication
371+ # Add client_id to body. prepare_token_auth handles client_secret
372+ # placement (Basic header vs. POST body) based on auth method.
383373 if self .context .client_info :
384- # Default to client_secret_basic if not specified (per OAuth 2.0 spec)
385374 if self .context .client_info .token_endpoint_auth_method is None :
386375 self .context .client_info .token_endpoint_auth_method = "client_secret_basic"
387376
388377 if self .context .client_info .client_id is not None :
389378 token_data ["client_id" ] = self .context .client_info .client_id
390- if self .context .client_info .client_secret is not None :
391- token_data ["client_secret" ] = self .context .client_info .client_secret
392379
393- # Apply client authentication method (handles client_secret_basic vs client_secret_post)
394380 headers : dict [str , str ] = {"Content-Type" : "application/x-www-form-urlencoded" }
395- # Cast to dict[str, str] for prepare_token_auth compatibility
396- # Double-cast to bypass TypedDict strictness for prepare_token_auth
397- data_dict = cast (dict [str , str ], cast (object , token_data ))
398- data_dict , headers = self .context .prepare_token_auth (data_dict , headers )
381+ token_data , headers = self .context .prepare_token_auth (token_data , headers )
399382
400- return httpx .Request ("POST" , token_endpoint , data = data_dict , headers = headers )
383+ return httpx .Request ("POST" , token_endpoint , data = token_data , headers = headers )
401384
402385 async def _perform_authorization (self ) -> httpx .Request :
403386 """Perform enterprise authorization flow.
404387
405- Overrides parent method to use token exchange + JWT bearer grant
406- instead of standard authorization code flow.
388+ Called by the parent's ``async_auth_flow`` when a new access token is needed.
389+ Unconditionally performs full token exchange as the parent already handles
390+ token validity checks.
407391
408- This method:
409- 1. Exchanges IDP ID token for ID-JAG at the IDP server (direct HTTP call)
410- 2. Returns an httpx.Request for JWT bearer grant (ID-JAG → Access token)
392+ Flow:
393+ 1. Exchange IdP subject token for ID-JAG (RFC 8693, direct HTTP call)
394+ 2. Return an ``httpx.Request`` for the JWT bearer grant (RFC 7523)
395+ that the parent will execute and pass to ``_handle_token_response``
411396
412397 Returns:
413398 httpx.Request for the JWT bearer grant to the MCP authorization server
414399 """
415- # Check if we already have valid tokens
416- if self .context .is_token_valid ():
417- # Reuse unexpired cached ID-JAG to prevent auth failures
418- if self ._id_jag and self ._id_jag_expiry :
419- if time .time () < self ._id_jag_expiry :
420- logger .debug ("Reusing cached ID-JAG for JWT bearer grant" )
421- return await self .exchange_id_jag_for_access_token (self ._id_jag )
422- else :
423- logger .debug ("Cached ID-JAG expired, will obtain a new one" )
424- # Fall through to full flow if ID-JAG is expired or missing (e.g., loaded from storage)
425-
426- # Step 1: Exchange IDP ID token for ID-JAG (RFC 8693)
427- # This is an external call to the IDP, so we make it directly
400+ # Step 1: Exchange IDP subject token for ID-JAG (RFC 8693)
428401 async with httpx .AsyncClient (timeout = self .context .timeout ) as client :
429402 id_jag = await self .exchange_token_for_id_jag (client )
430- # Cache the ID-JAG for potential reuse
431403 self ._id_jag = id_jag
432404
433405 # Step 2: Build JWT bearer grant request (RFC 7523)
434- # This request will be yielded by the parent's async_auth_flow
435- # and the response will be handled by _handle_token_response
436406 jwt_bearer_request = await self .exchange_id_jag_for_access_token (id_jag )
437407
438408 logger .debug ("Returning JWT bearer grant request to async_auth_flow" )
@@ -441,42 +411,39 @@ async def _perform_authorization(self) -> httpx.Request:
441411 async def refresh_with_new_id_token (self , new_id_token : str ) -> None :
442412 """Refresh MCP server access tokens using a fresh ID token from the IdP.
443413
444- Updates the subject token and clears cached tokens (including ID-JAG),
445- triggering re-authentication on the next API request .
414+ Updates the subject token and clears cached state so that the next API
415+ request triggers a full re-authentication .
446416
447417 Note: OAuth metadata is not re-discovered. If the MCP server's OAuth
448- configuration has changed, you must create a new provider instance.
418+ configuration has changed, create a new provider instance instead .
449419
450420 Args:
451421 new_id_token: Fresh ID token obtained from your enterprise IdP.
452422 """
453423 logger .info ("Refreshing tokens with new ID token from IdP" )
454- self .token_exchange_params .subject_token = new_id_token
424+ # Update the mutable subject token (does NOT mutate the original params object)
425+ self ._subject_token = new_id_token
455426
456- # Clear caches to force ID-JAG re-exchange and re-authentication
427+ # Clear caches to force full re-exchange on next request
457428 self ._id_jag = None
458429 self ._id_jag_expiry = None
459430 self .context .clear_tokens ()
460- logger .debug ("Token refresh prepared - will re-authenticate on next request" )
431+ logger .debug ("Token refresh prepared — will re-authenticate on next request" )
461432
462433
463434def decode_id_jag (id_jag : str ) -> IDJAGClaims :
464435 """Decode an ID-JAG token without verification.
465436
437+ Relies on the receiving server to validate the JWT signature.
438+
466439 Args:
467440 id_jag: The ID-JAG token string
468441
469442 Returns:
470443 Decoded ID-JAG claims
471-
472- Note:
473- This function does not verify the JWT, instead relying on the receiving server to validate it.
474444 """
475- # Decode without verification for inspection
476445 claims = jwt .decode (id_jag , options = {"verify_signature" : False })
477446 header = jwt .get_unverified_header (id_jag )
478-
479- # Add typ from header to claims
480447 claims ["typ" ] = header .get ("typ" , "" )
481448
482449 return IDJAGClaims .model_validate (claims )
0 commit comments