aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-03-19 22:12:01 -0400
committerReynold Xin <rxin@databricks.com>2015-03-19 22:12:01 -0400
commit0745a305fac622a6eeb8aa4a7401205a14252939 (patch)
tree6d325a3e021cb861a6ea1b800288ad82cc5eb1d9 /core
parentf17d43b033d928dbc46aef8e367aa08902e698ad (diff)
downloadspark-0745a305fac622a6eeb8aa4a7401205a14252939.tar.gz
spark-0745a305fac622a6eeb8aa4a7401205a14252939.tar.bz2
spark-0745a305fac622a6eeb8aa4a7401205a14252939.zip
Tighten up field/method visibility in Executor and made some code more clear to read.
I was reading Executor just now and found that some latest changes introduced some weird code path with too much monadic chaining and unnecessary fields. I cleaned it up a bit, and also tightened up the visibility of various fields/methods. Also added some inline documentation to help understand this code better. Author: Reynold Xin <rxin@databricks.com> Closes #4850 from rxin/executor and squashes the following commits: 866fc60 [Reynold Xin] Code review feedback. 020efbb [Reynold Xin] Tighten up field/method visibility in Executor and made some code more clear to read.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala196
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
5 files changed, 120 insertions, 106 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 29a5cd5fda..48fd3e7e23 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -151,11 +151,7 @@ case object TaskKilled extends TaskFailedReason {
* Task requested the driver to commit, but was denied.
*/
@DeveloperApi
-case class TaskCommitDenied(
- jobID: Int,
- partitionID: Int,
- attemptID: Int)
- extends TaskFailedReason {
+case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason {
override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
index f7604a321f..f47d7ef511 100644
--- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -22,14 +22,12 @@ import org.apache.spark.{TaskCommitDenied, TaskEndReason}
/**
* Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
*/
-class CommitDeniedException(
+private[spark] class CommitDeniedException(
msg: String,
jobID: Int,
splitID: Int,
attemptID: Int)
extends Exception(msg) {
- def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID)
-
+ def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID)
}
-
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 6196f7b165..bf3135ef08 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -21,7 +21,7 @@ import java.io.File
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
-import java.util.concurrent._
+import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
@@ -31,15 +31,17 @@ import akka.actor.Props
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader,
- SparkUncaughtExceptionHandler, AkkaUtils, Utils}
+import org.apache.spark.util._
/**
- * Spark executor used with Mesos, YARN, and the standalone scheduler.
- * In coarse-grained mode, an existing actor system is provided.
+ * Spark executor, backed by a threadpool to run tasks.
+ *
+ * This can be used with Mesos, YARN, and the standalone scheduler.
+ * An internal RPC interface (at the moment Akka) is used for communication with the driver,
+ * except in the case of Mesos fine-grained mode.
*/
private[spark] class Executor(
executorId: String,
@@ -47,8 +49,8 @@ private[spark] class Executor(
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
- extends Logging
-{
+ extends Logging {
+
logInfo(s"Starting executor ID $executorId on host $executorHostname")
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -78,9 +80,8 @@ private[spark] class Executor(
}
// Start worker thread pool
- val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
-
- val executorSource = new ExecutorSource(this, executorId)
+ private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
+ private val executorSource = new ExecutorSource(threadPool, executorId)
if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
@@ -122,21 +123,21 @@ private[spark] class Executor(
taskId: Long,
attemptNumber: Int,
taskName: String,
- serializedTask: ByteBuffer) {
+ serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
- def killTask(taskId: Long, interruptThread: Boolean) {
+ def killTask(taskId: Long, interruptThread: Boolean): Unit = {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill(interruptThread)
}
}
- def stop() {
+ def stop(): Unit = {
env.metricsSystem.report()
env.actorSystem.stop(executorActor)
isStopped = true
@@ -146,7 +147,10 @@ private[spark] class Executor(
}
}
- private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+ /** Returns the total amount of time this JVM process has spent in garbage collection. */
+ private def computeTotalGcTime(): Long = {
+ ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+ }
class TaskRunner(
execBackend: ExecutorBackend,
@@ -156,12 +160,19 @@ private[spark] class Executor(
serializedTask: ByteBuffer)
extends Runnable {
+ /** Whether this task has been killed. */
@volatile private var killed = false
- @volatile var task: Task[Any] = _
- @volatile var attemptedTask: Option[Task[Any]] = None
+
+ /** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
- def kill(interruptThread: Boolean) {
+ /**
+ * The task to run. This will be set in run() by deserializing the task binary coming
+ * from the driver. Once it is set, it will never be changed.
+ */
+ @volatile var task: Task[Any] = _
+
+ def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
if (task != null) {
@@ -169,14 +180,14 @@ private[spark] class Executor(
}
}
- override def run() {
+ override def run(): Unit = {
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
- startGCTime = gcTime
+ startGCTime = computeTotalGcTime()
try {
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
@@ -193,7 +204,6 @@ private[spark] class Executor(
throw new TaskKilledException
}
- attemptedTask = Some(task)
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
@@ -215,18 +225,17 @@ private[spark] class Executor(
for (m <- task.metrics) {
m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
m.setExecutorRunTime(taskFinish - taskStart)
- m.setJvmGCTime(gcTime - startGCTime)
+ m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}
val accumUpdates = Accumulators.values
-
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
- val serializedResult = {
+ val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
@@ -248,42 +257,40 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
- case ffe: FetchFailedException => {
+ case ffe: FetchFailedException =>
val reason = ffe.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
- }
- case _: TaskKilledException | _: InterruptedException if task.killed => {
+ case _: TaskKilledException | _: InterruptedException if task.killed =>
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
- }
- case cDE: CommitDeniedException => {
+ case cDE: CommitDeniedException =>
val reason = cDE.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
- }
- case t: Throwable => {
+ case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
- val serviceTime = System.currentTimeMillis() - taskStart
- val metrics = attemptedTask.flatMap(t => t.metrics)
- for (m <- metrics) {
- m.setExecutorRunTime(serviceTime)
- m.setJvmGCTime(gcTime - startGCTime)
+ val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
+ task.metrics.map { m =>
+ m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
+ m.setJvmGCTime(computeTotalGcTime() - startGCTime)
+ m
+ }
}
- val reason = new ExceptionFailure(t, metrics)
- execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ val taskEndReason = new ExceptionFailure(t, metrics)
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
}
- }
+
} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
@@ -358,7 +365,7 @@ private[spark] class Executor(
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
@@ -370,12 +377,12 @@ private[spark] class Executor(
if (currentTimeStamp < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
+ val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
+ if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
@@ -384,61 +391,70 @@ private[spark] class Executor(
}
}
- def startDriverHeartbeater() {
- val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
- val timeout = AkkaUtils.lookupTimeout(conf)
- val retryAttempts = AkkaUtils.numRetries(conf)
- val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
- val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
+ private val timeout = AkkaUtils.lookupTimeout(conf)
+ private val retryAttempts = AkkaUtils.numRetries(conf)
+ private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
+ private val heartbeatReceiverRef =
+ AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
+
+ /** Reports heartbeat and metrics for active tasks to the driver. */
+ private def reportHeartBeat(): Unit = {
+ // list of (task id, metrics) to send back to the driver
+ val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ val curGCTime = computeTotalGcTime()
+
+ for (taskRunner <- runningTasks.values()) {
+ if (taskRunner.task != null) {
+ taskRunner.task.metrics.foreach { metrics =>
+ metrics.updateShuffleReadMetrics()
+ metrics.updateInputMetrics()
+ metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+
+ if (isLocal) {
+ // JobProgressListener will hold an reference of it during
+ // onExecutorMetricsUpdate(), then JobProgressListener can not see
+ // the changes of metrics any more, so make a deep copy of it
+ val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
+ tasksMetrics += ((taskRunner.taskId, copiedMetrics))
+ } else {
+ // It will be copied by serialization
+ tasksMetrics += ((taskRunner.taskId, metrics))
+ }
+ }
+ }
+ }
- val t = new Thread() {
+ val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
+ try {
+ val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
+ retryAttempts, retryIntervalMs, timeout)
+ if (response.reregisterBlockManager) {
+ logWarning("Told to re-register on heartbeat")
+ env.blockManager.reregister()
+ }
+ } catch {
+ case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e)
+ }
+ }
+
+ /**
+ * Starts a thread to report heartbeat and partial metrics for active tasks to driver.
+ * This thread stops running when the executor is stopped.
+ */
+ private def startDriverHeartbeater(): Unit = {
+ val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
+ val thread = new Thread() {
override def run() {
// Sleep a random interval so the heartbeats don't end up in sync
Thread.sleep(interval + (math.random * interval).asInstanceOf[Int])
-
while (!isStopped) {
- val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
- val curGCTime = gcTime
-
- for (taskRunner <- runningTasks.values()) {
- if (taskRunner.attemptedTask.nonEmpty) {
- Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
- metrics.updateShuffleReadMetrics()
- metrics.updateInputMetrics()
- metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
-
- if (isLocal) {
- // JobProgressListener will hold an reference of it during
- // onExecutorMetricsUpdate(), then JobProgressListener can not see
- // the changes of metrics any more, so make a deep copy of it
- val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
- tasksMetrics += ((taskRunner.taskId, copiedMetrics))
- } else {
- // It will be copied by serialization
- tasksMetrics += ((taskRunner.taskId, metrics))
- }
- }
- }
- }
-
- val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
- try {
- val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
- retryAttempts, retryIntervalMs, timeout)
- if (response.reregisterBlockManager) {
- logWarning("Told to re-register on heartbeat")
- env.blockManager.reregister()
- }
- } catch {
- case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t)
- }
-
+ reportHeartBeat()
Thread.sleep(interval)
}
}
}
- t.setDaemon(true)
- t.setName("Driver Heartbeater")
- t.start()
+ thread.setDaemon(true)
+ thread.setName("driver-heartbeater")
+ thread.start()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index c4d73622c4..293c512f8b 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -17,6 +17,8 @@
package org.apache.spark.executor
+import java.util.concurrent.ThreadPoolExecutor
+
import scala.collection.JavaConversions._
import com.codahale.metrics.{Gauge, MetricRegistry}
@@ -24,9 +26,11 @@ import org.apache.hadoop.fs.FileSystem
import org.apache.spark.metrics.source.Source
-private[spark] class ExecutorSource(val executor: Executor, executorId: String) extends Source {
+private[spark]
+class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source {
+
private def fileStats(scheme: String) : Option[FileSystem.Statistics] =
- FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption
+ FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme))
private def registerFileSystemStat[T](
scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = {
@@ -41,23 +45,23 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String)
// Gauge for executor thread pool's actively executing task counts
metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] {
- override def getValue: Int = executor.threadPool.getActiveCount()
+ override def getValue: Int = threadPool.getActiveCount()
})
// Gauge for executor thread pool's approximate total number of tasks that have been completed
metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), new Gauge[Long] {
- override def getValue: Long = executor.threadPool.getCompletedTaskCount()
+ override def getValue: Long = threadPool.getCompletedTaskCount()
})
// Gauge for executor thread pool's current number of threads
metricRegistry.register(MetricRegistry.name("threadpool", "currentPool_size"), new Gauge[Int] {
- override def getValue: Int = executor.threadPool.getPoolSize()
+ override def getValue: Int = threadPool.getPoolSize()
})
// Gauge got executor thread pool's largest number of threads that have ever simultaneously
// been in th pool
metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), new Gauge[Int] {
- override def getValue: Int = executor.threadPool.getMaximumPoolSize()
+ override def getValue: Int = threadPool.getMaximumPoolSize()
})
// Gauge for file system stats of this executor
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 847a4912ee..4d9f940813 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
/**
- * Called by Executor to run this task.
+ * Called by [[Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)