Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,107 @@ public void testCapacityFunction() {
expectedHeader,
retArray,
DATABASE_NAME);

// CAPACITY with SLIDE=2 (same as SIZE=2, should behave identically to no SLIDE)
expectedHeader = new String[] {"window_index", "time", "stock_id", "price", "s1"};
retArray =
new String[] {
"0,2021-01-01T09:05:00.000Z,AAPL,100.0,101.0,",
"0,2021-01-01T09:07:00.000Z,AAPL,103.0,101.0,",
"1,2021-01-01T09:09:00.000Z,AAPL,102.0,101.0,",
"0,2021-01-01T09:06:00.000Z,TESL,200.0,102.0,",
"0,2021-01-01T09:07:00.000Z,TESL,202.0,202.0,",
"1,2021-01-01T09:15:00.000Z,TESL,195.0,332.0,",
};
tableResultSetEqualTest(
"SELECT * FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 2, SLIDE => 2) ORDER BY stock_id, time",
expectedHeader,
retArray,
DATABASE_NAME);

// CAPACITY with SIZE=2, SLIDE=1 (overlapping windows)
expectedHeader = new String[] {"window_index", "time", "stock_id", "price", "s1"};
retArray =
new String[] {
"0,2021-01-01T09:05:00.000Z,AAPL,100.0,101.0,",
"0,2021-01-01T09:07:00.000Z,AAPL,103.0,101.0,",
"1,2021-01-01T09:07:00.000Z,AAPL,103.0,101.0,",
"1,2021-01-01T09:09:00.000Z,AAPL,102.0,101.0,",
"2,2021-01-01T09:09:00.000Z,AAPL,102.0,101.0,",
"0,2021-01-01T09:06:00.000Z,TESL,200.0,102.0,",
"0,2021-01-01T09:07:00.000Z,TESL,202.0,202.0,",
"1,2021-01-01T09:07:00.000Z,TESL,202.0,202.0,",
"1,2021-01-01T09:15:00.000Z,TESL,195.0,332.0,",
"2,2021-01-01T09:15:00.000Z,TESL,195.0,332.0,",
};
tableResultSetEqualTest(
"SELECT * FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 2, SLIDE => 1) ORDER BY stock_id, window_index, time",
expectedHeader,
retArray,
DATABASE_NAME);

// CAPACITY with SIZE=3, SLIDE=2 (overlapping windows, different params)
expectedHeader = new String[] {"window_index", "time", "stock_id", "price", "s1"};
retArray =
new String[] {
"0,2021-01-01T09:05:00.000Z,AAPL,100.0,101.0,",
"0,2021-01-01T09:07:00.000Z,AAPL,103.0,101.0,",
"0,2021-01-01T09:09:00.000Z,AAPL,102.0,101.0,",
"1,2021-01-01T09:09:00.000Z,AAPL,102.0,101.0,",
"0,2021-01-01T09:06:00.000Z,TESL,200.0,102.0,",
"0,2021-01-01T09:07:00.000Z,TESL,202.0,202.0,",
"0,2021-01-01T09:15:00.000Z,TESL,195.0,332.0,",
"1,2021-01-01T09:15:00.000Z,TESL,195.0,332.0,",
};
tableResultSetEqualTest(
"SELECT * FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 3, SLIDE => 2) ORDER BY stock_id, window_index, time",
expectedHeader,
retArray,
DATABASE_NAME);

// CAPACITY with SIZE=2, SLIDE=3 (gap windows, some rows discarded)
expectedHeader = new String[] {"window_index", "time", "stock_id", "price", "s1"};
retArray =
new String[] {
"0,2021-01-01T09:05:00.000Z,AAPL,100.0,101.0,",
"0,2021-01-01T09:07:00.000Z,AAPL,103.0,101.0,",
"0,2021-01-01T09:06:00.000Z,TESL,200.0,102.0,",
"0,2021-01-01T09:07:00.000Z,TESL,202.0,202.0,",
};
tableResultSetEqualTest(
"SELECT * FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 2, SLIDE => 3) ORDER BY stock_id, window_index, time",
expectedHeader,
retArray,
DATABASE_NAME);

// CAPACITY with SIZE=2, SLIDE=1 + GROUP BY (verify aggregation with overlapping windows)
expectedHeader = new String[] {"stock_id", "window_index", "avg"};
retArray =
new String[] {
"AAPL,0,101.5,",
"AAPL,1,102.5,",
"AAPL,2,102.0,",
"TESL,0,201.0,",
"TESL,1,198.5,",
"TESL,2,195.0,",
};
tableResultSetEqualTest(
"SELECT stock_id, window_index, avg(price) as avg FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 2, SLIDE => 1) GROUP BY window_index, stock_id ORDER BY stock_id, window_index",
expectedHeader,
retArray,
DATABASE_NAME);

// CAPACITY with negative SLIDE (error case)
tableAssertTestFail(
"SELECT * FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 2, SLIDE => -1) ORDER BY stock_id, time",
"Invalid scalar argument SLIDE, should be a positive value",
DATABASE_NAME);

// CAPACITY with SLIDE=0 (error case)
tableAssertTestFail(
"SELECT * FROM CAPACITY(DATA => bid PARTITION BY stock_id ORDER BY time, SIZE => 2, SLIDE => 0) ORDER BY stock_id, time",
"Invalid scalar argument SLIDE, should be a positive value",
DATABASE_NAME);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@
import java.util.List;
import java.util.Map;

import static org.apache.iotdb.udf.api.relational.table.argument.ScalarArgumentChecker.POSITIVE_LONG_CHECKER;

public class CapacityTableFunction implements TableFunction {
private static final String DATA_PARAMETER_NAME = "DATA";
private static final String SIZE_PARAMETER_NAME = "SIZE";
private static final String SLIDE_PARAMETER_NAME = "SLIDE";

@Override
public List<ParameterSpecification> getArgumentsSpecifications() {
Expand All @@ -53,7 +56,17 @@ public List<ParameterSpecification> getArgumentsSpecifications() {
.name(DATA_PARAMETER_NAME)
.passThroughColumns()
.build(),
ScalarParameterSpecification.builder().name(SIZE_PARAMETER_NAME).type(Type.INT64).build());
ScalarParameterSpecification.builder()
.name(SIZE_PARAMETER_NAME)
.addChecker(POSITIVE_LONG_CHECKER)
.type(Type.INT64)
.build(),
ScalarParameterSpecification.builder()
.name(SLIDE_PARAMETER_NAME)
.addChecker(POSITIVE_LONG_CHECKER)
.type(Type.INT64)
.defaultValue(-1L)
.build());
}

@Override
Expand All @@ -62,8 +75,16 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws UDF
if (size <= 0) {
throw new UDFException("Size must be greater than 0");
}
long slide = (long) ((ScalarArgument) arguments.get(SLIDE_PARAMETER_NAME)).getValue();
// default SLIDE to SIZE when not specified (sentinel value -1)
if (slide == -1L) {
slide = size;
}
MapTableFunctionHandle handle =
new MapTableFunctionHandle.Builder().addProperty(SIZE_PARAMETER_NAME, size).build();
new MapTableFunctionHandle.Builder()
.addProperty(SIZE_PARAMETER_NAME, size)
.addProperty(SLIDE_PARAMETER_NAME, slide)
.build();
return TableFunctionAnalysis.builder()
.properColumnSchema(
new DescribedSchema.Builder().addField("window_index", Type.INT64).build())
Expand All @@ -81,52 +102,48 @@ public TableFunctionHandle createTableFunctionHandle() {
@Override
public TableFunctionProcessorProvider getProcessorProvider(
TableFunctionHandle tableFunctionHandle) {
long sz =
(long) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(SIZE_PARAMETER_NAME);
MapTableFunctionHandle handle = (MapTableFunctionHandle) tableFunctionHandle;
long size = (long) handle.getProperty(SIZE_PARAMETER_NAME);
long slide = (long) handle.getProperty(SLIDE_PARAMETER_NAME);
return new TableFunctionProcessorProvider() {
@Override
public TableFunctionDataProcessor getDataProcessor() {
return new CapacityDataProcessor(sz);
return new CapacityDataProcessor(size, slide);
}
};
}

private static class CapacityDataProcessor implements TableFunctionDataProcessor {

private final long size;
private long currentStartIndex = 0;
private final long slide;
private long curIndex = 0;
private long windowIndex = 0;

public CapacityDataProcessor(long size) {
public CapacityDataProcessor(long size, long slide) {
this.size = size;
this.slide = slide;
}

@Override
public void process(
Record input,
List<ColumnBuilder> properColumnBuilders,
ColumnBuilder passThroughIndexBuilder) {
if (curIndex - currentStartIndex == size) {
outputWindow(properColumnBuilders, passThroughIndexBuilder);
currentStartIndex = curIndex;
// For each row at curIndex, find all windows k such that:
// k * slide <= curIndex < k * slide + size, and k >= 0
// The first valid k: max(0, ceil((curIndex - size + 1) / slide))
// The last valid k: floor(curIndex / slide)
long firstWindow = Math.max(0, (curIndex - size + slide) / slide);
long lastWindow = curIndex / slide;
for (long k = firstWindow; k <= lastWindow; k++) {
// Verify: k * slide <= curIndex < k * slide + size
long windowStart = k * slide;
if (windowStart <= curIndex && curIndex < windowStart + size) {
properColumnBuilders.get(0).writeLong(k);
passThroughIndexBuilder.writeLong(curIndex);
}
}
curIndex++;
}

@Override
public void finish(
List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) {
outputWindow(properColumnBuilders, passThroughIndexBuilder);
}

private void outputWindow(
List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) {
for (long i = currentStartIndex; i < curIndex; i++) {
properColumnBuilders.get(0).writeLong(windowIndex);
passThroughIndexBuilder.writeLong(i);
}
windowIndex++;
}
}
}
Loading
Loading