Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJo
import org.apache.spark.sql.catalyst.plans.logical.EmptyRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ProjectExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
Expand Down Expand Up @@ -52,19 +53,37 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
// - positive value means an estimated row count which can be over-estimated
// - none means the plan has not materialized or the plan can not be estimated
private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized =>
case LogicalQueryStage(_, physicalPlan) =>
getEstimatedRowCount(physicalPlan)

case _: EmptyRelation => Some(0)

case _ => None
}

private def getEstimatedRowCount(plan: SparkPlan): Option[BigInt] = plan match {
case stage: QueryStageExec if stage.isMaterialized =>
stage.getRuntimeStatistics.rowCount

case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty &&
agg.child.isInstanceOf[QueryStageExec] =>
val stage = agg.child.asInstanceOf[QueryStageExec]
if (stage.isMaterialized) {
stage.getRuntimeStatistics.rowCount
} else {
None
case read: AQEShuffleReadExec =>
getEstimatedRowCount(read.child)

case sort: SortExec =>
getEstimatedRowCount(sort.child)

case project: ProjectExec =>
getEstimatedRowCount(project.child)

case columnarToRow: ColumnarToRowExec =>
getEstimatedRowCount(columnarToRow.child)

case aggregate: BaseAggregateExec if aggregate.groupingExpressions.isEmpty =>
getEstimatedRowCount(aggregate.child).map { rowCount =>
if (rowCount == 0) BigInt(1) else rowCount
}

case _: EmptyRelation => Some(0)
case aggregate: BaseAggregateExec =>
getEstimatedRowCount(aggregate.child)

case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ case class AdaptiveSparkPlanExec(
var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true)
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
val errors = new mutable.ArrayBuffer[Throwable]()
val obsoleteCancelledStageIds = new mutable.HashSet[Int]
var stagesToReplace = Seq.empty[QueryStageExec]
while (!result.allChildStagesMaterialized) {
currentPhysicalPlan = result.newPlan
Expand Down Expand Up @@ -333,7 +334,9 @@ case class AdaptiveSparkPlanExec(
stage.resultOption.set(Some(res))
case StageFailure(stage, ex) =>
stage.error.set(Some(ex))
errors.append(ex)
if (!obsoleteCancelledStageIds.contains(stage.id)) {
errors.append(ex)
}
}

// In case of errors, we cancel all running stages and throw exception.
Expand Down Expand Up @@ -373,6 +376,7 @@ case class AdaptiveSparkPlanExec(
currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
obsoleteCancelledStageIds ++= cancelObsoleteStages(newPhysicalPlan, stagesToReplace)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
Expand All @@ -395,6 +399,30 @@ case class AdaptiveSparkPlanExec(
.get.asInstanceOf[T]
}

private def cancelObsoleteStages(
newPhysicalPlan: SparkPlan,
stagesToReplace: Seq[QueryStageExec]): Seq[Int] = {
val newStages = newPhysicalPlan.collect {
case stage: QueryStageExec => stage
}
val obsoleteStages = stagesToReplace.collect {
case stage: ExchangeQueryStageExec
if !newStages.exists(newStage =>
newStage.id == stage.id || newStage.resultOption.eq(stage.resultOption)) => stage
}
obsoleteStages.foreach { stage =>
if (!stage.isMaterialized) {
try {
stage.cancel("The query stage is no longer referenced by the current adaptive plan.")
} catch {
case NonFatal(t) =>
logError(s"Exception in cancelling obsolete query stage: ${stage.treeString}", t)
}
}
}
obsoleteStages.map(_.id)
}

// Use a lazy val to avoid this being called more than once.
@transient private lazy val finalPlanUpdate: Unit = {
// Do final plan update after result stage has materialized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.adaptive

import java.io.File
import java.net.URI

import org.apache.logging.log4j.Level
import org.scalatest.PrivateMethodTester

Expand Down Expand Up @@ -352,6 +351,42 @@ class AdaptiveQueryExecSuite
}
}

test("empty materialized stage short-circuits AQE through sort wrappers") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
val left = spark.range(0, 1, 1, 1).where("id < 0").select($"id".as("k"))
val right = spark.range(0, 200, 1, 20).as[Long].map { id =>
Thread.sleep(200)
id
}.select($"value".as("k"))
val df = left.join(right, Seq("k"))

checkAnswer(df, Seq.empty)

val finalPlan = df.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
assert(collect(finalPlan) { case s: SortMergeJoinExec => s }.isEmpty)
}
}

test("empty filtered global aggregate stage is not treated as non-empty") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
withTempView("empty_rows") {
spark.range(1).where("id < 0").createOrReplaceTempView("empty_rows")

val df = sql(
"SELECT * FROM testData LEFT ANTI JOIN " +
"(SELECT count(*) c FROM empty_rows HAVING c < 0) r")

checkAnswer(df, testData.collect().toSeq)
}
}
}

test("Scalar subquery") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
Expand Down