aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-04-22 14:13:56 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-04-22 14:14:03 +0530
commit185bb9525a3a48313cd5e446e1b80d2d697465d8 (patch)
treebdbab31c488d6140973ef99d93cc8ff959cad50f /core/src/test/scala
parent4b57f83209b94f87890ef307af45fa493e7fdba8 (diff)
parent17e076de800ea0d4c55f2bd657348641f6f9c55b (diff)
downloadspark-185bb9525a3a48313cd5e446e1b80d2d697465d8.tar.gz
spark-185bb9525a3a48313cd5e446e1b80d2d697465d8.tar.bz2
spark-185bb9525a3a48313cd5e446e1b80d2d697465d8.zip
Manually merged scala-2.10 and master
Diffstat (limited to 'core/src/test/scala')
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala37
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala14
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala130
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala365
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala71
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala129
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala33
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala14
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala29
-rw-r--r--core/src/test/scala/spark/FileSuite.scala16
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java109
-rw-r--r--core/src/test/scala/spark/KryoSerializerSuite.scala1
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala41
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala55
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala23
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala16
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala186
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala110
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala23
-rw-r--r--core/src/test/scala/spark/ThreadingSuite.scala14
-rw-r--r--core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala (renamed from core/src/test/scala/spark/ParallelCollectionSplitSuite.scala)40
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala403
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala86
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala27
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala153
-rw-r--r--core/src/test/scala/spark/util/DistributionSuite.scala25
-rw-r--r--core/src/test/scala/spark/util/NextIteratorSuite.scala68
-rw-r--r--core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala23
28 files changed, 1749 insertions, 492 deletions
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index 9f5335978f..f59334a033 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -1,6 +1,5 @@
package spark
-import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import collection.mutable
@@ -9,9 +8,8 @@ import scala.math.exp
import scala.math.signum
import spark.SparkContext._
-class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
+class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
- var sc: SparkContext = null
implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] {
def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = {
@@ -27,15 +25,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
test ("basic accumulation"){
sc = new SparkContext("local", "test")
val acc : Accumulator[Int] = sc.accumulator(0)
@@ -43,6 +32,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val d = sc.parallelize(1 to 20)
d.foreach{x => acc += x}
acc.value should be (210)
+
+
+ val longAcc = sc.accumulator(0l)
+ val maxInt = Integer.MAX_VALUE.toLong
+ d.foreach{x => longAcc += maxInt + x}
+ longAcc.value should be (210l + maxInt * 20)
}
test ("value not assignable from tasks") {
@@ -66,10 +61,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
for (i <- 1 to maxI) {
v should contain(i)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -84,10 +76,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.value += x
}
} should produce [SparkException]
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -113,10 +102,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
bufferAcc.value should contain(i)
mapAcc.value should contain (i -> i.toString)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -131,8 +117,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.localValue ++= x
}
acc.value should be ( (0 to maxI).toSet)
- sc.stop()
- sc = null
+ resetSparkContext()
}
}
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
index 2d3302f0aa..362a31fb0d 100644
--- a/core/src/test/scala/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -1,20 +1,8 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-class BroadcastSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class BroadcastSuite extends FunSuite with LocalSparkContext {
test("basic broadcast") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
deleted file mode 100644
index d0c2bd47fc..0000000000
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ /dev/null
@@ -1,130 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import scala.concurrent.{Await, Future}
-import akka.remote._
-import scala.concurrent.duration.Duration
-import akka.util.Timeout
-import scala.concurrent.duration._
-
-class CacheTrackerSuite extends FunSuite {
- // Send a message to an actor and wait for a reply, in a blocking manner
- private def ask(actor: ActorRef, message: Any): Any = {
- try {
- val timeout = 10.seconds
- val future: Future[Any] = akka.pattern.ask(actor, message)(timeout)
- Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with actor", e)
- }
- }
-
- test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 3)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
-
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- /**
- * Helper function to get cacheLocations from CacheTracker
- */
- def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
- val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- answer.map { case (i, arr) => (i, arr.toList) }
- }
-}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
new file mode 100644
index 0000000000..ca385972fb
--- /dev/null
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -0,0 +1,365 @@
+package spark
+
+import org.scalatest.FunSuite
+import java.io.File
+import spark.rdd._
+import spark.SparkContext._
+import storage.StorageLevel
+
+class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
+ initLogging()
+
+ var checkpointDir: File = _
+ val partitioner = new HashPartitioner(2)
+
+ override def beforeEach() {
+ super.beforeEach()
+ checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+ sc = new SparkContext("local", "test")
+ sc.setCheckpointDir(checkpointDir.toString)
+ }
+
+ override def afterEach() {
+ super.afterEach()
+ if (checkpointDir != null) {
+ checkpointDir.delete()
+ }
+ }
+
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
+ test("ParallelCollection") {
+ val parCollection = sc.makeRDD(1 to 4, 2)
+ val numPartitions = parCollection.partitions.size
+ parCollection.checkpoint()
+ assert(parCollection.dependencies === Nil)
+ val result = parCollection.collect()
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ assert(parCollection.dependencies != Nil)
+ assert(parCollection.partitions.length === numPartitions)
+ assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList)
+ assert(parCollection.collect() === result)
+ }
+
+ test("BlockRDD") {
+ val blockId = "id"
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
+ val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ val numPartitions = blockRDD.partitions.size
+ blockRDD.checkpoint()
+ val result = blockRDD.collect()
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.partitions.length === numPartitions)
+ assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
+ assert(blockRDD.collect() === result)
+ }
+
+ test("ShuffledRDD") {
+ testCheckpointing(rdd => {
+ // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+ new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+ })
+ }
+
+ test("UnionRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+
+ // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed.
+ // Current implementation of UnionRDD has transient reference to parent RDDs,
+ // so only the partitions will reduce in serialized size, not the RDD.
+ testCheckpointing(_.union(otherRDD), false, true)
+ testParentCheckpointing(_.union(otherRDD), false, true)
+ }
+
+ test("CartesianRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+ testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+ // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // Note that this test is very specific to the current implementation of CartesianRDD.
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint() // checkpoint that MappedRDD
+ val cartesian = new CartesianRDD(sc, ones, ones)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
+ cartesian.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
+ assert(
+ (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+ (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+ "CartesianRDD.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoalescedRDD") {
+ testCheckpointing(_.coalesce(2))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(_.coalesce(2), true, false)
+
+ // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // Note that this test is very specific to the current implementation of CoalescedRDDPartitions
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint() // checkpoint that MappedRDD
+ val coalesced = new CoalescedRDD(ones, 2)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
+ coalesced.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
+ assert(
+ splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+ "CoalescedRDDPartition.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoGroupedRDD") {
+ val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+ testCheckpointing(rdd => {
+ CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+
+ val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+ testParentCheckpointing(rdd => {
+ CheckpointSuite.cogroup(
+ longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+ }
+
+ test("ZippedRDD") {
+ testCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of ZippedRDDPartitions has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+ }
+
+ test("CheckpointRDD with zero partitions") {
+ val rdd = new BlockRDD[Int](sc, Array[String]())
+ assert(rdd.partitions.size === 0)
+ assert(rdd.isCheckpointed === false)
+ rdd.checkpoint()
+ assert(rdd.count() === 0)
+ assert(rdd.isCheckpointed === true)
+ assert(rdd.partitions.size === 0)
+ }
+
+ /**
+ * Test checkpointing of the final RDD generated by the given operation. By default,
+ * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+ * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or
+ * not, but this is not done by default as usually the partitions do not refer to any RDD and
+ * therefore never store the lineage.
+ */
+ def testCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean = true,
+ testRDDPartitionSize: Boolean = false
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numPartitions = operatedRDD.partitions.length
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.checkpoint()
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the checkpoint file has been created
+ assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the partitions have been changed to the new Hadoop partitions
+ assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
+
+ // Test whether the number of partitions is same as before
+ assert(operatedRDD.partitions.length === numPartitions)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced. If the RDD
+ // does not have any dependency to another RDD (e.g., ParallelCollection,
+ // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+ if (testRDDSize) {
+ logInfo("Size of " + rddType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the partitions has reduced. If the partitions
+ // do not have any non-transient reference to another RDD or another RDD's partitions, it
+ // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+ // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions
+ // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+ // replaced with the HadooPartitions of the checkpointed RDD.
+ if (testRDDPartitionSize) {
+ logInfo("Size of " + rddType + " partitions "
+ + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " partitions did not reduce after checkpointing " +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+ }
+
+ /**
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
+ * this RDD will remember the partitions and therefore potentially the whole lineage.
+ */
+ def testParentCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean,
+ testRDDPartitionSize: Boolean
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val rddType = operatedRDD.getClass.getSimpleName
+ val parentRDDType = parentRDD.getClass.getSimpleName
+
+ // Get the partitions and dependencies of the parent in case they're lazily computed
+ parentRDD.dependencies
+ parentRDD.partitions
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced because of its parent being
+ // checkpointed. If this RDD or its parent RDD do not have any dependency
+ // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+ // not reduce in size after checkpointing.
+ if (testRDDSize) {
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the partitions has reduced because of its parent being
+ // checkpointed. If the partitions do not have any non-transient reference to another RDD
+ // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent
+ // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's
+ // partitions must have changed after checkpointing.
+ if (testRDDPartitionSize) {
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+
+ }
+
+ /**
+ * Generate an RDD with a long lineage of one-to-one dependencies.
+ */
+ def generateLongLineageRDD(): RDD[Int] = {
+ var rdd = sc.makeRDD(1 to 100, 4)
+ for (i <- 1 to 50) {
+ rdd = rdd.map(x => x + 1)
+ }
+ rdd
+ }
+
+ /**
+ * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+ * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+ * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+ * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ */
+ def generateLongLineageRDDForCoGroupedRDD() = {
+ val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+ def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+ var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+ for(i <- 1 to 10) {
+ cogrouped = cogrouped.mapValues(add).cogroup(ones)
+ }
+ cogrouped.mapValues(add)
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
+ */
+ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
+ Utils.serialize(rdd.partitions).length)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
+ }
+}
+
+
+object CheckpointSuite {
+ // This is a custom cogroup function that does not use mapValues like
+ // the PairRDDFunctions.cogroup()
+ def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+ //println("First = " + first + ", second = " + second)
+ new CoGroupedRDD[K](
+ Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]),
+ part
+ ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+ }
+
+}
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index 7c0334d957..b2d0dd4627 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -3,6 +3,7 @@ package spark
import java.io.NotSerializableException
import org.scalatest.FunSuite
+import spark.LocalSparkContext._
import SparkContext._
class ClosureCleanerSuite extends FunSuite {
@@ -43,11 +44,10 @@ object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -58,11 +58,10 @@ class TestClass extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -71,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -87,11 +85,10 @@ class TestClassWithoutFieldAccess {
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -100,16 +97,16 @@ object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- var y = 1
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + y).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ var y = 1
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + y).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
@@ -119,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + getY).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + getY).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index cacc2796b6..4104b33c8b 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -1,5 +1,6 @@
package spark
+import network.ConnectionManagerId
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
@@ -13,43 +14,30 @@ import com.google.common.io.Files
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-import storage.StorageLevel
+import storage.{GetBlock, BlockManagerWorker, StorageLevel}
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
- @transient var sc: SparkContext = _
-
after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
System.clearProperty("spark.reducer.maxMbInFlight")
System.clearProperty("spark.storage.memoryFraction")
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2, 1, 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
- sc = null
+ resetSparkContext()
}
test("simple groupByKey") {
@@ -153,9 +141,22 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
test("caching in memory and disk, serialized, replicated") {
sc = new SparkContext(clusterUrl, "test")
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
+
assert(data.count() === 1000)
assert(data.count() === 1000)
assert(data.count() === 1000)
+
+ // Get all the locations of the first partition and try to fetch the partitions
+ // from those locations.
+ val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray
+ val blockId = blockIds(0)
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.master.getLocations(blockId).foreach(id => {
+ val bytes = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(id.ip, id.port))
+ val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
+ assert(deserialized === (1 to 100).toList)
+ })
}
test("compute without caching when no partitions fit in memory") {
@@ -188,4 +189,94 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
}
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ val masterId = SparkEnv.get.blockManager.blockManagerId
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+
+ test("recover from node failures with replication") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ // Using more than two nodes so we don't have a symmetric communication pattern and might
+ // cache a partially correct list of peers.
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+
+ assert(data.count === 4)
+ assert(data.map(markNodeIfIdentity).collect.size === 4)
+ assert(data.map(failOnMarkedIdentity).collect.size === 4)
+
+ // Create a new replicated RDD to make sure that cached peer information doesn't cause
+ // problems.
+ val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2)
+ assert(data2.count === 2)
+ }
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
new file mode 100644
index 0000000000..5e84b3a66a
--- /dev/null
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -0,0 +1,33 @@
+package spark
+
+import java.io.File
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ assert(System.getenv("SPARK_HOME") != null)
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(30 seconds) {
+ Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master),
+ new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index a3454f25f6..8c1445a465 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -1,7 +1,6 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
@@ -23,18 +22,7 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class FailureSuite extends FunSuite with LocalSparkContext {
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..f1a35bced3 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -2,17 +2,16 @@ package spark
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter, FileReader, BufferedReader}
import SparkContext._
-class FileServerSuite extends FunSuite with BeforeAndAfter {
+class FileServerSuite extends FunSuite with LocalSparkContext {
- @transient var sc: SparkContext = _
- @transient var tmpFile : File = _
- @transient var testJarFile : File = _
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
- before {
+ override def beforeEach() {
+ super.beforeEach()
// Create a sample text file
val tmpdir = new File(Files.createTempDir(), "test")
tmpdir.mkdir()
@@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
pw.close()
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
+ override def afterEach() {
+ super.afterEach()
// Clean up downloaded file
if (tmpFile.exists) {
tmpFile.delete()
}
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("Distributing files locally") {
@@ -40,7 +34,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +49,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +79,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 554bea53a9..91b48c7456 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -6,24 +6,12 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.apache.hadoop.io._
import SparkContext._
-class FileSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class FileSuite extends FunSuite with LocalSparkContext {
+
test("text files") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index c61913fc82..d3dcd3bbeb 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable {
sc.stop();
sc = null;
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port");
+ System.clearProperty("spark.driver.port");
}
static class ReverseIntComparator implements Comparator<Integer>, Serializable {
@@ -197,6 +197,28 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void foldByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0,
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, sums.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, sums.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
+ }
+
+ @Test
public void reduceByKey() {
List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
new Tuple2<Integer, Integer>(2, 1),
@@ -356,6 +378,34 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
public void mapPartitions() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> partitionSums = rdd.mapPartitions(
@@ -395,7 +445,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0);
+ TaskContext context = new TaskContext(0, 0, 0, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}
@@ -586,7 +636,7 @@ public class JavaAPISuite implements Serializable {
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
- final Accumulator<Integer> intAccum = sc.accumulator(10);
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
intAccum.add(x);
@@ -594,7 +644,7 @@ public class JavaAPISuite implements Serializable {
});
Assert.assertEquals((Integer) 25, intAccum.value());
- final Accumulator<Double> doubleAccum = sc.accumulator(10.0);
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
doubleAccum.add((double) x);
@@ -641,4 +691,55 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
}
+
+ @Test
+ public void checkpointAndComputation() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+ }
+
+ @Test
+ public void checkpointAndRestore() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+
+ Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+ JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+ }
+
+ @Test
+ public void mapOnPairRDD() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+ JavaPairRDD<Integer, Integer> rdd3 = rdd2.map(
+ new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception {
+ return new Tuple2<Integer, Integer>(in._2(), in._1());
+ }
+ });
+ Assert.assertEquals(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(0, 2),
+ new Tuple2<Integer, Integer>(1, 3),
+ new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
+
+ }
}
diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala
index 06d446ea24..327e2ff848 100644
--- a/core/src/test/scala/spark/KryoSerializerSuite.scala
+++ b/core/src/test/scala/spark/KryoSerializerSuite.scala
@@ -82,6 +82,7 @@ class KryoSerializerSuite extends FunSuite {
check(mutable.HashMap(1 -> "one", 2 -> "two"))
check(mutable.HashMap("one" -> 1, "two" -> 2))
check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4))))
+ check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three")))
}
test("custom registrator") {
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
new file mode 100644
index 0000000000..ff00dd05dd
--- /dev/null
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -0,0 +1,41 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+
+/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */
+trait LocalSparkContext extends BeforeAndAfterEach { self: Suite =>
+
+ @transient var sc: SparkContext = _
+
+ override def afterEach() {
+ resetSparkContext()
+ super.afterEach()
+ }
+
+ def resetSparkContext() = {
+ if (sc != null) {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+ }
+
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSpark[T](sc: SparkContext)(f: SparkContext => T) = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index d3dd3a8fa4..3abc584b6a 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -1,17 +1,13 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import akka.actor._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
import spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
- after {
- System.clearProperty("spark.master.port")
- }
+class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
@@ -35,58 +31,65 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
test("master start and stop") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker(actorSystem, true)
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker(actorSystem, true)
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
- (new BlockManagerId("hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
+ (BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
}
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker(actorSystem, true)
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
- // The remaining reduce task might try to grab the output dispite the shuffle failure;
+ // The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
}
test("remote fetch") {
- System.clearProperty("spark.master.host")
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("test", "localhost", 0)
- System.setProperty("spark.master.port", boundPort.toString)
- val masterTracker = new MapOutputTracker(actorSystem, true)
- val slaveTracker = new MapOutputTracker(actorSystem, false)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+ val masterTracker = new MapOutputTracker()
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0)
+ val slaveTracker = new MapOutputTracker()
+ slaveTracker.trackerActor = slaveSystem.actorFor(
+ "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+
masterTracker.registerShuffle(10, 1)
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
@@ -95,13 +98,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- new BlockManagerId("hostA", 1000), Array(compressedSize1000)))
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((new BlockManagerId("hostA", 1000), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index eb3c8f238f..60db759c25 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,25 +1,12 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class PartitioningSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class PartitioningSuite extends FunSuite with LocalSparkContext {
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
@@ -97,10 +84,10 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner)
assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner)
- assert(grouped2.join(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.cogroup(grouped4).partitioner === grouped2.partitioner)
+ assert(grouped2.join(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.join(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index 9b84b29227..a6344edf8f 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -1,21 +1,9 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import SparkContext._
-class PipedRDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class PipedRDDSuite extends FunSuite with LocalSparkContext {
test("basic pipe") {
sc = new SparkContext("local", "test")
@@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter {
}
}
-
-
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index d74e9786c3..7fbdd44340 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,32 +2,20 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
+import spark.SparkContext._
+import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD}
-import spark.rdd.CoalescedRDD
-import SparkContext._
-
-class RDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class RDDSuite extends FunSuite with LocalSparkContext {
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
- assert(dups.distinct.count === 4)
- assert(dups.distinct().collect === dups.distinct.collect)
- assert(dups.distinct(2).collect === dups.distinct.collect)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct.collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -44,6 +32,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ val partitionSumsWithIndex = nums.mapPartitionsWithIndex {
+ case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
+ }
+ assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
}
test("SparkContext.union") {
@@ -76,10 +73,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("checkpointing") {
+ test("basic checkpointing") {
+ import java.io.File
+ val checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
sc = new SparkContext("local", "test")
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
- assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ sc.setCheckpointDir(checkpointDir.toString)
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ Thread.sleep(1000)
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+
+ checkpointDir.deleteOnExit()
}
test("basic caching") {
@@ -91,13 +101,13 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
}
test("caching with failures") {
- sc = new SparkContext("local", "test")
- val onlySplit = new Split { override def index: Int = 0 }
+ sc = new SparkContext("local", "test")
+ val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
- val rdd = new RDD[Int](sc) {
- override def splits: Array[Split] = Array(onlySplit)
- override val dependencies = List[Dependency[_]]()
- override def compute(split: Split, context: TaskContext): Iterator[Int] = {
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getPartitions: Array[Partition] = Array(onlySplit)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
if (shouldFail) {
throw new Exception("injected failure")
} else {
@@ -113,35 +123,72 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
+ test("cogrouped RDDs") {
+ sc = new SparkContext("local", "test")
+ val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
+ val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
+
+ // Use cogroup function
+ val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
+ assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
+ assert(cogrouped(2) === (Seq("two"), Seq("two1")))
+ assert(cogrouped(3) === (Seq("three"), Seq()))
+
+ // Construct CoGroupedRDD directly, with map side combine enabled
+ val cogrouped1 = new CoGroupedRDD[Int](
+ Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
+ new HashPartitioner(3),
+ true).collectAsMap()
+ assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
+ assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
+ assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
+
+ // Construct CoGroupedRDD directly, with map side combine disabled
+ val cogrouped2 = new CoGroupedRDD[Int](
+ Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
+ new HashPartitioner(3),
+ false).collectAsMap()
+ assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
+ assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
+ assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
+ }
+
test("coalesced RDDs") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
- val coalesced1 = new CoalescedRDD(data, 2)
+ val coalesced1 = data.coalesce(2)
assert(coalesced1.collect().toList === (1 to 10).toList)
assert(coalesced1.glom().collect().map(_.toList).toList ===
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
+ List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
+ List(5, 6, 7, 8, 9))
- val coalesced2 = new CoalescedRDD(data, 3)
+ val coalesced2 = data.coalesce(3)
assert(coalesced2.collect().toList === (1 to 10).toList)
assert(coalesced2.glom().collect().map(_.toList).toList ===
List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10)))
- val coalesced3 = new CoalescedRDD(data, 10)
+ val coalesced3 = data.coalesce(10)
assert(coalesced3.collect().toList === (1 to 10).toList)
assert(coalesced3.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
// If we try to coalesce into more partitions than the original RDD, it should just
// keep the original number of partitions.
- val coalesced4 = new CoalescedRDD(data, 20)
+ val coalesced4 = data.coalesce(20)
assert(coalesced4.collect().toList === (1 to 10).toList)
assert(coalesced4.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
+
+ // we can optionally shuffle to keep the upstream parallel
+ val coalesced5 = data.coalesce(1, shuffle = true)
+ assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _]] !=
+ null)
}
test("zipped RDDs") {
@@ -155,4 +202,75 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
}
+
+ test("partition pruning") {
+ sc = new SparkContext("local", "test")
+ val data = sc.parallelize(1 to 10, 10)
+ // Note that split number starts from 0, so > 8 means only 10th partition left.
+ val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
+ assert(prunedRdd.partitions.size === 1)
+ val prunedData = prunedRdd.collect()
+ assert(prunedData.size === 1)
+ assert(prunedData(0) === 10)
+ }
+
+ test("mapWith") {
+ import java.util.Random
+ sc = new SparkContext("local", "test")
+ val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
+ val randoms = ones.mapWith(
+ (index: Int) => new Random(index + 42))
+ {(t: Int, prng: Random) => prng.nextDouble * t}.collect()
+ val prn42_3 = {
+ val prng42 = new Random(42)
+ prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
+ }
+ val prn43_3 = {
+ val prng43 = new Random(43)
+ prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
+ }
+ assert(randoms(2) === prn42_3)
+ assert(randoms(5) === prn43_3)
+ }
+
+ test("flatMapWith") {
+ import java.util.Random
+ sc = new SparkContext("local", "test")
+ val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
+ val randoms = ones.flatMapWith(
+ (index: Int) => new Random(index + 42))
+ {(t: Int, prng: Random) =>
+ val random = prng.nextDouble()
+ Seq(random * t, random * t * 10)}.
+ collect()
+ val prn42_3 = {
+ val prng42 = new Random(42)
+ prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
+ }
+ val prn43_3 = {
+ val prng43 = new Random(43)
+ prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
+ }
+ assert(randoms(5) === prn42_3 * 10)
+ assert(randoms(11) === prn43_3 * 10)
+ }
+
+ test("filterWith") {
+ import java.util.Random
+ sc = new SparkContext("local", "test")
+ val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
+ val sample = ints.filterWith(
+ (index: Int) => new Random(index + 42))
+ {(t: Int, prng: Random) => prng.nextInt(3) == 0}.
+ collect()
+ val checkSample = {
+ val prng42 = new Random(42)
+ val prng43 = new Random(43)
+ Array(1, 2, 3, 4, 5, 6).filter{i =>
+ if (i < 4) 0 == prng42.nextInt(3)
+ else 0 == prng43.nextInt(3)}
+ }
+ assert(sample.size === checkSample.size)
+ for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
+ }
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index bebb8ebe86..2b2a90defa 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,9 +1,9 @@
package spark
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
@@ -15,18 +15,7 @@ import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
-class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("groupByKey") {
sc = new SparkContext("local", "test")
@@ -110,6 +99,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
+
+ test("reduceByKey with partitioner") {
+ sc = new SparkContext("local", "test")
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
test("join") {
sc = new SparkContext("local", "test")
@@ -211,7 +222,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.splits.size == 0)
+ assert(file.partitions.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
@@ -223,6 +234,77 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
assert(rdd.keys.collect().toList === List(1, 2))
assert(rdd.values.collect().toList === List("a", "b"))
}
+
+ test("default partitioner uses partition size") {
+ sc = new SparkContext("local", "test")
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ sc = new SparkContext("local", "test")
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ sc = new SparkContext("local", "test")
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ sc = new SparkContext("local", "test")
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ sc = new SparkContext("local", "test")
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ sc = new SparkContext("local", "test")
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitioner.get === p)
+ }
+
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 1ad11ff4c3..495f957e53 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
test("sortByKey") {
sc = new SparkContext("local", "test")
@@ -30,7 +19,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey()
- assert(sorted.splits.size === 2)
+ assert(sorted.partitions.size === 2)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
@@ -40,17 +29,17 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 1)
- assert(sorted.splits.size === 1)
+ assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
- test("large array with many splits") {
+ test("large array with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 20)
- assert(sorted.splits.size === 20)
+ assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
@@ -70,7 +59,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
- test("sort descending with many splits") {
+ test("sort descending with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala
index e9b1837d89..ff315b6693 100644
--- a/core/src/test/scala/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/spark/ThreadingSuite.scala
@@ -22,19 +22,7 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class ThreadingSuite extends FunSuite with LocalSparkContext {
test("accessing SparkContext form a different thread") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala
index 450c69bd58..d27a2538e4 100644
--- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import scala.collection.immutable.NumericRange
@@ -11,7 +11,7 @@ import org.scalacheck.Prop._
class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one element per slice") {
val data = Array(1, 2, 3)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1")
assert(slices(1).mkString(",") === "2")
@@ -20,14 +20,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one slice") {
val data = Array(1, 2, 3)
- val slices = ParallelCollection.slice(data, 1)
+ val slices = ParallelCollectionRDD.slice(data, 1)
assert(slices.size === 1)
assert(slices(0).mkString(",") === "1,2,3")
}
test("equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -36,7 +36,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("non-equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -45,7 +45,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("splitting exclusive range") {
val data = 0 until 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 65).mkString(","))
@@ -54,7 +54,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("splitting inclusive range") {
val data = 0 to 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 66).mkString(","))
@@ -63,24 +63,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("empty data") {
val data = new Array[Int](0)
- val slices = ParallelCollection.slice(data, 5)
+ val slices = ParallelCollectionRDD.slice(data, 5)
assert(slices.size === 5)
for (slice <- slices) assert(slice.size === 0)
}
test("zero slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelCollection.slice(data, 0) }
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) }
}
test("negative number of slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelCollection.slice(data, -5) }
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) }
}
test("exclusive ranges sliced into ranges") {
val data = 1 until 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -88,7 +88,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges sliced into ranges") {
val data = 1 to 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -97,7 +97,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("large ranges don't overflow") {
val N = 100 * 1000 * 1000
val data = 0 until N
- val slices = ParallelCollection.slice(data, 40)
+ val slices = ParallelCollectionRDD.slice(data, 40)
assert(slices.size === 40)
for (i <- 0 until 40) {
assert(slices(i).isInstanceOf[Range])
@@ -117,7 +117,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
(tuple: (List[Int], Int)) =>
val d = tuple._1
val n = tuple._2
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
@@ -134,7 +134,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
} yield (a until b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -152,7 +152,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
} yield (a to b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -163,7 +163,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("exclusive ranges of longs") {
val data = 1L until 100L
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -171,7 +171,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges of longs") {
val data = 1L to 100L
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -179,7 +179,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("exclusive ranges of doubles") {
val data = 1.0 until 100.0 by 1.0
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -187,7 +187,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges of doubles") {
val data = 1.0 to 100.0 by 1.0
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..6da58a0f6e
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,403 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark.LocalSparkContext
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Partition
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ /** Set of TaskSets the DAGScheduler has requested executed. */
+ val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+ val taskScheduler = new TaskScheduler() {
+ override def start() = {}
+ override def stop() = {}
+ override def submitTasks(taskSet: TaskSet) = {
+ // normally done by TaskSetManager
+ taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
+ taskSets += taskSet
+ }
+ override def setListener(listener: TaskSchedulerListener) = {}
+ override def defaultParallelism() = 2
+ }
+
+ var mapOutputTracker: MapOutputTracker = null
+ var scheduler: DAGScheduler = null
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+ // stub out BlockManagerMaster.getLocations to use our cacheLocations
+ val blockManagerMaster = new BlockManagerMaster(null) {
+ override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ cacheLocations.getOrElse(key, Seq())
+ } else {
+ Seq()
+ }
+ }.toSeq
+ }
+ override def removeExecutor(execId: String) {
+ // don't need to propagate to the driver, which we don't have
+ }
+ }
+
+ /** The list of results that DAGScheduler has collected. */
+ val results = new HashMap[Int, Any]()
+ var failure: Exception = _
+ val listener = new JobListener() {
+ override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+ override def jobFailed(exception: Exception) = { failure = exception }
+ }
+
+ before {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ taskSets.clear()
+ cacheLocations.clear()
+ results.clear()
+ mapOutputTracker = new MapOutputTracker()
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+ override def runLocally(job: ActiveJob) {
+ // don't bother with the thread while unit testing
+ runLocallyWithinThread(job)
+ }
+ }
+ }
+
+ after {
+ scheduler.stop()
+ }
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ private def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ private def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+ it.next.asInstanceOf[Tuple2[_, _]]._1
+
+ /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+ private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
+ }
+ }
+ }
+
+ /** Sends the rdd to the scheduler for scheduling. */
+ private def submit(
+ rdd: RDD[_],
+ partitions: Array[Int],
+ func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+ allowLocal: Boolean = false,
+ listener: JobListener = listener) {
+ runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+ }
+
+ /** Sends TaskSetFailed to the scheduler. */
+ private def failed(taskSet: TaskSet, message: String) {
+ runEvent(TaskSetFailed(taskSet, message))
+ }
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ val fakeListener = new JobListener() {
+ override def taskSucceeded(partition: Int, value: Any) = numResults += 1
+ override def jobFailed(exception: Exception) = throw exception
+ }
+ submit(rdd, Array(), listener = fakeListener)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ submit(rdd, Array(0))
+ complete(taskSets(0), List((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getPartitions = Array( new Partition { override def index = 0 } )
+ override def getPreferredLocations(split: Partition) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ submit(finalRdd, Array(0))
+ complete(taskSets(0), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ submit(finalRdd, Array(0))
+ val taskSet = taskSets(0)
+ assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
+ complete(taskSet, Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("trivial job failure") {
+ submit(makeRdd(1, Nil), Array(0))
+ failed(taskSets(0), "some failure")
+ assert(failure.getMessage === "Job failed: some failure")
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+ submit(reduceRdd, Array(0))
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ complete(taskSets(1), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // the 2nd ResultTask failed
+ complete(taskSets(1), Seq(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
+ // this will get called
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // ask the scheduler to try it again
+ scheduler.resubmitFailedStages()
+ // have the 2nd attempt pass
+ complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+ // we can see both result blocks now
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
+ complete(taskSets(3), Seq((Success, 43)))
+ assert(results === Map(0 -> 42, 1 -> 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+ // pretend we were told hostA went away
+ val oldGeneration = mapOutputTracker.getGeneration
+ runEvent(ExecutorLost("exec-hostA"))
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ val taskSet = taskSets(0)
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+ // should work because it's a new generation
+ taskSet.tasks(1).generation = newGeneration
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ complete(taskSets(1), Seq((Success, 42), (Success, 43)))
+ assert(results === Map(0 -> 42, 1 -> 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+ submit(reduceRdd, Array(0))
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // pretend we were told hostA went away
+ runEvent(ExecutorLost("exec-hostA"))
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // have hostC complete the resubmitted task
+ complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ complete(taskSets(2), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+ submit(finalRdd, Array(0))
+ // have the first stage complete normally
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ // have the second stage complete normally
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))))
+ // fail the third stage because hostA went down
+ complete(taskSets(2), Seq(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ // TODO assert this:
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // have DAGScheduler try again
+ scheduler.resubmitFailedStages()
+ complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
+ complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
+ complete(taskSets(5), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+ submit(finalRdd, Array(0))
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ // complete stage 2
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ // complete stage 1
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // pretend stage 0 failed because hostA went down
+ complete(taskSets(2), Seq(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ // TODO assert this:
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ scheduler.resubmitFailedStages()
+ assertLocations(taskSets(3), Seq(Seq("hostD")))
+ // allow hostD to recover
+ complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
+ complete(taskSets(4), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ /** Assert that the supplied TaskSet has exactly the given preferredLocations. */
+ private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ private def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ private def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+}
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
new file mode 100644
index 0000000000..2f5af10e69
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -0,0 +1,86 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import spark.{SparkContext, LocalSparkContext}
+import scala.collection.mutable
+import org.scalatest.matchers.ShouldMatchers
+import spark.SparkContext._
+
+/**
+ *
+ */
+
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("local metrics") {
+ sc = new SparkContext("local[4]", "test")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ sc.addSparkListener(new StatsReportListener)
+ //just to make sure some of the tasks take a noticeable amount of time
+ val w = {i:Int =>
+ if (i == 0)
+ Thread.sleep(100)
+ i
+ }
+
+ val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ d.count
+ listener.stageInfos.size should be (1)
+
+ val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1")
+
+ val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2")
+
+ val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
+ d4.setName("A Cogroup")
+
+ d4.collectAsMap
+
+ listener.stageInfos.size should be (4)
+ listener.stageInfos.foreach {stageInfo =>
+ //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
+ if (stageInfo.stage.rdd.name == d4.name) {
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
+ }
+
+ stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
+ taskMetrics.resultSize should be > (0l)
+ if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ taskMetrics.shuffleWriteMetrics should be ('defined)
+ taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
+ }
+ if (stageInfo.stage.rdd.name == d4.name) {
+ taskMetrics.shuffleReadMetrics should be ('defined)
+ val sm = taskMetrics.shuffleReadMetrics.get
+ sm.totalBlocksFetched should be > (0)
+ sm.shuffleReadMillis should be > (0l)
+ sm.localBlocksFetched should be > (0)
+ sm.remoteBlocksFetched should be (0)
+ sm.remoteBytesRead should be (0l)
+ sm.remoteFetchTime should be (0l)
+ }
+ }
+ }
+ }
+
+ def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+ assert(m.sum / m.size.toDouble > 0.0, msg)
+ }
+
+ def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
+ val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
+ !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+ }
+
+ class SaveStageInfo extends SparkListener {
+ val stageInfos = mutable.Buffer[StageInfo]()
+ def onStageCompleted(stage: StageCompleted) {
+ stageInfos += stage.stageInfo
+ }
+ }
+
+}
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
index f937877340..647bcaf860 100644
--- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -5,28 +5,17 @@ import org.scalatest.BeforeAndAfter
import spark.TaskContext
import spark.RDD
import spark.SparkContext
-import spark.Split
+import spark.Partition
+import spark.LocalSparkContext
-class TaskContextSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
test("Calls executeOnCompleteCallbacks after failure") {
var completed = false
sc = new SparkContext("local", "test")
- val rdd = new RDD[String](sc) {
- override val splits = Array[Split](StubSplit(0))
- override val dependencies = List()
- override def compute(split: Split, context: TaskContext) = {
+ val rdd = new RDD[String](sc, List()) {
+ override def getPartitions = Array[Partition](StubPartition(0))
+ override def compute(split: Partition, context: TaskContext) = {
context.addOnCompleteCallback(() => completed = true)
sys.error("failed")
}
@@ -39,5 +28,5 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter {
assert(completed === true)
}
- case class StubSplit(val index: Int) extends Split
-} \ No newline at end of file
+ case class StubPartition(val index: Int) extends Partition
+}
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index 8f86e3170e..b8c0f6fb76 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -12,6 +12,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers._
import org.scalatest.time.SpanSugar._
+import spark.JavaSerializer
import spark.KryoSerializer
import spark.SizeEstimator
import spark.util.ByteBufferInputStream
@@ -31,7 +32,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
before {
actorSystem = ActorSystem("test")
- master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077)
+ master = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true))))
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
@@ -69,33 +71,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("StorageLevel object caching") {
- val level1 = new StorageLevel(false, false, false, 3)
- val level2 = new StorageLevel(false, false, false, 3)
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
val bytes1 = spark.Utils.serialize(level1)
val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
val bytes2 = spark.Utils.serialize(level2)
val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
assert(level1_ === level1, "Deserialized level1 not same as original level1")
- assert(level2_ === level2, "Deserialized level2 not same as original level1")
- assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2")
- assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
}
test("BlockManagerId object caching") {
- val id1 = new StorageLevel(false, false, false, 3)
- val id2 = new StorageLevel(false, false, false, 3)
+ val id1 = BlockManagerId("e1", "XXX", 1)
+ val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
val bytes1 = spark.Utils.serialize(id1)
- val id1_ = spark.Utils.deserialize[StorageLevel](bytes1)
+ val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1)
val bytes2 = spark.Utils.serialize(id2)
- val id2_ = spark.Utils.deserialize[StorageLevel](bytes2)
- assert(id1_ === id1, "Deserialized id1 not same as original id1")
- assert(id2_ === id2, "Deserialized id2 not same as original id1")
- assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2")
- assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1")
+ val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
}
test("master + 1 manager interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -125,8 +135,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
- store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -141,7 +151,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -190,7 +200,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -198,7 +208,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a1") != None, "a1 was not in store")
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store invokePrivate heartBeat()
@@ -206,25 +216,63 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
-
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ store.waitForAsyncReregister()
assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
assert(master.getLocations("a2").size > 0, "master was not told about a2")
}
+ test("reregistration doesn't dead lock") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = List(new Array[Byte](400))
+
+ // try many times to trigger any deadlocks
+ for (i <- 1 to 100) {
+ master.removeExecutor(store.blockManagerId.executorId)
+ val t1 = new Thread {
+ override def run() {
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+ }
+ }
+ val t2 = new Thread {
+ override def run() {
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ }
+ }
+ val t3 = new Thread {
+ override def run() {
+ store invokePrivate heartBeat()
+ }
+ }
+
+ t1.start()
+ t2.start()
+ t3.start()
+ t1.join()
+ t2.join()
+ t3.join()
+
+ store.dropFromMemory("a1", null)
+ store.dropFromMemory("a2", null)
+ store.waitForAsyncReregister()
+ }
+ }
+
test("in-memory LRU storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -243,7 +291,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -262,14 +310,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
- // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2
+ // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
@@ -281,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -304,7 +352,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -317,7 +365,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -332,7 +380,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -347,7 +395,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -362,7 +410,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -377,7 +425,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -402,7 +450,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -426,7 +474,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -472,7 +520,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager(actorSystem, master, serializer, 500)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -483,49 +531,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
System.setProperty("spark.shuffle.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
@@ -536,4 +584,21 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.clearProperty("spark.rdd.compress")
}
}
+
+ test("block store put failure") {
+ // Use Java serializer so we can create an unserializable error.
+ store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer, 1200)
+
+ // The put should fail since a1 is not serializable.
+ class UnserializableClass
+ val a1 = new UnserializableClass
+ intercept[java.io.NotSerializableException] {
+ store.putSingle("a1", a1, StorageLevel.DISK_ONLY)
+ }
+
+ // Make sure get a1 doesn't hang and returns None.
+ failAfter(1 second) {
+ assert(store.getSingle("a1") == None, "a1 should not be in store")
+ }
+ }
}
diff --git a/core/src/test/scala/spark/util/DistributionSuite.scala b/core/src/test/scala/spark/util/DistributionSuite.scala
new file mode 100644
index 0000000000..cc6249b1dd
--- /dev/null
+++ b/core/src/test/scala/spark/util/DistributionSuite.scala
@@ -0,0 +1,25 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+/**
+ *
+ */
+
+class DistributionSuite extends FunSuite with ShouldMatchers {
+ test("summary") {
+ val d = new Distribution((1 to 100).toArray.map{_.toDouble})
+ val stats = d.statCounter
+ stats.count should be (100)
+ stats.mean should be (50.5)
+ stats.sum should be (50 * 101)
+
+ val quantiles = d.getQuantiles()
+ quantiles(0) should be (1)
+ quantiles(1) should be (26)
+ quantiles(2) should be (51)
+ quantiles(3) should be (76)
+ quantiles(4) should be (100)
+ }
+}
diff --git a/core/src/test/scala/spark/util/NextIteratorSuite.scala b/core/src/test/scala/spark/util/NextIteratorSuite.scala
new file mode 100644
index 0000000000..ed5b36da73
--- /dev/null
+++ b/core/src/test/scala/spark/util/NextIteratorSuite.scala
@@ -0,0 +1,68 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable.Buffer
+import java.util.NoSuchElementException
+
+class NextIteratorSuite extends FunSuite with ShouldMatchers {
+ test("one iteration") {
+ val i = new StubIterator(Buffer(1))
+ i.hasNext should be === true
+ i.next should be === 1
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("two iterations") {
+ val i = new StubIterator(Buffer(1, 2))
+ i.hasNext should be === true
+ i.next should be === 1
+ i.hasNext should be === true
+ i.next should be === 2
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("empty iteration") {
+ val i = new StubIterator(Buffer())
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("close is called once for empty iterations") {
+ val i = new StubIterator(Buffer())
+ i.hasNext should be === false
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ }
+
+ test("close is called once for non-empty iterations") {
+ val i = new StubIterator(Buffer(1, 2))
+ i.next should be === 1
+ i.next should be === 2
+ // close isn't called until we check for the next element
+ i.closeCalled should be === 0
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ }
+
+ class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] {
+ var closeCalled = 0
+
+ override def getNext() = {
+ if (ints.size == 0) {
+ finished = true
+ 0
+ } else {
+ ints.remove(0)
+ }
+ }
+
+ override def close() {
+ closeCalled += 1
+ }
+ }
+}
diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
new file mode 100644
index 0000000000..794063fb6d
--- /dev/null
+++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
@@ -0,0 +1,23 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import java.io.ByteArrayOutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStreamSuite extends FunSuite {
+
+ private def benchmark[U](f: => U): Long = {
+ val start = System.nanoTime
+ f
+ System.nanoTime - start
+ }
+
+ test("write") {
+ val underlying = new ByteArrayOutputStream
+ val data = "X" * 41000
+ val stream = new RateLimitedOutputStream(underlying, 10000)
+ val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
+ assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4)
+ assert(underlying.toString("UTF-8") == data)
+ }
+}