Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
Expand All @@ -31,6 +32,8 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator;

Expand Down Expand Up @@ -92,7 +95,27 @@ public RexNode visitCall(RexCall call)
}
} else if (isCoalesceWhenThen(rexBuilder.getTypeFactory(), caseArgs.get(i), caseArgs.get(i + 1))) {
// WHEN x IS NOT NULL THEN x
coalesceArgs.add(((RexCall) caseArgs.get(i)).getOperands().get(0));

// Use x from the 'when' arg, potentially with a cast to the type of 'then', ignoring nullability.
final RexNode whenIsNotNullArg = ((RexCall) caseArgs.get(i)).getOperands().get(0);
final RexNode thenArg = caseArgs.get(i + 1);
final boolean typesMatch = SqlTypeUtil.equalSansNullability(
rexBuilder.getTypeFactory(),
whenIsNotNullArg.getType(),
thenArg.getType()
);

if (typesMatch) {
coalesceArgs.add(whenIsNotNullArg);
} else {
coalesceArgs.add(
rexBuilder.makeCast(
rexBuilder.getTypeFactory()
.createTypeWithNullability(thenArg.getType(), whenIsNotNullArg.getType().isNullable()),
RexUtil.removeNullabilityCast(rexBuilder.getTypeFactory(), whenIsNotNullArg)
)
);
}
} else {
return super.visitCall(call);
}
Expand All @@ -106,7 +129,8 @@ public RexNode visitCall(RexCall call)
}

/**
* Returns whether "when" is like "then IS NOT NULL". Ignores nullability casts on "then".
* Returns whether "when" is like "then IS NOT NULL". Ignores irrelevant casts, as defined by
* {@link #isIrrelevantCast(RelDataTypeFactory, RexNode, RelDataType)}.
*/
private static boolean isCoalesceWhenThen(
final RelDataTypeFactory typeFactory,
Expand All @@ -115,11 +139,56 @@ private static boolean isCoalesceWhenThen(
)
{
if (when.isA(SqlKind.IS_NOT_NULL)) {
// Remove any casts that don't change the type name. (We don't do anything different during execution based on
// features of the type other than its name, so they can be safely ignored.)
final RexNode whenIsNotNullArg =
RexUtil.removeNullabilityCast(typeFactory, ((RexCall) when).getOperands().get(0));
return whenIsNotNullArg.equals(RexUtil.removeNullabilityCast(typeFactory, then));
removeIrrelevantCasts(typeFactory, ((RexCall) when).getOperands().get(0));
return whenIsNotNullArg.equals(removeIrrelevantCasts(typeFactory, then));
} else {
return false;
}
}

/**
* Remove any irrelevant casts, as defined by {@link #isIrrelevantCast(RelDataTypeFactory, RexNode, RelDataType)}.
*/
private static RexNode removeIrrelevantCasts(final RelDataTypeFactory typeFactory, final RexNode rexNode)
{
final RelDataType type = rexNode.getType();

RexNode retVal = rexNode;
while (isIrrelevantCast(typeFactory, retVal, type)) {
retVal = ((RexCall) retVal).operands.get(0);
}
return retVal;
}

/**
* Returns whether "rexNode" is a {@link SqlKind#CAST} that changes type in a way that is irrelevant to the
* CASE-to-COALESCE analysis done by {@link #isCoalesceWhenThen}. This means ignorning nullability, and ignoring type
* changes that don't affect runtime execution behavior.
*/
private static boolean isIrrelevantCast(
final RelDataTypeFactory typeFactory,
final RexNode rexNode,
final RelDataType castType
)
{
if (!rexNode.isA(SqlKind.CAST)) {
return false;
}
final RexNode argRexNode = ((RexCall) rexNode).getOperands().get(0);
final SqlTypeName typeName = argRexNode.getType().getSqlTypeName();
if (SqlTypeName.NUMERIC_TYPES.contains(typeName)
|| SqlTypeName.CHAR_TYPES.contains(typeName)
|| SqlTypeName.BOOLEAN_TYPES.contains(typeName)
|| SqlTypeName.DATETIME_TYPES.contains(typeName)
|| SqlTypeName.INTERVAL_TYPES.contains(typeName)) {
// For these types, we have no difference in runtime behavior that is affected by anything about the type
// other than its name.
return typeName == castType.getSqlTypeName();
} else {
return SqlTypeUtil.equalSansNullability(typeFactory, argRexNode.getType(), castType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2738,7 +2738,6 @@ public void testGroupByPathSelectorFilter()
@Test
public void testGroupByPathSelectorFilterCoalesce()
{
cannotVectorizeUnlessFallback();
testQuery(
"SELECT "
+ "JSON_VALUE(nest, '$.x'), "
Expand All @@ -2752,16 +2751,14 @@ public void testGroupByPathSelectorFilterCoalesce()
.setVirtualColumns(
new ExpressionVirtualColumn(
"v0",
"case_searched(notnull(\"v1\"),(\"v1\" == '100'),0)",
ColumnType.LONG,
"nvl(\"v1\",'0')",
ColumnType.STRING,
queryFramework().macroTable()
),
new NestedFieldVirtualColumn("nest", "$.x", "v1", ColumnType.STRING)
)
.setDimensions(new DefaultDimensionSpec("v1", "d0"))
.setDimFilter(
expressionFilter("\"v0\"")
)
.setDimFilter(equality("v0", "100", ColumnType.STRING))
.setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
Expand Down Expand Up @@ -3051,6 +3048,45 @@ public void testGroupByPathSelectorFilterVariant()
);
}

@Test
public void testGroupByCoalesceJsonValue()
{
testQuery(
"SELECT "
+ "COALESCE(JSON_VALUE(nest, '$.x'), 'unknown'), "
+ "SUM(cnt) "
+ "FROM druid.nested\n"
+ "GROUP BY 1",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(DATA_SOURCE)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn("v0", "nvl(\"v1\",'unknown')", ColumnType.STRING),
new NestedFieldVirtualColumn("nest", "$.x", "v1", ColumnType.STRING)
)
.setDimensions(
dimensions(
new DefaultDimensionSpec("v0", "d0")
)
)
.setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{"100", 2L},
new Object[]{"200", 1L},
new Object[]{"unknown", 4L}
),
RowSignature.builder()
.add("EXPR$0", ColumnType.STRING)
.add("EXPR$1", ColumnType.LONG)
.build()
);
}

@Test
public void testGroupByPathSelectorFilterVariant2()
{
Expand Down Expand Up @@ -6683,6 +6719,35 @@ public void testFilterJsonIsNull()
);
}

@Test
public void testFilterCoalesceJsonValue()
{
testQuery(
"SELECT "
+ "SUM(cnt) "
+ "FROM druid.nested\n"
+ "WHERE COALESCE(JSON_VALUE(nest, '$.x'), 'unknown') = '200'",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(DATA_SOURCE)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "nvl(\"v1\",'unknown')", ColumnType.STRING),
new NestedFieldVirtualColumn("nest", "$.x", "v1", ColumnType.STRING)
)
.filters(equality("v0", "200", ColumnType.STRING))
.aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L}),
RowSignature.builder()
.add("EXPR$0", ColumnType.LONG)
.build()
);
}

@Test
public void testCoalesceOnNestedColumns()
{
Expand Down
Loading