aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml16
-rw-r--r--core/src/main/scala/spark/Dependency.scala14
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/RDD.scala43
-rw-r--r--core/src/main/scala/spark/SparkContext.scala54
-rw-r--r--core/src/main/scala/spark/Utils.scala8
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala39
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala38
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala13
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala24
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala22
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala58
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala4
-rw-r--r--core/src/main/scala/spark/rdd/PartitionPruningRDD.scala30
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala283
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala19
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala4
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala1
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala2
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala7
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala10
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala13
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala663
28 files changed, 1085 insertions, 311 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 862d3ec37a..66c62151fe 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<scope>test</scope>
@@ -163,11 +168,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -220,12 +220,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index 647aee6eb5..5eea907322 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -61,17 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
}
}
}
-
-
-/**
- * Represents a dependency between the PartitionPruningRDD and its parent. In this
- * case, the child RDD contains a subset of partitions of the parents'.
- */
-class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
- extends NarrowDependency[T](rdd) {
-
- @transient
- val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
-
- override def getParents(partitionId: Int) = List(partitions(partitionId).index)
-}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index aaf433b324..4735207585 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
}
}
- def cleanup(cleanupTime: Long) {
+ private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 231e23a7de..cc3cca2571 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- self.filter(_._1 == key).map(_._2).collect
+ self.filter(_._1 == key).map(_._2).collect()
}
}
@@ -590,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0
while(iter.hasNext) {
- val record = iter.next
+ val record = iter.next()
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 210404d540..9d6ea782bd 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -385,20 +385,22 @@ abstract class RDD[T: ClassManifest](
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
- }else {
+ } else {
None
}
}
- val options = sc.runJob(this, reducePartition)
- val results = new ArrayBuffer[T]
- for (opt <- options; elem <- opt) {
- results += elem
- }
- if (results.size == 0) {
- throw new UnsupportedOperationException("empty collection")
- } else {
- return results.reduceLeft(cleanF)
+ var jobResult: Option[T] = None
+ val mergeResult = (index: Int, taskResult: Option[T]) => {
+ if (taskResult != None) {
+ jobResult = jobResult match {
+ case Some(value) => Some(f(value, taskResult.get))
+ case None => taskResult
+ }
+ }
}
+ sc.runJob(this, reducePartition, mergeResult)
+ // Get the final result out of our Option, or throw an exception if the RDD was empty
+ jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
@@ -408,9 +410,13 @@ abstract class RDD[T: ClassManifest](
* modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op)
- val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
- return results.fold(zeroValue)(cleanOp)
+ val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
+ val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
+ sc.runJob(this, foldPartition, mergeResult)
+ jobResult
}
/**
@@ -422,11 +428,14 @@ abstract class RDD[T: ClassManifest](
* allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
- val results = sc.runJob(this,
- (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
- return results.fold(zeroValue)(cleanCombOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
+ sc.runJob(this, aggregatePartition, mergeResult)
+ jobResult
}
/**
@@ -437,7 +446,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}).sum
@@ -452,7 +461,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index d0eefd5229..b95049da0c 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -188,6 +188,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
@@ -559,27 +560,43 @@ class SparkContext(
}
/**
- * Run a function on a given set of partitions in an RDD and return the results. This is the main
- * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
- * whether the scheduler can run the computation on the driver rather than shipping it out to the
- * cluster, for short actions like first().
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark. The allowLocal
+ * flag specifies whether the scheduler can run the computation on the driver rather than
+ * shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
- allowLocal: Boolean
- ): Array[U] = {
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
+ * Run a function on a given set of partitions in an RDD and return the results as an array. The
+ * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
+ * than shipping it out to the cluster, for short actions like first().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean
+ ): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
@@ -607,6 +624,29 @@ class SparkContext(
}
/**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: (TaskContext, Iterator[T]) => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
+ runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
+ }
+
+ /**
* Run a job that can return approximate results.
*/
def runApproximateJob[T, U, R](
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 1e58d01273..28d643abca 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -12,6 +12,7 @@ import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
+import spark.serializer.SerializerInstance
/**
* Various utility methods used by Spark.
@@ -446,4 +447,11 @@ private object Utils extends Logging {
socket.close()
portBound
}
+
+ /**
+ * Clone an object using a Spark serializer.
+ */
+ def clone[T](value: T, serializer: SerializerInstance): T = {
+ serializer.deserialize[T](serializer.serialize(value))
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f43a152ca7..39758e94f4 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -103,21 +103,27 @@ private[spark] class PythonRDD[T: ClassManifest](
private def read(): Array[Byte] = {
try {
- val length = stream.readInt()
- if (length != -1) {
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
- } else {
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
- val update = new Array[Byte](len2)
- stream.readFully(update)
- accumulator += Collections.singletonList(update)
- }
- new Array[Byte](0)
+ stream.readInt() match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ case -2 =>
+ // Signals that an exception has been thrown in python
+ val exLength = stream.readInt()
+ val obj = new Array[Byte](exLength)
+ stream.readFully(obj)
+ throw new PythonException(new String(obj))
+ case -1 =>
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
}
} catch {
case eof: EOFException => {
@@ -140,6 +146,9 @@ private[spark] class PythonRDD[T: ClassManifest](
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 2836574ecb..22319a96ca 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- val localIpAddress = Utils.localIpAddress
+ private val localIpAddress = Utils.localIpAddress
+ private val masterActorSystems = ArrayBuffer[ActorSystem]()
+ private val workerActorSystems = ArrayBuffer[ActorSystem]()
- var masterActor : ActorRef = _
- var masterActorSystem : ActorSystem = _
- var masterPort : Int = _
- var masterUrl : String = _
-
- val workerActorSystems = ArrayBuffer[ActorSystem]()
- val workerActors = ArrayBuffer[ActorRef]()
-
- def start() : String = {
+ def start(): String = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
- masterActorSystem = actorSystem
- masterUrl = "spark://" + localIpAddress + ":" + masterPort
- masterActor = masterActorSystem.actorOf(
- Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ masterActorSystems += masterSystem
+ val masterUrl = "spark://" + localIpAddress + ":" + masterPort
- /* Start the Slaves */
+ /* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0)
- workerActorSystems += actorSystem
- val actor = actorSystem.actorOf(
- Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)),
- name = "Worker")
- workerActors += actor
+ val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ memoryPerWorker, masterUrl, null, Some(workerNum))
+ workerActorSystems += workerSystem
}
return masterUrl
@@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
// Stop the workers before the master so they don't get upset that it disconnected
workerActorSystems.foreach(_.shutdown())
workerActorSystems.foreach(_.awaitTermination())
- masterActorSystem.shutdown()
- masterActorSystem.awaitTermination()
+ masterActorSystems.foreach(_.shutdown())
+ masterActorSystems.foreach(_.awaitTermination())
}
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index 90fe9508cd..a63eee1233 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
import spark.deploy.RegisterJob
+import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
@@ -24,26 +25,18 @@ private[spark] class Client(
listener: ClientListener)
extends Logging {
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
-
var actor: ActorRef = null
var jobId: String = null
- if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
- throw new SparkException("Invalid master URL: " + masterUrl)
- }
-
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() {
- val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
+ logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(akkaUrl)
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index c618e87cdd..92e7914b1b 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -262,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
private[spark] object Master {
+ private val systemName = "sparkMaster"
+ private val actorName = "Master"
+ private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
+
+ /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ def toAkkaUrl(sparkUrl: String): String = {
+ sparkUrl match {
+ case sparkUrlRegex(host, port) =>
+ "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ case _ =>
+ throw new SparkException("Invalid master URL: " + sparkUrl)
+ }
+ }
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
+ (actorSystem, boundPort)
+ }
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index a01774f511..529f72e9da 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -45,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => null
- }
- }
+ masterState.activeJobs.find(_.id == jobId).getOrElse({
+ masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ })
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo])
@@ -61,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
-
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => null
- }
- }
+ val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
+ masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ })
+ spark.deploy.master.html.job_details.render(job)
}
}
}
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 8b41620d98..2219dd6262 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -1,7 +1,7 @@
package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import akka.actor.{ActorRef, Props, Actor}
+import akka.actor.{ActorRef, Props, Actor, ActorSystem}
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
@@ -13,6 +13,7 @@ import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed
+import spark.deploy.master.Master
import akka.actor.Terminated
import java.io.File
@@ -27,7 +28,6 @@ private[spark] class Worker(
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var master: ActorRef = null
var masterWebUiUrl : String = ""
@@ -48,11 +48,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
- workDir = if (workDirPath != null) {
- new File(workDirPath)
- } else {
- new File(sparkHome, "work")
- }
+ workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@@ -68,8 +64,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
- val envVar = System.getenv("SPARK_HOME")
- sparkHome = new File(if (envVar == null) "." else envVar)
+ sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
connectToMaster()
@@ -77,24 +72,15 @@ private[spark] class Worker(
}
def connectToMaster() {
- masterUrl match {
- case MASTER_REGEX(masterHost, masterPort) => {
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
- try {
- master = context.actorFor(akkaUrl)
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to master", e)
- System.exit(1)
- }
- }
-
- case _ =>
- logError("Invalid master URL: " + masterUrl)
+ logInfo("Connecting to master " + masterUrl)
+ try {
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to master", e)
System.exit(1)
}
}
@@ -183,11 +169,19 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
- args.master, args.workDir)),
- name = "Worker")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
+ masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ // The LocalSparkCluster runs multiple local sparkWorkerX actor systems
+ val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
+ masterUrl, workDir)), name = "Worker")
+ (actorSystem, boundPort)
+ }
+
}
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
index 42f46e06ed..24b4909380 100644
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
- // Notify any waiting thread that may have called getResult
+ // Notify any waiting thread that may have called awaitResult
this.notifyAll()
}
}
@@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
- def getResult(): PartialResult[R] = synchronized {
+ def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
index b8482338c6..a50ce75171 100644
--- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -1,24 +1,42 @@
package spark.rdd
-import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
+import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
+
+
+class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
+ override val index = idx
+}
+
+
+/**
+ * Represents a dependency between the PartitionPruningRDD and its parent. In this
+ * case, the child RDD contains a subset of partitions of the parents'.
+ */
+class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
+ extends NarrowDependency[T](rdd) {
+
+ @transient
+ val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
+ .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
+
+ override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+}
+
/**
* A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
- *
- * TODO: This currently doesn't give partition IDs properly!
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
- override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
+ override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
+ split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
override protected def getSplits =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
-
- override val partitioner = firstParent[T].partitioner
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index b130be6a38..319eef6978 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+class DAGScheduler(
+ taskSched: TaskScheduler,
+ mapOutputTracker: MapOutputTracker,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv)
+ extends TaskSchedulerListener with Logging {
+
+ def this(taskSched: TaskScheduler) {
+ this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ }
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
- val env = SparkEnv.get
- val mapOutputTracker = env.mapOutputTracker
- val blockManagerMaster = env.blockManager.master
-
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
// sent with every task. When we detect a node failing, we note the current generation number
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
@@ -90,14 +95,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
+ def start() {
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+ }
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
@@ -107,7 +114,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
cacheLocs(rdd.id)
}
- def clearCacheLocs() {
+ private def clearCacheLocs() {
cacheLocs.clear()
}
@@ -116,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -131,11 +138,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of splits is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
@@ -148,7 +155,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
- def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
@@ -170,25 +177,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList
}
- def getMissingParentStages(stage: Stage): List[Stage] = {
+ private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
- if (!mapStage.isAvailable) {
- missing += mapStage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
+ if (getCacheLocs(rdd).contains(Nil)) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
}
}
}
@@ -198,23 +202,45 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
+ /**
+ * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
+ * JobWaiter whose getResult() method will return the result of the job when it is complete.
+ *
+ * The job is assumed to have at least one partition; zero partition jobs should be handled
+ * without a JobSubmitted event.
+ */
+ private[scheduler] def prepareJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
+ : (JobSubmitted, JobWaiter[U]) =
+ {
+ assert(partitions.size > 0)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ return (toSubmit, waiter)
+ }
+
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
- allowLocal: Boolean)
- : Array[U] =
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
{
if (partitions.size == 0) {
- return new Array[U](0)
+ return
}
- val waiter = new JobWaiter(partitions.size)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
- waiter.getResult() match {
- case JobSucceeded(results: Seq[_]) =>
- return results.asInstanceOf[Seq[U]].toArray
+ val (toSubmit, waiter) = prepareJob(
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ eventQueue.put(toSubmit)
+ waiter.awaitResult() match {
+ case JobSucceeded => {}
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@@ -233,90 +259,117 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
- return listener.getResult() // Will throw an exception if the job fails
+ return listener.awaitResult() // Will throw an exception if the job fails
+ }
+
+ /**
+ * Process one event retrieved from the event queue.
+ * Returns true if we should stop the event loop.
+ */
+ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ clearCacheLocs()
+ logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+ " output partitions (allowLocal=" + allowLocal + ")")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(idToStage(taskSet.stageId), reason)
+
+ case StopDAGScheduler =>
+ // Cancel any active jobs
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ }
+ return true
+ }
+ return false
}
/**
+ * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+ * the last fetch failure.
+ */
+ private[scheduler] def resubmitFailedStages() {
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+ /**
+ * Check for waiting or failed stages which are now eligible for resubmission.
+ * Ordinarily run on every iteration of the event loop.
+ */
+ private[scheduler] def submitWaitingStages() {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+
+ /**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
- def run() {
+ private def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
- event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
- val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- clearCacheLocs()
- logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
- " output partitions")
- logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- runLocally(job)
- } else {
- activeJobs += job
- resultStageToJob(finalStage) = job
- submitStage(finalStage)
- }
-
- case ExecutorLost(execId) =>
- handleExecutorLost(execId)
-
- case completion: CompletionEvent =>
- handleTaskCompletion(completion)
-
- case TaskSetFailed(taskSet, reason) =>
- abortStage(idToStage(taskSet.stageId), reason)
-
- case StopDAGScheduler =>
- // Cancel any active jobs
- for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
- job.listener.jobFailed(error)
- }
+ if (event != null) {
+ if (processEvent(event)) {
return
-
- case null =>
- // queue.poll() timed out, ignore it
+ }
}
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- clearCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ resubmitFailedStages()
} else {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logTrace("Checking for newly runnable parent stages")
- logTrace("running: " + running)
- logTrace("waiting: " + waiting)
- logTrace("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ submitWaitingStages()
}
}
}
@@ -326,7 +379,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
- def runLocally(job: ActiveJob) {
+ private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
@@ -349,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
}
- def submitStage(stage: Stage) {
+ /** Submits stage, but first recursively submits any missing parents. */
+ private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@@ -367,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
- def submitMissingTasks(stage: Stage) {
+ /** Called when stage's parents are available and we can now do its task. */
+ private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -388,7 +443,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
if (tasks.size > 0) {
- logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
@@ -407,7 +462,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
- def handleTaskCompletion(event: CompletionEvent) {
+ private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
@@ -492,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
}
}
@@ -541,12 +596,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Optionally the generation during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
+ private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
failedGeneration(execId) = currentGeneration
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
- env.blockManager.master.removeExecutor(execId)
+ blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnExecutor(execId)
@@ -567,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- def abortStage(failedStage: Stage, reason: String) {
+ private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
@@ -583,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Return true if one of stage's ancestors is target.
*/
- def stageDependsOn(stage: Stage, target: Stage): Boolean = {
+ private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
@@ -610,7 +665,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd)
}
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
@@ -636,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
- def cleanup(cleanupTime: Long) {
+ private def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size
idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
index c4a74e526f..654131ee84 100644
--- a/core/src/main/scala/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -5,5 +5,5 @@ package spark.scheduler
*/
private[spark] sealed trait JobResult
-private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
+private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
index b3d4feebe5..3cc6a86345 100644
--- a/core/src/main/scala/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
- * An object that waits for a DAGScheduler job to complete.
+ * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
+ * results to the given handler function.
*/
-private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
- private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+ extends JobListener {
+
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
- taskResults(index) = result
+ resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
- jobResult = JobSucceeded(taskResults)
+ jobResult = JobSucceeded
this.notifyAll()
}
}
@@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
}
}
- def getResult(): JobResult = synchronized {
+ def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 83641a2a84..bed9f1864f 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
return old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask(
val bucketId = dep.partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
- val bucketIterators = buckets.map(_.iterator)
val compressedSizes = new Array[Byte](numOutputSplits)
@@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask(
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = bucketIterators(i)
+ val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 0b4177805b..1e4fbdb874 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def submitTasks(taskSet: TaskSet) {
+ override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 26201ad0dd..3dabdd76b1 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
-private[spark] class TaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet)
- extends Logging {
+private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
}
// Add a task to all the pending-task lists that it should be on.
- def addPendingTask(index: Int) {
+ private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
@@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
- def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
@@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
- def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
@@ -137,7 +134,7 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
- def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
@@ -162,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
- def findTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) {
return localTask
@@ -184,7 +181,7 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
- def isPreferredLocation(task: Task[_], host: String): Boolean = {
+ private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
}
@@ -335,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 9ff7c02097..482d1cc853 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running task " + idInJob)
+ logInfo("Running " + task)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
- logInfo("Finished task " + idInJob)
+ logInfo("Finished " + task)
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index c61fd75c2b..9893e9625d 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -950,6 +950,7 @@ class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
+ metadataCleaner.cancel()
logInfo("BlockManager stopped")
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 0372cb080a..7389bee150 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster(
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
- val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val timeout = 10.seconds
var driverActor: ActorRef = {
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e0fdeffbc4..30aec5a663 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
* Various utility classes for working with Akka.
*/
private[spark] object AkkaUtils {
+
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka).
+ *
+ * Note: the `name` parameter is important, as even if a client sends a message to right
+ * host + port, if the system name is incorrect, Akka will drop the message.
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
@@ -30,6 +34,7 @@ private[spark] object AkkaUtils {
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
+ akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
akka.remote.log-remote-lifecycle-events = on
@@ -41,7 +46,7 @@ private[spark] object AkkaUtils {
akka.actor.default-dispatcher.throughput = %d
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
- val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
+ val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
index eaff7ae581..a342d378ff 100644
--- a/core/src/main/scala/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -9,12 +9,12 @@ import spark.Logging
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
- val delaySeconds = MetadataCleaner.getDelaySeconds
- val periodSeconds = math.max(10, delaySeconds / 10)
- val timer = new Timer(name + " cleanup timer", true)
+ private val delaySeconds = MetadataCleaner.getDelaySeconds
+ private val periodSeconds = math.max(10, delaySeconds / 10)
+ private val timer = new Timer(name + " cleanup timer", true)
- val task = new TimerTask {
- def run() {
+ private val task = new TimerTask {
+ override def run() {
try {
cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
logInfo("Ran metadata cleaner for " + name)
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index ed03e65153..89a3687386 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -12,9 +12,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
- assert(dups.distinct.count === 4)
- assert(dups.distinct().collect === dups.distinct.collect)
- assert(dups.distinct(2).collect === dups.distinct.collect)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct().collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -31,6 +32,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
}
test("SparkContext.union") {
@@ -164,7 +169,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
assert(prunedRdd.splits.size === 1)
- val prunedData = prunedRdd.collect
+ val prunedData = prunedRdd.collect()
assert(prunedData.size === 1)
assert(prunedData(0) === 10)
}
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..83663ac702
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+ // impose a time limit on this test in case we don't let the job finish, in which case
+ // JobWaiter#getResult will hang.
+ override val timeLimit = Span(5, Seconds)
+
+ val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+ var scheduler: DAGScheduler = null
+ val taskScheduler = mock[TaskScheduler]
+ val blockManagerMaster = mock[BlockManagerMaster]
+ var mapOutputTracker: MapOutputTracker = null
+ var schedulerThread: Thread = null
+ var schedulerException: Throwable = null
+
+ /**
+ * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+ * We cache these so we do not create duplicate matchers for the same RDD.
+ * This allows us to easily setup a sequence of expectations for task sets for
+ * that RDD.
+ */
+ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+ /**
+ * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+ * will only submit one job) from needing to explicitly track it.
+ */
+ var lastJobWaiter: JobWaiter[Int] = null
+
+ /**
+ * Array into which we are accumulating the results from the last job asynchronously.
+ */
+ var lastJobResult: Array[Int] = null
+
+ /**
+ * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+ * and whenExecuting {...} */
+ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+ /**
+ * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+ * to be reset after each time their expectations are set, and we tend to check mock object
+ * calls over a single call to DAGScheduler.
+ *
+ * We also set a default expectation here that blockManagerMaster.getLocations can be called
+ * and will return values from cacheLocations.
+ */
+ def resetExpecting(f: => Unit) {
+ reset(taskScheduler)
+ reset(blockManagerMaster)
+ expecting {
+ expectGetLocations()
+ f
+ }
+ }
+
+ before {
+ taskSetMatchers.clear()
+ cacheLocations.clear()
+ val actorSystem = ActorSystem("test")
+ mapOutputTracker = new MapOutputTracker(actorSystem, true)
+ resetExpecting {
+ taskScheduler.setListener(anyObject())
+ }
+ whenExecuting {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+ }
+ }
+
+ after {
+ assert(scheduler.processEvent(StopDAGScheduler))
+ resetExpecting {
+ taskScheduler.stop()
+ }
+ whenExecuting {
+ scheduler.stop()
+ }
+ sc.stop()
+ System.clearProperty("spark.master.port")
+ }
+
+ def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ def makeRdd(
+ numSplits: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxSplit = numSplits - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getSplits() = (0 to maxSplit).map(i => new Split {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Split): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+ * is from a particular RDD.
+ */
+ def taskSetForRdd(rdd: MyRDD): TaskSet = {
+ val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+ new IArgumentMatcher {
+ override def matches(actual: Any): Boolean = {
+ val taskSet = actual.asInstanceOf[TaskSet]
+ taskSet.tasks(0) match {
+ case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+ case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+ case _ => false
+ }
+ }
+ override def appendTo(buf: StringBuffer) {
+ buf.append("taskSetForRdd(" + rdd + ")")
+ }
+ })
+ EasyMock.reportMatcher(matcher)
+ return null
+ }
+
+ /**
+ * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+ * cacheLocations.
+ */
+ def expectGetLocations(): Unit = {
+ EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+ andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+ override def answer(): Seq[Seq[BlockManagerId]] = {
+ val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+ return blocks.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ if (cacheLocations.contains(key)) {
+ cacheLocations(key)
+ } else {
+ Seq[BlockManagerId]()
+ }
+ } else {
+ Seq[BlockManagerId]()
+ }
+ }.toSeq
+ }
+ }).anyTimes()
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+ * called from a resetExpecting { ... } block.
+ *
+ * Returns a easymock Capture that will contain the task set after the stage is submitted.
+ * Most tests should use interceptStage() instead of this directly.
+ */
+ def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+ val taskSetCapture = new Capture[TaskSet]
+ taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+ return taskSetCapture
+ }
+
+ /**
+ * Expect the supplied code snippet to submit a stage for the specified RDD.
+ * Return the resulting TaskSet. First marks all the tasks are belonging to the
+ * current MapOutputTracker generation.
+ */
+ def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+ var capture: Capture[TaskSet] = null
+ resetExpecting {
+ capture = expectStage(rdd)
+ }
+ whenExecuting {
+ f
+ }
+ val taskSet = capture.getValue
+ for (task <- taskSet.tasks) {
+ task.generation = mapOutputTracker.getGeneration
+ }
+ return taskSet
+ }
+
+ /**
+ * Send the given CompletionEvent messages for the tasks in the TaskSet.
+ */
+ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ }
+ }
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given preferredLocations.
+ */
+ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+ it.next._1.asInstanceOf[Int]
+
+
+ /**
+ * Start a job to compute the given RDD. Returns the JobWaiter that will
+ * collect the result of the job via callbacks from DAGScheduler.
+ */
+ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+ val resultArray = new Array[Int](rdd.splits.size)
+ val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+ rdd,
+ jobComputeFunc,
+ (0 to (rdd.splits.size - 1)),
+ "test-site",
+ allowLocal,
+ (i: Int, value: Int) => resultArray(i) = value
+ )
+ lastJobWaiter = waiter
+ lastJobResult = resultArray
+ runEvent(toSubmit)
+ return (waiter, resultArray)
+ }
+
+ /**
+ * Assert that a job we started has failed.
+ */
+ def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+ waiter.awaitResult() match {
+ case JobSucceeded => fail()
+ case JobFailed(_) => return
+ }
+ }
+
+ /**
+ * Assert that a job we started has succeeded and has the given result.
+ */
+ def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+ result: Array[Int] = lastJobResult) {
+ waiter.awaitResult match {
+ case JobSucceeded =>
+ assert(expected === result)
+ case JobFailed(_) =>
+ fail()
+ }
+ }
+
+ def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ def accumulateResult(partition: Int, value: Int) {
+ numResults += 1
+ }
+ scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getSplits() = Array( new Split { override def index = 0 } )
+ override def getPreferredLocations(split: Split) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ submitRdd(rdd, true)
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("trivial job failure") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ runEvent(TaskSetFailed(taskSet, "test failure"))
+ expectJobException()
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(secondStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(secondStage, List(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+ ))
+ }
+ val thirdStage = interceptStage(shuffleMapRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val fourthStage = interceptStage(reduceRdd) {
+ respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(fourthStage, List( (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val oldGeneration = mapOutputTracker.getGeneration
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ // We rely on the event queue being ordered and increasing the generation number by 1
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ taskSet.tasks(1).generation = newGeneration
+ val secondStage = interceptStage(reduceRdd) {
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ val secondStage = interceptStage(shuffleMapRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ val thirdStage = interceptStage(reduceRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ respondToTaskSet(thirdStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ val secondStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val thirdStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(thirdStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List(
+ (Success, makeMapStatus("hostA", 2))
+ ))
+ }
+ val finalStage = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostA", 1))
+ ))
+ }
+ respondToTaskSet(finalStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostD", 1))
+ ))
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle but fails") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+ intercept[FetchFailedException]{
+ mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+ }
+
+ // Simulate the shuffle input data failing to be cached.
+ cacheLocations.remove(shuffleTwoRdd.id -> 0)
+ respondToTaskSet(recomputeTwoCached, List(
+ (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+ ))
+
+ // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+ // everything.
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+ val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+}