diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/MultiJoinUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/MultiJoinUtil.java index 83bf9be50c1a3..59a69599381d9 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/MultiJoinUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/MultiJoinUtil.java @@ -71,9 +71,13 @@ private static void extractEqualityConditions( final SqlKind kind = call.getOperator().getKind(); if (kind != SqlKind.EQUALS) { - for (final RexNode operand : call.getOperands()) { - extractEqualityConditions( - operand, inputOffsets, inputFieldCounts, joinAttributeMap); + // Only conjunctions (AND) can contain equality conditions that are valid for multijoin. + // All other condition types are deferred to the postJoinFilter. + if (kind == SqlKind.AND) { + for (final RexNode operand : call.getOperands()) { + extractEqualityConditions( + operand, inputOffsets, inputFieldCounts, joinAttributeMap); + } } return; } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinSemanticTests.java index d9f371bcd47e7..b3b51cc4f79c2 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinSemanticTests.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinSemanticTests.java @@ -51,7 +51,9 @@ public List programs() { MultiJoinTestPrograms.MULTI_JOIN_LEFT_OUTER_WITH_NULL_KEYS, MultiJoinTestPrograms.MULTI_JOIN_NULL_SAFE_JOIN_WITH_NULL_KEYS, MultiJoinTestPrograms.MULTI_JOIN_MIXED_CHANGELOG_MODES, - MultiJoinTestPrograms - .MULTI_JOIN_WITH_TIME_ATTRIBUTES_IN_CONDITIONS_MATERIALIZATION); + MultiJoinTestPrograms.MULTI_JOIN_WITH_TIME_ATTRIBUTES_IN_CONDITIONS_MATERIALIZATION, + MultiJoinTestPrograms.MULTI_JOIN_TWO_WAY_INNER_JOIN_WITH_WHERE_IN, + MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_INNER_JOIN_MULTI_KEY_TYPES, + MultiJoinTestPrograms.MULTI_JOIN_FOUR_WAY_MIXED_JOIN_MULTI_KEY_TYPES_SHUFFLED); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinTestPrograms.java index 2137a1e5642d3..31442c08972d0 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinTestPrograms.java @@ -19,7 +19,6 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; import org.apache.flink.table.api.config.OptimizerConfigOptions; -import org.apache.flink.table.api.config.TableConfigOptions; import org.apache.flink.table.test.program.SinkTestStep; import org.apache.flink.table.test.program.SourceTestStep; import org.apache.flink.table.test.program.TableTestProgram; @@ -195,7 +194,7 @@ public class MultiJoinTestPrograms { "order_id STRING PRIMARY KEY NOT ENFORCED", "product STRING", "user_id_1 STRING") - .addOption("changelog-mode", "I,D") + .addOption("changelog-mode", "I") .producedValues( Row.ofKind(RowKind.INSERT, "order0", "ProdB", "1"), Row.ofKind(RowKind.INSERT, "order6", "ProdF", "6"), @@ -276,6 +275,7 @@ public class MultiJoinTestPrograms { TableTestProgram.of( "three-way-left-outer-join-with-restore", "three way left outer join with restore") + .setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) .setupTableSource( SourceTestStep.newBuilder("Users") .addSchema("user_id STRING", "name STRING") @@ -329,6 +329,7 @@ public class MultiJoinTestPrograms { TableTestProgram.of( "three-way-inner-join-with-restore", "three way inner join with restore") + .setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) .setupTableSource( SourceTestStep.newBuilder("Users") .addSchema("user_id STRING", "name STRING") @@ -377,6 +378,61 @@ public class MultiJoinTestPrograms { + "INNER JOIN Payments p ON u.user_id = p.user_id") .build(); + public static final TableTestProgram + MULTI_JOIN_THREE_WAY_INNER_JOIN_WITH_TTL_HINTS_WITH_RESTORE = + TableTestProgram.of( + "three-way-inner-join-with-ttl-hints-with-restore", + "three way inner join with restore and ttl hints") + .setupConfig( + OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) + .setupTableSource( + SourceTestStep.newBuilder("Users") + .addSchema("user_id STRING", "name STRING") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, "1", "Gus"), + Row.ofKind(RowKind.INSERT, "2", "Bob")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, "3", "Alice")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Orders") + .addSchema("user_id STRING", "order_id STRING") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, "1", "order1"), + Row.ofKind(RowKind.INSERT, "2", "order2")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, "3", "order3")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Payments") + .addSchema("user_id STRING", "payment_id STRING") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, "1", "payment1"), + Row.ofKind(RowKind.INSERT, "2", "payment2")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, "3", "payment3")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema( + "user_id STRING", + "name STRING", + "order_id STRING", + "payment_id STRING") + .consumedBeforeRestore( + "+I[1, Gus, order1, payment1]", + "+I[2, Bob, order2, payment2]") + .consumedAfterRestore("+I[3, Alice, order3, payment3]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO sink " + + "SELECT /*+ STATE_TTL('u' = '1d', 'p' = '2d') */u.user_id, u.name, o.order_id, p.payment_id " + + "FROM Users u " + + "INNER JOIN Orders o ON u.user_id = o.user_id " + + "INNER JOIN Payments p ON u.user_id = p.user_id") + .build(); + public static final TableTestProgram MULTI_JOIN_THREE_WAY_INNER_JOIN_NO_JOIN_KEY = TableTestProgram.of( "three-way-inner-join-no-join-key", @@ -454,7 +510,7 @@ public class MultiJoinTestPrograms { TableTestProgram.of( "four-way-join-no-common-join-key-with-restore", "four way join no common join key with restore") - .setupConfig(TableConfigOptions.PLAN_FORCE_RECOMPILE, true) + .setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) .setupTableSource( SourceTestStep.newBuilder("Users") .addSchema( @@ -486,7 +542,7 @@ public class MultiJoinTestPrograms { "order_id STRING PRIMARY KEY NOT ENFORCED", "product STRING", "user_id_1 STRING") - .addOption("changelog-mode", "I,D") + .addOption("changelog-mode", "I") .producedBeforeRestore( Row.ofKind(RowKind.INSERT, "order2", "ProdB", "2"), Row.ofKind(RowKind.INSERT, "order0", "ProdB", "1"), @@ -536,7 +592,7 @@ public class MultiJoinTestPrograms { "order_id STRING", "payment_id STRING", "location STRING") - .addOption("sink-changelog-mode-enforced", "I,UA,UB,D") + .addOption("changelog-mode", "I,UA,UB,D") .consumedBeforeRestore( "+I[1, Gus, order0, 1, London]", "+I[1, Gus, order1, 1, London]", @@ -593,7 +649,7 @@ public class MultiJoinTestPrograms { "order_id STRING PRIMARY KEY NOT ENFORCED", "product STRING", "user_id_1 STRING") - .addOption("changelog-mode", "I,D") + .addOption("changelog-mode", "I") .producedValues( Row.ofKind(RowKind.INSERT, "order0", "ProdB", "1"), Row.ofKind(RowKind.INSERT, "order6", "ProdF", "6"), @@ -640,7 +696,7 @@ public class MultiJoinTestPrograms { "order_id STRING", "payment_id STRING", "location STRING") - .addOption("sink-changelog-mode-enforced", "I,UA,UB,D") + .addOption("changelog-mode", "I,UA,UB,D") .consumedValues( "+I[1, Gus, order0, 1, London]", "+I[1, Gus, order1, 1, London]", @@ -665,7 +721,7 @@ public class MultiJoinTestPrograms { TableTestProgram.of( "four-way-complex-updating-join-with-restore", "four way complex updating join with restore") - .setupConfig(TableConfigOptions.PLAN_FORCE_RECOMPILE, true) + .setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) .setupTableSource( SourceTestStep.newBuilder("Users") .addSchema( @@ -697,7 +753,7 @@ public class MultiJoinTestPrograms { "order_id STRING PRIMARY KEY NOT ENFORCED", "product STRING", "user_id_1 STRING") - .addOption("changelog-mode", "I,D") + .addOption("changelog-mode", "I") .producedBeforeRestore( Row.ofKind(RowKind.INSERT, "order2", "ProdB", "2"), Row.ofKind(RowKind.INSERT, "order0", "ProdB", "1"), @@ -747,7 +803,7 @@ public class MultiJoinTestPrograms { "order_id STRING", "payment_id STRING", "location STRING") - .addOption("sink-changelog-mode-enforced", "I,UA,UB,D") + .addOption("changelog-mode", "I,UA,UB,D") .consumedBeforeRestore( "+I[1, Gus, order0, payment1, London]", "+I[1, Gus, order1, payment1, London]", @@ -769,6 +825,106 @@ public class MultiJoinTestPrograms { + "LEFT JOIN Shipments s ON p.user_id_2 = s.user_id_3") .build(); + public static final TableTestProgram + MULTI_JOIN_FOUR_WAY_COMPLEX_PRESERVES_UPSERT_KEY_WITH_RESTORE = + TableTestProgram.of( + "four-way-complex-preserves-upsert-key-with-restore", + "four way complex preserves upsert key with restore") + .setupConfig( + OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) + .setupTableSource( + SourceTestStep.newBuilder("Users") + .addSchema( + "user_id_0 STRING PRIMARY KEY NOT ENFORCED", + "name STRING") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, "1", "Gus"), + Row.ofKind(RowKind.INSERT, "3", "Joe")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, "2", "Bob")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Orders") + .addSchema( + "order_id STRING NOT NULL", + "user_id_1 STRING NOT NULL", + "product STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`order_id`, `user_id_1`) NOT ENFORCED") + .addOption("changelog-mode", "I") + .producedBeforeRestore( + Row.ofKind( + RowKind.INSERT, "order1", "1", "ProdA")) + .producedAfterRestore( + Row.ofKind( + RowKind.INSERT, "order2", "2", "ProdB")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Payments") + .addSchema( + "payment_id STRING NOT NULL", + "user_id_2 STRING NOT NULL", + "price INT", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`payment_id`, `user_id_2`) NOT ENFORCED") + .addOption("changelog-mode", "I") + .producedBeforeRestore( + Row.ofKind( + RowKind.INSERT, "payment1", "1", 100), + Row.ofKind( + RowKind.INSERT, "payment3", "3", 300)) + .producedAfterRestore( + Row.ofKind( + RowKind.INSERT, "payment2", "2", 200)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Address") + .addSchema( + "user_id_3 STRING NOT NULL", + "location STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id_3`) NOT ENFORCED") + .addOption("changelog-mode", "I") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, "1", "London"), + Row.ofKind(RowKind.INSERT, "3", "Berlin")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, "2", "Paris")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addOption("changelog-mode", "I,UA,D") + .addSchema( + "user_id_0 STRING NOT NULL", + "order_id STRING NOT NULL", + "user_id_1 STRING NOT NULL", + "payment_id STRING NOT NULL", + "user_id_2 STRING NOT NULL", + "name STRING", + "location STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id_0`, `order_id`, `user_id_1`, `payment_id`, `user_id_2`) NOT ENFORCED") + .consumedBeforeRestore( + "+I[1, order1, 1, payment1, 1, Gus, London]") + .consumedAfterRestore( + "+I[2, order2, 2, payment2, 2, Bob, Paris]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO `sink`\n" + + "SELECT\n" + + " u.user_id_0,\n" + + " o.order_id,\n" + + " o.user_id_1,\n" + + " p.payment_id,\n" + + " p.user_id_2,\n" + + " u.name,\n" + + " a.location\n" + + "FROM Users u\n" + + "JOIN Orders o\n" + + " ON u.user_id_0 = o.user_id_1 AND o.product IS NOT NULL\n" + + "JOIN Payments p\n" + + " ON u.user_id_0 = p.user_id_2 AND p.price >= 0\n" + + "JOIN Address a\n" + + " ON u.user_id_0 = a.user_id_3 AND a.location IS NOT NULL") + .build(); + public static final TableTestProgram MULTI_JOIN_WITH_TIME_ATTRIBUTES_MATERIALIZATION = TableTestProgram.of( "three-way-join-with-time-attributes", @@ -896,6 +1052,65 @@ public class MultiJoinTestPrograms { + "INNER JOIN Payments p ON u.user_id_0 = p.user_id_2") .build(); + public static final TableTestProgram MULTI_JOIN_MIXED_CHANGELOG_MODES = + TableTestProgram.of( + "three-way-mixed-changelog-modes", + "three way join with mixed changelog modes and primary key configurations") + .setupTableSource( + SourceTestStep.newBuilder("AppendTable") + .addSchema("id STRING PRIMARY KEY NOT ENFORCED, val STRING") + .addOption("changelog-mode", "I") + .producedValues( + Row.ofKind(RowKind.INSERT, "1", "append1"), + Row.ofKind(RowKind.INSERT, "2", "append2"), + Row.ofKind(RowKind.INSERT, "3", "append3")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("RetractTable") + .addSchema("ref_id STRING, data STRING") + .addOption("changelog-mode", "I,UA,UB,D") + .producedValues( + Row.ofKind(RowKind.INSERT, "1", "retract1"), + Row.ofKind(RowKind.INSERT, "2", "retract2"), + Row.ofKind(RowKind.INSERT, "3", "retract3"), + Row.ofKind(RowKind.DELETE, "3", "retract3"), + Row.ofKind(RowKind.INSERT, "1", "retract1_new")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("UpsertTable") + .addSchema( + "key_id STRING PRIMARY KEY NOT ENFORCED, status STRING") + .addOption("changelog-mode", "I,UA,D") + .producedValues( + Row.ofKind(RowKind.INSERT, "1", "active"), + Row.ofKind(RowKind.INSERT, "2", "pending"), + Row.ofKind(RowKind.UPDATE_AFTER, "2", "active"), + Row.ofKind(RowKind.INSERT, "3", "inactive"), + Row.ofKind(RowKind.DELETE, "3", "inactive")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema( + "id STRING", + "val STRING", + "data STRING", + "status STRING") + .addOption("changelog-mode", "I,UA,UB,D") + .consumedValues( + "+I[1, append1, retract1, active]", + "+I[2, append2, retract2, active]", + "+I[1, append1, retract1_new, active]", + "+I[3, append3, null, null]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO sink " + + "SELECT a.id, a.val, r.data, u.status " + + "FROM AppendTable a " + + "LEFT JOIN RetractTable r ON a.id = r.ref_id " + + "LEFT JOIN UpsertTable u ON a.id = u.key_id") + .build(); + public static final TableTestProgram MULTI_JOIN_THREE_WAY_LEFT_OUTER_JOIN_WITH_CTE = TableTestProgram.of( "left-outer-join-with-cte", @@ -907,7 +1122,7 @@ public class MultiJoinTestPrograms { "user_id STRING", "order_id STRING PRIMARY KEY NOT ENFORCED", "product STRING") - .addOption("changelog-mode", "I, UA,D") + .addOption("changelog-mode", "I,UA,D") .producedValues( Row.ofKind(RowKind.INSERT, "2", "order2", "Product B"), Row.ofKind( @@ -1028,64 +1243,151 @@ public class MultiJoinTestPrograms { + "INNER JOIN OrdersNullSafe o ON u.user_id IS NOT DISTINCT FROM o.user_id") .build(); - public static final TableTestProgram MULTI_JOIN_MIXED_CHANGELOG_MODES = + static final TableTestProgram MULTI_JOIN_TWO_WAY_LEFT_JOIN_PRESERVES_UPSERT_KEY_WITH_RESTORE = TableTestProgram.of( - "three-way-mixed-changelog-modes", - "three way join with mixed changelog modes and primary key configurations") + "two-way-left-join-preserves-upsert-key-with-restore", + "validates upsert with non key filter with restore") .setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) .setupTableSource( - SourceTestStep.newBuilder("AppendTable") - .addSchema("id STRING PRIMARY KEY NOT ENFORCED, val STRING") - .addOption("changelog-mode", "I") - .producedValues( - Row.ofKind(RowKind.INSERT, "1", "append1"), - Row.ofKind(RowKind.INSERT, "2", "append2"), - Row.ofKind(RowKind.INSERT, "3", "append3")) + SourceTestStep.newBuilder("Users") + .addOption("changelog-mode", "I,UA,D") + .addSchema( + "user_id INT NOT NULL", + "shard_id INT NOT NULL", + "description STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`) NOT ENFORCED") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, 1, 1, "shard_a")) + .producedAfterRestore( + Row.ofKind(RowKind.UPDATE_AFTER, 3, 1, "another shard")) .build()) .setupTableSource( - SourceTestStep.newBuilder("RetractTable") - .addSchema("ref_id STRING, data STRING") - .addOption("changelog-mode", "I,UA,UB,D") - .producedValues( - Row.ofKind(RowKind.INSERT, "1", "retract1"), - Row.ofKind(RowKind.INSERT, "2", "retract2"), - Row.ofKind(RowKind.INSERT, "3", "retract3"), - Row.ofKind(RowKind.DELETE, "3", "retract3"), - Row.ofKind(RowKind.INSERT, "1", "retract1_new")) + SourceTestStep.newBuilder("Orders") + .addOption("changelog-mode", "I,UA,D") + .addSchema( + "user_id INT NOT NULL", + "order_id BIGINT NOT NULL", + "product STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`, `order_id`) NOT ENFORCED") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, 1, 1L, "a"), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 1L, "a_updated")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, 1, 2L, "b"), + Row.ofKind(RowKind.DELETE, 1, 2L, "b"), + Row.ofKind(RowKind.INSERT, 3, 1L, "b"), + Row.ofKind(RowKind.INSERT, 9, 9L, "b")) .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addOption("changelog-mode", "I,UA,D") + .addSchema( + "`user_id` INT NOT NULL", + "`order_id` BIGINT NOT NULL", + "product STRING", + "user_shard_id INT", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`, `order_id`) NOT ENFORCED") + .consumedBeforeRestore("+I[1, 1, a_updated, 1]") + .consumedAfterRestore("+I[3, 1, b, 1]", "+I[9, 9, b, null]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO `sink`\n" + + "SELECT\n" + + " o.user_id,\n" + + " o.order_id,\n" + + " o.product,\n" + + " u.shard_id\n" + + "FROM Orders o\n" + + "LEFT JOIN Users u\n" + + " ON u.user_id = o.user_id\n") + .build(); + + static final TableTestProgram MULTI_JOIN_THREE_WAY_JOIN_PRESERVES_UPSERT_KEY_WITH_RESTORE = + TableTestProgram.of( + "three-way-upsert-preserves-key-with-restore", + "validates upsert with non key filter with restore") + .setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true) .setupTableSource( - SourceTestStep.newBuilder("UpsertTable") + SourceTestStep.newBuilder("Users") + .addOption("changelog-mode", "I,UA,D") .addSchema( - "key_id STRING PRIMARY KEY NOT ENFORCED, status STRING") + "user_id INT NOT NULL", + "shard_id INT NOT NULL", + "description STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`) NOT ENFORCED") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, 1, 1, "shard_a")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, 2, 1, "shard_b")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Orders") .addOption("changelog-mode", "I,UA,D") - .producedValues( - Row.ofKind(RowKind.INSERT, "1", "active"), - Row.ofKind(RowKind.INSERT, "2", "pending"), - Row.ofKind(RowKind.UPDATE_AFTER, "2", "active"), - Row.ofKind(RowKind.INSERT, "3", "inactive"), - Row.ofKind(RowKind.DELETE, "3", "inactive")) + .addSchema( + "user_id INT NOT NULL", + "order_id BIGINT NOT NULL", + "product STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`, `order_id`) NOT ENFORCED") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, 1, 1L, "a"), + Row.ofKind(RowKind.INSERT, 1, 2L, "b")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, 2, 1L, "c"), + Row.ofKind(RowKind.INSERT, 2, 2L, "d")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Payments") + .addOption("changelog-mode", "I,UA,D") + .addSchema( + "user_id INT NOT NULL", + "payment_id BIGINT NOT NULL", + "product STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`, `payment_id`) NOT ENFORCED") + .producedBeforeRestore( + Row.ofKind(RowKind.INSERT, 1, 1L, "a"), + Row.ofKind(RowKind.INSERT, 1, 2L, "b")) + .producedAfterRestore( + Row.ofKind(RowKind.INSERT, 2, 1L, "c"), + Row.ofKind(RowKind.INSERT, 2, 2L, "d")) .build()) .setupTableSink( SinkTestStep.newBuilder("sink") + .addOption("changelog-mode", "I,UA,D") .addSchema( - "id STRING", - "val STRING", - "data STRING", - "status STRING") - .addOption("sink-changelog-mode-enforced", "I,UA,UB,D") - .consumedValues( - "+I[1, append1, retract1, active]", - "+I[2, append2, retract2, active]", - "+I[1, append1, retract1_new, active]", - "+I[3, append3, null, null]") + "`user_id` INT NOT NULL", + "`order_id` BIGINT NOT NULL", + "`user_id2` INT NOT NULL", + "payment_id BIGINT NOT NULL", + "`user_id3` INT NOT NULL", + "description STRING", + "CONSTRAINT `PRIMARY` PRIMARY KEY (`user_id`, `order_id`, `user_id2`, `payment_id`, `user_id3`) NOT ENFORCED") + .consumedBeforeRestore( + "+I[1, 1, 1, 2, 1, shard_a]", + "+I[1, 1, 1, 1, 1, shard_a]", + "+I[1, 2, 1, 1, 1, shard_a]", + "+I[1, 2, 1, 2, 1, shard_a]") + .consumedAfterRestore( + "+I[2, 2, 2, 1, 2, shard_b]", + "+I[2, 1, 2, 2, 2, shard_b]", + "+I[2, 2, 2, 2, 2, shard_b]", + "+I[2, 1, 2, 1, 2, shard_b]") .testMaterializedData() .build()) .runSql( - "INSERT INTO sink " - + "SELECT a.id, a.val, r.data, u.status " - + "FROM AppendTable a " - + "LEFT JOIN RetractTable r ON a.id = r.ref_id " - + "LEFT JOIN UpsertTable u ON a.id = u.key_id") + "INSERT INTO `sink`\n" + + "SELECT\n" + + " o.user_id,\n" + + " o.order_id,\n" + + " p.user_id,\n" + + " p.payment_id,\n" + + " u.user_id,\n" + + " u.description\n" + + "FROM Users u\n" + + "JOIN `Orders` AS o\n" + + " ON o.user_id = u.user_id\n" + + "JOIN Payments p\n" + + " ON o.user_id = p.user_id") .build(); public static final TableTestProgram @@ -1184,4 +1486,340 @@ public class MultiJoinTestPrograms { + ") Q " + "GROUP BY Q.category") .build(); + + public static final TableTestProgram MULTI_JOIN_TWO_WAY_INNER_JOIN_WITH_WHERE_IN = + TableTestProgram.of( + "two-way-inner-join-with-where-in", + "two way inner join with WHERE IN clause") + .setupTableSource( + SourceTestStep.newBuilder("Records") + .addSchema( + "`record_id` STRING PRIMARY KEY NOT ENFORCED", + "`user_id` INT") + .addOption("changelog-mode", "I,UA,D") + .producedValues( + Row.ofKind(RowKind.INSERT, "record_1", 1), + Row.ofKind(RowKind.INSERT, "record_2", 2), + Row.ofKind(RowKind.INSERT, "record_3", 3)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Users") + .addSchema( + "`user_id` INT PRIMARY KEY NOT ENFORCED", "`id` STRING") + .addOption("changelog-mode", "I,UA,D") + .producedValues( + Row.ofKind(RowKind.INSERT, 1, "record_1"), + Row.ofKind(RowKind.INSERT, 2, "record_2"), + Row.ofKind(RowKind.INSERT, 3, "record_3")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema("record_id STRING", "`user_id` INT", "id STRING") + .addOption("changelog-mode", "I,UA,UB,D") + .consumedValues( + // Only records with user_id 1 and 2 should be + // included due to WHERE IN clause + "+I[record_1, 1, record_1]", + "+I[record_2, 2, record_2]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO sink " + + "SELECT r.`record_id`, r.`user_id`, u.`id` " + + "FROM Records r " + + "INNER JOIN Users u ON u.`user_id` = r.`user_id` AND u.`id` = r.`record_id` " + + "WHERE r.`user_id` IN (1, 2)") + .build(); + + public static final TableTestProgram MULTI_JOIN_THREE_WAY_INNER_JOIN_MULTI_KEY_TYPES = + TableTestProgram.of( + "three-way-inner-join-three-keys-shuffled", + "three way inner join with three keys in shuffled order") + .setupTableSource( + SourceTestStep.newBuilder("Users3K") + .addSchema("k1 STRING", "k2 INT", "k3 BOOLEAN", "name STRING") + .producedValues( + Row.ofKind(RowKind.INSERT, "K1", 100, true, "Gus"), + Row.ofKind(RowKind.INSERT, "A1", 200, false, "Bob")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Orders3K") + .addSchema( + // Shuffled key order: k2 first, then k3, then k1 + "k2 INT PRIMARY KEY NOT ENFORCED," + + "order_id STRING," + + "k3 BOOLEAN," + + "k1 STRING," + + "k4 STRING") + .producedValues( + // Matches Gus (Users3K: K1,100,true) + Row.ofKind( + RowKind.INSERT, + 100, + "order1", + true, + "K1", + "k4_val1"), + // Matches Bob (Users3K: A1,200,false) + Row.ofKind( + RowKind.INSERT, + 200, + "order2", + false, + "A1", + "k4_val2"), + // Partial matches that should NOT join (one key wrong) + Row.ofKind( + RowKind.INSERT, + 100, + "order_partial1", + true, + "WRONG", + "k4_u"), + Row.ofKind( + RowKind.INSERT, + 100, + "order_partial2", + false, + "K1", + "k4_u"), + Row.ofKind( + RowKind.INSERT, + 200, + "order_partial3", + false, + "WRONG", + "k4_u"), + Row.ofKind( + RowKind.INSERT, + 200, + "order_partial4", + true, + "A1", + "k4_u")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Payments3K") + .addSchema( + // Shuffled key order: k1 first, then k3, then k2 + "k1 STRING", + "payment_id STRING", + "k3 BOOLEAN", + "k2 INT") + .producedValues( + // Matches Gus (Users3K: K1,100,true) + Row.ofKind(RowKind.INSERT, "K1", "payment1", true, 100), + // Matches Bob (Users3K: A1,200,false) + Row.ofKind( + RowKind.INSERT, "A1", "payment2", false, 200), + // Partial matches that should NOT join (one key wrong) + Row.ofKind( + RowKind.INSERT, + "K1", + "payment_partial1", + false, + 100), + Row.ofKind( + RowKind.INSERT, + "A1", + "payment_partial2", + false, + 300)) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema( + "name STRING", + "order_id STRING", + "payment_id STRING", + "k1 STRING", + "k2 INT", + "k3 BOOLEAN") + .addOption("changelog-mode", "I") + .consumedValues( + "+I[Gus, order1, payment1, K1, 100, true]", + "+I[Bob, order2, payment2, A1, 200, false]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO sink " + + "SELECT U.name, O.order_id, P.payment_id, U.k1, U.k2, U.k3 " + + "FROM Users3K AS U " + + "JOIN Orders3K AS O ON U.k1 = O.k1 AND U.k2 = O.k2 AND U.k3 = O.k3 " + + "JOIN Payments3K AS P ON U.k1 = P.k1 AND U.k2 = P.k2 AND U.k3 = P.k3") + .build(); + + public static final TableTestProgram MULTI_JOIN_FOUR_WAY_MIXED_JOIN_MULTI_KEY_TYPES_SHUFFLED = + TableTestProgram.of( + "four-way-mixed-join-multi-key-types-shuffled", + "four way mixed join with three keys in shuffled order and a final join with two conditions") + .setupTableSource( + SourceTestStep.newBuilder("Users4K") + .addSchema( + "k1 STRING PRIMARY KEY NOT ENFORCED", + "k2 INT", + "k3 BOOLEAN", + "k4 STRING", + "name STRING") + .producedValues( + Row.ofKind( + RowKind.INSERT, + "K1", + 100, + true, + "k4_val1", + "Gus"), + Row.ofKind( + RowKind.INSERT, + "A1", + 200, + false, + "k4_val2", + "Bob"), + Row.ofKind( + RowKind.INSERT, + "A1", + 200, + false, + "k4_val2", + "John"), + Row.ofKind( + RowKind.DELETE, + "A1", + 200, + false, + "k4_val2", + "John"), + Row.ofKind( + RowKind.INSERT, + "U_NO_MATCH", + 1, + true, + "k4_u", + "UserNoMatch")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Orders4K") + .addSchema( + // Shuffled key order: k2 first, then k3, then k1 + "k2 INT PRIMARY KEY NOT ENFORCED", + "order_id STRING", + "k3 BOOLEAN", + "k1 STRING", + "k4 STRING") + .producedValues( + // Matches Gus (Users4K: K1,100,true,k4_val1) + Row.ofKind( + RowKind.INSERT, + 100, + "order1", + true, + "K1", + "k4_val1"), + // Matches Bob (Users4K: A1,200,false,k4_val2) + Row.ofKind( + RowKind.INSERT, + 200, + "order2", + false, + "A1", + "k4_val2"), + // No match + Row.ofKind( + RowKind.INSERT, + 2, + "O_NO_MATCH", + true, + "K_O", + "k4_o")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Payments4K") + .addSchema( + // Shuffled key order: k1 first, then k3, then + // k2 + "k1 STRING PRIMARY KEY NOT ENFORCED," + + "payment_id STRING," + + "k3 BOOLEAN," + + "k2 INT," + + "k4 STRING") + .producedValues( + // Matches Gus (Users4K: K1,100,true,k4_val1) + Row.ofKind( + RowKind.INSERT, + "K1", + "payment1", + true, + 100, + "k4_val1"), + // Matches Bob (Users4K: A1,200,false,k4_val2) + Row.ofKind( + RowKind.INSERT, + "A1", + "payment2", + false, + 200, + "k4_val2"), + // No match + Row.ofKind( + RowKind.INSERT, + "P_NO_MATCH", + "P_NO_MATCH_ID", + true, + 3, + "k4_p")) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("Shipments4K") + .addSchema( + "k3 BOOLEAN PRIMARY KEY NOT ENFORCED", + "k4 STRING", + "ship_id STRING") + .producedValues( + // Matches Gus + Row.ofKind(RowKind.INSERT, true, "k4_val1", "ship1"), + // Does not match anyone + Row.ofKind( + RowKind.INSERT, + false, + "k4_val_no_match", + "ship2"), + // No match + Row.ofKind( + RowKind.INSERT, + true, + "k4_val_another_no_match", + "ship3"), + // No match + Row.ofKind( + RowKind.INSERT, + false, + "k4_s_no_match", + "ship_no_match")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema( + "name STRING", + "order_id STRING", + "payment_id STRING", + "ship_id STRING", + "k1 STRING", + "k2 INT", + "k3 BOOLEAN", + "k4 STRING") + .addOption("changelog-mode", "I,UA,UB,D") + .consumedValues( + "+I[Bob, order2, payment2, null, A1, 200, false, k4_val2]") + .testMaterializedData() + .build()) + .runSql( + "INSERT INTO sink " + + "SELECT U.name, O.order_id, P.payment_id, S.ship_id, U.k1, U.k2, U.k3, U.k4 " + + "FROM Users4K AS U " + + "JOIN Orders4K AS O ON O.k2 > 100 AND U.k2 = O.k2 AND U.k1 = O.k1 AND U.k3 = O.k3 AND U.k4 = O.k4 " + + "JOIN Payments4K AS P ON U.k1 = P.k1 AND O.k3 = P.k3 AND U.k4 = P.k4 AND O.k2 = P.k2 AND P.k2 > 100 " + + "LEFT JOIN Shipments4K AS S ON U.k3 = S.k3 AND U.k2 > 150 AND U.k4 = S.k4 " + + "WHERE U.k2 > 50") + .build(); } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml index 56bf1998204f0..d641e2836dbb0 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml @@ -23,25 +23,27 @@ Calc(select=[user_id, name, category_name, order_count, total_items, total_value +- GroupAggregate(groupBy=[user_id, name, category_name], select=[user_id, name, category_name, COUNT_RETRACT(DISTINCT order_id) AS order_count, SUM_RETRACT(quantity) AS total_items, SUM_RETRACT($f5) AS total_value, AVG_RETRACT(unit_price) AS avg_item_price, MAX_RETRACT(price) AS max_payment, MIN_RETRACT(price) AS min_payment, COUNT_RETRACT($f8) AS bulk_orders]) +- Exchange(distribution=[hash[user_id, name, category_name]]) +- Calc(select=[user_id, name, category_name, order_id, quantity, *(quantity, unit_price) AS $f5, unit_price, price, CASE(>(quantity, 5), 1, null:INTEGER) AS $f8]) - +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($8, $11), =($0, $16)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:8;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:8;RightInputId:1;RightFieldIndex:0;], 2=[LeftInputId:0;LeftFieldIndex:0;RightInputId:2;RightFieldIndex:2;]}], select=[user_id,name,cash,order_id,user_id0,product,item_id,order_id0,product_name,quantity,unit_price,category_id,category_name,parent_category,payment_id,price,user_id1], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) item_id, VARCHAR(2147483647) order_id0, VARCHAR(2147483647) product_name, INTEGER quantity, DOUBLE unit_price, VARCHAR(2147483647) category_id, VARCHAR(2147483647) category_name, VARCHAR(2147483647) parent_category, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1)]) + +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($0, $16)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:2;]}], select=[user_id,name,cash,order_id,user_id0,product,item_id,order_id0,product_name,quantity,unit_price,category_id,category_name,parent_category,payment_id,price,user_id1], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) item_id, VARCHAR(2147483647) order_id0, VARCHAR(2147483647) product_name, INTEGER quantity, DOUBLE unit_price, VARCHAR(2147483647) category_id, VARCHAR(2147483647) category_name, VARCHAR(2147483647) parent_category, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1)]) :- Exchange(distribution=[hash[user_id]]) - : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($3, $7)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:3;], 1=[LeftInputId:0;LeftFieldIndex:3;RightInputId:1;RightFieldIndex:1;]}], select=[user_id,name,cash,order_id,user_id0,product,item_id,order_id0,product_name,quantity,unit_price], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) item_id, VARCHAR(2147483647) order_id0, VARCHAR(2147483647) product_name, INTEGER quantity, DOUBLE unit_price)]) - : :- Exchange(distribution=[hash[order_id]]) - : : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($0, $4)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;]}], select=[user_id,name,cash,order_id,user_id0,product], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product)]) - : : :- Exchange(distribution=[hash[user_id]]) - : : : +- ChangelogNormalize(key=[user_id]) + : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($8, $11)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:8;], 1=[LeftInputId:0;LeftFieldIndex:8;RightInputId:1;RightFieldIndex:0;]}], select=[user_id,name,cash,order_id,user_id0,product,item_id,order_id0,product_name,quantity,unit_price,category_id,category_name,parent_category], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) item_id, VARCHAR(2147483647) order_id0, VARCHAR(2147483647) product_name, INTEGER quantity, DOUBLE unit_price, VARCHAR(2147483647) category_id, VARCHAR(2147483647) category_name, VARCHAR(2147483647) parent_category)]) + : :- Exchange(distribution=[hash[product_name]]) + : : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($3, $7)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:3;], 1=[LeftInputId:0;LeftFieldIndex:3;RightInputId:1;RightFieldIndex:1;]}], select=[user_id,name,cash,order_id,user_id0,product,item_id,order_id0,product_name,quantity,unit_price], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) item_id, VARCHAR(2147483647) order_id0, VARCHAR(2147483647) product_name, INTEGER quantity, DOUBLE unit_price)]) + : : :- Exchange(distribution=[hash[order_id]]) + : : : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($0, $4)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;]}], select=[user_id,name,cash,order_id,user_id0,product], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product)]) + : : : :- Exchange(distribution=[hash[user_id]]) + : : : : +- ChangelogNormalize(key=[user_id]) + : : : : +- Exchange(distribution=[hash[user_id]]) + : : : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) : : : +- Exchange(distribution=[hash[user_id]]) - : : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) - : : +- Exchange(distribution=[hash[user_id]]) - : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) - : +- Exchange(distribution=[hash[order_id]]) - : +- ChangelogNormalize(key=[item_id]) - : +- Exchange(distribution=[hash[item_id]]) - : +- TableSourceScan(table=[[default_catalog, default_database, OrderItems]], fields=[item_id, order_id, product_name, quantity, unit_price]) - :- Exchange(distribution=[hash[category_id]]) - : +- ChangelogNormalize(key=[category_id]) + : : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) + : : +- Exchange(distribution=[hash[order_id]]) + : : +- ChangelogNormalize(key=[item_id]) + : : +- Exchange(distribution=[hash[item_id]]) + : : +- TableSourceScan(table=[[default_catalog, default_database, OrderItems]], fields=[item_id, order_id, product_name, quantity, unit_price]) : +- Exchange(distribution=[hash[category_id]]) - : +- TableSourceScan(table=[[default_catalog, default_database, ProductCategories]], fields=[category_id, category_name, parent_category]) + : +- ChangelogNormalize(key=[category_id]) + : +- Exchange(distribution=[hash[category_id]]) + : +- TableSourceScan(table=[[default_catalog, default_database, ProductCategories]], fields=[category_id, category_name, parent_category]) +- Exchange(distribution=[hash[user_id]]) +- TableSourceScan(table=[[default_catalog, default_database, Payments]], fields=[payment_id, price, user_id]) ]]> @@ -101,22 +103,24 @@ Calc(select=[user_id, name, order_id, dept_name, project_name, budget]) GroupAggregate(groupBy=[user_id, name], select=[user_id, name, COUNT_RETRACT(order_id) AS total_orders, SUM_RETRACT(price) AS total_spent, AVG_RETRACT(price) AS avg_order_value, COUNT_RETRACT($f4) AS completed_orders]) +- Exchange(distribution=[hash[user_id, name]]) +- Calc(select=[user_id, name, order_id, price, CASE(is_final, 1, null:INTEGER) AS $f4]) - +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($2, $6), =($4, $9)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:2;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:4;], 1=[LeftInputId:0;LeftFieldIndex:2;RightInputId:1;RightFieldIndex:0;], 2=[LeftInputId:0;LeftFieldIndex:4;RightInputId:2;RightFieldIndex:0;]}], select=[user_id,name,order_id,product,payment_id,price,status_id,status_name,is_final,method_id,method_name,processing_fee], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, VARCHAR(2147483647) order_id, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) status_id, VARCHAR(2147483647) status_name, BOOLEAN is_final, VARCHAR(2147483647) method_id, VARCHAR(2147483647) method_name, DOUBLE processing_fee)]) - :- Exchange(distribution=[hash[order_id]]) - : +- Calc(select=[user_id, name, order_id, product, payment_id, price]) - : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($0, $4), =($0, $8)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;], 2=[LeftInputId:0;LeftFieldIndex:0;RightInputId:2;RightFieldIndex:2;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1)]) - : :- Exchange(distribution=[hash[user_id]]) - : : +- ChangelogNormalize(key=[user_id]) - : : +- Exchange(distribution=[hash[user_id]]) - : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) - : :- Exchange(distribution=[hash[user_id]]) - : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) - : +- Exchange(distribution=[hash[user_id]]) - : +- TableSourceScan(table=[[default_catalog, default_database, Payments]], fields=[payment_id, price, user_id]) - :- Exchange(distribution=[hash[status_id]]) - : +- ChangelogNormalize(key=[status_id]) + +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($4, $9)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:4;], 1=[LeftInputId:0;LeftFieldIndex:4;RightInputId:1;RightFieldIndex:0;]}], select=[user_id,name,order_id,product,payment_id,price,status_id,status_name,is_final,method_id,method_name,processing_fee], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, VARCHAR(2147483647) order_id, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) status_id, VARCHAR(2147483647) status_name, BOOLEAN is_final, VARCHAR(2147483647) method_id, VARCHAR(2147483647) method_name, DOUBLE processing_fee)]) + :- Exchange(distribution=[hash[payment_id]]) + : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($2, $6)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:2;], 1=[LeftInputId:0;LeftFieldIndex:2;RightInputId:1;RightFieldIndex:0;]}], select=[user_id,name,order_id,product,payment_id,price,status_id,status_name,is_final], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, VARCHAR(2147483647) order_id, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) status_id, VARCHAR(2147483647) status_name, BOOLEAN is_final)]) + : :- Exchange(distribution=[hash[order_id]]) + : : +- Calc(select=[user_id, name, order_id, product, payment_id, price]) + : : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($0, $4), =($0, $8)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;], 2=[LeftInputId:0;LeftFieldIndex:0;RightInputId:2;RightFieldIndex:2;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1)]) + : : :- Exchange(distribution=[hash[user_id]]) + : : : +- ChangelogNormalize(key=[user_id]) + : : : +- Exchange(distribution=[hash[user_id]]) + : : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) + : : :- Exchange(distribution=[hash[user_id]]) + : : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) + : : +- Exchange(distribution=[hash[user_id]]) + : : +- TableSourceScan(table=[[default_catalog, default_database, Payments]], fields=[payment_id, price, user_id]) : +- Exchange(distribution=[hash[status_id]]) - : +- TableSourceScan(table=[[default_catalog, default_database, OrderStatus]], fields=[status_id, status_name, is_final]) + : +- ChangelogNormalize(key=[status_id]) + : +- Exchange(distribution=[hash[status_id]]) + : +- TableSourceScan(table=[[default_catalog, default_database, OrderStatus]], fields=[status_id, status_name, is_final]) +- Exchange(distribution=[hash[method_id]]) +- ChangelogNormalize(key=[method_id]) +- Exchange(distribution=[hash[method_id]]) @@ -128,12 +132,12 @@ GroupAggregate(groupBy=[user_id, name], select=[user_id, name, COUNT_RETRACT(ord @@ -365,7 +376,7 @@ Calc(select=[user_id, name, order_id, payment_id, location]) :- Exchange(distribution=[hash[payment_id]]) : +- Calc(select=[user_id, name, cash, order_id, user_id0, product, payment_id, price, user_id1]) : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, AND(=($0, $9), =($6, $10), OR(>=(FLOOR($2), FLOOR($8)), <($8, 0)))]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:6;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:2;, LeftInputId:0;LeftFieldIndex:6;RightInputId:1;RightFieldIndex:3;]}], select=[user_id,name,cash,order_id,user_id0,product,$f6,payment_id,price,user_id1,$f3], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) $f6, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1, VARCHAR(2147483647) $f3)]) - : :- Exchange(distribution=[hash[user_id]]) + : :- Exchange(distribution=[hash[user_id, $f6]]) : : +- Calc(select=[user_id, name, cash, order_id, user_id0, product, UPPER(name) AS $f6]) : : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($4, $0)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;]}], select=[user_id,name,cash,order_id,user_id0,product], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product)]) : : :- Exchange(distribution=[hash[user_id]]) @@ -374,7 +385,7 @@ Calc(select=[user_id, name, order_id, payment_id, location]) : : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) : : +- Exchange(distribution=[hash[user_id]]) : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) - : +- Exchange(distribution=[hash[user_id]]) + : +- Exchange(distribution=[hash[user_id, $f3]]) : +- Calc(select=[payment_id, price, user_id, UPPER(payment_id) AS $f3]) : +- TableSourceScan(table=[[default_catalog, default_database, Payments]], fields=[payment_id, price, user_id]) +- Exchange(distribution=[hash[location]]) @@ -404,21 +415,23 @@ Calc(select=[user_id, name, order_id, payment_id]) (price, 1000), 'High', >(price, 500), 'Medium', 'Low') AS price_tier, REGEXP_REPLACE(tags, ',', ' | ') AS formatted_tags, TO_TIMESTAMP_LTZ(created_date, 3) AS product_created, COALESCE(preferred_category, 'None') AS user_preference, CASE(=(notification_level, 'HIGH'), 'Frequent Updates', =(notification_level, 'MEDIUM'), 'Daily Updates', 'Weekly Updates') AS notification_frequency]) -+- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($5, $9), =($0, $14)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:5;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:5;RightInputId:1;RightFieldIndex:0;], 2=[LeftInputId:0;LeftFieldIndex:0;RightInputId:2;RightFieldIndex:0;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1,product_id,product_name,description,created_date,tags,user_id2,preferred_category,notification_level], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1, VARCHAR(2147483647) product_id, VARCHAR(2147483647) product_name, VARCHAR(2147483647) description, BIGINT created_date, VARCHAR(2147483647) tags, VARCHAR(2147483647) user_id2, VARCHAR(2147483647) preferred_category, VARCHAR(2147483647) notification_level)]) - :- Exchange(distribution=[hash[user_id]]) - : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($0, $4), =($0, $8)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;], 2=[LeftInputId:0;LeftFieldIndex:0;RightInputId:2;RightFieldIndex:2;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1)]) - : :- Exchange(distribution=[hash[user_id]]) - : : +- ChangelogNormalize(key=[user_id]) ++- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($0, $14)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:0;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1,product_id,product_name,description,created_date,tags,user_id2,preferred_category,notification_level], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1, VARCHAR(2147483647) product_id, VARCHAR(2147483647) product_name, VARCHAR(2147483647) description, BIGINT created_date, VARCHAR(2147483647) tags, VARCHAR(2147483647) user_id2, VARCHAR(2147483647) preferred_category, VARCHAR(2147483647) notification_level)]) + :- Exchange(distribution=[hash[user_id]]) + : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT]], joinConditions=[[true, =($5, $9)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:5;], 1=[LeftInputId:0;LeftFieldIndex:5;RightInputId:1;RightFieldIndex:0;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1,product_id,product_name,description,created_date,tags], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1, VARCHAR(2147483647) product_id, VARCHAR(2147483647) product_name, VARCHAR(2147483647) description, BIGINT created_date, VARCHAR(2147483647) tags)]) + : :- Exchange(distribution=[hash[product]]) + : : +- MultiJoin(joinFilter=[true], joinTypes=[[INNER, LEFT, LEFT]], joinConditions=[[true, =($0, $4), =($0, $8)]], joinAttributeMap=[{0=[LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;, LeftInputId:-1;LeftFieldIndex:-1;RightInputId:0;RightFieldIndex:0;], 1=[LeftInputId:0;LeftFieldIndex:0;RightInputId:1;RightFieldIndex:1;], 2=[LeftInputId:0;LeftFieldIndex:0;RightInputId:2;RightFieldIndex:2;]}], select=[user_id,name,cash,order_id,user_id0,product,payment_id,price,user_id1], rowType=[RecordType(VARCHAR(2147483647) user_id, VARCHAR(2147483647) name, INTEGER cash, VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id0, VARCHAR(2147483647) product, VARCHAR(2147483647) payment_id, INTEGER price, VARCHAR(2147483647) user_id1)]) + : : :- Exchange(distribution=[hash[user_id]]) + : : : +- ChangelogNormalize(key=[user_id]) + : : : +- Exchange(distribution=[hash[user_id]]) + : : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) + : : :- Exchange(distribution=[hash[user_id]]) + : : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) : : +- Exchange(distribution=[hash[user_id]]) - : : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash]) - : :- Exchange(distribution=[hash[user_id]]) - : : +- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id, product]) - : +- Exchange(distribution=[hash[user_id]]) - : +- TableSourceScan(table=[[default_catalog, default_database, Payments]], fields=[payment_id, price, user_id]) - :- Exchange(distribution=[hash[product_id]]) - : +- ChangelogNormalize(key=[product_id]) + : : +- TableSourceScan(table=[[default_catalog, default_database, Payments]], fields=[payment_id, price, user_id]) : +- Exchange(distribution=[hash[product_id]]) - : +- TableSourceScan(table=[[default_catalog, default_database, ProductDetails]], fields=[product_id, product_name, description, created_date, tags]) + : +- ChangelogNormalize(key=[product_id]) + : +- Exchange(distribution=[hash[product_id]]) + : +- TableSourceScan(table=[[default_catalog, default_database, ProductDetails]], fields=[product_id, product_name, description, created_date, tags]) +- Exchange(distribution=[hash[user_id]]) +- ChangelogNormalize(key=[user_id]) +- Exchange(distribution=[hash[user_id]]) @@ -533,6 +546,42 @@ MultiJoin(joinFilter=[=($0, $3)], joinTypes=[[INNER, INNER]], joinConditions=[[t +- Exchange(distribution=[hash[user_id]]) +- Calc(select=[payment_id, user_id, price], where=[>(price, 100)]) +- TableSourceScan(table=[[default_catalog, default_database, Payments, filter=[]]], fields=[payment_id, price, user_id]) +]]> + + + + + (b, 100)]) + +- TableSourceScan(table=[[default_catalog, default_database, src2, project=[b, c, d], metadata=[]]], fields=[b, c, d]) + +advice[1]: [ADVICE] You might want to enable local-global two-phase optimization by configuring ('table.exec.mini-batch.enabled' to 'true', 'table.exec.mini-batch.allow-latency' to a positive long value, 'table.exec.mini-batch.size' to a positive long value). +advice[2]: [WARNING] The column(s): day(generated by non-deterministic function: CURRENT_TIMESTAMP ) can not satisfy the determinism requirement for correctly processing update message('UB'/'UA'/'D' in changelogMode, not 'I' only), this usually happens when input node has no upsertKey(upsertKeys=[{}]) or current node outputs non-deterministic update messages. Please consider removing these non-deterministic columns or making them deterministic by using deterministic functions. + +related rel plan: +Calc(select=[a, b, DATE_FORMAT(CURRENT_TIMESTAMP(), _UTF-16LE'yyMMdd') AS day], changelogMode=[I,UB,UA,D]) ++- TableSourceScan(table=[[default_catalog, default_database, src1, project=[a, b], metadata=[]]], fields=[a, b], changelogMode=[I,UB,UA,D]) + + ]]> @@ -755,11 +804,15 @@ Calc(select=[user_id, name, order_id, payment_id]) @@ -1063,53 +1116,21 @@ Calc(select=[product_id, -(price, discount) AS net_price, quantity, promo_text]) - - - - - (b, 100)]) - +- TableSourceScan(table=[[default_catalog, default_database, src2, project=[b, c, d], metadata=[]]], fields=[b, c, d]) - -advice[1]: [ADVICE] You might want to enable local-global two-phase optimization by configuring ('table.exec.mini-batch.enabled' to 'true', 'table.exec.mini-batch.allow-latency' to a positive long value, 'table.exec.mini-batch.size' to a positive long value). -advice[2]: [WARNING] The column(s): day(generated by non-deterministic function: CURRENT_TIMESTAMP ) can not satisfy the determinism requirement for correctly processing update message('UB'/'UA'/'D' in changelogMode, not 'I' only), this usually happens when input node has no upsertKey(upsertKeys=[{}]) or current node outputs non-deterministic update messages. Please consider removing these non-deterministic columns or making them deterministic by using deterministic functions. - -related rel plan: -Calc(select=[a, b, DATE_FORMAT(CURRENT_TIMESTAMP(), _UTF-16LE'yyMMdd') AS day], changelogMode=[I,UB,UA,D]) -+- TableSourceScan(table=[[default_catalog, default_database, src1, project=[a, b], metadata=[]]], fields=[a, b], changelogMode=[I,UB,UA,D]) - - ]]> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala index e55fa33070b56..fa5ef6a2752ca 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.runtime.stream.sql import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.scala._ import org.apache.flink.table.api.config.ExecutionConfigOptions +import org.apache.flink.table.api.config.OptimizerConfigOptions import org.apache.flink.table.planner.expressions.utils.FuncWithOpen import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.runtime.utils._ @@ -623,6 +624,22 @@ class JoinITCase(miniBatch: MiniBatchMode, state: StateBackendMode, enableAsyncS assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted) } + @TestTemplate + def testInnerMultiJoinWithEqualPk(): Unit = { + tEnv.getConfig.getConfiguration + .setString(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED.key(), "true") + val query1 = "SELECT SUM(a2) AS a2, a1 FROM A group by a1" + val query2 = "SELECT SUM(b2) AS b2, b1 FROM B group by b1" + val query = s"SELECT a1, b1 FROM ($query1) JOIN ($query2) ON a1 = b1" + + val sink = new TestingRetractSink + tEnv.sqlQuery(query).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = Seq("1,1", "2,2", "3,3") + assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted) + } + @TestTemplate def testInnerJoinWithPk(): Unit = { val query1 = "SELECT SUM(a2) AS a2, a1 FROM A group by a1" diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java index 3764b886d00d5..f92c43684d874 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime.operators.join.stream; import org.apache.flink.api.common.functions.DefaultOpenContext; +import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; import org.apache.flink.streaming.api.operators.AbstractInput; import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2; import org.apache.flink.streaming.api.operators.Input; @@ -609,6 +610,10 @@ private Integer processRecords( } } + private boolean isHeapBackend() { + return getKeyedStateBackend() instanceof HeapKeyedStateBackend; + } + private static RowData newJoinedRowData(int depth, RowData joinedRowData, RowData record) { RowData newJoinedRowData; if (depth == 0) { @@ -808,6 +813,11 @@ private void initializeStateHandlers() { "Keyed state store not found when initializing keyed state store handlers."); } + boolean prohibitReuseRow = isHeapBackend(); + if (prohibitReuseRow) { + this.keyExtractor.requiresKeyDeepCopy(); + } + this.stateHandlers = new ArrayList<>(inputSpecs.size()); for (int i = 0; i < inputSpecs.size(); i++) { MultiJoinStateView stateView; diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java index 10a2a9f3e0cc9..6de9e570adc6f 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java @@ -20,6 +20,7 @@ import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.utils.NoCommonJoinKeyException; @@ -36,13 +37,23 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.stream.Collectors; /** - * A {@link JoinKeyExtractor} that derives keys based on {@link AttributeRef} mappings provided in - * {@code joinAttributeMap}. It defines how attributes from different input streams are related - * through equi-join conditions, assuming input 0 is the base and subsequent inputs join to - * preceding ones. + * A {@link JoinKeyExtractor} that derives keys from {@link AttributeRef} mappings in {@code + * joinAttributeMap}. It describes how attributes from different inputs are equated via equi-join + * conditions. Input 0 is the base; each subsequent input joins to one of the preceding inputs. + * + *

Example used throughout the comments: t1.id1 = t2.user_id2 and t3.user_id3 = t2.user_id2. All + * three attributes (t1.id1, t2.user_id2, t3.user_id3) represent the same conceptual key. + * + *

The {@code joinAttributeMap} for this example would be structured as follows: + * + *

*/ public class AttributeBasedJoinKeyExtractor implements JoinKeyExtractor, Serializable { private static final long serialVersionUID = 1L; @@ -50,11 +61,26 @@ public class AttributeBasedJoinKeyExtractor implements JoinKeyExtractor, Seriali private final Map> joinAttributeMap; private final List inputTypes; - // Cache for pre-computed key extraction structures - private final Map> inputIdToExtractorsMap; - private final Map> inputKeyFieldIndices; + // Cache for pre-computed key extraction structures. + // leftKeyExtractorsMap: extractors that read the left side (joined row so far) + // using the same attribute order as in joinAttributeMap. + private final Map> leftKeyExtractorsMap; + // RowData serializers for left key. + private final Map leftKeySerializersMap; + // rightKeyExtractorsMap: extractors to extract the right-side key from each input. + private final Map> rightKeyExtractorsMap; + // RowData serializers for right key. + private final Map rightKeySerializersMap; + + // Data structures for the "common join key" shared by all inputs. + // Input 0 provides the canonical order and defines commonJoinKeyType. private final Map> commonJoinKeyExtractors; + // RowData serializers for common join key. + private final Map commonJoinKeySerializersMap; private RowType commonJoinKeyType; + // Controls whether key rows built are serialized and copied. This is required for the + // Heap State Backend to prevent object reuse issues, which can lead to data corruption + private boolean requiresKeyDeepCopy; /** * Creates an AttributeBasedJoinKeyExtractor. @@ -70,18 +96,23 @@ public AttributeBasedJoinKeyExtractor( final List inputTypes) { this.joinAttributeMap = joinAttributeMap; this.inputTypes = inputTypes; - this.inputIdToExtractorsMap = new HashMap<>(); - this.inputKeyFieldIndices = new HashMap<>(); + this.leftKeyExtractorsMap = new HashMap<>(); + this.leftKeySerializersMap = new HashMap<>(); + this.rightKeyExtractorsMap = new HashMap<>(); + this.rightKeySerializersMap = new HashMap<>(); this.commonJoinKeyExtractors = new HashMap<>(); + this.commonJoinKeySerializersMap = new HashMap<>(); + this.requiresKeyDeepCopy = false; initializeCaches(); initializeCommonJoinKeyStructures(); + validateKeyStructures(); } // ==================== Public Interface Methods ==================== @Override - public RowData getJoinKey(RowData row, int inputId) { + public RowData getJoinKey(final RowData row, final int inputId) { if (inputId == 0) { return null; } @@ -91,50 +122,52 @@ public RowData getJoinKey(RowData row, int inputId) { return null; } - final List keyFieldIndices = inputKeyFieldIndices.get(inputId); - if (keyFieldIndices == null || keyFieldIndices.isEmpty()) { + final List keyExtractors = rightKeyExtractorsMap.get(inputId); + if (keyExtractors == null || keyExtractors.isEmpty()) { return null; } - return buildKeyRow(row, inputId, keyFieldIndices); + return buildKeyRowFromSourceRow(row, keyExtractors, rightKeySerializersMap.get(inputId)); } @Override - public RowData getLeftSideJoinKey(int depth, RowData joinedRowData) { + public RowData getLeftSideJoinKey(final int depth, final RowData joinedRowData) { if (depth == 0) { return null; } - List keyExtractors = inputIdToExtractorsMap.get(depth); + final List keyExtractors = leftKeyExtractorsMap.get(depth); if (keyExtractors == null || keyExtractors.isEmpty()) { return null; } - return buildKeyRow(keyExtractors, joinedRowData); + return buildKeyRowFromJoinedRow( + keyExtractors, joinedRowData, leftKeySerializersMap.get(depth)); } @Override @Nullable - public RowType getJoinKeyType(int inputId) { + public RowType getJoinKeyType(final int inputId) { if (inputId == 0) { return null; } - final List keyFieldIndices = createJoinKeyFieldInputExtractors(inputId); - if (keyFieldIndices.isEmpty()) { + final List keyExtractors = this.rightKeyExtractorsMap.get(inputId); + + if (keyExtractors == null || keyExtractors.isEmpty()) { return null; } - return buildJoinKeyType(inputId, keyFieldIndices); + return buildJoinKeyType(inputId, keyExtractors); } @Override - public int[] getJoinKeyIndices(int inputId) { - final List keyFieldIndices = this.inputKeyFieldIndices.get(inputId); + public int[] getJoinKeyIndices(final int inputId) { + final List keyFieldIndices = this.rightKeyExtractorsMap.get(inputId); if (keyFieldIndices == null) { return new int[0]; } - return keyFieldIndices.stream().mapToInt(i -> i).toArray(); + return keyFieldIndices.stream().mapToInt(KeyExtractor::getFieldIndexInSourceRow).toArray(); } @Override @@ -144,18 +177,18 @@ public RowType getCommonJoinKeyType() { } @Override - public @Nullable RowData getCommonJoinKey(RowData row, int inputId) { - List extractors = commonJoinKeyExtractors.get(inputId); + public @Nullable RowData getCommonJoinKey(final RowData row, final int inputId) { + final List extractors = commonJoinKeyExtractors.get(inputId); if (extractors == null || extractors.isEmpty()) { return null; } - return buildCommonJoinKey(row, extractors); + return buildKeyRowFromSourceRow(row, extractors, commonJoinKeySerializersMap.get(inputId)); } @Override - public int[] getCommonJoinKeyIndices(int inputId) { - List extractors = commonJoinKeyExtractors.get(inputId); + public int[] getCommonJoinKeyIndices(final int inputId) { + final List extractors = commonJoinKeyExtractors.get(inputId); if (extractors == null || extractors.isEmpty()) { return new int[0]; } @@ -163,60 +196,108 @@ public int[] getCommonJoinKeyIndices(int inputId) { return extractors.stream().mapToInt(KeyExtractor::getFieldIndexInSourceRow).toArray(); } + @Override + public void requiresKeyDeepCopy() { + this.requiresKeyDeepCopy = true; + } + // ==================== Initialization Methods ==================== private void initializeCaches() { if (this.inputTypes != null) { for (int i = 0; i < this.inputTypes.size(); i++) { - this.inputIdToExtractorsMap.put(i, createLeftJoinKeyFieldExtractors(i)); - this.inputKeyFieldIndices.put(i, createJoinKeyFieldInputExtractors(i)); + List leftJoinKeyExtractors = createLeftJoinKeyFieldExtractors(i); + this.leftKeyExtractorsMap.put(i, leftJoinKeyExtractors); + this.leftKeySerializersMap.put(i, createJoinKeySerializer(leftJoinKeyExtractors)); + + List rightJoinKeyExtractors = createRightJoinKeyExtractors(i); + this.rightKeyExtractorsMap.put(i, rightJoinKeyExtractors); + this.rightKeySerializersMap.put(i, createJoinKeySerializer(rightJoinKeyExtractors)); } } } - private List createLeftJoinKeyFieldExtractors(int depth) { - if (depth == 0) { + private List createLeftJoinKeyFieldExtractors(final int inputId) { + if (inputId == 0) { return Collections.emptyList(); } - List attributeMapping = joinAttributeMap.get(depth); + final List attributeMapping = joinAttributeMap.get(inputId); if (attributeMapping == null || attributeMapping.isEmpty()) { return Collections.emptyList(); } - List keyExtractors = new ArrayList<>(); - for (ConditionAttributeRef entry : attributeMapping) { - AttributeRef leftAttrRef = getLeftAttributeRef(depth, entry); + final List keyExtractors = new ArrayList<>(); + for (final ConditionAttributeRef entry : attributeMapping) { + final AttributeRef leftAttrRef = getLeftAttributeRef(inputId, entry); keyExtractors.add(createKeyExtractor(leftAttrRef)); } - keyExtractors.sort( - Comparator.comparingInt(KeyExtractor::getInputIdToAccess) - .thenComparingInt(KeyExtractor::getFieldIndexInSourceRow)); return keyExtractors; } - private static AttributeRef getLeftAttributeRef(int depth, ConditionAttributeRef entry) { - AttributeRef leftAttrRef = new AttributeRef(entry.leftInputId, entry.leftFieldIndex); - if (leftAttrRef.inputId >= depth) { + private List createRightJoinKeyExtractors(final int inputId) { + final List attributeMapping = joinAttributeMap.get(inputId); + if (attributeMapping == null) { + return Collections.emptyList(); + } + + final List keyExtractors = new ArrayList<>(); + for (final ConditionAttributeRef entry : attributeMapping) { + final AttributeRef rightAttrRef = getRightAttributeRef(inputId, entry); + keyExtractors.add(createKeyExtractor(rightAttrRef)); + } + + return keyExtractors; + } + + private RowDataSerializer createJoinKeySerializer(List keyExtractors) { + return new RowDataSerializer( + keyExtractors.stream().map(e -> e.fieldType).toArray(LogicalType[]::new)); + } + + private static AttributeRef getLeftAttributeRef( + final int inputId, final ConditionAttributeRef entry) { + final AttributeRef leftAttrRef = new AttributeRef(entry.leftInputId, entry.leftFieldIndex); + if (leftAttrRef.inputId >= inputId) { throw new IllegalStateException( - "Invalid joinAttributeMap configuration for depth " - + depth + "Invalid joinAttributeMap configuration for inputId " + + inputId + ". Left attribute " + leftAttrRef + " does not reference a previous input (< " - + depth + + inputId + ")."); } return leftAttrRef; } - private KeyExtractor createKeyExtractor(AttributeRef attrRef) { - RowType rowType = inputTypes.get(attrRef.inputId); + private static AttributeRef getRightAttributeRef( + final int inputId, final ConditionAttributeRef entry) { + final AttributeRef rightAttrRef = + new AttributeRef(entry.rightInputId, entry.rightFieldIndex); + // For a given join step (e.g., input 2 joining to a previous input), the "right" + // attribute in the join condition belongs to the current input (input 2). + if (rightAttrRef.inputId != inputId) { + throw new IllegalStateException( + "Invalid joinAttributeMap configuration for inputId " + + inputId + + ". Right attribute " + + rightAttrRef + + " must reference the current input (" + + inputId + + ")."); + } + return rightAttrRef; + } + + private KeyExtractor createKeyExtractor(final AttributeRef attrRef) { + final RowType rowType = inputTypes.get(attrRef.inputId); validateFieldIndex(attrRef.inputId, attrRef.fieldIndex, rowType); - LogicalType fieldType = rowType.getTypeAt(attrRef.fieldIndex); + final LogicalType fieldType = rowType.getTypeAt(attrRef.fieldIndex); - // Calculate absolute field index by summing up field counts of previous inputs + // Absolute field index within the concatenated joined row = + // current field index + sum(field counts of all previous inputs). int absoluteFieldIndex = attrRef.fieldIndex; for (int i = 0; i < attrRef.inputId; i++) { absoluteFieldIndex += inputTypes.get(i).getFieldCount(); @@ -225,71 +306,57 @@ private KeyExtractor createKeyExtractor(AttributeRef attrRef) { return new KeyExtractor(attrRef.inputId, attrRef.fieldIndex, absoluteFieldIndex, fieldType); } - private List createJoinKeyFieldInputExtractors(int inputId) { - final List attributeMapping = joinAttributeMap.get(inputId); - if (attributeMapping == null) { - return Collections.emptyList(); - } - - return attributeMapping.stream() - .filter(rightAttrRef -> rightAttrRef.rightInputId == inputId) - .map(rightAttrRef -> rightAttrRef.rightFieldIndex) - .distinct() - .sorted() - .collect(Collectors.toList()); - } - // ==================== Key Building Methods ==================== - private RowData buildKeyRow(List keyExtractors, RowData joinedRowData) { + private RowData buildKeyRowFromJoinedRow( + final List keyExtractors, + final RowData joinedRowData, + RowDataSerializer keySerializer) { if (keyExtractors.isEmpty()) { return null; } - GenericRowData keyRow = new GenericRowData(keyExtractors.size()); + final GenericRowData keyRow = new GenericRowData(keyExtractors.size()); for (int i = 0; i < keyExtractors.size(); i++) { keyRow.setField(i, keyExtractors.get(i).getLeftSideKey(joinedRowData)); } - return keyRow; - } - - private GenericRowData buildKeyRow( - RowData sourceRow, int inputId, List keyFieldIndices) { - final GenericRowData keyRow = new GenericRowData(keyFieldIndices.size()); - final RowType rowType = inputTypes.get(inputId); - - for (int i = 0; i < keyFieldIndices.size(); i++) { - final int fieldIndex = keyFieldIndices.get(i); - validateFieldIndex(inputId, fieldIndex, rowType); - - final LogicalType fieldType = rowType.getTypeAt(fieldIndex); - final RowData.FieldGetter fieldGetter = - RowData.createFieldGetter(fieldType, fieldIndex); - final Object value = fieldGetter.getFieldOrNull(sourceRow); - keyRow.setField(i, value); + if (requiresKeyDeepCopy) { + return keySerializer.toBinaryRow(keyRow, true); + } else { + return keyRow; } - return keyRow; } - private RowData buildCommonJoinKey(RowData row, List extractors) { - GenericRowData commonJoinKeyRow = new GenericRowData(extractors.size()); + private RowData buildKeyRowFromSourceRow( + final RowData sourceRow, + final List keyExtractors, + RowDataSerializer keySerializer) { + if (keyExtractors.isEmpty()) { + return null; + } - for (int i = 0; i < extractors.size(); i++) { - commonJoinKeyRow.setField(i, extractors.get(i).getRightSideKey(row)); + final GenericRowData keyRow = new GenericRowData(keyExtractors.size()); + for (int i = 0; i < keyExtractors.size(); i++) { + keyRow.setField(i, keyExtractors.get(i).getRightSideKey(sourceRow)); + } + if (requiresKeyDeepCopy) { + return keySerializer.toBinaryRow(keyRow, true); + } else { + return keyRow; } - return commonJoinKeyRow; } - private RowType buildJoinKeyType(int inputId, List keyFieldIndices) { + private RowType buildJoinKeyType(final int inputId, final List keyExtractors) { final RowType originalRowType = inputTypes.get(inputId); - final LogicalType[] keyTypes = new LogicalType[keyFieldIndices.size()]; - final String[] keyNames = new String[keyFieldIndices.size()]; + final LogicalType[] keyTypes = new LogicalType[keyExtractors.size()]; + final String[] keyNames = new String[keyExtractors.size()]; - for (int i = 0; i < keyFieldIndices.size(); i++) { - final int fieldIndex = keyFieldIndices.get(i); + for (int i = 0; i < keyExtractors.size(); i++) { + final KeyExtractor extractor = keyExtractors.get(i); + final int fieldIndex = extractor.getFieldIndexInSourceRow(); validateFieldIndex(inputId, fieldIndex, originalRowType); - keyTypes[i] = originalRowType.getTypeAt(fieldIndex); + keyTypes[i] = extractor.fieldType; keyNames[i] = originalRowType.getFieldNames().get(fieldIndex) + "_key"; } @@ -298,6 +365,34 @@ private RowType buildJoinKeyType(int inputId, List keyFieldIndices) { // ==================== Common Key Methods ==================== + /** + * Builds the data structures that describe the common join key shared by all inputs. + * + *

Algorithm: + * + *

    + *
  1. Collect all attributes referenced by any equality condition. + *
  2. Run union-find: for each equi-join condition, union the two attributes so that directly + * or transitively equal attributes end up in the same set. + *
  3. Form equivalence sets by grouping attributes that share the same union-find root. + *
  4. Keep only those sets that touch every join step (see {@link + * #isCommonConceptualAttributeSet(Set)}). Each kept set is one conceptual key common to + * all inputs. + *
  5. For each input, pick that input's attributes from the kept sets (in a consistent order) + * and create per-input key extractors. For input 0, also build the {@link RowType} of the + * common key. + *
+ * + *

Example (same throughout): + * + *

    + *
  1. Join conditions: t1.id1 = t2.user_id2 and t3.user_id3 = t2.user_id2. + *
  2. Union-find groups attributes into { (t1,id1), (t2,user_id2), (t3,user_id3) }. + *
  3. The group touches both steps (t2↔t1 and t3↔t2), so it is the common key. + *
  4. Create extractors for t1.id1, t2.user_id2, and t3.user_id3. + *
  5. Define a one-field commonJoinKeyType based on t1.id1. + *
+ */ private void initializeCommonJoinKeyStructures() { this.commonJoinKeyType = null; @@ -312,28 +407,43 @@ private void initializeCommonJoinKeyStructures() { return; } - Map parent = new HashMap<>(); - Map rank = new HashMap<>(); - Set allAttrRefs = collectAllAttributeRefs(); + // Maps an attribute to the representative of its equivalence set. + final Map attributeToRoot = new HashMap<>(); + // Used for the union-by-rank optimization to keep the union-find tree shallow. + final Map rootRank = new HashMap<>(); + final Set allAttrRefs = collectAllAttributeRefs(); if (allAttrRefs.isEmpty()) { return; } - initializeDisjointSets(parent, rank, allAttrRefs); - processJoinConditions(parent, rank); - Map> equivalenceSets = - buildEquivalenceSets(parent, allAttrRefs); - List> commonConceptualAttributeSets = - findCommonConceptualAttributeSets(equivalenceSets); + initializeDisjointSets(attributeToRoot, rootRank, allAttrRefs); + findAttributeRoots(attributeToRoot, rootRank); + final Map> equivalenceSets = + buildEquivalenceSets(attributeToRoot, allAttrRefs); - processCommonAttributes(commonConceptualAttributeSets); + // From all equivalence sets, keep only those that touch every join step. + final List> commonConceptualAttributeSets = + findCommonConceptualAttributeSets(equivalenceSets); + processCommonAttributes(commonConceptualAttributeSets, attributeToRoot); } + /** + * Returns every attribute reference (input id + field index) present in the configured equality + * conditions. + * + *

Example: with t1.id1 = t2.user_id2 and t3.user_id3 = t2.user_id2, this returns { (t1,id1), + * (t2,user_id2), (t3,user_id3) }. + */ private Set collectAllAttributeRefs() { - Set allAttrRefs = new HashSet<>(); - for (List mapping : joinAttributeMap.values()) { - for (ConditionAttributeRef attrRef : mapping) { + final Set allAttrRefs = new HashSet<>(); + for (final Map.Entry> entry : + joinAttributeMap.entrySet()) { + // Skip joinAttributeMap for key 0 where there are no join conditions to the left + if (entry.getKey() == 0) { + continue; + } + for (final ConditionAttributeRef attrRef : entry.getValue()) { allAttrRefs.add(new AttributeRef(attrRef.leftInputId, attrRef.leftFieldIndex)); allAttrRefs.add(new AttributeRef(attrRef.rightInputId, attrRef.rightFieldIndex)); } @@ -341,43 +451,70 @@ private Set collectAllAttributeRefs() { return allAttrRefs; } + /** + * Initializes union-find: every attribute is the root of its own set with rank 0. + * + *

Example: attributes {(t1,id1), (t2,user_id2), (t3,user_id3)} start with attributeToRoot[a] + * = a. + */ private void initializeDisjointSets( - Map parent, - Map rank, - Set allAttrRefs) { - for (AttributeRef attrRef : allAttrRefs) { - parent.put(attrRef, attrRef); - rank.put(attrRef, 0); + final Map attributeToRoot, + final Map rootRank, + final Set allAttrRefs) { + for (final AttributeRef attrRef : allAttrRefs) { + attributeToRoot.put(attrRef, attrRef); + rootRank.put(attrRef, 0); } } - private void processJoinConditions( - Map parent, Map rank) { - for (List mapping : joinAttributeMap.values()) { - for (ConditionAttributeRef condition : mapping) { + /** + * Applies all equi-join conditions to union-find by uniting the corresponding attribute sets. + * + *

Example: unite (t2,user_id2) with (t1,id1), then (t3,user_id3) with (t2,user_id2), + * yielding one set { (t1,id1), (t2,user_id2), (t3,user_id3) }. + */ + private void findAttributeRoots( + final Map attributeToRoot, + final Map rootRank) { + for (final Map.Entry> entry : + joinAttributeMap.entrySet()) { + // Skip joinAttributeMap for key 0 where there are no join conditions to the left + if (entry.getKey() == 0) { + continue; + } + for (final ConditionAttributeRef condition : entry.getValue()) { unionAttributeSets( - parent, - rank, + attributeToRoot, + rootRank, new AttributeRef(condition.leftInputId, condition.leftFieldIndex), new AttributeRef(condition.rightInputId, condition.rightFieldIndex)); } } } + /** Converts the union-find forest into a map from root → full equivalence set. */ private Map> buildEquivalenceSets( - Map parent, Set allAttrRefs) { - Map> equivalenceSets = new HashMap<>(); - for (AttributeRef attrRef : allAttrRefs) { - AttributeRef root = findAttributeSet(parent, attrRef); + final Map attributeToRoot, + final Set allAttrRefs) { + final Map> equivalenceSets = new HashMap<>(); + for (final AttributeRef attrRef : allAttrRefs) { + final AttributeRef root = findAttributeSet(attributeToRoot, attrRef); equivalenceSets.computeIfAbsent(root, k -> new HashSet<>()).add(attrRef); } return equivalenceSets; } + /** + * From all equivalence sets, keep only those that intersect every join step. A set is "common" + * only if each step has at least one attribute in the set. + * + *

Example: with steps (t2.user_id2 = t1.id1) and (t3.user_id3 = t2.user_id2), { (t1,id1), + * (t2,user_id2), (t3,user_id3) } is kept; a set touching only one step is discarded. + */ private List> findCommonConceptualAttributeSets( - Map> equivalenceSets) { - List> commonConceptualAttributeSets = new ArrayList<>(); - for (Set eqSet : equivalenceSets.values()) { + final Map> equivalenceSets) { + final List> commonConceptualAttributeSets = new ArrayList<>(); + for (final Set eqSet : equivalenceSets.values()) { if (isCommonConceptualAttributeSet(eqSet)) { commonConceptualAttributeSets.add(eqSet); } @@ -385,18 +522,25 @@ private List> findCommonConceptualAttributeSets( return commonConceptualAttributeSets; } - private boolean isCommonConceptualAttributeSet(Set eqSet) { + /** + * Returns true if the given equivalence set intersects every configured join step. For each + * step, we check whether the set contains any left or right attribute of that step. + * + *

Example: with steps (t2.user_id2 = t1.id1) and (t3.user_id3 = t2.user_id2): - { (t1,id1), + * (t2,user_id2), (t3,user_id3) } → true - { (t1,other), (t2,other) } → false + */ + private boolean isCommonConceptualAttributeSet(final Set eqSet) { if (joinAttributeMap.isEmpty()) { return false; } - for (List conditionsForStep : joinAttributeMap.values()) { + for (final List conditionsForStep : joinAttributeMap.values()) { if (conditionsForStep.isEmpty()) { return false; } boolean foundInThisStep = false; - for (ConditionAttributeRef condition : conditionsForStep) { + for (final ConditionAttributeRef condition : conditionsForStep) { if (eqSet.contains( new AttributeRef(condition.leftInputId, condition.leftFieldIndex)) || eqSet.contains( @@ -413,10 +557,20 @@ private boolean isCommonConceptualAttributeSet(Set eqSet) { return true; } - private void processCommonAttributes(List> commonConceptualAttributeSets) { + /** + * For each input, select its attributes that belong to the common sets and initialize + * extractors. If any input contributes none, the multi-join is invalid (exception). + * + *

Example: from { (t1,id1), (t2,user_id2), (t3,user_id3) }, input 0 uses id1, input 1 uses + * user_id2, input 2 uses user_id3. If some input had none, we would throw an exception. + */ + private void processCommonAttributes( + final List> commonConceptualAttributeSets, + final Map attributeToRoot) { for (int currentInputId = 0; currentInputId < inputTypes.size(); currentInputId++) { final List commonAttrsForThisInput = - findCommonAttributesForInput(currentInputId, commonConceptualAttributeSets); + findCommonAttributesForInput( + currentInputId, commonConceptualAttributeSets, attributeToRoot); if (commonAttrsForThisInput.isEmpty()) { // This indicates that there is no common join key among all inputs. @@ -433,40 +587,70 @@ private void processCommonAttributes(List> commonConceptualAtt } } + /** + * For a given input, pick at most one attribute per common set (first match), then sort by + * field index. This defines the field order of the common key for that input. + * + *

Example: if input 1 participates in two common sets with fields user_id and account_id, + * this returns [account_id, user_id] if account_id's index < user_id's index. + */ private List findCommonAttributesForInput( - int currentInputId, List> commonConceptualAttributeSets) { - List commonAttrsForThisInput = new ArrayList<>(); - for (Set eqSet : commonConceptualAttributeSets) { - for (AttributeRef attrRef : eqSet) { + final int currentInputId, + final List> commonConceptualAttributeSets, + final Map attributeToRoot) { + final List commonAttrsForThisInput = new ArrayList<>(); + for (final Set eqSet : commonConceptualAttributeSets) { + for (final AttributeRef attrRef : eqSet) { if (attrRef.inputId == currentInputId) { commonAttrsForThisInput.add(attrRef); break; } } } - commonAttrsForThisInput.sort(Comparator.comparingInt(attr -> attr.fieldIndex)); + + // Important: ensure a consistent conceptual attribute ordering derived from roots. + // The common key fields must have a canonical order across all inputs. This is + // achieved by sorting based on the properties of the union-find root attribute + // for each conceptual key set. We sort first by the root's field index and then + // by the current attribute's field index as a tie-breaker for stability. This + // ensures that input 0's attribute order defines the canonical order. + commonAttrsForThisInput.sort( + Comparator.comparingInt( + a -> { + AttributeRef root = attributeToRoot.get(a); + return root != null ? root.fieldIndex : -1; + }) + // This is for stable ordering when two roots happen to have the same + // fieldIndex + .thenComparingInt(a -> a.fieldIndex)); + return commonAttrsForThisInput; } + /** + * Creates extractor objects and key type metadata for an input. Input 0 defines {@code + * commonJoinKeyType} shared across all inputs. + * + *

Example: for input 1 with common attributes [user_id], we create an extractor for that + * field and name it "user_id_common" with its logical type. + */ private void processInputCommonAttributes( - int currentInputId, List commonAttrsForThisInput) { + final int currentInputId, final List commonAttrsForThisInput) { final List extractors = new ArrayList<>(); - final LogicalType[] keyFieldTypes = new LogicalType[commonAttrsForThisInput.size()]; final String[] keyFieldNames = new String[commonAttrsForThisInput.size()]; final RowType originalRowType = inputTypes.get(currentInputId); for (int i = 0; i < commonAttrsForThisInput.size(); i++) { final AttributeRef attr = commonAttrsForThisInput.get(i); - validateFieldIndex(currentInputId, attr.fieldIndex, originalRowType); - final LogicalType fieldType = originalRowType.getTypeAt(attr.fieldIndex); - extractors.add( - new KeyExtractor(currentInputId, attr.fieldIndex, attr.fieldIndex, fieldType)); - keyFieldTypes[i] = fieldType; + extractors.add(createKeyExtractor(attr)); keyFieldNames[i] = originalRowType.getFieldNames().get(attr.fieldIndex) + "_common"; } this.commonJoinKeyExtractors.put(currentInputId, extractors); + final LogicalType[] keyFieldTypes = + extractors.stream().map(e -> e.fieldType).toArray(LogicalType[]::new); + this.commonJoinKeySerializersMap.put(currentInputId, new RowDataSerializer(keyFieldTypes)); if (currentInputId == 0 && !extractors.isEmpty()) { this.commonJoinKeyType = RowType.of(keyFieldTypes, keyFieldNames); } @@ -474,7 +658,8 @@ private void processInputCommonAttributes( // ==================== Helper Methods ==================== - private void validateFieldIndex(int inputId, int fieldIndex, RowType rowType) { + private void validateFieldIndex( + final int inputId, final int fieldIndex, final RowType rowType) { if (fieldIndex >= rowType.getFieldCount() || fieldIndex < 0) { throw new IndexOutOfBoundsException( "joinAttributeMap references field index " @@ -487,39 +672,179 @@ private void validateFieldIndex(int inputId, int fieldIndex, RowType rowType) { } private static AttributeRef findAttributeSet( - Map parent, AttributeRef item) { - if (!parent.get(item).equals(item)) { - parent.put(item, findAttributeSet(parent, parent.get(item))); + final Map attributeToRoot, final AttributeRef item) { + if (!attributeToRoot.get(item).equals(item)) { + attributeToRoot.put(item, findAttributeSet(attributeToRoot, attributeToRoot.get(item))); } - return parent.get(item); + return attributeToRoot.get(item); } private static void unionAttributeSets( - Map parent, - Map rank, - AttributeRef a, - AttributeRef b) { - AttributeRef rootA = findAttributeSet(parent, a); - AttributeRef rootB = findAttributeSet(parent, b); - + final Map attributeToRoot, + final Map rootRank, + final AttributeRef a, + final AttributeRef b) { + final AttributeRef rootA = findAttributeSet(attributeToRoot, a); + final AttributeRef rootB = findAttributeSet(attributeToRoot, b); + + // Standard union-by-rank implementation to merge sets. The resulting root + // depends on the rank of the sets, not on attribute properties like inputId. + // A subsequent sorting step establishes a canonical ordering for the common key fields. + // Following this logic, the root will be the common join key attributes from input 0 if (!rootA.equals(rootB)) { - if (rank.get(rootA) < rank.get(rootB)) { - parent.put(rootA, rootB); - } else if (rank.get(rootA) > rank.get(rootB)) { - parent.put(rootB, rootA); + if (rootRank.get(rootA) < rootRank.get(rootB)) { + attributeToRoot.put(rootA, rootB); + } else if (rootRank.get(rootA) > rootRank.get(rootB)) { + attributeToRoot.put(rootB, rootA); } else { - parent.put(rootB, rootA); - rank.put(rootA, rank.get(rootA) + 1); + attributeToRoot.put(rootB, rootA); + rootRank.put(rootA, rootRank.get(rootA) + 1); } } } + // ==================== Validation Methods ==================== + + /** + * Validates internal key structures for consistency: + * + *

    + *
  1. For every input id, the number of left-side extractors equals the number of right-side + * key-field indices. + *
  2. If a common join key is defined, then for every input id the number of common key + * extractors equals the number of fields in {@code commonJoinKeyType}. + *
  3. If a common join key is defined, each extractor's logical type matches the + * corresponding field type in {@code commonJoinKeyType}. + *
+ * + *

Throws {@link IllegalStateException} if any inconsistency is found. + */ + public void validateKeyStructures() { + final int numInputs = inputTypes == null ? 0 : inputTypes.size(); + + for (int inputId = 0; inputId < numInputs; inputId++) { + if (inputId == 0) { + if (!leftKeyExtractorsMap.get(inputId).isEmpty()) { + throw new IllegalStateException( + "Input 0 should not have left key extractors, but found left extractors " + + leftKeyExtractorsMap.get(inputId) + + "."); + } + + // We skip validating the extracted keys type equality for input 0 + // because it has no left-side extractors. + continue; + } + + final List leftExtractors = leftKeyExtractorsMap.get(inputId); + final List rightExtractors = rightKeyExtractorsMap.get(inputId); + + final int extractorsLength = + validateExtractorsLength(leftExtractors, rightExtractors, inputId); + + for (int j = 0; j < extractorsLength; j++) { + final KeyExtractor rightExtractor = rightExtractors.get(j); + final KeyExtractor leftExtractor = leftExtractors.get(j); + + if (leftExtractor == null || rightExtractor == null) { + throw new IllegalStateException( + "Null extractor found when validating key structures for field " + + j + + " on input " + + inputId + + ": left extractor " + + leftExtractor + + ", right extractor " + + rightExtractor + + "."); + } + + final LogicalType leftType = leftExtractor.fieldType; + final LogicalType rightType = rightExtractor.fieldType; + if (!Objects.equals(leftType.getTypeRoot(), rightType.getTypeRoot())) { + throw new IllegalStateException( + "Type mismatch for join key field " + + j + + " on input " + + inputId + + ": left type " + + leftType.getTypeRoot() + + " vs right type " + + rightType.getTypeRoot() + + "."); + } + } + } + + if (this.commonJoinKeyType != null) { + final int expectedCommonFields = this.commonJoinKeyType.getFieldCount(); + + for (int inputId = 0; inputId < numInputs; inputId++) { + final List commonExtractors = commonJoinKeyExtractors.get(inputId); + final int actual = commonExtractors == null ? 0 : commonExtractors.size(); + + if (actual != expectedCommonFields) { + throw new IllegalStateException( + "Mismatch in common key counts for input " + + inputId + + ": extractors (" + + actual + + ") vs commonJoinKeyType fields (" + + expectedCommonFields + + ")."); + } + + for (int i = 0; i < expectedCommonFields; i++) { + final LogicalType extractorType = commonExtractors.get(i).fieldType; + final LogicalType expectedType = this.commonJoinKeyType.getTypeAt(i); + + if (!Objects.equals(extractorType.getTypeRoot(), expectedType.getTypeRoot())) { + throw new IllegalStateException( + "Type mismatch for common key field " + + i + + " on input " + + inputId + + ": extractor type " + + extractorType.getTypeRoot() + + " vs commonJoinKeyType " + + expectedType.getTypeRoot() + + "."); + } + } + } + } + } + + private static int validateExtractorsLength( + final List leftExtractors, + final List rightExtractors, + final int inputId) { + final int leftSize = leftExtractors == null ? 0 : leftExtractors.size(); + final int rightSize = rightExtractors == null ? 0 : rightExtractors.size(); + + if (leftSize != rightSize) { + throw new IllegalStateException( + "Mismatch in key counts for input " + + inputId + + ": left extractors (" + + leftSize + + ") vs right extractors (" + + rightSize + + ")."); + } + return leftSize; + } + // ==================== Inner Classes ==================== - /** Helper class to store pre-computed information for extracting a key part. */ - // TODO we actually need int[] for the indices - // because we can have multiple common join keys as fields - // this whole file will be refactored in a next ticket + /** + * Helper class to store pre-computed information for extracting a key part. + * + *

This class uses two separate {@link RowData.FieldGetter} instances because a key extractor + * may need to extract a field from the left (already joined) side, which is a wide, + * concatenated row, or the right (current input) side, which is a single source row. The field + * indices for these two cases are different. + */ private static final class KeyExtractor implements Serializable { private static final long serialVersionUID = 1L; @@ -527,41 +852,57 @@ private static final class KeyExtractor implements Serializable { private final int fieldIndexInSourceRow; private final int absoluteFieldIndex; private final LogicalType fieldType; - private transient RowData.FieldGetter fieldGetter; + private transient RowData.FieldGetter leftFieldGetter; + private transient RowData.FieldGetter rightFieldGetter; public KeyExtractor( - int inputIdToAccess, - int fieldIndexInSourceRow, - int absoluteFieldIndex, - LogicalType fieldType) { + final int inputIdToAccess, + final int fieldIndexInSourceRow, + final int absoluteFieldIndex, + final LogicalType fieldType) { this.inputIdToAccess = inputIdToAccess; this.fieldIndexInSourceRow = fieldIndexInSourceRow; this.absoluteFieldIndex = absoluteFieldIndex; this.fieldType = fieldType; - this.fieldGetter = + this.rightFieldGetter = RowData.createFieldGetter(this.fieldType, this.fieldIndexInSourceRow); - } - - public Object getRightSideKey(RowData joinedRowData) { - if (joinedRowData == null) { + this.leftFieldGetter = + RowData.createFieldGetter(this.fieldType, this.absoluteFieldIndex); + } + + /** + * Returns the key value from a single input row (right/current side of the current join + * step). + * + *

Example: for a step joining t1 (input 0) and t2 (input 1) t1.id1 = t2.user_id2, the + * extractor configured for t2 (input 1) returns t2.user_id2 value from the provided t2 row. + */ + public Object getRightSideKey(final RowData row) { + if (row == null) { return null; } - if (this.fieldGetter == null) { - this.fieldGetter = + if (this.rightFieldGetter == null) { + this.rightFieldGetter = RowData.createFieldGetter(this.fieldType, this.fieldIndexInSourceRow); } - return this.fieldGetter.getFieldOrNull(joinedRowData); + return this.rightFieldGetter.getFieldOrNull(row); } - public Object getLeftSideKey(RowData joinedRowData) { + /** + * Returns the key value from the already-joined row (left side of the current join step). + * + *

Example: for a step joining t1 (input 0) and t2 (input 1) t1.id1 = t2.user_id2, the + * extractor configured for t2 (input 1) returns t1.id1 value from the provided t1 row. + */ + public Object getLeftSideKey(final RowData joinedRowData) { if (joinedRowData == null) { return null; } - if (this.fieldGetter == null) { - this.fieldGetter = + if (this.leftFieldGetter == null) { + this.leftFieldGetter = RowData.createFieldGetter(this.fieldType, this.absoluteFieldIndex); } - return this.fieldGetter.getFieldOrNull(joinedRowData); + return this.leftFieldGetter.getFieldOrNull(joinedRowData); } public int getInputIdToAccess() { @@ -572,12 +913,14 @@ public int getFieldIndexInSourceRow() { return fieldIndexInSourceRow; } - private void readObject(java.io.ObjectInputStream in) + private void readObject(final java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { in.defaultReadObject(); if (this.fieldType != null) { - this.fieldGetter = + this.rightFieldGetter = RowData.createFieldGetter(this.fieldType, this.fieldIndexInSourceRow); + this.leftFieldGetter = + RowData.createFieldGetter(this.fieldType, this.absoluteFieldIndex); } } } @@ -591,20 +934,20 @@ public AttributeRef() { // Default constructor for deserialization } - public AttributeRef(int inputId, int fieldIndex) { + public AttributeRef(final int inputId, final int fieldIndex) { this.inputId = inputId; this.fieldIndex = fieldIndex; } @Override - public boolean equals(Object o) { + public boolean equals(final Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } - AttributeRef that = (AttributeRef) o; + final AttributeRef that = (AttributeRef) o; return inputId == that.inputId && fieldIndex == that.fieldIndex; } @@ -631,7 +974,10 @@ public ConditionAttributeRef() { } public ConditionAttributeRef( - int leftInputId, int leftFieldIndex, int rightInputId, int rightFieldIndex) { + final int leftInputId, + final int leftFieldIndex, + final int rightInputId, + final int rightFieldIndex) { this.leftInputId = leftInputId; this.leftFieldIndex = leftFieldIndex; this.rightInputId = rightInputId; @@ -639,14 +985,14 @@ public ConditionAttributeRef( } @Override - public boolean equals(Object o) { + public boolean equals(final Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } - ConditionAttributeRef that = (ConditionAttributeRef) o; + final ConditionAttributeRef that = (ConditionAttributeRef) o; return leftInputId == that.leftInputId && leftFieldIndex == that.leftFieldIndex && rightInputId == that.rightInputId diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java index 7e6f91d6ff728..4fb96d28ed182 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java @@ -109,4 +109,7 @@ public interface JoinKeyExtractor extends Serializable { * common join key. */ int[] getCommonJoinKeyIndices(int inputId); + + /** Enables copying of row data. */ + void requiresKeyDeepCopy(); } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java index 3106d9e65d5c3..a83ea7a5e5546 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java @@ -188,10 +188,14 @@ public int getArity() { /** Convert {@link RowData} into {@link BinaryRowData}. TODO modify it to code gen. */ @Override public BinaryRowData toBinaryRow(RowData row) { + return toBinaryRow(row, false); + } + + public BinaryRowData toBinaryRow(RowData row, boolean requiresDeepCopy) { if (row instanceof BinaryRowData) { return (BinaryRowData) row; } - if (reuseRow == null) { + if (reuseRow == null || requiresDeepCopy) { reuseRow = new BinaryRowData(types.length); reuseWriter = new BinaryRowWriter(reuseRow); } diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingFourWayMixedOuterJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingFourWayMixedOuterJoinOperatorTest.java index 9b5faad697e76..cfa5b8ebf9ce6 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingFourWayMixedOuterJoinOperatorTest.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingFourWayMixedOuterJoinOperatorTest.java @@ -25,6 +25,7 @@ import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.CharType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.VarCharType; @@ -1181,24 +1182,25 @@ protected RowType createInputTypeInfo(int inputIndex) { String userIdFieldName = String.format("user_id_%d", inputIndex); String pkFieldName = String.format("pk_%d", inputIndex); // Generic PK name - if (inputIndex == 0) { // Users: user_id (VARCHAR), pk (VARCHAR), details (BIGINT) + if (inputIndex == 0) { // Users: user_id (CHAR NOT NULL), pk (VARCHAR), details (BIGINT) return RowType.of( new LogicalType[] { - VarCharType.STRING_TYPE, VarCharType.STRING_TYPE, new BigIntType() + new CharType(false, 20), VarCharType.STRING_TYPE, new BigIntType() }, new String[] { userIdFieldName, pkFieldName, String.format("details_%d", inputIndex) }); } else if (inputIndex - == 3) { // Shipments: user_id (VARCHAR), pk (VARCHAR), details (BIGINT) + == 3) { // Shipments: user_id (CHAR NOT NULL), pk (VARCHAR), details (BIGINT) return RowType.of( new LogicalType[] { - VarCharType.STRING_TYPE, VarCharType.STRING_TYPE, new BigIntType() + new CharType(false, 20), VarCharType.STRING_TYPE, new BigIntType() }, new String[] { userIdFieldName, pkFieldName, String.format("details_%d", inputIndex) }); - } else { // Orders (1), Payments (2): user_id (VARCHAR), pk (VARCHAR), details (VARCHAR) + } else { // Orders (1), Payments (2): user_id (CHAR NOT NULL), pk (CHAR NOT NULL), details + // (VARCHAR) return super.createInputTypeInfo(inputIndex); } } diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java index b09c7f13a9554..27861fafb11c2 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java @@ -362,6 +362,7 @@ private static class SerializableKeySelector public SerializableKeySelector(JoinKeyExtractor keyExtractor, int inputIndex) { this.keyExtractor = keyExtractor; + this.keyExtractor.requiresKeyDeepCopy(); this.inputIndex = inputIndex; }