diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java index f3b98054d0..4e337e06b2 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java @@ -352,6 +352,7 @@ static DataStreamReplyByteBuffer newDataStreamReplyByteBuffer(DataStreamRequestB .setDataStreamPacket(request) .setBuffer(buffer) .setSuccess(reply.isSuccess()) + .setCommitInfos(reply.getCommitInfos()) .build(); } diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java index 451040bb62..24303d867e 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java @@ -152,6 +152,7 @@ void close() { private final ChannelFuture channelFuture; private final DataStreamManagement requests; + private final ReadStreamManagement reads; private final ProxiesPool proxies; private final NettyServerStreamRpcMetrics metrics; @@ -162,6 +163,7 @@ public NettyServerStreamRpc(RaftServer server, Parameters parameters) { this.name = server.getId() + "-" + JavaUtils.getClassSimpleName(getClass()); this.metrics = new NettyServerStreamRpcMetrics(this.name); this.requests = new DataStreamManagement(server, metrics); + this.reads = new ReadStreamManagement(server); final RaftProperties properties = server.getProperties(); @@ -235,6 +237,9 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { final DataStreamRequestByteBuf request = (DataStreamRequestByteBuf)msg; try(UncheckedAutoCloseable autoReset = requestRef.set(request)) { + if (reads.process(request, ctx)) { + return; + } requests.read(request, ctx, proxies.get(request)::getDataStreamOutput); } } @@ -248,6 +253,7 @@ public void channelInactive(ChannelHandlerContext ctx) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable throwable) { Optional.ofNullable(requestRef.getAndSetNull()) .ifPresent(request -> requests.replyDataStreamException(throwable, request, ctx)); + ctx.close(); } }; } diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java new file mode 100644 index 0000000000..bcdced1eac --- /dev/null +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.netty.server; + +import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer; +import org.apache.ratis.datastream.impl.DataStreamRequestByteBuf; +import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type; +import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto; +import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto.TypeCase; +import org.apache.ratis.protocol.ClientId; +import org.apache.ratis.protocol.RaftClientRequest; +import org.apache.ratis.protocol.exceptions.AlreadyClosedException; +import org.apache.ratis.server.RaftServer; +import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.ratis.thirdparty.io.netty.channel.ChannelFuture; +import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandlerContext; +import org.apache.ratis.util.JavaUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.concurrent.CompletableFuture; + +import static org.apache.ratis.client.impl.ClientProtoUtils.toRaftClientRequest; +import static org.apache.ratis.netty.server.DataStreamManagement.replyDataStreamException; + +public class ReadStreamManagement { + public static final Logger LOG = LoggerFactory.getLogger(ReadStreamManagement.class); + + static class ReadStream implements WritableByteChannel { + private final ClientId clientId; + private final long streamId; + private final ChannelHandlerContext ctx; + private final CompletableFuture closed = new CompletableFuture<>(); + private long streamOffset; + + ReadStream(DataStreamRequestByteBuf request, ChannelHandlerContext ctx) { + this.clientId = request.getClientId(); + this.streamId = request.getStreamId(); + this.ctx = ctx; + } + + @Override + public boolean isOpen() { + return !closed.isDone(); + } + + @Override + public void close() { + closed.complete(null); + } + + @Override + public synchronized int write(ByteBuffer buffer) throws IOException { + if (!isOpen()) { + throw new AlreadyClosedException("Channel closed at offset " + streamOffset); + } + buffer = buffer.asReadOnlyBuffer(); + final int length = buffer.remaining(); + final DataStreamReplyByteBuffer reply = newReply(buffer); + final ChannelFuture future = ctx.writeAndFlush(reply); + try { + future.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new InterruptedIOException( + "Interrupted while writing " + length + " bytes at offset " + streamOffset); + } + if (!future.isSuccess()) { + throw new IOException("Failed to write " + length + " bytes at offset " + streamOffset, future.cause()); + } + streamOffset += length; + return length; + } + + private synchronized DataStreamReplyByteBuffer newReply(ByteBuffer buffer) { + return DataStreamReplyByteBuffer.newBuilder() + .setClientId(clientId) + .setType(Type.STREAM_DATA) + .setStreamId(streamId) + .setStreamOffset(streamOffset) + .setBuffer(buffer) + .setSuccess(true) + .setBytesWritten(buffer.remaining()) + .build(); + } + } + + private final RaftServer server; + private final String name; + + ReadStreamManagement(RaftServer server) { + this.server = server; + this.name = server.getId() + "-" + JavaUtils.getClassSimpleName(getClass()); + } + + boolean process(DataStreamRequestByteBuf requestBuf, ChannelHandlerContext ctx) { + boolean processed = false; + try { + processed = processImpl(requestBuf, ctx); + } catch (Throwable e) { + LOG.error("Failed to process {}", requestBuf, e); + processed = true; + } finally { + if (processed) { + requestBuf.release(); + } + } + return processed; + } + + private boolean processImpl(DataStreamRequestByteBuf requestBuf, ChannelHandlerContext ctx) + throws InvalidProtocolBufferException { + if (requestBuf.getType() != Type.STREAM_HEADER) { + return false; + } + final RaftClientRequest request = toRaftClientRequest( + RaftClientRequestProto.parseFrom(requestBuf.slice().nioBuffer())); + if (!request.is(TypeCase.READ)) { + return false; + } + + final RaftServer.Division division; + try { + division = server.getDivision(request.getRaftGroupId()); + } catch (IOException e) { + replyDataStreamException(server, e, request, requestBuf, ctx); + return true; + } + + final ReadStream stream = new ReadStream(requestBuf, ctx); + division.getStateMachine().data().query(request.getMessage(), stream); + return true; + } + + @Override + public String toString() { + return name; + } +} diff --git a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java index 98d4537847..61e708febb 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java @@ -116,6 +116,17 @@ default CompletableFuture stream(RaftClientRequest request) { return CompletableFuture.completedFuture(null); } + /** + * Similar to {@link #query(Message)} except that + * {@link #query(Message)} returns the result in a future + * while this method sends the result using the given stream. + * + * @param request the client request + * @param stream the output stream to send the results + */ + default void query(Message request, WritableByteChannel stream) { + } + /** * Link asynchronously the given stream with the given log entry. * The given stream can be null if it is unavailable due to errors. diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java index 989b6cd2b2..fe9c3f9ea1 100644 --- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java +++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java @@ -57,9 +57,11 @@ import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.Collection; @@ -147,8 +149,14 @@ static RoutingTable getRoutingTableChainTopology(Iterable peers, Raf } class MultiDataStreamStateMachine extends BaseStateMachine { + static final int READ_ONLY_STREAM_CHUNKS = 3; + private final ConcurrentMap streams = new ConcurrentHashMap<>(); + static ByteString getReadOnlyStreamChunk(ByteString query, int index) { + return query.concat(ByteString.copyFromUtf8("-chunk-" + index)); + } + @Override public CompletableFuture stream(RaftClientRequest request) { final SingleDataStream s = new SingleDataStream(request); @@ -176,6 +184,34 @@ public CompletableFuture applyTransaction(TransactionContext trx) { return CompletableFuture.completedFuture(() -> bytesWritten); } + @Override + public CompletableFuture query(Message request) { + return CompletableFuture.completedFuture(request); + } + + @Override + public void query(Message request, WritableByteChannel stream) { + CompletableFuture.supplyAsync(() -> { + try { + streamReadOnlyImpl(request, stream); + } catch (IOException e) { + throw new CompletionException("Failed to streamReadOnly for " + request, e); + } + return null; + }); + } + + private void streamReadOnlyImpl(Message request, WritableByteChannel stream) throws IOException { + try { + for (int i = 0; i < READ_ONLY_STREAM_CHUNKS; i++) { + final ByteString chunk = getReadOnlyStreamChunk(request.getContent(), i); + stream.write(chunk.asReadOnlyByteBuffer()); + } + } finally { + stream.close(); + } + } + SingleDataStream getSingleDataStream(RaftClientRequest request) { return getSingleDataStream(ClientInvocationId.valueOf(request)); } diff --git a/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java b/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java index 5c06ddd319..1573a2a283 100644 --- a/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java +++ b/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java @@ -17,34 +17,124 @@ */ package org.apache.ratis.netty.server; +import org.apache.ratis.client.impl.ClientProtoUtils; import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl; import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer; import org.apache.ratis.datastream.impl.DataStreamRequestByteBuf; import org.apache.ratis.io.StandardWriteOption; import org.apache.ratis.netty.metrics.NettyServerStreamRpcMetrics; import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type; import org.apache.ratis.protocol.ClientId; +import org.apache.ratis.protocol.DataStreamReply; +import org.apache.ratis.protocol.Message; import org.apache.ratis.protocol.RaftClientRequest; +import org.apache.ratis.protocol.RaftGroupId; import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.server.RaftServer; +import org.apache.ratis.statemachine.StateMachine; +import org.apache.ratis.statemachine.StateMachine.DataApi; +import org.apache.ratis.statemachine.impl.BaseStateMachine; +import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; +import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf; import org.apache.ratis.thirdparty.io.netty.buffer.Unpooled; import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.ratis.thirdparty.io.netty.channel.ChannelId; import org.apache.ratis.thirdparty.io.netty.channel.ChannelInboundHandlerAdapter; import org.apache.ratis.thirdparty.io.netty.channel.embedded.EmbeddedChannel; +import org.apache.ratis.util.JavaUtils; +import org.apache.ratis.util.TimeDuration; import org.apache.ratis.util.function.CheckedBiFunction; import org.junit.jupiter.api.Test; import java.io.IOException; import java.lang.reflect.Proxy; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; class TestDataStreamManagement { + @Test + void readOnlyRequestInvokesReadStreamManagement() throws Exception { + final RaftPeerId serverId = RaftPeerId.valueOf("s1"); + final ClientId clientId = ClientId.randomId(); + final RaftGroupId groupId = RaftGroupId.randomId(); + final ByteString query = ByteString.copyFromUtf8("query"); + final ByteString response = ByteString.copyFromUtf8("response"); + + final AtomicReference messageRef = new AtomicReference<>(); + final AtomicReference streamRef = new AtomicReference<>(); + final DataApi dataApi = new DataApi() { + @Override + public void query(Message request, WritableByteChannel stream) { + messageRef.set(request); + streamRef.set(stream); + } + }; + final StateMachine stateMachine = new BaseStateMachine() { + @Override + public DataApi data() { + return dataApi; + } + }; + final RaftServer server = newRaftServer(serverId, new RaftProperties(), groupId, newDivision(stateMachine)); + final ReadStreamManagement management = new ReadStreamManagement(server); + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + + final RaftClientRequest raftClientRequest = RaftClientRequest.newBuilder() + .setClientId(clientId) + .setServerId(serverId) + .setGroupId(groupId) + .setCallId(1L) + .setMessage(Message.valueOf(query)) + .setType(RaftClientRequest.readRequestType()) + .build(); + final ByteBuffer header = ClientProtoUtils.toRaftClientRequestProtoByteBuffer(raftClientRequest); + final ByteBuf headerBuf = Unpooled.wrappedBuffer(header); + final DataStreamRequestByteBuf request = new DataStreamRequestByteBuf( + clientId, + Type.STREAM_HEADER, + raftClientRequest.getCallId(), + 0L, + Collections.singletonList(StandardWriteOption.FLUSH), + headerBuf); + + try { + assertTrue(management.process(request, embeddedChannel.pipeline().firstContext())); + assertEquals(0, headerBuf.refCnt()); + + final WritableByteChannel stream = streamRef.get(); + assertNotNull(stream); + stream.write(response.asReadOnlyByteBuffer()); + stream.close(); + + final List replies = new ArrayList<>(); + JavaUtils.attempt(() -> { + for (Object outbound; (outbound = embeddedChannel.readOutbound()) != null;) { + replies.add((DataStreamReply) outbound); + } + assertEquals(1, replies.size()); + }, 10, TimeDuration.valueOf(100, TimeUnit.MILLISECONDS), "read-only replies", null); + + assertEquals(query, messageRef.get().getContent()); + assertFalse(streamRef.get().isOpen(), "state machine should close the streaming query channel"); + assertSuccessReply(Type.STREAM_DATA, response.size(), replies.get(0)); + } finally { + embeddedChannel.finishAndReleaseAll(); + } + } + @Test void readCleansChannelMapOnEarlyException() throws Exception { // Scenario: STREAM_DATA arrives without prior STREAM_HEADER, so readImpl fails early. @@ -85,29 +175,62 @@ void readCleansChannelMapOnEarlyException() throws Exception { } } + private static void assertSuccessReply(Type expectedType, long expectedBytesWritten, DataStreamReply reply) { + assertEquals(expectedType, reply.getType()); + assertTrue(reply.isSuccess()); + assertEquals(expectedBytesWritten, reply.getBytesWritten()); + assertTrue(reply instanceof DataStreamReplyByteBuffer); + } + private static RaftServer newRaftServer(RaftPeerId serverId, RaftProperties properties) { - return (RaftServer) Proxy.newProxyInstance(TestDataStreamManagement.class.getClassLoader(), - new Class[]{RaftServer.class}, + return newRaftServer(serverId, properties, null, null); + } + + private static RaftServer newRaftServer(RaftPeerId serverId, RaftProperties properties, + RaftGroupId groupId, RaftServer.Division division) { + return (RaftServer) Proxy.newProxyInstance(RaftServer.class.getClassLoader(), new Class[]{RaftServer.class}, (proxy, method, args) -> { - if (method.getDeclaringClass() == Object.class) { - switch (method.getName()) { - case "toString": - return "RaftServerProxy(" + serverId + ")"; - case "hashCode": - return System.identityHashCode(proxy); - case "equals": - return proxy == args[0]; - default: - return null; + switch (method.getName()) { + case "getId": + return serverId; + case "getProperties": + return properties; + case "getDivision": + if (groupId != null && groupId.equals(args[0])) { + return division; } + throw new IOException("Division not found: " + args[0]); + case "close": + return null; + case "toString": + return serverId.toString(); + case "hashCode": + return System.identityHashCode(proxy); + case "equals": + return proxy == args[0]; + default: + throw new UnsupportedOperationException(method.toString()); } + }); + } + + private static RaftServer.Division newDivision(StateMachine stateMachine) { + return (RaftServer.Division) Proxy.newProxyInstance(RaftServer.Division.class.getClassLoader(), + new Class[]{RaftServer.Division.class}, + (proxy, method, args) -> { switch (method.getName()) { - case "getId": - return serverId; - case "getProperties": - return properties; - default: - throw new UnsupportedOperationException("Unexpected RaftServer call: " + method); + case "getStateMachine": + return stateMachine; + case "close": + return null; + case "toString": + return stateMachine.toString(); + case "hashCode": + return System.identityHashCode(proxy); + case "equals": + return proxy == args[0]; + default: + throw new UnsupportedOperationException(method.toString()); } }); }