From 9614eb42af2b89926a09c576094b6046f3c270ea Mon Sep 17 00:00:00 2001 From: fredliu-data Date: Tue, 5 May 2026 10:30:08 -0700 Subject: [PATCH 1/2] [SPARK-56842][SQL] Short-circuit AQE when materialized stages are empty --- .../adaptive/AQEPropagateEmptyRelation.scala | 35 ++++++++--- .../adaptive/AdaptiveSparkPlanExec.scala | 31 +++++++++- .../adaptive/AdaptiveQueryExecSuite.scala | 62 ++++++++++++++++++- 3 files changed, 118 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index e2a013b9e814c..2130d72e90248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJo import org.apache.spark.sql.catalyst.plans.logical.EmptyRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ProjectExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys @@ -52,20 +53,38 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { // - positive value means an estimated row count which can be over-estimated // - none means the plan has not materialized or the plan can not be estimated private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match { - case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized => + case LogicalQueryStage(_, physicalPlan) => + getEstimatedRowCount(physicalPlan) + + case _ => None + } + + private def getEstimatedRowCount(plan: SparkPlan): Option[BigInt] = plan match { + case stage: QueryStageExec if stage.isMaterialized => stage.getRuntimeStatistics.rowCount - case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty && - agg.child.isInstanceOf[QueryStageExec] => - val stage = agg.child.asInstanceOf[QueryStageExec] - if (stage.isMaterialized) { - stage.getRuntimeStatistics.rowCount - } else { - None + case read: AQEShuffleReadExec => + getEstimatedRowCount(read.child) + + case sort: SortExec => + getEstimatedRowCount(sort.child) + + case project: ProjectExec => + getEstimatedRowCount(project.child) + + case columnarToRow: ColumnarToRowExec => + getEstimatedRowCount(columnarToRow.child) + + case aggregate: BaseAggregateExec if aggregate.groupingExpressions.isEmpty => + getEstimatedRowCount(aggregate.child).map { rowCount => + if (rowCount == 0) BigInt(1) else rowCount } case _: EmptyRelation => Some(0) + case aggregate: BaseAggregateExec => + getEstimatedRowCount(aggregate.child) + case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 112ee82314c4b..879aa49e1a243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -278,6 +278,7 @@ case class AdaptiveSparkPlanExec( var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true) val events = new LinkedBlockingQueue[StageMaterializationEvent]() val errors = new mutable.ArrayBuffer[Throwable]() + val obsoleteCancelledStageIds = new mutable.HashSet[Int] var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { currentPhysicalPlan = result.newPlan @@ -333,7 +334,9 @@ case class AdaptiveSparkPlanExec( stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => stage.error.set(Some(ex)) - errors.append(ex) + if (!obsoleteCancelledStageIds.contains(stage.id)) { + errors.append(ex) + } } // In case of errors, we cancel all running stages and throw exception. @@ -373,6 +376,7 @@ case class AdaptiveSparkPlanExec( currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n") logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}") cleanUpTempTags(newPhysicalPlan) + obsoleteCancelledStageIds ++= cancelObsoleteStages(newPhysicalPlan, stagesToReplace) currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan stagesToReplace = Seq.empty[QueryStageExec] @@ -395,6 +399,31 @@ case class AdaptiveSparkPlanExec( .get.asInstanceOf[T] } + private def cancelObsoleteStages( + newPhysicalPlan: SparkPlan, + stagesToReplace: Seq[QueryStageExec]): Seq[Int] = { + val newStages = newPhysicalPlan.collect { + case stage: QueryStageExec => stage + } + val obsoleteStages = stagesToReplace.collect { + case stage: ExchangeQueryStageExec + if !newStages.exists(newStage => + newStage.id == stage.id || newStage.resultOption.eq(stage.resultOption)) => stage + } + obsoleteStages.foreach { stage => + if (!stage.isMaterialized) { + removeStageFromCache(stage) + try { + stage.cancel() + } catch { + case NonFatal(t) => + logError(s"Exception in cancelling obsolete query stage: ${stage.treeString}", t) + } + } + } + obsoleteStages.map(_.id) + } + // Use a lazy val to avoid this being called more than once. @transient private lazy val finalPlanUpdate: Unit = { // Do final plan update after result stage has materialized. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 50322905f29f3..56914d500535a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI +import java.util.Locale + +import scala.collection.mutable import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.apache.spark.SparkException import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerEvent, SparkListenerJobEnd, SparkListenerJobStart} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -352,6 +355,63 @@ class AdaptiveQueryExecSuite } } + test("empty materialized stage short-circuits AQE through sort wrappers") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + val jobEndEvents = new mutable.ArrayBuffer[SparkListenerJobEnd] + val listener = new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = jobEndEvents.synchronized { + jobEndEvents += jobEnd + } + } + spark.sparkContext.addSparkListener(listener) + try { + val left = spark.range(0, 1, 1, 1).where("id < 0").select($"id".as("k")) + val right = spark.range(0, 200, 1, 20).as[Long].map { id => + Thread.sleep(200) + id + }.select($"value".as("k")) + val df = left.join(right, Seq("k")) + + checkAnswer(df, Seq.empty) + + val finalPlan = df.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert(collect(finalPlan) { case s: SortMergeJoinExec => s }.isEmpty) + assert(collect(finalPlan) { case s: LocalTableScanExec => s }.nonEmpty) + + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobEndEvents.synchronized { + jobEndEvents.exists { + case SparkListenerJobEnd(_, _, JobFailed(e)) => + e.getMessage.toLowerCase(Locale.ROOT).contains("cancel") + case _ => false + } + }) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + + test("empty filtered global aggregate stage is not treated as non-empty") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + withTempView("empty_rows") { + spark.range(1).where("id < 0").createOrReplaceTempView("empty_rows") + + val df = sql( + "SELECT * FROM testData LEFT ANTI JOIN " + + "(SELECT count(*) c FROM empty_rows HAVING c < 0) r") + + checkAnswer(df, testData.collect().toSeq) + } + } + } + test("Scalar subquery") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", From 13dfe34b64a51d817a16e5dc3b00a03507b3751c Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 12 May 2026 18:00:05 -0700 Subject: [PATCH 2/2] [SPARK-56842][SQL] Fix OSS carryover for AQE empty-stage handling --- .../adaptive/AQEPropagateEmptyRelation.scala | 4 +- .../adaptive/AdaptiveSparkPlanExec.scala | 3 +- .../adaptive/AdaptiveQueryExecSuite.scala | 51 +++++-------------- 3 files changed, 16 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index 2130d72e90248..cae7ccdcf0d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -56,6 +56,8 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { case LogicalQueryStage(_, physicalPlan) => getEstimatedRowCount(physicalPlan) + case _: EmptyRelation => Some(0) + case _ => None } @@ -80,8 +82,6 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { if (rowCount == 0) BigInt(1) else rowCount } - case _: EmptyRelation => Some(0) - case aggregate: BaseAggregateExec => getEstimatedRowCount(aggregate.child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 879aa49e1a243..18061257d6911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -412,9 +412,8 @@ case class AdaptiveSparkPlanExec( } obsoleteStages.foreach { stage => if (!stage.isMaterialized) { - removeStageFromCache(stage) try { - stage.cancel() + stage.cancel("The query stage is no longer referenced by the current adaptive plan.") } catch { case NonFatal(t) => logError(s"Exception in cancelling obsolete query stage: ${stage.treeString}", t) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 56914d500535a..05d46029e683c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -19,16 +19,12 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI -import java.util.Locale - -import scala.collection.mutable - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.apache.spark.SparkException import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerEvent, SparkListenerJobEnd, SparkListenerJobStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -360,39 +356,18 @@ class AdaptiveQueryExecSuite SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - val jobEndEvents = new mutable.ArrayBuffer[SparkListenerJobEnd] - val listener = new SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = jobEndEvents.synchronized { - jobEndEvents += jobEnd - } - } - spark.sparkContext.addSparkListener(listener) - try { - val left = spark.range(0, 1, 1, 1).where("id < 0").select($"id".as("k")) - val right = spark.range(0, 200, 1, 20).as[Long].map { id => - Thread.sleep(200) - id - }.select($"value".as("k")) - val df = left.join(right, Seq("k")) - - checkAnswer(df, Seq.empty) - - val finalPlan = df.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - assert(collect(finalPlan) { case s: SortMergeJoinExec => s }.isEmpty) - assert(collect(finalPlan) { case s: LocalTableScanExec => s }.nonEmpty) - - spark.sparkContext.listenerBus.waitUntilEmpty() - assert(jobEndEvents.synchronized { - jobEndEvents.exists { - case SparkListenerJobEnd(_, _, JobFailed(e)) => - e.getMessage.toLowerCase(Locale.ROOT).contains("cancel") - case _ => false - } - }) - } finally { - spark.sparkContext.removeSparkListener(listener) - } + val left = spark.range(0, 1, 1, 1).where("id < 0").select($"id".as("k")) + val right = spark.range(0, 200, 1, 20).as[Long].map { id => + Thread.sleep(200) + id + }.select($"value".as("k")) + val df = left.join(right, Seq("k")) + + checkAnswer(df, Seq.empty) + + val finalPlan = df.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert(collect(finalPlan) { case s: SortMergeJoinExec => s }.isEmpty) } }