aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-12-19 18:43:59 -0800
committerYin Huai <yhuai@databricks.com>2016-12-19 18:43:59 -0800
commitfa829ce21fb84028d90b739a49c4ece70a17ccfd (patch)
tree1ac5a7e18d76a3dfb94209c361fcb48f1b13eba0 /core/src/main/scala/org
parent5857b9ac2d9808d9b89a5b29620b5052e2beebf5 (diff)
downloadspark-fa829ce21fb84028d90b739a49c4ece70a17ccfd.tar.gz
spark-fa829ce21fb84028d90b739a49c4ece70a17ccfd.tar.bz2
spark-fa829ce21fb84028d90b739a49c4ece70a17ccfd.zip
[SPARK-18761][CORE] Introduce "task reaper" to oversee task killing in executors
## What changes were proposed in this pull request? Spark's current task cancellation / task killing mechanism is "best effort" because some tasks may not be interruptible or may not respond to their "killed" flags being set. If a significant fraction of a cluster's task slots are occupied by tasks that have been marked as killed but remain running then this can lead to a situation where new jobs and tasks are starved of resources that are being used by these zombie tasks. This patch aims to address this problem by adding a "task reaper" mechanism to executors. At a high-level, task killing now launches a new thread which attempts to kill the task and then watches the task and periodically checks whether it has been killed. The TaskReaper will periodically re-attempt to call `TaskRunner.kill()` and will log warnings if the task keeps running. I modified TaskRunner to rename its thread at the start of the task, allowing TaskReaper to take a thread dump and filter it in order to log stacktraces from the exact task thread that we are waiting to finish. If the task has not stopped after a configurable timeout then the TaskReaper will throw an exception to trigger executor JVM death, thereby forcibly freeing any resources consumed by the zombie tasks. This feature is flagged off by default and is controlled by four new configurations under the `spark.task.reaper.*` namespace. See the updated `configuration.md` doc for details. ## How was this patch tested? Tested via a new test case in `JobCancellationSuite`, plus manual testing. Author: Josh Rosen <joshrosen@databricks.com> Closes #16189 from JoshRosen/cancellation.
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala169
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala56
2 files changed, 197 insertions, 28 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9501dd9cd8..3346f6dd1f 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -84,6 +84,16 @@ private[spark] class Executor(
// Start worker thread pool
private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
private val executorSource = new ExecutorSource(threadPool, executorId)
+ // Pool used for threads that supervise task killing / cancellation
+ private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
+ // For tasks which are in the process of being killed, this map holds the most recently created
+ // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
+ // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
+ // the integrity of the map's internal state). The purpose of this map is to prevent the creation
+ // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
+ // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
+ // create. The map key is a task id.
+ private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()
if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
@@ -93,6 +103,9 @@ private[spark] class Executor(
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false)
+ // Whether to monitor killed / interrupted tasks
+ private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false)
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -148,9 +161,27 @@ private[spark] class Executor(
}
def killTask(taskId: Long, interruptThread: Boolean): Unit = {
- val tr = runningTasks.get(taskId)
- if (tr != null) {
- tr.kill(interruptThread)
+ val taskRunner = runningTasks.get(taskId)
+ if (taskRunner != null) {
+ if (taskReaperEnabled) {
+ val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
+ val shouldCreateReaper = taskReaperForTask.get(taskId) match {
+ case None => true
+ case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
+ }
+ if (shouldCreateReaper) {
+ val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
+ taskReaperForTask(taskId) = taskReaper
+ Some(taskReaper)
+ } else {
+ None
+ }
+ }
+ // Execute the TaskReaper from outside of the synchronized block.
+ maybeNewTaskReaper.foreach(taskReaperPool.execute)
+ } else {
+ taskRunner.kill(interruptThread = interruptThread)
+ }
}
}
@@ -161,12 +192,7 @@ private[spark] class Executor(
* @param interruptThread whether to interrupt the task thread
*/
def killAllTasks(interruptThread: Boolean) : Unit = {
- // kill all the running tasks
- for (taskRunner <- runningTasks.values().asScala) {
- if (taskRunner != null) {
- taskRunner.kill(interruptThread)
- }
- }
+ runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
}
def stop(): Unit = {
@@ -192,13 +218,21 @@ private[spark] class Executor(
serializedTask: ByteBuffer)
extends Runnable {
+ val threadName = s"Executor task launch worker for task $taskId"
+
/** Whether this task has been killed. */
@volatile private var killed = false
+ @volatile private var threadId: Long = -1
+
+ def getThreadId: Long = threadId
+
/** Whether this task has been finished. */
@GuardedBy("TaskRunner.this")
private var finished = false
+ def isFinished: Boolean = synchronized { finished }
+
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
@@ -229,9 +263,15 @@ private[spark] class Executor(
// ClosedByInterruptException during execBackend.statusUpdate which causes
// Executor to crash
Thread.interrupted()
+ // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
+ // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
+ // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
+ notifyAll()
}
override def run(): Unit = {
+ threadId = Thread.currentThread.getId
+ Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
@@ -432,6 +472,117 @@ private[spark] class Executor(
}
/**
+ * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
+ * sending a Thread.interrupt(), and monitoring the task until it finishes.
+ *
+ * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
+ * may not be interruptable or may not respond to their "killed" flags being set. If a significant
+ * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
+ * remain running then this can lead to a situation where new jobs and tasks are starved of
+ * resources that are being used by these zombie tasks.
+ *
+ * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
+ * tasks. For backwards-compatibility / backportability this component is disabled by default
+ * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
+ *
+ * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
+ * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
+ * in case kill is called twice with different values for the `interrupt` parameter.
+ *
+ * Once created, a TaskReaper will run until its supervised task has finished running. If the
+ * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
+ * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
+ * if the supervised task never exits.
+ */
+ private class TaskReaper(
+ taskRunner: TaskRunner,
+ val interruptThread: Boolean)
+ extends Runnable {
+
+ private[this] val taskId: Long = taskRunner.taskId
+
+ private[this] val killPollingIntervalMs: Long =
+ conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s")
+
+ private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1")
+
+ private[this] val takeThreadDump: Boolean =
+ conf.getBoolean("spark.task.reaper.threadDump", true)
+
+ override def run(): Unit = {
+ val startTimeMs = System.currentTimeMillis()
+ def elapsedTimeMs = System.currentTimeMillis() - startTimeMs
+ def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs
+ try {
+ // Only attempt to kill the task once. If interruptThread = false then a second kill
+ // attempt would be a no-op and if interruptThread = true then it may not be safe or
+ // effective to interrupt multiple times:
+ taskRunner.kill(interruptThread = interruptThread)
+ // Monitor the killed task until it exits. The synchronization logic here is complicated
+ // because we don't want to synchronize on the taskRunner while possibly taking a thread
+ // dump, but we also need to be careful to avoid races between checking whether the task
+ // has finished and wait()ing for it to finish.
+ var finished: Boolean = false
+ while (!finished && !timeoutExceeded()) {
+ taskRunner.synchronized {
+ // We need to synchronize on the TaskRunner while checking whether the task has
+ // finished in order to avoid a race where the task is marked as finished right after
+ // we check and before we call wait().
+ if (taskRunner.isFinished) {
+ finished = true
+ } else {
+ taskRunner.wait(killPollingIntervalMs)
+ }
+ }
+ if (taskRunner.isFinished) {
+ finished = true
+ } else {
+ logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms")
+ if (takeThreadDump) {
+ try {
+ Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
+ if (thread.threadName == taskRunner.threadName) {
+ logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}")
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Exception thrown while obtaining thread dump: ", e)
+ }
+ }
+ }
+ }
+
+ if (!taskRunner.isFinished && timeoutExceeded()) {
+ if (isLocal) {
+ logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " +
+ "not killing JVM because we are running in local mode.")
+ } else {
+ // In non-local-mode, the exception thrown here will bubble up to the uncaught exception
+ // handler and cause the executor JVM to exit.
+ throw new SparkException(
+ s"Killing executor JVM because killed task $taskId could not be stopped within " +
+ s"$killTimeoutMs ms.")
+ }
+ }
+ } finally {
+ // Clean up entries in the taskReaperForTask map.
+ taskReaperForTask.synchronized {
+ taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
+ if (taskReaperInMap eq this) {
+ taskReaperForTask.remove(taskId)
+ } else {
+ // This must have been a TaskReaper where interruptThread == false where a subsequent
+ // killTask() call for the same task had interruptThread == true and overwrote the
+ // map entry.
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c6ad154167..078cc3d5b4 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
import java.io._
-import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo}
+import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.net._
import java.nio.ByteBuffer
import java.nio.channels.Channels
@@ -2131,28 +2131,46 @@ private[spark] object Utils extends Logging {
// We need to filter out null values here because dumpAllThreads() may return null array
// elements for threads that are dead / don't exist.
val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
- threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
- val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
- val stackTrace = threadInfo.getStackTrace.map { frame =>
- monitors.get(frame) match {
- case Some(monitor) =>
- monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
- case None =>
- frame.toString
- }
- }.mkString("\n")
-
- // use a set to dedup re-entrant locks that are held at multiple places
- val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString)
- ++ threadInfo.getLockedMonitors.map(_.lockString)
- ).toSet
+ threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
+ }
- ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState,
- stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
- Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq)
+ def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
+ if (threadId <= 0) {
+ None
+ } else {
+ // The Int.MaxValue here requests the entire untruncated stack trace of the thread:
+ val threadInfo =
+ Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue))
+ threadInfo.map(threadInfoToThreadStackTrace)
}
}
+ private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
+ val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
+ val stackTrace = threadInfo.getStackTrace.map { frame =>
+ monitors.get(frame) match {
+ case Some(monitor) =>
+ monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
+ case None =>
+ frame.toString
+ }
+ }.mkString("\n")
+
+ // use a set to dedup re-entrant locks that are held at multiple places
+ val heldLocks =
+ (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet
+
+ ThreadStackTrace(
+ threadId = threadInfo.getThreadId,
+ threadName = threadInfo.getThreadName,
+ threadState = threadInfo.getThreadState,
+ stackTrace = stackTrace,
+ blockedByThreadId =
+ if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
+ blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""),
+ holdingLocks = heldLocks.toSeq)
+ }
+
/**
* Convert all spark properties set in the given SparkConf to a sequence of java options.
*/