diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java index bd44d5db4f6c4..aa635d8288db0 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java @@ -4388,6 +4388,42 @@ public void approxPercentileTest() { DATABASE_NAME); } + @Test + public void percentileTest() { + tableResultSetEqualTest( + "select percentile(time, 0.5),percentile(s1,0.5),percentile(s2,0.5),percentile(s3,0.5),percentile(s4,0.5),percentile(s9,0.5) from table1", + buildHeaders(6), + new String[] {"2024-09-24T06:15:40.000Z,40,43000,37.5,43.0,2024-09-24T06:15:40.000Z,"}, + DATABASE_NAME); + + tableResultSetEqualTest( + "select time,province,percentile(time, 0.5),percentile(s1,0.5),percentile(s2,0.5) from table1 group by 1,2 order by 2,1", + new String[] {"time", "province", "_col2", "_col3", "_col4"}, + new String[] { + "2024-09-24T06:15:30.000Z,beijing,2024-09-24T06:15:30.000Z,30,null,", + "2024-09-24T06:15:31.000Z,beijing,2024-09-24T06:15:31.000Z,null,31000,", + "2024-09-24T06:15:35.000Z,beijing,2024-09-24T06:15:35.000Z,null,35000,", + "2024-09-24T06:15:36.000Z,beijing,2024-09-24T06:15:36.000Z,36,null,", + "2024-09-24T06:15:40.000Z,beijing,2024-09-24T06:15:40.000Z,40,40000,", + "2024-09-24T06:15:41.000Z,beijing,2024-09-24T06:15:41.000Z,41,null,", + "2024-09-24T06:15:46.000Z,beijing,2024-09-24T06:15:46.000Z,null,46000,", + "2024-09-24T06:15:50.000Z,beijing,2024-09-24T06:15:50.000Z,null,50000,", + "2024-09-24T06:15:51.000Z,beijing,2024-09-24T06:15:51.000Z,null,null,", + "2024-09-24T06:15:55.000Z,beijing,2024-09-24T06:15:55.000Z,55,null,", + "2024-09-24T06:15:30.000Z,shanghai,2024-09-24T06:15:30.000Z,30,null,", + "2024-09-24T06:15:31.000Z,shanghai,2024-09-24T06:15:31.000Z,null,31000,", + "2024-09-24T06:15:35.000Z,shanghai,2024-09-24T06:15:35.000Z,null,35000,", + "2024-09-24T06:15:36.000Z,shanghai,2024-09-24T06:15:36.000Z,36,null,", + "2024-09-24T06:15:40.000Z,shanghai,2024-09-24T06:15:40.000Z,40,40000,", + "2024-09-24T06:15:41.000Z,shanghai,2024-09-24T06:15:41.000Z,41,null,", + "2024-09-24T06:15:46.000Z,shanghai,2024-09-24T06:15:46.000Z,null,46000,", + "2024-09-24T06:15:50.000Z,shanghai,2024-09-24T06:15:50.000Z,null,50000,", + "2024-09-24T06:15:51.000Z,shanghai,2024-09-24T06:15:51.000Z,null,null,", + "2024-09-24T06:15:55.000Z,shanghai,2024-09-24T06:15:55.000Z,55,null,", + }, + DATABASE_NAME); + } + @Test public void exceptionTest() { tableAssertTestFail( @@ -4478,6 +4514,22 @@ public void exceptionTest() { "select 1 as g, approx_percentile(s1,s2,0.5) from table1 group by 1", "701: Aggregation functions [approx_percentile] do not support weight as INT64 type", DATABASE_NAME); + tableAssertTestFail( + "select percentile() from table1", + "701: Aggregation functions [percentile] should only have two arguments", + DATABASE_NAME); + tableAssertTestFail( + "select percentile(s1,1.1) from table1", + "701: percentage should be in [0,1], got 1.1", + DATABASE_NAME); + tableAssertTestFail( + "select percentile(s1,'test') from table1", + "701: The second argument of 'percentile' function percentage must be a double literal", + DATABASE_NAME); + tableAssertTestFail( + "select percentile(s5,0.5) from table1", + "701: Aggregation functions [percentile] should have value column as numeric type [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + DATABASE_NAME); } // ================================================================== diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/Percentile.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/Percentile.java new file mode 100644 index 0000000000000..5e02d0a1def2c --- /dev/null +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/Percentile.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.calc.execution.operator.source.relational; + +import org.apache.iotdb.commons.exception.SemanticException; + +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +public class Percentile { + private double[] values; + private int size; + private int capacity; + private boolean sorted; + + private static final int INITIAL_CAPACITY = 32; + private static final double GROWTH_FACTOR = 1.5; + + public Percentile() { + this.capacity = INITIAL_CAPACITY; + this.values = new double[capacity]; + this.size = 0; + this.sorted = true; + } + + public void addValue(double value) { + ensureCapacity(); + values[size++] = value; + sorted = false; + } + + public void addValues(double... vals) { + if (vals == null || vals.length == 0) { + return; + } + + int newSize = size + vals.length; + if (newSize > capacity) { + grow(newSize); + } + + System.arraycopy(vals, 0, values, size, vals.length); + size = newSize; + sorted = false; + } + + public void merge(Percentile other) { + if (other == null || other.size == 0) { + return; + } + + int newSize = size + other.size; + if (newSize > capacity) { + grow(newSize); + } + + System.arraycopy(other.values, 0, values, size, other.size); + size = newSize; + sorted = false; + } + + public double getPercentile(double percentile) { + if (size == 0) { + return Double.NaN; + } + if (percentile < 0.0 || percentile > 1.0) { + throw new SemanticException("percentage should be in [0,1], got " + percentile); + } + + ensureSorted(); + + if (size == 1) { + return values[0]; + } + + double realIndex = percentile * (size - 1); + int index = (int) realIndex; + double fraction = realIndex - index; + + if (index >= size - 1) { + return values[size - 1]; + } + + return values[index] + fraction * (values[index + 1] - values[index]); + } + + public int getSize() { + return size; + } + + public void clear() { + // Shrink the backing array back to the initial capacity so the memory held by a large group is + // actually released on reset, instead of staying reserved at the historical peak capacity. + if (capacity > INITIAL_CAPACITY) { + capacity = INITIAL_CAPACITY; + values = new double[capacity]; + } + size = 0; + sorted = true; + } + + private void ensureCapacity() { + if (size >= capacity) { + grow(size + 1); + } + } + + private void grow(int minCapacity) { + int newCapacity = Math.max((int) (capacity * GROWTH_FACTOR), minCapacity); + double[] newValues = new double[newCapacity]; + System.arraycopy(values, 0, newValues, 0, size); + values = newValues; + capacity = newCapacity; + } + + private void ensureSorted() { + if (!sorted && size > 1) { + Arrays.sort(values, 0, size); + sorted = true; + } + } + + public void serialize(ByteBuffer buffer) { + ReadWriteIOUtils.write(size, buffer); + for (int i = 0; i < size; i++) { + ReadWriteIOUtils.write(values[i], buffer); + } + } + + public static Percentile deserialize(ByteBuffer buffer) { + int size = ReadWriteIOUtils.readInt(buffer); + Percentile percentile = new Percentile(); + if (size > percentile.capacity) { + percentile.capacity = size; + percentile.values = new double[size]; + } + percentile.size = size; + for (int i = 0; i < size; i++) { + percentile.values[i] = ReadWriteIOUtils.readDouble(buffer); + } + percentile.sorted = false; + return percentile; + } + + public int getSerializedSize() { + return Integer.BYTES + (int) ((long) size * Double.BYTES); + } + + public long getEstimatedSize() { + return RamUsageEstimator.shallowSizeOfInstance(Percentile.class) + + (long) capacity * Double.BYTES; + } +} diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java index f4c0d98ce65dc..a6388a135d945 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -50,12 +50,14 @@ import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedRegressionAccumulator; +import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedPercentileAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.IntGroupedApproxMostFrequentAccumulator; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.LongGroupedApproxMostFrequentAccumulator; import org.apache.iotdb.calc.i18n.CalcMessages; +import org.apache.iotdb.calc.plan.planner.memory.MemoryReservationManager; import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.queryengine.execution.operator.source.relational.aggregation.grouped.UpdateMemory; import org.apache.iotdb.commons.queryengine.execution.operator.source.relational.aggregation.grouped.hash.MarkDistinctHash; @@ -106,7 +108,8 @@ public static TableAccumulator createAccumulator( boolean isAggTableScan, String timeColumnName, Set measurementColumnNames, - boolean distinct) { + boolean distinct, + MemoryReservationManager memoryReservationManager) { TableAccumulator result; // Input expression size of 1 indicates aggregation split has occurred and this is a final @@ -166,7 +169,7 @@ public static TableAccumulator createAccumulator( ? new FirstAccumulator(inputDataTypes.get(0), isAggTableScan) : new FirstDescAccumulator(inputDataTypes.get(0)); } else { - result = createBuiltinAccumulator(aggregationType, inputDataTypes); + result = createBuiltinAccumulator(aggregationType, inputDataTypes, memoryReservationManager); } if (distinct) { @@ -188,7 +191,8 @@ public static GroupedAccumulator createGroupedAccumulator( List inputExpressions, Map inputAttributes, boolean ascending, - boolean distinct) { + boolean distinct, + MemoryReservationManager memoryReservationManager) { GroupedAccumulator result; if (aggregationType == TAggregationType.UDAF) { @@ -197,7 +201,12 @@ public static GroupedAccumulator createGroupedAccumulator( } else { result = createBuiltinGroupedAccumulator( - aggregationType, inputDataTypes, inputExpressions, inputAttributes, ascending); + aggregationType, + inputDataTypes, + inputExpressions, + inputAttributes, + ascending, + memoryReservationManager); } if (distinct) { @@ -242,7 +251,8 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( List inputDataTypes, List inputExpressions, Map inputAttributes, - boolean ascending) { + boolean ascending, + MemoryReservationManager memoryReservationManager) { switch (aggregationType) { case COUNT: return new GroupedCountAccumulator(); @@ -326,6 +336,8 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( case KURTOSIS: return new GroupedCentralMomentAccumulator( inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); + case PERCENTILE: + return new GroupedPercentileAccumulator(inputDataTypes.get(0), memoryReservationManager); default: throw new IllegalArgumentException( CalcMessages.INVALID_AGGREGATION_FUNCTION + aggregationType); @@ -333,7 +345,9 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( } public static TableAccumulator createBuiltinAccumulator( - TAggregationType aggregationType, List inputDataTypes) { + TAggregationType aggregationType, + List inputDataTypes, + MemoryReservationManager memoryReservationManager) { switch (aggregationType) { case COUNT: return new CountAccumulator(); @@ -418,6 +432,8 @@ public static TableAccumulator createBuiltinAccumulator( case KURTOSIS: return new TableCentralMomentAccumulator( inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); + case PERCENTILE: + return new PercentileAccumulator(inputDataTypes.get(0), memoryReservationManager); default: throw new IllegalArgumentException( CalcMessages.INVALID_AGGREGATION_FUNCTION + aggregationType); diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/PercentileAccumulator.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/PercentileAccumulator.java new file mode 100644 index 0000000000000..4a6b075d0c4f6 --- /dev/null +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/PercentileAccumulator.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.calc.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.calc.execution.operator.source.relational.Percentile; +import org.apache.iotdb.calc.plan.planner.memory.MemoryReservationManager; +import org.apache.iotdb.commons.exception.SemanticException; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +public class PercentileAccumulator implements TableAccumulator { + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(PercentileAccumulator.class); + + private final TSDataType seriesDataType; + private Percentile percentile = new Percentile(); + // percentage is a query-level constant; it is read once from the first input/intermediate and + // kept fixed afterwards, so it never gets reset to 0 by a later all-null batch. + private double percentage; + private boolean percentageInitialized; + + private final MemoryReservationManager memoryReservationManager; + private long previousPercentileSize; + + public PercentileAccumulator( + TSDataType seriesDataType, MemoryReservationManager memoryReservationManager) { + this.seriesDataType = seriesDataType; + this.memoryReservationManager = memoryReservationManager; + updateMemoryReservation(); + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + percentile.getEstimatedSize(); + } + + @Override + public TableAccumulator copy() { + return new PercentileAccumulator(seriesDataType, memoryReservationManager); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + if (arguments.length != 2) { + throw new SemanticException( + String.format("PERCENTILE requires 2 arguments, but got %d", arguments.length)); + } + if (!percentageInitialized) { + percentage = arguments[1].getDouble(0); + percentageInitialized = true; + } + switch (seriesDataType) { + case INT32: + addIntInput(arguments[0], mask); + break; + case INT64: + case TIMESTAMP: + addLongInput(arguments[0], mask); + break; + case FLOAT: + addFloatInput(arguments[0], mask); + break; + case DOUBLE: + addDoubleInput(arguments[0], mask); + break; + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Percentile Aggregation: %s", seriesDataType)); + } + updateMemoryReservation(); + } + + @Override + public void addIntermediate(Column argument) { + for (int i = 0; i < argument.getPositionCount(); i++) { + if (!argument.isNull(i)) { + byte[] data = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(data); + // Always consume the leading 8 bytes so the buffer position is correct for deserialize, + // but only keep the percentage once: every partial carries the same query-level constant. + double serializedPercentage = ReadWriteIOUtils.readDouble(buffer); + if (!percentageInitialized) { + percentage = serializedPercentage; + percentageInitialized = true; + } + percentile.merge(Percentile.deserialize(buffer)); + } + } + updateMemoryReservation(); + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + int percentileDataLength = percentile.getSerializedSize(); + // Use long arithmetic to avoid integer overflow + ByteBuffer buffer = ByteBuffer.allocate(Math.toIntExact(8L + percentileDataLength)); + ReadWriteIOUtils.write(percentage, buffer); + percentile.serialize(buffer); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + double result = percentile.getPercentile(percentage); + if (Double.isNaN(result)) { + columnBuilder.appendNull(); + return; + } + switch (seriesDataType) { + case INT32: + columnBuilder.writeInt((int) result); + break; + case INT64: + case TIMESTAMP: + columnBuilder.writeLong((long) result); + break; + case FLOAT: + columnBuilder.writeFloat((float) result); + break; + case DOUBLE: + columnBuilder.writeDouble(result); + break; + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in PERCENTILE Aggregation: %s", seriesDataType)); + } + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException("PercentileAccumulator does not support statistics"); + } + + @Override + public void reset() { + percentile = new Percentile(); + percentageInitialized = false; + updateMemoryReservation(); + } + + private void updateMemoryReservation() { + long currentSize = percentile.getEstimatedSize(); + long delta = currentSize - previousPercentileSize; + if (delta > 0) { + memoryReservationManager.reserveMemoryCumulatively(delta); + } else if (delta < 0) { + memoryReservationManager.releaseMemoryCumulatively(-delta); + } + previousPercentileSize = currentSize; + } + + private void addIntInput(Column column, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!column.isNull(i)) { + percentile.addValue(column.getInt(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + if (!column.isNull(position)) { + percentile.addValue(column.getInt(position)); + } + } + } + } + + private void addLongInput(Column column, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!column.isNull(i)) { + percentile.addValue(column.getLong(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + if (!column.isNull(position)) { + percentile.addValue(column.getLong(position)); + } + } + } + } + + private void addFloatInput(Column column, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!column.isNull(i)) { + percentile.addValue(column.getFloat(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + if (!column.isNull(position)) { + percentile.addValue(column.getFloat(position)); + } + } + } + } + + private void addDoubleInput(Column column, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!column.isNull(i)) { + percentile.addValue(column.getDouble(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + if (!column.isNull(position)) { + percentile.addValue(column.getDouble(position)); + } + } + } + } +} diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedPercentileAccumulator.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedPercentileAccumulator.java new file mode 100644 index 0000000000000..79b9017d1d2e9 --- /dev/null +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedPercentileAccumulator.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.calc.execution.operator.source.relational.Percentile; +import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.array.PercentileBigArray; +import org.apache.iotdb.calc.plan.planner.memory.MemoryReservationManager; +import org.apache.iotdb.commons.exception.SemanticException; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +public class GroupedPercentileAccumulator implements GroupedAccumulator { + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedPercentileAccumulator.class); + private final TSDataType seriesDataType; + // percentage is a query-level constant; it is read once from the first input/intermediate and + // kept fixed afterwards, so it never gets reset to 0 by a later all-null batch. + private double percentage; + private boolean percentageInitialized; + private final MemoryReservationManager memoryReservationManager; + private long previousArraySize; + private final PercentileBigArray array = new PercentileBigArray(); + + public GroupedPercentileAccumulator( + TSDataType seriesDataType, MemoryReservationManager memoryReservationManager) { + this.seriesDataType = seriesDataType; + this.memoryReservationManager = memoryReservationManager; + updateMemoryReservation(); + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + array.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + array.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + if (arguments.length != 2) { + throw new SemanticException( + String.format("PERCENTILE requires 2 arguments, but got %d", arguments.length)); + } + if (!percentageInitialized) { + percentage = arguments[1].getDouble(0); + percentageInitialized = true; + } + + switch (seriesDataType) { + case INT32: + addIntInput(groupIds, arguments, mask); + break; + case INT64: + case TIMESTAMP: + addLongInput(groupIds, arguments, mask); + break; + case FLOAT: + addFloatInput(groupIds, arguments, mask); + break; + case DOUBLE: + addDoubleInput(groupIds, arguments, mask); + break; + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in PERCENTILE Aggregation: %s", seriesDataType)); + } + updateMemoryReservation(); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + for (int i = 0; i < groupIds.length; i++) { + int groupId = groupIds[i]; + if (!argument.isNull(i)) { + byte[] data = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(data); + // Always consume the leading 8 bytes so the buffer position is correct for deserialize, + // but only keep the percentage once: every partial carries the same query-level constant. + double serializedPercentage = ReadWriteIOUtils.readDouble(buffer); + if (!percentageInitialized) { + percentage = serializedPercentage; + percentageInitialized = true; + } + Percentile other = Percentile.deserialize(buffer); + array.get(groupId).merge(other); + } + } + updateMemoryReservation(); + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + Percentile percentile = array.get(groupId); + int percentileDataLength = percentile.getSerializedSize(); + // Use long arithmetic to avoid integer overflow + ByteBuffer buffer = ByteBuffer.allocate(Math.toIntExact(8L + percentileDataLength)); + ReadWriteIOUtils.write(percentage, buffer); + percentile.serialize(buffer); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + Percentile percentile = array.get(groupId); + double result = percentile.getPercentile(percentage); + if (Double.isNaN(result)) { + columnBuilder.appendNull(); + return; + } + switch (seriesDataType) { + case INT32: + columnBuilder.writeInt((int) result); + break; + case INT64: + case TIMESTAMP: + columnBuilder.writeLong((long) result); + break; + case FLOAT: + columnBuilder.writeFloat((float) result); + break; + case DOUBLE: + columnBuilder.writeDouble(result); + break; + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in PERCENTILE Aggregation: %s", seriesDataType)); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + array.reset(); + percentageInitialized = false; + updateMemoryReservation(); + } + + private void updateMemoryReservation() { + long currentSize = array.sizeOf(); + long delta = currentSize - previousArraySize; + if (delta > 0) { + memoryReservationManager.reserveMemoryCumulatively(delta); + } else if (delta < 0) { + memoryReservationManager.releaseMemoryCumulatively(-delta); + } + previousArraySize = currentSize; + } + + public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + Column valueColumn = arguments[0]; + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + int groupId = groupIds[i]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(i)) { + percentile.addValue(valueColumn.getInt(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + int groupId; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + groupId = groupIds[position]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(position)) { + percentile.addValue(valueColumn.getInt(position)); + } + } + } + } + + public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + Column valueColumn = arguments[0]; + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + int groupId = groupIds[i]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(i)) { + percentile.addValue(valueColumn.getLong(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + int groupId; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + groupId = groupIds[position]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(position)) { + percentile.addValue(valueColumn.getLong(position)); + } + } + } + } + + public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + Column valueColumn = arguments[0]; + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + int groupId = groupIds[i]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(i)) { + percentile.addValue(valueColumn.getFloat(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + int groupId; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + groupId = groupIds[position]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(position)) { + percentile.addValue(valueColumn.getFloat(position)); + } + } + } + } + + public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + Column valueColumn = arguments[0]; + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + int groupId = groupIds[i]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(i)) { + percentile.addValue(valueColumn.getDouble(i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + int position; + int groupId; + for (int i = 0; i < positionCount; i++) { + position = selectedPositions[i]; + groupId = groupIds[position]; + Percentile percentile = array.get(groupId); + if (!valueColumn.isNull(position)) { + percentile.addValue(valueColumn.getDouble(position)); + } + } + } + } +} diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/array/PercentileBigArray.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/array/PercentileBigArray.java new file mode 100644 index 0000000000000..7dad32c43d99e --- /dev/null +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/array/PercentileBigArray.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.array; + +import org.apache.iotdb.calc.execution.operator.source.relational.Percentile; + +import static org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOf; +import static org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance; + +public final class PercentileBigArray { + private static final long INSTANCE_SIZE = shallowSizeOfInstance(PercentileBigArray.class); + private final ObjectBigArray array; + + public PercentileBigArray() { + array = new ObjectBigArray<>(); + } + + /** + * Unlike fixed-size sketches (e.g. TDigest), each {@link Percentile} stores all raw values and + * grows unboundedly as values are appended through {@link #get(long)}. Caching the retained size + * and only refreshing it on {@code set} would therefore drift far below the real footprint, so we + * sum the live estimated size of every Percentile on demand to keep memory accounting accurate. + */ + public long sizeOf() { + long[] sizeOfPercentile = {0}; + array.forEach( + item -> { + if (item != null) { + sizeOfPercentile[0] += item.getEstimatedSize(); + } + }); + return INSTANCE_SIZE + shallowSizeOf(array) + sizeOfPercentile[0]; + } + + public Percentile get(long index) { + Percentile percentile = array.get(index); + if (percentile == null) { + percentile = new Percentile(); + array.set(index, percentile); + } + return percentile; + } + + public void ensureCapacity(long length) { + array.ensureCapacity(length); + } + + public void reset() { + array.forEach( + item -> { + if (item != null) { + item.clear(); + } + }); + } +} diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/plan/planner/TableOperatorGenerator.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/plan/planner/TableOperatorGenerator.java index a59016030e7c8..13e91af2f197f 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/plan/planner/TableOperatorGenerator.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/plan/planner/TableOperatorGenerator.java @@ -95,6 +95,7 @@ import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.grouped.StreamingHashAggregationOperator; import org.apache.iotdb.calc.execution.relational.ColumnTransformerBuilder; import org.apache.iotdb.calc.i18n.CalcMessages; +import org.apache.iotdb.calc.plan.planner.memory.MemoryReservationManager; import org.apache.iotdb.calc.plan.relational.metadata.ITypeMetadata; import org.apache.iotdb.calc.plan.relational.planner.CastToBlobLiteralVisitor; import org.apache.iotdb.calc.plan.relational.planner.CastToBooleanLiteralVisitor; @@ -1327,7 +1328,8 @@ private Operator planGlobalAggregation( true, false, null, - Collections.emptySet()))); + Collections.emptySet(), + operatorContext.getMemoryReservationContext()))); return new AggregationOperator(operatorContext, child, aggregatorBuilder.build()); } @@ -1341,7 +1343,8 @@ protected TableAggregator buildAggregator( boolean scanAscending, boolean isAggTableScan, String timeColumnName, - Set measurementColumnNames) { + Set measurementColumnNames, + MemoryReservationManager memoryReservationManager) { List argumentChannels = new ArrayList<>(); for (Expression argument : aggregation.getArguments()) { Symbol argumentSymbol = Symbol.from(argument); @@ -1364,7 +1367,8 @@ protected TableAggregator buildAggregator( isAggTableScan, timeColumnName, measurementColumnNames, - aggregation.isDistinct()); + aggregation.isDistinct(), + memoryReservationManager); OptionalInt maskChannel = OptionalInt.empty(); if (aggregation.hasMask()) { @@ -1406,7 +1410,8 @@ protected Operator planGroupByAggregation( true, false, null, - Collections.emptySet()))); + Collections.emptySet(), + context.getMemoryReservationManager()))); CommonOperatorContext operatorContext = addOperatorContext( @@ -1428,7 +1433,13 @@ protected Operator planGroupByAggregation( .forEach( (k, v) -> aggregatorBuilder.add( - buildGroupByAggregator(childLayout, k, v, node.getStep(), typeProvider))); + buildGroupByAggregator( + childLayout, + k, + v, + node.getStep(), + typeProvider, + context.getMemoryReservationManager()))); Set preGroupedKeys = ImmutableSet.copyOf(node.getPreGroupedSymbols()); List groupingKeys = node.getGroupingKeys(); @@ -1479,7 +1490,13 @@ protected Operator planGroupByAggregation( .forEach( (k, v) -> aggregatorBuilder.add( - buildGroupByAggregator(childLayout, k, v, node.getStep(), typeProvider))); + buildGroupByAggregator( + childLayout, + k, + v, + node.getStep(), + typeProvider, + context.getMemoryReservationManager()))); CommonOperatorContext operatorContext = addOperatorContext( context, node.getPlanNodeId(), HashAggregationOperator.class.getSimpleName()); @@ -1520,7 +1537,8 @@ protected GroupedAggregator buildGroupByAggregator( Symbol symbol, AggregationNode.Aggregation aggregation, AggregationNode.Step step, - ITableTypeProvider typeProvider) { + ITableTypeProvider typeProvider, + MemoryReservationManager memoryReservationManager) { List argumentChannels = new ArrayList<>(); for (Expression argument : aggregation.getArguments()) { Symbol argumentSymbol = Symbol.from(argument); @@ -1540,7 +1558,8 @@ protected GroupedAggregator buildGroupByAggregator( Collections.emptyList(), Collections.emptyMap(), true, - aggregation.isDistinct()); + aggregation.isDistinct(), + memoryReservationManager); OptionalInt maskChannel = OptionalInt.empty(); if (aggregation.hasMask()) { @@ -1635,7 +1654,8 @@ private PatternAggregator buildPatternAggregator( ResolvedFunction resolvedFunction, List> arguments, List argumentChannels, - PatternAggregationTracker patternAggregationTracker) { + PatternAggregationTracker patternAggregationTracker, + MemoryReservationManager memoryReservationManager) { String functionName = resolvedFunction.getSignature().getName(); List originalArgumentTypes = resolvedFunction.getSignature().getArgumentTypes().stream() @@ -1643,7 +1663,10 @@ private PatternAggregator buildPatternAggregator( .collect(Collectors.toList()); TableAccumulator accumulator = - createBuiltinAccumulator(getAggregationTypeByFuncName(functionName), originalArgumentTypes); + createBuiltinAccumulator( + getAggregationTypeByFuncName(functionName), + originalArgumentTypes, + memoryReservationManager); BoundSignature signature = resolvedFunction.getSignature(); @@ -1804,7 +1827,11 @@ public Operator visitPatternRecognition(PatternRecognitionNode node, C context) PatternAggregator variableRecognizerAggregator = buildPatternAggregator( - resolvedFunction, arguments, valueChannels, patternAggregationTracker); + resolvedFunction, + arguments, + valueChannels, + patternAggregationTracker, + context.getMemoryReservationManager()); variableRecognizerAggregatorBuilder.add(variableRecognizerAggregator); @@ -1895,7 +1922,11 @@ public Operator visitPatternRecognition(PatternRecognitionNode node, C context) PatternAggregator measurePatternAggregator = buildPatternAggregator( - resolvedFunction, arguments, valueChannels, patternAggregationTracker); + resolvedFunction, + arguments, + valueChannels, + patternAggregationTracker, + context.getMemoryReservationManager()); measurePatternAggregatorBuilder.add(measurePatternAggregator); @@ -2133,7 +2164,12 @@ public Operator visitWindowFunction(WindowNode node, C context) { FunctionKind functionKind = resolvedFunction.getFunctionKind(); if (functionKind == FunctionKind.AGGREGATE) { WindowAggregator tableWindowAggregator = - buildWindowAggregator(symbol, function, typeProvider, argumentChannels); + buildWindowAggregator( + symbol, + function, + typeProvider, + argumentChannels, + context.getMemoryReservationManager()); windowFunction = new AggregationWindowFunction(tableWindowAggregator); } else if (functionKind == FunctionKind.WINDOW) { String functionName = function.getResolvedFunction().getSignature().getName(); @@ -2178,7 +2214,8 @@ private WindowAggregator buildWindowAggregator( Symbol symbol, WindowNode.Function function, ITableTypeProvider typeProvider, - List argumentChannels) { + List argumentChannels, + MemoryReservationManager memoryReservationManager) { // Create accumulator first String functionName = function.getResolvedFunction().getSignature().getName(); List originalArgumentTypes = @@ -2186,7 +2223,10 @@ private WindowAggregator buildWindowAggregator( .map(InternalTypeManager::getTSDataType) .collect(Collectors.toList()); TableAccumulator accumulator = - createBuiltinAccumulator(getAggregationTypeByFuncName(functionName), originalArgumentTypes); + createBuiltinAccumulator( + getAggregationTypeByFuncName(functionName), + originalArgumentTypes, + memoryReservationManager); // Create aggregator by accumulator return new WindowAggregator( diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/utils/constant/SqlConstant.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/utils/constant/SqlConstant.java index 9d015cb3ed111..d542528ec1edc 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/utils/constant/SqlConstant.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/utils/constant/SqlConstant.java @@ -88,6 +88,7 @@ protected SqlConstant() { public static final String APPROX_COUNT_DISTINCT = "approx_count_distinct"; public static final String APPROX_MOST_FREQUENT = "approx_most_frequent"; public static final String APPROX_PERCENTILE = "approx_percentile"; + public static final String PERCENTILE = "percentile"; // names of scalar functions public static final String DIFF = "diff"; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/DataNodeTableOperatorGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/DataNodeTableOperatorGenerator.java index b5de9a5561e23..d96ed0845f723 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/DataNodeTableOperatorGenerator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/DataNodeTableOperatorGenerator.java @@ -1463,7 +1463,8 @@ public Operator visitNonAlignedAggregationTreeDeviceViewScan( scanAscending, true, timeColumnName, - measurementColumnsIndexMap.keySet())); + measurementColumnsIndexMap.keySet(), + context.getMemoryReservationManager())); } ITableTimeRangeIterator timeRangeIterator = null; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 2f0f12b9f8c32..c023e69a374c7 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -76,6 +76,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; +import static org.apache.iotdb.calc.plan.relational.metadata.CommonMetadataUtils.isDecimalType; +import static org.apache.iotdb.calc.plan.relational.metadata.CommonMetadataUtils.isNumericType; import static org.apache.iotdb.calc.transformation.dag.column.FailFunctionColumnTransformer.FAIL_FUNCTION_NAME; import static org.apache.tsfile.read.common.type.BlobType.BLOB; import static org.apache.tsfile.read.common.type.BooleanType.BOOLEAN; @@ -217,7 +219,7 @@ public static Type getFunctionType(String functionName, List arg if (TableBuiltinScalarFunction.DIFF.getFunctionName().equalsIgnoreCase(functionName)) { if (!CommonMetadataUtils.isOneNumericType(argumentTypes) && !(argumentTypes.size() == 2 - && CommonMetadataUtils.isNumericType(argumentTypes.get(0)) + && isNumericType(argumentTypes.get(0)) && BOOLEAN.equals(argumentTypes.get(1)))) { throw new SemanticException( "Scalar function " @@ -1265,7 +1267,7 @@ public static Type getFunctionType(String functionName, List arg } Type valueColumnType = argumentTypes.get(0); - if (!CommonMetadataUtils.isNumericType(valueColumnType)) { + if (!isNumericType(valueColumnType)) { throw new SemanticException( String.format( "Aggregation functions [%s] should have value column as numeric type [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", @@ -1273,7 +1275,7 @@ public static Type getFunctionType(String functionName, List arg } Type percentageType = argumentTypes.get(argumentSize - 1); - if (!CommonMetadataUtils.isDecimalType(percentageType)) { + if (!isDecimalType(percentageType)) { throw new SemanticException( String.format( "Aggregation functions [%s] should have percentage as decimal type", @@ -1288,7 +1290,26 @@ public static Type getFunctionType(String functionName, List arg functionName, weightType.getDisplayName())); } } + break; + case SqlConstant.PERCENTILE: + if (argumentTypes.size() != 2) { + throw new SemanticException( + String.format( + "Aggregation functions [%s] should only have two arguments", functionName)); + } + if (!isNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregation functions [%s] should have value column as numeric type [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName)); + } + if (!isDecimalType(argumentTypes.get(1))) { + throw new SemanticException( + String.format( + "Aggregation functions [%s] should have percentage as decimal type", + functionName)); + } break; case SqlConstant.COUNT: break; @@ -1314,6 +1335,7 @@ public static Type getFunctionType(String functionName, List arg case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: case SqlConstant.APPROX_PERCENTILE: + case SqlConstant.PERCENTILE: return argumentTypes.get(0); case SqlConstant.AVG: case SqlConstant.SUM: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java index a19168b50f42f..6ed3893784456 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java @@ -317,6 +317,7 @@ import static org.apache.iotdb.calc.utils.constant.SqlConstant.APPROX_COUNT_DISTINCT; import static org.apache.iotdb.calc.utils.constant.SqlConstant.APPROX_MOST_FREQUENT; import static org.apache.iotdb.calc.utils.constant.SqlConstant.APPROX_PERCENTILE; +import static org.apache.iotdb.calc.utils.constant.SqlConstant.PERCENTILE; import static org.apache.iotdb.commons.queryengine.plan.relational.sql.ast.AnchorPattern.Type.PARTITION_END; import static org.apache.iotdb.commons.queryengine.plan.relational.sql.ast.AnchorPattern.Type.PARTITION_START; import static org.apache.iotdb.commons.queryengine.plan.relational.sql.ast.GroupingSets.Type.CUBE; @@ -3656,6 +3657,11 @@ public Node visitFunctionCall(RelationalSqlParser.FunctionCallContext ctx) { throw new SemanticException( "The third argument of 'approx_percentile' function percentage must be a double literal"); } + } else if (name.toString().equalsIgnoreCase(PERCENTILE)) { + if (arguments.size() == 2 && !(arguments.get(1) instanceof DoubleLiteral)) { + throw new SemanticException( + "The second argument of 'percentile' function percentage must be a double literal"); + } } return new FunctionCall(getLocation(ctx), name, window, nulls, distinct, mode, arguments); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/function/FunctionTestUtils.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/function/FunctionTestUtils.java index 9da867894fcd7..17bc22b2a0a8e 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/function/FunctionTestUtils.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/function/FunctionTestUtils.java @@ -27,6 +27,7 @@ import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.AccumulatorFactory; import org.apache.iotdb.calc.execution.operator.source.relational.aggregation.TableAccumulator; import org.apache.iotdb.common.rpc.thrift.TAggregationType; +import org.apache.iotdb.db.queryengine.plan.planner.memory.FakedMemoryReservationManager; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.common.block.TsBlock; @@ -101,7 +102,9 @@ public static AggregationWindowFunction createAggregationWindowFunction( // inputExpressions and inputAttributes are not used in this method TableAccumulator accumulator = AccumulatorFactory.createBuiltinAccumulator( - aggregationType, Collections.singletonList(inputDataType)); + aggregationType, + Collections.singletonList(inputDataType), + new FakedMemoryReservationManager()); WindowAggregator aggregator = new WindowAggregator(accumulator, outputDataType, Collections.singletonList(0)); return new AggregationWindowFunction(aggregator); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AggregationTableScanTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AggregationTableScanTest.java index 9dc5ecce8639c..872d755a5aec4 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AggregationTableScanTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AggregationTableScanTest.java @@ -32,6 +32,7 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.queryengine.plan.relational.sql.ast.Expression; import org.apache.iotdb.commons.queryengine.plan.relational.sql.ast.SymbolReference; +import org.apache.iotdb.db.queryengine.plan.planner.memory.FakedMemoryReservationManager; import org.apache.tsfile.enums.TSDataType; import org.junit.Test; @@ -161,7 +162,8 @@ private void doCreateAndAssert( isAggTableScan, TIME_COL, measurementColumnNames, - distinct); + distinct, + new FakedMemoryReservationManager()); String msg = String.format( diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationCornerCaseTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationCornerCaseTest.java index 424eb75118c79..d40c73dbf6a2f 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationCornerCaseTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationCornerCaseTest.java @@ -439,7 +439,8 @@ public long ramBytesUsed() { Collections.emptyList(), Collections.emptyMap(), true, - false), + false, + operatorContext.getMemoryReservationContext()), AggregationNode.Step.SINGLE, TSDataType.INT32, ImmutableList.of(1, 0), @@ -453,7 +454,8 @@ public long ramBytesUsed() { Collections.emptyList(), Collections.emptyMap(), true, - false), + false, + operatorContext.getMemoryReservationContext()), AggregationNode.Step.SINGLE, TSDataType.INT32, ImmutableList.of(1, 0), @@ -467,7 +469,8 @@ public long ramBytesUsed() { Collections.emptyList(), Collections.emptyMap(), true, - false), + false, + operatorContext.getMemoryReservationContext()), AggregationNode.Step.SINGLE, TSDataType.DOUBLE, ImmutableList.of(1), @@ -481,7 +484,8 @@ public long ramBytesUsed() { Collections.emptyList(), Collections.emptyMap(), true, - false), + false, + operatorContext.getMemoryReservationContext()), AggregationNode.Step.SINGLE, TSDataType.INT32, ImmutableList.of(1), diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index d1ff1f06a5d84..e8c5806fbc526 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -66,7 +66,8 @@ public enum TableBuiltinAggregationFunction { REGR_SLOPE("regr_slope"), REGR_INTERCEPT("regr_intercept"), SKEWNESS("skewness"), - KURTOSIS("kurtosis"); + KURTOSIS("kurtosis"), + PERCENTILE("percentile"); private final String functionName; @@ -120,6 +121,7 @@ public static Type getIntermediateType(String name, List originalArgumentT case "kurtosis": case "approx_count_distinct": case "approx_percentile": + case "percentile": return RowType.anonymous(Collections.emptyList()); case "extreme": case "max": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index c825d8900afea..a107b0259ff09 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -317,6 +317,7 @@ enum TAggregationType { REGR_INTERCEPT, SKEWNESS, KURTOSIS + PERCENTILE, } struct TShowConfigurationTemplateResp {