diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/processor/twostage/exchange/payload/CombineRequest.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/processor/twostage/exchange/payload/CombineRequest.java index 99c8bb67a1a74..ae6f7f9a5fbd4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/processor/twostage/exchange/payload/CombineRequest.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/processor/twostage/exchange/payload/CombineRequest.java @@ -20,6 +20,7 @@ package org.apache.iotdb.db.pipe.processor.twostage.exchange.payload; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion; +import org.apache.iotdb.db.pipe.processor.twostage.state.CountState; import org.apache.iotdb.db.pipe.processor.twostage.state.State; import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq; @@ -109,7 +110,7 @@ private CombineRequest translateFromTPipeTransferReq(TPipeTransferReq transferRe combineId = ReadWriteIOUtils.readString(transferReq.body); final String stateClassName = ReadWriteIOUtils.readString(transferReq.body); - state = (State) Class.forName(stateClassName).newInstance(); + state = instantiateState(stateClassName); state.deserialize(transferReq.body); version = transferReq.version; @@ -118,6 +119,13 @@ private CombineRequest translateFromTPipeTransferReq(TPipeTransferReq transferRe return this; } + private State instantiateState(final String stateClassName) throws Exception { + if (CountState.class.getName().equals(stateClassName)) { + return new CountState(); + } + throw new IllegalArgumentException("Unexpected state class: " + stateClassName); + } + @Override public String toString() { return "CombineRequest{" diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java index d4af2135c95e0..1263e83ae08ec 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java @@ -22,6 +22,8 @@ import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.pipe.sink.payload.thrift.response.PipeTransferFilePieceResp; import org.apache.iotdb.commons.schema.SchemaConstant; +import org.apache.iotdb.db.pipe.processor.twostage.exchange.payload.CombineRequest; +import org.apache.iotdb.db.pipe.processor.twostage.state.CountState; import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferDataNodeHandshakeV1Req; import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferPlanNodeReq; import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferSchemaSnapshotPieceReq; @@ -43,6 +45,7 @@ import org.apache.iotdb.db.queryengine.plan.statement.Statement; import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertBaseStatement; import org.apache.iotdb.rpc.RpcUtils; +import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq; import org.apache.tsfile.common.conf.TSFileConfig; import org.apache.tsfile.enums.TSDataType; @@ -69,6 +72,61 @@ public class PipeDataNodeThriftRequestTest { private static final String TIME_PRECISION = "ms"; + @Test + public void testCombineRequest() throws Exception { + final CombineRequest req = + CombineRequest.toTPipeTransferReq("pipe", 1L, 2, "combine", new CountState(123L)); + final CombineRequest deserializeReq = CombineRequest.fromTPipeTransferReq(req); + + Assert.assertEquals(req.getVersion(), deserializeReq.getVersion()); + Assert.assertEquals(req.getType(), deserializeReq.getType()); + Assert.assertEquals("pipe", deserializeReq.getPipeName()); + Assert.assertEquals(1L, deserializeReq.getCreationTime()); + Assert.assertEquals(2, deserializeReq.getRegionId()); + Assert.assertEquals("combine", deserializeReq.getCombineId()); + Assert.assertTrue(deserializeReq.getState() instanceof CountState); + Assert.assertEquals(123L, ((CountState) deserializeReq.getState()).getCount()); + } + + @Test + public void testCombineRequestWithUnexpectedStateClassName() throws Exception { + final CombineRequest req = + CombineRequest.toTPipeTransferReq("pipe", 1L, 2, "combine", new CountState(123L)); + + final ByteBuffer bodyBuffer = req.body.duplicate(); + final String pipeName = ReadWriteIOUtils.readString(bodyBuffer); + final long creationTime = ReadWriteIOUtils.readLong(bodyBuffer); + final int regionId = ReadWriteIOUtils.readInt(bodyBuffer); + final String combineId = ReadWriteIOUtils.readString(bodyBuffer); + ReadWriteIOUtils.readString(bodyBuffer); + final long count = ReadWriteIOUtils.readLong(bodyBuffer); + + final ByteBuffer tamperedBody; + try (final PublicBAOS byteArrayOutputStream = new PublicBAOS(); + final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) { + ReadWriteIOUtils.write(pipeName, outputStream); + ReadWriteIOUtils.write(creationTime, outputStream); + ReadWriteIOUtils.write(regionId, outputStream); + ReadWriteIOUtils.write(combineId, outputStream); + ReadWriteIOUtils.write("java.lang.String", outputStream); + ReadWriteIOUtils.write(count, outputStream); + tamperedBody = + ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size()); + } + + final TPipeTransferReq tamperedReq = new TPipeTransferReq(); + tamperedReq.version = req.version; + tamperedReq.type = req.type; + tamperedReq.body = tamperedBody; + + try { + CombineRequest.fromTPipeTransferReq(tamperedReq); + Assert.fail("Expected IllegalArgumentException"); + } catch (final IllegalArgumentException e) { + Assert.assertTrue(e.getMessage().contains("Unexpected state class")); + } + } + @Test public void testPipeTransferDataNodeHandshakeReq() throws IOException { final PipeTransferDataNodeHandshakeV1Req req =