Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 66 additions & 41 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
import org.apache.spark.serializer.SerializerHelper
import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.status.api.v1.ThreadStackTrace
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
Expand Down Expand Up @@ -803,36 +803,39 @@ private[spark] class Executor(
}

override def run(): Unit = {

// Classloader isolation
val isolatedSession = taskDescription.artifacts.state match {
case Some(jobArtifactState) =>
obtainSession(jobArtifactState)
case _ =>
// The default session is never in the cache and never evicted,
// so no need to acquire/release.
defaultSessionState
}

setMDCForTask(taskName, mdcProperties)
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTimeNs = System.nanoTime()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(log"Running ${MDC(TASK_NAME, taskName)}")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
// SPARK-55661: isolatedSession and ser are declared before the try block so that
// the catch and finally blocks can access them for cleanup and error reporting.
var isolatedSession: IsolatedSessionState = defaultSessionState
var ser: SerializerInstance = null
var taskStartTimeNs: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
var taskStarted: Boolean = false

try {
// Classloader isolation
isolatedSession = taskDescription.artifacts.state match {
case Some(jobArtifactState) =>
obtainSession(jobArtifactState)
case _ =>
// The default session is never in the cache and never evicted,
// so no need to acquire/release.
defaultSessionState
}

setMDCForTask(taskName, mdcProperties)
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTimeNs = System.nanoTime()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
ser = env.closureSerializer.newInstance()
logInfo(log"Running ${MDC(TASK_NAME, taskName)}")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
startGCTime = computeTotalGcTime()
// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
Executor.taskDeserializationProps.set(taskDescription.properties)
Expand Down Expand Up @@ -1083,25 +1086,47 @@ private[spark] class Executor(
// the task failure would not be ignored if the shutdown happened because of preemption,
// instead of an app issue).
if (!ShutdownHookManager.inShutdown()) {
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq

val (taskFailureReason, serializedTaskFailureReason) = {
try {
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
.withMetricPeaks(metricPeaks)
(ef, ser.serialize(ef))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
if (ser != null) {
val (accums, accUpdates) =
collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq

val (taskFailureReason, serializedTaskFailureReason) = {
try {
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
.withMetricPeaks(metricPeaks)
(ef, ser.serialize(ef))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
.withMetricPeaks(metricPeaks)
(ef, ser.serialize(ef))
}
}
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(taskFailureReason))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
} else {
// SPARK-55661: Setup failed before the serializer was created. Send a
// StatusUpdate so the driver releases resources allocated for this task.
try {
val killReason = reasonIfKilled
val failSer = env.closureSerializer.newInstance()
if (killReason.isDefined) {
val reason = TaskKilled(killReason.get)
execBackend.statusUpdate(taskId, TaskState.KILLED, failSer.serialize(reason))
} else {
val ef = new ExceptionFailure(t, Seq.empty)
execBackend.statusUpdate(taskId, TaskState.FAILED, failSer.serialize(ef))
}
} catch {
case NonFatal(inner) =>
logError(
log"Failed to report task setup failure for " +
log"${MDC(TASK_NAME, taskName)}", inner)
}
}
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(taskFailureReason))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
} else {
logInfo("Not reporting error to driver during JVM shutdown.")
}
Expand Down
79 changes: 79 additions & 0 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,85 @@ class ExecutorSuite extends SparkFunSuite
}
}

test("SPARK-55661: TaskRunner.run() setup failure should send StatusUpdate " +
"to prevent driver resource leak") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)

val mockExecutorBackend = mock[ExecutorBackend]
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])

withExecutor("id", "localhost", env) { executor =>
val mockClosureSerializer = mock[JavaSerializer]
when(mockClosureSerializer.newInstance())
.thenThrow(new RuntimeException("simulated setup failure in TaskRunner"))
.thenReturn(serializer.newInstance())
when(env.closureSerializer).thenReturn(mockClosureSerializer)

executor.launchTask(mockExecutorBackend, taskDescription)

eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}

verify(mockExecutorBackend).statusUpdate(
meq(0L),
meq(TaskState.FAILED),
statusCaptor.capture()
)

val failureData = statusCaptor.getValue
val failReason = serializer.newInstance()
.deserialize[ExceptionFailure](failureData)
assert(failReason.exception.isDefined)
assert(failReason.exception.get.getMessage ===
"simulated setup failure in TaskRunner")
}
}

test("SPARK-55661: TaskRunner.run() setup failure on killed task should send " +
"StatusUpdate(KILLED) to prevent driver resource leak") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)

val mockExecutorBackend = mock[ExecutorBackend]
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])

withExecutor("id", "localhost", env) { executor =>
val mockClosureSerializer = mock[JavaSerializer]
when(mockClosureSerializer.newInstance())
.thenThrow(new RuntimeException("simulated setup failure in TaskRunner"))
.thenReturn(serializer.newInstance())
when(env.closureSerializer).thenReturn(mockClosureSerializer)

executor.killMarks.put(taskDescription.taskId,
(true, "stage cancelled", System.currentTimeMillis()))

executor.launchTask(mockExecutorBackend, taskDescription)

eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}

verify(mockExecutorBackend).statusUpdate(
meq(0L),
meq(TaskState.KILLED),
statusCaptor.capture()
)

val failureData = statusCaptor.getValue
val failReason = serializer.newInstance()
.deserialize[TaskKilled](failureData)
assert(failReason.reason === "stage cancelled")
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
Expand Down