From 6d54b7a3c8fd6a60862a857a91995539e9c7e1be Mon Sep 17 00:00:00 2001 From: Amruth Ashok Date: Fri, 27 Feb 2026 00:41:25 +0530 Subject: [PATCH 1/2] [SPARK-55661][CORE] Ensure TaskRunner.run() sends StatusUpdate on setup failure Move TaskRunner.run() setup code (classloader isolation, thread naming, serializer creation) inside the existing try/catch/finally block so that exceptions during setup are caught and reported to the driver via StatusUpdate. Previously, setup code ran outside the try block, causing silent failures that leaked GPU/CPU resources on the driver. The fix changes 'val isolatedSession' and 'val ser' to 'var' declarations before the try block (with safe defaults), and adds a setup-failure branch in the catch-all handler that sends StatusUpdate(FAILED) or StatusUpdate(KILLED) when the serializer was never initialized (ser == null). Closes: https://issues.apache.org/jira/browse/SPARK-55661 --- .../org/apache/spark/executor/Executor.scala | 107 +++++++++++------- .../apache/spark/executor/ExecutorSuite.scala | 79 +++++++++++++ 2 files changed, 145 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 36f4fb5ac970e..9ad6f2c26f933 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -51,7 +51,7 @@ import org.apache.spark.metrics.source.JVMCPUSource import org.apache.spark.resource.ResourceInformation import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler._ -import org.apache.spark.serializer.SerializerHelper +import org.apache.spark.serializer.{SerializerHelper, SerializerInstance} import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher} import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} @@ -803,36 +803,39 @@ private[spark] class Executor( } override def run(): Unit = { - - // Classloader isolation - val isolatedSession = taskDescription.artifacts.state match { - case Some(jobArtifactState) => - obtainSession(jobArtifactState) - case _ => - // The default session is never in the cache and never evicted, - // so no need to acquire/release. - defaultSessionState - } - - setMDCForTask(taskName, mdcProperties) - threadId = Thread.currentThread.getId - Thread.currentThread.setName(threadName) - val threadMXBean = ManagementFactory.getThreadMXBean - val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) - val deserializeStartTimeNs = System.nanoTime() - val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { - threadMXBean.getCurrentThreadCpuTime - } else 0L - Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader) - val ser = env.closureSerializer.newInstance() - logInfo(log"Running ${MDC(TASK_NAME, taskName)}") - execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + // SPARK-55661: isolatedSession and ser are declared before the try block so that + // the catch and finally blocks can access them for cleanup and error reporting. + var isolatedSession: IsolatedSessionState = defaultSessionState + var ser: SerializerInstance = null var taskStartTimeNs: Long = 0 var taskStartCpu: Long = 0 - startGCTime = computeTotalGcTime() var taskStarted: Boolean = false try { + // Classloader isolation + isolatedSession = taskDescription.artifacts.state match { + case Some(jobArtifactState) => + obtainSession(jobArtifactState) + case _ => + // The default session is never in the cache and never evicted, + // so no need to acquire/release. + defaultSessionState + } + + setMDCForTask(taskName, mdcProperties) + threadId = Thread.currentThread.getId + Thread.currentThread.setName(threadName) + val threadMXBean = ManagementFactory.getThreadMXBean + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) + val deserializeStartTimeNs = System.nanoTime() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L + Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader) + ser = env.closureSerializer.newInstance() + logInfo(log"Running ${MDC(TASK_NAME, taskName)}") + execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + startGCTime = computeTotalGcTime() // Must be set before updateDependencies() is called, in case fetching dependencies // requires access to properties contained within (e.g. for access control). Executor.taskDeserializationProps.set(taskDescription.properties) @@ -1083,25 +1086,47 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of preemption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) - val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq - - val (taskFailureReason, serializedTaskFailureReason) = { - try { - val ef = new ExceptionFailure(t, accUpdates).withAccums(accums) - .withMetricPeaks(metricPeaks) - (ef, ser.serialize(ef)) - } catch { - case _: NotSerializableException => - // t is not serializable so just send the stacktrace - val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums) + if (ser != null) { + val (accums, accUpdates) = + collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) + val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq + + val (taskFailureReason, serializedTaskFailureReason) = { + try { + val ef = new ExceptionFailure(t, accUpdates).withAccums(accums) .withMetricPeaks(metricPeaks) (ef, ser.serialize(ef)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums) + .withMetricPeaks(metricPeaks) + (ef, ser.serialize(ef)) + } + } + setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskFailed(taskFailureReason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason) + } else { + // SPARK-55661: Setup failed before the serializer was created. Send a + // StatusUpdate so the driver releases resources allocated for this task. + try { + val killReason = reasonIfKilled + val failSer = env.closureSerializer.newInstance() + if (killReason.isDefined) { + val reason = TaskKilled(killReason.get) + execBackend.statusUpdate(taskId, TaskState.KILLED, failSer.serialize(reason)) + } else { + val ef = new ExceptionFailure(t, Seq.empty) + execBackend.statusUpdate(taskId, TaskState.FAILED, failSer.serialize(ef)) + } + } catch { + case NonFatal(inner) => + logError( + log"Failed to report task setup failure for " + + log"${MDC(TASK_NAME, taskName)}", inner) } } - setTaskFinishedAndClearInterruptStatus() - plugins.foreach(_.onTaskFailed(taskFailureReason)) - execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason) } else { logInfo("Not reporting error to driver during JVM shutdown.") } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f63cf4f05d61e..726090d767153 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -722,6 +722,85 @@ class ExecutorSuite extends SparkFunSuite } } + test("SPARK-55661: TaskRunner.run() setup failure should send StatusUpdate " + + "to prevent driver resource leak") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) + val taskDescription = createFakeTaskDescription(serializedTask) + + val mockExecutorBackend = mock[ExecutorBackend] + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + + withExecutor("id", "localhost", env) { executor => + val mockClosureSerializer = mock[JavaSerializer] + when(mockClosureSerializer.newInstance()) + .thenThrow(new RuntimeException("simulated setup failure in TaskRunner")) + .thenReturn(serializer.newInstance()) + when(env.closureSerializer).thenReturn(mockClosureSerializer) + + executor.launchTask(mockExecutorBackend, taskDescription) + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(executor.numRunningTasks === 0) + } + + verify(mockExecutorBackend).statusUpdate( + meq(0L), + meq(TaskState.FAILED), + statusCaptor.capture() + ) + + val failureData = statusCaptor.getValue + val failReason = serializer.newInstance() + .deserialize[ExceptionFailure](failureData) + assert(failReason.exception.isDefined) + assert(failReason.exception.get.getMessage === + "simulated setup failure in TaskRunner") + } + } + + test("SPARK-55661: TaskRunner.run() setup failure on killed task should send " + + "StatusUpdate(KILLED) to prevent driver resource leak") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) + val taskDescription = createFakeTaskDescription(serializedTask) + + val mockExecutorBackend = mock[ExecutorBackend] + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + + withExecutor("id", "localhost", env) { executor => + val mockClosureSerializer = mock[JavaSerializer] + when(mockClosureSerializer.newInstance()) + .thenThrow(new RuntimeException("simulated setup failure in TaskRunner")) + .thenReturn(serializer.newInstance()) + when(env.closureSerializer).thenReturn(mockClosureSerializer) + + executor.killMarks.put(taskDescription.taskId, + (true, "AQE stage cancellation", System.currentTimeMillis())) + + executor.launchTask(mockExecutorBackend, taskDescription) + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(executor.numRunningTasks === 0) + } + + verify(mockExecutorBackend).statusUpdate( + meq(0L), + meq(TaskState.KILLED), + statusCaptor.capture() + ) + + val failureData = statusCaptor.getValue + val failReason = serializer.newInstance() + .deserialize[TaskKilled](failureData) + assert(failReason.reason === "AQE stage cancellation") + } + } + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv] From 5a1590c7c6c7eea59f5f0c37165ec3e98bd1b934 Mon Sep 17 00:00:00 2001 From: Amruth Ashok Date: Fri, 27 Feb 2026 01:06:57 +0530 Subject: [PATCH 2/2] [SPARK-55661][CORE] Updated kill reason in setup failure test --- .../test/scala/org/apache/spark/executor/ExecutorSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 726090d767153..941689fff29a1 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -780,7 +780,7 @@ class ExecutorSuite extends SparkFunSuite when(env.closureSerializer).thenReturn(mockClosureSerializer) executor.killMarks.put(taskDescription.taskId, - (true, "AQE stage cancellation", System.currentTimeMillis())) + (true, "stage cancelled", System.currentTimeMillis())) executor.launchTask(mockExecutorBackend, taskDescription) @@ -797,7 +797,7 @@ class ExecutorSuite extends SparkFunSuite val failureData = statusCaptor.getValue val failReason = serializer.newInstance() .deserialize[TaskKilled](failureData) - assert(failReason.reason === "AQE stage cancellation") + assert(failReason.reason === "stage cancelled") } }