aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-02-02 19:09:59 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-02-02 19:09:59 -0800
commit85019d76a44df9c4d19592da1484f925ef5b5b56 (patch)
tree75d3ea27bae9d4a72461cdd651c7acb1cfe364c6
parentae26911ec0d768dcdae8b7d706ca4544e36535e6 (diff)
parent610795796257ff63e6a5ac0473b183de461a72d4 (diff)
downloadspark-85019d76a44df9c4d19592da1484f925ef5b5b56.tar.gz
spark-85019d76a44df9c4d19592da1484f925ef5b5b56.tar.bz2
spark-85019d76a44df9c4d19592da1484f925ef5b5b56.zip
Merge pull request #427 from woggling/dag-sched-tests
Tests for DAGScheduler
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/spark/SparkContext.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala197
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala663
-rw-r--r--pom.xml6
-rw-r--r--project/SparkBuild.scala3
6 files changed, 802 insertions, 73 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 873e8a1d0f..66c62151fe 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<scope>test</scope>
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index ddbf8f95d9..2ed458c6fe 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -187,6 +187,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 908a22b2df..8cfc08e5ac 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+class DAGScheduler(
+ taskSched: TaskScheduler,
+ mapOutputTracker: MapOutputTracker,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv)
+ extends TaskSchedulerListener with Logging {
+
+ def this(taskSched: TaskScheduler) {
+ this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ }
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
- val env = SparkEnv.get
- val mapOutputTracker = env.mapOutputTracker
- val blockManagerMaster = env.blockManager.master
-
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
// sent with every task. When we detect a node failing, we note the current generation number
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
@@ -90,12 +95,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
+ def start() {
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+ }
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
@@ -198,6 +205,28 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
+ /** Returns (and does not) submit a JobSubmitted event suitable to run a given job, and
+ * a JobWaiter whose getResult() method will return the result of the job when it is complete.
+ *
+ * The job is assumed to have at least one partition; zero partition jobs should be handled
+ * without a JobSubmitted event.
+ */
+ private[scheduler] def prepareJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
+ : (JobSubmitted, JobWaiter[U]) =
+ {
+ assert(partitions.size > 0)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ return (toSubmit, waiter)
+ }
+
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
@@ -209,9 +238,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (partitions.size == 0) {
return
}
- val waiter = new JobWaiter(partitions.size, resultHandler)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
+ val (toSubmit, waiter) = prepareJob(
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception) =>
@@ -235,6 +264,81 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return listener.awaitResult() // Will throw an exception if the job fails
}
+ /** Process one event retrieved from the event queue.
+ * Returns true if we should stop the event loop.
+ */
+ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ clearCacheLocs()
+ logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+ " output partitions (allowLocal=" + allowLocal + ")")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(idToStage(taskSet.stageId), reason)
+
+ case StopDAGScheduler =>
+ // Cancel any active jobs
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ }
+ return true
+ }
+ return false
+ }
+
+ /** Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+ * the last fetch failure.
+ */
+ private[scheduler] def resubmitFailedStages() {
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+ /** Check for waiting or failed stages which are now eligible for resubmission.
+ * Ordinarily run on every iteration of the event loop.
+ */
+ private[scheduler] def submitWaitingStages() {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
@@ -245,77 +349,26 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
- event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
- val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- clearCacheLocs()
- logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
- " output partitions")
- logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- runLocally(job)
- } else {
- activeJobs += job
- resultStageToJob(finalStage) = job
- submitStage(finalStage)
- }
-
- case ExecutorLost(execId) =>
- handleExecutorLost(execId)
-
- case completion: CompletionEvent =>
- handleTaskCompletion(completion)
-
- case TaskSetFailed(taskSet, reason) =>
- abortStage(idToStage(taskSet.stageId), reason)
-
- case StopDAGScheduler =>
- // Cancel any active jobs
- for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
- job.listener.jobFailed(error)
- }
+ if (event != null) {
+ if (processEvent(event)) {
return
-
- case null =>
- // queue.poll() timed out, ignore it
+ }
}
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- clearCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ resubmitFailedStages
} else {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logTrace("Checking for newly runnable parent stages")
- logTrace("running: " + running)
- logTrace("waiting: " + waiting)
- logTrace("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ submitWaitingStages
}
}
}
@@ -547,7 +600,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
failedGeneration(execId) = currentGeneration
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
- env.blockManager.master.removeExecutor(execId)
+ blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnExecutor(execId)
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..83663ac702
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+ // impose a time limit on this test in case we don't let the job finish, in which case
+ // JobWaiter#getResult will hang.
+ override val timeLimit = Span(5, Seconds)
+
+ val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+ var scheduler: DAGScheduler = null
+ val taskScheduler = mock[TaskScheduler]
+ val blockManagerMaster = mock[BlockManagerMaster]
+ var mapOutputTracker: MapOutputTracker = null
+ var schedulerThread: Thread = null
+ var schedulerException: Throwable = null
+
+ /**
+ * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+ * We cache these so we do not create duplicate matchers for the same RDD.
+ * This allows us to easily setup a sequence of expectations for task sets for
+ * that RDD.
+ */
+ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+ /**
+ * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+ * will only submit one job) from needing to explicitly track it.
+ */
+ var lastJobWaiter: JobWaiter[Int] = null
+
+ /**
+ * Array into which we are accumulating the results from the last job asynchronously.
+ */
+ var lastJobResult: Array[Int] = null
+
+ /**
+ * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+ * and whenExecuting {...} */
+ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+ /**
+ * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+ * to be reset after each time their expectations are set, and we tend to check mock object
+ * calls over a single call to DAGScheduler.
+ *
+ * We also set a default expectation here that blockManagerMaster.getLocations can be called
+ * and will return values from cacheLocations.
+ */
+ def resetExpecting(f: => Unit) {
+ reset(taskScheduler)
+ reset(blockManagerMaster)
+ expecting {
+ expectGetLocations()
+ f
+ }
+ }
+
+ before {
+ taskSetMatchers.clear()
+ cacheLocations.clear()
+ val actorSystem = ActorSystem("test")
+ mapOutputTracker = new MapOutputTracker(actorSystem, true)
+ resetExpecting {
+ taskScheduler.setListener(anyObject())
+ }
+ whenExecuting {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+ }
+ }
+
+ after {
+ assert(scheduler.processEvent(StopDAGScheduler))
+ resetExpecting {
+ taskScheduler.stop()
+ }
+ whenExecuting {
+ scheduler.stop()
+ }
+ sc.stop()
+ System.clearProperty("spark.master.port")
+ }
+
+ def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ def makeRdd(
+ numSplits: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxSplit = numSplits - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getSplits() = (0 to maxSplit).map(i => new Split {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Split): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+ * is from a particular RDD.
+ */
+ def taskSetForRdd(rdd: MyRDD): TaskSet = {
+ val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+ new IArgumentMatcher {
+ override def matches(actual: Any): Boolean = {
+ val taskSet = actual.asInstanceOf[TaskSet]
+ taskSet.tasks(0) match {
+ case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+ case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+ case _ => false
+ }
+ }
+ override def appendTo(buf: StringBuffer) {
+ buf.append("taskSetForRdd(" + rdd + ")")
+ }
+ })
+ EasyMock.reportMatcher(matcher)
+ return null
+ }
+
+ /**
+ * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+ * cacheLocations.
+ */
+ def expectGetLocations(): Unit = {
+ EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+ andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+ override def answer(): Seq[Seq[BlockManagerId]] = {
+ val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+ return blocks.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ if (cacheLocations.contains(key)) {
+ cacheLocations(key)
+ } else {
+ Seq[BlockManagerId]()
+ }
+ } else {
+ Seq[BlockManagerId]()
+ }
+ }.toSeq
+ }
+ }).anyTimes()
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+ * called from a resetExpecting { ... } block.
+ *
+ * Returns a easymock Capture that will contain the task set after the stage is submitted.
+ * Most tests should use interceptStage() instead of this directly.
+ */
+ def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+ val taskSetCapture = new Capture[TaskSet]
+ taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+ return taskSetCapture
+ }
+
+ /**
+ * Expect the supplied code snippet to submit a stage for the specified RDD.
+ * Return the resulting TaskSet. First marks all the tasks are belonging to the
+ * current MapOutputTracker generation.
+ */
+ def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+ var capture: Capture[TaskSet] = null
+ resetExpecting {
+ capture = expectStage(rdd)
+ }
+ whenExecuting {
+ f
+ }
+ val taskSet = capture.getValue
+ for (task <- taskSet.tasks) {
+ task.generation = mapOutputTracker.getGeneration
+ }
+ return taskSet
+ }
+
+ /**
+ * Send the given CompletionEvent messages for the tasks in the TaskSet.
+ */
+ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ }
+ }
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given preferredLocations.
+ */
+ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+ it.next._1.asInstanceOf[Int]
+
+
+ /**
+ * Start a job to compute the given RDD. Returns the JobWaiter that will
+ * collect the result of the job via callbacks from DAGScheduler.
+ */
+ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+ val resultArray = new Array[Int](rdd.splits.size)
+ val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+ rdd,
+ jobComputeFunc,
+ (0 to (rdd.splits.size - 1)),
+ "test-site",
+ allowLocal,
+ (i: Int, value: Int) => resultArray(i) = value
+ )
+ lastJobWaiter = waiter
+ lastJobResult = resultArray
+ runEvent(toSubmit)
+ return (waiter, resultArray)
+ }
+
+ /**
+ * Assert that a job we started has failed.
+ */
+ def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+ waiter.awaitResult() match {
+ case JobSucceeded => fail()
+ case JobFailed(_) => return
+ }
+ }
+
+ /**
+ * Assert that a job we started has succeeded and has the given result.
+ */
+ def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+ result: Array[Int] = lastJobResult) {
+ waiter.awaitResult match {
+ case JobSucceeded =>
+ assert(expected === result)
+ case JobFailed(_) =>
+ fail()
+ }
+ }
+
+ def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ def accumulateResult(partition: Int, value: Int) {
+ numResults += 1
+ }
+ scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getSplits() = Array( new Split { override def index = 0 } )
+ override def getPreferredLocations(split: Split) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ submitRdd(rdd, true)
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("trivial job failure") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ runEvent(TaskSetFailed(taskSet, "test failure"))
+ expectJobException()
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(secondStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(secondStage, List(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+ ))
+ }
+ val thirdStage = interceptStage(shuffleMapRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val fourthStage = interceptStage(reduceRdd) {
+ respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(fourthStage, List( (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val oldGeneration = mapOutputTracker.getGeneration
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ // We rely on the event queue being ordered and increasing the generation number by 1
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ taskSet.tasks(1).generation = newGeneration
+ val secondStage = interceptStage(reduceRdd) {
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ val secondStage = interceptStage(shuffleMapRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ val thirdStage = interceptStage(reduceRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ respondToTaskSet(thirdStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ val secondStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val thirdStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(thirdStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List(
+ (Success, makeMapStatus("hostA", 2))
+ ))
+ }
+ val finalStage = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostA", 1))
+ ))
+ }
+ respondToTaskSet(finalStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostD", 1))
+ ))
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle but fails") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+ intercept[FetchFailedException]{
+ mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+ }
+
+ // Simulate the shuffle input data failing to be cached.
+ cacheLocations.remove(shuffleTwoRdd.id -> 0)
+ respondToTaskSet(recomputeTwoCached, List(
+ (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+ ))
+
+ // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+ // everything.
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+ val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+}
diff --git a/pom.xml b/pom.xml
index c6b9012dc6..7e06cae052 100644
--- a/pom.xml
+++ b/pom.xml
@@ -274,6 +274,12 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <version>3.1</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.version}</artifactId>
<version>1.9</version>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 03b8094f7d..af8b5ba017 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -92,7 +92,8 @@ object SparkBuild extends Build {
"org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
"org.scalatest" %% "scalatest" % "1.8" % "test",
"org.scalacheck" %% "scalacheck" % "1.9" % "test",
- "com.novocode" % "junit-interface" % "0.8" % "test"
+ "com.novocode" % "junit-interface" % "0.8" % "test",
+ "org.easymock" % "easymock" % "3.1" % "test"
),
parallelExecution := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */