aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala38
-rw-r--r--core/src/test/scala/spark/BoundedMemoryCacheSuite.scala58
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala14
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala131
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala48
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala73
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala92
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala33
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala14
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala29
-rw-r--r--core/src/test/scala/spark/FileSuite.scala16
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java46
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala41
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala57
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala20
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala16
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala51
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala21
-rw-r--r--core/src/test/scala/spark/SizeEstimatorSuite.scala48
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala13
-rw-r--r--core/src/test/scala/spark/ThreadingSuite.scala14
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala663
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala32
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala132
24 files changed, 1185 insertions, 515 deletions
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index d8be99dde7..ac8ae7d308 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -1,6 +1,5 @@
package spark
-import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import collection.mutable
@@ -9,18 +8,7 @@ import scala.math.exp
import scala.math.signum
import spark.SparkContext._
-class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = null
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test ("basic accumulation"){
sc = new SparkContext("local", "test")
@@ -29,6 +17,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val d = sc.parallelize(1 to 20)
d.foreach{x => acc += x}
acc.value should be (210)
+
+
+ val longAcc = sc.accumulator(0l)
+ val maxInt = Integer.MAX_VALUE.toLong
+ d.foreach{x => longAcc += maxInt + x}
+ longAcc.value should be (210l + maxInt * 20)
}
test ("value not assignable from tasks") {
@@ -53,10 +47,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
for (i <- 1 to maxI) {
v should contain(i)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -86,10 +77,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.value += x
}
} should produce [SparkException]
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -115,10 +103,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
bufferAcc.value should contain(i)
mapAcc.value should contain (i -> i.toString)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -134,8 +119,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.localValue ++= x
}
acc.value should be ( (0 to maxI).toSet)
- sc.stop()
- sc = null
+ resetSparkContext()
}
}
diff --git a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala b/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
deleted file mode 100644
index 37cafd1e8e..0000000000
--- a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-import org.scalatest.PrivateMethodTester
-import org.scalatest.matchers.ShouldMatchers
-
-// TODO: Replace this with a test of MemoryStore
-class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers {
- test("constructor test") {
- val cache = new BoundedMemoryCache(60)
- expect(60)(cache.getCapacity)
- }
-
- test("caching") {
- // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
- val oldArch = System.setProperty("os.arch", "amd64")
- val oldOops = System.setProperty("spark.test.useCompressedOops", "true")
- val initialize = PrivateMethod[Unit]('initialize)
- SizeEstimator invokePrivate initialize()
-
- val cache = new BoundedMemoryCache(60) {
- //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry'
- override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- }
- }
-
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
-
- //should be OK
- cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
-
- //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from
- //cache because it's from the same dataset
- expect(CachePutFailure())(cache.put("1", 1, "Meh"))
-
- //should be OK, dataset '1' can be evicted from cache
- cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
-
- //should fail, cache should obey it's capacity
- expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string"))
-
- if (oldArch != null) {
- System.setProperty("os.arch", oldArch)
- } else {
- System.clearProperty("os.arch")
- }
-
- if (oldOops != null) {
- System.setProperty("spark.test.useCompressedOops", oldOops)
- } else {
- System.clearProperty("spark.test.useCompressedOops")
- }
- }
-}
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
index 2d3302f0aa..362a31fb0d 100644
--- a/core/src/test/scala/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -1,20 +1,8 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-class BroadcastSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class BroadcastSuite extends FunSuite with LocalSparkContext {
test("basic broadcast") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
deleted file mode 100644
index 467605981b..0000000000
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-class CacheTrackerSuite extends FunSuite {
- // Send a message to an actor and wait for a reply, in a blocking manner
- private def ask(actor: ActorRef, message: Any): Any = {
- try {
- val timeout = 10.seconds
- val future = actor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with actor", e)
- }
- }
-
- test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 3)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
-
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- /**
- * Helper function to get cacheLocations from CacheTracker
- */
- def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
- val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- answer.map { case (i, arr) => (i, arr.toList) }
- }
-}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 51573254ca..4425949f46 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -1,34 +1,27 @@
package spark
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.FunSuite
import java.io.File
import spark.rdd._
import spark.SparkContext._
import storage.StorageLevel
-class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
+class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
initLogging()
- var sc: SparkContext = _
var checkpointDir: File = _
val partitioner = new HashPartitioner(2)
- before {
+ override def beforeEach() {
+ super.beforeEach()
checkpointDir = File.createTempFile("temp", "")
checkpointDir.delete()
-
sc = new SparkContext("local", "test")
sc.setCheckpointDir(checkpointDir.toString)
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
-
+ override def afterEach() {
+ super.afterEach()
if (checkpointDir != null) {
checkpointDir.delete()
}
@@ -106,7 +99,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint // checkpoint that MappedRDD
+ ones.checkpoint() // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
@@ -132,7 +125,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CoalescedRDDSplits
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint // checkpoint that MappedRDD
+ ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
@@ -167,7 +160,16 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
// so only the RDD will reduce in serialized size, not the splits.
testParentCheckpointing(
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+ }
+ test("CheckpointRDD with zero partitions") {
+ val rdd = new BlockRDD[Int](sc, Array[String]())
+ assert(rdd.splits.size === 0)
+ assert(rdd.isCheckpointed === false)
+ rdd.checkpoint()
+ assert(rdd.count() === 0)
+ assert(rdd.isCheckpointed === true)
+ assert(rdd.splits.size === 0)
}
/**
@@ -183,7 +185,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
testRDDSplitSize: Boolean = false
) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD
+ val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
@@ -252,12 +254,16 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
testRDDSplitSize: Boolean
) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD
+ val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.head.rdd
val rddType = operatedRDD.getClass.getSimpleName
val parentRDDType = parentRDD.getClass.getSimpleName
+ // Get the splits and dependencies of the parent in case they're lazily computed
+ parentRDD.dependencies
+ parentRDD.splits
+
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
@@ -274,7 +280,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
if (testRDDSize) {
assert(
rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
- "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+ "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
"[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
)
}
@@ -325,10 +331,12 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
}
/**
- * Get serialized sizes of the RDD and its splits
+ * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
- (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+ (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
+ Utils.serialize(rdd.splits).length)
}
/**
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index dfa2de80e6..b2d0dd4627 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -3,6 +3,7 @@ package spark
import java.io.NotSerializableException
import org.scalatest.FunSuite
+import spark.LocalSparkContext._
import SparkContext._
class ClosureCleanerSuite extends FunSuite {
@@ -43,13 +44,10 @@ object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -60,11 +58,10 @@ class TestClass extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -73,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -89,11 +85,10 @@ class TestClassWithoutFieldAccess {
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -102,16 +97,16 @@ object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- var y = 1
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + y).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ var y = 1
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + y).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
@@ -121,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + getY).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + getY).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index cacc2796b6..0e2585daa4 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -15,41 +15,28 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
import storage.StorageLevel
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
- @transient var sc: SparkContext = _
-
after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
System.clearProperty("spark.reducer.maxMbInFlight")
System.clearProperty("spark.storage.memoryFraction")
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2, 1, 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
- sc = null
+ resetSparkContext()
}
test("simple groupByKey") {
@@ -188,4 +175,73 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
}
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ val masterId = SparkEnv.get.blockManager.blockManagerId
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
new file mode 100644
index 0000000000..5e84b3a66a
--- /dev/null
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -0,0 +1,33 @@
+package spark
+
+import java.io.File
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ assert(System.getenv("SPARK_HOME") != null)
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(30 seconds) {
+ Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master),
+ new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index a3454f25f6..8c1445a465 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -1,7 +1,6 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
@@ -23,18 +22,7 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class FailureSuite extends FunSuite with LocalSparkContext {
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..f1a35bced3 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -2,17 +2,16 @@ package spark
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter, FileReader, BufferedReader}
import SparkContext._
-class FileServerSuite extends FunSuite with BeforeAndAfter {
+class FileServerSuite extends FunSuite with LocalSparkContext {
- @transient var sc: SparkContext = _
- @transient var tmpFile : File = _
- @transient var testJarFile : File = _
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
- before {
+ override def beforeEach() {
+ super.beforeEach()
// Create a sample text file
val tmpdir = new File(Files.createTempDir(), "test")
tmpdir.mkdir()
@@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
pw.close()
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
+ override def afterEach() {
+ super.afterEach()
// Clean up downloaded file
if (tmpFile.exists) {
tmpFile.delete()
}
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("Distributing files locally") {
@@ -40,7 +34,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +49,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +79,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 554bea53a9..91b48c7456 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -6,24 +6,12 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.apache.hadoop.io._
import SparkContext._
-class FileSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class FileSuite extends FunSuite with LocalSparkContext {
+
test("text files") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 0b5354774b..934e4c2f67 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable {
sc.stop();
sc = null;
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port");
+ System.clearProperty("spark.driver.port");
}
static class ReverseIntComparator implements Comparator<Integer>, Serializable {
@@ -356,6 +356,34 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
public void mapPartitions() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> partitionSums = rdd.mapPartitions(
@@ -624,6 +652,22 @@ public class JavaAPISuite implements Serializable {
}
});
Assert.assertEquals((Float) 25.0f, floatAccum.value());
+
+ // Test the setValue method
+ floatAccum.setValue(5.0f);
+ Assert.assertEquals((Float) 5.0f, floatAccum.value());
+ }
+
+ @Test
+ public void keyBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
+ List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
+ public String call(Integer t) throws Exception {
+ return t.toString();
+ }
+ }).collect();
+ Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
+ Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
}
@Test
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
new file mode 100644
index 0000000000..ff00dd05dd
--- /dev/null
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -0,0 +1,41 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+
+/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */
+trait LocalSparkContext extends BeforeAndAfterEach { self: Suite =>
+
+ @transient var sc: SparkContext = _
+
+ override def afterEach() {
+ resetSparkContext()
+ super.afterEach()
+ }
+
+ def resetSparkContext() = {
+ if (sc != null) {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+ }
+
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSpark[T](sc: SparkContext)(f: SparkContext => T) = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 5b4b198960..dd19442dcb 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -5,8 +5,10 @@ import org.scalatest.FunSuite
import akka.actor._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite {
+class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
+
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
assert(MapOutputTracker.compressSize(1L) === 1)
@@ -41,13 +43,13 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
- (new BlockManagerId("hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
+ (BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
}
@@ -59,18 +61,51 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
- // The remaining reduce task might try to grab the output dispite the shuffle failure;
+ // The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[Exception] { tracker.getServerStatuses(10, 1) }
+ intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ }
+
+ test("remote fetch") {
+ try {
+ System.clearProperty("spark.driver.host") // In case some previous test had set it
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+ System.setProperty("spark.driver.port", boundPort.toString)
+ val masterTracker = new MapOutputTracker(actorSystem, true)
+ val slaveTracker = new MapOutputTracker(actorSystem, false)
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ } finally {
+ System.clearProperty("spark.driver.port")
+ }
}
}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index f09b602a7b..af1107cd19 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,25 +1,12 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class PartitioningSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class PartitioningSuite extends FunSuite with LocalSparkContext {
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
@@ -106,6 +93,11 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
+
+ assert(grouped2.map(_ => 1).partitioner === None)
+ assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
+ assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
+ assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
}
test("partitioning Java arrays should fail") {
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index 9b84b29227..a6344edf8f 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -1,21 +1,9 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import SparkContext._
-class PipedRDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class PipedRDDSuite extends FunSuite with LocalSparkContext {
test("basic pipe") {
sc = new SparkContext("local", "test")
@@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter {
}
}
-
-
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index e5a59dc7d6..fe7deb10d6 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,32 +2,20 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
+import spark.SparkContext._
+import spark.rdd.{CoalescedRDD, PartitionPruningRDD}
-import spark.rdd.CoalescedRDD
-import SparkContext._
-
-class RDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class RDDSuite extends FunSuite with LocalSparkContext {
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
- assert(dups.distinct.count === 4)
- assert(dups.distinct().collect === dups.distinct.collect)
- assert(dups.distinct(2).collect === dups.distinct.collect)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct.collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -35,6 +23,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
+ assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
+ assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))
@@ -42,6 +32,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
}
test("SparkContext.union") {
@@ -102,7 +96,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
}
test("caching with failures") {
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val onlySplit = new Split { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
@@ -134,8 +128,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
+ List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
+ List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)
@@ -166,4 +162,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
}
+
+ test("partition pruning") {
+ sc = new SparkContext("local", "test")
+ val data = sc.parallelize(1 to 10, 10)
+ // Note that split number starts from 0, so > 8 means only 10th partition left.
+ val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
+ assert(prunedRdd.splits.size === 1)
+ val prunedData = prunedRdd.collect()
+ assert(prunedData.size === 1)
+ assert(prunedData(0) === 10)
+ }
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 8170100f1d..3493b9511f 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -3,7 +3,6 @@ package spark
import scala.collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
@@ -15,18 +14,7 @@ import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
-class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("groupByKey") {
sc = new SparkContext("local", "test")
@@ -216,6 +204,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
+
+ test("keys and values") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
index 17f366212b..e235ef2f67 100644
--- a/core/src/test/scala/spark/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -3,7 +3,6 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfterAll
import org.scalatest.PrivateMethodTester
-import org.scalatest.matchers.ShouldMatchers
class DummyClass1 {}
@@ -20,8 +19,17 @@ class DummyClass4(val d: DummyClass3) {
val x: Int = 0
}
+object DummyString {
+ def apply(str: String) : DummyString = new DummyString(str.toArray)
+}
+class DummyString(val arr: Array[Char]) {
+ override val hashCode: Int = 0
+ // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f
+ @transient val hash32: Int = 0
+}
+
class SizeEstimatorSuite
- extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers {
+ extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {
var oldArch: String = _
var oldOops: String = _
@@ -45,15 +53,13 @@ class SizeEstimatorSuite
expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
}
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length.
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- SizeEstimator.estimate("") should (equal (48) or equal (40))
- SizeEstimator.estimate("a") should (equal (56) or equal (48))
- SizeEstimator.estimate("ab") should (equal (56) or equal (48))
- SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56))
+ expect(40)(SizeEstimator.estimate(DummyString("")))
+ expect(48)(SizeEstimator.estimate(DummyString("a")))
+ expect(48)(SizeEstimator.estimate(DummyString("ab")))
+ expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
}
test("primitive arrays") {
@@ -105,18 +111,16 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- expect(40)(SizeEstimator.estimate(""))
- expect(48)(SizeEstimator.estimate("a"))
- expect(48)(SizeEstimator.estimate("ab"))
- expect(56)(SizeEstimator.estimate("abcdefgh"))
+ expect(40)(SizeEstimator.estimate(DummyString("")))
+ expect(48)(SizeEstimator.estimate(DummyString("a")))
+ expect(48)(SizeEstimator.estimate(DummyString("ab")))
+ expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
}
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length.
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("64-bit arch with no compressed oops") {
val arch = System.setProperty("os.arch", "amd64")
val oops = System.setProperty("spark.test.useCompressedOops", "false")
@@ -124,10 +128,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- SizeEstimator.estimate("") should (equal (64) or equal (56))
- SizeEstimator.estimate("a") should (equal (72) or equal (64))
- SizeEstimator.estimate("ab") should (equal (72) or equal (64))
- SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72))
+ expect(56)(SizeEstimator.estimate(DummyString("")))
+ expect(64)(SizeEstimator.estimate(DummyString("a")))
+ expect(64)(SizeEstimator.estimate(DummyString("ab")))
+ expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 1ad11ff4c3..edb8c839fc 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
test("sortByKey") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala
index e9b1837d89..ff315b6693 100644
--- a/core/src/test/scala/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/spark/ThreadingSuite.scala
@@ -22,19 +22,7 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class ThreadingSuite extends FunSuite with LocalSparkContext {
test("accessing SparkContext form a different thread") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..83663ac702
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+ // impose a time limit on this test in case we don't let the job finish, in which case
+ // JobWaiter#getResult will hang.
+ override val timeLimit = Span(5, Seconds)
+
+ val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+ var scheduler: DAGScheduler = null
+ val taskScheduler = mock[TaskScheduler]
+ val blockManagerMaster = mock[BlockManagerMaster]
+ var mapOutputTracker: MapOutputTracker = null
+ var schedulerThread: Thread = null
+ var schedulerException: Throwable = null
+
+ /**
+ * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+ * We cache these so we do not create duplicate matchers for the same RDD.
+ * This allows us to easily setup a sequence of expectations for task sets for
+ * that RDD.
+ */
+ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+ /**
+ * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+ * will only submit one job) from needing to explicitly track it.
+ */
+ var lastJobWaiter: JobWaiter[Int] = null
+
+ /**
+ * Array into which we are accumulating the results from the last job asynchronously.
+ */
+ var lastJobResult: Array[Int] = null
+
+ /**
+ * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+ * and whenExecuting {...} */
+ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+ /**
+ * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+ * to be reset after each time their expectations are set, and we tend to check mock object
+ * calls over a single call to DAGScheduler.
+ *
+ * We also set a default expectation here that blockManagerMaster.getLocations can be called
+ * and will return values from cacheLocations.
+ */
+ def resetExpecting(f: => Unit) {
+ reset(taskScheduler)
+ reset(blockManagerMaster)
+ expecting {
+ expectGetLocations()
+ f
+ }
+ }
+
+ before {
+ taskSetMatchers.clear()
+ cacheLocations.clear()
+ val actorSystem = ActorSystem("test")
+ mapOutputTracker = new MapOutputTracker(actorSystem, true)
+ resetExpecting {
+ taskScheduler.setListener(anyObject())
+ }
+ whenExecuting {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+ }
+ }
+
+ after {
+ assert(scheduler.processEvent(StopDAGScheduler))
+ resetExpecting {
+ taskScheduler.stop()
+ }
+ whenExecuting {
+ scheduler.stop()
+ }
+ sc.stop()
+ System.clearProperty("spark.master.port")
+ }
+
+ def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ def makeRdd(
+ numSplits: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxSplit = numSplits - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getSplits() = (0 to maxSplit).map(i => new Split {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Split): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+ * is from a particular RDD.
+ */
+ def taskSetForRdd(rdd: MyRDD): TaskSet = {
+ val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+ new IArgumentMatcher {
+ override def matches(actual: Any): Boolean = {
+ val taskSet = actual.asInstanceOf[TaskSet]
+ taskSet.tasks(0) match {
+ case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+ case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+ case _ => false
+ }
+ }
+ override def appendTo(buf: StringBuffer) {
+ buf.append("taskSetForRdd(" + rdd + ")")
+ }
+ })
+ EasyMock.reportMatcher(matcher)
+ return null
+ }
+
+ /**
+ * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+ * cacheLocations.
+ */
+ def expectGetLocations(): Unit = {
+ EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+ andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+ override def answer(): Seq[Seq[BlockManagerId]] = {
+ val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+ return blocks.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ if (cacheLocations.contains(key)) {
+ cacheLocations(key)
+ } else {
+ Seq[BlockManagerId]()
+ }
+ } else {
+ Seq[BlockManagerId]()
+ }
+ }.toSeq
+ }
+ }).anyTimes()
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+ * called from a resetExpecting { ... } block.
+ *
+ * Returns a easymock Capture that will contain the task set after the stage is submitted.
+ * Most tests should use interceptStage() instead of this directly.
+ */
+ def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+ val taskSetCapture = new Capture[TaskSet]
+ taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+ return taskSetCapture
+ }
+
+ /**
+ * Expect the supplied code snippet to submit a stage for the specified RDD.
+ * Return the resulting TaskSet. First marks all the tasks are belonging to the
+ * current MapOutputTracker generation.
+ */
+ def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+ var capture: Capture[TaskSet] = null
+ resetExpecting {
+ capture = expectStage(rdd)
+ }
+ whenExecuting {
+ f
+ }
+ val taskSet = capture.getValue
+ for (task <- taskSet.tasks) {
+ task.generation = mapOutputTracker.getGeneration
+ }
+ return taskSet
+ }
+
+ /**
+ * Send the given CompletionEvent messages for the tasks in the TaskSet.
+ */
+ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ }
+ }
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given preferredLocations.
+ */
+ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+ it.next._1.asInstanceOf[Int]
+
+
+ /**
+ * Start a job to compute the given RDD. Returns the JobWaiter that will
+ * collect the result of the job via callbacks from DAGScheduler.
+ */
+ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+ val resultArray = new Array[Int](rdd.splits.size)
+ val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+ rdd,
+ jobComputeFunc,
+ (0 to (rdd.splits.size - 1)),
+ "test-site",
+ allowLocal,
+ (i: Int, value: Int) => resultArray(i) = value
+ )
+ lastJobWaiter = waiter
+ lastJobResult = resultArray
+ runEvent(toSubmit)
+ return (waiter, resultArray)
+ }
+
+ /**
+ * Assert that a job we started has failed.
+ */
+ def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+ waiter.awaitResult() match {
+ case JobSucceeded => fail()
+ case JobFailed(_) => return
+ }
+ }
+
+ /**
+ * Assert that a job we started has succeeded and has the given result.
+ */
+ def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+ result: Array[Int] = lastJobResult) {
+ waiter.awaitResult match {
+ case JobSucceeded =>
+ assert(expected === result)
+ case JobFailed(_) =>
+ fail()
+ }
+ }
+
+ def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ def accumulateResult(partition: Int, value: Int) {
+ numResults += 1
+ }
+ scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getSplits() = Array( new Split { override def index = 0 } )
+ override def getPreferredLocations(split: Split) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ submitRdd(rdd, true)
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("trivial job failure") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ runEvent(TaskSetFailed(taskSet, "test failure"))
+ expectJobException()
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(secondStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(secondStage, List(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+ ))
+ }
+ val thirdStage = interceptStage(shuffleMapRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val fourthStage = interceptStage(reduceRdd) {
+ respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(fourthStage, List( (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val oldGeneration = mapOutputTracker.getGeneration
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ // We rely on the event queue being ordered and increasing the generation number by 1
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ taskSet.tasks(1).generation = newGeneration
+ val secondStage = interceptStage(reduceRdd) {
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ val secondStage = interceptStage(shuffleMapRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ val thirdStage = interceptStage(reduceRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ respondToTaskSet(thirdStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ val secondStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val thirdStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(thirdStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List(
+ (Success, makeMapStatus("hostA", 2))
+ ))
+ }
+ val finalStage = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostA", 1))
+ ))
+ }
+ respondToTaskSet(finalStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostD", 1))
+ ))
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle but fails") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+ intercept[FetchFailedException]{
+ mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+ }
+
+ // Simulate the shuffle input data failing to be cached.
+ cacheLocations.remove(shuffleTwoRdd.id -> 0)
+ respondToTaskSet(recomputeTwoCached, List(
+ (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+ ))
+
+ // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+ // everything.
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+ val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
new file mode 100644
index 0000000000..a5db7103f5
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -0,0 +1,32 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import spark.TaskContext
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+import spark.LocalSparkContext
+
+class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ test("Calls executeOnCompleteCallbacks after failure") {
+ var completed = false
+ sc = new SparkContext("local", "test")
+ val rdd = new RDD[String](sc, List()) {
+ override def getSplits = Array[Split](StubSplit(0))
+ override def compute(split: Split, context: TaskContext) = {
+ context.addOnCompleteCallback(() => completed = true)
+ sys.error("failed")
+ }
+ }
+ val func = (c: TaskContext, i: Iterator[String]) => i.next
+ val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+ intercept[RuntimeException] {
+ task.run(0)
+ }
+ assert(completed === true)
+ }
+
+ case class StubSplit(val index: Int) extends Split
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index 8f86e3170e..2d177bbf67 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -69,33 +69,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("StorageLevel object caching") {
- val level1 = new StorageLevel(false, false, false, 3)
- val level2 = new StorageLevel(false, false, false, 3)
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
val bytes1 = spark.Utils.serialize(level1)
val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
val bytes2 = spark.Utils.serialize(level2)
val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
assert(level1_ === level1, "Deserialized level1 not same as original level1")
- assert(level2_ === level2, "Deserialized level2 not same as original level1")
- assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2")
- assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
}
test("BlockManagerId object caching") {
- val id1 = new StorageLevel(false, false, false, 3)
- val id2 = new StorageLevel(false, false, false, 3)
+ val id1 = BlockManagerId("e1", "XXX", 1)
+ val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
val bytes1 = spark.Utils.serialize(id1)
- val id1_ = spark.Utils.deserialize[StorageLevel](bytes1)
+ val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1)
val bytes2 = spark.Utils.serialize(id2)
- val id2_ = spark.Utils.deserialize[StorageLevel](bytes2)
- assert(id1_ === id1, "Deserialized id1 not same as original id1")
- assert(id2_ === id2, "Deserialized id2 not same as original id1")
- assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2")
- assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1")
+ val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
}
test("master + 1 manager interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -125,8 +133,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
- store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -141,7 +149,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -190,7 +198,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -198,7 +206,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a1") != None, "a1 was not in store")
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store invokePrivate heartBeat()
@@ -206,25 +214,63 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
-
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ store.waitForAsyncReregister()
assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
assert(master.getLocations("a2").size > 0, "master was not told about a2")
}
+ test("reregistration doesn't dead lock") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = List(new Array[Byte](400))
+
+ // try many times to trigger any deadlocks
+ for (i <- 1 to 100) {
+ master.removeExecutor(store.blockManagerId.executorId)
+ val t1 = new Thread {
+ override def run() {
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+ }
+ }
+ val t2 = new Thread {
+ override def run() {
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ }
+ }
+ val t3 = new Thread {
+ override def run() {
+ store invokePrivate heartBeat()
+ }
+ }
+
+ t1.start()
+ t2.start()
+ t3.start()
+ t1.join()
+ t2.join()
+ t3.join()
+
+ store.dropFromMemory("a1", null)
+ store.dropFromMemory("a2", null)
+ store.waitForAsyncReregister()
+ }
+ }
+
test("in-memory LRU storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -243,7 +289,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -262,14 +308,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
- // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2
+ // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
@@ -281,7 +327,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -304,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -317,7 +363,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -332,7 +378,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -347,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -362,7 +408,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -377,7 +423,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -402,7 +448,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -426,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -472,7 +518,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager(actorSystem, master, serializer, 500)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -483,49 +529,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
System.setProperty("spark.shuffle.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()