Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,35 @@ case class ColumnarSubqueryBroadcastExec(
val relation = child.executeBroadcast[Any]().value
relation match {
case b: BuildSideRelation =>
val index = indices(0) // TODO(): fixme
// Transform columnar broadcast value to Array[InternalRow] by key.
if (canRewriteAsLongType(buildKeys)) {
b.transform(HashJoin.extractKeyExprAt(buildKeys, index)).distinct
// Build key expressions for all indices (multi-key DPP support).
val keyExprs = if (canRewriteAsLongType(buildKeys)) {
indices.map(idx => HashJoin.extractKeyExprAt(buildKeys, idx))
} else {
b.transform(
BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable))
.distinct
indices.map {
idx =>
BoundReference(
idx,
buildKeys(idx).dataType,
buildKeys(idx).nullable): Expression
}
}
if (keyExprs.size == 1) {
b.transform(keyExprs.head).distinct
} else {
// For multi-key DPP, pack all keys into a struct via transform(),
// then flatten back to individual columns to match the expected
// output schema (InSubqueryExec expects N-column rows).
val structExpr = CreateStruct(keyExprs)
val structRows = b.transform(structExpr)
val structType = structExpr.dataType
val flattenExprs = keyExprs.indices.map {
i =>
GetStructField(
BoundReference(0, structType, nullable = true),
i): Expression
}
val flatProj = UnsafeProjection.create(flattenExprs)
structRows.map(r => flatProj(r).copy()).distinct
}
case h: HashedRelation =>
val (iter, exprs) = if (h.isInstanceOf[LongHashedRelation]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,38 @@ class GlutenDynamicPartitionPruningV1SuiteAEOn
}
}
}

testGluten("multi-key DPP with columnar broadcast") {
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
SQLConf.ANSI_ENABLED.key -> "false"
) {
withTable("fact_mk", "dim_mk") {
sql("""CREATE TABLE fact_mk (id BIGINT, value INT, a STRING, b STRING)
|USING parquet PARTITIONED BY (a, b)""".stripMargin)
sql(
"INSERT INTO fact_mk VALUES " +
(0 until 10).map(i => s"($i, 1, '1', '1')").mkString(", "))
sql(
"INSERT INTO fact_mk VALUES " +
(0 until 10).map(i => s"($i, 2, '2', '2')").mkString(", "))
sql(
"INSERT INTO fact_mk VALUES " +
(0 until 10).map(i => s"($i, 3, '3', '3')").mkString(", "))

sql("CREATE TABLE dim_mk (x STRING, y STRING, z INT) USING parquet")
sql("INSERT INTO dim_mk VALUES ('1', '1', 10), ('2', '2', 20)")

val df = sql("""SELECT f.id, f.a, f.b FROM fact_mk f
|JOIN dim_mk d ON f.a = d.x AND f.b = d.y
|WHERE d.z < 15""".stripMargin)

val result = df.collect()
assert(result.length == 10)
checkAnswer(df, result)
}
}
}
}

abstract class GlutenDynamicPartitionPruningV2Suite extends GlutenDynamicPartitionPruningSuiteBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,40 @@ class GlutenDynamicPartitionPruningV1SuiteAEOn
}
}
}

testGluten("multi-key DPP with columnar broadcast") {
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
SQLConf.ANSI_ENABLED.key -> "false"
) {
withTable("fact_mk", "dim_mk") {
sql("""CREATE TABLE fact_mk (id BIGINT, value INT,
| a STRING, b STRING) USING parquet
|PARTITIONED BY (a, b)""".stripMargin)
sql(
"INSERT INTO fact_mk VALUES " +
(0 until 10).map(i => s"($i, 1, '1', '1')").mkString(", "))
sql(
"INSERT INTO fact_mk VALUES " +
(0 until 10).map(i => s"($i, 2, '2', '2')").mkString(", "))
sql(
"INSERT INTO fact_mk VALUES " +
(0 until 10).map(i => s"($i, 3, '3', '3')").mkString(", "))

sql("""CREATE TABLE dim_mk (x STRING, y STRING, z INT)
|USING parquet""".stripMargin)
sql("INSERT INTO dim_mk VALUES ('1','1',10), ('2','2',20)")

val df = sql("""SELECT f.id, f.a, f.b FROM fact_mk f
|JOIN dim_mk d ON f.a = d.x AND f.b = d.y
|WHERE d.z < 15""".stripMargin)

val result = df.collect()
assert(result.length == 10)
checkAnswer(df, result)
}
}
}
}

abstract class GlutenDynamicPartitionPruningV2Suite extends GlutenDynamicPartitionPruningSuiteBase {
Expand Down
Loading