diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index e2a013b9e814c..cae7ccdcf0d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJo import org.apache.spark.sql.catalyst.plans.logical.EmptyRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ProjectExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys @@ -52,19 +53,37 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { // - positive value means an estimated row count which can be over-estimated // - none means the plan has not materialized or the plan can not be estimated private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match { - case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized => + case LogicalQueryStage(_, physicalPlan) => + getEstimatedRowCount(physicalPlan) + + case _: EmptyRelation => Some(0) + + case _ => None + } + + private def getEstimatedRowCount(plan: SparkPlan): Option[BigInt] = plan match { + case stage: QueryStageExec if stage.isMaterialized => stage.getRuntimeStatistics.rowCount - case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty && - agg.child.isInstanceOf[QueryStageExec] => - val stage = agg.child.asInstanceOf[QueryStageExec] - if (stage.isMaterialized) { - stage.getRuntimeStatistics.rowCount - } else { - None + case read: AQEShuffleReadExec => + getEstimatedRowCount(read.child) + + case sort: SortExec => + getEstimatedRowCount(sort.child) + + case project: ProjectExec => + getEstimatedRowCount(project.child) + + case columnarToRow: ColumnarToRowExec => + getEstimatedRowCount(columnarToRow.child) + + case aggregate: BaseAggregateExec if aggregate.groupingExpressions.isEmpty => + getEstimatedRowCount(aggregate.child).map { rowCount => + if (rowCount == 0) BigInt(1) else rowCount } - case _: EmptyRelation => Some(0) + case aggregate: BaseAggregateExec => + getEstimatedRowCount(aggregate.child) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 112ee82314c4b..18061257d6911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -278,6 +278,7 @@ case class AdaptiveSparkPlanExec( var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true) val events = new LinkedBlockingQueue[StageMaterializationEvent]() val errors = new mutable.ArrayBuffer[Throwable]() + val obsoleteCancelledStageIds = new mutable.HashSet[Int] var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { currentPhysicalPlan = result.newPlan @@ -333,7 +334,9 @@ case class AdaptiveSparkPlanExec( stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => stage.error.set(Some(ex)) - errors.append(ex) + if (!obsoleteCancelledStageIds.contains(stage.id)) { + errors.append(ex) + } } // In case of errors, we cancel all running stages and throw exception. @@ -373,6 +376,7 @@ case class AdaptiveSparkPlanExec( currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n") logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}") cleanUpTempTags(newPhysicalPlan) + obsoleteCancelledStageIds ++= cancelObsoleteStages(newPhysicalPlan, stagesToReplace) currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan stagesToReplace = Seq.empty[QueryStageExec] @@ -395,6 +399,30 @@ case class AdaptiveSparkPlanExec( .get.asInstanceOf[T] } + private def cancelObsoleteStages( + newPhysicalPlan: SparkPlan, + stagesToReplace: Seq[QueryStageExec]): Seq[Int] = { + val newStages = newPhysicalPlan.collect { + case stage: QueryStageExec => stage + } + val obsoleteStages = stagesToReplace.collect { + case stage: ExchangeQueryStageExec + if !newStages.exists(newStage => + newStage.id == stage.id || newStage.resultOption.eq(stage.resultOption)) => stage + } + obsoleteStages.foreach { stage => + if (!stage.isMaterialized) { + try { + stage.cancel("The query stage is no longer referenced by the current adaptive plan.") + } catch { + case NonFatal(t) => + logError(s"Exception in cancelling obsolete query stage: ${stage.treeString}", t) + } + } + } + obsoleteStages.map(_.id) + } + // Use a lazy val to avoid this being called more than once. @transient private lazy val finalPlanUpdate: Unit = { // Do final plan update after result stage has materialized. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 50322905f29f3..05d46029e683c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester @@ -352,6 +351,42 @@ class AdaptiveQueryExecSuite } } + test("empty materialized stage short-circuits AQE through sort wrappers") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + val left = spark.range(0, 1, 1, 1).where("id < 0").select($"id".as("k")) + val right = spark.range(0, 200, 1, 20).as[Long].map { id => + Thread.sleep(200) + id + }.select($"value".as("k")) + val df = left.join(right, Seq("k")) + + checkAnswer(df, Seq.empty) + + val finalPlan = df.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert(collect(finalPlan) { case s: SortMergeJoinExec => s }.isEmpty) + } + } + + test("empty filtered global aggregate stage is not treated as non-empty") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + withTempView("empty_rows") { + spark.range(1).where("id < 0").createOrReplaceTempView("empty_rows") + + val df = sql( + "SELECT * FROM testData LEFT ANTI JOIN " + + "(SELECT count(*) c FROM empty_rows HAVING c < 0) r") + + checkAnswer(df, testData.collect().toSeq) + } + } + } + test("Scalar subquery") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",