diff options
30 files changed, 468 insertions, 193 deletions
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 0eda80af64..d2189169d2 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -1,6 +1,6 @@ package spark.bagel -import org.scalatest.{FunSuite, Assertions} +import org.scalatest.{FunSuite, Assertions, BeforeAndAfter} import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ import org.scalacheck.Gen @@ -13,9 +13,16 @@ import spark._ class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions { +class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + sc.stop() + } + test("halting by voting") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 @@ -26,11 +33,10 @@ class BagelSuite extends FunSuite with Assertions { } for ((id, vert) <- result.collect) assert(vert.age === numSupersteps) - sc.stop() } test("halting by message silence") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 @@ -48,6 +54,5 @@ class BagelSuite extends FunSuite with Assertions { } for ((id, vert) <- result.collect) assert(vert.age === numSupersteps) - sc.stop() } } diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index 86e2061b9f..bf77417852 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -4,21 +4,39 @@ import java.io._ import scala.collection.mutable.Map -class Accumulator[T] ( +class Accumulable[T,R] ( @transient initialValue: T, - param: AccumulatorParam[T]) + param: AccumulableParam[T,R]) extends Serializable { val id = Accumulators.newId @transient - var value_ = initialValue // Current value on master + private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false Accumulators.register(this, true) - def += (term: T) { value_ = param.addInPlace(value_, term) } - def value = this.value_ + /** + * add more data to this accumulator / accumulable + * @param term the data to add + */ + def += (term: R) { value_ = param.addAccumulator(value_, term) } + + /** + * merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `+=`. + * @param term the other Accumulable that will get merged with this + */ + def ++= (term: T) { value_ = param.addInPlace(value_, term)} + def value = { + if (!deserialized) value_ + else throw new UnsupportedOperationException("Can't use read value in task") + } + + private[spark] def localValue = value_ + def value_= (t: T) { if (!deserialized) value_ = t else throw new UnsupportedOperationException("Can't use value_= in task") @@ -35,17 +53,58 @@ class Accumulator[T] ( override def toString = value_.toString } -trait AccumulatorParam[T] extends Serializable { - def addInPlace(t1: T, t2: T): T - def zero(initialValue: T): T +class Accumulator[T]( + @transient initialValue: T, + param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param) + +/** + * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type + * as the accumulated value + * @tparam T + */ +trait AccumulatorParam[T] extends AccumulableParam[T,T] { + def addAccumulator(t1: T, t2: T) : T = { + addInPlace(t1, t2) + } +} + +/** + * A datatype that can be accumulated, ie. has a commutative & associative +. + * + * You must define how to add data, and how to merge two of these together. For some datatypes, these might be + * the same operation (eg., a counter). In that case, you might want to use [[spark.AccumulatorParam]]. They won't + * always be the same, though -- eg., imagine you are accumulating a set. You will add items to the set, and you + * will union two sets together. + * + * @tparam R the full accumulated data + * @tparam T partial data that can be added in + */ +trait AccumulableParam[R,T] extends Serializable { + /** + * Add additional data to the accumulator value. + * @param t1 the current value of the accumulator + * @param t2 the data to be added to the accumulator + * @return the new value of the accumulator + */ + def addAccumulator(t1: R, t2: T) : R + + /** + * merge two accumulated values together + * @param t1 one set of accumulated data + * @param t2 another set of accumulated data + * @return both data sets merged together + */ + def addInPlace(t1: R, t2: R): R + + def zero(initialValue: R): R } // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private object Accumulators { // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulator[_]]() - val localAccums = Map[Thread, Map[Long, Accumulator[_]]]() + val originals = Map[Long, Accumulable[_,_]]() + val localAccums = Map[Thread, Map[Long, Accumulable[_,_]]]() var lastId: Long = 0 def newId: Long = synchronized { @@ -53,7 +112,7 @@ private object Accumulators { return lastId } - def register(a: Accumulator[_], original: Boolean): Unit = synchronized { + def register(a: Accumulable[_,_], original: Boolean): Unit = synchronized { if (original) { originals(a.id) = a } else { @@ -71,7 +130,7 @@ private object Accumulators { def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) { - ret(id) = accum.value + ret(id) = accum.localValue } return ret } @@ -80,7 +139,7 @@ private object Accumulators { def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulator[Any]] += value + originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value } } } diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 4867829c17..76d1c92a12 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -225,9 +225,10 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { // Called by the Cache to report that an entry has been dropped from it def dropEntry(datasetId: Any, partition: Int) { - datasetId match { - //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. - case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost) + val (keySpaceId, innerId) = datasetId.asInstanceOf[(Any, Any)] + if (keySpaceId == cache.keySpaceId) { + // TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. + trackerActor !! DroppedFromCache(innerId.asInstanceOf[Int], partition, Utils.getHost) } } diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index c795b6c351..c8cb730d14 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -37,17 +37,17 @@ class Executor extends org.apache.mesos.Executor with Logging { // Make sure an appropriate class loader is set for remote actors RemoteActor.classLoader = getClass.getClassLoader - + + // Create our ClassLoader (using spark properties) and set it on this thread + classLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(classLoader) + // Initialize Spark environment (using system properties read above) env = SparkEnv.createFromSystemProperties(false) SparkEnv.set(env) // Old stuff that isn't yet using env Broadcast.initialize(false) - // Create our ClassLoader (using spark properties) and set it on this thread - classLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(classLoader) - // Start worker thread pool threadPool = new ThreadPoolExecutor( 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 0d11ab9cbd..07dafabf2e 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -38,10 +38,10 @@ trait Logging { // Log methods that take Throwables (Exceptions/Errors) too def logInfo(msg: => String, throwable: Throwable) = - if (log.isInfoEnabled) log.info(msg) + if (log.isInfoEnabled) log.info(msg, throwable) def logDebug(msg: => String, throwable: Throwable) = - if (log.isDebugEnabled) log.debug(msg) + if (log.isDebugEnabled) log.debug(msg, throwable) def logWarning(msg: => String, throwable: Throwable) = if (log.isWarnEnabled) log.warn(msg, throwable) diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index 024a4580ac..d05ef0ab5f 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -8,12 +8,16 @@ abstract class Partitioner extends Serializable { class HashPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions - def getPartition(key: Any) = { - val mod = key.hashCode % partitions - if (mod < 0) { - mod + partitions + def getPartition(key: Any): Int = { + if (key == null) { + return 0 } else { - mod // Guard against negative hash codes + val mod = key.hashCode % partitions + if (mod < 0) { + mod + partitions + } else { + mod // Guard against negative hash codes + } } } @@ -31,36 +35,41 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( private val ascending: Boolean = true) extends Partitioner { + // An array of upper bounds for the first (partitions - 1) partitions private val rangeBounds: Array[K] = { - val rddSize = rdd.count() - val maxSampleSize = partitions * 10.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(true, frac, 1).map(_._1).collect() - .sortWith((x, y) => if (ascending) x < y else x > y) - if (rddSample.length == 0) { + if (partitions == 1) { Array() } else { - val bounds = new Array[K](partitions) - for (i <- 0 until partitions) { - bounds(i) = rddSample(i * rddSample.length / partitions) + val rddSize = rdd.count() + val maxSampleSize = partitions * 10.0 + val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) + val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _) + if (rddSample.length == 0) { + Array() + } else { + val bounds = new Array[K](partitions - 1) + for (i <- 0 until partitions - 1) { + val index = (rddSample.length - 1) * (i + 1) / partitions + bounds(i) = rddSample(index) + } + bounds } - bounds } } - def numPartitions = rangeBounds.length + def numPartitions = partitions def getPartition(key: Any): Int = { // TODO: Use a binary search here if number of partitions is large val k = key.asInstanceOf[K] var partition = 0 - while (partition < rangeBounds.length - 1 && k > rangeBounds(partition)) { + while (partition < rangeBounds.length && k > rangeBounds(partition)) { partition += 1 } if (ascending) { partition } else { - rangeBounds.length - 1 - partition + rangeBounds.length - partition } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 4c4b2ee30d..ede7571bf6 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -261,6 +261,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) .saveAsSequenceFile(path) } + + /** A private method for tests, to look at the contents of each partition */ + private[spark] def collectPartitions(): Array[Array[T]] = { + sc.runJob(this, (iter: Iterator[T]) => iter.toArray) + } } class MappedRDD[U: ClassManifest, T: ClassManifest]( diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 9fa2180269..e220972e8f 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -148,15 +148,12 @@ class SparkContext( /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { - val job = new NewHadoopJob - NewFileInputFormat.addInputPath(job, new Path(path)) - val conf = job.getConfiguration newAPIHadoopFile( path, fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - conf) + new Configuration) } /** @@ -248,6 +245,15 @@ class SparkContext( def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) + /** + * create an accumulatable shared variable, with a `+=` method + * @tparam T accumulator type + * @tparam R type that can be added to the accumulator + */ + def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + new Accumulable(initialValue, param) + + // Keep around a weak hash map of values to Cached versions? def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index cd752f8b65..7e07811c90 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -26,7 +26,7 @@ object SparkEnv { val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") - val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] + val serializer = Class.forName(serializerClass, true, Thread.currentThread.getContextClassLoader).newInstance().asInstanceOf[Serializer] val closureSerializerClass = System.getProperty("spark.closure.serializer", "spark.JavaSerializer") diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 06049749a9..07094a034e 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -175,7 +175,7 @@ object Broadcast extends Logging with Serializable { } private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) { override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) } diff --git a/examples/src/main/scala/spark/examples/Vector.scala b/core/src/main/scala/spark/util/Vector.scala index 2abccbafce..4e95ac2ac6 100644 --- a/examples/src/main/scala/spark/examples/Vector.scala +++ b/core/src/main/scala/spark/util/Vector.scala @@ -1,8 +1,8 @@ -package spark.examples +package spark.util class Vector(val elements: Array[Double]) extends Serializable { def length = elements.length - + def apply(index: Int) = elements(index) def + (other: Vector): Vector = { @@ -29,12 +29,43 @@ class Vector(val elements: Array[Double]) extends Serializable { return ans } + /** + * return (this + plus) dot other, but without creating any intermediate storage + * @param plus + * @param other + * @return + */ + def plusDot(plus: Vector, other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + if (length != plus.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) + plus(i)) * other(i) + i += 1 + } + return ans + } + + def +=(other: Vector) { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + elements(i) += other(i) + i += 1 + } + } + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) def / (d: Double): Vector = this * (1 / d) def unary_- = this * -1 - + def sum = elements.reduceLeft(_ + _) def squaredDist(other: Vector): Double = { @@ -76,6 +107,8 @@ object Vector { implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { def addInPlace(t1: Vector, t2: Vector) = t1 + t2 + def zero(initialValue: Vector) = Vector.zeros(initialValue.length) } + } diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala new file mode 100644 index 0000000000..a59b77fc85 --- /dev/null +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -0,0 +1,82 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import collection.mutable +import java.util.Random +import scala.math.exp +import scala.math.signum +import spark.SparkContext._ + +class AccumulatorSuite extends FunSuite with ShouldMatchers { + + test ("basic accumulation"){ + val sc = new SparkContext("local", "test") + val acc : Accumulator[Int] = sc.accumulator(0) + + val d = sc.parallelize(1 to 20) + d.foreach{x => acc += x} + acc.value should be (210) + sc.stop() + } + + test ("value not assignable from tasks") { + val sc = new SparkContext("local", "test") + val acc : Accumulator[Int] = sc.accumulator(0) + + val d = sc.parallelize(1 to 20) + evaluating {d.foreach{x => acc.value = x}} should produce [Exception] + sc.stop() + } + + test ("add value to collection accumulators") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + d.foreach { + x => acc += x + } + val v = acc.value.asInstanceOf[mutable.Set[Int]] + for (i <- 1 to maxI) { + v should contain(i) + } + sc.stop() + } + } + + + implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { + def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { + new mutable.HashSet[Any]() + } + } + + + test ("value not readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + val thrown = evaluating { + d.foreach { + x => acc.value += x + } + } should produce [SparkException] + println(thrown) + } + } + +}
\ No newline at end of file diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala index 750703de30..1e0b587421 100644 --- a/core/src/test/scala/spark/BroadcastSuite.scala +++ b/core/src/test/scala/spark/BroadcastSuite.scala @@ -1,23 +1,31 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter -class BroadcastSuite extends FunSuite { +class BroadcastSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + test("basic broadcast") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val list = List(1, 2, 3, 4) val listBroadcast = sc.broadcast(list) val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) assert(results.collect.toSet === Set((1, 10), (2, 10))) - sc.stop() } test("broadcast variables accessed in multiple threads") { - val sc = new SparkContext("local[10]", "test") + sc = new SparkContext("local[10]", "test") val list = List(1, 2, 3, 4) val listBroadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) - sc.stop() } } diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index 75df4bee09..6145baee7b 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -1,6 +1,7 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import org.scalatest.prop.Checkers import scala.collection.mutable.ArrayBuffer @@ -20,11 +21,20 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite { +class FailureSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { - val sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3, 3).map { x => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 @@ -39,13 +49,12 @@ class FailureSuite extends FunSuite { assert(FailureSuiteState.tasksRun === 4) } assert(results.toList === List(1,4,9)) - sc.stop() FailureSuiteState.clear() } // Run a map-reduce job in which a reduce task deterministically fails once. test("failure in a two-stage job") { - val sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { case (k, v) => FailureSuiteState.synchronized { @@ -61,12 +70,11 @@ class FailureSuite extends FunSuite { assert(FailureSuiteState.tasksRun === 4) } assert(results.toSet === Set((1, 1), (2, 4), (3, 9))) - sc.stop() FailureSuiteState.clear() } test("failure because task results are not serializable") { - val sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) val thrown = intercept[spark.SparkException] { @@ -75,7 +83,6 @@ class FailureSuite extends FunSuite { assert(thrown.getClass === classOf[spark.SparkException]) assert(thrown.getMessage.contains("NotSerializableException")) - sc.stop() FailureSuiteState.clear() } diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index b12014e6be..4cb9c7802f 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -6,13 +6,23 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import org.apache.hadoop.io._ import SparkContext._ -class FileSuite extends FunSuite { +class FileSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + test("text files") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 4) @@ -23,11 +33,10 @@ class FileSuite extends FunSuite { assert(content === "1\n2\n3\n4\n") // Also try reading it in as a text file RDD assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) - sc.stop() } test("SequenceFiles") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) @@ -35,11 +44,10 @@ class FileSuite extends FunSuite { // Try reading the output back as a SequenceFile val output = sc.sequenceFile[IntWritable, Text](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } test("SequenceFile with writable key") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) @@ -47,11 +55,10 @@ class FileSuite extends FunSuite { // Try reading the output back as a SequenceFile val output = sc.sequenceFile[IntWritable, Text](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } test("SequenceFile with writable value") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x))) @@ -59,11 +66,10 @@ class FileSuite extends FunSuite { // Try reading the output back as a SequenceFile val output = sc.sequenceFile[IntWritable, Text](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } test("SequenceFile with writable key and value") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) @@ -71,11 +77,10 @@ class FileSuite extends FunSuite { // Try reading the output back as a SequenceFile val output = sc.sequenceFile[IntWritable, Text](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } test("implicit conversions in reading SequenceFiles") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) @@ -89,11 +94,10 @@ class FileSuite extends FunSuite { assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) val output3 = sc.sequenceFile[IntWritable, String](outputDir) assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } test("object files of ints") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 4) @@ -101,11 +105,10 @@ class FileSuite extends FunSuite { // Try reading the output back as an object file val output = sc.objectFile[Int](outputDir) assert(output.collect().toList === List(1, 2, 3, 4)) - sc.stop() } test("object files of complex types") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) @@ -113,12 +116,11 @@ class FileSuite extends FunSuite { // Try reading the output back as an object file val output = sc.objectFile[(Int, String)](outputDir) assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) - sc.stop() } test("write SequenceFile using new Hadoop API") { import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) @@ -126,12 +128,11 @@ class FileSuite extends FunSuite { outputDir) val output = sc.sequenceFile[IntWritable, Text](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } test("read SequenceFile using new Hadoop API") { import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) @@ -139,6 +140,5 @@ class FileSuite extends FunSuite { val output = sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - sc.stop() } } diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala index 078071209a..7fdb3847ec 100644 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ b/core/src/test/scala/spark/KryoSerializerSuite.scala @@ -8,7 +8,8 @@ import com.esotericsoftware.kryo._ import SparkContext._ -class KryoSerializerSuite extends FunSuite { +class KryoSerializerSuite extends FunSuite{ + test("basic types") { val ser = (new KryoSerializer).newInstance() def check[T](t: T): Unit = diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala index 0e6820cbdc..2f1bea58b5 100644 --- a/core/src/test/scala/spark/MesosSchedulerSuite.scala +++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala @@ -3,7 +3,7 @@ package spark import org.scalatest.FunSuite class MesosSchedulerSuite extends FunSuite { - test("memoryStringToMb"){ + test("memoryStringToMb") { assert(MesosScheduler.memoryStringToMb("1") == 0) assert(MesosScheduler.memoryStringToMb("1048575") == 0) diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 7f7f9493dc..cf2ffeb9b1 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,12 +1,23 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import scala.collection.mutable.ArrayBuffer import SparkContext._ -class PartitioningSuite extends FunSuite { +class PartitioningSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + + test("HashPartitioner equality") { val p2 = new HashPartitioner(2) val p4 = new HashPartitioner(4) @@ -20,7 +31,7 @@ class PartitioningSuite extends FunSuite { } test("RangePartitioner equality") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") // Make an RDD where all the elements are the same so that the partition range bounds // are deterministically all the same. @@ -46,12 +57,10 @@ class PartitioningSuite extends FunSuite { assert(p4 != descendingP4) assert(descendingP2 != p2) assert(descendingP4 != p4) - - sc.stop() } test("HashPartitioner not equal to RangePartitioner") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) val hashP2 = new HashPartitioner(2) @@ -59,11 +68,10 @@ class PartitioningSuite extends FunSuite { assert(hashP2 === hashP2) assert(hashP2 != rangeP2) assert(rangeP2 != hashP2) - sc.stop() } test("partitioner preservation") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) @@ -95,7 +103,5 @@ class PartitioningSuite extends FunSuite { assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) - - sc.stop() } } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index d5dc2efd91..db1b9835a0 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -1,12 +1,21 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import SparkContext._ -class PipedRDDSuite extends FunSuite { - +class PipedRDDSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + test("basic pipe") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe(Seq("cat")) @@ -18,18 +27,16 @@ class PipedRDDSuite extends FunSuite { assert(c(1) === "2") assert(c(2) === "3") assert(c(3) === "4") - sc.stop() } test("pipe with env variable") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) val c = piped.collect() assert(c.size === 2) assert(c(0) === "LALALA") assert(c(1) === "LALALA") - sc.stop() } } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7199b634b7..3924a6890b 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -2,11 +2,21 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import SparkContext._ -class RDDSuite extends FunSuite { +class RDDSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + test("basic operations") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.reduce(_ + _) === 10) @@ -18,11 +28,10 @@ class RDDSuite extends FunSuite { assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) - sc.stop() } test("aggregate") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] val emptyMap = new StringMap { @@ -40,6 +49,5 @@ class RDDSuite extends FunSuite { } val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) - sc.stop() } } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index c61cb90f82..3ba0e274b7 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -1,6 +1,7 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ import org.scalacheck.Gen @@ -12,9 +13,18 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ -class ShuffleSuite extends FunSuite { +class ShuffleSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + test("groupByKey") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) val groups = pairs.groupByKey().collect() assert(groups.size === 2) @@ -22,11 +32,10 @@ class ShuffleSuite extends FunSuite { assert(valuesFor1.toList.sorted === List(1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) - sc.stop() } test("groupByKey with duplicates") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val groups = pairs.groupByKey().collect() assert(groups.size === 2) @@ -34,11 +43,10 @@ class ShuffleSuite extends FunSuite { assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) - sc.stop() } test("groupByKey with negative key hash codes") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) val groups = pairs.groupByKey().collect() assert(groups.size === 2) @@ -46,11 +54,10 @@ class ShuffleSuite extends FunSuite { assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) - sc.stop() } test("groupByKey with many output partitions") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) val groups = pairs.groupByKey(10).collect() assert(groups.size === 2) @@ -58,37 +65,33 @@ class ShuffleSuite extends FunSuite { assert(valuesFor1.toList.sorted === List(1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) - sc.stop() } test("reduceByKey") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_).collect() assert(sums.toSet === Set((1, 7), (2, 1))) - sc.stop() } test("reduceByKey with collectAsMap") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_).collectAsMap() assert(sums.size === 2) assert(sums(1) === 7) assert(sums(2) === 1) - sc.stop() } test("reduceByKey with many output partitons") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) - sc.stop() } test("join") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.join(rdd2).collect() @@ -99,11 +102,10 @@ class ShuffleSuite extends FunSuite { (2, (1, 'y')), (2, (1, 'z')) )) - sc.stop() } test("join all-to-all") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) val joined = rdd1.join(rdd2).collect() @@ -116,11 +118,10 @@ class ShuffleSuite extends FunSuite { (1, (3, 'x')), (1, (3, 'y')) )) - sc.stop() } test("leftOuterJoin") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.leftOuterJoin(rdd2).collect() @@ -132,11 +133,10 @@ class ShuffleSuite extends FunSuite { (2, (1, Some('z'))), (3, (1, None)) )) - sc.stop() } test("rightOuterJoin") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.rightOuterJoin(rdd2).collect() @@ -148,20 +148,18 @@ class ShuffleSuite extends FunSuite { (2, (Some(1), 'z')), (4, (None, 'w')) )) - sc.stop() } test("join with no matches") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) val joined = rdd1.join(rdd2).collect() assert(joined.size === 0) - sc.stop() } test("join with many output partitions") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.join(rdd2, 10).collect() @@ -172,11 +170,10 @@ class ShuffleSuite extends FunSuite { (2, (1, 'y')), (2, (1, 'z')) )) - sc.stop() } test("groupWith") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.groupWith(rdd2).collect() @@ -187,17 +184,15 @@ class ShuffleSuite extends FunSuite { (3, (ArrayBuffer(1), ArrayBuffer())), (4, (ArrayBuffer(), ArrayBuffer('w'))) )) - sc.stop() } test("zero-partition RDD") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val emptyDir = Files.createTempDir() val file = sc.textFile(emptyDir.getAbsolutePath) assert(file.splits.size == 0) assert(file.collect().toList === Nil) // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - sc.stop() + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } } diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index caff884966..8fa1442a4d 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -1,50 +1,87 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite { - test("sortByKey") { - val sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0))) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) +class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging { + + var sc: SparkContext = _ + + after { + if (sc != null) { sc.stop() + } + } + + test("sortByKey") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0))) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) } - test("sortLargeArray") { - val sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - sc.stop() + test("large array") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } - test("sortDescending") { - val sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - sc.stop() + test("sort descending") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr) + assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } - test("morePartitionsThanElements") { - val sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 30) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - sc.stop() + test("more partitions than elements") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 30) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } - test("emptyRDD") { - val sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - sc.stop() + test("empty RDD") { + sc = new SparkContext("local", "test") + val pairArr = new Array[(Int, Int)](0) + val pairs = sc.parallelize(pairArr) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("partition balancing") { + sc = new SparkContext("local", "test") + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey() + assert(sorted.collect() === pairArr.sortBy(_._1)) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 200 + partitions(1).length should be > 200 + partitions(2).length should be > 200 + partitions(3).length should be > 200 + partitions(0).last should be < partitions(1).head + partitions(1).last should be < partitions(2).head + partitions(2).last should be < partitions(3).head + } + + test("partition balancing for descending sort") { + sc = new SparkContext("local", "test") + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey(false) + assert(sorted.collect() === pairArr.sortBy(_._1).reverse) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 200 + partitions(1).length should be > 200 + partitions(2).length should be > 200 + partitions(3).length should be > 200 + partitions(0).last should be > partitions(1).head + partitions(1).last should be > partitions(2).head + partitions(2).last should be > partitions(3).head } } diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala index cadf01432f..a8b5ccf721 100644 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ b/core/src/test/scala/spark/ThreadingSuite.scala @@ -5,6 +5,7 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import SparkContext._ @@ -21,9 +22,19 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite { +class ThreadingSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if(sc != null) { + sc.stop() + } + } + + test("accessing SparkContext form a different thread") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val nums = sc.parallelize(1 to 10, 2) val sem = new Semaphore(0) @volatile var answer1: Int = 0 @@ -38,11 +49,10 @@ class ThreadingSuite extends FunSuite { sem.acquire() assert(answer1 === 55) assert(answer2 === 1) - sc.stop() } test("accessing SparkContext form multiple threads") { - val sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val nums = sc.parallelize(1 to 10, 2) val sem = new Semaphore(0) @volatile var ok = true @@ -67,11 +77,10 @@ class ThreadingSuite extends FunSuite { if (!ok) { fail("One or more threads got the wrong answer from an RDD operation") } - sc.stop() } test("accessing multi-threaded SparkContext form multiple threads") { - val sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test") val nums = sc.parallelize(1 to 10, 2) val sem = new Semaphore(0) @volatile var ok = true @@ -96,13 +105,12 @@ class ThreadingSuite extends FunSuite { if (!ok) { fail("One or more threads got the wrong answer from an RDD operation") } - sc.stop() } test("parallel job execution") { // This test launches two jobs with two threads each on a 4-core local cluster. Each thread // waits until there are 4 threads running at once, to test that both jobs have been launched. - val sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test") val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() @@ -132,6 +140,5 @@ class ThreadingSuite extends FunSuite { if (ThreadingSuiteState.failed.get()) { fail("One or more threads didn't see runningThreads = 4") } - sc.stop() } } diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala index f31251e509..1ac4737f04 100644 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ b/core/src/test/scala/spark/UtilsSuite.scala @@ -2,7 +2,7 @@ package spark import org.scalatest.FunSuite import java.io.{ByteArrayOutputStream, ByteArrayInputStream} -import util.Random +import scala.util.Random class UtilsSuite extends FunSuite { diff --git a/examples/src/main/scala/spark/examples/LocalFileLR.scala b/examples/src/main/scala/spark/examples/LocalFileLR.scala index b819fe80fe..f958ef9f72 100644 --- a/examples/src/main/scala/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/spark/examples/LocalFileLR.scala @@ -1,7 +1,7 @@ package spark.examples import java.util.Random -import Vector._ +import spark.util.Vector object LocalFileLR { val D = 10 // Numer of dimensions diff --git a/examples/src/main/scala/spark/examples/LocalKMeans.scala b/examples/src/main/scala/spark/examples/LocalKMeans.scala index 7e8e7a6959..b442c604cd 100644 --- a/examples/src/main/scala/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/spark/examples/LocalKMeans.scala @@ -1,8 +1,7 @@ package spark.examples import java.util.Random -import Vector._ -import spark.SparkContext +import spark.util.Vector import spark.SparkContext._ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala index 72c5009109..f2ac2b3e06 100644 --- a/examples/src/main/scala/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/spark/examples/LocalLR.scala @@ -1,7 +1,7 @@ package spark.examples import java.util.Random -import Vector._ +import spark.util.Vector object LocalLR { val N = 10000 // Number of data points diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala index a87e0a408c..1a3c1c8264 100644 --- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala @@ -2,7 +2,7 @@ package spark.examples import java.util.Random import scala.math.exp -import Vector._ +import spark.util.Vector import spark._ object SparkHdfsLR { diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala index f310dffe23..9a30148130 100644 --- a/examples/src/main/scala/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala @@ -1,8 +1,8 @@ package spark.examples import java.util.Random -import Vector._ import spark.SparkContext +import spark.util.Vector import spark.SparkContext._ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala index 38af1f4080..9b801ed31e 100644 --- a/examples/src/main/scala/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/spark/examples/SparkLR.scala @@ -2,7 +2,7 @@ package spark.examples import java.util.Random import scala.math.exp -import Vector._ +import spark.util.Vector import spark._ object SparkLR { |