diff options
10 files changed, 130 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dda194d953..4cef0825dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -68,6 +68,11 @@ class DAGScheduler( eventQueue.put(BeginEvent(task, taskInfo)) } + // Called to report that a task has completed and results are being fetched remotely. + def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { + eventQueue.put(GettingResultEvent(task, taskInfo)) + } + // Called by TaskScheduler to report task completions or failures. def taskEnded( task: Task[_], @@ -415,6 +420,9 @@ class DAGScheduler( case begin: BeginEvent => listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + case gettingResult: GettingResultEvent => + listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo)) + case completion: CompletionEvent => listenerBus.post(SparkListenerTaskEnd( completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a5769c6041..708d221d60 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 324cd639b0..a35081f7b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents +case class SparkListenerTaskGettingResult( + task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents + case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents @@ -57,6 +60,12 @@ trait SparkListener { def onTaskStart(taskStart: SparkListenerTaskStart) { } /** + * Called when a task begins remotely fetching its result (will not be called for tasks that do + * not need to fetch the result remotely). + */ + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + + /** * Called when a task ends */ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 4d3e4a17ba..d5824e7954 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 7c2a422aff..4bae26f3a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -31,9 +31,25 @@ class TaskInfo( val host: String, val taskLocality: TaskLocality.TaskLocality) { + /** + * The time when the task started remotely getting the result. Will not be set if the + * task result was sent immediately when the task finished (as opposed to sending an + * IndirectTaskResult and later fetching the result from the block manager). + */ + var gettingResultTime: Long = 0 + + /** + * The time when the task has completed successfully (including the time to remotely fetch + * results, if necessary). + */ var finishTime: Long = 0 + var failed = false + def markGettingResult(time: Long = System.currentTimeMillis) { + gettingResultTime = time + } + def markSuccessful(time: Long = System.currentTimeMillis) { finishTime = time } @@ -43,6 +59,8 @@ class TaskInfo( failed = true } + def gettingResult: Boolean = gettingResultTime != 0 + def finished: Boolean = finishTime != 0 def successful: Boolean = finished && !failed @@ -52,6 +70,8 @@ class TaskInfo( def status: String = { if (running) "RUNNING" + else if (gettingResult) + "GET RESULT" else if (failed) "FAILED" else if (successful) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 4ea8bf8853..85033958ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + def handleSuccessfulTask( taskSetManager: ClusterTaskSetManager, tid: Long, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 29093e3b4f..ee47aaffca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager( sched.dagScheduler.taskStarted(task, info) } + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index 4312c46cc1..2064d97b49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case directResult: DirectTaskResult[_] => directResult case IndirectTaskResult(blockId) => logDebug("Fetching indirect task result for TID %s".format(tid)) + scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) if (!serializedTaskResult.isDefined) { /* We won't be able to get the task result if the machine that ran the task failed diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 9bb8a13ec4..6b854740d6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList taskList += ((taskStart.taskInfo, None, None)) stageIdToTaskInfos(sid) = taskList } - + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) + = synchronized { + // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in + // stageToTaskInfos already has the updated status. + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 42ca988f7a..f7f599532a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,22 +17,25 @@ package org.apache.spark.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkContext} -import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashSet} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers - with BeforeAndAfter { + with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 - before { - sc = new SparkContext("local", "DAGSchedulerSuite") + override def afterAll { + System.clearProperty("spark.akka.frameSize") } test("basic creation of StageInfo") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("StageInfo with fewer tasks than partitions") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("local metrics") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } + test("onTaskGettingResult() called when result fetched remotely") { + // Need to use local cluster mode here, because results are not ever returned through the + // block manager when using the LocalScheduler. + sc = new SparkContext("local-cluster[1,1,512]", "test") + + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the akka frame size + System.setProperty("spark.akka.frameSize", "1") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x) + assert(result === 1.to(akkaFrameSize).toArray) + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) + assert(listener.endedTasks.contains(TASK_INDEX)) + } + + test("onTaskGettingResult() not called when result sent directly") { + // Need to use local cluster mode here, because results are not ever returned through the + // block manager when using the LocalScheduler. + sc = new SparkContext("local-cluster[1,1,512]", "test") + + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the akka frame size + val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) + assert(result === 2) + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.isEmpty == true) + assert(listener.endedTasks.contains(TASK_INDEX)) + } + def checkNonZeroAvg(m: Traversable[Long], msg: String) { assert(m.sum / m.size.toDouble > 0.0, msg) } class SaveStageInfo extends SparkListener { - val stageInfos = mutable.Buffer[StageInfo]() + val stageInfos = Buffer[StageInfo]() override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stage } } + class SaveTaskEvents extends SparkListener { + val startedTasks = new HashSet[Int]() + val startedGettingResultTasks = new HashSet[Int]() + val endedTasks = new HashSet[Int]() + + override def onTaskStart(taskStart: SparkListenerTaskStart) { + startedTasks += taskStart.taskInfo.index + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + endedTasks += taskEnd.taskInfo.index + } + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { + startedGettingResultTasks += taskGettingResult.taskInfo.index + } + } } |