aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org')
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala143
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala392
-rw-r--r--core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala146
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala362
-rw-r--r--core/src/test/scala/org/apache/spark/DriverSuite.scala54
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala127
-rw-r--r--core/src/test/scala/org/apache/spark/FileServerSuite.scala123
-rw-r--r--core/src/test/scala/org/apache/spark/FileSuite.scala212
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java865
-rw-r--r--core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala208
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala68
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala136
-rw-r--r--core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala299
-rw-r--r--core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala150
-rw-r--r--core/src/test/scala/org/apache/spark/PipedRDDSuite.scala93
-rw-r--r--core/src/test/scala/org/apache/spark/RDDSuite.scala389
-rw-r--r--core/src/test/scala/org/apache/spark/SharedSparkContext.scala42
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala210
-rw-r--r--core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala164
-rw-r--r--core/src/test/scala/org/apache/spark/SortingSuite.scala123
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala60
-rw-r--r--core/src/test/scala/org/apache/spark/ThreadingSuite.scala152
-rw-r--r--core/src/test/scala/org/apache/spark/UnpersistSuite.scala47
-rw-r--r--core/src/test/scala/org/apache/spark/UtilsSuite.scala139
-rw-r--r--core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala50
-rw-r--r--core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala62
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala89
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala54
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala212
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala421
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala121
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala102
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala49
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala266
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala273
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala223
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala666
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala47
-rw-r--r--core/src/test/scala/org/apache/spark/util/DistributionSuite.scala42
-rw-r--r--core/src/test/scala/org/apache/spark/util/FakeClock.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala85
-rw-r--r--core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala40
47 files changed, 7732 insertions, 0 deletions
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
new file mode 100644
index 0000000000..4434f3b87c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import collection.mutable
+import java.util.Random
+import scala.math.exp
+import scala.math.signum
+import org.apache.spark.SparkContext._
+
+class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+
+ test ("basic accumulation"){
+ sc = new SparkContext("local", "test")
+ val acc : Accumulator[Int] = sc.accumulator(0)
+
+ val d = sc.parallelize(1 to 20)
+ d.foreach{x => acc += x}
+ acc.value should be (210)
+
+
+ val longAcc = sc.accumulator(0l)
+ val maxInt = Integer.MAX_VALUE.toLong
+ d.foreach{x => longAcc += maxInt + x}
+ longAcc.value should be (210l + maxInt * 20)
+ }
+
+ test ("value not assignable from tasks") {
+ sc = new SparkContext("local", "test")
+ val acc : Accumulator[Int] = sc.accumulator(0)
+
+ val d = sc.parallelize(1 to 20)
+ evaluating {d.foreach{x => acc.value = x}} should produce [Exception]
+ }
+
+ test ("add value to collection accumulators") {
+ import SetAccum._
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ sc = new SparkContext("local[" + nThreads + "]", "test")
+ val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
+ val d = sc.parallelize(1 to maxI)
+ d.foreach {
+ x => acc += x
+ }
+ val v = acc.value.asInstanceOf[mutable.Set[Int]]
+ for (i <- 1 to maxI) {
+ v should contain(i)
+ }
+ resetSparkContext()
+ }
+ }
+
+ implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
+ def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
+ t1 ++= t2
+ t1
+ }
+ def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = {
+ t1 += t2
+ t1
+ }
+ def zero(t: mutable.Set[Any]) : mutable.Set[Any] = {
+ new mutable.HashSet[Any]()
+ }
+ }
+
+ test ("value not readable in tasks") {
+ import SetAccum._
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ sc = new SparkContext("local[" + nThreads + "]", "test")
+ val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
+ val d = sc.parallelize(1 to maxI)
+ evaluating {
+ d.foreach {
+ x => acc.value += x
+ }
+ } should produce [SparkException]
+ resetSparkContext()
+ }
+ }
+
+ test ("collection accumulators") {
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) {
+ // test single & multi-threaded
+ sc = new SparkContext("local[" + nThreads + "]", "test")
+ val setAcc = sc.accumulableCollection(mutable.HashSet[Int]())
+ val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]())
+ val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]())
+ val d = sc.parallelize((1 to maxI) ++ (1 to maxI))
+ d.foreach {
+ x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
+ }
+
+ // Note that this is typed correctly -- no casts necessary
+ setAcc.value.size should be (maxI)
+ bufferAcc.value.size should be (2 * maxI)
+ mapAcc.value.size should be (maxI)
+ for (i <- 1 to maxI) {
+ setAcc.value should contain(i)
+ bufferAcc.value should contain(i)
+ mapAcc.value should contain (i -> i.toString)
+ }
+ resetSparkContext()
+ }
+ }
+
+ test ("localValue readable in tasks") {
+ import SetAccum._
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ sc = new SparkContext("local[" + nThreads + "]", "test")
+ val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
+ val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet}
+ val d = sc.parallelize(groupedInts)
+ d.foreach {
+ x => acc.localValue ++= x
+ }
+ acc.value should be ( (0 to maxI).toSet)
+ resetSparkContext()
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
new file mode 100644
index 0000000000..b3a53d928b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+class BroadcastSuite extends FunSuite with LocalSparkContext {
+
+ test("basic broadcast") {
+ sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("broadcast variables accessed in multiple threads") {
+ sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
new file mode 100644
index 0000000000..23b14f4245
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import java.io.File
+import org.apache.spark.rdd._
+import org.apache.spark.SparkContext._
+import storage.StorageLevel
+
+class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
+ initLogging()
+
+ var checkpointDir: File = _
+ val partitioner = new HashPartitioner(2)
+
+ override def beforeEach() {
+ super.beforeEach()
+ checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+ sc = new SparkContext("local", "test")
+ sc.setCheckpointDir(checkpointDir.toString)
+ }
+
+ override def afterEach() {
+ super.afterEach()
+ if (checkpointDir != null) {
+ checkpointDir.delete()
+ }
+ }
+
+ test("basic checkpointing") {
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+ }
+
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
+ test("ParallelCollection") {
+ val parCollection = sc.makeRDD(1 to 4, 2)
+ val numPartitions = parCollection.partitions.size
+ parCollection.checkpoint()
+ assert(parCollection.dependencies === Nil)
+ val result = parCollection.collect()
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ assert(parCollection.dependencies != Nil)
+ assert(parCollection.partitions.length === numPartitions)
+ assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList)
+ assert(parCollection.collect() === result)
+ }
+
+ test("BlockRDD") {
+ val blockId = "id"
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
+ val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ val numPartitions = blockRDD.partitions.size
+ blockRDD.checkpoint()
+ val result = blockRDD.collect()
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.partitions.length === numPartitions)
+ assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
+ assert(blockRDD.collect() === result)
+ }
+
+ test("ShuffledRDD") {
+ testCheckpointing(rdd => {
+ // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+ new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
+ })
+ }
+
+ test("UnionRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+
+ // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed.
+ // Current implementation of UnionRDD has transient reference to parent RDDs,
+ // so only the partitions will reduce in serialized size, not the RDD.
+ testCheckpointing(_.union(otherRDD), false, true)
+ testParentCheckpointing(_.union(otherRDD), false, true)
+ }
+
+ test("CartesianRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+ testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+ // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // Note that this test is very specific to the current implementation of CartesianRDD.
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint() // checkpoint that MappedRDD
+ val cartesian = new CartesianRDD(sc, ones, ones)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
+ cartesian.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
+ assert(
+ (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+ (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+ "CartesianRDD.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoalescedRDD") {
+ testCheckpointing(_.coalesce(2))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(_.coalesce(2), true, false)
+
+ // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // Note that this test is very specific to the current implementation of CoalescedRDDPartitions
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint() // checkpoint that MappedRDD
+ val coalesced = new CoalescedRDD(ones, 2)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
+ coalesced.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
+ assert(
+ splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+ "CoalescedRDDPartition.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoGroupedRDD") {
+ val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+ testCheckpointing(rdd => {
+ CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+
+ val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+ testParentCheckpointing(rdd => {
+ CheckpointSuite.cogroup(
+ longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+ }
+
+ test("ZippedRDD") {
+ testCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of ZippedRDDPartitions has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+ }
+
+ test("CheckpointRDD with zero partitions") {
+ val rdd = new BlockRDD[Int](sc, Array[String]())
+ assert(rdd.partitions.size === 0)
+ assert(rdd.isCheckpointed === false)
+ rdd.checkpoint()
+ assert(rdd.count() === 0)
+ assert(rdd.isCheckpointed === true)
+ assert(rdd.partitions.size === 0)
+ }
+
+ /**
+ * Test checkpointing of the final RDD generated by the given operation. By default,
+ * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+ * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or
+ * not, but this is not done by default as usually the partitions do not refer to any RDD and
+ * therefore never store the lineage.
+ */
+ def testCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean = true,
+ testRDDPartitionSize: Boolean = false
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numPartitions = operatedRDD.partitions.length
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.checkpoint()
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the checkpoint file has been created
+ assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the partitions have been changed to the new Hadoop partitions
+ assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
+
+ // Test whether the number of partitions is same as before
+ assert(operatedRDD.partitions.length === numPartitions)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced. If the RDD
+ // does not have any dependency to another RDD (e.g., ParallelCollection,
+ // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+ if (testRDDSize) {
+ logInfo("Size of " + rddType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the partitions has reduced. If the partitions
+ // do not have any non-transient reference to another RDD or another RDD's partitions, it
+ // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+ // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions
+ // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+ // replaced with the HadooPartitions of the checkpointed RDD.
+ if (testRDDPartitionSize) {
+ logInfo("Size of " + rddType + " partitions "
+ + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " partitions did not reduce after checkpointing " +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+ }
+
+ /**
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
+ * this RDD will remember the partitions and therefore potentially the whole lineage.
+ */
+ def testParentCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean,
+ testRDDPartitionSize: Boolean
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val rddType = operatedRDD.getClass.getSimpleName
+ val parentRDDType = parentRDD.getClass.getSimpleName
+
+ // Get the partitions and dependencies of the parent in case they're lazily computed
+ parentRDD.dependencies
+ parentRDD.partitions
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced because of its parent being
+ // checkpointed. If this RDD or its parent RDD do not have any dependency
+ // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+ // not reduce in size after checkpointing.
+ if (testRDDSize) {
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the partitions has reduced because of its parent being
+ // checkpointed. If the partitions do not have any non-transient reference to another RDD
+ // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent
+ // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's
+ // partitions must have changed after checkpointing.
+ if (testRDDPartitionSize) {
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+
+ }
+
+ /**
+ * Generate an RDD with a long lineage of one-to-one dependencies.
+ */
+ def generateLongLineageRDD(): RDD[Int] = {
+ var rdd = sc.makeRDD(1 to 100, 4)
+ for (i <- 1 to 50) {
+ rdd = rdd.map(x => x + 1)
+ }
+ rdd
+ }
+
+ /**
+ * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+ * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+ * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+ * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ */
+ def generateLongLineageRDDForCoGroupedRDD() = {
+ val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+ def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+ var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+ for(i <- 1 to 10) {
+ cogrouped = cogrouped.mapValues(add).cogroup(ones)
+ }
+ cogrouped.mapValues(add)
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
+ */
+ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
+ Utils.serialize(rdd.partitions).length)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
+ }
+}
+
+
+object CheckpointSuite {
+ // This is a custom cogroup function that does not use mapValues like
+ // the PairRDDFunctions.cogroup()
+ def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+ //println("First = " + first + ", second = " + second)
+ new CoGroupedRDD[K](
+ Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]),
+ part
+ ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala
new file mode 100644
index 0000000000..8494899b98
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.NotSerializableException
+
+import org.scalatest.FunSuite
+import org.apache.spark.LocalSparkContext._
+import SparkContext._
+
+class ClosureCleanerSuite extends FunSuite {
+ test("closures inside an object") {
+ assert(TestObject.run() === 30) // 6 + 7 + 8 + 9
+ }
+
+ test("closures inside a class") {
+ val obj = new TestClass
+ assert(obj.run() === 30) // 6 + 7 + 8 + 9
+ }
+
+ test("closures inside a class with no default constructor") {
+ val obj = new TestClassWithoutDefaultConstructor(5)
+ assert(obj.run() === 30) // 6 + 7 + 8 + 9
+ }
+
+ test("closures that don't use fields of the outer class") {
+ val obj = new TestClassWithoutFieldAccess
+ assert(obj.run() === 30) // 6 + 7 + 8 + 9
+ }
+
+ test("nested closures inside an object") {
+ assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
+ }
+
+ test("nested closures inside a class") {
+ val obj = new TestClassWithNesting(1)
+ assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
+ }
+}
+
+// A non-serializable class we create in closures to make sure that we aren't
+// keeping references to unneeded variables from our outer closures.
+class NonSerializable {}
+
+object TestObject {
+ def run(): Int = {
+ var nonSer = new NonSerializable
+ var x = 5
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
+ }
+}
+
+class TestClass extends Serializable {
+ var x = 5
+
+ def getX = x
+
+ def run(): Int = {
+ var nonSer = new NonSerializable
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
+ }
+}
+
+class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
+ def getX = x
+
+ def run(): Int = {
+ var nonSer = new NonSerializable
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
+ }
+}
+
+// This class is not serializable, but we aren't using any of its fields in our
+// closures, so they won't have a $outer pointing to it and should still work.
+class TestClassWithoutFieldAccess {
+ var nonSer = new NonSerializable
+
+ def run(): Int = {
+ var nonSer2 = new NonSerializable
+ var x = 5
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
+ }
+}
+
+
+object TestObjectWithNesting {
+ def run(): Int = {
+ var nonSer = new NonSerializable
+ var answer = 0
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ var y = 1
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + y).reduce(_ + _)
+ }
+ answer
+ }
+ }
+}
+
+class TestClassWithNesting(val y: Int) extends Serializable {
+ def getY = y
+
+ def run(): Int = {
+ var nonSer = new NonSerializable
+ var answer = 0
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + getY).reduce(_ + _)
+ }
+ answer
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
new file mode 100644
index 0000000000..7a856d4081
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -0,0 +1,362 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import network.ConnectionManagerId
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.prop.Checkers
+import org.scalatest.time.{Span, Millis}
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+import org.eclipse.jetty.server.{Server, Request, Handler}
+
+import com.google.common.io.Files
+
+import scala.collection.mutable.ArrayBuffer
+
+import SparkContext._
+import storage.{GetBlock, BlockManagerWorker, StorageLevel}
+import ui.JettyUtils
+
+
+class NotSerializableClass
+class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
+
+
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
+
+ val clusterUrl = "local-cluster[2,1,512]"
+
+ after {
+ System.clearProperty("spark.reducer.maxMbInFlight")
+ System.clearProperty("spark.storage.memoryFraction")
+ }
+
+ test("task throws not serializable exception") {
+ // Ensures that executors do not crash when an exn is not serializable. If executors crash,
+ // this test will hang. Correct behavior is that executors don't crash but fail tasks
+ // and the scheduler throws a SparkException.
+
+ // numSlaves must be less than numPartitions
+ val numSlaves = 3
+ val numPartitions = 10
+
+ sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test")
+ val data = sc.parallelize(1 to 100, numPartitions).
+ map(x => throw new NotSerializableExn(new NotSerializableClass))
+ intercept[SparkException] {
+ data.count()
+ }
+ resetSparkContext()
+ }
+
+ test("local-cluster format") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
+ resetSparkContext()
+ sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
+ resetSparkContext()
+ sc = new SparkContext("local-cluster[2, 1, 512]", "test")
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
+ resetSparkContext()
+ sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
+ resetSparkContext()
+ }
+
+ test("simple groupByKey") {
+ sc = new SparkContext(clusterUrl, "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5)
+ val groups = pairs.groupByKey(5).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey where map output sizes exceed maxMbInFlight") {
+ System.setProperty("spark.reducer.maxMbInFlight", "1")
+ sc = new SparkContext(clusterUrl, "test")
+ // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output
+ // file should be about 2.5 MB
+ val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000)))
+ val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect()
+ assert(groups.length === 16)
+ assert(groups.map(_._2).sum === 2000)
+ // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block
+ }
+
+ test("accumulators") {
+ sc = new SparkContext(clusterUrl, "test")
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 10, 10).foreach(x => accum += x)
+ assert(accum.value === 55)
+ }
+
+ test("broadcast variables") {
+ sc = new SparkContext(clusterUrl, "test")
+ val array = new Array[Int](100)
+ val bv = sc.broadcast(array)
+ array(2) = 3 // Change the array -- this should not be seen on workers
+ val rdd = sc.parallelize(1 to 10, 10)
+ val sum = rdd.map(x => bv.value.sum).reduce(_ + _)
+ assert(sum === 0)
+ }
+
+ test("repeatedly failing task") {
+ sc = new SparkContext(clusterUrl, "test")
+ val accum = sc.accumulator(0)
+ val thrown = intercept[SparkException] {
+ sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getMessage.contains("more than 4 times"))
+ }
+
+ test("caching") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).cache()
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ }
+
+ test("caching on disk") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ }
+
+ test("caching in memory, replicated") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ }
+
+ test("caching in memory, serialized, replicated") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ }
+
+ test("caching on disk, replicated") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ }
+
+ test("caching in memory and disk, replicated") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ }
+
+ test("caching in memory and disk, serialized, replicated") {
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
+
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+ assert(data.count() === 1000)
+
+ // Get all the locations of the first partition and try to fetch the partitions
+ // from those locations.
+ val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray
+ val blockId = blockIds(0)
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.master.getLocations(blockId).foreach(id => {
+ val bytes = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(id.host, id.port))
+ val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
+ assert(deserialized === (1 to 100).toList)
+ })
+ }
+
+ test("compute without caching when no partitions fit in memory") {
+ System.setProperty("spark.storage.memoryFraction", "0.0001")
+ sc = new SparkContext(clusterUrl, "test")
+ // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache
+ // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory
+ val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER)
+ assert(data.count() === 4000000)
+ assert(data.count() === 4000000)
+ assert(data.count() === 4000000)
+ System.clearProperty("spark.storage.memoryFraction")
+ }
+
+ test("compute when only some partitions fit in memory") {
+ System.setProperty("spark.storage.memoryFraction", "0.01")
+ sc = new SparkContext(clusterUrl, "test")
+ // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache
+ // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions
+ // to make sure that *some* of them do fit though
+ val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER)
+ assert(data.count() === 4000000)
+ assert(data.count() === 4000000)
+ assert(data.count() === 4000000)
+ System.clearProperty("spark.storage.memoryFraction")
+ }
+
+ test("passing environment variables to cluster") {
+ sc = new SparkContext(clusterUrl, "test", null, Nil, Map("TEST_VAR" -> "TEST_VALUE"))
+ val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
+ assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
+ }
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+
+ test("recover from node failures with replication") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ // Using more than two nodes so we don't have a symmetric communication pattern and might
+ // cache a partially correct list of peers.
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+
+ assert(data.count === 4)
+ assert(data.map(markNodeIfIdentity).collect.size === 4)
+ assert(data.map(failOnMarkedIdentity).collect.size === 4)
+
+ // Create a new replicated RDD to make sure that cached peer information doesn't cause
+ // problems.
+ val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2)
+ assert(data2.count === 2)
+ }
+ }
+
+ test("unpersist RDDs") {
+ DistributedSuite.amMaster = true
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+ data.count
+ assert(sc.persistentRdds.isEmpty === false)
+ data.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ }
+
+ test("job should fail if TaskResult exceeds Akka frame size") {
+ // We must use local-cluster mode since results are returned differently
+ // when running under LocalScheduler:
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)}
+ val exception = intercept[SparkException] {
+ rdd.reduce((x, y) => x)
+ }
+ exception.getMessage should endWith("result exceeded Akka frame size")
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
new file mode 100644
index 0000000000..b08aad1a6f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+
+import org.apache.log4j.Logger
+import org.apache.log4j.Level
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ assert(System.getenv("SPARK_HOME") != null)
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(30 seconds) {
+ Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master),
+ new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ Logger.getRootLogger().setLevel(Level.WARN)
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
new file mode 100644
index 0000000000..ee89a7a387
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import SparkContext._
+
+// Common state shared by FailureSuite-launched tasks. We use a global object
+// for this because any local variables used in the task closures will rightfully
+// be copied for each task, so there's no other way for them to share state.
+object FailureSuiteState {
+ var tasksRun = 0
+ var tasksFailed = 0
+
+ def clear() {
+ synchronized {
+ tasksRun = 0
+ tasksFailed = 0
+ }
+ }
+}
+
+class FailureSuite extends FunSuite with LocalSparkContext {
+
+ // Run a 3-task map job in which task 1 deterministically fails once, and check
+ // whether the job completes successfully and we ran 4 tasks in total.
+ test("failure in a single-stage job") {
+ sc = new SparkContext("local[1,1]", "test")
+ val results = sc.makeRDD(1 to 3, 3).map { x =>
+ FailureSuiteState.synchronized {
+ FailureSuiteState.tasksRun += 1
+ if (x == 1 && FailureSuiteState.tasksFailed == 0) {
+ FailureSuiteState.tasksFailed += 1
+ throw new Exception("Intentional task failure")
+ }
+ }
+ x * x
+ }.collect()
+ FailureSuiteState.synchronized {
+ assert(FailureSuiteState.tasksRun === 4)
+ }
+ assert(results.toList === List(1,4,9))
+ FailureSuiteState.clear()
+ }
+
+ // Run a map-reduce job in which a reduce task deterministically fails once.
+ test("failure in a two-stage job") {
+ sc = new SparkContext("local[1,1]", "test")
+ val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
+ case (k, v) =>
+ FailureSuiteState.synchronized {
+ FailureSuiteState.tasksRun += 1
+ if (k == 1 && FailureSuiteState.tasksFailed == 0) {
+ FailureSuiteState.tasksFailed += 1
+ throw new Exception("Intentional task failure")
+ }
+ }
+ (k, v(0) * v(0))
+ }.collect()
+ FailureSuiteState.synchronized {
+ assert(FailureSuiteState.tasksRun === 4)
+ }
+ assert(results.toSet === Set((1, 1), (2, 4), (3, 9)))
+ FailureSuiteState.clear()
+ }
+
+ test("failure because task results are not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
+
+ val thrown = intercept[SparkException] {
+ results.collect()
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getMessage.contains("NotSerializableException"))
+
+ FailureSuiteState.clear()
+ }
+
+ test("failure because task closure is not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val a = new NonSerializable
+
+ // Non-serializable closure in the final result stage
+ val thrown = intercept[SparkException] {
+ sc.parallelize(1 to 10, 2).map(x => a).count()
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getMessage.contains("NotSerializableException"))
+
+ // Non-serializable closure in an earlier stage
+ val thrown1 = intercept[SparkException] {
+ sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
+ }
+ assert(thrown1.getClass === classOf[SparkException])
+ assert(thrown1.getMessage.contains("NotSerializableException"))
+
+ // Non-serializable closure in foreach function
+ val thrown2 = intercept[SparkException] {
+ sc.parallelize(1 to 10, 2).foreach(x => println(a))
+ }
+ assert(thrown2.getClass === classOf[SparkException])
+ assert(thrown2.getMessage.contains("NotSerializableException"))
+
+ FailureSuiteState.clear()
+ }
+
+ // TODO: Need to add tests with shuffle fetch failures.
+}
+
+
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
new file mode 100644
index 0000000000..35d1d41af1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import com.google.common.io.Files
+import org.scalatest.FunSuite
+import java.io.{File, PrintWriter, FileReader, BufferedReader}
+import SparkContext._
+
+class FileServerSuite extends FunSuite with LocalSparkContext {
+
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
+
+ override def beforeEach() {
+ super.beforeEach()
+ // Create a sample text file
+ val tmpdir = new File(Files.createTempDir(), "test")
+ tmpdir.mkdir()
+ tmpFile = new File(tmpdir, "FileServerSuite.txt")
+ val pw = new PrintWriter(tmpFile)
+ pw.println("100")
+ pw.close()
+ }
+
+ override def afterEach() {
+ super.afterEach()
+ // Clean up downloaded file
+ if (tmpFile.exists) {
+ tmpFile.delete()
+ }
+ }
+
+ test("Distributing files locally") {
+ sc = new SparkContext("local[4]", "test")
+ sc.addFile(tmpFile.toString)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect()
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
+ test("Distributing files locally using URL as input") {
+ // addFile("file:///....")
+ sc = new SparkContext("local[4]", "test")
+ sc.addFile(new File(tmpFile.toString).toURI.toString)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect()
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
+ test ("Dynamically adding JARS locally") {
+ sc = new SparkContext("local[4]", "test")
+ val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader()
+ .loadClass("org.uncommons.maths.Maths")
+ .getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
+
+ test("Distributing files on a standalone cluster") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ sc.addFile(tmpFile.toString)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect()
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
+ test ("Dynamically adding JARS on a standalone cluster") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader()
+ .loadClass("org.uncommons.maths.Maths")
+ .getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
new file mode 100644
index 0000000000..7b82a4cdd9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.{FileWriter, PrintWriter, File}
+
+import scala.io.Source
+
+import com.google.common.io.Files
+import org.scalatest.FunSuite
+import org.apache.hadoop.io._
+import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec}
+
+
+import SparkContext._
+
+class FileSuite extends FunSuite with LocalSparkContext {
+
+ test("text files") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 4)
+ nums.saveAsTextFile(outputDir)
+ // Read the plain text file and check it's OK
+ val outputFile = new File(outputDir, "part-00000")
+ val content = Source.fromFile(outputFile).mkString
+ assert(content === "1\n2\n3\n4\n")
+ // Also try reading it in as a text file RDD
+ assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
+ }
+
+ test("text files (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize("a" * 10000, 1)
+ data.saveAsTextFile(normalDir)
+ data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec])
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.textFile(normalDir).collect
+ assert(normalContent === Array.fill(10000)("a"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.textFile(compressedOutputDir).collect
+ assert(compressedContent === Array.fill(10000)("a"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
+ test("SequenceFiles") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa)
+ nums.saveAsSequenceFile(outputDir)
+ // Try reading the output back as a SequenceFile
+ val output = sc.sequenceFile[IntWritable, Text](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("SequenceFile (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x))
+ data.saveAsSequenceFile(normalDir)
+ data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec]))
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.sequenceFile[String, String](normalDir).collect
+ assert(normalContent === Array.fill(100)("abc", "abc"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect
+ assert(compressedContent === Array.fill(100)("abc", "abc"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
+ test("SequenceFile with writable key") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x))
+ nums.saveAsSequenceFile(outputDir)
+ // Try reading the output back as a SequenceFile
+ val output = sc.sequenceFile[IntWritable, Text](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("SequenceFile with writable value") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x)))
+ nums.saveAsSequenceFile(outputDir)
+ // Try reading the output back as a SequenceFile
+ val output = sc.sequenceFile[IntWritable, Text](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("SequenceFile with writable key and value") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
+ nums.saveAsSequenceFile(outputDir)
+ // Try reading the output back as a SequenceFile
+ val output = sc.sequenceFile[IntWritable, Text](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("implicit conversions in reading SequenceFiles") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa)
+ nums.saveAsSequenceFile(outputDir)
+ // Similar to the tests above, we read a SequenceFile, but this time we pass type params
+ // that are convertable to Writable instead of calling sequenceFile[IntWritable, Text]
+ val output1 = sc.sequenceFile[Int, String](outputDir)
+ assert(output1.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa")))
+ // Also try having one type be a subclass of Writable and one not
+ val output2 = sc.sequenceFile[Int, Text](outputDir)
+ assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ val output3 = sc.sequenceFile[IntWritable, String](outputDir)
+ assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("object files of ints") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 4)
+ nums.saveAsObjectFile(outputDir)
+ // Try reading the output back as an object file
+ val output = sc.objectFile[Int](outputDir)
+ assert(output.collect().toList === List(1, 2, 3, 4))
+ }
+
+ test("object files of complex types") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x))
+ nums.saveAsObjectFile(outputDir)
+ // Try reading the output back as an object file
+ val output = sc.objectFile[(Int, String)](outputDir)
+ assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa")))
+ }
+
+ test("write SequenceFile using new Hadoop API") {
+ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
+ nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]](
+ outputDir)
+ val output = sc.sequenceFile[IntWritable, Text](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("read SequenceFile using new Hadoop API") {
+ import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
+ nums.saveAsSequenceFile(outputDir)
+ val output =
+ sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ }
+
+ test("file caching") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val out = new FileWriter(tempDir + "/input")
+ out.write("Hello world!\n")
+ out.write("What's up?\n")
+ out.write("Goodbye\n")
+ out.close()
+ val rdd = sc.textFile(tempDir + "/input").cache()
+ assert(rdd.count() === 3)
+ assert(rdd.count() === 3)
+ assert(rdd.count() === 3)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
new file mode 100644
index 0000000000..8a869c9005
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -0,0 +1,865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.*;
+
+import com.google.common.base.Optional;
+import scala.Tuple2;
+
+import com.google.common.base.Charsets;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import com.google.common.io.Files;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.Job;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.partial.BoundedDouble;
+import org.apache.spark.partial.PartialResult;
+import org.apache.spark.storage.StorageLevel;
+import org.apache.spark.util.StatCounter;
+
+
+// The test suite itself is Serializable so that anonymous Function implementations can be
+// serialized, as an alternative to converting these anonymous classes to static inner classes;
+// see http://stackoverflow.com/questions/758570/.
+public class JavaAPISuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaAPISuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port");
+ }
+
+ static class ReverseIntComparator implements Comparator<Integer>, Serializable {
+
+ @Override
+ public int compare(Integer a, Integer b) {
+ if (a > b) return -1;
+ else if (a < b) return 1;
+ else return 0;
+ }
+ };
+
+ @Test
+ public void sparkContextUnion() {
+ // Union of non-specialized JavaRDDs
+ List<String> strings = Arrays.asList("Hello", "World");
+ JavaRDD<String> s1 = sc.parallelize(strings);
+ JavaRDD<String> s2 = sc.parallelize(strings);
+ // Varargs
+ JavaRDD<String> sUnion = sc.union(s1, s2);
+ Assert.assertEquals(4, sUnion.count());
+ // List
+ List<JavaRDD<String>> list = new ArrayList<JavaRDD<String>>();
+ list.add(s2);
+ sUnion = sc.union(s1, list);
+ Assert.assertEquals(4, sUnion.count());
+
+ // Union of JavaDoubleRDDs
+ List<Double> doubles = Arrays.asList(1.0, 2.0);
+ JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD dUnion = sc.union(d1, d2);
+ Assert.assertEquals(4, dUnion.count());
+
+ // Union of JavaPairRDDs
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(1, 2));
+ pairs.add(new Tuple2<Integer, Integer>(3, 4));
+ JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> pUnion = sc.union(p1, p2);
+ Assert.assertEquals(4, pUnion.count());
+ }
+
+ @Test
+ public void sortByKey() {
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(0, 4));
+ pairs.add(new Tuple2<Integer, Integer>(3, 2));
+ pairs.add(new Tuple2<Integer, Integer>(-1, 1));
+
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+
+ // Default comparator
+ JavaPairRDD<Integer, Integer> sortedRDD = rdd.sortByKey();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ List<Tuple2<Integer, Integer>> sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+
+ // Custom comparator
+ sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false);
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+ }
+
+ static int foreachCalls = 0;
+
+ @Test
+ public void foreach() {
+ foreachCalls = 0;
+ JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World"));
+ rdd.foreach(new VoidFunction<String>() {
+ @Override
+ public void call(String s) {
+ foreachCalls++;
+ }
+ });
+ Assert.assertEquals(2, foreachCalls);
+ }
+
+ @Test
+ public void lookup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ Assert.assertEquals(2, categories.lookup("Oranges").size());
+ Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size());
+ }
+
+ @Test
+ public void groupBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() {
+ @Override
+ public Boolean call(Integer x) {
+ return x % 2 == 0;
+ }
+ };
+ JavaPairRDD<Boolean, List<Integer>> oddsAndEvens = rdd.groupBy(isOdd);
+ Assert.assertEquals(2, oddsAndEvens.count());
+ Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens
+ Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
+
+ oddsAndEvens = rdd.groupBy(isOdd, 1);
+ Assert.assertEquals(2, oddsAndEvens.count());
+ Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens
+ Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
+ }
+
+ @Test
+ public void cogroup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 2),
+ new Tuple2<String, Integer>("Apples", 3)
+ ));
+ JavaPairRDD<String, Tuple2<List<String>, List<Integer>>> cogrouped = categories.cogroup(prices);
+ Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString());
+ Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString());
+
+ cogrouped.collect();
+ }
+
+ @Test
+ public void leftOuterJoin() {
+ JavaPairRDD<Integer, Integer> rdd1 = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(1, 2),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(3, 1)
+ ));
+ JavaPairRDD<Integer, Character> rdd2 = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<Integer, Character>(1, 'x'),
+ new Tuple2<Integer, Character>(2, 'y'),
+ new Tuple2<Integer, Character>(2, 'z'),
+ new Tuple2<Integer, Character>(4, 'w')
+ ));
+ List<Tuple2<Integer,Tuple2<Integer,Optional<Character>>>> joined =
+ rdd1.leftOuterJoin(rdd2).collect();
+ Assert.assertEquals(5, joined.size());
+ Tuple2<Integer,Tuple2<Integer,Optional<Character>>> firstUnmatched =
+ rdd1.leftOuterJoin(rdd2).filter(
+ new Function<Tuple2<Integer, Tuple2<Integer, Optional<Character>>>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup)
+ throws Exception {
+ return !tup._2()._2().isPresent();
+ }
+ }).first();
+ Assert.assertEquals(3, firstUnmatched._1().intValue());
+ }
+
+ @Test
+ public void foldReduce() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ };
+
+ int sum = rdd.fold(0, add);
+ Assert.assertEquals(33, sum);
+
+ sum = rdd.reduce(add);
+ Assert.assertEquals(33, sum);
+ }
+
+ @Test
+ public void foldByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0,
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, sums.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, sums.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
+ }
+
+ @Test
+ public void reduceByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> counts = rdd.reduceByKey(
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, counts.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, counts.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, counts.lookup(3).get(0).intValue());
+
+ Map<Integer, Integer> localCounts = counts.collectAsMap();
+ Assert.assertEquals(1, localCounts.get(1).intValue());
+ Assert.assertEquals(2, localCounts.get(2).intValue());
+ Assert.assertEquals(3, localCounts.get(3).intValue());
+
+ localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer,
+ Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, localCounts.get(1).intValue());
+ Assert.assertEquals(2, localCounts.get(2).intValue());
+ Assert.assertEquals(3, localCounts.get(3).intValue());
+ }
+
+ @Test
+ public void approximateResults() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Map<Integer, Long> countsByValue = rdd.countByValue();
+ Assert.assertEquals(2, countsByValue.get(1).longValue());
+ Assert.assertEquals(1, countsByValue.get(13).longValue());
+
+ PartialResult<Map<Integer, BoundedDouble>> approx = rdd.countByValueApprox(1);
+ Map<Integer, BoundedDouble> finalValue = approx.getFinalValue();
+ Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01);
+ Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01);
+ }
+
+ @Test
+ public void take() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Assert.assertEquals(1, rdd.first().intValue());
+ List<Integer> firstTwo = rdd.take(2);
+ List<Integer> sample = rdd.takeSample(false, 2, 42);
+ }
+
+ @Test
+ public void cartesian() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ JavaRDD<String> stringRDD = sc.parallelize(Arrays.asList("Hello", "World"));
+ JavaPairRDD<String, Double> cartesian = stringRDD.cartesian(doubleRDD);
+ Assert.assertEquals(new Tuple2<String, Double>("Hello", 1.0), cartesian.first());
+ }
+
+ @Test
+ public void javaDoubleRDD() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ JavaDoubleRDD distinct = rdd.distinct();
+ Assert.assertEquals(5, distinct.count());
+ JavaDoubleRDD filter = rdd.filter(new Function<Double, Boolean>() {
+ @Override
+ public Boolean call(Double x) {
+ return x > 2.0;
+ }
+ });
+ Assert.assertEquals(3, filter.count());
+ JavaDoubleRDD union = rdd.union(rdd);
+ Assert.assertEquals(12, union.count());
+ union = union.cache();
+ Assert.assertEquals(12, union.count());
+
+ Assert.assertEquals(20, rdd.sum(), 0.01);
+ StatCounter stats = rdd.stats();
+ Assert.assertEquals(20, stats.sum(), 0.01);
+ Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
+ Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
+ Assert.assertEquals(6.22222, rdd.variance(), 0.01);
+ Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01);
+ Assert.assertEquals(2.49444, rdd.stdev(), 0.01);
+ Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01);
+
+ Double first = rdd.first();
+ List<Double> take = rdd.take(5);
+ }
+
+ @Test
+ public void map() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
+ @Override
+ public Double call(Integer x) {
+ return 1.0 * x;
+ }
+ }).cache();
+ JavaPairRDD<Integer, Integer> pairs = rdd.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer x) {
+ return new Tuple2<Integer, Integer>(x, x);
+ }
+ }).cache();
+ JavaRDD<String> strings = rdd.map(new Function<Integer, String>() {
+ @Override
+ public String call(Integer x) {
+ return x.toString();
+ }
+ }).cache();
+ }
+
+ @Test
+ public void flatMap() {
+ JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello World!",
+ "The quick brown fox jumps over the lazy dog."));
+ JavaRDD<String> words = rdd.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Arrays.asList(x.split(" "));
+ }
+ });
+ Assert.assertEquals("Hello", words.first());
+ Assert.assertEquals(11, words.count());
+
+ JavaPairRDD<String, String> pairs = rdd.flatMap(
+ new PairFlatMapFunction<String, String, String>() {
+
+ @Override
+ public Iterable<Tuple2<String, String>> call(String s) {
+ List<Tuple2<String, String>> pairs = new LinkedList<Tuple2<String, String>>();
+ for (String word : s.split(" ")) pairs.add(new Tuple2<String, String>(word, word));
+ return pairs;
+ }
+ }
+ );
+ Assert.assertEquals(new Tuple2<String, String>("Hello", "Hello"), pairs.first());
+ Assert.assertEquals(11, pairs.count());
+
+ JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction<String>() {
+ @Override
+ public Iterable<Double> call(String s) {
+ List<Double> lengths = new LinkedList<Double>();
+ for (String word : s.split(" ")) lengths.add(word.length() * 1.0);
+ return lengths;
+ }
+ });
+ Double x = doubles.first();
+ Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01);
+ Assert.assertEquals(11, pairs.count());
+ }
+
+ @Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
+ public void mapPartitions() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ JavaRDD<Integer> partitionSums = rdd.mapPartitions(
+ new FlatMapFunction<Iterator<Integer>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> iter) {
+ int sum = 0;
+ while (iter.hasNext()) {
+ sum += iter.next();
+ }
+ return Collections.singletonList(sum);
+ }
+ });
+ Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
+ }
+
+ @Test
+ public void persist() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(20, doubleRDD.sum(), 0.1);
+
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+ pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals("a", pairRDD.first()._2());
+
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ rdd = rdd.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(1, rdd.first().intValue());
+ }
+
+ @Test
+ public void iterator() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
+ TaskContext context = new TaskContext(0, 0, 0, null);
+ Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
+ }
+
+ @Test
+ public void glom() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ Assert.assertEquals("[1, 2]", rdd.glom().first().toString());
+ }
+
+ // File input / output tests are largely adapted from FileSuite:
+
+ @Test
+ public void textFiles() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir);
+ // Read the plain text file and check it's OK
+ File outputFile = new File(outputDir, "part-00000");
+ String content = Files.toString(outputFile, Charsets.UTF_8);
+ Assert.assertEquals("1\n2\n3\n4\n", content);
+ // Also try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void textFilesCompressed() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir, DefaultCodec.class);
+
+ // Try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void sequenceFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ // Try reading the output back as an object file
+ JavaPairRDD<Integer, String> readRDD = sc.sequenceFile(outputDir, IntWritable.class,
+ Text.class).map(new PairFunction<Tuple2<IntWritable, Text>, Integer, String>() {
+ @Override
+ public Tuple2<Integer, String> call(Tuple2<IntWritable, Text> pair) {
+ return new Tuple2<Integer, String>(pair._1().get(), pair._2().toString());
+ }
+ });
+ Assert.assertEquals(pairs, readRDD.collect());
+ }
+
+ @Test
+ public void writeWithNewAPIHadoopFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class,
+ org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.sequenceFile(outputDir, IntWritable.class,
+ Text.class);
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void readWithNewAPIHadoopFile() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.newAPIHadoopFile(outputDir,
+ org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class,
+ Text.class, new Job().getConfiguration());
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void objectFilesOfInts() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsObjectFile(outputDir);
+ // Try reading the output back as an object file
+ List<Integer> expected = Arrays.asList(1, 2, 3, 4);
+ JavaRDD<Integer> readRDD = sc.objectFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void objectFilesOfComplexTypes() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+ rdd.saveAsObjectFile(outputDir);
+ // Try reading the output back as an object file
+ JavaRDD<Tuple2<Integer, String>> readRDD = sc.objectFile(outputDir);
+ Assert.assertEquals(pairs, readRDD.collect());
+ }
+
+ @Test
+ public void hadoopFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void hadoopFileCompressed() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
+ DefaultCodec.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void zip() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
+ @Override
+ public Double call(Integer x) {
+ return 1.0 * x;
+ }
+ });
+ JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
+ zipped.count();
+ }
+
+ @Test
+ public void zipPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
+ JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
+ FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
+ new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
+ int sizeI = 0;
+ int sizeS = 0;
+ while (i.hasNext()) {
+ sizeI += 1;
+ i.next();
+ }
+ while (s.hasNext()) {
+ sizeS += 1;
+ s.next();
+ }
+ return Arrays.asList(sizeI, sizeS);
+ }
+ };
+
+ JavaRDD<Integer> sizes = rdd1.zipPartitions(rdd2, sizesFn);
+ Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
+ }
+
+ @Test
+ public void accumulators() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ intAccum.add(x);
+ }
+ });
+ Assert.assertEquals((Integer) 25, intAccum.value());
+
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ doubleAccum.add((double) x);
+ }
+ });
+ Assert.assertEquals((Double) 25.0, doubleAccum.value());
+
+ // Try a custom accumulator type
+ AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
+ public Float addInPlace(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float addAccumulator(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float zero(Float initialValue) {
+ return 0.0f;
+ }
+ };
+
+ final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ floatAccum.add((float) x);
+ }
+ });
+ Assert.assertEquals((Float) 25.0f, floatAccum.value());
+
+ // Test the setValue method
+ floatAccum.setValue(5.0f);
+ Assert.assertEquals((Float) 5.0f, floatAccum.value());
+ }
+
+ @Test
+ public void keyBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
+ List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
+ public String call(Integer t) throws Exception {
+ return t.toString();
+ }
+ }).collect();
+ Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
+ Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
+ }
+
+ @Test
+ public void checkpointAndComputation() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+ }
+
+ @Test
+ public void checkpointAndRestore() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+
+ Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+ JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+ }
+
+ @Test
+ public void mapOnPairRDD() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+ JavaPairRDD<Integer, Integer> rdd3 = rdd2.map(
+ new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception {
+ return new Tuple2<Integer, Integer>(in._2(), in._1());
+ }
+ });
+ Assert.assertEquals(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(0, 2),
+ new Tuple2<Integer, Integer>(1, 3),
+ new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
+
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala
new file mode 100644
index 0000000000..d7b23c93fe
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+import com.esotericsoftware.kryo._
+
+import KryoTest._
+
+class KryoSerializerSuite extends FunSuite with SharedSparkContext {
+ test("basic types") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+ check(1)
+ check(1L)
+ check(1.0f)
+ check(1.0)
+ check(1.toByte)
+ check(1.toShort)
+ check("")
+ check("hello")
+ check(Integer.MAX_VALUE)
+ check(Integer.MIN_VALUE)
+ check(java.lang.Long.MAX_VALUE)
+ check(java.lang.Long.MIN_VALUE)
+ check[String](null)
+ check(Array(1, 2, 3))
+ check(Array(1L, 2L, 3L))
+ check(Array(1.0, 2.0, 3.0))
+ check(Array(1.0f, 2.9f, 3.9f))
+ check(Array("aaa", "bbb", "ccc"))
+ check(Array("aaa", "bbb", null))
+ check(Array(true, false, true))
+ check(Array('a', 'b', 'c'))
+ check(Array[Int]())
+ check(Array(Array("1", "2"), Array("1", "2", "3", "4")))
+ }
+
+ test("pairs") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+ check((1, 1))
+ check((1, 1L))
+ check((1L, 1))
+ check((1L, 1L))
+ check((1.0, 1))
+ check((1, 1.0))
+ check((1.0, 1.0))
+ check((1.0, 1L))
+ check((1L, 1.0))
+ check((1.0, 1L))
+ check(("x", 1))
+ check(("x", 1.0))
+ check(("x", 1L))
+ check((1, "x"))
+ check((1.0, "x"))
+ check((1L, "x"))
+ check(("x", "x"))
+ }
+
+ test("Scala data structures") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+ check(List[Int]())
+ check(List[Int](1, 2, 3))
+ check(List[String]())
+ check(List[String]("x", "y", "z"))
+ check(None)
+ check(Some(1))
+ check(Some("hi"))
+ check(mutable.ArrayBuffer(1, 2, 3))
+ check(mutable.ArrayBuffer("1", "2", "3"))
+ check(mutable.Map())
+ check(mutable.Map(1 -> "one", 2 -> "two"))
+ check(mutable.Map("one" -> 1, "two" -> 2))
+ check(mutable.HashMap(1 -> "one", 2 -> "two"))
+ check(mutable.HashMap("one" -> 1, "two" -> 2))
+ check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4))))
+ check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three")))
+ }
+
+ test("custom registrator") {
+ import KryoTest._
+ System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
+
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+
+ check(CaseClass(17, "hello"))
+
+ val c1 = new ClassWithNoArgConstructor
+ c1.x = 32
+ check(c1)
+
+ val c2 = new ClassWithoutNoArgConstructor(47)
+ check(c2)
+
+ val hashMap = new java.util.HashMap[String, String]
+ hashMap.put("foo", "bar")
+ check(hashMap)
+
+ System.clearProperty("spark.kryo.registrator")
+ }
+
+ test("kryo with collect") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x)
+ assert(control === result.toSeq)
+ }
+
+ test("kryo with parallelize") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect()
+ assert (control === result.toSeq)
+ }
+
+ test("kryo with parallelize for specialized tuples") {
+ assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3)
+ }
+
+ test("kryo with parallelize for primitive arrays") {
+ assert (sc.parallelize( Array(1, 2, 3) ).count === 3)
+ }
+
+ test("kryo with collect for specialized tuples") {
+ assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11))
+ }
+
+ test("kryo with reduce") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
+ .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
+ assert(control.sum === result)
+ }
+
+ // TODO: this still doesn't work
+ ignore("kryo with fold") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
+ .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
+ assert(10 + control.sum === result)
+ }
+
+ override def beforeAll() {
+ System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer")
+ System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ super.afterAll()
+ System.clearProperty("spark.kryo.registrator")
+ System.clearProperty("spark.serializer")
+ }
+}
+
+object KryoTest {
+ case class CaseClass(i: Int, s: String) {}
+
+ class ClassWithNoArgConstructor {
+ var x: Int = 0
+ override def equals(other: Any) = other match {
+ case c: ClassWithNoArgConstructor => x == c.x
+ case _ => false
+ }
+ }
+
+ class ClassWithoutNoArgConstructor(val x: Int) {
+ override def equals(other: Any) = other match {
+ case c: ClassWithoutNoArgConstructor => x == c.x
+ case _ => false
+ }
+ }
+
+ class MyRegistrator extends KryoRegistrator {
+ override def registerClasses(k: Kryo) {
+ k.register(classOf[CaseClass])
+ k.register(classOf[ClassWithNoArgConstructor])
+ k.register(classOf[ClassWithoutNoArgConstructor])
+ k.register(classOf[java.util.HashMap[_, _]])
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
new file mode 100644
index 0000000000..6ec124da9c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.BeforeAndAfterAll
+
+import org.jboss.netty.logging.InternalLoggerFactory
+import org.jboss.netty.logging.Slf4JLoggerFactory
+
+/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */
+trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite =>
+
+ @transient var sc: SparkContext = _
+
+ override def beforeAll() {
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
+ super.beforeAll()
+ }
+
+ override def afterEach() {
+ resetSparkContext()
+ super.afterEach()
+ }
+
+ def resetSparkContext() = {
+ if (sc != null) {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+ }
+
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSpark[T](sc: SparkContext)(f: SparkContext => T) = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
new file mode 100644
index 0000000000..6013320eaa
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import akka.actor._
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.AkkaUtils
+
+class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
+
+ test("compressSize") {
+ assert(MapOutputTracker.compressSize(0L) === 0)
+ assert(MapOutputTracker.compressSize(1L) === 1)
+ assert(MapOutputTracker.compressSize(2L) === 8)
+ assert(MapOutputTracker.compressSize(10L) === 25)
+ assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145)
+ assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218)
+ // This last size is bigger than we can encode in a byte, so check that we just return 255
+ assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255)
+ }
+
+ test("decompressSize") {
+ assert(MapOutputTracker.decompressSize(0) === 0)
+ for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
+ val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size))
+ assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
+ "size " + size + " decompressed to " + size2 + ", which is out of range")
+ }
+ }
+
+ test("master start and stop") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ tracker.stop()
+ }
+
+ test("master register and fetch") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ Array(compressedSize1000, compressedSize10000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ Array(compressedSize10000, compressedSize1000)))
+ val statuses = tracker.getServerStatuses(10, 0)
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000),
+ (BlockManagerId("b", "hostB", 1000, 0), size10000)))
+ tracker.stop()
+ }
+
+ test("master register and unregister and fetch") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ Array(compressedSize1000, compressedSize1000, compressedSize1000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+
+ // As if we had two simulatenous fetch failures
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+
+ // The remaining reduce task might try to grab the output despite the shuffle failure;
+ // this should cause it to fail, and the scheduler will ignore the failure due to the
+ // stage already being aborted.
+ intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ }
+
+ test("remote fetch") {
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ val masterTracker = new MapOutputTracker()
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
+ val slaveTracker = new MapOutputTracker()
+ slaveTracker.trackerActor = slaveSystem.actorFor(
+ "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala
new file mode 100644
index 0000000000..f79752b34e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+
+import com.google.common.io.Files
+import org.apache.spark.SparkContext._
+
+
+class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("groupByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with partitioner") {
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
+
+ test("join") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("leftOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.leftOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (1, Some('x'))),
+ (1, (2, Some('x'))),
+ (2, (1, Some('y'))),
+ (2, (1, Some('z'))),
+ (3, (1, None))
+ ))
+ }
+
+ test("rightOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.rightOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (Some(1), 'x')),
+ (1, (Some(2), 'x')),
+ (2, (Some(1), 'y')),
+ (2, (Some(1), 'z')),
+ (4, (None, 'w'))
+ ))
+ }
+
+ test("join with no matches") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("groupWith") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.groupWith(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
+ (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
+ (3, (ArrayBuffer(1), ArrayBuffer())),
+ (4, (ArrayBuffer(), ArrayBuffer('w')))
+ ))
+ }
+
+ test("zero-partition RDD") {
+ val emptyDir = Files.createTempDir()
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.size == 0)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ }
+
+ test("keys and values") {
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
+
+ test("default partitioner uses partition size") {
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitioner.get === p)
+ }
+
+ test("foldByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.foldByKey(0)(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("foldByKey with mutable result type") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
+ // Fold the values using in-place mutation
+ val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
+ assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
+ // Check that the mutable objects in the original RDD were not changed
+ assert(bufs.collect().toSet === Set(
+ (1, ArrayBuffer(1)),
+ (1, ArrayBuffer(2)),
+ (1, ArrayBuffer(3)),
+ (1, ArrayBuffer(1)),
+ (2, ArrayBuffer(1))))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000000..adbe805916
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
@@ -0,0 +1,28 @@
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.PartitionPruningRDD
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+ test("Pruned Partitions inherit locality prefs correctly") {
+ class TestPartition(i: Int) extends Partition {
+ def index = i
+ }
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(1),
+ new TestPartition(2),
+ new TestPartition(3))
+ }
+ def compute(split: Partition, context: TaskContext) = {Iterator()}
+ }
+ val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
+ val p = prunedRDD.partitions(0)
+ assert(p.index == 2)
+ assert(prunedRDD.partitions.length == 1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
new file mode 100644
index 0000000000..7669cf6fb1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import scala.collection.mutable.ArrayBuffer
+import SparkContext._
+import org.apache.spark.util.StatCounter
+import scala.math.abs
+
+class PartitioningSuite extends FunSuite with SharedSparkContext {
+
+ test("HashPartitioner equality") {
+ val p2 = new HashPartitioner(2)
+ val p4 = new HashPartitioner(4)
+ val anotherP4 = new HashPartitioner(4)
+ assert(p2 === p2)
+ assert(p4 === p4)
+ assert(p2 != p4)
+ assert(p4 != p2)
+ assert(p4 === anotherP4)
+ assert(anotherP4 === p4)
+ }
+
+ test("RangePartitioner equality") {
+ // Make an RDD where all the elements are the same so that the partition range bounds
+ // are deterministically all the same.
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
+
+ val p2 = new RangePartitioner(2, rdd)
+ val p4 = new RangePartitioner(4, rdd)
+ val anotherP4 = new RangePartitioner(4, rdd)
+ val descendingP2 = new RangePartitioner(2, rdd, false)
+ val descendingP4 = new RangePartitioner(4, rdd, false)
+
+ assert(p2 === p2)
+ assert(p4 === p4)
+ assert(p2 != p4)
+ assert(p4 != p2)
+ assert(p4 === anotherP4)
+ assert(anotherP4 === p4)
+ assert(descendingP2 === descendingP2)
+ assert(descendingP4 === descendingP4)
+ assert(descendingP2 != descendingP4)
+ assert(descendingP4 != descendingP2)
+ assert(p2 != descendingP2)
+ assert(p4 != descendingP4)
+ assert(descendingP2 != p2)
+ assert(descendingP4 != p4)
+ }
+
+ test("HashPartitioner not equal to RangePartitioner") {
+ val rdd = sc.parallelize(1 to 10).map(x => (x, x))
+ val rangeP2 = new RangePartitioner(2, rdd)
+ val hashP2 = new HashPartitioner(2)
+ assert(rangeP2 === rangeP2)
+ assert(hashP2 === hashP2)
+ assert(hashP2 != rangeP2)
+ assert(rangeP2 != hashP2)
+ }
+
+ test("partitioner preservation") {
+ val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
+
+ val grouped2 = rdd.groupByKey(2)
+ val grouped4 = rdd.groupByKey(4)
+ val reduced2 = rdd.reduceByKey(_ + _, 2)
+ val reduced4 = rdd.reduceByKey(_ + _, 4)
+
+ assert(rdd.partitioner === None)
+
+ assert(grouped2.partitioner === Some(new HashPartitioner(2)))
+ assert(grouped4.partitioner === Some(new HashPartitioner(4)))
+ assert(reduced2.partitioner === Some(new HashPartitioner(2)))
+ assert(reduced4.partitioner === Some(new HashPartitioner(4)))
+
+ assert(grouped2.groupByKey().partitioner === grouped2.partitioner)
+ assert(grouped2.groupByKey(3).partitioner != grouped2.partitioner)
+ assert(grouped2.groupByKey(2).partitioner === grouped2.partitioner)
+ assert(grouped4.groupByKey().partitioner === grouped4.partitioner)
+ assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner)
+ assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner)
+
+ assert(grouped2.join(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner)
+
+ assert(grouped2.join(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
+
+ assert(grouped2.map(_ => 1).partitioner === None)
+ assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
+ assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
+ assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
+ }
+
+ test("partitioning Java arrays should fail") {
+ val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
+ val arrPairs: RDD[(Array[Int], Int)] =
+ sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
+
+ assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
+ // We can't catch all usages of arrays, since they might occur inside other collections:
+ //assert(fails { arrPairs.distinct() })
+ assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
+ }
+
+ test("zero-length partitions should be correctly handled") {
+ // Create RDD with some consecutive empty partitions (including the "first" one)
+ val rdd: RDD[Double] = sc
+ .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
+ .filter(_ >= 0.0)
+
+ // Run the partitions, including the consecutive empty ones, through StatCounter
+ val stats: StatCounter = rdd.stats();
+ assert(abs(6.0 - stats.sum) < 0.01);
+ assert(abs(6.0/2 - rdd.mean) < 0.01);
+ assert(abs(1.0 - rdd.variance) < 0.01);
+ assert(abs(1.0 - rdd.stdev) < 0.01);
+
+ // Add other tests here for classes that should be able to handle empty partitions correctly
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
new file mode 100644
index 0000000000..2e851d892d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import SparkContext._
+
+class PipedRDDSuite extends FunSuite with SharedSparkContext {
+
+ test("basic pipe") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+
+ val piped = nums.pipe(Seq("cat"))
+
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ }
+
+ test("advanced pipe") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Int, f: String=> Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str=>str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ }
+
+ test("pipe with env variable") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ }
+
+ test("pipe with non-zero exit status") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
+ intercept[SparkException] {
+ piped.collect()
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/RDDSuite.scala b/core/src/test/scala/org/apache/spark/RDDSuite.scala
new file mode 100644
index 0000000000..342ba8adb2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/RDDSuite.scala
@@ -0,0 +1,389 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.HashMap
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import scala.collection.parallel.mutable
+
+class RDDSuite extends FunSuite with SharedSparkContext {
+
+ test("basic operations") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ assert(nums.collect().toList === List(1, 2, 3, 4))
+ val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct.collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
+ assert(nums.reduce(_ + _) === 10)
+ assert(nums.fold(0)(_ + _) === 10)
+ assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
+ assert(nums.filter(_ > 2).collect().toList === List(3, 4))
+ assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
+ assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
+ assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
+ assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
+ val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
+ assert(partitionSums.collect().toList === List(3, 7))
+
+ val partitionSumsWithSplit = nums.mapPartitionsWithSplit {
+ case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
+ }
+ assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ val partitionSumsWithIndex = nums.mapPartitionsWithIndex {
+ case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
+ }
+ assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
+ }
+
+ test("SparkContext.union") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
+ assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
+ assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4))
+ assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
+ }
+
+ test("aggregate") {
+ val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
+ type StringMap = HashMap[String, Int]
+ val emptyMap = new StringMap {
+ override def default(key: String): Int = 0
+ }
+ val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => {
+ map(pair._1) += pair._2
+ map
+ }
+ val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
+ for ((key, value) <- map2) {
+ map1(key) += value
+ }
+ map1
+ }
+ val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps)
+ assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
+ }
+
+ test("basic caching") {
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ }
+
+ test("caching with failures") {
+ val onlySplit = new Partition { override def index: Int = 0 }
+ var shouldFail = true
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getPartitions: Array[Partition] = Array(onlySplit)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+ if (shouldFail) {
+ throw new Exception("injected failure")
+ } else {
+ return Array(1, 2, 3, 4).iterator
+ }
+ }
+ }.cache()
+ val thrown = intercept[Exception]{
+ rdd.collect()
+ }
+ assert(thrown.getMessage.contains("injected failure"))
+ shouldFail = false
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ }
+
+ test("empty RDD") {
+ val empty = new EmptyRDD[Int](sc)
+ assert(empty.count === 0)
+ assert(empty.collect().size === 0)
+
+ val thrown = intercept[UnsupportedOperationException]{
+ empty.reduce(_+_)
+ }
+ assert(thrown.getMessage.contains("empty"))
+
+ val emptyKv = new EmptyRDD[(Int, Int)](sc)
+ val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x))
+ assert(rdd.join(emptyKv).collect().size === 0)
+ assert(rdd.rightOuterJoin(emptyKv).collect().size === 0)
+ assert(rdd.leftOuterJoin(emptyKv).collect().size === 2)
+ assert(rdd.cogroup(emptyKv).collect().size === 2)
+ assert(rdd.union(emptyKv).collect().size === 2)
+ }
+
+ test("cogrouped RDDs") {
+ val data = sc.parallelize(1 to 10, 10)
+
+ val coalesced1 = data.coalesce(2)
+ assert(coalesced1.collect().toList === (1 to 10).toList)
+ assert(coalesced1.glom().collect().map(_.toList).toList ===
+ List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
+
+ // Check that the narrow dependency is also specified correctly
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
+ List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
+ List(5, 6, 7, 8, 9))
+
+ val coalesced2 = data.coalesce(3)
+ assert(coalesced2.collect().toList === (1 to 10).toList)
+ assert(coalesced2.glom().collect().map(_.toList).toList ===
+ List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10)))
+
+ val coalesced3 = data.coalesce(10)
+ assert(coalesced3.collect().toList === (1 to 10).toList)
+ assert(coalesced3.glom().collect().map(_.toList).toList ===
+ (1 to 10).map(x => List(x)).toList)
+
+ // If we try to coalesce into more partitions than the original RDD, it should just
+ // keep the original number of partitions.
+ val coalesced4 = data.coalesce(20)
+ assert(coalesced4.collect().toList === (1 to 10).toList)
+ assert(coalesced4.glom().collect().map(_.toList).toList ===
+ (1 to 10).map(x => List(x)).toList)
+
+ // we can optionally shuffle to keep the upstream parallel
+ val coalesced5 = data.coalesce(1, shuffle = true)
+ assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] !=
+ null)
+ }
+ test("cogrouped RDDs with locality") {
+ val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b"))))
+ val coal3 = data3.coalesce(3)
+ val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped")
+
+ // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5
+ val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)})))
+ val coalesced1 = data.coalesce(3)
+ assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing")
+
+ val splits = coalesced1.glom().collect().map(_.toList).toList
+ assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length)
+
+ assert(splits.forall(_.length >= 1) === true, "Some partitions were empty")
+
+ // If we try to coalesce into more partitions than the original RDD, it should just
+ // keep the original number of partitions.
+ val coalesced4 = data.coalesce(20)
+ val listOfLists = coalesced4.glom().collect().map(_.toList).toList
+ val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) }
+ assert( sortedList === (1 to 9).
+ map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back")
+ }
+
+ test("cogrouped RDDs with locality, large scale (10K partitions)") {
+ // large scale experiment
+ import collection.mutable
+ val rnd = scala.util.Random
+ val partitions = 10000
+ val numMachines = 50
+ val machines = mutable.ListBuffer[String]()
+ (1 to numMachines).foreach(machines += "m"+_)
+
+ val blocks = (1 to partitions).map(i =>
+ { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } )
+
+ val data2 = sc.makeRDD(blocks)
+ val coalesced2 = data2.coalesce(numMachines*2)
+
+ // test that you get over 90% locality in each group
+ val minLocality = coalesced2.partitions
+ .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
+ .foldLeft(1.)((perc, loc) => math.min(perc,loc))
+ assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%")
+
+ // test that the groups are load balanced with 100 +/- 20 elements in each
+ val maxImbalance = coalesced2.partitions
+ .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size)
+ .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev))
+ assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance)
+
+ val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs
+ val coalesced3 = data3.coalesce(numMachines*2)
+ val minLocality2 = coalesced3.partitions
+ .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
+ .foldLeft(1.)((perc, loc) => math.min(perc,loc))
+ assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " +
+ (minLocality2*100.).toInt + "%")
+ }
+
+ test("zipped RDDs") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val zipped = nums.zip(nums.map(_ + 1.0))
+ assert(zipped.glom().map(_.toList).collect().toList ===
+ List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
+
+ intercept[IllegalArgumentException] {
+ nums.zip(sc.parallelize(1 to 4, 1)).collect()
+ }
+ }
+
+ test("partition pruning") {
+ val data = sc.parallelize(1 to 10, 10)
+ // Note that split number starts from 0, so > 8 means only 10th partition left.
+ val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
+ assert(prunedRdd.partitions.size === 1)
+ val prunedData = prunedRdd.collect()
+ assert(prunedData.size === 1)
+ assert(prunedData(0) === 10)
+ }
+
+ test("mapWith") {
+ import java.util.Random
+ val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
+ val randoms = ones.mapWith(
+ (index: Int) => new Random(index + 42))
+ {(t: Int, prng: Random) => prng.nextDouble * t}.collect()
+ val prn42_3 = {
+ val prng42 = new Random(42)
+ prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
+ }
+ val prn43_3 = {
+ val prng43 = new Random(43)
+ prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
+ }
+ assert(randoms(2) === prn42_3)
+ assert(randoms(5) === prn43_3)
+ }
+
+ test("flatMapWith") {
+ import java.util.Random
+ val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
+ val randoms = ones.flatMapWith(
+ (index: Int) => new Random(index + 42))
+ {(t: Int, prng: Random) =>
+ val random = prng.nextDouble()
+ Seq(random * t, random * t * 10)}.
+ collect()
+ val prn42_3 = {
+ val prng42 = new Random(42)
+ prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
+ }
+ val prn43_3 = {
+ val prng43 = new Random(43)
+ prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
+ }
+ assert(randoms(5) === prn42_3 * 10)
+ assert(randoms(11) === prn43_3 * 10)
+ }
+
+ test("filterWith") {
+ import java.util.Random
+ val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
+ val sample = ints.filterWith(
+ (index: Int) => new Random(index + 42))
+ {(t: Int, prng: Random) => prng.nextInt(3) == 0}.
+ collect()
+ val checkSample = {
+ val prng42 = new Random(42)
+ val prng43 = new Random(43)
+ Array(1, 2, 3, 4, 5, 6).filter{i =>
+ if (i < 4) 0 == prng42.nextInt(3)
+ else 0 == prng43.nextInt(3)}
+ }
+ assert(sample.size === checkSample.size)
+ for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
+ }
+
+ test("top with predefined ordering") {
+ val nums = Array.range(1, 100000)
+ val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
+ val topK = ints.top(5)
+ assert(topK.size === 5)
+ assert(topK === nums.reverse.take(5))
+ }
+
+ test("top with custom ordering") {
+ val words = Vector("a", "b", "c", "d")
+ implicit val ord = implicitly[Ordering[String]].reverse
+ val rdd = sc.makeRDD(words, 2)
+ val topK = rdd.top(2)
+ assert(topK.size === 2)
+ assert(topK.sorted === Array("b", "a"))
+ }
+
+ test("takeOrdered with predefined ordering") {
+ val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ val rdd = sc.makeRDD(nums, 2)
+ val sortedLowerK = rdd.takeOrdered(5)
+ assert(sortedLowerK.size === 5)
+ assert(sortedLowerK === Array(1, 2, 3, 4, 5))
+ }
+
+ test("takeOrdered with custom ordering") {
+ val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ implicit val ord = implicitly[Ordering[Int]].reverse
+ val rdd = sc.makeRDD(nums, 2)
+ val sortedTopK = rdd.takeOrdered(5)
+ assert(sortedTopK.size === 5)
+ assert(sortedTopK === Array(10, 9, 8, 7, 6))
+ assert(sortedTopK === nums.sorted(ord).take(5))
+ }
+
+ test("takeSample") {
+ val data = sc.parallelize(1 to 100, 2)
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=false, 20, seed)
+ assert(sample.size === 20) // Got exactly 20 elements
+ assert(sample.toSet.size === 20) // Elements are distinct
+ assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=false, 200, seed)
+ assert(sample.size === 100) // Got only 100 elements
+ assert(sample.toSet.size === 100) // Elements are distinct
+ assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=true, 20, seed)
+ assert(sample.size === 20) // Got exactly 20 elements
+ assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=true, 100, seed)
+ assert(sample.size === 100) // Got exactly 100 elements
+ // Chance of getting all distinct elements is astronomically low, so test we got < 100
+ assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=true, 200, seed)
+ assert(sample.size === 200) // Got exactly 200 elements
+ // Chance of getting all distinct elements is still quite low, so test we got < 100
+ assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ }
+ }
+
+ test("runJob on an invalid partition") {
+ intercept[IllegalArgumentException] {
+ sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
new file mode 100644
index 0000000000..97cbca09bf
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
+trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
+
+ @transient private var _sc: SparkContext = _
+
+ def sc: SparkContext = _sc
+
+ override def beforeAll() {
+ _sc = new SparkContext("local", "test")
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (_sc != null) {
+ LocalSparkContext.stop(_sc)
+ _sc = null
+ }
+ super.afterAll()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
new file mode 100644
index 0000000000..e121b162ad
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.BeforeAndAfterAll
+
+
+class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
+
+ override def beforeAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "true")
+ }
+
+ override def afterAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "false")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
new file mode 100644
index 0000000000..357175e89e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
+import org.apache.spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD}
+import org.apache.spark.util.MutablePair
+
+
+class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+ test("groupByKey without compression") {
+ try {
+ System.setProperty("spark.shuffle.compress", "false")
+ sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
+ val groups = pairs.groupByKey(4).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ } finally {
+ System.setProperty("spark.shuffle.compress", "true")
+ }
+ }
+
+ test("shuffle non-zero block size") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val NUM_BLOCKS = 3
+
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new NonJavaSerializableClass(x * 2))
+ }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
+ b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName)
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+
+ assert(c.count === 10)
+
+ // All blocks must have non-zero size
+ (0 until NUM_BLOCKS).foreach { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ assert(statuses.forall(s => s._2 > 0))
+ }
+ }
+
+ test("shuffle serializer") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new NonJavaSerializableClass(x * 2))
+ }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
+ b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName)
+ assert(c.count === 10)
+ }
+
+ test("zero sized blocks") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer doesn't create zero-sized blocks.
+ // So, use Kryo
+ val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
+ .setSerializer(classOf[KryoSerializer].getName)
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
+
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer should create zero-sized blocks
+ val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
+
+ test("shuffle using mutable pairs") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+ val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
+ val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
+ val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2))
+ .collect()
+
+ data.foreach { pair => results should contain (pair) }
+ }
+
+ test("sorting using mutable pairs") {
+ // This is not in SortingSuite because of the local cluster setup.
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+ val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22))
+ val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
+ val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs)
+ .sortByKey().collect()
+ results(0) should be (p(1, 11))
+ results(1) should be (p(2, 22))
+ results(2) should be (p(3, 33))
+ results(3) should be (p(100, 100))
+ }
+
+ test("cogroup using mutable pairs") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+ val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
+ val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3"))
+ val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2)
+ val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2)
+ val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap()
+
+ assert(results(1)(0).length === 3)
+ assert(results(1)(0).contains(1))
+ assert(results(1)(0).contains(2))
+ assert(results(1)(0).contains(3))
+ assert(results(1)(1).length === 2)
+ assert(results(1)(1).contains("11"))
+ assert(results(1)(1).contains("12"))
+ assert(results(2)(0).length === 1)
+ assert(results(2)(0).contains(1))
+ assert(results(2)(1).length === 1)
+ assert(results(2)(1).contains("22"))
+ assert(results(3)(0).length === 0)
+ assert(results(3)(1).contains("3"))
+ }
+
+ test("subtract mutable pairs") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+ val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33))
+ val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"))
+ val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2)
+ val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2)
+ val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect()
+ results should have length (1)
+ // substracted rdd return results as Tuple2
+ results(0) should be ((3, 33))
+ }
+}
+
+object ShuffleSuite {
+
+ def mergeCombineException(x: Int, y: Int): Int = {
+ throw new SparkException("Exception for map-side combine.")
+ x + y
+ }
+
+ class NonJavaSerializableClass(val value: Int)
+}
diff --git a/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala
new file mode 100644
index 0000000000..214ac74898
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.PrivateMethodTester
+
+class DummyClass1 {}
+
+class DummyClass2 {
+ val x: Int = 0
+}
+
+class DummyClass3 {
+ val x: Int = 0
+ val y: Double = 0.0
+}
+
+class DummyClass4(val d: DummyClass3) {
+ val x: Int = 0
+}
+
+object DummyString {
+ def apply(str: String) : DummyString = new DummyString(str.toArray)
+}
+class DummyString(val arr: Array[Char]) {
+ override val hashCode: Int = 0
+ // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f
+ @transient val hash32: Int = 0
+}
+
+class SizeEstimatorSuite
+ extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {
+
+ var oldArch: String = _
+ var oldOops: String = _
+
+ override def beforeAll() {
+ // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
+ oldArch = System.setProperty("os.arch", "amd64")
+ oldOops = System.setProperty("spark.test.useCompressedOops", "true")
+ }
+
+ override def afterAll() {
+ resetOrClear("os.arch", oldArch)
+ resetOrClear("spark.test.useCompressedOops", oldOops)
+ }
+
+ test("simple classes") {
+ assert(SizeEstimator.estimate(new DummyClass1) === 16)
+ assert(SizeEstimator.estimate(new DummyClass2) === 16)
+ assert(SizeEstimator.estimate(new DummyClass3) === 24)
+ assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
+ assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
+ }
+
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
+ test("strings") {
+ assert(SizeEstimator.estimate(DummyString("")) === 40)
+ assert(SizeEstimator.estimate(DummyString("a")) === 48)
+ assert(SizeEstimator.estimate(DummyString("ab")) === 48)
+ assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
+ }
+
+ test("primitive arrays") {
+ assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
+ assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
+ assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
+ assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
+ assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
+ assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
+ assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
+ assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
+ assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
+ }
+
+ test("object arrays") {
+ // Arrays containing nulls should just have one pointer per element
+ assert(SizeEstimator.estimate(new Array[String](10)) === 56)
+ assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
+
+ // For object arrays with non-null elements, each object should take one pointer plus
+ // however many bytes that class takes. (Note that Array.fill calls the code in its
+ // second parameter separately for each object, so we get distinct objects.)
+ assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
+ assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
+ assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
+ assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
+
+ // Past size 100, our samples 100 elements, but we should still get the right size.
+ assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
+
+ // If an array contains the *same* element many times, we should only count it once.
+ val d1 = new DummyClass1
+ assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
+ assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
+
+ // Same thing with huge array containing the same element many times. Note that this won't
+ // return exactly 4032 because it can't tell that *all* the elements will equal the first
+ // one it samples, but it should be close to that.
+
+ // TODO: If we sample 100 elements, this should always be 4176 ?
+ val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1))
+ assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000")
+ assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100")
+ }
+
+ test("32-bit arch") {
+ val arch = System.setProperty("os.arch", "x86")
+
+ val initialize = PrivateMethod[Unit]('initialize)
+ SizeEstimator invokePrivate initialize()
+
+ assert(SizeEstimator.estimate(DummyString("")) === 40)
+ assert(SizeEstimator.estimate(DummyString("a")) === 48)
+ assert(SizeEstimator.estimate(DummyString("ab")) === 48)
+ assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
+
+ resetOrClear("os.arch", arch)
+ }
+
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
+ test("64-bit arch with no compressed oops") {
+ val arch = System.setProperty("os.arch", "amd64")
+ val oops = System.setProperty("spark.test.useCompressedOops", "false")
+
+ val initialize = PrivateMethod[Unit]('initialize)
+ SizeEstimator invokePrivate initialize()
+
+ assert(SizeEstimator.estimate(DummyString("")) === 56)
+ assert(SizeEstimator.estimate(DummyString("a")) === 64)
+ assert(SizeEstimator.estimate(DummyString("ab")) === 64)
+ assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
+
+ resetOrClear("os.arch", arch)
+ resetOrClear("spark.test.useCompressedOops", oops)
+ }
+
+ def resetOrClear(prop: String, oldValue: String) {
+ if (oldValue != null) {
+ System.setProperty(prop, oldValue)
+ } else {
+ System.clearProperty(prop)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/SortingSuite.scala b/core/src/test/scala/org/apache/spark/SortingSuite.scala
new file mode 100644
index 0000000000..f4fa9511dd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SortingSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.ShouldMatchers
+import SparkContext._
+
+class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
+
+ test("sortByKey") {
+ val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
+ }
+
+ test("large array") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ val sorted = pairs.sortByKey()
+ assert(sorted.partitions.size === 2)
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ }
+
+ test("large array with one split") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ val sorted = pairs.sortByKey(true, 1)
+ assert(sorted.partitions.size === 1)
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ }
+
+ test("large array with many partitions") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ val sorted = pairs.sortByKey(true, 20)
+ assert(sorted.partitions.size === 20)
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ }
+
+ test("sort descending") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+ }
+
+ test("sort descending with one split") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 1)
+ assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+ }
+
+ test("sort descending with many partitions") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+ }
+
+ test("more partitions than elements") {
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 30)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ }
+
+ test("empty RDD") {
+ val pairArr = new Array[(Int, Int)](0)
+ val pairs = sc.parallelize(pairArr, 2)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ }
+
+ test("partition balancing") {
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 4).sortByKey()
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ val partitions = sorted.collectPartitions()
+ logInfo("Partition lengths: " + partitions.map(_.length).mkString(", "))
+ partitions(0).length should be > 180
+ partitions(1).length should be > 180
+ partitions(2).length should be > 180
+ partitions(3).length should be > 180
+ partitions(0).last should be < partitions(1).head
+ partitions(1).last should be < partitions(2).head
+ partitions(2).last should be < partitions(3).head
+ }
+
+ test("partition balancing for descending sort") {
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
+ assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
+ val partitions = sorted.collectPartitions()
+ logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
+ partitions(0).length should be > 180
+ partitions(1).length should be > 180
+ partitions(2).length should be > 180
+ partitions(3).length should be > 180
+ partitions(0).last should be > partitions(1).head
+ partitions(1).last should be > partitions(2).head
+ partitions(2).last should be > partitions(3).head
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
new file mode 100644
index 0000000000..939fe51801
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.apache.spark.SparkContext._
+
+class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
+ test("getPersistentRDDs only returns RDDs that are marked as cached") {
+ sc = new SparkContext("local", "test")
+ assert(sc.getPersistentRDDs.isEmpty === true)
+
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ assert(sc.getPersistentRDDs.isEmpty === true)
+
+ rdd.cache()
+ assert(sc.getPersistentRDDs.size === 1)
+ assert(sc.getPersistentRDDs.values.head === rdd)
+ }
+
+ test("getPersistentRDDs returns an immutable map") {
+ sc = new SparkContext("local", "test")
+ val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+
+ val myRdds = sc.getPersistentRDDs
+ assert(myRdds.size === 1)
+ assert(myRdds.values.head === rdd1)
+
+ val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
+
+ // getPersistentRDDs should have 2 RDDs, but myRdds should not change
+ assert(sc.getPersistentRDDs.size === 2)
+ assert(myRdds.size === 1)
+ }
+
+ test("getRDDStorageInfo only reports on RDDs that actually persist data") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+
+ assert(sc.getRDDStorageInfo.size === 0)
+
+ rdd.collect()
+ assert(sc.getRDDStorageInfo.size === 1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
new file mode 100644
index 0000000000..69383ddfb8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.Semaphore
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import SparkContext._
+
+/**
+ * Holds state shared across task threads in some ThreadingSuite tests.
+ */
+object ThreadingSuiteState {
+ val runningThreads = new AtomicInteger
+ val failed = new AtomicBoolean
+
+ def clear() {
+ runningThreads.set(0)
+ failed.set(false)
+ }
+}
+
+class ThreadingSuite extends FunSuite with LocalSparkContext {
+
+ test("accessing SparkContext form a different thread") {
+ sc = new SparkContext("local", "test")
+ val nums = sc.parallelize(1 to 10, 2)
+ val sem = new Semaphore(0)
+ @volatile var answer1: Int = 0
+ @volatile var answer2: Int = 0
+ new Thread {
+ override def run() {
+ answer1 = nums.reduce(_ + _)
+ answer2 = nums.first() // This will run "locally" in the current thread
+ sem.release()
+ }
+ }.start()
+ sem.acquire()
+ assert(answer1 === 55)
+ assert(answer2 === 1)
+ }
+
+ test("accessing SparkContext form multiple threads") {
+ sc = new SparkContext("local", "test")
+ val nums = sc.parallelize(1 to 10, 2)
+ val sem = new Semaphore(0)
+ @volatile var ok = true
+ for (i <- 0 until 10) {
+ new Thread {
+ override def run() {
+ val answer1 = nums.reduce(_ + _)
+ if (answer1 != 55) {
+ printf("In thread %d: answer1 was %d\n", i, answer1)
+ ok = false
+ }
+ val answer2 = nums.first() // This will run "locally" in the current thread
+ if (answer2 != 1) {
+ printf("In thread %d: answer2 was %d\n", i, answer2)
+ ok = false
+ }
+ sem.release()
+ }
+ }.start()
+ }
+ sem.acquire(10)
+ if (!ok) {
+ fail("One or more threads got the wrong answer from an RDD operation")
+ }
+ }
+
+ test("accessing multi-threaded SparkContext form multiple threads") {
+ sc = new SparkContext("local[4]", "test")
+ val nums = sc.parallelize(1 to 10, 2)
+ val sem = new Semaphore(0)
+ @volatile var ok = true
+ for (i <- 0 until 10) {
+ new Thread {
+ override def run() {
+ val answer1 = nums.reduce(_ + _)
+ if (answer1 != 55) {
+ printf("In thread %d: answer1 was %d\n", i, answer1)
+ ok = false
+ }
+ val answer2 = nums.first() // This will run "locally" in the current thread
+ if (answer2 != 1) {
+ printf("In thread %d: answer2 was %d\n", i, answer2)
+ ok = false
+ }
+ sem.release()
+ }
+ }.start()
+ }
+ sem.acquire(10)
+ if (!ok) {
+ fail("One or more threads got the wrong answer from an RDD operation")
+ }
+ }
+
+ test("parallel job execution") {
+ // This test launches two jobs with two threads each on a 4-core local cluster. Each thread
+ // waits until there are 4 threads running at once, to test that both jobs have been launched.
+ sc = new SparkContext("local[4]", "test")
+ val nums = sc.parallelize(1 to 2, 2)
+ val sem = new Semaphore(0)
+ ThreadingSuiteState.clear()
+ for (i <- 0 until 2) {
+ new Thread {
+ override def run() {
+ val ans = nums.map(number => {
+ val running = ThreadingSuiteState.runningThreads
+ running.getAndIncrement()
+ val time = System.currentTimeMillis()
+ while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
+ Thread.sleep(100)
+ }
+ if (running.get() != 4) {
+ println("Waited 1 second without seeing runningThreads = 4 (it was " +
+ running.get() + "); failing test")
+ ThreadingSuiteState.failed.set(true)
+ }
+ number
+ }).collect()
+ assert(ans.toList === List(1, 2))
+ sem.release()
+ }
+ }.start()
+ }
+ sem.acquire(2)
+ if (ThreadingSuiteState.failed.get()) {
+ fail("One or more threads didn't see runningThreads = 4")
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
new file mode 100644
index 0000000000..46a2da1724
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
+import org.apache.spark.SparkContext._
+
+class UnpersistSuite extends FunSuite with LocalSparkContext {
+ test("unpersist RDD") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ rdd.count
+ assert(sc.persistentRdds.isEmpty === false)
+ rdd.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty === true)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/UtilsSuite.scala
new file mode 100644
index 0000000000..3a908720a8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/UtilsSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File}
+import org.scalatest.FunSuite
+import org.apache.commons.io.FileUtils
+import scala.util.Random
+
+class UtilsSuite extends FunSuite {
+
+ test("bytesToString") {
+ assert(Utils.bytesToString(10) === "10.0 B")
+ assert(Utils.bytesToString(1500) === "1500.0 B")
+ assert(Utils.bytesToString(2000000) === "1953.1 KB")
+ assert(Utils.bytesToString(2097152) === "2.0 MB")
+ assert(Utils.bytesToString(2306867) === "2.2 MB")
+ assert(Utils.bytesToString(5368709120L) === "5.0 GB")
+ assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB")
+ }
+
+ test("copyStream") {
+ //input array initialization
+ val bytes = Array.ofDim[Byte](9000)
+ Random.nextBytes(bytes)
+
+ val os = new ByteArrayOutputStream()
+ Utils.copyStream(new ByteArrayInputStream(bytes), os)
+
+ assert(os.toByteArray.toList.equals(bytes.toList))
+ }
+
+ test("memoryStringToMb") {
+ assert(Utils.memoryStringToMb("1") === 0)
+ assert(Utils.memoryStringToMb("1048575") === 0)
+ assert(Utils.memoryStringToMb("3145728") === 3)
+
+ assert(Utils.memoryStringToMb("1024k") === 1)
+ assert(Utils.memoryStringToMb("5000k") === 4)
+ assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K"))
+
+ assert(Utils.memoryStringToMb("1024m") === 1024)
+ assert(Utils.memoryStringToMb("5000m") === 5000)
+ assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M"))
+
+ assert(Utils.memoryStringToMb("2g") === 2048)
+ assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G"))
+
+ assert(Utils.memoryStringToMb("2t") === 2097152)
+ assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T"))
+ }
+
+ test("splitCommandString") {
+ assert(Utils.splitCommandString("") === Seq())
+ assert(Utils.splitCommandString("a") === Seq("a"))
+ assert(Utils.splitCommandString("aaa") === Seq("aaa"))
+ assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("'b c'") === Seq("b c"))
+ assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("\"b c\"") === Seq("b c"))
+ assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e"))
+ assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d"))
+ assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c"))
+ assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c"))
+ assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c"))
+ assert(Utils.splitCommandString("'a'b") === Seq("ab"))
+ assert(Utils.splitCommandString("'a''b'") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"b") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab"))
+ assert(Utils.splitCommandString("''") === Seq(""))
+ assert(Utils.splitCommandString("\"\"") === Seq(""))
+ }
+
+ test("string formatting of time durations") {
+ val second = 1000
+ val minute = second * 60
+ val hour = minute * 60
+ def str = Utils.msDurationToString(_)
+
+ assert(str(123) === "123 ms")
+ assert(str(second) === "1.0 s")
+ assert(str(second + 462) === "1.5 s")
+ assert(str(hour) === "1.00 h")
+ assert(str(minute) === "1.0 m")
+ assert(str(minute + 4 * second + 34) === "1.1 m")
+ assert(str(10 * hour + minute + 4 * second) === "10.02 h")
+ assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h")
+ }
+
+ test("reading offset bytes of a file") {
+ val tmpDir2 = Files.createTempDir()
+ val f1Path = tmpDir2 + "/f1"
+ val f1 = new FileOutputStream(f1Path)
+ f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8))
+ f1.close()
+
+ // Read first few bytes
+ assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3")
+
+ // Read some middle bytes
+ assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6")
+
+ // Read last few bytes
+ assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n")
+
+ // Read some nonexistent bytes in the beginning
+ assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3")
+
+ // Read some nonexistent bytes at the end
+ assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n")
+
+ // Read some nonexistent bytes on both ends
+ assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n")
+
+ FileUtils.deleteDirectory(tmpDir2)
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
new file mode 100644
index 0000000000..618b9c113b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.immutable.NumericRange
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+
+object ZippedPartitionsSuite {
+ def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
+ Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
+ }
+}
+
+class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
+ test("print sizes") {
+ val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
+ val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
+
+ val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData)
+
+ val obtainedSizes = zippedRDD.collect()
+ val expectedSizes = Array(2, 3, 1, 2, 3, 1)
+ assert(obtainedSizes.size == 6)
+ assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
new file mode 100644
index 0000000000..fd6f69041a
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.io
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import org.scalatest.FunSuite
+
+
+class CompressionCodecSuite extends FunSuite {
+
+ def testCodec(codec: CompressionCodec) {
+ // Write 1000 integers to the output stream, compressed.
+ val outputStream = new ByteArrayOutputStream()
+ val out = codec.compressedOutputStream(outputStream)
+ for (i <- 1 until 1000) {
+ out.write(i % 256)
+ }
+ out.close()
+
+ // Read the 1000 integers back.
+ val inputStream = new ByteArrayInputStream(outputStream.toByteArray)
+ val in = codec.compressedInputStream(inputStream)
+ for (i <- 1 until 1000) {
+ assert(in.read() === i % 256)
+ }
+ in.close()
+ }
+
+ test("default compression codec") {
+ val codec = CompressionCodec.createCodec()
+ assert(codec.getClass === classOf[SnappyCompressionCodec])
+ testCodec(codec)
+ }
+
+ test("lzf compression codec") {
+ val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName)
+ assert(codec.getClass === classOf[LZFCompressionCodec])
+ testCodec(codec)
+ }
+
+ test("snappy compression codec") {
+ val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName)
+ assert(codec.getClass === classOf[SnappyCompressionCodec])
+ testCodec(codec)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala
new file mode 100644
index 0000000000..58c94a162d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+class MetricsConfigSuite extends FunSuite with BeforeAndAfter {
+ var filePath: String = _
+
+ before {
+ filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile()
+ }
+
+ test("MetricsConfig with default properties") {
+ val conf = new MetricsConfig(Option("dummy-file"))
+ conf.initialize()
+
+ assert(conf.properties.size() === 5)
+ assert(conf.properties.getProperty("test-for-dummy") === null)
+
+ val property = conf.getInstance("random")
+ assert(property.size() === 3)
+ assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet")
+ assert(property.getProperty("sink.servlet.uri") === "/metrics/json")
+ assert(property.getProperty("sink.servlet.sample") === "false")
+ }
+
+ test("MetricsConfig with properties set") {
+ val conf = new MetricsConfig(Option(filePath))
+ conf.initialize()
+
+ val masterProp = conf.getInstance("master")
+ assert(masterProp.size() === 6)
+ assert(masterProp.getProperty("sink.console.period") === "20")
+ assert(masterProp.getProperty("sink.console.unit") === "minutes")
+ assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource")
+ assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet")
+ assert(masterProp.getProperty("sink.servlet.uri") === "/metrics/master/json")
+ assert(masterProp.getProperty("sink.servlet.sample") === "false")
+
+ val workerProp = conf.getInstance("worker")
+ assert(workerProp.size() === 6)
+ assert(workerProp.getProperty("sink.console.period") === "10")
+ assert(workerProp.getProperty("sink.console.unit") === "seconds")
+ assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource")
+ assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet")
+ assert(workerProp.getProperty("sink.servlet.uri") === "/metrics/json")
+ assert(workerProp.getProperty("sink.servlet.sample") === "false")
+ }
+
+ test("MetricsConfig with subProperties") {
+ val conf = new MetricsConfig(Option(filePath))
+ conf.initialize()
+
+ val propCategories = conf.propertyCategories
+ assert(propCategories.size === 3)
+
+ val masterProp = conf.getInstance("master")
+ val sourceProps = conf.subProperties(masterProp, MetricsSystem.SOURCE_REGEX)
+ assert(sourceProps.size === 1)
+ assert(sourceProps("jvm").getProperty("class") === "org.apache.spark.metrics.source.JvmSource")
+
+ val sinkProps = conf.subProperties(masterProp, MetricsSystem.SINK_REGEX)
+ assert(sinkProps.size === 2)
+ assert(sinkProps.contains("console"))
+ assert(sinkProps.contains("servlet"))
+
+ val consoleProps = sinkProps("console")
+ assert(consoleProps.size() === 2)
+
+ val servletProps = sinkProps("servlet")
+ assert(servletProps.size() === 3)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
new file mode 100644
index 0000000000..7181333adf
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.spark.deploy.master.MasterSource
+
+class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
+ var filePath: String = _
+
+ before {
+ filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile()
+ System.setProperty("spark.metrics.conf", filePath)
+ }
+
+ test("MetricsSystem with default config") {
+ val metricsSystem = MetricsSystem.createMetricsSystem("default")
+ val sources = metricsSystem.sources
+ val sinks = metricsSystem.sinks
+
+ assert(sources.length === 0)
+ assert(sinks.length === 0)
+ assert(!metricsSystem.getServletHandlers.isEmpty)
+ }
+
+ test("MetricsSystem with sources add") {
+ val metricsSystem = MetricsSystem.createMetricsSystem("test")
+ val sources = metricsSystem.sources
+ val sinks = metricsSystem.sinks
+
+ assert(sources.length === 0)
+ assert(sinks.length === 1)
+ assert(!metricsSystem.getServletHandlers.isEmpty)
+
+ val source = new MasterSource(null)
+ metricsSystem.registerSource(source)
+ assert(sources.length === 1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
new file mode 100644
index 0000000000..3d39a31252
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.{ BeforeAndAfter, FunSuite }
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.JdbcRDD
+import java.sql._
+
+class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ before {
+ Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
+ val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
+ try {
+ val create = conn.createStatement
+ create.execute("""
+ CREATE TABLE FOO(
+ ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
+ DATA INTEGER
+ )""")
+ create.close
+ val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
+ (1 to 100).foreach { i =>
+ insert.setInt(1, i * 2)
+ insert.executeUpdate
+ }
+ insert.close
+ } catch {
+ case e: SQLException if e.getSQLState == "X0Y32" =>
+ // table exists
+ } finally {
+ conn.close
+ }
+ }
+
+ test("basic functionality") {
+ sc = new SparkContext("local", "test")
+ val rdd = new JdbcRDD(
+ sc,
+ () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
+ "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
+ 1, 100, 3,
+ (r: ResultSet) => { r.getInt(1) } ).cache
+
+ assert(rdd.count === 100)
+ assert(rdd.reduce(_+_) === 10100)
+ }
+
+ after {
+ try {
+ DriverManager.getConnection("jdbc:derby:;shutdown=true")
+ } catch {
+ case se: SQLException if se.getSQLState == "XJ015" =>
+ // normal shutdown
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
new file mode 100644
index 0000000000..a80afdee7e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.immutable.NumericRange
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+class ParallelCollectionSplitSuite extends FunSuite with Checkers {
+ test("one element per slice") {
+ val data = Array(1, 2, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === "1")
+ assert(slices(1).mkString(",") === "2")
+ assert(slices(2).mkString(",") === "3")
+ }
+
+ test("one slice") {
+ val data = Array(1, 2, 3)
+ val slices = ParallelCollectionRDD.slice(data, 1)
+ assert(slices.size === 1)
+ assert(slices(0).mkString(",") === "1,2,3")
+ }
+
+ test("equal slices") {
+ val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9)
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === "1,2,3")
+ assert(slices(1).mkString(",") === "4,5,6")
+ assert(slices(2).mkString(",") === "7,8,9")
+ }
+
+ test("non-equal slices") {
+ val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === "1,2,3")
+ assert(slices(1).mkString(",") === "4,5,6")
+ assert(slices(2).mkString(",") === "7,8,9,10")
+ }
+
+ test("splitting exclusive range") {
+ val data = 0 until 100
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === (0 to 32).mkString(","))
+ assert(slices(1).mkString(",") === (33 to 65).mkString(","))
+ assert(slices(2).mkString(",") === (66 to 99).mkString(","))
+ }
+
+ test("splitting inclusive range") {
+ val data = 0 to 100
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === (0 to 32).mkString(","))
+ assert(slices(1).mkString(",") === (33 to 66).mkString(","))
+ assert(slices(2).mkString(",") === (67 to 100).mkString(","))
+ }
+
+ test("empty data") {
+ val data = new Array[Int](0)
+ val slices = ParallelCollectionRDD.slice(data, 5)
+ assert(slices.size === 5)
+ for (slice <- slices) assert(slice.size === 0)
+ }
+
+ test("zero slices") {
+ val data = Array(1, 2, 3)
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) }
+ }
+
+ test("negative number of slices") {
+ val data = Array(1, 2, 3)
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) }
+ }
+
+ test("exclusive ranges sliced into ranges") {
+ val data = 1 until 100
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 99)
+ assert(slices.forall(_.isInstanceOf[Range]))
+ }
+
+ test("inclusive ranges sliced into ranges") {
+ val data = 1 to 100
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 100)
+ assert(slices.forall(_.isInstanceOf[Range]))
+ }
+
+ test("large ranges don't overflow") {
+ val N = 100 * 1000 * 1000
+ val data = 0 until N
+ val slices = ParallelCollectionRDD.slice(data, 40)
+ assert(slices.size === 40)
+ for (i <- 0 until 40) {
+ assert(slices(i).isInstanceOf[Range])
+ val range = slices(i).asInstanceOf[Range]
+ assert(range.start === i * (N / 40), "slice " + i + " start")
+ assert(range.end === (i+1) * (N / 40), "slice " + i + " end")
+ assert(range.step === 1, "slice " + i + " step")
+ }
+ }
+
+ test("random array tests") {
+ val gen = for {
+ d <- arbitrary[List[Int]]
+ n <- Gen.choose(1, 100)
+ } yield (d, n)
+ val prop = forAll(gen) {
+ (tuple: (List[Int], Int)) =>
+ val d = tuple._1
+ val n = tuple._2
+ val slices = ParallelCollectionRDD.slice(d, n)
+ ("n slices" |: slices.size == n) &&
+ ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
+ ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
+ }
+ check(prop)
+ }
+
+ test("random exclusive range tests") {
+ val gen = for {
+ a <- Gen.choose(-100, 100)
+ b <- Gen.choose(-100, 100)
+ step <- Gen.choose(-5, 5) suchThat (_ != 0)
+ n <- Gen.choose(1, 100)
+ } yield (a until b by step, n)
+ val prop = forAll(gen) {
+ case (d: Range, n: Int) =>
+ val slices = ParallelCollectionRDD.slice(d, n)
+ ("n slices" |: slices.size == n) &&
+ ("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
+ ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
+ ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
+ }
+ check(prop)
+ }
+
+ test("random inclusive range tests") {
+ val gen = for {
+ a <- Gen.choose(-100, 100)
+ b <- Gen.choose(-100, 100)
+ step <- Gen.choose(-5, 5) suchThat (_ != 0)
+ n <- Gen.choose(1, 100)
+ } yield (a to b by step, n)
+ val prop = forAll(gen) {
+ case (d: Range, n: Int) =>
+ val slices = ParallelCollectionRDD.slice(d, n)
+ ("n slices" |: slices.size == n) &&
+ ("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
+ ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
+ ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
+ }
+ check(prop)
+ }
+
+ test("exclusive ranges of longs") {
+ val data = 1L until 100L
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 99)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+
+ test("inclusive ranges of longs") {
+ val data = 1L to 100L
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 100)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+
+ test("exclusive ranges of doubles") {
+ val data = 1.0 until 100.0 by 1.0
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 99)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+
+ test("inclusive ranges of doubles") {
+ val data = 1.0 to 100.0 by 1.0
+ val slices = ParallelCollectionRDD.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 100)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..94df282b28
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.LocalSparkContext
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.Partition
+import org.apache.spark.TaskContext
+import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import org.apache.spark.{FetchFailed, Success, TaskEndReason}
+import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
+
+import org.apache.spark.scheduler.cluster.Pool
+import org.apache.spark.scheduler.cluster.SchedulingMode
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ /** Set of TaskSets the DAGScheduler has requested executed. */
+ val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+ val taskScheduler = new TaskScheduler() {
+ override def rootPool: Pool = null
+ override def schedulingMode: SchedulingMode = SchedulingMode.NONE
+ override def start() = {}
+ override def stop() = {}
+ override def submitTasks(taskSet: TaskSet) = {
+ // normally done by TaskSetManager
+ taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+ taskSets += taskSet
+ }
+ override def setListener(listener: TaskSchedulerListener) = {}
+ override def defaultParallelism() = 2
+ }
+
+ var mapOutputTracker: MapOutputTracker = null
+ var scheduler: DAGScheduler = null
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+ // stub out BlockManagerMaster.getLocations to use our cacheLocations
+ val blockManagerMaster = new BlockManagerMaster(null) {
+ override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ cacheLocations.getOrElse(key, Seq())
+ } else {
+ Seq()
+ }
+ }.toSeq
+ }
+ override def removeExecutor(execId: String) {
+ // don't need to propagate to the driver, which we don't have
+ }
+ }
+
+ /** The list of results that DAGScheduler has collected. */
+ val results = new HashMap[Int, Any]()
+ var failure: Exception = _
+ val listener = new JobListener() {
+ override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+ override def jobFailed(exception: Exception) = { failure = exception }
+ }
+
+ before {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ taskSets.clear()
+ cacheLocations.clear()
+ results.clear()
+ mapOutputTracker = new MapOutputTracker()
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+ override def runLocally(job: ActiveJob) {
+ // don't bother with the thread while unit testing
+ runLocallyWithinThread(job)
+ }
+ }
+ }
+
+ after {
+ scheduler.stop()
+ }
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ private def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ private def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+ it.next.asInstanceOf[Tuple2[_, _]]._1
+
+ /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+ private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
+ }
+ }
+ }
+
+ /** Sends the rdd to the scheduler for scheduling. */
+ private def submit(
+ rdd: RDD[_],
+ partitions: Array[Int],
+ func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+ allowLocal: Boolean = false,
+ listener: JobListener = listener) {
+ runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+ }
+
+ /** Sends TaskSetFailed to the scheduler. */
+ private def failed(taskSet: TaskSet, message: String) {
+ runEvent(TaskSetFailed(taskSet, message))
+ }
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ val fakeListener = new JobListener() {
+ override def taskSucceeded(partition: Int, value: Any) = numResults += 1
+ override def jobFailed(exception: Exception) = throw exception
+ }
+ submit(rdd, Array(), listener = fakeListener)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ submit(rdd, Array(0))
+ complete(taskSets(0), List((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getPartitions = Array( new Partition { override def index = 0 } )
+ override def getPreferredLocations(split: Partition) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ submit(finalRdd, Array(0))
+ complete(taskSets(0), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ submit(finalRdd, Array(0))
+ val taskSet = taskSets(0)
+ assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
+ complete(taskSet, Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("trivial job failure") {
+ submit(makeRdd(1, Nil), Array(0))
+ failed(taskSets(0), "some failure")
+ assert(failure.getMessage === "Job failed: some failure")
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+ submit(reduceRdd, Array(0))
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ complete(taskSets(1), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // the 2nd ResultTask failed
+ complete(taskSets(1), Seq(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
+ // this will get called
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // ask the scheduler to try it again
+ scheduler.resubmitFailedStages()
+ // have the 2nd attempt pass
+ complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+ // we can see both result blocks now
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
+ complete(taskSets(3), Seq((Success, 43)))
+ assert(results === Map(0 -> 42, 1 -> 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+ // pretend we were told hostA went away
+ val oldEpoch = mapOutputTracker.getEpoch
+ runEvent(ExecutorLost("exec-hostA"))
+ val newEpoch = mapOutputTracker.getEpoch
+ assert(newEpoch > oldEpoch)
+ val noAccum = Map[Long, Any]()
+ val taskSet = taskSets(0)
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+ // should work because it's a new epoch
+ taskSet.tasks(1).epoch = newEpoch
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ complete(taskSets(1), Seq((Success, 42), (Success, 43)))
+ assert(results === Map(0 -> 42, 1 -> 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+ submit(reduceRdd, Array(0))
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // pretend we were told hostA went away
+ runEvent(ExecutorLost("exec-hostA"))
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // have hostC complete the resubmitted task
+ complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ complete(taskSets(2), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+ submit(finalRdd, Array(0))
+ // have the first stage complete normally
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ // have the second stage complete normally
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))))
+ // fail the third stage because hostA went down
+ complete(taskSets(2), Seq(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ // TODO assert this:
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // have DAGScheduler try again
+ scheduler.resubmitFailedStages()
+ complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
+ complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
+ complete(taskSets(5), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+ submit(finalRdd, Array(0))
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ // complete stage 2
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ // complete stage 1
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // pretend stage 0 failed because hostA went down
+ complete(taskSets(2), Seq(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ // TODO assert this:
+ // blockManagerMaster.removeExecutor("exec-hostA")
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ scheduler.resubmitFailedStages()
+ assertLocations(taskSets(3), Seq(Seq("hostD")))
+ // allow hostD to recover
+ complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
+ complete(taskSets(4), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
+ * Note that this checks only the host and not the executor ID.
+ */
+ private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) {
+ assert(hosts.size === taskSet.tasks.size)
+ for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) {
+ assert(taskLocs.map(_.host) === expectedLocs)
+ }
+ }
+
+ private def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ private def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345, 0)
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
new file mode 100644
index 0000000000..f5b3e97222
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.LinkedBlockingQueue
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+
+class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("inner method") {
+ sc = new SparkContext("local", "joblogger")
+ val joblogger = new JobLogger {
+ def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
+ def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
+ def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
+ def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
+ }
+ type MyRDD = RDD[(Int, Int)]
+ def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]]
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ }
+ }
+ val jobID = 5
+ val parentRdd = makeRdd(4, Nil)
+ val shuffleDep = new ShuffleDependency(parentRdd, null)
+ val rootRdd = makeRdd(4, List(shuffleDep))
+ val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
+ val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ parentRdd.setName("MyRDD")
+ joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
+ joblogger.createLogWriterTest(jobID)
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.buildJobDepTest(jobID, rootStage)
+ joblogger.getJobIDToStages.get(jobID).get.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
+ joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
+ joblogger.closeLogWriterTest(jobID)
+ joblogger.getStageIDToJobID.size should be (0)
+ joblogger.getJobIDToStages.size should be (0)
+ joblogger.getJobIDtoPrintWriter.size should be (0)
+ }
+
+ test("inner variables") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ override protected def closeLogWriter(jobID: Int) =
+ getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ }
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.getStageIDToJobID.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(0))
+ joblogger.getStageIDToJobID.get(1) should be (Some(0))
+ joblogger.getJobIDToStages.size should be (1)
+ }
+
+
+ test("interface functions") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ var onTaskEndCount = 0
+ var onJobEndCount = 0
+ var onJobStartCount = 0
+ var onStageCompletedCount = 0
+ var onStageSubmittedCount = 0
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
+ override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
+ override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.onJobStartCount should be (1)
+ joblogger.onJobEndCount should be (1)
+ joblogger.onTaskEndCount should be (8)
+ joblogger.onStageSubmittedCount should be (2)
+ joblogger.onStageCompletedCount should be (2)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
new file mode 100644
index 0000000000..aac7c207cb
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.scalatest.FunSuite
+import org.apache.spark.{SparkContext, LocalSparkContext}
+import scala.collection.mutable
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.SparkContext._
+
+/**
+ *
+ */
+
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("local metrics") {
+ sc = new SparkContext("local[4]", "test")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ sc.addSparkListener(new StatsReportListener)
+ //just to make sure some of the tasks take a noticeable amount of time
+ val w = {i:Int =>
+ if (i == 0)
+ Thread.sleep(100)
+ i
+ }
+
+ val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ d.count
+ listener.stageInfos.size should be (1)
+
+ val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1")
+
+ val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2")
+
+ val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
+ d4.setName("A Cogroup")
+
+ d4.collectAsMap
+
+ listener.stageInfos.size should be (4)
+ listener.stageInfos.foreach {stageInfo =>
+ //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
+ if (stageInfo.stage.rdd.name == d4.name) {
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
+ }
+
+ stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
+ taskMetrics.resultSize should be > (0l)
+ if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ taskMetrics.shuffleWriteMetrics should be ('defined)
+ taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
+ }
+ if (stageInfo.stage.rdd.name == d4.name) {
+ taskMetrics.shuffleReadMetrics should be ('defined)
+ val sm = taskMetrics.shuffleReadMetrics.get
+ sm.totalBlocksFetched should be > (0)
+ sm.localBlocksFetched should be > (0)
+ sm.remoteBlocksFetched should be (0)
+ sm.remoteBytesRead should be (0l)
+ sm.remoteFetchTime should be (0l)
+ }
+ }
+ }
+ }
+
+ def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+ assert(m.sum / m.size.toDouble > 0.0, msg)
+ }
+
+ def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
+ val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
+ !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+ }
+
+ class SaveStageInfo extends SparkListener {
+ val stageInfos = mutable.Buffer[StageInfo]()
+ override def onStageCompleted(stage: StageCompleted) {
+ stageInfos += stage.stageInfo
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
new file mode 100644
index 0000000000..0347cc02d7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.TaskContext
+import org.apache.spark.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.Partition
+import org.apache.spark.LocalSparkContext
+
+class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ test("Calls executeOnCompleteCallbacks after failure") {
+ var completed = false
+ sc = new SparkContext("local", "test")
+ val rdd = new RDD[String](sc, List()) {
+ override def getPartitions = Array[Partition](StubPartition(0))
+ override def compute(split: Partition, context: TaskContext) = {
+ context.addOnCompleteCallback(() => completed = true)
+ sys.error("failed")
+ }
+ }
+ val func = (c: TaskContext, i: Iterator[String]) => i.next
+ val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+ intercept[RuntimeException] {
+ task.run(0)
+ }
+ assert(completed === true)
+ }
+
+ case class StubPartition(val index: Int) extends Partition
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
new file mode 100644
index 0000000000..92ad9f09b2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark._
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+
+import java.util.Properties
+
+class FakeTaskSetManager(
+ initPriority: Int,
+ initStageId: Int,
+ initNumTasks: Int,
+ clusterScheduler: ClusterScheduler,
+ taskSet: TaskSet)
+ extends ClusterTaskSetManager(clusterScheduler, taskSet) {
+
+ parent = null
+ weight = 1
+ minShare = 2
+ runningTasks = 0
+ priority = initPriority
+ stageId = initStageId
+ name = "TaskSet_"+stageId
+ override val numTasks = initNumTasks
+ tasksFinished = 0
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {
+ }
+
+ override def removeSchedulable(schedulable: Schedulable) {
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String): Unit = {
+ }
+
+ override def resourceOffer(
+ execId: String,
+ host: String,
+ availableCpus: Int,
+ maxLocality: TaskLocality.TaskLocality)
+ : Option[TaskDescription] =
+ {
+ if (tasksFinished + runningTasks < numTasks) {
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
+ }
+ return None
+ }
+
+ override def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def taskFinished() {
+ decreaseRunningTasks(1)
+ tasksFinished +=1
+ if (tasksFinished == numTasks) {
+ parent.removeSchedulable(this)
+ }
+ }
+
+ def abort() {
+ decreaseRunningTasks(runningTasks)
+ parent.removeSchedulable(this)
+ }
+}
+
+class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
+
+ def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = {
+ new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
+ }
+
+ def resourceOffer(rootPool: Pool): Int = {
+ val taskSetQueue = rootPool.getSortedTaskSetQueue()
+ /* Just for Test*/
+ for (manager <- taskSetQueue) {
+ logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+ }
+ for (taskSet <- taskSetQueue) {
+ taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match {
+ case Some(task) =>
+ return taskSet.stageId
+ case None => {}
+ }
+ }
+ -1
+ }
+
+ def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) {
+ assert(resourceOffer(rootPool) === expectedTaskSetId)
+ }
+
+ test("FIFO Scheduler Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new FakeTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
+ val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
+ schedulableBuilder.buildPools()
+
+ val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
+ val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
+ val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager0, null)
+ schedulableBuilder.addTaskSetManager(taskSetManager1, null)
+ schedulableBuilder.addTaskSetManager(taskSetManager2, null)
+
+ checkTaskSetId(rootPool, 0)
+ resourceOffer(rootPool)
+ checkTaskSetId(rootPool, 1)
+ resourceOffer(rootPool)
+ taskSetManager1.abort()
+ checkTaskSetId(rootPool, 2)
+ }
+
+ test("Fair Scheduler Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new FakeTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val schedulableBuilder = new FairSchedulableBuilder(rootPool)
+ schedulableBuilder.buildPools()
+
+ assert(rootPool.getSchedulableByName("default") != null)
+ assert(rootPool.getSchedulableByName("1") != null)
+ assert(rootPool.getSchedulableByName("2") != null)
+ assert(rootPool.getSchedulableByName("3") != null)
+ assert(rootPool.getSchedulableByName("1").minShare === 2)
+ assert(rootPool.getSchedulableByName("1").weight === 1)
+ assert(rootPool.getSchedulableByName("2").minShare === 3)
+ assert(rootPool.getSchedulableByName("2").weight === 1)
+ assert(rootPool.getSchedulableByName("3").minShare === 2)
+ assert(rootPool.getSchedulableByName("3").weight === 1)
+
+ val properties1 = new Properties()
+ properties1.setProperty("spark.scheduler.cluster.fair.pool","1")
+ val properties2 = new Properties()
+ properties2.setProperty("spark.scheduler.cluster.fair.pool","2")
+
+ val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
+ val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
+ val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
+ schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
+ schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
+
+ val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
+ val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
+ schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
+
+ checkTaskSetId(rootPool, 0)
+ checkTaskSetId(rootPool, 3)
+ checkTaskSetId(rootPool, 3)
+ checkTaskSetId(rootPool, 1)
+ checkTaskSetId(rootPool, 4)
+ checkTaskSetId(rootPool, 2)
+ checkTaskSetId(rootPool, 2)
+ checkTaskSetId(rootPool, 4)
+
+ taskSetManager12.taskFinished()
+ assert(rootPool.getSchedulableByName("1").runningTasks === 3)
+ taskSetManager24.abort()
+ assert(rootPool.getSchedulableByName("2").runningTasks === 2)
+ }
+
+ test("Nested Pool Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new FakeTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
+ val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1)
+ rootPool.addSchedulable(pool0)
+ rootPool.addSchedulable(pool1)
+
+ val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2)
+ val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1)
+ pool0.addSchedulable(pool00)
+ pool0.addSchedulable(pool01)
+
+ val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2)
+ val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1)
+ pool1.addSchedulable(pool10)
+ pool1.addSchedulable(pool11)
+
+ val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
+ val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
+ pool00.addSchedulable(taskSetManager000)
+ pool00.addSchedulable(taskSetManager001)
+
+ val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
+ val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
+ pool01.addSchedulable(taskSetManager010)
+ pool01.addSchedulable(taskSetManager011)
+
+ val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
+ val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
+ pool10.addSchedulable(taskSetManager100)
+ pool10.addSchedulable(taskSetManager101)
+
+ val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
+ val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
+ pool11.addSchedulable(taskSetManager110)
+ pool11.addSchedulable(taskSetManager111)
+
+ checkTaskSetId(rootPool, 0)
+ checkTaskSetId(rootPool, 4)
+ checkTaskSetId(rootPool, 6)
+ checkTaskSetId(rootPool, 2)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
new file mode 100644
index 0000000000..a4f63baf3d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.scheduler._
+import org.apache.spark.executor.TaskMetrics
+import java.nio.ByteBuffer
+import org.apache.spark.util.FakeClock
+
+/**
+ * A mock ClusterScheduler implementation that just remembers information about tasks started and
+ * feedback received from the TaskSetManagers. Note that it's important to initialize this with
+ * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
+ * to work, and these are required for locality in ClusterTaskSetManager.
+ */
+class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
+ extends ClusterScheduler(sc)
+{
+ val startedTasks = new ArrayBuffer[Long]
+ val endedTasks = new mutable.HashMap[Long, TaskEndReason]
+ val finishedManagers = new ArrayBuffer[TaskSetManager]
+
+ val executors = new mutable.HashMap[String, String] ++ liveExecutors
+
+ listener = new TaskSchedulerListener {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ startedTasks += taskInfo.index
+ }
+
+ def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: mutable.Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
+ {
+ endedTasks(taskInfo.index) = reason
+ }
+
+ def executorGained(execId: String, host: String) {}
+
+ def executorLost(execId: String) {}
+
+ def taskSetFailed(taskSet: TaskSet, reason: String) {}
+ }
+
+ def removeExecutor(execId: String): Unit = executors -= execId
+
+ override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager
+
+ override def isExecutorAlive(execId: String): Boolean = executors.contains(execId)
+
+ override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)
+}
+
+class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
+ import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL}
+
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ test("TaskSet with no preferences") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(1)
+ val manager = new ClusterTaskSetManager(sched, taskSet)
+
+ // Offer a host with no CPUs
+ assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None)
+
+ // Offer a host with process-local as the constraint; this should work because the TaskSet
+ // above won't have any locality preferences
+ val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL)
+ assert(taskOption.isDefined)
+ val task = taskOption.get
+ assert(task.executorId === "exec1")
+ assert(sched.startedTasks.contains(0))
+
+ // Re-offer the host -- now we should get no more tasks
+ assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None)
+
+ // Tell it the task has finished
+ manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+ assert(sched.endedTasks(0) === Success)
+ assert(sched.finishedManagers.contains(manager))
+ }
+
+ test("multiple offers with no preferences") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(3)
+ val manager = new ClusterTaskSetManager(sched, taskSet)
+
+ // First three offers should all find tasks
+ for (i <- 0 until 3) {
+ val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL)
+ assert(taskOption.isDefined)
+ val task = taskOption.get
+ assert(task.executorId === "exec1")
+ }
+ assert(sched.startedTasks.toSet === Set(0, 1, 2))
+
+ // Re-offer the host -- now we should get no more tasks
+ assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+ // Finish the first two tasks
+ manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+ manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1))
+ assert(sched.endedTasks(0) === Success)
+ assert(sched.endedTasks(1) === Success)
+ assert(!sched.finishedManagers.contains(manager))
+
+ // Finish the last task
+ manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2))
+ assert(sched.endedTasks(2) === Success)
+ assert(sched.finishedManagers.contains(manager))
+ }
+
+ test("basic delay scheduling") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = createTaskSet(4,
+ Seq(TaskLocation("host1", "exec1")),
+ Seq(TaskLocation("host2", "exec2")),
+ Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")),
+ Seq() // Last task has no locality prefs
+ )
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // First offer host1, exec1: first task should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Offer host1, exec1 again: the last task, which has no prefs, should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3)
+
+ // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+ // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2
+ assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2)
+
+ // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None)
+
+ // Offer host1, exec1 again, at ANY level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // Offer host1, exec1 again, at ANY level: task 1 should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+ // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+ }
+
+ test("delay scheduling with fallback") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc,
+ ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
+ val taskSet = createTaskSet(5,
+ Seq(TaskLocation("host1")),
+ Seq(TaskLocation("host2")),
+ Seq(TaskLocation("host2")),
+ Seq(TaskLocation("host3")),
+ Seq(TaskLocation("host2"))
+ )
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // First offer host1: first task should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Offer host1 again: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // Offer host1 again: second task (on host2) should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+ // Offer host1 again: third task (on host2) should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2)
+
+ // Offer host2: fifth task (also on host2) should get chosen
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4)
+
+ // Now that we've launched a local task, we should no longer launch the task for host3
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // After another delay, we can go ahead and launch that task non-locally
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3)
+ }
+
+ test("delay scheduling with failed hosts") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = createTaskSet(3,
+ Seq(TaskLocation("host1")),
+ Seq(TaskLocation("host2")),
+ Seq(TaskLocation("host3"))
+ )
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // First offer host1: first task should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Offer host1 again: third task should be chosen immediately because host3 is not up
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2)
+
+ // After this, nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+ // Now mark host2 as dead
+ sched.removeExecutor("exec2")
+ manager.executorLost("exec2", "host2")
+
+ // Task 1 should immediately be launched on host1 because its original host is gone
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+ // Now that all tasks have launched, nothing new should be launched anywhere else
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
+ }
+
+ /**
+ * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
+ * locations for each task (given as varargs) if this sequence is not empty.
+ */
+ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+ if (prefLocs.size != 0 && prefLocs.size != numTasks) {
+ throw new IllegalArgumentException("Wrong number of task locations")
+ }
+ val tasks = Array.tabulate[Task[_]](numTasks) { i =>
+ new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
+ }
+ new TaskSet(tasks, 0, 0, 0, null)
+ }
+
+ def createTaskResult(id: Int): ByteBuffer = {
+ ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics)))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
new file mode 100644
index 0000000000..2f12aaed18
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.scheduler.{TaskLocation, Task}
+
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
+ override def run(attemptId: Long): Int = 0
+
+ override def preferredLocations: Seq[TaskLocation] = prefLocs
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
new file mode 100644
index 0000000000..111340a65c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark._
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+import java.util.concurrent.Semaphore
+import java.util.concurrent.CountDownLatch
+import java.util.Properties
+
+class Lock() {
+ var finished = false
+ def jobWait() = {
+ synchronized {
+ while(!finished) {
+ this.wait()
+ }
+ }
+ }
+
+ def jobFinished() = {
+ synchronized {
+ finished = true
+ this.notifyAll()
+ }
+ }
+}
+
+object TaskThreadInfo {
+ val threadToLock = HashMap[Int, Lock]()
+ val threadToRunning = HashMap[Int, Boolean]()
+ val threadToStarted = HashMap[Int, CountDownLatch]()
+}
+
+/*
+ * 1. each thread contains one job.
+ * 2. each job contains one stage.
+ * 3. each stage only contains one task.
+ * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
+ * it will get cpu core resource, and will wait to finished after user manually
+ * release "Lock" and then cluster will contain another free cpu cores.
+ * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
+ * thus it will be scheduled later when cluster has free cpu cores.
+ */
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+
+ def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
+
+ TaskThreadInfo.threadToRunning(threadIndex) = false
+ val nums = sc.parallelize(threadIndex to threadIndex, 1)
+ TaskThreadInfo.threadToLock(threadIndex) = new Lock()
+ TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
+ new Thread {
+ if (poolName != null) {
+ sc.setLocalProperty("spark.scheduler.cluster.fair.pool", poolName)
+ }
+ override def run() {
+ val ans = nums.map(number => {
+ TaskThreadInfo.threadToRunning(number) = true
+ TaskThreadInfo.threadToStarted(number).countDown()
+ TaskThreadInfo.threadToLock(number).jobWait()
+ TaskThreadInfo.threadToRunning(number) = false
+ number
+ }).collect()
+ assert(ans.toList === List(threadIndex))
+ sem.release()
+ }
+ }.start()
+ }
+
+ test("Local FIFO scheduler end-to-end test") {
+ System.setProperty("spark.cluster.schedulingmode", "FIFO")
+ sc = new SparkContext("local[4]", "test")
+ val sem = new Semaphore(0)
+
+ createThread(1,null,sc,sem)
+ TaskThreadInfo.threadToStarted(1).await()
+ createThread(2,null,sc,sem)
+ TaskThreadInfo.threadToStarted(2).await()
+ createThread(3,null,sc,sem)
+ TaskThreadInfo.threadToStarted(3).await()
+ createThread(4,null,sc,sem)
+ TaskThreadInfo.threadToStarted(4).await()
+ // thread 5 and 6 (stage pending)must meet following two points
+ // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
+ // queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
+ // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
+ // So I just use "sleep" 1s here for each thread.
+ // TODO: any better solution?
+ createThread(5,null,sc,sem)
+ Thread.sleep(1000)
+ createThread(6,null,sc,sem)
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(1) === true)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === false)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(1).jobFinished()
+ TaskThreadInfo.threadToStarted(5).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(3).jobFinished()
+ TaskThreadInfo.threadToStarted(6).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === false)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === true)
+
+ TaskThreadInfo.threadToLock(2).jobFinished()
+ TaskThreadInfo.threadToLock(4).jobFinished()
+ TaskThreadInfo.threadToLock(5).jobFinished()
+ TaskThreadInfo.threadToLock(6).jobFinished()
+ sem.acquire(6)
+ }
+
+ test("Local fair scheduler end-to-end test") {
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+ System.setProperty("spark.cluster.schedulingmode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+
+ createThread(10,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(10).await()
+ createThread(20,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(20).await()
+ createThread(30,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(30).await()
+
+ assert(TaskThreadInfo.threadToRunning(10) === true)
+ assert(TaskThreadInfo.threadToRunning(20) === true)
+ assert(TaskThreadInfo.threadToRunning(30) === true)
+
+ createThread(11,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(11).await()
+ createThread(21,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(21).await()
+ createThread(31,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(31).await()
+
+ assert(TaskThreadInfo.threadToRunning(11) === true)
+ assert(TaskThreadInfo.threadToRunning(21) === true)
+ assert(TaskThreadInfo.threadToRunning(31) === true)
+
+ createThread(12,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(12).await()
+ createThread(22,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(22).await()
+ createThread(32,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(12) === true)
+ assert(TaskThreadInfo.threadToRunning(22) === true)
+ assert(TaskThreadInfo.threadToRunning(32) === false)
+
+ TaskThreadInfo.threadToLock(10).jobFinished()
+ TaskThreadInfo.threadToStarted(32).await()
+
+ assert(TaskThreadInfo.threadToRunning(32) === true)
+
+ //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
+ // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
+ //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
+ createThread(23,"2",sc,sem)
+ createThread(33,"3",sc,sem)
+ Thread.sleep(1000)
+
+ TaskThreadInfo.threadToLock(11).jobFinished()
+ TaskThreadInfo.threadToStarted(23).await()
+
+ assert(TaskThreadInfo.threadToRunning(23) === true)
+ assert(TaskThreadInfo.threadToRunning(33) === false)
+
+ TaskThreadInfo.threadToLock(12).jobFinished()
+ TaskThreadInfo.threadToStarted(33).await()
+
+ assert(TaskThreadInfo.threadToRunning(33) === true)
+
+ TaskThreadInfo.threadToLock(20).jobFinished()
+ TaskThreadInfo.threadToLock(21).jobFinished()
+ TaskThreadInfo.threadToLock(22).jobFinished()
+ TaskThreadInfo.threadToLock(23).jobFinished()
+ TaskThreadInfo.threadToLock(30).jobFinished()
+ TaskThreadInfo.threadToLock(31).jobFinished()
+ TaskThreadInfo.threadToLock(32).jobFinished()
+ TaskThreadInfo.threadToLock(33).jobFinished()
+
+ sem.acquire(11)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
new file mode 100644
index 0000000000..88ba10f2f2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -0,0 +1,666 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import akka.actor._
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.PrivateMethodTester
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.matchers.ShouldMatchers._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.JavaSerializer
+import org.apache.spark.KryoSerializer
+import org.apache.spark.SizeEstimator
+import org.apache.spark.Utils
+import org.apache.spark.util.AkkaUtils
+import org.apache.spark.util.ByteBufferInputStream
+
+
+class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
+ var store: BlockManager = null
+ var store2: BlockManager = null
+ var actorSystem: ActorSystem = null
+ var master: BlockManagerMaster = null
+ var oldArch: String = null
+ var oldOops: String = null
+ var oldHeartBeat: String = null
+
+ // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+ System.setProperty("spark.kryoserializer.buffer.mb", "1")
+ val serializer = new KryoSerializer
+
+ before {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
+ this.actorSystem = actorSystem
+ System.setProperty("spark.driver.port", boundPort.toString)
+ System.setProperty("spark.hostPort", "localhost:" + boundPort)
+
+ master = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+
+ // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
+ oldArch = System.setProperty("os.arch", "amd64")
+ oldOops = System.setProperty("spark.test.useCompressedOops", "true")
+ oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true")
+ val initialize = PrivateMethod[Unit]('initialize)
+ SizeEstimator invokePrivate initialize()
+ // Set some value ...
+ System.setProperty("spark.hostPort", Utils.localHostName() + ":" + 1111)
+ }
+
+ after {
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+
+ if (store != null) {
+ store.stop()
+ store = null
+ }
+ if (store2 != null) {
+ store2.stop()
+ store2 = null
+ }
+ actorSystem.shutdown()
+ actorSystem.awaitTermination()
+ actorSystem = null
+ master = null
+
+ if (oldArch != null) {
+ System.setProperty("os.arch", oldArch)
+ } else {
+ System.clearProperty("os.arch")
+ }
+
+ if (oldOops != null) {
+ System.setProperty("spark.test.useCompressedOops", oldOops)
+ } else {
+ System.clearProperty("spark.test.useCompressedOops")
+ }
+ }
+
+ test("StorageLevel object caching") {
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
+ val bytes1 = Utils.serialize(level1)
+ val level1_ = Utils.deserialize[StorageLevel](bytes1)
+ val bytes2 = Utils.serialize(level2)
+ val level2_ = Utils.deserialize[StorageLevel](bytes2)
+ assert(level1_ === level1, "Deserialized level1 not same as original level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
+ }
+
+ test("BlockManagerId object caching") {
+ val id1 = BlockManagerId("e1", "XXX", 1, 0)
+ val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
+ val bytes1 = Utils.serialize(id1)
+ val id1_ = Utils.deserialize[BlockManagerId](bytes1)
+ val bytes2 = Utils.serialize(id2)
+ val id2_ = Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
+ }
+
+ test("master + 1 manager interaction") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+
+ // Putting a1, a2 and a3 in memory and telling master only about a1 and a2
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false)
+
+ // Checking whether blocks are in memory
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") != None, "a3 was not in store")
+
+ // Checking whether master knows about the blocks or not
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
+ assert(master.getLocations("a2").size > 0, "master was not told about a2")
+ assert(master.getLocations("a3").size === 0, "master was told about a3")
+
+ // Drop a1 and a2 from memory; this should be reported back to the master
+ store.dropFromMemory("a1", null)
+ store.dropFromMemory("a2", null)
+ assert(store.getSingle("a1") === None, "a1 not removed from store")
+ assert(store.getSingle("a2") === None, "a2 not removed from store")
+ assert(master.getLocations("a1").size === 0, "master did not remove a1")
+ assert(master.getLocations("a2").size === 0, "master did not remove a2")
+ }
+
+ test("master + 2 managers interaction") {
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000)
+
+ val peers = master.getPeers(store.blockManagerId, 1)
+ assert(peers.size === 1, "master did not return the other manager as a peer")
+ assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager")
+
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2)
+ store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2)
+ assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1")
+ assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2")
+ }
+
+ test("removing block") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+
+ // Putting a1, a2 and a3 in memory and telling master only about a1 and a2
+ store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false)
+
+ // Checking whether blocks are in memory and memory size
+ val memStatus = master.getMemoryStatus.head._2
+ assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000")
+ assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200")
+ assert(store.getSingle("a1-to-remove") != None, "a1 was not in store")
+ assert(store.getSingle("a2-to-remove") != None, "a2 was not in store")
+ assert(store.getSingle("a3-to-remove") != None, "a3 was not in store")
+
+ // Checking whether master knows about the blocks or not
+ assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1")
+ assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2")
+ assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3")
+
+ // Remove a1 and a2 and a3. Should be no-op for a3.
+ master.removeBlock("a1-to-remove")
+ master.removeBlock("a2-to-remove")
+ master.removeBlock("a3-to-remove")
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a1-to-remove") should be (None)
+ master.getLocations("a1-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a2-to-remove") should be (None)
+ master.getLocations("a2-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a3-to-remove") should not be (None)
+ master.getLocations("a3-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ val memStatus = master.getMemoryStatus.head._2
+ memStatus._1 should equal (2000L)
+ memStatus._2 should equal (2000L)
+ }
+ }
+
+ test("removing rdd") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ // Putting a1, a2 and a3 in memory.
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0, blocking = false)
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("nonrddblock") should not be (None)
+ master.getLocations("nonrddblock") should have size (1)
+ }
+
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0, blocking = true)
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+
+ test("reregistration on heart beat") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
+
+ master.removeExecutor(store.blockManagerId.executorId)
+ assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
+
+ store invokePrivate heartBeat()
+ assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
+ }
+
+ test("reregistration on block update") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
+
+ master.removeExecutor(store.blockManagerId.executorId)
+ assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
+
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
+ store.waitForAsyncReregister()
+
+ assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
+ assert(master.getLocations("a2").size > 0, "master was not told about a2")
+ }
+
+ test("reregistration doesn't dead lock") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = List(new Array[Byte](400))
+
+ // try many times to trigger any deadlocks
+ for (i <- 1 to 100) {
+ master.removeExecutor(store.blockManagerId.executorId)
+ val t1 = new Thread {
+ override def run() {
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ }
+ }
+ val t2 = new Thread {
+ override def run() {
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ }
+ }
+ val t3 = new Thread {
+ override def run() {
+ store invokePrivate heartBeat()
+ }
+ }
+
+ t1.start()
+ t2.start()
+ t3.start()
+ t1.join()
+ t2.join()
+ t3.join()
+
+ store.dropFromMemory("a1", null)
+ store.dropFromMemory("a2", null)
+ store.waitForAsyncReregister()
+ }
+ }
+
+ test("in-memory LRU storage") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY)
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") != None, "a3 was not in store")
+ assert(store.getSingle("a1") === None, "a1 was in store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ // At this point a2 was gotten last, so LRU will getSingle rid of a3
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") === None, "a3 was in store")
+ }
+
+ test("in-memory LRU storage with serialization") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER)
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") != None, "a3 was not in store")
+ assert(store.getSingle("a1") === None, "a1 was in store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ // At this point a2 was gotten last, so LRU will getSingle rid of a3
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER)
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") === None, "a3 was in store")
+ }
+
+ test("in-memory LRU for partitions of same RDD") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
+ // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
+ // from the same RDD
+ assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
+ assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store")
+ // Check that rdd_0_3 doesn't replace them even after further accesses
+ assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ }
+
+ test("in-memory LRU for partitions of multiple RDDs") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ // At this point rdd_1_1 should've replaced rdd_0_1
+ assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store")
+ assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
+ assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
+ // Do a get() on rdd_0_2 so that it is the most recently used item
+ assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
+ // Put in more partitions from RDD 0; they should replace rdd_1_1
+ store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped
+ // when we try to add rdd_0_4.
+ assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store")
+ assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
+ assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store")
+ assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store")
+ }
+
+ test("on-disk storage") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.DISK_ONLY)
+ store.putSingle("a2", a2, StorageLevel.DISK_ONLY)
+ store.putSingle("a3", a3, StorageLevel.DISK_ONLY)
+ assert(store.getSingle("a2") != None, "a2 was in store")
+ assert(store.getSingle("a3") != None, "a3 was in store")
+ assert(store.getSingle("a1") != None, "a1 was in store")
+ }
+
+ test("disk and memory storage") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK)
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") != None, "a3 was not in store")
+ assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store")
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store")
+ }
+
+ test("disk and memory storage with getLocalBytes") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK)
+ assert(store.getLocalBytes("a2") != None, "a2 was not in store")
+ assert(store.getLocalBytes("a3") != None, "a3 was not in store")
+ assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store")
+ assert(store.getLocalBytes("a1") != None, "a1 was not in store")
+ assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store")
+ }
+
+ test("disk and memory storage with serialization") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER)
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") != None, "a3 was not in store")
+ assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store")
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store")
+ }
+
+ test("disk and memory storage with serialization and getLocalBytes") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER)
+ assert(store.getLocalBytes("a2") != None, "a2 was not in store")
+ assert(store.getLocalBytes("a3") != None, "a3 was not in store")
+ assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store")
+ assert(store.getLocalBytes("a1") != None, "a1 was not in store")
+ assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store")
+ }
+
+ test("LRU with mixed storage levels") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ val a4 = new Array[Byte](400)
+ // First store a1 and a2, both in memory, and a3, on disk only
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER)
+ store.putSingle("a3", a3, StorageLevel.DISK_ONLY)
+ // At this point LRU should not kick in because a3 is only on disk
+ assert(store.getSingle("a1") != None, "a2 was not in store")
+ assert(store.getSingle("a2") != None, "a3 was not in store")
+ assert(store.getSingle("a3") != None, "a1 was not in store")
+ assert(store.getSingle("a1") != None, "a2 was not in store")
+ assert(store.getSingle("a2") != None, "a3 was not in store")
+ assert(store.getSingle("a3") != None, "a1 was not in store")
+ // Now let's add in a4, which uses both disk and memory; a1 should drop out
+ store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER)
+ assert(store.getSingle("a1") == None, "a1 was in store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ assert(store.getSingle("a3") != None, "a3 was not in store")
+ assert(store.getSingle("a4") != None, "a4 was not in store")
+ }
+
+ test("in-memory LRU with streams") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val list1 = List(new Array[Byte](200), new Array[Byte](200))
+ val list2 = List(new Array[Byte](200), new Array[Byte](200))
+ val list3 = List(new Array[Byte](200), new Array[Byte](200))
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ assert(store.get("list2") != None, "list2 was not in store")
+ assert(store.get("list2").get.size == 2)
+ assert(store.get("list3") != None, "list3 was not in store")
+ assert(store.get("list3").get.size == 2)
+ assert(store.get("list1") === None, "list1 was in store")
+ assert(store.get("list2") != None, "list2 was not in store")
+ assert(store.get("list2").get.size == 2)
+ // At this point list2 was gotten last, so LRU will getSingle rid of list3
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ assert(store.get("list1") != None, "list1 was not in store")
+ assert(store.get("list1").get.size == 2)
+ assert(store.get("list2") != None, "list2 was not in store")
+ assert(store.get("list2").get.size == 2)
+ assert(store.get("list3") === None, "list1 was in store")
+ }
+
+ test("LRU with mixed storage levels and streams") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
+ val list1 = List(new Array[Byte](200), new Array[Byte](200))
+ val list2 = List(new Array[Byte](200), new Array[Byte](200))
+ val list3 = List(new Array[Byte](200), new Array[Byte](200))
+ val list4 = List(new Array[Byte](200), new Array[Byte](200))
+ // First store list1 and list2, both in memory, and list3, on disk only
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true)
+ store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true)
+ store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true)
+ // At this point LRU should not kick in because list3 is only on disk
+ assert(store.get("list1") != None, "list2 was not in store")
+ assert(store.get("list1").get.size === 2)
+ assert(store.get("list2") != None, "list3 was not in store")
+ assert(store.get("list2").get.size === 2)
+ assert(store.get("list3") != None, "list1 was not in store")
+ assert(store.get("list3").get.size === 2)
+ assert(store.get("list1") != None, "list2 was not in store")
+ assert(store.get("list1").get.size === 2)
+ assert(store.get("list2") != None, "list3 was not in store")
+ assert(store.get("list2").get.size === 2)
+ assert(store.get("list3") != None, "list1 was not in store")
+ assert(store.get("list3").get.size === 2)
+ // Now let's add in list4, which uses both disk and memory; list1 should drop out
+ store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)
+ assert(store.get("list1") === None, "list1 was in store")
+ assert(store.get("list2") != None, "list3 was not in store")
+ assert(store.get("list2").get.size === 2)
+ assert(store.get("list3") != None, "list1 was not in store")
+ assert(store.get("list3").get.size === 2)
+ assert(store.get("list4") != None, "list4 was not in store")
+ assert(store.get("list4").get.size === 2)
+ }
+
+ test("negative byte values in ByteBufferInputStream") {
+ val buffer = ByteBuffer.wrap(Array[Int](254, 255, 0, 1, 2).map(_.toByte).toArray)
+ val stream = new ByteBufferInputStream(buffer)
+ val temp = new Array[Byte](10)
+ assert(stream.read() === 254, "unexpected byte read")
+ assert(stream.read() === 255, "unexpected byte read")
+ assert(stream.read() === 0, "unexpected byte read")
+ assert(stream.read(temp, 0, temp.length) === 2, "unexpected number of bytes read")
+ assert(stream.read() === -1, "end of stream not signalled")
+ assert(stream.read(temp, 0, temp.length) === -1, "end of stream not signalled")
+ }
+
+ test("overly large block") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500)
+ store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
+ assert(store.getSingle("a1") === None, "a1 was in store")
+ store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
+ assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store")
+ assert(store.getSingle("a2") != None, "a2 was not in store")
+ }
+
+ test("block compression") {
+ try {
+ System.setProperty("spark.shuffle.compress", "true")
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
+ store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
+ store.stop()
+ store = null
+
+ System.setProperty("spark.shuffle.compress", "false")
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
+ store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
+ store.stop()
+ store = null
+
+ System.setProperty("spark.broadcast.compress", "true")
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
+ store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
+ store.stop()
+ store = null
+
+ System.setProperty("spark.broadcast.compress", "false")
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
+ store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
+ store.stop()
+ store = null
+
+ System.setProperty("spark.rdd.compress", "true")
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
+ store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
+ store.stop()
+ store = null
+
+ System.setProperty("spark.rdd.compress", "false")
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
+ store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
+ store.stop()
+ store = null
+
+ // Check that any other block types are also kept uncompressed
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000)
+ store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
+ assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
+ store.stop()
+ store = null
+ } finally {
+ System.clearProperty("spark.shuffle.compress")
+ System.clearProperty("spark.broadcast.compress")
+ System.clearProperty("spark.rdd.compress")
+ }
+ }
+
+ test("block store put failure") {
+ // Use Java serializer so we can create an unserializable error.
+ store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer, 1200)
+
+ // The put should fail since a1 is not serializable.
+ class UnserializableClass
+ val a1 = new UnserializableClass
+ intercept[java.io.NotSerializableException] {
+ store.putSingle("a1", a1, StorageLevel.DISK_ONLY)
+ }
+
+ // Make sure get a1 doesn't hang and returns None.
+ failAfter(1 second) {
+ assert(store.getSingle("a1") == None, "a1 should not be in store")
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
new file mode 100644
index 0000000000..3321fb5eb7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.util.{Failure, Success, Try}
+import java.net.ServerSocket
+import org.scalatest.FunSuite
+import org.eclipse.jetty.server.Server
+
+class UISuite extends FunSuite {
+ test("jetty port increases under contention") {
+ val startPort = 3030
+ val server = new Server(startPort)
+ server.start()
+ val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq())
+ val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq())
+
+ // Allow some wiggle room in case ports on the machine are under contention
+ assert(boundPort1 > startPort && boundPort1 < startPort + 10)
+ assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
+ }
+
+ test("jetty binds to port 0 correctly") {
+ val (jettyServer, boundPort) = JettyUtils.startJettyServer("localhost", 0, Seq())
+ assert(jettyServer.getState === "STARTED")
+ assert(boundPort != 0)
+ Try {new ServerSocket(boundPort)} match {
+ case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort))
+ case Failure (e) =>
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
new file mode 100644
index 0000000000..63642461e4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+/**
+ *
+ */
+
+class DistributionSuite extends FunSuite with ShouldMatchers {
+ test("summary") {
+ val d = new Distribution((1 to 100).toArray.map{_.toDouble})
+ val stats = d.statCounter
+ stats.count should be (100)
+ stats.mean should be (50.5)
+ stats.sum should be (50 * 101)
+
+ val quantiles = d.getQuantiles()
+ quantiles(0) should be (1)
+ quantiles(1) should be (26)
+ quantiles(2) should be (51)
+ quantiles(3) should be (76)
+ quantiles(4) should be (100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/FakeClock.scala b/core/src/test/scala/org/apache/spark/util/FakeClock.scala
new file mode 100644
index 0000000000..0a45917b08
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/FakeClock.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+class FakeClock extends Clock {
+ private var time = 0L
+
+ def advance(millis: Long): Unit = time += millis
+
+ def getTime(): Long = time
+}
diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
new file mode 100644
index 0000000000..45867463a5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable.Buffer
+import java.util.NoSuchElementException
+
+class NextIteratorSuite extends FunSuite with ShouldMatchers {
+ test("one iteration") {
+ val i = new StubIterator(Buffer(1))
+ i.hasNext should be === true
+ i.next should be === 1
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("two iterations") {
+ val i = new StubIterator(Buffer(1, 2))
+ i.hasNext should be === true
+ i.next should be === 1
+ i.hasNext should be === true
+ i.next should be === 2
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("empty iteration") {
+ val i = new StubIterator(Buffer())
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("close is called once for empty iterations") {
+ val i = new StubIterator(Buffer())
+ i.hasNext should be === false
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ }
+
+ test("close is called once for non-empty iterations") {
+ val i = new StubIterator(Buffer(1, 2))
+ i.next should be === 1
+ i.next should be === 2
+ // close isn't called until we check for the next element
+ i.closeCalled should be === 0
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ }
+
+ class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] {
+ var closeCalled = 0
+
+ override def getNext() = {
+ if (ints.size == 0) {
+ finished = true
+ 0
+ } else {
+ ints.remove(0)
+ }
+ }
+
+ override def close() {
+ closeCalled += 1
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala
new file mode 100644
index 0000000000..a9dd0b1a5b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.scalatest.FunSuite
+import java.io.ByteArrayOutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStreamSuite extends FunSuite {
+
+ private def benchmark[U](f: => U): Long = {
+ val start = System.nanoTime
+ f
+ System.nanoTime - start
+ }
+
+ test("write") {
+ val underlying = new ByteArrayOutputStream
+ val data = "X" * 41000
+ val stream = new RateLimitedOutputStream(underlying, 10000)
+ val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
+ assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4)
+ assert(underlying.toString("UTF-8") == data)
+ }
+}