Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<'c> AsyncHttpClient<'c> for OAuthReqwestClient {
const DEFAULT_EXCHANGE_URL: &str = "http://localhost";

/// Stored credentials for OAuth2 authorization
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Clone, Serialize, Deserialize)]
pub struct StoredCredentials {
pub client_id: String,
pub token_response: Option<OAuthTokenResponse>,
Expand All @@ -69,6 +69,20 @@ pub struct StoredCredentials {
pub token_received_at: Option<u64>,
}

impl std::fmt::Debug for StoredCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StoredCredentials")
.field("client_id", &self.client_id)
.field(
"token_response",
&self.token_response.as_ref().map(|_| "[REDACTED]"),
)
.field("granted_scopes", &self.granted_scopes)
.field("token_received_at", &self.token_received_at)
.finish()
}
}

/// Trait for storing and retrieving OAuth2 credentials
///
/// Implementations of this trait can provide custom storage backends
Expand Down Expand Up @@ -119,13 +133,23 @@ impl CredentialStore for InMemoryCredentialStore {
}

/// Stored authorization state for OAuth2 PKCE flow
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Clone, Serialize, Deserialize)]
pub struct StoredAuthorizationState {
pub pkce_verifier: String,
pub csrf_token: String,
pub created_at: u64,
}

impl std::fmt::Debug for StoredAuthorizationState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StoredAuthorizationState")
.field("pkce_verifier", &"[REDACTED]")
.field("csrf_token", &"[REDACTED]")
.field("created_at", &self.created_at)
.finish()
}
}

/// A transparent wrapper around a JSON object that captures any extra fields returned by the
/// authorization server during token exchange that are not part of the standard OAuth 2.0 token
/// response.
Expand Down Expand Up @@ -2776,6 +2800,44 @@ mod tests {
assert_eq!(deserialized.csrf_token, "my-csrf");
}

#[test]
fn test_stored_authorization_state_debug_redacts_secrets() {
let pkce = PkceCodeVerifier::new("super-secret-verifier".to_string());
let csrf = CsrfToken::new("super-secret-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
let debug_output = format!("{:?}", state);

assert!(!debug_output.contains("super-secret-verifier"));
assert!(!debug_output.contains("super-secret-csrf"));
assert!(debug_output.contains("[REDACTED]"));
assert!(debug_output.contains("created_at"));
assert!(debug_output.contains("created_at"));
}

#[test]
fn test_stored_credentials_debug_redacts_token_response() {
use oauth2::{AccessToken, basic::BasicTokenType};

use super::{OAuthTokenResponse, StoredCredentials};

let token_response = OAuthTokenResponse::new(
AccessToken::new("super-secret-access-token".to_string()),
BasicTokenType::Bearer,
VendorExtraTokenFields::default(),
);
let creds = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(token_response),
granted_scopes: vec![],
token_received_at: None,
};
let debug_output = format!("{:?}", creds);

assert!(!debug_output.contains("super-secret-access-token"));
assert!(debug_output.contains("[REDACTED]"));
assert!(debug_output.contains("my-client"));
}

#[test]
fn test_stored_authorization_state_into_pkce_verifier() {
let pkce = PkceCodeVerifier::new("original-verifier".to_string());
Expand Down
Loading