Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/proto/h2/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ where

if me.h2_tx.capacity() == 0 {
// poll_capacity oddly needs a loop
'capacity: loop {
loop {
match me.h2_tx.poll_capacity(cx) {
Poll::Ready(Some(Ok(0))) => {}
Poll::Ready(Some(Ok(_))) => break,
Expand All @@ -95,7 +95,7 @@ where
"send stream capacity unexpectedly closed",
)));
}
Poll::Pending => break 'capacity,
Poll::Pending => return Poll::Pending,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment from L71-L74 I think is relevant to why this previously did not return early. We want to make sure the waker is registered with each of the futures, so that if one side "cancels", the task can clean up quickly.

  • We want to notice when capacity has become available.
  • Or when the remote has sent a RST_STREAM (or other error)
  • Or when our bytes sender (on the me.rx side) has closed and no longer expects to send more data.

Said another way, if we're waiting for capacity, and the user drops the Upgraded type (meaning they no longer want to write), this UpgradedSendStreamTask will not notice and will hang around until capacity is eventually given (if the peer ever gives it), and only then hang up.

I get what you're trying to do, but I think the types or channels would need to adjusted a little to handle those cases.

}
}
}
Expand Down
277 changes: 277 additions & 0 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,283 @@ async fn h2_connect_empty_frames() {
.unwrap();
}

#[tokio::test]
async fn h2_connect_backpressure_respected() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(1024);
builder.initial_connection_window_size(1024);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

const CHUNK: &[u8] = b"backpressure test data chunk!\n";
const TOTAL_LEN: usize = CHUNK.len() * 2000;

let client_handle = tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, _send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let mut received = 0usize;

while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
if chunk.is_empty() {
break;
}
let len = chunk.len();
received += len;
let _ = body.flow_control().release_capacity(len);
}

assert_eq!(received, TOTAL_LEN);
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));

for _ in 0..2000 {
upgraded.write_all(CHUNK).await.unwrap();
}

upgraded.shutdown().await.unwrap();
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client_handle.await.unwrap();
}

#[tokio::test]
async fn h2_connect_zero_window_then_release() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(65535);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

const DATA: &[u8] = b"Hello from upgraded stream";

let client_handle = tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, _send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let mut received = Vec::new();

while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
if chunk.is_empty() {
break;
}
let len = chunk.len();
received.extend_from_slice(&chunk);
let _ = body.flow_control().release_capacity(len);
}

assert_eq!(&received[..], DATA);
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));
upgraded.write_all(DATA).await.unwrap();
upgraded.shutdown().await.unwrap();
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client_handle.await.unwrap();
}

#[tokio::test]
async fn h2_connect_reset_during_backpressure() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(1024);
builder.initial_connection_window_size(1024);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
let _ = connection.await;
});
let mut h2 = h2.ready().await.unwrap();

let (write_err_tx, write_err_rx) = oneshot::channel::<bool>();
let write_err_tx = Arc::new(Mutex::new(Some(write_err_tx)));

tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, mut send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let bytes = body.data().await.unwrap().unwrap();
let _ = body.flow_control().release_capacity(bytes.len());

send_stream.send_reset(h2::Reason::CANCEL);
drop(body);
drop(send_stream);

let got_err = write_err_rx.await.unwrap_or(false);
assert!(got_err, "server write should have failed after RST_STREAM");
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);
let write_err_tx = write_err_tx.clone();

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));
upgraded.write_all(b"initial").await.unwrap();

let large_data = vec![b'x'; 1024 * 1024];
let write_result = upgraded.write_all(&large_data).await;

if let Some(tx) = write_err_tx.lock().unwrap().take() {
let _ = tx.send(write_result.is_err());
}
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
let _ = http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await;
}

#[tokio::test]
async fn h2_connect_backpressure_bidirectional() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(2048);
builder.initial_connection_window_size(4096);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

const PATTERN: &[u8] = b"All work and no bread makes nox a dull boy.\n";
const REPEAT: usize = 500;
let expected_len = PATTERN.len() * REPEAT;

let client_handle = tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, mut send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let mut received = 0usize;

while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
if chunk.is_empty() {
break;
}
let len = chunk.len();
received += len;
let _ = body.flow_control().release_capacity(len);
}

assert_eq!(received, expected_len);

send_stream.send_data("client done".into(), true).unwrap();
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));

for _ in 0..REPEAT {
upgraded.write_all(PATTERN).await.unwrap();
}

upgraded.shutdown().await.unwrap();

let mut response_buf = vec![0u8; 64];
let n = upgraded.read(&mut response_buf).await.unwrap();
assert_eq!(&response_buf[..n], b"client done");
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client_handle.await.unwrap();
}

#[tokio::test]
async fn parse_errors_send_4xx_response() {
let (listener, addr) = setup_tcp_listener();
Expand Down
Loading