diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index f929646df..956273936 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -143,7 +143,8 @@ schemars = ["dep:schemars"] [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "1.1.0", features = ["chrono04"] } - +axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } +url = "2.4" anyhow = "1.0" tracing-subscriber = { version = "0.3", features = [ "env-filter", diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 7cedc1c27..cc77a8623 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; use async_trait::async_trait; use oauth2::{ @@ -23,6 +27,10 @@ const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; pub struct StoredCredentials { pub client_id: String, pub token_response: Option, + #[serde(default)] + pub granted_scopes: Vec, + #[serde(default)] + pub token_received_at: Option, } /// Trait for storing and retrieving OAuth2 credentials @@ -567,39 +575,78 @@ impl AuthorizationManager { debug!("exchange token result: {:?}", token_result); + let granted_scopes: Vec = token_result + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + // Store credentials in the credential store let client_id = oauth_client.client_id().to_string(); let stored = StoredCredentials { client_id, token_response: Some(token_result.clone()), + granted_scopes, + token_received_at: Some(Self::now_epoch_secs()), }; self.credential_store.save(stored).await?; Ok(token_result) } + fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + } + + /// Proactive refresh buffer: refresh tokens this many seconds before they expire + /// to avoid races between token retrieval and the actual HTTP request. + const REFRESH_BUFFER_SECS: u64 = 30; + /// get access token, if expired, refresh it automatically pub async fn get_access_token(&self) -> Result { - // Load credentials from store let stored = self.credential_store.load().await?; - let credentials = stored.and_then(|s| s.token_response); - - if let Some(creds) = credentials.as_ref() { - // check token expiry if we have a refresh token or an expiry time - if creds.refresh_token().is_some() || creds.expires_in().is_some() { - let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0)); - if expires_in <= Duration::from_secs(0) { - tracing::info!("Access token expired, refreshing."); - - let new_creds = self.refresh_token().await?; - tracing::info!("Refreshed access token."); - return Ok(new_creds.access_token().secret().to_string()); - } + let Some(stored_creds) = stored else { + return Err(AuthError::AuthorizationRequired); + }; + let Some(creds) = stored_creds.token_response.as_ref() else { + return Err(AuthError::AuthorizationRequired); + }; + + if let (Some(expires_in), Some(received_at)) = + (creds.expires_in(), stored_creds.token_received_at) + { + let elapsed = Self::now_epoch_secs().saturating_sub(received_at); + let remaining = expires_in.as_secs().saturating_sub(elapsed); + + if remaining < Self::REFRESH_BUFFER_SECS { + tracing::info!( + remaining_secs = remaining, + "Access token expired or nearly expired, refreshing." + ); + return self.try_refresh_or_reauth().await; } + } - Ok(creds.access_token().secret().to_string()) - } else { - Err(AuthError::AuthorizationRequired) + Ok(creds.access_token().secret().to_string()) + } + + /// Attempt to refresh the token. If refresh fails because there is no + /// refresh token or the server rejected it, return `AuthorizationRequired` + /// so the caller can re-prompt the user. Infrastructure errors (e.g. store + /// I/O failures, misconfigured client) are propagated as-is. + async fn try_refresh_or_reauth(&self) -> Result { + match self.refresh_token().await { + Ok(new_creds) => { + tracing::info!("Refreshed access token."); + Ok(new_creds.access_token().secret().to_string()) + } + Err(AuthError::AuthorizationRequired | AuthError::TokenRefreshFailed(_)) => { + tracing::warn!("Token refresh not possible, re-authorization required."); + Err(AuthError::AuthorizationRequired) + } + Err(e) => Err(e), } } @@ -613,25 +660,37 @@ impl AuthorizationManager { .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; let stored = self.credential_store.load().await?; - let current_credentials = stored - .and_then(|s| s.token_response) - .ok_or_else(|| AuthError::AuthorizationRequired)?; + let stored_credentials = stored.ok_or(AuthError::AuthorizationRequired)?; + let current_credentials = stored_credentials + .token_response + .ok_or(AuthError::AuthorizationRequired)?; let refresh_token = current_credentials.refresh_token().ok_or_else(|| { AuthError::TokenRefreshFailed("No refresh token available".to_string()) })?; debug!("refresh token: {:?}", refresh_token); - let token_result = oauth_client - .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) + let refresh_token_value = RefreshToken::new(refresh_token.secret().to_string()); + let mut refresh_request = oauth_client.exchange_refresh_token(&refresh_token_value); + for scope in &stored_credentials.granted_scopes { + refresh_request = refresh_request.add_scope(Scope::new(scope.clone())); + } + let token_result = refresh_request .request_async(&self.http_client) .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; + let granted_scopes: Vec = match token_result.scopes() { + Some(scopes) => scopes.iter().map(|s| s.to_string()).collect(), + None => vec![], + }; + let client_id = oauth_client.client_id().to_string(); let stored = StoredCredentials { client_id, token_response: Some(token_result.clone()), + granted_scopes, + token_received_at: Some(Self::now_epoch_secs()), }; self.credential_store.save(stored).await?; @@ -1102,9 +1161,16 @@ impl OAuthState { AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?, ); + let granted_scopes: Vec = credentials + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + let stored = StoredCredentials { client_id: client_id.to_string(), token_response: Some(credentials), + granted_scopes, + token_received_at: Some(AuthorizationManager::now_epoch_secs()), }; manager.credential_store.save(stored).await?; @@ -1259,7 +1325,35 @@ impl OAuthState { mod tests { use url::Url; - use super::AuthorizationManager; + use super::{ + AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore, OAuthClientConfig, + OAuthTokenResponse, StoredCredentials, StateStore, StoredAuthorizationState, + is_https_url, + }; + + // SEP-991: URL-based Client IDs + // Tests adapted from the TypeScript SDK's isHttpsUrl test suite + #[test] + fn test_is_https_url_scenarios() { + // Returns true for valid https url with path + assert!(is_https_url("https://example.com/client-metadata.json")); + // Returns true for https url with query params + assert!(is_https_url("https://example.com/metadata?version=1")); + // Returns false for https url without path + assert!(!is_https_url("https://example.com")); + assert!(!is_https_url("https://example.com/")); + assert!(!is_https_url("https://")); + // Returns false for http url + assert!(!is_https_url("http://example.com/metadata")); + // Returns false for non-url strings + assert!(!is_https_url("not a url")); + // Returns false for empty string + assert!(!is_https_url("")); + // Returns false for javascript scheme + assert!(!is_https_url("javascript:alert(1)")); + // Returns false for data scheme + assert!(!is_https_url("data:text/html,")); + } #[test] fn parses_resource_metadata_parameter() { @@ -1364,4 +1458,560 @@ mod tests { ] ); } + + #[test] + fn generate_discovery_urls() { + // Test root URL (no path components): OAuth first, then OpenID Connect + let base_url = Url::parse("https://auth.example.com").unwrap(); + let urls = AuthorizationManager::generate_discovery_urls(&base_url); + assert_eq!(urls.len(), 2); + assert_eq!( + urls[0].as_str(), + "https://auth.example.com/.well-known/oauth-authorization-server" + ); + assert_eq!( + urls[1].as_str(), + "https://auth.example.com/.well-known/openid-configuration" + ); + + // Test URL with single path segment: follow spec priority order + let base_url = Url::parse("https://auth.example.com/tenant1").unwrap(); + let urls = AuthorizationManager::generate_discovery_urls(&base_url); + assert_eq!(urls.len(), 3); + assert_eq!( + urls[0].as_str(), + "https://auth.example.com/.well-known/oauth-authorization-server/tenant1" + ); + assert_eq!( + urls[1].as_str(), + "https://auth.example.com/.well-known/openid-configuration/tenant1" + ); + assert_eq!( + urls[2].as_str(), + "https://auth.example.com/tenant1/.well-known/openid-configuration" + ); + + // Test URL with path and trailing slash + let base_url = Url::parse("https://auth.example.com/v1/mcp/").unwrap(); + let urls = AuthorizationManager::generate_discovery_urls(&base_url); + assert_eq!(urls.len(), 3); + assert_eq!( + urls[0].as_str(), + "https://auth.example.com/.well-known/oauth-authorization-server/v1/mcp" + ); + assert_eq!( + urls[1].as_str(), + "https://auth.example.com/.well-known/openid-configuration/v1/mcp" + ); + assert_eq!( + urls[2].as_str(), + "https://auth.example.com/v1/mcp/.well-known/openid-configuration" + ); + + // Test URL with multiple path segments + let base_url = Url::parse("https://auth.example.com/tenant1/subtenant").unwrap(); + let urls = AuthorizationManager::generate_discovery_urls(&base_url); + assert_eq!(urls.len(), 3); + assert_eq!( + urls[0].as_str(), + "https://auth.example.com/.well-known/oauth-authorization-server/tenant1/subtenant" + ); + assert_eq!( + urls[1].as_str(), + "https://auth.example.com/.well-known/openid-configuration/tenant1/subtenant" + ); + assert_eq!( + urls[2].as_str(), + "https://auth.example.com/tenant1/subtenant/.well-known/openid-configuration" + ); + } + + // StateStore and StoredAuthorizationState tests + + #[tokio::test] + async fn test_in_memory_state_store_save_and_load() { + let store = InMemoryStateStore::new(); + let pkce = PkceCodeVerifier::new("test-verifier".to_string()); + let csrf = CsrfToken::new("test-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + // Save state + store.save("test-csrf", state).await.unwrap(); + + // Load state + let loaded = store.load("test-csrf").await.unwrap(); + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.csrf_token, "test-csrf"); + assert_eq!(loaded.pkce_verifier, "test-verifier"); + } + + #[tokio::test] + async fn test_in_memory_state_store_load_nonexistent() { + let store = InMemoryStateStore::new(); + let result = store.load("nonexistent").await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_in_memory_state_store_delete() { + let store = InMemoryStateStore::new(); + let pkce = PkceCodeVerifier::new("verifier".to_string()); + let csrf = CsrfToken::new("csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save("csrf", state).await.unwrap(); + store.delete("csrf").await.unwrap(); + + let result = store.load("csrf").await.unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_stored_authorization_state_serialization() { + let pkce = PkceCodeVerifier::new("my-verifier".to_string()); + let csrf = CsrfToken::new("my-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + // Serialize to JSON + let json = serde_json::to_string(&state).unwrap(); + + // Deserialize back + let deserialized: StoredAuthorizationState = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.pkce_verifier, "my-verifier"); + assert_eq!(deserialized.csrf_token, "my-csrf"); + } + + #[test] + fn test_stored_authorization_state_into_pkce_verifier() { + let pkce = PkceCodeVerifier::new("original-verifier".to_string()); + let csrf = CsrfToken::new("csrf-token".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + let recovered = state.into_pkce_verifier(); + assert_eq!(recovered.secret(), "original-verifier"); + } + + #[test] + fn test_stored_authorization_state_created_at() { + let pkce = PkceCodeVerifier::new("verifier".to_string()); + let csrf = CsrfToken::new("csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + // created_at should be a reasonable timestamp (after year 2020) + assert!(state.created_at > 1577836800); // Jan 1, 2020 + } + + #[tokio::test] + async fn test_in_memory_state_store_overwrite() { + let store = InMemoryStateStore::new(); + let csrf_key = "same-csrf"; + + // Save first state + let pkce1 = PkceCodeVerifier::new("verifier-1".to_string()); + let csrf1 = CsrfToken::new(csrf_key.to_string()); + let state1 = StoredAuthorizationState::new(&pkce1, &csrf1); + store.save(csrf_key, state1).await.unwrap(); + + // Overwrite with second state + let pkce2 = PkceCodeVerifier::new("verifier-2".to_string()); + let csrf2 = CsrfToken::new(csrf_key.to_string()); + let state2 = StoredAuthorizationState::new(&pkce2, &csrf2); + store.save(csrf_key, state2).await.unwrap(); + + // Should get the second state + let loaded = store.load(csrf_key).await.unwrap().unwrap(); + assert_eq!(loaded.pkce_verifier, "verifier-2"); + } + + #[tokio::test] + async fn test_in_memory_state_store_concurrent_access() { + let store = Arc::new(InMemoryStateStore::new()); + let mut handles = vec![]; + + // Spawn 10 concurrent tasks that each save and load their own state + for i in 0..10 { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + let csrf_key = format!("csrf-{}", i); + let verifier = format!("verifier-{}", i); + + let pkce = PkceCodeVerifier::new(verifier.clone()); + let csrf = CsrfToken::new(csrf_key.clone()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save(&csrf_key, state).await.unwrap(); + let loaded = store.load(&csrf_key).await.unwrap().unwrap(); + assert_eq!(loaded.pkce_verifier, verifier); + + store.delete(&csrf_key).await.unwrap(); + let deleted = store.load(&csrf_key).await.unwrap(); + assert!(deleted.is_none()); + }); + handles.push(handle); + } + + // Wait for all tasks to complete + for handle in handles { + handle.await.unwrap(); + } + } + + #[tokio::test] + async fn test_custom_state_store_with_authorization_manager() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Custom state store that tracks calls + #[derive(Debug, Default)] + struct TrackingStateStore { + inner: InMemoryStateStore, + save_count: AtomicUsize, + load_count: AtomicUsize, + delete_count: AtomicUsize, + } + + #[async_trait::async_trait] + impl StateStore for TrackingStateStore { + async fn save( + &self, + csrf_token: &str, + state: StoredAuthorizationState, + ) -> Result<(), AuthError> { + self.save_count.fetch_add(1, Ordering::SeqCst); + self.inner.save(csrf_token, state).await + } + + async fn load( + &self, + csrf_token: &str, + ) -> Result, AuthError> { + self.load_count.fetch_add(1, Ordering::SeqCst); + self.inner.load(csrf_token).await + } + + async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> { + self.delete_count.fetch_add(1, Ordering::SeqCst); + self.inner.delete(csrf_token).await + } + } + + // Verify custom store works standalone + let store = TrackingStateStore::default(); + let pkce = PkceCodeVerifier::new("test-verifier".to_string()); + let csrf = CsrfToken::new("test-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save("test-csrf", state).await.unwrap(); + assert_eq!(store.save_count.load(Ordering::SeqCst), 1); + + let _ = store.load("test-csrf").await.unwrap(); + assert_eq!(store.load_count.load(Ordering::SeqCst), 1); + + store.delete("test-csrf").await.unwrap(); + assert_eq!(store.delete_count.load(Ordering::SeqCst), 1); + + // Verify custom store can be set on AuthorizationManager + let mut manager = AuthorizationManager::new("http://localhost").await.unwrap(); + manager.set_state_store(TrackingStateStore::default()); + } + + /// Helper: create an AuthorizationManager with minimal metadata so + /// `configure_client` can be exercised without a live server. + async fn manager_with_metadata( + metadata_override: Option, + ) -> AuthorizationManager { + let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + mgr.set_metadata(metadata_override.unwrap_or(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + ..Default::default() + })); + mgr + } + + fn test_client_config() -> OAuthClientConfig { + OAuthClientConfig { + client_id: "my-client".to_string(), + client_secret: Some("my-secret".to_string()), + scopes: vec![], + redirect_uri: "http://localhost/callback".to_string(), + } + } + + // -- get_access_token -- + + fn make_token_response(access_token: &str, expires_in_secs: Option) -> OAuthTokenResponse { + use oauth2::{AccessToken, EmptyExtraTokenFields, basic::BasicTokenType}; + let mut resp = OAuthTokenResponse::new( + AccessToken::new(access_token.to_string()), + BasicTokenType::Bearer, + EmptyExtraTokenFields {}, + ); + if let Some(secs) = expires_in_secs { + resp.set_expires_in(Some(&std::time::Duration::from_secs(secs))); + } + resp + } + + #[tokio::test] + async fn get_access_token_returns_error_when_no_credentials() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let err = manager.get_access_token().await.unwrap_err(); + assert!(matches!(err, AuthError::AuthorizationRequired)); + } + + #[tokio::test] + async fn get_access_token_returns_token_when_not_expired() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let stored = StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response("my-access-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs()), + }; + manager.credential_store.save(stored).await.unwrap(); + + let token = manager.get_access_token().await.unwrap(); + assert_eq!(token, "my-access-token"); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_expired_and_no_refresh_token() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response("stale-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired when token is expired and refresh is impossible, got: {err:?}" + ); + } + + #[tokio::test] + async fn get_access_token_returns_token_without_expiry_info() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let stored = StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response("no-expiry-token", None)), + granted_scopes: vec![], + token_received_at: None, + }; + manager.credential_store.save(stored).await.unwrap(); + + let token = manager.get_access_token().await.unwrap(); + assert_eq!(token, "no-expiry-token"); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_within_refresh_buffer() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response("almost-expired", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs() - 3590), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired when token is within refresh buffer, got: {err:?}" + ); + } + + #[tokio::test] + async fn get_access_token_propagates_internal_errors() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let stored = StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response("stale-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::InternalError(_)), + "expected InternalError when OAuth client is not configured, got: {err:?}" + ); + } + + // -- refresh_token -- + + fn make_token_response_with_refresh( + access_token: &str, + refresh_token_str: &str, + ) -> OAuthTokenResponse { + use oauth2::RefreshToken; + let mut resp = make_token_response(access_token, Some(3600)); + resp.set_refresh_token(Some(RefreshToken::new(refresh_token_str.to_string()))); + resp + } + + #[tokio::test] + async fn refresh_token_returns_error_when_no_stored_credentials() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let err = manager.refresh_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired when no credentials stored, got: {err:?}" + ); + } + + #[tokio::test] + async fn refresh_token_returns_error_when_no_token_response() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: None, + granted_scopes: vec![], + token_received_at: None, + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.refresh_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired when token_response is None, got: {err:?}" + ); + } + + #[tokio::test] + async fn refresh_token_returns_error_when_no_refresh_token() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response("old-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs()), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.refresh_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::TokenRefreshFailed(_)), + "expected TokenRefreshFailed when no refresh token, got: {err:?}" + ); + } + + async fn start_token_server() -> (String, Arc>>) { + use axum::{Router, body::Body, http::Response, routing::post}; + let captured: Arc>> = Arc::new(std::sync::Mutex::new(None)); + let captured_clone = Arc::clone(&captured); + + let app = Router::new().route( + "/token", + post(move |body: axum::body::Bytes| { + let cap = Arc::clone(&captured_clone); + async move { + *cap.lock().unwrap() = + Some(String::from_utf8(body.to_vec()).unwrap()); + Response::builder() + .status(200) + .header("content-type", "application/json") + .body(Body::from( + r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#, + )) + .unwrap() + } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + (format!("http://{}", addr), captured) + } + + #[tokio::test] + async fn refresh_token_sends_granted_scopes_in_request() { + let (base_url, captured) = start_token_server().await; + + let mut manager = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: format!("{}/authorize", base_url), + token_endpoint: format!("{}/token", base_url), + ..Default::default() + })) + .await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response_with_refresh( + "old-token", + "my-refresh-token", + )), + granted_scopes: vec!["read".to_string(), "write".to_string()], + token_received_at: Some(AuthorizationManager::now_epoch_secs()), + }; + manager.credential_store.save(stored).await.unwrap(); + + manager.refresh_token().await.unwrap(); + + let body = captured.lock().unwrap().take().unwrap(); + let params: std::collections::HashMap<_, _> = url::form_urlencoded::parse(body.as_bytes()) + .into_owned() + .collect(); + let scope = params + .get("scope") + .expect("scope should be present in refresh request"); + let mut scope_parts: Vec<&str> = scope.split_whitespace().collect(); + scope_parts.sort_unstable(); + assert_eq!(scope_parts, vec!["read", "write"]); + } + + #[tokio::test] + async fn refresh_token_omits_scope_when_granted_scopes_is_empty() { + let (base_url, captured) = start_token_server().await; + + let mut manager = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: format!("{}/authorize", base_url), + token_endpoint: format!("{}/token", base_url), + ..Default::default() + })) + .await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response_with_refresh( + "old-token", + "my-refresh-token", + )), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs()), + }; + manager.credential_store.save(stored).await.unwrap(); + + manager.refresh_token().await.unwrap(); + + let body = captured.lock().unwrap().take().unwrap(); + let params: std::collections::HashMap<_, _> = url::form_urlencoded::parse(body.as_bytes()) + .into_owned() + .collect(); + assert!( + !params.contains_key("scope"), + "scope should be absent when granted_scopes is empty, body: {body}" + ); + } } diff --git a/docs/OAUTH_SUPPORT.md b/docs/OAUTH_SUPPORT.md index 4b585786e..abaf62b5e 100644 --- a/docs/OAUTH_SUPPORT.md +++ b/docs/OAUTH_SUPPORT.md @@ -95,12 +95,15 @@ cargo run --example oauth-client ## Authorization Flow Description -1. **Metadata Discovery**: Client attempts to get authorization server metadata from `/.well-known/oauth-authorization-server` -2. **Client Registration**: If supported, client dynamically registers itself -3. **Authorization Request**: Build authorization URL with PKCE and guide user to access -4. **Authorization Code Exchange**: After user authorization, exchange authorization code for access token -5. **Token Usage**: Use access token for API calls -6. **Token Refresh**: Automatically use refresh token to get new access token when current one expires +1. **Resource Metadata Discovery**: Client probes the server and extracts `WWW-Authenticate` parameters including `resource_metadata` URL and `scope` +2. **Protected Resource Metadata**: Client fetches resource server metadata (RFC 9728) to find authorization server(s) and supported scopes +3. **AS Metadata Discovery**: Client discovers authorization server metadata via RFC 8414 and OpenID Connect well-known endpoints +4. **Client Registration**: If supported, client dynamically registers itself (or uses URL-based Client ID via SEP-991) +5. **Scope Selection**: SDK picks scopes from WWW-Authenticate > PRM > AS metadata > caller defaults +6. **Authorization Request**: Build authorization URL with PKCE (S256) and RFC 8707 resource parameter +7. **Authorization Code Exchange**: After user authorization, exchange code for access token (with resource parameter) +8. **Token Usage**: Use access token for API calls via `AuthClient` or `AuthorizedHttpClient` +9. **Token Refresh**: Automatically use refresh token to get new access token when current one expires; previously granted scopes are forwarded in the refresh request so providers that require them (e.g. Azure AD v2) work correctly ## Security Considerations @@ -123,4 +126,8 @@ If you encounter authorization issues, check the following: - [MCP Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) - [OAuth 2.1 Specification Draft](https://oauth.net/2.1/) - [RFC 8414: OAuth 2.0 Authorization Server Metadata](https://datatracker.ietf.org/doc/html/rfc8414) -- [RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol](https://datatracker.ietf.org/doc/html/rfc7591) +- [RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol](https://datatracker.ietf.org/doc/html/rfc7591) +- [RFC 8707: Resource Indicators for OAuth 2.0](https://datatracker.ietf.org/doc/html/rfc8707) +- [RFC 9728: OAuth 2.0 Protected Resource Metadata](https://datatracker.ietf.org/doc/html/rfc9728) +- [RFC 7636: Proof Key for Code Exchange (PKCE)](https://datatracker.ietf.org/doc/html/rfc7636) +- [RFC 6749 ยง6: Refreshing an Access Token](https://www.rfc-editor.org/rfc/rfc6749#section-6)