From 7cff8c54f1a4057f5138e2d99f4ea591b42134e3 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 10 Mar 2026 16:32:33 +0100 Subject: [PATCH 1/2] Default to http2 for proxy-client to proxy-server connections --- Cargo.lock | 1 + Cargo.toml | 1 + src/lib.rs | 178 ++++++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 152 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f30b0c1..3124798 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -613,6 +613,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "webpki-roots", "x509-parser", ] diff --git a/Cargo.toml b/Cargo.toml index 305efca..b482c75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ serde = "1.0.228" reqwest = { version = "0.12.24", default-features = false, features = [ "rustls-tls-webpki-roots-no-provider", ] } +webpki-roots = "1.0.4" tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } axum = "0.8.6" diff --git a/src/lib.rs b/src/lib.rs index 58e0ec9..d145349 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,8 +24,10 @@ use thiserror::Error; use tokio::io; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; -use tokio_rustls::rustls::server::VerifierBuilderError; -use tokio_rustls::rustls::{ClientConfig, ServerConfig, pki_types::CertificateDer}; +use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; +use tokio_rustls::rustls::{ + self, ClientConfig, RootCertStore, ServerConfig, pki_types::CertificateDer, +}; use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; @@ -59,6 +61,17 @@ type RequestWithResponseSender = ( oneshot::Sender>, hyper::Error>>, ); +/// Adds HTTP 1 and 2 to the list of allowed protocols +fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { + for protocol in [ALPN_H2, ALPN_HTTP11] { + let already_present = alpn_protocols.iter().any(|p| p.as_slice() == protocol); + + if !already_present { + alpn_protocols.push(protocol.to_vec()); + } + } +} + /// Retrieve the attested remote TLS certificate. pub async fn get_tls_cert( server_name: String, @@ -101,11 +114,32 @@ impl ProxyServer { attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result { - let attested_tls_server = AttestedTlsServer::new( - cert_and_key, + let mut server_config = if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(verifier) + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + }; + ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols); + + let attested_tls_server = AttestedTlsServer::new_with_tls_config( + cert_and_key.cert_chain, + server_config, attestation_generator, attestation_verifier, - client_auth, )?; let listener = TcpListener::bind(local).await?; @@ -126,16 +160,7 @@ impl ProxyServer { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { - for protocol in [ALPN_H2, ALPN_HTTP11] { - let already_present = server_config - .alpn_protocols - .iter() - .any(|p| p.as_slice() == protocol); - - if !already_present { - server_config.alpn_protocols.push(protocol.to_vec()); - } - } + ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols); let attested_tls_server = AttestedTlsServer::new_with_tls_config( cert_chain, @@ -347,11 +372,34 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { - let attested_tls_client = AttestedTlsClient::new( - cert_and_key, + let root_store = match remote_certificate { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate)?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + let mut client_config = if let Some(ref cert_and_key) = cert_and_key { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(root_store) + .with_client_auth_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(root_store) + .with_no_client_auth() + }; + ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); + + let attested_tls_client = AttestedTlsClient::new_with_tls_config( + client_config, attestation_generator, attestation_verifier, - remote_certificate, + cert_and_key.map(|c| c.cert_chain), )?; Self::new_with_inner(address, attested_tls_client, &server_name).await @@ -366,16 +414,7 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - for protocol in [ALPN_H2, ALPN_HTTP11] { - let already_present = client_config - .alpn_protocols - .iter() - .any(|p| p.as_slice() == protocol); - - if !already_present { - client_config.alpn_protocols.push(protocol.to_vec()); - } - } + ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); let attested_tls_client = AttestedTlsClient::new_with_tls_config( client_config, @@ -763,6 +802,89 @@ mod tests { generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements, }; + #[test] + fn proxy_alpn_protocols_prefer_http2() { + let mut protocols = Vec::new(); + ensure_proxy_alpn_protocols(&mut protocols); + + assert_eq!(protocols, vec![ALPN_H2.to_vec(), ALPN_HTTP11.to_vec()]); + } + + #[test] + fn proxy_alpn_protocols_preserve_existing_order_without_duplicates() { + let mut protocols = vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]; + ensure_proxy_alpn_protocols(&mut protocols); + + assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]); + } + + #[tokio::test] + async fn http_proxy_default_constructors_work() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let server_cert = cert_chain[0].clone(); + + let proxy_server = ProxyServer::new( + TlsCertAndKey { + cert_chain, + key: private_key, + }, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new( + None, + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + Some(server_cert), + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + let headers = res.headers(); + + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = + MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) + .unwrap(); + assert_eq!(measurements, mock_dcap_measurements()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); + } + // Server has mock DCAP, client has no attestation and no client auth #[tokio::test] async fn http_proxy_with_server_attestation() { From 9540908374f7346c3cd983df7ba7df54841ad2dd Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 10 Mar 2026 16:41:21 +0100 Subject: [PATCH 2/2] Improve logging --- src/http_version.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/http_version.rs b/src/http_version.rs index 9653cf0..bef817c 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -21,8 +21,12 @@ impl HttpVersion { pub fn from_negotiated_protocol_server(tls: &tokio_rustls::server::TlsStream) -> Self { let (_io, conn) = tls.get_ref(); - let chosen_protocol = Self::from_alpn_bytes(conn.alpn_protocol()); - tracing::debug!("[server] Chosen protocol {chosen_protocol:?}",); + let negotiated_alpn = conn.alpn_protocol(); + let chosen_protocol = Self::from_alpn_bytes(negotiated_alpn); + tracing::debug!( + "[server] Negotiated ALPN {:?}, chosen protocol {chosen_protocol:?}", + negotiated_alpn.map(String::from_utf8_lossy) + ); chosen_protocol } @@ -30,8 +34,12 @@ impl HttpVersion { pub fn from_negotiated_protocol_client(tls: &tokio_rustls::client::TlsStream) -> Self { let (_io, conn) = tls.get_ref(); - let chosen_protocol = Self::from_alpn_bytes(conn.alpn_protocol()); - tracing::debug!("[client] Chosen protocol {chosen_protocol:?}",); + let negotiated_alpn = conn.alpn_protocol(); + let chosen_protocol = Self::from_alpn_bytes(negotiated_alpn); + tracing::debug!( + "[client] Negotiated ALPN {:?}, chosen protocol {chosen_protocol:?}", + negotiated_alpn.map(String::from_utf8_lossy) + ); chosen_protocol }