diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java index 04832a8079839..9ac57e63973ca 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java @@ -532,9 +532,20 @@ public Frame decorrelateRel(LogicalFilter rel) { unsupportedCorConditions); assert unsupportedCorConditions.isEmpty(); - final RexNode remainingCondition = + RexNode remainingCondition = RexUtil.composeConjunction(rexBuilder, nonCorConditions, false); + // Re-index the remaining (non-correlated) condition against the rewritten input. + // The child may have shifted its row type during decorrelation (e.g. an Aggregate + // injects correlated columns into its group key), so RexInputRefs in HAVING / + // Filter predicates that survive in nonCorConditions must be remapped through + // frame.oldToNewOutputs. Otherwise they silently point at the wrong column. + if (remainingCondition != null) { + remainingCondition = + adjustInputRefs( + remainingCondition, frame.oldToNewOutputs, frame.r.getRowType()); + } + // Using LogicalFilter.create instead of RelBuilder.filter to create Filter // because RelBuilder.filter method does not have VariablesSet arg. final RelNode newFilter = diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml index 831d380b7af38..47b74cfa36e2b 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml @@ -366,6 +366,35 @@ LogicalProject(a=[$0], b=[$1], c=[$2]) +- LogicalProject(e=[$1], f=[$2]) +- LogicalFilter(condition=[true]) +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + 10 GROUP BY r.f)]]> + + + ($1, 10))]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + ($1, 10)]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) ]]> @@ -446,6 +475,130 @@ LogicalProject(a=[$0], b=[$1]) +- LogicalProject(d=[$1]) +- LogicalFilter(condition=[true]) +- LogicalTableScan(table=[[default_catalog, default_database, y]]) +]]> + + + + + = 3)]]> + + + =($1, 3)]) + LogicalAggregate(group=[{0}], agg#0=[SUM($1)]) + LogicalProject(f=[$2], e=[$1]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($2, 3)]) + +- LogicalAggregate(group=[{0, 1}], agg#0=[SUM($2)]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + = 3 AND MAX(r.e) < 100)]]> + + + =($1, 3), <($2, 100))]) + LogicalAggregate(group=[{0}], agg#0=[SUM($1)], agg#1=[MAX($1)]) + LogicalProject(f=[$2], e=[$1]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($2, 3), <($3, 100))]) + +- LogicalAggregate(group=[{0, 1}], agg#0=[SUM($2)], agg#1=[MAX($2)]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + = 3 AND COUNT(*) > 1)]]> + + + =($1, 3), >($2, 1))]) + LogicalAggregate(group=[{0}], agg#0=[SUM($1)], agg#1=[COUNT()]) + LogicalProject(f=[$2], e=[$1]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($2, 3), >($3, 1))]) + +- LogicalAggregate(group=[{0, 1}], agg#0=[SUM($2)], agg#1=[COUNT()]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + = 2)]]> + + + =($1, 2)]) + LogicalAggregate(group=[{0}], agg#0=[COUNT($1)]) + LogicalProject(f=[$2], d=[$0]) + LogicalFilter(condition=[AND(=($cor0.a, $0), =($cor0.b, $1))]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($3, 2)]) + +- LogicalAggregate(group=[{0, 1, 2}], agg#0=[COUNT($1)]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) ]]> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala index d2a2c686cf8ae..5e70a13395e11 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala @@ -1323,6 +1323,51 @@ class SubQuerySemiJoinTest extends SubQueryTestBase { util.verifyRelPlanNotExpected(sqlQuery, "joinType=[semi]") } + @Test + def testExistsWithCorrelatedOnWhere_Having1(): Unit = { + // Correlated WHERE plus HAVING on a single aggregate output. + // Regression for SubQueryDecorrelator: the non-correlated HAVING predicate must be + // re-indexed against the rewritten Aggregate (which receives the correlated column + // injected into its group key). + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d GROUP BY r.f HAVING SUM(r.e) >= 3)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Having2(): Unit = { + // Compound HAVING with multiple aggregate refs. + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d GROUP BY r.f HAVING SUM(r.e) >= 3 AND MAX(r.e) < 100)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Having3(): Unit = { + // HAVING that mixes an aggregate ref with COUNT(*). + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d GROUP BY r.f HAVING SUM(r.e) >= 3 AND COUNT(*) > 1)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Having4(): Unit = { + // Multiple correlated WHERE columns combined with a HAVING on aggregate output. + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d AND l.b = r.e GROUP BY r.f HAVING COUNT(r.d) >= 2)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Aggregate_LocalWhere(): Unit = { + // Mixed correlated + local WHERE, no HAVING. Guards against an over-eager fix: + // the local predicate `r.e > 10` sits below the Aggregate, so its RexInputRef must + // remain stable through decorrelation. + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d AND r.e > 10 GROUP BY r.f)" + util.verifyRelPlan(sqlQuery) + } + @Test def testExistsWithCorrelatedOnWhere_UnsupportedAggregate1(): Unit = { util.addTableSource[(Int, Long)]("l1", 'a, 'b)