diff --git a/stream/storage/impl/src/main/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptor.java b/stream/storage/impl/src/main/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptor.java index 664a40da71a..d2d9108bb88 100644 --- a/stream/storage/impl/src/main/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptor.java +++ b/stream/storage/impl/src/main/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptor.java @@ -337,8 +337,15 @@ private ReqT interceptMessage( if (null == descriptor) { return message; } else { + byte[] requestBytes = null; try { - return interceptTableRequest(method, descriptor, message, sid, rid, rk); + if (message.getClass() == descriptor.getClz()) { + return interceptTableRequest(method, descriptor, message, sid, rid, rk); + } else { + InputStream is = method.getRequestMarshaller().stream(message); + requestBytes = is.readAllBytes(); + return interceptTableRequest(method, descriptor, requestBytes, sid, rid, rk); + } } catch (Throwable t) { log.error() .attr("streamId", sid) @@ -346,6 +353,9 @@ private ReqT interceptMessage( .attr("routingKey", Hex.encodeHexString(rk)) .exception(t) .log("Failed to intercept table request"); + if (null != requestBytes) { + return method.getRequestMarshaller().parse(new ByteArrayInputStream(requestBytes)); + } return message; } } @@ -363,9 +373,8 @@ private ReqT interceptTableRequest( if (message.getClass() == interceptor.getClz()) { request = (TableReqT) message; } else { - InputStream is = method.getRequestMarshaller().stream(message); - byte[] bytes = is.readAllBytes(); - request = interceptor.getParser().parseFrom(bytes); + return interceptTableRequest(method, interceptor, + method.getRequestMarshaller().stream(message).readAllBytes(), sid, rid, rk); } TableReqT interceptedMessage = interceptor.getInterceptor().intercept( request, sid, rid, rk @@ -378,4 +387,18 @@ private ReqT interceptTableRequest( } } + + private ReqT interceptTableRequest( + MethodDescriptor method, + InterceptorDescriptor interceptor, + byte[] requestBytes, + Long sid, Long rid, byte[] rk + ) { + TableReqT request = interceptor.getParser().parseFrom(requestBytes); + TableReqT interceptedMessage = interceptor.getInterceptor().intercept( + request, sid, rid, rk + ); + byte[] reqBytes = interceptor.getSerializer().toByteArray(interceptedMessage); + return method.getRequestMarshaller().parse(new ByteArrayInputStream(reqBytes)); + } } diff --git a/stream/storage/impl/src/test/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptorTest.java b/stream/storage/impl/src/test/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptorTest.java index 17f2ce93195..88e95368943 100644 --- a/stream/storage/impl/src/test/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptorTest.java +++ b/stream/storage/impl/src/test/java/org/apache/bookkeeper/stream/storage/impl/routing/RoutingHeaderProxyInterceptorTest.java @@ -23,7 +23,9 @@ import static org.apache.bookkeeper.stream.protocol.ProtocolConstants.RID_METADATA_KEY; import static org.apache.bookkeeper.stream.protocol.ProtocolConstants.RK_METADATA_KEY; import static org.apache.bookkeeper.stream.protocol.ProtocolConstants.SID_METADATA_KEY; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; import io.grpc.CallOptions; import io.grpc.Channel; @@ -34,11 +36,14 @@ import io.grpc.MethodDescriptor; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.stub.StreamObserver; +import java.io.ByteArrayInputStream; +import java.io.InputStream; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import lombok.CustomLog; import org.apache.bookkeeper.clients.grpc.GrpcClientTestBase; import org.apache.bookkeeper.clients.impl.channel.StorageServerChannel; +import org.apache.bookkeeper.common.grpc.netty.IdentityInputStreamMarshaller; import org.apache.bookkeeper.stream.proto.kv.rpc.DeleteRangeRequest; import org.apache.bookkeeper.stream.proto.kv.rpc.DeleteRangeResponse; import org.apache.bookkeeper.stream.proto.kv.rpc.IncrementRequest; @@ -49,6 +54,7 @@ import org.apache.bookkeeper.stream.proto.kv.rpc.RangeResponse; import org.apache.bookkeeper.stream.proto.kv.rpc.ResponseHeader; import org.apache.bookkeeper.stream.proto.kv.rpc.RoutingHeader; +import org.apache.bookkeeper.stream.proto.kv.rpc.TableServiceGrpc; import org.apache.bookkeeper.stream.proto.kv.rpc.TableServiceGrpc.TableServiceImplBase; import org.apache.bookkeeper.stream.proto.kv.rpc.TxnRequest; import org.apache.bookkeeper.stream.proto.kv.rpc.TxnResponse; @@ -271,4 +277,85 @@ public void testTxnRequest() throws Exception { assertEquals(expectedRequest.getHeader(), response.getHeader().getRoutingHeader()); } + @Test + public void testForwardOriginalInputStreamWhenInterceptionFails() throws Exception { + byte[] requestBytes = new byte[] { 4, 1, 2, 3 }; + CapturingClientCall delegateCall = new CapturingClientCall(); + RoutingHeaderProxyInterceptor interceptor = new RoutingHeaderProxyInterceptor(); + + ClientCall interceptedCall = interceptor.interceptCall( + proxyTxnMethod(), + CallOptions.DEFAULT, + new CapturingChannel(delegateCall)); + + Metadata headers = new Metadata(); + headers.put(SID_METADATA_KEY, 1026L); + headers.put(RID_METADATA_KEY, 1030L); + headers.put(RK_METADATA_KEY, "txn-key".getBytes(UTF_8)); + + ByteArrayInputStream originalMessage = new ByteArrayInputStream(requestBytes); + interceptedCall.start(new ClientCall.Listener() { }, headers); + interceptedCall.sendMessage(originalMessage); + + assertNotSame(originalMessage, delegateCall.message); + assertArrayEquals(requestBytes, delegateCall.message.readAllBytes()); + } + + private static MethodDescriptor proxyTxnMethod() { + return MethodDescriptor.newBuilder( + IdentityInputStreamMarshaller.of(), + IdentityInputStreamMarshaller.of()) + .setFullMethodName(TableServiceGrpc.getTxnMethod().getFullMethodName()) + .setType(TableServiceGrpc.getTxnMethod().getType()) + .build(); + } + + private static class CapturingChannel extends Channel { + + private final CapturingClientCall call; + + CapturingChannel(CapturingClientCall call) { + this.call = call; + } + + @SuppressWarnings("unchecked") + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, + CallOptions callOptions) { + return (ClientCall) call; + } + + @Override + public String authority() { + return "test-authority"; + } + } + + private static class CapturingClientCall extends ClientCall { + + private InputStream message; + + @Override + public void start(Listener responseListener, Metadata headers) { + } + + @Override + public void request(int numMessages) { + } + + @Override + public void cancel(String message, Throwable cause) { + } + + @Override + public void halfClose() { + } + + @Override + public void sendMessage(InputStream message) { + this.message = message; + } + } + }