aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorSandy Ryza <sandy@cloudera.com>2014-08-01 11:08:39 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-08-01 11:08:39 -0700
commit8d338f64c4eda45d22ae33f61ef7928011cc2846 (patch)
treee15459a74699b857b99ad80c86510b041a69bbff /core
parent5328c0aaa09911c848f9b3e1e1f2397bef932d0f (diff)
downloadspark-8d338f64c4eda45d22ae33f61ef7928011cc2846.tar.gz
spark-8d338f64c4eda45d22ae33f61ef7928011cc2846.tar.bz2
spark-8d338f64c4eda45d22ae33f61ef7928011cc2846.zip
SPARK-2099. Report progress while task is running.
This is a sketch of a patch that allows the UI to show metrics for tasks that have not yet completed. It adds a heartbeat every 2 seconds from the executors to the driver, reporting metrics for all of the executor's tasks. It still needs unit tests, polish, and cluster testing, but I wanted to put it up to get feedback on the approach. Author: Sandy Ryza <sandy@cloudera.com> Closes #1056 from sryza/sandy-spark-2099 and squashes the following commits: 93b9fdb [Sandy Ryza] Up heartbeat interval to 10 seconds and other tidying 132aec7 [Sandy Ryza] Heartbeat and HeartbeatResponse are already Serializable as case classes 38dffde [Sandy Ryza] Additional review feedback and restore test that was removed in BlockManagerSuite 51fa396 [Sandy Ryza] Remove hostname race, add better comments about threading, and some stylistic improvements 3084f10 [Sandy Ryza] Make TaskUIData a case class again 3bda974 [Sandy Ryza] Stylistic fixes 0dae734 [Sandy Ryza] SPARK-2099. Report progress while task is running.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala117
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala66
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala23
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala86
23 files changed, 460 insertions, 157 deletions
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
new file mode 100644
index 0000000000..24ccce21b6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import akka.actor.Actor
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.scheduler.TaskScheduler
+
+/**
+ * A heartbeat from executors to the driver. This is a shared message used by several internal
+ * components to convey liveness or execution information for in-progress tasks.
+ */
+private[spark] case class Heartbeat(
+ executorId: String,
+ taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
+ blockManagerId: BlockManagerId)
+
+private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
+
+/**
+ * Lives in the driver to receive heartbeats from executors..
+ */
+private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor {
+ override def receive = {
+ case Heartbeat(executorId, taskMetrics, blockManagerId) =>
+ val response = HeartbeatResponse(
+ !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId))
+ sender ! response
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0e513568b0..5f75c1dd2c 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary
+import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
@@ -307,6 +308,8 @@ class SparkContext(config: SparkConf) extends Logging {
// Create and start the scheduler
private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master)
+ private val heartbeatReceiver = env.actorSystem.actorOf(
+ Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
dagScheduler = new DAGScheduler(this)
@@ -992,6 +995,7 @@ class SparkContext(config: SparkConf) extends Logging {
if (dagSchedulerCopy != null) {
env.metricsSystem.report()
metadataCleaner.cancel()
+ env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
taskScheduler = null
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 6ee731b22c..92c809d854 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -193,13 +193,7 @@ object SparkEnv extends Logging {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
- val driverHost: String = conf.get("spark.driver.host", "localhost")
- val driverPort: Int = conf.getInt("spark.driver.port", 7077)
- Utils.checkHost(driverHost, "Expected hostname")
- val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name"
- val timeout = AkkaUtils.lookupTimeout(conf)
- logInfo(s"Connecting to $name: $url")
- Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ AkkaUtils.makeDriverRef(name, conf, actorSystem)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 99d650a363..1bb1b4aae9 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import java.util.concurrent._
import scala.collection.JavaConversions._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark._
import org.apache.spark.scheduler._
@@ -48,6 +48,8 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
+ @volatile private var isStopped = false
+
// No ip or host:port - just hostname
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
// must not have port specified.
@@ -107,6 +109,8 @@ private[spark] class Executor(
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+ startDriverHeartbeater()
+
def launchTask(
context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, taskName, serializedTask)
@@ -121,8 +125,10 @@ private[spark] class Executor(
}
}
- def stop(): Unit = {
+ def stop() {
env.metricsSystem.report()
+ isStopped = true
+ threadPool.shutdown()
}
/** Get the Yarn approved local directories. */
@@ -141,11 +147,12 @@ private[spark] class Executor(
}
class TaskRunner(
- execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer)
+ execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
extends Runnable {
@volatile private var killed = false
- @volatile private var task: Task[Any] = _
+ @volatile var task: Task[Any] = _
+ @volatile var attemptedTask: Option[Task[Any]] = None
def kill(interruptThread: Boolean) {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
@@ -162,7 +169,6 @@ private[spark] class Executor(
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
- var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
val startGCTime = gcTime
@@ -204,7 +210,6 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.hostname = Utils.localHostName()
m.executorDeserializeTime = taskStart - startTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
@@ -354,4 +359,42 @@ private[spark] class Executor(
}
}
}
+
+ def startDriverHeartbeater() {
+ val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ val retryAttempts = AkkaUtils.numRetries(conf)
+ val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
+ val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
+
+ val t = new Thread() {
+ override def run() {
+ // Sleep a random interval so the heartbeats don't end up in sync
+ Thread.sleep(interval + (math.random * interval).asInstanceOf[Int])
+
+ while (!isStopped) {
+ val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ for (taskRunner <- runningTasks.values()) {
+ if (!taskRunner.attemptedTask.isEmpty) {
+ Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
+ tasksMetrics += ((taskRunner.taskId, metrics))
+ }
+ }
+ }
+
+ val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
+ val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
+ retryAttempts, retryIntervalMs, timeout)
+ if (response.reregisterBlockManager) {
+ logWarning("Told to re-register on heartbeat")
+ env.blockManager.reregister()
+ }
+ Thread.sleep(interval)
+ }
+ }
+ }
+ t.setDaemon(true)
+ t.setName("Driver Heartbeater")
+ t.start()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 21fe643b8d..56cd8723a3 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -23,6 +23,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
/**
* :: DeveloperApi ::
* Metrics tracked during the execution of a task.
+ *
+ * This class is used to house metrics both for in-progress and completed tasks. In executors,
+ * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread
+ * reads it to send in-progress metrics, and the task thread reads it to send metrics along with
+ * the completed task.
+ *
+ * So, when adding new fields, take into consideration that the whole object can be serialized for
+ * shipping off at any time to consumers of the SparkListener interface.
*/
@DeveloperApi
class TaskMetrics extends Serializable {
@@ -143,7 +151,7 @@ class ShuffleReadMetrics extends Serializable {
/**
* Absolute time when this task finished reading shuffle data
*/
- var shuffleFinishTime: Long = _
+ var shuffleFinishTime: Long = -1
/**
* Number of blocks fetched in this shuffle by this task (remote or local)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 50186d097a..c7e3d7c5f8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -29,7 +29,6 @@ import scala.reflect.ClassTag
import scala.util.control.NonFatal
import akka.actor._
-import akka.actor.OneForOneStrategy
import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout
@@ -39,8 +38,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
+import org.apache.spark.storage._
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
+import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -154,6 +154,23 @@ class DAGScheduler(
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
}
+ /**
+ * Update metrics for in-progress tasks and let the master know that the BlockManager is still
+ * alive. Return true if the driver knows about the given block manager. Otherwise, return false,
+ * indicating that the block manager should re-register.
+ */
+ def executorHeartbeatReceived(
+ execId: String,
+ taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics)
+ blockManagerId: BlockManagerId): Boolean = {
+ listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
+ implicit val timeout = Timeout(600 seconds)
+
+ Await.result(
+ blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId),
+ timeout.duration).asInstanceOf[Boolean]
+ }
+
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
eventProcessActor ! ExecutorLost(execId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 82163eadd5..d01d318633 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -76,6 +76,12 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId)
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
@DeveloperApi
+case class SparkListenerExecutorMetricsUpdate(
+ execId: String,
+ taskMetrics: Seq[(Long, Int, TaskMetrics)])
+ extends SparkListenerEvent
+
+@DeveloperApi
case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String)
extends SparkListenerEvent
@@ -158,6 +164,11 @@ trait SparkListener {
* Called when the application ends
*/
def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { }
+
+ /**
+ * Called when the driver receives task metrics from an executor in a heartbeat.
+ */
+ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index ed9fb24bc8..e79ffd7a35 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -68,6 +68,8 @@ private[spark] trait SparkListenerBus extends Logging {
foreachListener(_.onApplicationStart(applicationStart))
case applicationEnd: SparkListenerApplicationEnd =>
foreachListener(_.onApplicationEnd(applicationEnd))
+ case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
+ foreachListener(_.onExecutorMetricsUpdate(metricsUpdate))
case SparkListenerShutdown =>
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 5871edeb85..5c5e421404 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -26,6 +26,8 @@ import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.Utils
+
/**
* A unit of execution. We have two kinds of Task's in Spark:
@@ -44,6 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ context.taskMetrics.hostname = Utils.localHostName();
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 819c35257b..1a0b877c8a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -18,6 +18,8 @@
package org.apache.spark.scheduler
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
/**
* Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
@@ -54,4 +56,12 @@ private[spark] trait TaskScheduler {
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
+
+ /**
+ * Update metrics for in-progress tasks and let the master know that the BlockManager is still
+ * alive. Return true if the driver knows about the given block manager. Otherwise, return false,
+ * indicating that the block manager should re-register.
+ */
+ def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
+ blockManagerId: BlockManagerId): Boolean
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index be3673c48e..d2f764fc22 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -32,6 +32,9 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.util.Utils
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
+import akka.actor.Props
/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
@@ -320,6 +323,26 @@ private[spark] class TaskSchedulerImpl(
}
}
+ /**
+ * Update metrics for in-progress tasks and let the master know that the BlockManager is still
+ * alive. Return true if the driver knows about the given block manager. Otherwise, return false,
+ * indicating that the block manager should re-register.
+ */
+ override def executorHeartbeatReceived(
+ execId: String,
+ taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
+ blockManagerId: BlockManagerId): Boolean = {
+ val metricsWithStageIds = taskMetrics.flatMap {
+ case (id, metrics) => {
+ taskIdToTaskSetId.get(id)
+ .flatMap(activeTaskSets.get)
+ .map(_.stageId)
+ .map(x => (id, x, metrics))
+ }
+ }
+ dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
+ }
+
def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 5b897597fa..3d1cf312cc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -23,8 +23,9 @@ import akka.actor.{Actor, ActorRef, Props}
import org.apache.spark.{Logging, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.storage.BlockManagerId
private case class ReviveOffers()
@@ -32,6 +33,8 @@ private case class StatusUpdate(taskId: Long, state: TaskState, serializedData:
private case class KillTask(taskId: Long, interruptThread: Boolean)
+private case class StopExecutor()
+
/**
* Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
* LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
@@ -63,6 +66,9 @@ private[spark] class LocalActor(
case KillTask(taskId, interruptThread) =>
executor.killTask(taskId, interruptThread)
+
+ case StopExecutor =>
+ executor.stop()
}
def reviveOffers() {
@@ -91,6 +97,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
}
override def stop() {
+ localActor ! StopExecutor
}
override def reviveOffers() {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index d746526639..c0a0601794 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -116,15 +116,6 @@ private[spark] class BlockManager(
private var asyncReregisterTask: Future[Unit] = null
private val asyncReregisterLock = new Object
- private def heartBeat(): Unit = {
- if (!master.sendHeartBeat(blockManagerId)) {
- reregister()
- }
- }
-
- private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- private var heartBeatTask: Cancellable = null
-
private val metadataCleaner = new MetadataCleaner(
MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf)
private val broadcastCleaner = new MetadataCleaner(
@@ -161,11 +152,6 @@ private[spark] class BlockManager(
private def initialize(): Unit = {
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
BlockManagerWorker.startBlockManagerWorker(this)
- if (!BlockManager.getDisableHeartBeatsForTesting(conf)) {
- heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) {
- Utils.tryOrExit { heartBeat() }
- }
- }
}
/**
@@ -195,7 +181,7 @@ private[spark] class BlockManager(
*
* Note that this method must be called without any BlockInfo locks held.
*/
- private def reregister(): Unit = {
+ def reregister(): Unit = {
// TODO: We might need to rate limit re-registering.
logInfo("BlockManager re-registering with master")
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
@@ -1065,9 +1051,6 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- if (heartBeatTask != null) {
- heartBeatTask.cancel()
- }
connectionManager.stop()
shuffleBlockManager.stop()
diskBlockManager.stop()
@@ -1095,12 +1078,6 @@ private[spark] object BlockManager extends Logging {
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
- def getHeartBeatFrequency(conf: SparkConf): Long =
- conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4
-
- def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean =
- conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false)
-
/**
* Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
* might cause errors if one attempts to read from the unmapped buffer, but it's better than
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 7897fade2d..669307765d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -21,7 +21,6 @@ import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global
import akka.actor._
-import akka.pattern.ask
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
@@ -29,8 +28,8 @@ import org.apache.spark.util.AkkaUtils
private[spark]
class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging {
- val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3)
- val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000)
+ private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf)
+ private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf)
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
@@ -42,15 +41,6 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
logInfo("Removed " + execId + " successfully in removeExecutor")
}
- /**
- * Send the driver actor a heart beat from the slave. Returns true if everything works out,
- * false if the driver does not know about the given block manager, which means the block
- * manager should re-register.
- */
- def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
- askDriverWithReply[Boolean](HeartBeat(blockManagerId))
- }
-
/** Register the BlockManager's id with the driver. */
def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
@@ -223,33 +213,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
* throw a SparkException if this fails.
*/
private def askDriverWithReply[T](message: Any): T = {
- // TODO: Consider removing multiple attempts
- if (driverActor == null) {
- throw new SparkException("Error sending message to BlockManager as driverActor is null " +
- "[message = " + message + "]")
- }
- var attempts = 0
- var lastException: Exception = null
- while (attempts < AKKA_RETRY_ATTEMPTS) {
- attempts += 1
- try {
- val future = driverActor.ask(message)(timeout)
- val result = Await.result(future, timeout)
- if (result == null) {
- throw new SparkException("BlockManagerMaster returned null")
- }
- return result.asInstanceOf[T]
- } catch {
- case ie: InterruptedException => throw ie
- case e: Exception =>
- lastException = e
- logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e)
- }
- Thread.sleep(AKKA_RETRY_INTERVAL_MS)
- }
-
- throw new SparkException(
- "Error sending message to BlockManagerMaster [message = " + message + "]", lastException)
+ AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS,
+ timeout)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index de1cc5539f..94f5a4bb2e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -52,25 +52,24 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private val akkaTimeout = AkkaUtils.askTimeout(conf)
- val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs",
- "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong
+ val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs",
+ math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000))
- val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs",
- "60000").toLong
+ val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs",
+ 60000)
var timeoutCheckingTask: Cancellable = null
override def preStart() {
- if (!BlockManager.getDisableHeartBeatsForTesting(conf)) {
- import context.dispatcher
- timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
- checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
- }
+ import context.dispatcher
+ timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
+ checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
super.preStart()
}
def receive = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
+ logInfo("received a register")
register(blockManagerId, maxMemSize, slaveActor)
sender ! true
@@ -129,8 +128,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case ExpireDeadHosts =>
expireDeadHosts()
- case HeartBeat(blockManagerId) =>
- sender ! heartBeat(blockManagerId)
+ case BlockManagerHeartbeat(blockManagerId) =>
+ sender ! heartbeatReceived(blockManagerId)
case other =>
logWarning("Got unknown message: " + other)
@@ -216,7 +215,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
val minSeenTime = now - slaveTimeout
val toRemove = new mutable.HashSet[BlockManagerId]
for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime) {
+ if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "<driver>") {
logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: "
+ (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
toRemove += info.blockManagerId
@@ -230,7 +229,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
}
- private def heartBeat(blockManagerId: BlockManagerId): Boolean = {
+ /**
+ * Return true if the driver knows about the given block manager. Otherwise, return false,
+ * indicating that the block manager should re-register.
+ */
+ private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
blockManagerId.executorId == "<driver>" && !isLocal
} else {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 2b53bf33b5..10b65286fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import akka.actor.ActorRef
-private[storage] object BlockManagerMessages {
+private[spark] object BlockManagerMessages {
//////////////////////////////////////////////////////////////////////////////////
// Messages from the master to slaves.
//////////////////////////////////////////////////////////////////////////////////
@@ -53,8 +53,6 @@ private[storage] object BlockManagerMessages {
sender: ActorRef)
extends ToBlockManagerMaster
- case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
-
class UpdateBlockInfo(
var blockManagerId: BlockManagerId,
var blockId: BlockId,
@@ -124,5 +122,7 @@ private[storage] object BlockManagerMessages {
case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true)
extends ToBlockManagerMaster
+ case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+
case object ExpireDeadHosts extends ToBlockManagerMaster
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index efb527b4f0..da2f5d3172 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -130,32 +130,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
new StageUIData
})
- // create executor summary map if necessary
- val executorSummaryMap = stageData.executorSummary
- executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary)
-
- executorSummaryMap.get(info.executorId).foreach { y =>
- // first update failed-task, succeed-task
- taskEnd.reason match {
- case Success =>
- y.succeededTasks += 1
- case _ =>
- y.failedTasks += 1
- }
-
- // update duration
- y.taskTime += info.duration
-
- val metrics = taskEnd.taskMetrics
- if (metrics != null) {
- metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead }
- metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead }
- metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten }
- y.memoryBytesSpilled += metrics.memoryBytesSpilled
- y.diskBytesSpilled += metrics.diskBytesSpilled
- }
+ val execSummaryMap = stageData.executorSummary
+ val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary)
+
+ taskEnd.reason match {
+ case Success =>
+ execSummary.succeededTasks += 1
+ case _ =>
+ execSummary.failedTasks += 1
}
-
+ execSummary.taskTime += info.duration
stageData.numActiveTasks -= 1
val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) =
@@ -171,28 +155,75 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
(Some(e.toErrorString), None)
}
+ if (!metrics.isEmpty) {
+ val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics)
+ updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics)
+ }
- val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L)
- stageData.executorRunTime += taskRunTime
- val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)
- stageData.inputBytes += inputBytes
-
- val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)
- stageData.shuffleReadBytes += shuffleRead
-
- val shuffleWrite =
- metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)
- stageData.shuffleWriteBytes += shuffleWrite
-
- val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L)
- stageData.memoryBytesSpilled += memoryBytesSpilled
+ val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info))
+ taskData.taskInfo = info
+ taskData.taskMetrics = metrics
+ taskData.errorMessage = errorMessage
+ }
+ }
- val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L)
- stageData.diskBytesSpilled += diskBytesSpilled
+ /**
+ * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage
+ * aggregate metrics by calculating deltas between the currently recorded metrics and the new
+ * metrics.
+ */
+ def updateAggregateMetrics(
+ stageData: StageUIData,
+ execId: String,
+ taskMetrics: TaskMetrics,
+ oldMetrics: Option[TaskMetrics]) {
+ val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary)
+
+ val shuffleWriteDelta =
+ (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L)
+ - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L))
+ stageData.shuffleWriteBytes += shuffleWriteDelta
+ execSummary.shuffleWrite += shuffleWriteDelta
+
+ val shuffleReadDelta =
+ (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L)
+ - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L))
+ stageData.shuffleReadBytes += shuffleReadDelta
+ execSummary.shuffleRead += shuffleReadDelta
+
+ val diskSpillDelta =
+ taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L)
+ stageData.diskBytesSpilled += diskSpillDelta
+ execSummary.diskBytesSpilled += diskSpillDelta
+
+ val memorySpillDelta =
+ taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L)
+ stageData.memoryBytesSpilled += memorySpillDelta
+ execSummary.memoryBytesSpilled += memorySpillDelta
+
+ val timeDelta =
+ taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L)
+ stageData.executorRunTime += timeDelta
+ }
- stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage)
+ override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) {
+ for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
+ val stageData = stageIdToData.getOrElseUpdate(sid, {
+ logWarning("Metrics update for task in unknown stage " + sid)
+ new StageUIData
+ })
+ val taskData = stageData.taskData.get(taskId)
+ taskData.map { t =>
+ if (!t.taskInfo.finished) {
+ updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics,
+ t.taskMetrics)
+
+ // Overwrite task metrics
+ t.taskMetrics = Some(taskMetrics)
+ }
+ }
}
- } // end of onTaskEnd
+ }
override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) {
synchronized {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index be11a11695..2f96f7909c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -55,8 +55,11 @@ private[jobs] object UIData {
var executorSummary = new HashMap[String, ExecutorSummary]
}
+ /**
+ * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation.
+ */
case class TaskUIData(
- taskInfo: TaskInfo,
- taskMetrics: Option[TaskMetrics] = None,
- errorMessage: Option[String] = None)
+ var taskInfo: TaskInfo,
+ var taskMetrics: Option[TaskMetrics] = None,
+ var errorMessage: Option[String] = None)
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 9930c71749..feafd654e9 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -18,13 +18,16 @@
package org.apache.spark.util
import scala.collection.JavaConversions.mapAsJavaMap
+import scala.concurrent.Await
import scala.concurrent.duration.{Duration, FiniteDuration}
-import akka.actor.{ActorSystem, ExtendedActorSystem}
+import akka.actor.{Actor, ActorRef, ActorSystem, ExtendedActorSystem}
+import akka.pattern.ask
+
import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
/**
* Various utility classes for working with Akka.
@@ -124,4 +127,63 @@ private[spark] object AkkaUtils extends Logging {
/** Space reserved for extra data in an Akka message besides serialized task or task result. */
val reservedSizeBytes = 200 * 1024
+
+ /** Returns the configured number of times to retry connecting */
+ def numRetries(conf: SparkConf): Int = {
+ conf.getInt("spark.akka.num.retries", 3)
+ }
+
+ /** Returns the configured number of milliseconds to wait on each retry */
+ def retryWaitMs(conf: SparkConf): Int = {
+ conf.getInt("spark.akka.retry.wait", 3000)
+ }
+
+ /**
+ * Send a message to the given actor and get its result within a default timeout, or
+ * throw a SparkException if this fails.
+ */
+ def askWithReply[T](
+ message: Any,
+ actor: ActorRef,
+ retryAttempts: Int,
+ retryInterval: Int,
+ timeout: FiniteDuration): T = {
+ // TODO: Consider removing multiple attempts
+ if (actor == null) {
+ throw new SparkException("Error sending message as driverActor is null " +
+ "[message = " + message + "]")
+ }
+ var attempts = 0
+ var lastException: Exception = null
+ while (attempts < retryAttempts) {
+ attempts += 1
+ try {
+ val future = actor.ask(message)(timeout)
+ val result = Await.result(future, timeout)
+ if (result == null) {
+ throw new SparkException("Actor returned null")
+ }
+ return result.asInstanceOf[T]
+ } catch {
+ case ie: InterruptedException => throw ie
+ case e: Exception =>
+ lastException = e
+ logWarning("Error sending message in " + attempts + " attempts", e)
+ }
+ Thread.sleep(retryInterval)
+ }
+
+ throw new SparkException(
+ "Error sending message [message = " + message + "]", lastException)
+ }
+
+ def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = {
+ val driverHost: String = conf.get("spark.driver.host", "localhost")
+ val driverPort: Int = conf.getInt("spark.driver.port", 7077)
+ Utils.checkHost(driverHost, "Expected hostname")
+ val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name"
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ logInfo(s"Connecting to $name: $url")
+ Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index 4b727e50db..495a0d4863 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.scalatest.{FunSuite, PrivateMethodTester}
+import org.scalatest.{BeforeAndAfterEach, FunSuite, PrivateMethodTester}
import org.apache.spark.scheduler.{TaskScheduler, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend}
@@ -25,12 +25,12 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
class SparkContextSchedulerCreationSuite
- extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
+ extends FunSuite with PrivateMethodTester with Logging with BeforeAndAfterEach {
def createTaskScheduler(master: String): TaskSchedulerImpl = {
// Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
// real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
- sc = new SparkContext("local", "test")
+ val sc = new SparkContext("local", "test")
val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master)
sched.asInstanceOf[TaskSchedulerImpl]
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 9021662bcf..0ce13d015d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.util.CallSite
+import org.apache.spark.executor.TaskMetrics
class BuggyDAGEventProcessActor extends Actor {
val state = 0
@@ -77,6 +78,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
override def schedulingMode: SchedulingMode = SchedulingMode.NONE
override def start() = {}
override def stop() = {}
+ override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
+ blockManagerId: BlockManagerId): Boolean = true
override def submitTasks(taskSet: TaskSet) = {
// normally done by TaskSetManager
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
@@ -342,6 +345,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
+ override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
+ blockManagerId: BlockManagerId): Boolean = true
}
val noKillScheduler = new DAGScheduler(
sc,
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 58ea0cc30e..0ac0269d7c 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -19,22 +19,28 @@ package org.apache.spark.storage
import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
+import java.util.concurrent.TimeUnit
import akka.actor._
+import akka.pattern.ask
+import akka.util.Timeout
+
import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
import org.scalatest.Matchers
-import org.scalatest.time.SpanSugar._
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.Await
+import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
@@ -73,7 +79,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
oldArch = System.setProperty("os.arch", "amd64")
conf.set("os.arch", "amd64")
conf.set("spark.test.useCompressedOops", "true")
- conf.set("spark.storage.disableBlockManagerHeartBeat", "true")
conf.set("spark.driver.port", boundPort.toString)
conf.set("spark.storage.unrollFraction", "0.4")
conf.set("spark.storage.unrollMemoryThreshold", "512")
@@ -341,7 +346,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
}
test("reregistration on heart beat") {
- val heartBeat = PrivateMethod[Unit]('heartBeat)
store = makeBlockManager(2000)
val a1 = new Array[Byte](400)
@@ -353,13 +357,15 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- store invokePrivate heartBeat()
- assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
+ implicit val timeout = Timeout(30, TimeUnit.SECONDS)
+ val reregister = !Await.result(
+ master.driverActor ? BlockManagerHeartbeat(store.blockManagerId),
+ timeout.duration).asInstanceOf[Boolean]
+ assert(reregister == true)
}
test("reregistration on block update") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
- securityMgr, mapOutputTracker)
+ store = makeBlockManager(2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -377,7 +383,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
}
test("reregistration doesn't dead lock") {
- val heartBeat = PrivateMethod[Unit]('heartBeat)
store = makeBlockManager(2000)
val a1 = new Array[Byte](400)
val a2 = List(new Array[Byte](400))
@@ -397,7 +402,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
}
val t3 = new Thread {
override def run() {
- store invokePrivate heartBeat()
+ store.reregister()
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 86a271eb67..cb82525152 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -21,7 +21,8 @@ import org.scalatest.FunSuite
import org.scalatest.Matchers
import org.apache.spark._
-import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import org.apache.spark.{LocalSparkContext, SparkConf, Success}
+import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -129,4 +130,87 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1)
assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount)
}
+
+ test("test update metrics") {
+ val conf = new SparkConf()
+ val listener = new JobProgressListener(conf)
+
+ val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0))
+ val execId = "exe-1"
+
+ def makeTaskMetrics(base: Int) = {
+ val taskMetrics = new TaskMetrics()
+ val shuffleReadMetrics = new ShuffleReadMetrics()
+ val shuffleWriteMetrics = new ShuffleWriteMetrics()
+ taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
+ taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
+ shuffleReadMetrics.remoteBytesRead = base + 1
+ shuffleReadMetrics.remoteBlocksFetched = base + 2
+ shuffleWriteMetrics.shuffleBytesWritten = base + 3
+ taskMetrics.executorRunTime = base + 4
+ taskMetrics.diskBytesSpilled = base + 5
+ taskMetrics.memoryBytesSpilled = base + 6
+ taskMetrics
+ }
+
+ def makeTaskInfo(taskId: Long, finishTime: Int = 0) = {
+ val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL,
+ false)
+ taskInfo.finishTime = finishTime
+ taskInfo
+ }
+
+ listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L)))
+ listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L)))
+ listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L)))
+ listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L)))
+
+ listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
+ (1234L, 0, makeTaskMetrics(0)),
+ (1235L, 0, makeTaskMetrics(100)),
+ (1236L, 1, makeTaskMetrics(200)))))
+
+ var stage0Data = listener.stageIdToData.get(0).get
+ var stage1Data = listener.stageIdToData.get(1).get
+ assert(stage0Data.shuffleReadBytes == 102)
+ assert(stage1Data.shuffleReadBytes == 201)
+ assert(stage0Data.shuffleWriteBytes == 106)
+ assert(stage1Data.shuffleWriteBytes == 203)
+ assert(stage0Data.executorRunTime == 108)
+ assert(stage1Data.executorRunTime == 204)
+ assert(stage0Data.diskBytesSpilled == 110)
+ assert(stage1Data.diskBytesSpilled == 205)
+ assert(stage0Data.memoryBytesSpilled == 112)
+ assert(stage1Data.memoryBytesSpilled == 206)
+ assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get
+ .totalBlocksFetched == 2)
+ assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get
+ .totalBlocksFetched == 102)
+ assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get
+ .totalBlocksFetched == 202)
+
+ // task that was included in a heartbeat
+ listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1),
+ makeTaskMetrics(300)))
+ // task that wasn't included in a heartbeat
+ listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1),
+ makeTaskMetrics(400)))
+
+ stage0Data = listener.stageIdToData.get(0).get
+ stage1Data = listener.stageIdToData.get(1).get
+ assert(stage0Data.shuffleReadBytes == 402)
+ assert(stage1Data.shuffleReadBytes == 602)
+ assert(stage0Data.shuffleWriteBytes == 406)
+ assert(stage1Data.shuffleWriteBytes == 606)
+ assert(stage0Data.executorRunTime == 408)
+ assert(stage1Data.executorRunTime == 608)
+ assert(stage0Data.diskBytesSpilled == 410)
+ assert(stage1Data.diskBytesSpilled == 610)
+ assert(stage0Data.memoryBytesSpilled == 412)
+ assert(stage1Data.memoryBytesSpilled == 612)
+ assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get
+ .totalBlocksFetched == 302)
+ assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get
+ .totalBlocksFetched == 402)
+ }
}