aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala77
10 files changed, 130 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dda194d953..4cef0825dd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -68,6 +68,11 @@ class DAGScheduler(
eventQueue.put(BeginEvent(task, taskInfo))
}
+ // Called to report that a task has completed and results are being fetched remotely.
+ def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
+ eventQueue.put(GettingResultEvent(task, taskInfo))
+ }
+
// Called by TaskScheduler to report task completions or failures.
def taskEnded(
task: Task[_],
@@ -415,6 +420,9 @@ class DAGScheduler(
case begin: BeginEvent =>
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case gettingResult: GettingResultEvent =>
+ listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
+
case completion: CompletionEvent =>
listenerBus.post(SparkListenerTaskEnd(
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a5769c6041..708d221d60 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler]
+case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 324cd639b0..a35081f7b1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+case class SparkListenerTaskGettingResult(
+ task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
@@ -57,6 +60,12 @@ trait SparkListener {
def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
+ * Called when a task begins remotely fetching its result (will not be called for tasks that do
+ * not need to fetch the result remotely).
+ */
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+
+ /**
* Called when a task ends
*/
def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 4d3e4a17ba..d5824e7954 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 7c2a422aff..4bae26f3a6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -31,9 +31,25 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
+ /**
+ * The time when the task started remotely getting the result. Will not be set if the
+ * task result was sent immediately when the task finished (as opposed to sending an
+ * IndirectTaskResult and later fetching the result from the block manager).
+ */
+ var gettingResultTime: Long = 0
+
+ /**
+ * The time when the task has completed successfully (including the time to remotely fetch
+ * results, if necessary).
+ */
var finishTime: Long = 0
+
var failed = false
+ def markGettingResult(time: Long = System.currentTimeMillis) {
+ gettingResultTime = time
+ }
+
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@@ -43,6 +59,8 @@ class TaskInfo(
failed = true
}
+ def gettingResult: Boolean = gettingResultTime != 0
+
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@@ -52,6 +70,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
+ else if (gettingResult)
+ "GET RESULT"
else if (failed)
"FAILED"
else if (successful)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 4ea8bf8853..85033958ef 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ taskSetManager.handleTaskGettingResult(tid)
+ }
+
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 29093e3b4f..ee47aaffca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager(
sched.dagScheduler.taskStarted(task, info)
}
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+ }
+
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index 4312c46cc1..2064d97b49 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
+ scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 9bb8a13ec4..6b854740d6 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
taskList += ((taskStart.taskInfo, None, None))
stageIdToTaskInfos(sid) = taskList
}
-
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
+ = synchronized {
+ // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
+ // stageToTaskInfos already has the updated status.
+ }
+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 42ca988f7a..f7f599532a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,22 +17,25 @@
package org.apache.spark.scheduler
-import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.apache.spark.{LocalSparkContext, SparkContext}
-import scala.collection.mutable
+import scala.collection.mutable.{Buffer, HashSet}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
- with BeforeAndAfter {
+ with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
- before {
- sc = new SparkContext("local", "DAGSchedulerSuite")
+ override def afterAll {
+ System.clearProperty("spark.akka.frameSize")
}
test("basic creation of StageInfo") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("StageInfo with fewer tasks than partitions") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("local metrics") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}
+ test("onTaskGettingResult() called when result fetched remotely") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ System.setProperty("spark.akka.frameSize", "1")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
+ assert(listener.endedTasks.contains(TASK_INDEX))
+ }
+
+ test("onTaskGettingResult() not called when result sent directly") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+ assert(result === 2)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.isEmpty == true)
+ assert(listener.endedTasks.contains(TASK_INDEX))
+ }
+
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
assert(m.sum / m.size.toDouble > 0.0, msg)
}
class SaveStageInfo extends SparkListener {
- val stageInfos = mutable.Buffer[StageInfo]()
+ val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stage
}
}
+ class SaveTaskEvents extends SparkListener {
+ val startedTasks = new HashSet[Int]()
+ val startedGettingResultTasks = new HashSet[Int]()
+ val endedTasks = new HashSet[Int]()
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ startedTasks += taskStart.taskInfo.index
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ endedTasks += taskEnd.taskInfo.index
+ }
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
+ startedGettingResultTasks += taskGettingResult.taskInfo.index
+ }
+ }
}