aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2013-12-30 15:08:34 -0800
committerHossein Falaki <falaki@gmail.com>2013-12-30 15:08:34 -0800
commitd50ccc5ca9f9f0fa6418c88e7fbfb4a87b1a0e68 (patch)
tree2f458388f4773621e34afb68a5efcec8e5209a6f /core/src/test
parent49bf47e1b792b82561b164f4f8006ddd4dd350ee (diff)
parentd63856c361cf47b1a508397ee9de38a7b5899fa0 (diff)
downloadspark-d50ccc5ca9f9f0fa6418c88e7fbfb4a87b1a0e68.tar.gz
spark-d50ccc5ca9f9f0fa6418c88e7fbfb4a87b1a0e68.tar.bz2
spark-d50ccc5ca9f9f0fa6418c88e7fbfb4a87b1a0e68.zip
Using origin version
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/DriverSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/FileServerSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java68
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala140
-rw-r--r--core/src/test/scala/org/apache/spark/UnpersistSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala (renamed from core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala)35
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala271
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala86
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala51
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala132
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala111
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala72
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala177
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala180
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala119
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala117
35 files changed, 1905 insertions, 156 deletions
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 4434f3b87c..c443c5266e 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -27,6 +27,21 @@ import org.apache.spark.SparkContext._
class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+
+ implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] {
+ def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = {
+ t1 ++= t2
+ t1
+ }
+ def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = {
+ t1 += t2
+ t1
+ }
+ def zero(t: mutable.Set[A]) : mutable.Set[A] = {
+ new mutable.HashSet[A]()
+ }
+ }
+
test ("basic accumulation"){
sc = new SparkContext("local", "test")
val acc : Accumulator[Int] = sc.accumulator(0)
@@ -51,7 +66,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
test ("add value to collection accumulators") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -68,22 +82,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
}
- implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
- def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
- t1 ++= t2
- t1
- }
- def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = {
- t1 += t2
- t1
- }
- def zero(t: mutable.Set[Any]) : mutable.Set[Any] = {
- new mutable.HashSet[Any]()
- }
- }
-
test ("value not readable in tasks") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -125,7 +124,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
test ("localValue readable in tasks") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index b3a53d928b..e022accee6 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -20,8 +20,42 @@ package org.apache.spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
-
- test("basic broadcast") {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.broadcast.factory")
+ }
+
+ test("Using HttpBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing HttpBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing HttpBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Using TorrentBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
@@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
- test("broadcast variables accessed in multiple threads") {
+ test("Accessing TorrentBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
+
+ test("Accessing TorrentBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f26c44d3e7..f25d921d3f 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import scala.reflect.ClassTag
import org.scalatest.FunSuite
import java.io.File
import org.apache.spark.rdd._
@@ -62,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithContextRDD(r,
- (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
@@ -207,7 +206,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* not, but this is not done by default as usually the partitions do not refer to any RDD and
* therefore never store the lineage.
*/
- def testCheckpointing[U: ClassManifest](
+ def testCheckpointing[U: ClassTag](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean = true,
testRDDPartitionSize: Boolean = false
@@ -276,7 +275,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
* this RDD will remember the partitions and therefore potentially the whole lineage.
*/
- def testParentCheckpointing[U: ClassManifest](
+ def testParentCheckpointing[U: ClassTag](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean,
testRDDPartitionSize: Boolean
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 480bac84f3..d9cb7fead5 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -122,7 +122,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
}
assert(thrown.getClass === classOf[SparkException])
- assert(thrown.getMessage.contains("more than 4 times"))
+ assert(thrown.getMessage.contains("failed 4 times"))
}
test("caching") {
@@ -303,12 +303,13 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
Thread.sleep(200)
}
} catch {
- case _ => { Thread.sleep(10) }
+ case _: Throwable => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
}
}
+
}
object DistributedSuite {
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index 01a72d8401..6d1695eae7 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -34,7 +34,7 @@ class DriverSuite extends FunSuite with Timeouts {
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
forAll(masters) { (master: String) =>
- failAfter(30 seconds) {
+ failAfter(60 seconds) {
Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master),
new File(System.getenv("SPARK_HOME")))
}
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 35d1d41af1..c210dd5c3b 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -120,4 +120,20 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
}.collect()
assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
+
+ test ("Dynamically adding JARS on a standalone cluster using local: URL") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile.replace("file", "local"))
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader()
+ .loadClass("org.uncommons.maths.Maths")
+ .getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 7b0bb89ab2..79913dc718 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void javaDoubleRDDHistoGram() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ // Test using generated buckets
+ Tuple2<double[], long[]> results = rdd.histogram(2);
+ double[] expected_buckets = {1.0, 2.5, 4.0};
+ long[] expected_counts = {2, 2};
+ Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2);
+ // Test with provided buckets
+ long[] histogram = rdd.histogram(expected_buckets);
+ Assert.assertArrayEquals(expected_counts, histogram);
+ }
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
@@ -473,6 +487,27 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void repartition() {
+ // Shrinking number of partitions
+ JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
+ JavaRDD<Integer> repartitioned1 = in1.repartition(4);
+ List<List<Integer>> result1 = repartitioned1.glom().collect();
+ Assert.assertEquals(4, result1.size());
+ for (List<Integer> l: result1) {
+ Assert.assertTrue(l.size() > 0);
+ }
+
+ // Growing number of partitions
+ JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
+ JavaRDD<Integer> repartitioned2 = in2.repartition(2);
+ List<List<Integer>> result2 = repartitioned2.glom().collect();
+ Assert.assertEquals(2, result2.size());
+ for (List<Integer> l: result2) {
+ Assert.assertTrue(l.size() > 0);
+ }
+ }
+
+ @Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
@@ -862,4 +897,37 @@ public class JavaAPISuite implements Serializable {
new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
}
+
+ @Test
+ public void collectPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
+
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+
+ List[] parts = rdd1.collectPartitions(new int[] {0});
+ Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
+
+ parts = rdd1.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(3, 4), parts[0]);
+ Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]);
+
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(2, 0)),
+ rdd2.collectPartitions(new int[] {0})[0]);
+
+ parts = rdd2.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
+ new Tuple2<Integer, Integer>(4, 0)),
+ parts[0]);
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
+ new Tuple2<Integer, Integer>(6, 0),
+ new Tuple2<Integer, Integer>(7, 1)),
+ parts[1]);
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index a192651491..1121e06e2e 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark
import java.util.concurrent.Semaphore
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global
@@ -83,6 +85,36 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}
+ test("job group") {
+ sc = new SparkContext("local[2]", "test")
+
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ // jobA is the one to be cancelled.
+ val jobA = future {
+ sc.setJobGroup("jobA", "this is a job to be cancelled")
+ sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ }
+
+ sc.clearJobGroup()
+ val jobB = sc.parallelize(1 to 100, 2).countAsync()
+
+ // Block until both tasks of job A have started and cancel job A.
+ sem.acquire(2)
+ sc.cancelJobGroup("jobA")
+ val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
+ assert(e.getMessage contains "cancel")
+
+ // Once A is cancelled, job B should finish fairly quickly.
+ assert(jobB.get() === 100)
+ }
+/*
test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
@@ -116,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
intercept[SparkException] { f1.get() }
intercept[SparkException] { f2.get() }
}
-
+ */
def testCount() {
// Cancel before launching any tasks
{
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 459e257d79..8dd5786da6 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self
@transient var sc: SparkContext = _
override def beforeAll() {
- InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory())
super.beforeAll()
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6013320eaa..271dc905bc 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
- // As if we had two simulatenous fetch failures
+ // As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@@ -102,14 +100,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
- val masterTracker = new MapOutputTracker()
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+ val masterTracker = new MapOutputTrackerMaster()
+ masterTracker.trackerActor = Left(actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker"))
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
- slaveTracker.trackerActor = slaveSystem.actorFor(
- "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+ slaveTracker.trackerActor = Right(slaveSystem.actorSelection(
+ "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker"))
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 7d938917f2..1374d01774 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
.filter(_ >= 0.0)
// Run the partitions, including the consecutive empty ones, through StatCounter
- val stats: StatCounter = rdd.stats();
- assert(abs(6.0 - stats.sum) < 0.01);
- assert(abs(6.0/2 - rdd.mean) < 0.01);
- assert(abs(1.0 - rdd.variance) < 0.01);
- assert(abs(1.0 - rdd.stdev) < 0.01);
+ val stats: StatCounter = rdd.stats()
+ assert(abs(6.0 - stats.sum) < 0.01)
+ assert(abs(6.0/2 - rdd.mean) < 0.01)
+ assert(abs(1.0 - rdd.variance) < 0.01)
+ assert(abs(1.0 - rdd.stdev) < 0.01)
// Add other tests here for classes that should be able to handle empty partitions correctly
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
new file mode 100644
index 0000000000..151af0d213
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.{FunSuite, PrivateMethodTester}
+
+import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend}
+import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.scheduler.local.LocalScheduler
+
+class SparkContextSchedulerCreationSuite
+ extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
+
+ def createTaskScheduler(master: String): TaskScheduler = {
+ // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
+ // real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
+ sc = new SparkContext("local", "test")
+ val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
+ SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+ }
+
+ test("bad-master") {
+ val e = intercept[SparkException] {
+ createTaskScheduler("localhost:1234")
+ }
+ assert(e.getMessage.contains("Could not parse Master URL"))
+ }
+
+ test("local") {
+ createTaskScheduler("local") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 1)
+ assert(s.maxFailures === 0)
+ case _ => fail()
+ }
+ }
+
+ test("local-n") {
+ createTaskScheduler("local[5]") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 5)
+ assert(s.maxFailures === 0)
+ case _ => fail()
+ }
+ }
+
+ test("local-n-failures") {
+ createTaskScheduler("local[4, 2]") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 4)
+ assert(s.maxFailures === 2)
+ case _ => fail()
+ }
+ }
+
+ test("simr") {
+ createTaskScheduler("simr://uri") match {
+ case s: ClusterScheduler =>
+ assert(s.backend.isInstanceOf[SimrSchedulerBackend])
+ case _ => fail()
+ }
+ }
+
+ test("local-cluster") {
+ createTaskScheduler("local-cluster[3, 14, 512]") match {
+ case s: ClusterScheduler =>
+ assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend])
+ case _ => fail()
+ }
+ }
+
+ def testYarn(master: String, expectedClassName: String) {
+ try {
+ createTaskScheduler(master) match {
+ case s: ClusterScheduler =>
+ assert(s.getClass === Class.forName(expectedClassName))
+ case _ => fail()
+ }
+ } catch {
+ case e: SparkException =>
+ assert(e.getMessage.contains("YARN mode not available"))
+ logWarning("YARN not available, could not test actual YARN scheduler creation")
+ case e: Throwable => fail(e)
+ }
+ }
+
+ test("yarn-standalone") {
+ testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ }
+
+ test("yarn-client") {
+ testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ }
+
+ def testMesos(master: String, expectedClass: Class[_]) {
+ try {
+ createTaskScheduler(master) match {
+ case s: ClusterScheduler =>
+ assert(s.backend.getClass === expectedClass)
+ case _ => fail()
+ }
+ } catch {
+ case e: UnsatisfiedLinkError =>
+ assert(e.getMessage.contains("no mesos in"))
+ logWarning("Mesos not available, could not test actual Mesos scheduler creation")
+ case e: Throwable => fail(e)
+ }
+ }
+
+ test("mesos fine-grained") {
+ System.setProperty("spark.mesos.coarse", "false")
+ testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend])
+ }
+
+ test("mesos coarse-grained") {
+ System.setProperty("spark.mesos.coarse", "true")
+ testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend])
+ }
+
+ test("mesos with zookeeper") {
+ System.setProperty("spark.mesos.coarse", "false")
+ testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend])
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
index 46a2da1724..768ca3850e 100644
--- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
@@ -37,7 +37,7 @@ class UnpersistSuite extends FunSuite with LocalSparkContext {
Thread.sleep(200)
}
} catch {
- case _ => { Thread.sleep(10) }
+ case _: Throwable => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 21f16ef2c6..4cb4ddc9cd 100644
--- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -15,31 +15,22 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.deploy.worker
+import java.io.File
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
+import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription}
+class ExecutorRunnerTest extends FunSuite {
+ test("command includes appId") {
+ def f(s:String) = new File(s)
+ val sparkHome = sys.env("SPARK_HOME")
+ val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()),
+ sparkHome, "appUiUrl")
+ val appId = "12345-worker321-9876"
+ val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
+ f("ooga"), ExecutorState.RUNNING)
-class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
-
- test("Pruned Partitions inherit locality prefs correctly") {
- class TestPartition(i: Int) extends Partition {
- def index = i
- }
- val rdd = new RDD[Int](sc, Nil) {
- override protected def getPartitions = {
- Array[Partition](
- new TestPartition(1),
- new TestPartition(2),
- new TestPartition(3))
- }
- def compute(split: Partition, context: TaskContext) = {Iterator()}
- }
- val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
- val p = prunedRDD.partitions(0)
- assert(p.index == 2)
- assert(prunedRDD.partitions.length == 1)
+ assert(er.buildCommandSeq().last === appId)
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index da032b17d9..0d4c10db8e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.rdd
import java.util.concurrent.Semaphore
+import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -173,4 +175,28 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts
sem.acquire(2)
}
}
+
+ /**
+ * Awaiting FutureAction results
+ */
+ test("FutureAction result, infinite wait") {
+ val f = sc.parallelize(1 to 100, 4)
+ .countAsync()
+ assert(Await.result(f, Duration.Inf) === 100)
+ }
+
+ test("FutureAction result, finite wait") {
+ val f = sc.parallelize(1 to 100, 4)
+ .countAsync()
+ assert(Await.result(f, Duration(30, "seconds")) === 100)
+ }
+
+ test("FutureAction result, timeout") {
+ val f = sc.parallelize(1 to 100, 4)
+ .mapPartitions(itr => { Thread.sleep(20); itr })
+ .countAsync()
+ intercept[TimeoutException] {
+ Await.result(f, Duration(20, "milliseconds"))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000..7f50a5a47c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+ // Verify tests on the histogram functionality. We test with both evenly
+ // and non-evenly spaced buckets as the bucket lookup function changes.
+ test("WorksOnEmpty") {
+ // Make sure that it works on an empty input
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithOneBucket") {
+ // Verify that if all of the elements are out of range the counts are zero
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucket") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucketExactMatch") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(1.0, 4.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+ // Verify that out of range works with two un even buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 4.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBuckets") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBucketsAndNaN") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithFourUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket and an inifity
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithInfiniteBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+ val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(1, 1)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Test the failure mode with an invalid bucket array
+ test("ThrowsExceptionOnInvalidBucketArray") {
+ val rdd = sc.parallelize(Seq(1.0))
+ // Empty array
+ intercept[IllegalArgumentException] {
+ val buckets = Array.empty[Double]
+ val result = rdd.histogram(buckets)
+ }
+ // Single element array
+ intercept[IllegalArgumentException] {
+ val buckets = Array(1.0)
+ val result = rdd.histogram(buckets)
+ }
+ }
+
+ // Test automatic histogram function
+ test("WorksWithoutBucketsBasic") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicSingleElement") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(1)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicNoRange") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsBasicTwo") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(2)
+ val expectedHistogramResults = Array(2, 2)
+ val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2))
+ val (histogramBuckets, histogramResults) = rdd.histogram(10)
+ val expectedHistogramResults =
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+ val expectedHistogramBuckets =
+ Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ // Test the failure mode with an invalid RDD
+ test("ThrowsExceptionOnInvalidRDDs") {
+ // infinity
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+ val result = rdd.histogram(1)
+ }
+ // NaN
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, Double.NaN))
+ val result = rdd.histogram(1)
+ }
+ // Empty
+ intercept[UnsupportedOperationException] {
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val result = rdd.histogram(1)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000000..53a7b7c44d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.scalatest.FunSuite
+import org.apache.spark.{TaskContext, Partition, SharedSparkContext}
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+
+ test("Pruned Partitions inherit locality prefs correctly") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 1),
+ new TestPartition(1, 1),
+ new TestPartition(2, 1))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ Iterator()
+ }
+ }
+ val prunedRDD = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+ assert(prunedRDD.partitions.length == 1)
+ val p = prunedRDD.partitions(0)
+ assert(p.index == 0)
+ assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
+ }
+
+
+ test("Pruned Partitions can be unioned ") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 4),
+ new TestPartition(1, 5),
+ new TestPartition(2, 6))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ List(split.asInstanceOf[TestPartition].testValue).iterator
+ }
+ }
+ val prunedRDD1 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 0) true else false
+ })
+
+ val prunedRDD2 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+
+ val merged = prunedRDD1 ++ prunedRDD2
+ assert(merged.count() == 2)
+ val take = merged.take(2)
+ assert(take.apply(0) == 4)
+ assert(take.apply(1) == 6)
+ }
+}
+
+class TestPartition(i: Int, value: Int) extends Partition with Serializable {
+ def index = i
+
+ def testValue = this.value
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 413ea85322..2f81b81797 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -152,6 +152,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2)
}
+ test("repartitioned RDDs") {
+ val data = sc.parallelize(1 to 1000, 10)
+
+ // Coalesce partitions
+ val repartitioned1 = data.repartition(2)
+ assert(repartitioned1.partitions.size == 2)
+ val partitions1 = repartitioned1.glom().collect()
+ assert(partitions1(0).length > 0)
+ assert(partitions1(1).length > 0)
+ assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
+
+ // Split partitions
+ val repartitioned2 = data.repartition(20)
+ assert(repartitioned2.partitions.size == 20)
+ val partitions2 = repartitioned2.glom().collect()
+ assert(partitions2(0).length > 0)
+ assert(partitions2(19).length > 0)
+ assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
+ }
+
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)
@@ -237,8 +257,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// test that you get over 90% locality in each group
val minLocality = coalesced2.partitions
.map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
- .foldLeft(1.)((perc, loc) => math.min(perc,loc))
- assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%")
+ .foldLeft(1.0)((perc, loc) => math.min(perc,loc))
+ assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%")
// test that the groups are load balanced with 100 +/- 20 elements in each
val maxImbalance = coalesced2.partitions
@@ -250,9 +270,9 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val coalesced3 = data3.coalesce(numMachines*2)
val minLocality2 = coalesced3.partitions
.map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
- .foldLeft(1.)((perc, loc) => math.min(perc,loc))
+ .foldLeft(1.0)((perc, loc) => math.min(perc,loc))
assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " +
- (minLocality2*100.).toInt + "%")
+ (minLocality2*100.0).toInt + "%")
}
test("zipped RDDs") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2a2f828be6..706d84a58b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
-import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
@@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def defaultParallelism() = 2
}
- var mapOutputTracker: MapOutputTracker = null
+ var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@@ -99,8 +99,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTracker()
- scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+ mapOutputTracker = new MapOutputTrackerMaster()
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, sc.env) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
runLocallyWithinThread(job)
@@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
submit(rdd, Array(0))
complete(taskSets(0), List((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("local job") {
@@ -219,6 +220,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val jobId = scheduler.nextJobId.getAndIncrement()
runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("run trivial job w/ dependency") {
@@ -227,6 +229,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
submit(finalRdd, Array(0))
complete(taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("cache location preferences w/ dependency") {
@@ -239,12 +242,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
complete(taskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
failed(taskSets(0), "some failure")
assert(failure.getMessage === "Job aborted: some failure")
+ assertDataStructuresEmpty
}
test("run trivial shuffle") {
@@ -260,6 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("run trivial shuffle with fetch failure") {
@@ -285,6 +291,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
+ assertDataStructuresEmpty
}
test("ignore late map task completions") {
@@ -313,6 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
+ assertDataStructuresEmpty
}
test("run trivial shuffle with out-of-band failure and retry") {
@@ -329,15 +337,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
- // have hostC complete the resubmitted task
- complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
- complete(taskSets(2), Seq((Success, 42)))
- assert(results === Map(0 -> 42))
- }
-
- test("recursive shuffle failures") {
+ // have hostC complete the resubmitted task
+ complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ complete(taskSets(2), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
+ }
+
+ test("recursive shuffle failures") {
val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
@@ -363,6 +372,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
complete(taskSets(5), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("cached post-shuffle") {
@@ -394,6 +404,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
complete(taskSets(4), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
/**
@@ -413,4 +424,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345, 0)
+ private def assertDataStructuresEmpty = {
+ assert(scheduler.pendingTasks.isEmpty)
+ assert(scheduler.activeJobs.isEmpty)
+ assert(scheduler.failed.isEmpty)
+ assert(scheduler.idToActiveJob.isEmpty)
+ assert(scheduler.jobIdToStageIds.isEmpty)
+ assert(scheduler.stageIdToJobIds.isEmpty)
+ assert(scheduler.stageIdToStage.isEmpty)
+ assert(scheduler.stageToInfos.isEmpty)
+ assert(scheduler.resultStageToJob.isEmpty)
+ assert(scheduler.running.isEmpty)
+ assert(scheduler.shuffleToMapStage.isEmpty)
+ assert(scheduler.waiting.isEmpty)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index cece60dda7..002368ff55 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+ val WAIT_TIMEOUT_MILLIS = 10000
test("inner method") {
sc = new SparkContext("local", "joblogger")
@@ -58,11 +59,14 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
- val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
- val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
-
- joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
- joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ val shuffleMapStage =
+ new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
+ val rootStage =
+ new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
+ val rootStageInfo = new StageInfo(rootStage)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getSimpleName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
joblogger.createLogWriterTest(jobID)
@@ -88,8 +92,12 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
- joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0))
@@ -115,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
-
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
joblogger.onJobStartCount should be (1)
joblogger.onJobEndCount should be (1)
joblogger.onTaskEndCount should be (8)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index a549417a47..2e41438a52 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,16 +17,62 @@
package org.apache.spark.scheduler
-import org.scalatest.FunSuite
-import org.apache.spark.{SparkContext, LocalSparkContext}
-import scala.collection.mutable
+import scala.collection.mutable.{Buffer, HashSet}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
-class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
+ with BeforeAndAfterAll {
+ /** Length of time to wait while draining listener events. */
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ override def afterAll {
+ System.clearProperty("spark.akka.frameSize")
+ }
+
+ test("basic creation of StageInfo") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ rdd2.setName("Target RDD")
+ rdd2.count
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.rddName should be {"Target RDD"}
+ first.numTasks should be {4}
+ first.numPartitions should be {4}
+ first.submissionTime should be ('defined)
+ first.completionTime should be ('defined)
+ first.taskInfos.length should be {4}
+ }
+
+ test("StageInfo with fewer tasks than partitions") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.numTasks should be {2}
+ first.numPartitions should be {4}
+ }
test("local metrics") {
- sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -37,9 +83,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
i
}
- val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ val d = sc.parallelize(0 to 1e4.toInt, 64).map{i => w(i)}
d.count()
- val WAIT_TIMEOUT_MILLIS = 10000
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
@@ -64,7 +109,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
@@ -72,11 +117,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
- if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ if (stageInfo.rddName == d2.name || stageInfo.rddName == d3.name) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
@@ -89,20 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}
- def checkNonZeroAvg(m: Traversable[Long], msg: String) {
- assert(m.sum / m.size.toDouble > 0.0, msg)
+ test("onTaskGettingResult() called when result fetched remotely") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ System.setProperty("spark.akka.frameSize", "1")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
+ assert(listener.endedTasks.contains(TASK_INDEX))
}
- def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
- val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
- !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+ test("onTaskGettingResult() not called when result sent directly") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+ assert(result === 2)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.isEmpty == true)
+ assert(listener.endedTasks.contains(TASK_INDEX))
+ }
+
+ def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+ assert(m.sum / m.size.toDouble > 0.0, msg)
}
class SaveStageInfo extends SparkListener {
- val stageInfos = mutable.Buffer[StageInfo]()
+ val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
- stageInfos += stage.stageInfo
+ stageInfos += stage.stage
}
}
+ class SaveTaskEvents extends SparkListener {
+ val startedTasks = new HashSet[Int]()
+ val startedGettingResultTasks = new HashSet[Int]()
+ val endedTasks = new HashSet[Int]()
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ startedTasks += taskStart.taskInfo.index
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ endedTasks += taskEnd.taskInfo.index
+ }
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
+ startedGettingResultTasks += taskGettingResult.taskInfo.index
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index b97f2b19b5..bb28a31a99 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -283,7 +283,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
// Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
// after the last failure.
- (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+ (1 to manager.MAX_TASK_FAILURES).foreach { index =>
val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
assert(offerResult != None,
"Expect resource offer on iteration %s to return a task".format(index))
@@ -313,6 +313,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
}
def createTaskResult(id: Int): DirectTaskResult[Int] = {
- new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics)
+ val valueSer = SparkEnv.get.serializer.newInstance()
+ new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
index ee150a3107..27c2d53361 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -82,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
test("handling results larger than Akka frame size") {
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
@@ -103,7 +103,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
}
scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index cb76275e39..b647e8a672 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -39,7 +39,7 @@ class BlockIdSuite extends FunSuite {
fail()
} catch {
case e: IllegalStateException => // OK
- case _ => fail()
+ case _: Throwable => fail()
}
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 484a654108..5b4d63b954 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -56,7 +56,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.hostPort", "localhost:" + boundPort)
master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))))
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
new file mode 100644
index 0000000000..070982e798
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{FileWriter, File}
+
+import scala.collection.mutable
+
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ val rootDir0 = Files.createTempDir()
+ rootDir0.deleteOnExit()
+ val rootDir1 = Files.createTempDir()
+ rootDir1.deleteOnExit()
+ val rootDirs = rootDir0.getName + "," + rootDir1.getName
+ println("Created root dirs: " + rootDirs)
+
+ // This suite focuses primarily on consolidation features,
+ // so we coerce consolidation if not already enabled.
+ val consolidateProp = "spark.shuffle.consolidateFiles"
+ val oldConsolidate = Option(System.getProperty(consolidateProp))
+ System.setProperty(consolidateProp, "true")
+
+ val shuffleBlockManager = new ShuffleBlockManager(null) {
+ var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
+ override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
+ }
+
+ var diskBlockManager: DiskBlockManager = _
+
+ override def afterAll() {
+ oldConsolidate.map(c => System.setProperty(consolidateProp, c))
+ }
+
+ override def beforeEach() {
+ diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
+ shuffleBlockManager.idToSegmentMap.clear()
+ }
+
+ test("basic block creation") {
+ val blockId = new TestBlockId("test")
+ assertSegmentEquals(blockId, blockId.name, 0, 0)
+
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 10)
+ assertSegmentEquals(blockId, blockId.name, 0, 10)
+
+ newFile.delete()
+ }
+
+ test("block appending") {
+ val blockId = new TestBlockId("test")
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 15)
+ assertSegmentEquals(blockId, blockId.name, 0, 15)
+ val newFile2 = diskBlockManager.getFile(blockId)
+ assert(newFile === newFile2)
+ writeToFile(newFile2, 12)
+ assertSegmentEquals(blockId, blockId.name, 0, 27)
+ newFile.delete()
+ }
+
+ test("block remapping") {
+ val filename = "test"
+ val blockId0 = new ShuffleBlockId(1, 2, 3)
+ val newFile = diskBlockManager.getFile(filename)
+ writeToFile(newFile, 15)
+ shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
+ assertSegmentEquals(blockId0, filename, 0, 15)
+
+ val blockId1 = new ShuffleBlockId(1, 2, 4)
+ val newFile2 = diskBlockManager.getFile(filename)
+ writeToFile(newFile2, 12)
+ shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
+ assertSegmentEquals(blockId1, filename, 15, 12)
+
+ assert(newFile === newFile2)
+ newFile.delete()
+ }
+
+ def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
+ val segment = diskBlockManager.getBlockLocation(blockId)
+ assert(segment.file.getName === filename)
+ assert(segment.offset === offset)
+ assert(segment.length === length)
+ }
+
+ def writeToFile(file: File, numBytes: Int) {
+ val writer = new FileWriter(file, true)
+ for (i <- 0 until numBytes) writer.write(i)
+ writer.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 8f0ec6683b..3764f4d1a0 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -34,7 +34,6 @@ class UISuite extends FunSuite {
}
val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq())
val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq())
-
// Allow some wiggle room in case ports on the machine are under contention
assert(boundPort1 > startPort && boundPort1 < startPort + 10)
assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
new file mode 100644
index 0000000000..67a57a0e7f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.scalatest.FunSuite
+import org.apache.spark.scheduler._
+import org.apache.spark.{LocalSparkContext, SparkContext, Success}
+import org.apache.spark.scheduler.SparkListenerTaskStart
+import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
+
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext {
+ test("test executor id to summary") {
+ val sc = new SparkContext("local", "test")
+ val listener = new JobProgressListener(sc)
+ val taskMetrics = new TaskMetrics()
+ val shuffleReadMetrics = new ShuffleReadMetrics()
+
+ // nothing in it
+ assert(listener.stageIdToExecutorSummaries.size == 0)
+
+ // finish this task, should get updated shuffleRead
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+ .shuffleRead == 1000)
+
+ // finish a task with unknown executor-id, nothing should happen
+ taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.size == 1)
+
+ // finish this task, should get updated duration
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+ .shuffleRead == 2000)
+
+ // finish this task, should get updated duration
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail())
+ .shuffleRead == 1000)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 4e40dcbdee..5aff26f9fc 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -63,54 +63,53 @@ class SizeEstimatorSuite
}
test("simple classes") {
- assert(SizeEstimator.estimate(new DummyClass1) === 16)
- assert(SizeEstimator.estimate(new DummyClass2) === 16)
- assert(SizeEstimator.estimate(new DummyClass3) === 24)
- assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
- assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
+ expectResult(16)(SizeEstimator.estimate(new DummyClass1))
+ expectResult(16)(SizeEstimator.estimate(new DummyClass2))
+ expectResult(24)(SizeEstimator.estimate(new DummyClass3))
+ expectResult(24)(SizeEstimator.estimate(new DummyClass4(null)))
+ expectResult(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
}
// NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- assert(SizeEstimator.estimate(DummyString("")) === 40)
- assert(SizeEstimator.estimate(DummyString("a")) === 48)
- assert(SizeEstimator.estimate(DummyString("ab")) === 48)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
+ expectResult(40)(SizeEstimator.estimate(DummyString("")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
}
test("primitive arrays") {
- assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
- assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
- assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
- assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
- assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
- assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
- assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
- assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
- assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
+ expectResult(32)(SizeEstimator.estimate(new Array[Byte](10)))
+ expectResult(40)(SizeEstimator.estimate(new Array[Char](10)))
+ expectResult(40)(SizeEstimator.estimate(new Array[Short](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[Int](10)))
+ expectResult(96)(SizeEstimator.estimate(new Array[Long](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[Float](10)))
+ expectResult(96)(SizeEstimator.estimate(new Array[Double](10)))
+ expectResult(4016)(SizeEstimator.estimate(new Array[Int](1000)))
+ expectResult(8016)(SizeEstimator.estimate(new Array[Long](1000)))
}
test("object arrays") {
// Arrays containing nulls should just have one pointer per element
- assert(SizeEstimator.estimate(new Array[String](10)) === 56)
- assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
-
+ expectResult(56)(SizeEstimator.estimate(new Array[String](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
- assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
+ expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)))
+ expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)))
+ expectResult(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)))
+ expectResult(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
// Past size 100, our samples 100 elements, but we should still get the right size.
- assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
+ expectResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
// If an array contains the *same* element many times, we should only count it once.
val d1 = new DummyClass1
- assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
- assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
+ expectResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
+ expectResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
// Same thing with huge array containing the same element many times. Note that this won't
// return exactly 4032 because it can't tell that *all* the elements will equal the first
@@ -128,11 +127,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- assert(SizeEstimator.estimate(DummyString("")) === 40)
- assert(SizeEstimator.estimate(DummyString("a")) === 48)
- assert(SizeEstimator.estimate(DummyString("ab")) === 48)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
-
+ expectResult(40)(SizeEstimator.estimate(DummyString("")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
}
@@ -145,10 +143,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- assert(SizeEstimator.estimate(DummyString("")) === 56)
- assert(SizeEstimator.estimate(DummyString("a")) === 64)
- assert(SizeEstimator.estimate(DummyString("ab")) === 64)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
+ expectResult(56)(SizeEstimator.estimate(DummyString("")))
+ expectResult(64)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(64)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
new file mode 100644
index 0000000000..b78367b6ca
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.Random
+import org.scalatest.FlatSpec
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.Utils.times
+
+class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+
+ def fixture = new {
+ val seed = 1L
+ val xorRand = new XORShiftRandom(seed)
+ val hundMil = 1e8.toInt
+ }
+
+ /*
+ * This test is based on a chi-squared test for randomness. The values are hard-coded
+ * so as not to create Spark's dependency on apache.commons.math3 just to call one
+ * method for calculating the exact p-value for a given number of random numbers
+ * and bins. In case one would want to move to a full-fledged test based on
+ * apache.commons.math3, the relevant class is here:
+ * org.apache.commons.math3.stat.inference.ChiSquareTest
+ */
+ test ("XORShift generates valid random numbers") {
+
+ val f = fixture
+
+ val numBins = 10
+ // create 10 bins
+ val bins = Array.fill(numBins)(0)
+
+ // populate bins based on modulus of the random number
+ times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
+
+ /* since the seed is deterministic, until the algorithm is changed, we know the result will be
+ * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
+ * significance level. However, should the RNG implementation change, the test should still
+ * pass at the same significance level. The chi-squared test done in R gave the following
+ * results:
+ * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699))
+ * Chi-squared test for given probabilities
+ * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
+ * 10002286, 9998699)
+ * X-squared = 11.975, df = 9, p-value = 0.2147
+ * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
+ * random numbers
+ * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
+ * is greater than or equal to that number.
+ */
+ val binSize = f.hundMil/numBins
+ val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
+ xSquared should be < (16.9196)
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
new file mode 100644
index 0000000000..0f1ab3d20e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+
+class BitSetSuite extends FunSuite {
+
+ test("basic set and get") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+
+ for (i <- 0 until 100) {
+ assert(!bitset.get(i))
+ }
+
+ setBits.foreach(i => bitset.set(i))
+
+ for (i <- 0 until 100) {
+ if (setBits.contains(i)) {
+ assert(bitset.get(i))
+ } else {
+ assert(!bitset.get(i))
+ }
+ }
+ assert(bitset.cardinality() === setBits.size)
+ }
+
+ test("100% full bit set") {
+ val bitset = new BitSet(10000)
+ for (i <- 0 until 10000) {
+ assert(!bitset.get(i))
+ bitset.set(i)
+ }
+ for (i <- 0 until 10000) {
+ assert(bitset.get(i))
+ }
+ assert(bitset.cardinality() === 10000)
+ }
+
+ test("nextSetBit") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+ setBits.foreach(i => bitset.set(i))
+
+ assert(bitset.nextSetBit(0) === 0)
+ assert(bitset.nextSetBit(1) === 1)
+ assert(bitset.nextSetBit(2) === 9)
+ assert(bitset.nextSetBit(9) === 9)
+ assert(bitset.nextSetBit(10) === 10)
+ assert(bitset.nextSetBit(11) === 90)
+ assert(bitset.nextSetBit(80) === 90)
+ assert(bitset.nextSetBit(91) === 96)
+ assert(bitset.nextSetBit(96) === 96)
+ assert(bitset.nextSetBit(97) === -1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
new file mode 100644
index 0000000000..e9b62ea70d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive value (int)") {
+ val capacity = 1024
+ val map = new OpenHashMap[String, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
+ val expectedSize = capacity * (64 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
+
+ test("initialization") {
+ val goodMap1 = new OpenHashMap[String, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new OpenHashMap[String, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new OpenHashMap[String, String](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, String](0)
+ }
+ }
+
+ test("primitive value") {
+ val map = new OpenHashMap[String, Int]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i
+ assert(map(i.toString) === i)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === 0)
+
+ map(null) = -1
+ assert(map.size === 1001)
+ assert(map(null) === -1)
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1)
+ assert(set === expected.toSet)
+ }
+
+ test("non-primitive value") {
+ val map = new OpenHashMap[String, String]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i.toString
+ assert(map(i.toString) === i.toString)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === null)
+
+ map(null) = "-1"
+ assert(map.size === 1001)
+ assert(map(null) === "-1")
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i.toString)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1")
+ assert(set === expected.toSet)
+ }
+
+ test("null keys") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toString, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, { "null!" }, v => { assert(false); v })
+ assert(map.size === 401)
+ map.changeValue(null, { assert(false); "" }, v => {
+ assert(v === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new OpenHashMap[String, String](1)
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toString) === i.toString)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
new file mode 100644
index 0000000000..1b24f8f287
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.util.SizeEstimator
+
+
+class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive int") {
+ val loadFactor = 0.7
+ val set = new OpenHashSet[Int](64, loadFactor)
+ for (i <- 0 until 1024) {
+ set.add(i)
+ }
+ assert(set.size === 1024)
+ assert(set.capacity > 1024)
+ val actualSize = SizeEstimator.estimate(set)
+ // 32 bits for the ints + 1 bit for the bitset
+ val expectedSize = set.capacity * (32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
+
+ test("primitive int") {
+ val set = new OpenHashSet[Int]
+ assert(set.size === 0)
+ assert(!set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(10)
+ assert(set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 2)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(999)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+ }
+
+ test("primitive long") {
+ val set = new OpenHashSet[Long]
+ assert(set.size === 0)
+ assert(!set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(10L)
+ assert(set.size === 1)
+ assert(set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 2)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(999L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+ }
+
+ test("non-primitive") {
+ val set = new OpenHashSet[String]
+ assert(set.size === 0)
+ assert(!set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(10.toString)
+ assert(set.size === 1)
+ assert(set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 2)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(999.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+ }
+
+ test("non-primitive set growth") {
+ val set = new OpenHashSet[String]
+ for (i <- 1 to 1000) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+
+ test("primitive set growth") {
+ val set = new OpenHashSet[Long]
+ for (i <- 1 to 1000) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
new file mode 100644
index 0000000000..3b60decee9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive key, value (int, int)") {
+ val capacity = 1024
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
+ val expectedSize = capacity * (32 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
+
+ test("initialization") {
+ val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](0)
+ }
+ }
+
+ test("basic operations") {
+ val longBase = 1000000L
+ val map = new PrimitiveKeyOpenHashMap[Long, Int]
+
+ for (i <- 1 to 1000) {
+ map(i + longBase) = i
+ assert(map(i + longBase) === i)
+ }
+
+ assert(map.size === 1000)
+
+ for (i <- 1 to 1000) {
+ assert(map(i + longBase) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(Long, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet)
+ }
+
+ test("null values") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = null
+ }
+ assert(map.size === 100)
+ assert(map(1.toLong) === null)
+ }
+
+ test("changeValue") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toLong, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String](1)
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toLong) === i.toString)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala
new file mode 100644
index 0000000000..970dade628
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.SizeEstimator
+
+class PrimitiveVectorSuite extends FunSuite {
+
+ test("primitive value") {
+ val vector = new PrimitiveVector[Int]
+
+ for (i <- 0 until 1000) {
+ vector += i
+ assert(vector(i) === i)
+ }
+
+ assert(vector.size === 1000)
+ assert(vector.size == vector.length)
+ intercept[IllegalArgumentException] {
+ vector(1000)
+ }
+
+ for (i <- 0 until 1000) {
+ assert(vector(i) == i)
+ }
+ }
+
+ test("non-primitive value") {
+ val vector = new PrimitiveVector[String]
+
+ for (i <- 0 until 1000) {
+ vector += i.toString
+ assert(vector(i) === i.toString)
+ }
+
+ assert(vector.size === 1000)
+ assert(vector.size == vector.length)
+ intercept[IllegalArgumentException] {
+ vector(1000)
+ }
+
+ for (i <- 0 until 1000) {
+ assert(vector(i) == i.toString)
+ }
+ }
+
+ test("ideal growth") {
+ val vector = new PrimitiveVector[Long](initialSize = 1)
+ vector += 1
+ for (i <- 1 until 1024) {
+ vector += i
+ assert(vector.size === i + 1)
+ assert(vector.capacity === Integer.highestOneBit(i) * 2)
+ }
+ assert(vector.capacity === 1024)
+ vector += 1024
+ assert(vector.capacity === 2048)
+ }
+
+ test("ideal size") {
+ val vector = new PrimitiveVector[Long](8192)
+ for (i <- 0 until 8192) {
+ vector += i
+ }
+ assert(vector.size === 8192)
+ assert(vector.capacity === 8192)
+ val actualSize = SizeEstimator.estimate(vector)
+ val expectedSize = 8192 * 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array.
+ assert(actualSize < expectedSize * 1.1)
+ }
+
+ test("resizing") {
+ val vector = new PrimitiveVector[Long]
+ for (i <- 0 until 4097) {
+ vector += i
+ }
+ assert(vector.size === 4097)
+ assert(vector.capacity === 8192)
+ vector.trim()
+ assert(vector.size === 4097)
+ assert(vector.capacity === 4097)
+ vector.resize(5000)
+ assert(vector.size === 4097)
+ assert(vector.capacity === 5000)
+ vector.resize(4000)
+ assert(vector.size === 4000)
+ assert(vector.capacity === 4000)
+ vector.resize(5000)
+ assert(vector.size === 4000)
+ assert(vector.capacity === 5000)
+ for (i <- 0 until 4000) {
+ assert(vector(i) == i)
+ }
+ intercept[IllegalArgumentException] {
+ vector(4000)
+ }
+ }
+}