aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
blob: f802b66cf13f914ba1eab1f144650c2aea59c4ab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package spark.scheduler

import scala.collection.mutable.{Map, HashMap}

import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter

import spark.LocalSparkContext

import spark.storage.BlockManager
import spark.storage.BlockManagerId
import spark.storage.BlockManagerMaster
import spark.{Dependency, ShuffleDependency, OneToOneDependency}
import spark.FetchFailedException
import spark.MapOutputTracker
import spark.RDD
import spark.SparkContext
import spark.SparkException
import spark.Partition
import spark.TaskContext
import spark.TaskEndReason

import spark.{FetchFailed, Success}

/**
 * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
 * rather than spawning an event loop thread as happens in the real code. They use EasyMock
 * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
 * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
 * host notifications are sent). In addition, tests may check for side effects on a non-mocked
 * MapOutputTracker instance.
 *
 * Tests primarily consist of running DAGScheduler#processEvent and
 * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
 * and capturing the resulting TaskSets from the mock TaskScheduler.
 */
class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {

  /** Set of TaskSets the DAGScheduler has requested executed. */
  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
  val taskScheduler = new TaskScheduler() {
    override def start() = {}
    override def stop() = {}
    override def submitTasks(taskSet: TaskSet) = {
      // normally done by TaskSetManager
      taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
      taskSets += taskSet
    }
    override def setListener(listener: TaskSchedulerListener) = {}
    override def defaultParallelism() = 2
  }

  var mapOutputTracker: MapOutputTracker = null
  var scheduler: DAGScheduler = null

  /**
   * Set of cache locations to return from our mock BlockManagerMaster.
   * Keys are (rdd ID, partition ID). Anything not present will return an empty
   * list of cache locations silently.
   */
  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
  // stub out BlockManagerMaster.getLocations to use our cacheLocations
  val blockManagerMaster = new BlockManagerMaster(null) {
      override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
        blockIds.map { name =>
          val pieces = name.split("_")
          if (pieces(0) == "rdd") {
            val key = pieces(1).toInt -> pieces(2).toInt
            cacheLocations.getOrElse(key, Seq())
          } else {
            Seq()
          }
        }.toSeq
      }
      override def removeExecutor(execId: String) {
        // don't need to propagate to the driver, which we don't have
      }
    }

  /** The list of results that DAGScheduler has collected. */
  val results = new HashMap[Int, Any]()
  var failure: Exception = _
  val listener = new JobListener() {
    override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
    override def jobFailed(exception: Exception) = { failure = exception }
  }

  before {
    sc = new SparkContext("local", "DAGSchedulerSuite")
    taskSets.clear()
    cacheLocations.clear()
    results.clear()
    mapOutputTracker = new MapOutputTracker()
    scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
      override def runLocally(job: ActiveJob) {
        // don't bother with the thread while unit testing
        runLocallyWithinThread(job)
      }
    }
  }

  after {
    scheduler.stop()
  }

  /**
   * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
   * This is a pair RDD type so it can always be used in ShuffleDependencies.
   */
  type MyRDD = RDD[(Int, Int)]

  /**
   * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
   * preferredLocations (if any) that are passed to them. They are deliberately not executable
   * so we can test that DAGScheduler does not try to execute RDDs locally.
   */
  private def makeRdd(
        numPartitions: Int,
        dependencies: List[Dependency[_]],
        locations: Seq[Seq[String]] = Nil
      ): MyRDD = {
    val maxPartition = numPartitions - 1
    return new MyRDD(sc, dependencies) {
      override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
        throw new RuntimeException("should not be reached")
      override def getPartitions = (0 to maxPartition).map(i => new Partition {
        override def index = i
      }).toArray
      override def getPreferredLocations(split: Partition): Seq[String] =
        if (locations.isDefinedAt(split.index))
          locations(split.index)
        else
          Nil
      override def toString: String = "DAGSchedulerSuiteRDD " + id
    }
  }

  /**
   * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
   * the scheduler not to exit.
   *
   * After processing the event, submit waiting stages as is done on most iterations of the
   * DAGScheduler event loop.
   */
  private def runEvent(event: DAGSchedulerEvent) {
    assert(!scheduler.processEvent(event))
    scheduler.submitWaitingStages()
  }

  /**
   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
   * below, we do not expect this function to ever be executed; instead, we will return results
   * directly through CompletionEvents.
   */
  private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
     it.next.asInstanceOf[Tuple2[_, _]]._1

  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
  private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
    assert(taskSet.tasks.size >= results.size)
    for ((result, i) <- results.zipWithIndex) {
      if (i < taskSet.tasks.size) {
        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
      }
    }
  }

  /** Sends the rdd to the scheduler for scheduling. */
  private def submit(
      rdd: RDD[_],
      partitions: Array[Int],
      func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
      allowLocal: Boolean = false,
      listener: JobListener = listener) {
    runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
  }

  /** Sends TaskSetFailed to the scheduler. */
  private def failed(taskSet: TaskSet, message: String) {
    runEvent(TaskSetFailed(taskSet, message))
  }

  test("zero split job") {
    val rdd = makeRdd(0, Nil)
    var numResults = 0
    val fakeListener = new JobListener() {
      override def taskSucceeded(partition: Int, value: Any) = numResults += 1
      override def jobFailed(exception: Exception) = throw exception
    }
    submit(rdd, Array(), listener = fakeListener)
    assert(numResults === 0)
  }

  test("run trivial job") {
    val rdd = makeRdd(1, Nil)
    submit(rdd, Array(0))
    complete(taskSets(0), List((Success, 42)))
    assert(results === Map(0 -> 42))
  }

  test("local job") {
    val rdd = new MyRDD(sc, Nil) {
      override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
        Array(42 -> 0).iterator
      override def getPartitions = Array( new Partition { override def index = 0 } )
      override def getPreferredLocations(split: Partition) = Nil
      override def toString = "DAGSchedulerSuite Local RDD"
    }
    runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
    assert(results === Map(0 -> 42))
  }

  test("run trivial job w/ dependency") {
    val baseRdd = makeRdd(1, Nil)
    val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
    submit(finalRdd, Array(0))
    complete(taskSets(0), Seq((Success, 42)))
    assert(results === Map(0 -> 42))
  }

  test("cache location preferences w/ dependency") {
    val baseRdd = makeRdd(1, Nil)
    val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
    cacheLocations(baseRdd.id -> 0) =
      Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
    submit(finalRdd, Array(0))
    val taskSet = taskSets(0)
    assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
    complete(taskSet, Seq((Success, 42)))
    assert(results === Map(0 -> 42))
  }

  test("trivial job failure") {
    submit(makeRdd(1, Nil), Array(0))
    failed(taskSets(0), "some failure")
    assert(failure.getMessage === "Job failed: some failure")
  }

  test("run trivial shuffle") {
    val shuffleMapRdd = makeRdd(2, Nil)
    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
    val shuffleId = shuffleDep.shuffleId
    val reduceRdd = makeRdd(1, List(shuffleDep))
    submit(reduceRdd, Array(0))
    complete(taskSets(0), Seq(
        (Success, makeMapStatus("hostA", 1)),
        (Success, makeMapStatus("hostB", 1))))
    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
           Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
    complete(taskSets(1), Seq((Success, 42)))
    assert(results === Map(0 -> 42))
  }

  test("run trivial shuffle with fetch failure") {
    val shuffleMapRdd = makeRdd(2, Nil)
    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
    val shuffleId = shuffleDep.shuffleId
    val reduceRdd = makeRdd(2, List(shuffleDep))
    submit(reduceRdd, Array(0, 1))
    complete(taskSets(0), Seq(
        (Success, makeMapStatus("hostA", 1)),
        (Success, makeMapStatus("hostB", 1))))
    // the 2nd ResultTask failed
    complete(taskSets(1), Seq(
        (Success, 42),
        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
    // this will get called
    // blockManagerMaster.removeExecutor("exec-hostA")
    // ask the scheduler to try it again
    scheduler.resubmitFailedStages()
    // have the 2nd attempt pass
    complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
    // we can see both result blocks now
    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
    complete(taskSets(3), Seq((Success, 43)))
    assert(results === Map(0 -> 42, 1 -> 43))
  }

  test("ignore late map task completions") {
    val shuffleMapRdd = makeRdd(2, Nil)
    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
    val shuffleId = shuffleDep.shuffleId
    val reduceRdd = makeRdd(2, List(shuffleDep))
    submit(reduceRdd, Array(0, 1))
    // pretend we were told hostA went away
    val oldGeneration = mapOutputTracker.getGeneration
    runEvent(ExecutorLost("exec-hostA"))
    val newGeneration = mapOutputTracker.getGeneration
    assert(newGeneration > oldGeneration)
    val noAccum = Map[Long, Any]()
    val taskSet = taskSets(0)
    // should be ignored for being too old
    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
    // should work because it's a non-failed host
    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
    // should be ignored for being too old
    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
    // should work because it's a new generation
    taskSet.tasks(1).generation = newGeneration
    runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
           Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
    complete(taskSets(1), Seq((Success, 42), (Success, 43)))
    assert(results === Map(0 -> 42, 1 -> 43))
  }

  test("run trivial shuffle with out-of-band failure and retry") {
    val shuffleMapRdd = makeRdd(2, Nil)
    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
    val shuffleId = shuffleDep.shuffleId
    val reduceRdd = makeRdd(1, List(shuffleDep))
    submit(reduceRdd, Array(0))
    // blockManagerMaster.removeExecutor("exec-hostA")
    // pretend we were told hostA went away
    runEvent(ExecutorLost("exec-hostA"))
    // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
    // rather than marking it is as failed and waiting.
    complete(taskSets(0), Seq(
        (Success, makeMapStatus("hostA", 1)),
       (Success, makeMapStatus("hostB", 1))))
   // have hostC complete the resubmitted task
   complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
   assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
          Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
   complete(taskSets(2), Seq((Success, 42)))
   assert(results === Map(0 -> 42))
 }

 test("recursive shuffle failures") {
    val shuffleOneRdd = makeRdd(2, Nil)
    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
    val finalRdd = makeRdd(1, List(shuffleDepTwo))
    submit(finalRdd, Array(0))
    // have the first stage complete normally
    complete(taskSets(0), Seq(
        (Success, makeMapStatus("hostA", 2)),
        (Success, makeMapStatus("hostB", 2))))
    // have the second stage complete normally
    complete(taskSets(1), Seq(
        (Success, makeMapStatus("hostA", 1)),
        (Success, makeMapStatus("hostC", 1))))
    // fail the third stage because hostA went down
    complete(taskSets(2), Seq(
        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
    // TODO assert this:
    // blockManagerMaster.removeExecutor("exec-hostA")
    // have DAGScheduler try again
    scheduler.resubmitFailedStages()
    complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
    complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
    complete(taskSets(5), Seq((Success, 42)))
    assert(results === Map(0 -> 42))
  }

  test("cached post-shuffle") {
    val shuffleOneRdd = makeRdd(2, Nil)
    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
    val finalRdd = makeRdd(1, List(shuffleDepTwo))
    submit(finalRdd, Array(0))
    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
    // complete stage 2
    complete(taskSets(0), Seq(
        (Success, makeMapStatus("hostA", 2)),
        (Success, makeMapStatus("hostB", 2))))
    // complete stage 1
    complete(taskSets(1), Seq(
        (Success, makeMapStatus("hostA", 1)),
        (Success, makeMapStatus("hostB", 1))))
    // pretend stage 0 failed because hostA went down
    complete(taskSets(2), Seq(
        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
    // TODO assert this:
    // blockManagerMaster.removeExecutor("exec-hostA")
    // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
    scheduler.resubmitFailedStages()
    assertLocations(taskSets(3), Seq(Seq("hostD")))
    // allow hostD to recover
    complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
    complete(taskSets(4), Seq((Success, 42)))
    assert(results === Map(0 -> 42))
  }

  /** Assert that the supplied TaskSet has exactly the given preferredLocations. Note, converts taskSet's locations to host only. */
  private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
    assert(locations.size === taskSet.tasks.size)
    for ((expectLocs, taskLocs) <-
            taskSet.tasks.map(_.preferredLocations).zip(locations)) {
      assert(expectLocs.map(loc => spark.Utils.parseHostPort(loc)._1) === taskLocs)
    }
  }

  private def makeMapStatus(host: String, reduces: Int): MapStatus =
   new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))

  private def makeBlockManagerId(host: String): BlockManagerId =
    BlockManagerId("exec-" + host, host, 12345, 0)

}