diff --git a/crates/google-workspace-cli/src/executor.rs b/crates/google-workspace-cli/src/executor.rs index 46f31ac4..e4649297 100644 --- a/crates/google-workspace-cli/src/executor.rs +++ b/crates/google-workspace-cli/src/executor.rs @@ -187,10 +187,7 @@ async fn build_http_request( } } - // Set quota project from ADC for billing/quota attribution - if let Some(quota_project) = crate::auth::get_quota_project() { - request = request.header("x-goog-user-project", quota_project); - } + request = add_quota_project_header(request); let mut all_query_params = input.query_params.clone(); if let Some(pt) = page_token { @@ -384,6 +381,69 @@ async fn handle_binary_response( Ok(None) } +fn extract_download_uri(json_val: &Value) -> Option<&str> { + [ + "/response/downloadUri", + "/response/downloadUrl", + "/metadata/downloadUri", + "/metadata/downloadUrl", + "/downloadUri", + "/downloadUrl", + ] + .into_iter() + .find_map(|path| json_val.pointer(path).and_then(|v| v.as_str())) +} + +fn is_drive_download_operation(json_val: &Value) -> bool { + json_val + .get("kind") + .and_then(Value::as_str) + .is_some_and(|kind| kind == "drive#operation") +} + +fn parse_download_uri_host(uri: &str) -> Option { + let Ok(url) = reqwest::Url::parse(uri) else { + return None; + }; + if url.scheme() != "https" || !url.username().is_empty() || url.password().is_some() { + return None; + }; + url.host_str().map(ToOwned::to_owned) +} + +fn is_google_download_uri(uri: &str) -> bool { + parse_download_uri_host(uri).as_deref().is_some_and(|host| { + host == "googleapis.com" + || host.ends_with(".googleapis.com") + || host.ends_with(".googleusercontent.com") + }) +} + +fn is_google_api_download_host(uri: &str) -> bool { + matches!( + parse_download_uri_host(uri).as_deref(), + Some("googleapis.com" | "www.googleapis.com" | "storage.googleapis.com") + ) +} + +fn extract_google_download_uri(body_text: &str) -> Result, GwsError> { + let Ok(json_val) = serde_json::from_str::(body_text) else { + return Ok(None); + }; + if !is_drive_download_operation(&json_val) { + return Ok(None); + } + let Some(uri) = extract_download_uri(&json_val) else { + return Ok(None); + }; + if !is_google_download_uri(uri) { + return Err(GwsError::Validation( + "Refusing to follow non-Google downloadUri from API response".to_string(), + )); + } + Ok(Some(uri.to_string())) +} + /// Executes an API method call. /// /// This is the core function of the CLI that handles: @@ -464,7 +524,10 @@ pub async fn execute_method( .to_string(); if !status.is_success() { - let error_body = response.text().await.unwrap_or_default(); + let error_body = response + .text() + .await + .context("Failed to read API error response body")?; tracing::warn!( api_method = method_id, http_method = %method.http_method, @@ -495,6 +558,45 @@ pub async fn execute_method( .await .context("Failed to read response body")?; + if output_path.is_some() && method.id.as_deref() == Some("drive.files.download") { + if let Some(download_uri) = extract_google_download_uri(&body_text)? { + let download_request = + build_download_request(&client, &download_uri, token, &auth_method); + let download_response = download_request + .send() + .await + .context("HTTP download request failed")?; + let download_status = download_response.status(); + let download_content_type = download_response + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("application/octet-stream") + .to_string(); + + if !download_status.is_success() { + let error_body = download_response + .text() + .await + .context("Failed to read Drive download error response body")?; + return handle_error_response(download_status, &error_body, &auth_method); + } + + if let Some(res) = handle_binary_response( + download_response, + &download_content_type, + output_path, + output_format, + capture_output, + ) + .await? + { + captured_values.push(res); + } + break; + } + } + let should_continue = handle_json_response( &body_text, pagination, @@ -537,6 +639,50 @@ pub async fn execute_method( Ok(None) } +fn add_quota_project_header(request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(quota_project) = crate::auth::get_quota_project() { + request.header("x-goog-user-project", quota_project) + } else { + request + } +} + +fn is_signed_download_uri(download_uri: &str) -> bool { + reqwest::Url::parse(download_uri) + .map(|url| { + url.query_pairs().any(|(key, _)| { + let key = key.as_ref(); + key.eq_ignore_ascii_case("GoogleAccessId") + || key.eq_ignore_ascii_case("Signature") + || key.to_ascii_lowercase().starts_with("x-goog-") + }) + }) + .unwrap_or(false) +} + +fn build_download_request( + client: &reqwest::Client, + download_uri: &str, + token: Option<&str>, + auth_method: &AuthMethod, +) -> reqwest::RequestBuilder { + // Keep secondary Drive downloads on the same reqwest client so they use + // the same client-level configuration as API requests. + let mut request = client.get(download_uri); + let is_signed = is_signed_download_uri(download_uri); + + if !is_signed { + request = add_quota_project_header(request); + if let Some(token) = token { + if *auth_method == AuthMethod::OAuth && is_google_api_download_host(download_uri) { + request = request.bearer_auth(token); + } + } + } + + request +} + fn build_url( doc: &RestDescription, method: &RestMethod, @@ -1209,6 +1355,243 @@ mod tests { assert_ne!(AuthMethod::OAuth, AuthMethod::None); } + #[test] + fn test_extract_download_uri_from_drive_operation_response() { + let operation = json!({ + "done": true, + "response": { + "downloadUri": "https://www.googleapis.com/download/drive/v3/files/abc?alt=media" + } + }); + + assert_eq!( + extract_download_uri(&operation), + Some("https://www.googleapis.com/download/drive/v3/files/abc?alt=media") + ); + } + + #[test] + fn test_extract_google_download_uri_ignores_user_json_file_content() { + let file_content = json!({ + "done": true, + "downloadUri": "https://www.googleapis.com/download/drive/v3/files/abc?alt=media" + }) + .to_string(); + + assert_eq!(extract_google_download_uri(&file_content).unwrap(), None); + } + + #[test] + fn test_extract_google_download_uri_accepts_drive_operation_kind() { + let operation = json!({ + "kind": "drive#operation", + "response": { + "downloadUrl": "https://www.googleapis.com/download/drive/v3/files/abc?alt=media" + } + }) + .to_string(); + + assert_eq!( + extract_google_download_uri(&operation).unwrap(), + Some("https://www.googleapis.com/download/drive/v3/files/abc?alt=media".to_string()) + ); + } + + #[test] + fn test_extract_google_download_uri_rejects_non_google_url() { + let operation = json!({ + "kind": "drive#operation", + "response": { + "downloadUri": "https://example.com/file.csv" + } + }) + .to_string(); + + let err = extract_google_download_uri(&operation).unwrap_err(); + assert!(err.to_string().contains("non-Google downloadUri")); + } + + #[test] + fn test_is_google_download_uri_allows_google_download_hosts() { + assert!(is_google_download_uri("https://googleapis.com/download")); + assert!(is_google_download_uri( + "https://storage.googleapis.com/download/storage/v1/b/bucket/o/file" + )); + assert!(is_google_download_uri( + "https://doc-0k-8s-docs.googleusercontent.com/document/export" + )); + assert!(is_google_download_uri( + "https://attacker-bucket.storage.googleapis.com/file" + )); + assert!(!is_google_download_uri( + "https://storage.googleapis.com.evil.example/file" + )); + } + + #[test] + #[serial_test::serial] + fn test_add_quota_project_header_uses_configured_project() { + unsafe { + std::env::set_var("GOOGLE_WORKSPACE_PROJECT_ID", "quota-project"); + } + + let request = add_quota_project_header(reqwest::Client::new().get("https://example.com")) + .build() + .unwrap(); + + unsafe { + std::env::remove_var("GOOGLE_WORKSPACE_PROJECT_ID"); + } + + assert_eq!( + request + .headers() + .get("x-goog-user-project") + .and_then(|value| value.to_str().ok()), + Some("quota-project") + ); + } + + #[test] + #[serial_test::serial] + fn test_build_download_request_keeps_client_auth_and_quota_headers() { + let client = reqwest::Client::new(); + + unsafe { + std::env::set_var("GOOGLE_WORKSPACE_PROJECT_ID", "quota-project"); + } + + let request = build_download_request( + &client, + "https://www.googleapis.com/download/drive/v3/files/abc?alt=media", + Some("access-token"), + &AuthMethod::OAuth, + ) + .build() + .unwrap(); + + unsafe { + std::env::remove_var("GOOGLE_WORKSPACE_PROJECT_ID"); + } + + assert_eq!( + request + .headers() + .get("x-goog-user-project") + .and_then(|value| value.to_str().ok()), + Some("quota-project") + ); + assert_eq!( + request + .headers() + .get(reqwest::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()), + Some("Bearer access-token") + ); + } + + #[test] + #[serial_test::serial] + fn test_build_download_request_skips_headers_for_signed_uri() { + let client = reqwest::Client::new(); + + unsafe { + std::env::set_var("GOOGLE_WORKSPACE_PROJECT_ID", "quota-project"); + } + + let request = build_download_request( + &client, + "https://storage.googleapis.com/download/storage/v1/b/bucket/o/file?X-Goog-Signature=sig&X-Goog-Credential=credential", + Some("access-token"), + &AuthMethod::OAuth, + ) + .build() + .unwrap(); + + unsafe { + std::env::remove_var("GOOGLE_WORKSPACE_PROJECT_ID"); + } + + assert!(request.headers().get("x-goog-user-project").is_none()); + assert!(request + .headers() + .get(reqwest::header::AUTHORIZATION) + .is_none()); + } + + #[test] + #[serial_test::serial] + fn test_build_download_request_sends_bearer_to_exact_storage_host() { + let client = reqwest::Client::new(); + + unsafe { + std::env::set_var("GOOGLE_WORKSPACE_PROJECT_ID", "quota-project"); + } + + let request = build_download_request( + &client, + "https://storage.googleapis.com/download/storage/v1/b/bucket/o/file", + Some("access-token"), + &AuthMethod::OAuth, + ) + .build() + .unwrap(); + + unsafe { + std::env::remove_var("GOOGLE_WORKSPACE_PROJECT_ID"); + } + + assert_eq!( + request + .headers() + .get("x-goog-user-project") + .and_then(|value| value.to_str().ok()), + Some("quota-project") + ); + assert_eq!( + request + .headers() + .get(reqwest::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()), + Some("Bearer access-token") + ); + } + + #[test] + #[serial_test::serial] + fn test_build_download_request_never_sends_bearer_to_storage_subdomain() { + let client = reqwest::Client::new(); + + unsafe { + std::env::set_var("GOOGLE_WORKSPACE_PROJECT_ID", "quota-project"); + } + + let request = build_download_request( + &client, + "https://attacker-bucket.storage.googleapis.com/file", + Some("access-token"), + &AuthMethod::OAuth, + ) + .build() + .unwrap(); + + unsafe { + std::env::remove_var("GOOGLE_WORKSPACE_PROJECT_ID"); + } + + assert_eq!( + request + .headers() + .get("x-goog-user-project") + .and_then(|value| value.to_str().ok()), + Some("quota-project") + ); + assert!(request + .headers() + .get(reqwest::header::AUTHORIZATION) + .is_none()); + } + #[test] fn test_mime_to_extension_more_types() { assert_eq!(mime_to_extension("text/plain"), "txt");