aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-02-05 20:27:09 -0600
committerStephen Haberman <stephen@exigencecorp.com>2013-02-05 20:27:09 -0600
commit870b2aaf5d1398704f69a5b1a8be30de522b284c (patch)
treea0572102d2545a2189334ecd7425e65273332711 /core
parent0e19093fd89ec9740f98cdcffd1ec09f4faf2490 (diff)
parenta4611d66f0d2ebc4425f385988d541b8f930e505 (diff)
downloadspark-870b2aaf5d1398704f69a5b1a8be30de522b284c.tar.gz
spark-870b2aaf5d1398704f69a5b1a8be30de522b284c.tar.bz2
spark-870b2aaf5d1398704f69a5b1a8be30de522b284c.zip
Merge branch 'master' into fixdeathpactexception
Conflicts: core/src/main/scala/spark/deploy/worker/Worker.scala
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/spark/SparkContext.scala22
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala11
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala38
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala13
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala24
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala3
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala58
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala224
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala10
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala10
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala14
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala6
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala39
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala24
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala6
-rw-r--r--core/src/main/twirl/spark/storage/rdd.scala.html6
-rw-r--r--core/src/main/twirl/spark/storage/rdd_table.scala.html6
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala3
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala3
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala2
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala663
23 files changed, 965 insertions, 237 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 873e8a1d0f..66c62151fe 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<scope>test</scope>
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index ddbf8f95d9..0efc00d5dd 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -46,6 +46,7 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import storage.BlockManagerUI
import util.{MetadataCleaner, TimeStampedHashMap}
+import storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -107,8 +108,9 @@ class SparkContext(
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
+ // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
- "SPARK_TESTING")) {
+ "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
@@ -187,6 +189,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
@@ -467,13 +470,28 @@ class SparkContext(
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
- def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
+ def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
}
}
/**
+ * Return information about what RDDs are cached, if they are in mem or on disk, how much space
+ * they take, etc.
+ */
+ def getRDDStorageInfo : Array[RDDInfo] = {
+ StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
+ }
+
+ /**
+ * Return information about blocks stored in all of the slaves
+ */
+ def getExecutorStorageStatus : Array[StorageStatus] = {
+ env.blockManager.master.getStorageStatus
+ }
+
+ /**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 39758e94f4..ab8351e55e 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -238,6 +238,11 @@ private[spark] object PythonRDD {
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ import scala.collection.JavaConverters._
+ writeIteratorToPickleFile(items.asScala, filename)
+ }
+
+ def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
@@ -245,8 +250,10 @@ private[spark] object PythonRDD {
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+ def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
+ implicit val cm : ClassManifest[T] = rdd.elementClassManifest
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
+ }
}
private object Pickle {
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 2836574ecb..22319a96ca 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- val localIpAddress = Utils.localIpAddress
+ private val localIpAddress = Utils.localIpAddress
+ private val masterActorSystems = ArrayBuffer[ActorSystem]()
+ private val workerActorSystems = ArrayBuffer[ActorSystem]()
- var masterActor : ActorRef = _
- var masterActorSystem : ActorSystem = _
- var masterPort : Int = _
- var masterUrl : String = _
-
- val workerActorSystems = ArrayBuffer[ActorSystem]()
- val workerActors = ArrayBuffer[ActorRef]()
-
- def start() : String = {
+ def start(): String = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
- masterActorSystem = actorSystem
- masterUrl = "spark://" + localIpAddress + ":" + masterPort
- masterActor = masterActorSystem.actorOf(
- Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ masterActorSystems += masterSystem
+ val masterUrl = "spark://" + localIpAddress + ":" + masterPort
- /* Start the Slaves */
+ /* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0)
- workerActorSystems += actorSystem
- val actor = actorSystem.actorOf(
- Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)),
- name = "Worker")
- workerActors += actor
+ val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ memoryPerWorker, masterUrl, null, Some(workerNum))
+ workerActorSystems += workerSystem
}
return masterUrl
@@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
// Stop the workers before the master so they don't get upset that it disconnected
workerActorSystems.foreach(_.shutdown())
workerActorSystems.foreach(_.awaitTermination())
- masterActorSystem.shutdown()
- masterActorSystem.awaitTermination()
+ masterActorSystems.foreach(_.shutdown())
+ masterActorSystems.foreach(_.awaitTermination())
}
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index 90fe9508cd..a63eee1233 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
import spark.deploy.RegisterJob
+import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
@@ -24,26 +25,18 @@ private[spark] class Client(
listener: ClientListener)
extends Logging {
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
-
var actor: ActorRef = null
var jobId: String = null
- if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
- throw new SparkException("Invalid master URL: " + masterUrl)
- }
-
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() {
- val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
+ logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(akkaUrl)
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index c618e87cdd..92e7914b1b 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -262,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
private[spark] object Master {
+ private val systemName = "sparkMaster"
+ private val actorName = "Master"
+ private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
+
+ /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ def toAkkaUrl(sparkUrl: String): String = {
+ sparkUrl match {
+ case sparkUrlRegex(host, port) =>
+ "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ case _ =>
+ throw new SparkException("Invalid master URL: " + sparkUrl)
+ }
+ }
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
+ (actorSystem, boundPort)
+ }
}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index f5ff267d44..4ef637090c 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -113,8 +113,7 @@ private[spark] class ExecutorRunner(
for ((key, value) <- jobDesc.command.environment) {
env.put(key, value)
}
- env.put("SPARK_CORES", cores.toString)
- env.put("SPARK_MEMORY", memory.toString)
+ env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 48177a638a..38547ec4f1 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -1,7 +1,7 @@
package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import akka.actor.{ActorRef, Props, Actor, Terminated}
+import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
@@ -11,6 +11,7 @@ import java.util.Date
import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed
+import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
@@ -24,7 +25,6 @@ private[spark] class Worker(
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var master: ActorRef = null
var masterWebUiUrl : String = ""
@@ -45,11 +45,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
- workDir = if (workDirPath != null) {
- new File(workDirPath)
- } else {
- new File(sparkHome, "work")
- }
+ workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@@ -65,8 +61,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
- val envVar = System.getenv("SPARK_HOME")
- sparkHome = new File(if (envVar == null) "." else envVar)
+ sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
connectToMaster()
@@ -74,24 +69,15 @@ private[spark] class Worker(
}
def connectToMaster() {
- masterUrl match {
- case MASTER_REGEX(masterHost, masterPort) => {
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
- try {
- master = context.actorFor(akkaUrl)
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to master", e)
- System.exit(1)
- }
- }
-
- case _ =>
- logError("Invalid master URL: " + masterUrl)
+ logInfo("Connecting to master " + masterUrl)
+ try {
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to master", e)
System.exit(1)
}
}
@@ -180,11 +166,19 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
- args.master, args.workDir)),
- name = "Worker")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
+ masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ // The LocalSparkCluster runs multiple local sparkWorkerX actor systems
+ val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
+ masterUrl, workDir)), name = "Worker")
+ (actorSystem, boundPort)
+ }
+
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 908a22b2df..319eef6978 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+class DAGScheduler(
+ taskSched: TaskScheduler,
+ mapOutputTracker: MapOutputTracker,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv)
+ extends TaskSchedulerListener with Logging {
+
+ def this(taskSched: TaskScheduler) {
+ this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ }
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
- val env = SparkEnv.get
- val mapOutputTracker = env.mapOutputTracker
- val blockManagerMaster = env.blockManager.master
-
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
// sent with every task. When we detect a node failing, we note the current generation number
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
@@ -90,12 +95,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
+ def start() {
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+ }
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
@@ -176,19 +183,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
- if (!mapStage.isAvailable) {
- missing += mapStage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
+ if (getCacheLocs(rdd).contains(Nil)) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
}
}
}
@@ -198,6 +202,29 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
+ /**
+ * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
+ * JobWaiter whose getResult() method will return the result of the job when it is complete.
+ *
+ * The job is assumed to have at least one partition; zero partition jobs should be handled
+ * without a JobSubmitted event.
+ */
+ private[scheduler] def prepareJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
+ : (JobSubmitted, JobWaiter[U]) =
+ {
+ assert(partitions.size > 0)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ return (toSubmit, waiter)
+ }
+
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
@@ -209,9 +236,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (partitions.size == 0) {
return
}
- val waiter = new JobWaiter(partitions.size, resultHandler)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
+ val (toSubmit, waiter) = prepareJob(
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception) =>
@@ -236,6 +263,84 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
/**
+ * Process one event retrieved from the event queue.
+ * Returns true if we should stop the event loop.
+ */
+ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ clearCacheLocs()
+ logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+ " output partitions (allowLocal=" + allowLocal + ")")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(idToStage(taskSet.stageId), reason)
+
+ case StopDAGScheduler =>
+ // Cancel any active jobs
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ }
+ return true
+ }
+ return false
+ }
+
+ /**
+ * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+ * the last fetch failure.
+ */
+ private[scheduler] def resubmitFailedStages() {
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+ /**
+ * Check for waiting or failed stages which are now eligible for resubmission.
+ * Ordinarily run on every iteration of the event loop.
+ */
+ private[scheduler] def submitWaitingStages() {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+
+ /**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
@@ -245,77 +350,26 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
- event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
- val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- clearCacheLocs()
- logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
- " output partitions")
- logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- runLocally(job)
- } else {
- activeJobs += job
- resultStageToJob(finalStage) = job
- submitStage(finalStage)
- }
-
- case ExecutorLost(execId) =>
- handleExecutorLost(execId)
-
- case completion: CompletionEvent =>
- handleTaskCompletion(completion)
-
- case TaskSetFailed(taskSet, reason) =>
- abortStage(idToStage(taskSet.stageId), reason)
-
- case StopDAGScheduler =>
- // Cancel any active jobs
- for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
- job.listener.jobFailed(error)
- }
+ if (event != null) {
+ if (processEvent(event)) {
return
-
- case null =>
- // queue.poll() timed out, ignore it
+ }
}
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- clearCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ resubmitFailedStages()
} else {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logTrace("Checking for newly runnable parent stages")
- logTrace("running: " + running)
- logTrace("waiting: " + waiting)
- logTrace("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ submitWaitingStages()
}
}
}
@@ -547,7 +601,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
failedGeneration(execId) = currentGeneration
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
- env.blockManager.master.removeExecutor(execId)
+ blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnExecutor(execId)
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index ddcd64d7c6..9ac875de3a 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.Utils
+
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
@@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ // Memory used by each executor (in megabytes)
+ protected val executorMemory = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
+
+
// TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 9760d23072..59ff8bcb90 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -20,16 +20,6 @@ private[spark] class SparkDeploySchedulerBackend(
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
- // Memory used by each executor (in megabytes)
- val executorMemory = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
override def start() {
super.start()
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index 7bf56a05d6..b481ec0a72 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend(
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
- // Memory used by each executor (in megabytes)
- val executorMemory = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index eab1c60e0b..300766d0f5 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend(
with MScheduler
with Logging {
- // Memory used by each executor (in megabytes)
- val EXECUTOR_MEMORY = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@@ -89,7 +79,7 @@ private[spark] class MesosSchedulerBackend(
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build())
+ .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
.build()
val command = CommandInfo.newBuilder()
.setValue(execScript)
@@ -161,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue
- mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
+ mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
}
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 36398095a2..7389bee150 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster(
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
- val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val timeout = 10.seconds
var driverActor: ActorRef = {
@@ -117,6 +115,10 @@ private[spark] class BlockManagerMaster(
askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
+ def getStorageStatus: Array[StorageStatus] = {
+ askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
+ }
+
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
index eda320fa47..9e6721ec17 100644
--- a/core/src/main/scala/spark/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -1,13 +1,10 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
-import akka.pattern.ask
import akka.util.Timeout
import akka.util.duration._
-import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import cc.spray.Directives
-import scala.collection.mutable.ArrayBuffer
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
@@ -48,32 +45,26 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
path("") {
completeWith {
// Request the current storage status from the Master
- val future = blockManagerMaster ? GetStorageStatus
- future.map { status =>
- // Calculate macro-level statistics
- val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
- val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
- val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
- val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
- .reduceOption(_+_).getOrElse(0L)
- val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
- spark.storage.html.index.
- render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
- }
+ val storageStatusList = sc.getExecutorStorageStatus
+ // Calculate macro-level statistics
+ val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
+ val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
+ val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
+ .reduceOption(_+_).getOrElse(0L)
+ val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
+ spark.storage.html.index.
+ render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
}
} ~
path("rdd") {
parameter("id") { id =>
completeWith {
- val future = blockManagerMaster ? GetStorageStatus
- future.map { status =>
- val prefix = "rdd_" + id.toString
- val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
- val filteredStorageStatusList = StorageUtils.
- filterStorageStatusByPrefix(storageStatusList, prefix)
- val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
- spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
- }
+ val prefix = "rdd_" + id.toString
+ val storageStatusList = sc.getExecutorStorageStatus
+ val filteredStorageStatusList = StorageUtils.
+ filterStorageStatusByPrefix(storageStatusList, prefix)
+ val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
+ spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
}
}
} ~
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index a10e3a95c6..5f72b67b2c 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -1,6 +1,6 @@
package spark.storage
-import spark.SparkContext
+import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
@@ -22,8 +22,13 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numPartitions: Int, memSize: Long, diskSize: Long)
-
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ override def toString = {
+ import Utils.memoryBytesToString
+ "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
+ storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
+ }
+}
/* Helper methods for storage-related objects */
private[spark]
@@ -38,8 +43,6 @@ object StorageUtils {
/* Given a list of BlockStatus objets, returns information for each RDD */
def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
- // Find all RDD Blocks (ignore broadcast variables)
- val rddBlocks = infos.filterKeys(_.startsWith("rdd"))
// Group by rddId, ignore the partition name
val groupedRddBlocks = infos.groupBy { case(k, v) =>
@@ -56,10 +59,11 @@ object StorageUtils {
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
// Get the friendly name for the rdd, if available.
- val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey)
- val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel
-
- RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize)
+ val rdd = sc.persistentRdds(rddId)
+ val rddName = Option(rdd.name).getOrElse(rddKey)
+ val rddStorageLevel = rdd.getStorageLevel
+
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize)
}.toArray
}
@@ -75,4 +79,4 @@ object StorageUtils {
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e43fbd6b1c..30aec5a663 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
* Various utility classes for working with Akka.
*/
private[spark] object AkkaUtils {
+
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka).
+ *
+ * Note: the `name` parameter is important, as even if a client sends a message to right
+ * host + port, if the system name is incorrect, Akka will drop the message.
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
@@ -42,7 +46,7 @@ private[spark] object AkkaUtils {
akka.actor.default-dispatcher.throughput = %d
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
- val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
+ val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html
index ac7f8c981f..d85addeb17 100644
--- a/core/src/main/twirl/spark/storage/rdd.scala.html
+++ b/core/src/main/twirl/spark/storage/rdd.scala.html
@@ -11,7 +11,11 @@
<strong>Storage Level:</strong>
@(rddInfo.storageLevel.description)
<li>
- <strong>Partitions:</strong>
+ <strong>Cached Partitions:</strong>
+ @(rddInfo.numCachedPartitions)
+ </li>
+ <li>
+ <strong>Total Partitions:</strong>
@(rddInfo.numPartitions)
</li>
<li>
diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html
index af801cf229..a51e64aed0 100644
--- a/core/src/main/twirl/spark/storage/rdd_table.scala.html
+++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html
@@ -6,7 +6,8 @@
<tr>
<th>RDD Name</th>
<th>Storage Level</th>
- <th>Partitions</th>
+ <th>Cached Partitions</th>
+ <th>Fraction Partitions Cached</th>
<th>Size in Memory</th>
<th>Size on Disk</th>
</tr>
@@ -21,7 +22,8 @@
</td>
<td>@(rdd.storageLevel.description)
</td>
- <td>@rdd.numPartitions</td>
+ <td>@rdd.numCachedPartitions</td>
+ <td>@(rdd.numCachedPartitions / rdd.numPartitions.toDouble)</td>
<td>@{Utils.memoryBytesToString(rdd.memSize)}</td>
<td>@{Utils.memoryBytesToString(rdd.diskSize)}</td>
</tr>
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
index 342610e1dd..5e84b3a66a 100644
--- a/core/src/test/scala/spark/DriverSuite.scala
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -9,10 +9,11 @@ import org.scalatest.time.SpanSugar._
class DriverSuite extends FunSuite with Timeouts {
test("driver should exit after finishing") {
+ assert(System.getenv("SPARK_HOME") != null)
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
forAll(masters) { (master: String) =>
- failAfter(10 seconds) {
+ failAfter(30 seconds) {
Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master),
new File(System.getenv("SPARK_HOME")))
}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index f4e7ec39fe..dd19442dcb 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -79,8 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("remote fetch") {
try {
System.clearProperty("spark.driver.host") // In case some previous test had set it
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("test", "localhost", 0)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
System.setProperty("spark.driver.port", boundPort.toString)
val masterTracker = new MapOutputTracker(actorSystem, true)
val slaveTracker = new MapOutputTracker(actorSystem, false)
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 89a3687386..fe7deb10d6 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -14,7 +14,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
assert(dups.distinct().count() === 4)
assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
- assert(dups.distinct().collect === dups.distinct().collect)
+ assert(dups.distinct.collect === dups.distinct().collect)
assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..83663ac702
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+ // impose a time limit on this test in case we don't let the job finish, in which case
+ // JobWaiter#getResult will hang.
+ override val timeLimit = Span(5, Seconds)
+
+ val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+ var scheduler: DAGScheduler = null
+ val taskScheduler = mock[TaskScheduler]
+ val blockManagerMaster = mock[BlockManagerMaster]
+ var mapOutputTracker: MapOutputTracker = null
+ var schedulerThread: Thread = null
+ var schedulerException: Throwable = null
+
+ /**
+ * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+ * We cache these so we do not create duplicate matchers for the same RDD.
+ * This allows us to easily setup a sequence of expectations for task sets for
+ * that RDD.
+ */
+ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+ /**
+ * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+ * will only submit one job) from needing to explicitly track it.
+ */
+ var lastJobWaiter: JobWaiter[Int] = null
+
+ /**
+ * Array into which we are accumulating the results from the last job asynchronously.
+ */
+ var lastJobResult: Array[Int] = null
+
+ /**
+ * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+ * and whenExecuting {...} */
+ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+ /**
+ * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+ * to be reset after each time their expectations are set, and we tend to check mock object
+ * calls over a single call to DAGScheduler.
+ *
+ * We also set a default expectation here that blockManagerMaster.getLocations can be called
+ * and will return values from cacheLocations.
+ */
+ def resetExpecting(f: => Unit) {
+ reset(taskScheduler)
+ reset(blockManagerMaster)
+ expecting {
+ expectGetLocations()
+ f
+ }
+ }
+
+ before {
+ taskSetMatchers.clear()
+ cacheLocations.clear()
+ val actorSystem = ActorSystem("test")
+ mapOutputTracker = new MapOutputTracker(actorSystem, true)
+ resetExpecting {
+ taskScheduler.setListener(anyObject())
+ }
+ whenExecuting {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+ }
+ }
+
+ after {
+ assert(scheduler.processEvent(StopDAGScheduler))
+ resetExpecting {
+ taskScheduler.stop()
+ }
+ whenExecuting {
+ scheduler.stop()
+ }
+ sc.stop()
+ System.clearProperty("spark.master.port")
+ }
+
+ def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ def makeRdd(
+ numSplits: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxSplit = numSplits - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getSplits() = (0 to maxSplit).map(i => new Split {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Split): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+ * is from a particular RDD.
+ */
+ def taskSetForRdd(rdd: MyRDD): TaskSet = {
+ val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+ new IArgumentMatcher {
+ override def matches(actual: Any): Boolean = {
+ val taskSet = actual.asInstanceOf[TaskSet]
+ taskSet.tasks(0) match {
+ case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+ case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+ case _ => false
+ }
+ }
+ override def appendTo(buf: StringBuffer) {
+ buf.append("taskSetForRdd(" + rdd + ")")
+ }
+ })
+ EasyMock.reportMatcher(matcher)
+ return null
+ }
+
+ /**
+ * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+ * cacheLocations.
+ */
+ def expectGetLocations(): Unit = {
+ EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+ andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+ override def answer(): Seq[Seq[BlockManagerId]] = {
+ val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+ return blocks.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ if (cacheLocations.contains(key)) {
+ cacheLocations(key)
+ } else {
+ Seq[BlockManagerId]()
+ }
+ } else {
+ Seq[BlockManagerId]()
+ }
+ }.toSeq
+ }
+ }).anyTimes()
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+ * called from a resetExpecting { ... } block.
+ *
+ * Returns a easymock Capture that will contain the task set after the stage is submitted.
+ * Most tests should use interceptStage() instead of this directly.
+ */
+ def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+ val taskSetCapture = new Capture[TaskSet]
+ taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+ return taskSetCapture
+ }
+
+ /**
+ * Expect the supplied code snippet to submit a stage for the specified RDD.
+ * Return the resulting TaskSet. First marks all the tasks are belonging to the
+ * current MapOutputTracker generation.
+ */
+ def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+ var capture: Capture[TaskSet] = null
+ resetExpecting {
+ capture = expectStage(rdd)
+ }
+ whenExecuting {
+ f
+ }
+ val taskSet = capture.getValue
+ for (task <- taskSet.tasks) {
+ task.generation = mapOutputTracker.getGeneration
+ }
+ return taskSet
+ }
+
+ /**
+ * Send the given CompletionEvent messages for the tasks in the TaskSet.
+ */
+ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ }
+ }
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given preferredLocations.
+ */
+ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+ it.next._1.asInstanceOf[Int]
+
+
+ /**
+ * Start a job to compute the given RDD. Returns the JobWaiter that will
+ * collect the result of the job via callbacks from DAGScheduler.
+ */
+ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+ val resultArray = new Array[Int](rdd.splits.size)
+ val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+ rdd,
+ jobComputeFunc,
+ (0 to (rdd.splits.size - 1)),
+ "test-site",
+ allowLocal,
+ (i: Int, value: Int) => resultArray(i) = value
+ )
+ lastJobWaiter = waiter
+ lastJobResult = resultArray
+ runEvent(toSubmit)
+ return (waiter, resultArray)
+ }
+
+ /**
+ * Assert that a job we started has failed.
+ */
+ def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+ waiter.awaitResult() match {
+ case JobSucceeded => fail()
+ case JobFailed(_) => return
+ }
+ }
+
+ /**
+ * Assert that a job we started has succeeded and has the given result.
+ */
+ def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+ result: Array[Int] = lastJobResult) {
+ waiter.awaitResult match {
+ case JobSucceeded =>
+ assert(expected === result)
+ case JobFailed(_) =>
+ fail()
+ }
+ }
+
+ def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ def accumulateResult(partition: Int, value: Int) {
+ numResults += 1
+ }
+ scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getSplits() = Array( new Split { override def index = 0 } )
+ override def getPreferredLocations(split: Split) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ submitRdd(rdd, true)
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("trivial job failure") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ runEvent(TaskSetFailed(taskSet, "test failure"))
+ expectJobException()
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(secondStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(secondStage, List(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+ ))
+ }
+ val thirdStage = interceptStage(shuffleMapRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val fourthStage = interceptStage(reduceRdd) {
+ respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(fourthStage, List( (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val oldGeneration = mapOutputTracker.getGeneration
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ // We rely on the event queue being ordered and increasing the generation number by 1
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ taskSet.tasks(1).generation = newGeneration
+ val secondStage = interceptStage(reduceRdd) {
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ val secondStage = interceptStage(shuffleMapRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ val thirdStage = interceptStage(reduceRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ respondToTaskSet(thirdStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ val secondStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val thirdStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(thirdStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List(
+ (Success, makeMapStatus("hostA", 2))
+ ))
+ }
+ val finalStage = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostA", 1))
+ ))
+ }
+ respondToTaskSet(finalStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostD", 1))
+ ))
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle but fails") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+ intercept[FetchFailedException]{
+ mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+ }
+
+ // Simulate the shuffle input data failing to be cached.
+ cacheLocations.remove(shuffleTwoRdd.id -> 0)
+ respondToTaskSet(recomputeTwoCached, List(
+ (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+ ))
+
+ // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+ // everything.
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+ val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+}