aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala17
-rw-r--r--core/src/main/scala/spark/Accumulators.scala85
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala7
-rw-r--r--core/src/main/scala/spark/Executor.scala10
-rw-r--r--core/src/main/scala/spark/Logging.scala4
-rw-r--r--core/src/main/scala/spark/Partitioner.scala45
-rw-r--r--core/src/main/scala/spark/RDD.scala5
-rw-r--r--core/src/main/scala/spark/SparkContext.scala14
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala2
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala2
-rw-r--r--core/src/main/scala/spark/util/Vector.scala (renamed from examples/src/main/scala/spark/examples/Vector.scala)39
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala82
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala18
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala21
-rw-r--r--core/src/test/scala/spark/FileSuite.scala42
-rw-r--r--core/src/test/scala/spark/KryoSerializerSuite.scala3
-rw-r--r--core/src/test/scala/spark/MesosSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala24
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala19
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala18
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala59
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala103
-rw-r--r--core/src/test/scala/spark/ThreadingSuite.scala25
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala2
-rw-r--r--examples/src/main/scala/spark/examples/LocalFileLR.scala2
-rw-r--r--examples/src/main/scala/spark/examples/LocalKMeans.scala3
-rw-r--r--examples/src/main/scala/spark/examples/LocalLR.scala2
-rw-r--r--examples/src/main/scala/spark/examples/SparkHdfsLR.scala2
-rw-r--r--examples/src/main/scala/spark/examples/SparkKMeans.scala2
-rw-r--r--examples/src/main/scala/spark/examples/SparkLR.scala2
30 files changed, 468 insertions, 193 deletions
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 0eda80af64..d2189169d2 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -1,6 +1,6 @@
package spark.bagel
-import org.scalatest.{FunSuite, Assertions}
+import org.scalatest.{FunSuite, Assertions, BeforeAndAfter}
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
@@ -13,9 +13,16 @@ import spark._
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
-class BagelSuite extends FunSuite with Assertions {
+class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ sc.stop()
+ }
+
test("halting by voting") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
@@ -26,11 +33,10 @@ class BagelSuite extends FunSuite with Assertions {
}
for ((id, vert) <- result.collect)
assert(vert.age === numSupersteps)
- sc.stop()
}
test("halting by message silence") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
@@ -48,6 +54,5 @@ class BagelSuite extends FunSuite with Assertions {
}
for ((id, vert) <- result.collect)
assert(vert.age === numSupersteps)
- sc.stop()
}
}
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index 86e2061b9f..bf77417852 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -4,21 +4,39 @@ import java.io._
import scala.collection.mutable.Map
-class Accumulator[T] (
+class Accumulable[T,R] (
@transient initialValue: T,
- param: AccumulatorParam[T])
+ param: AccumulableParam[T,R])
extends Serializable {
val id = Accumulators.newId
@transient
- var value_ = initialValue // Current value on master
+ private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
Accumulators.register(this, true)
- def += (term: T) { value_ = param.addInPlace(value_, term) }
- def value = this.value_
+ /**
+ * add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def += (term: R) { value_ = param.addAccumulator(value_, term) }
+
+ /**
+ * merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `+=`.
+ * @param term the other Accumulable that will get merged with this
+ */
+ def ++= (term: T) { value_ = param.addInPlace(value_, term)}
+ def value = {
+ if (!deserialized) value_
+ else throw new UnsupportedOperationException("Can't use read value in task")
+ }
+
+ private[spark] def localValue = value_
+
def value_= (t: T) {
if (!deserialized) value_ = t
else throw new UnsupportedOperationException("Can't use value_= in task")
@@ -35,17 +53,58 @@ class Accumulator[T] (
override def toString = value_.toString
}
-trait AccumulatorParam[T] extends Serializable {
- def addInPlace(t1: T, t2: T): T
- def zero(initialValue: T): T
+class Accumulator[T](
+ @transient initialValue: T,
+ param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param)
+
+/**
+ * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type
+ * as the accumulated value
+ * @tparam T
+ */
+trait AccumulatorParam[T] extends AccumulableParam[T,T] {
+ def addAccumulator(t1: T, t2: T) : T = {
+ addInPlace(t1, t2)
+ }
+}
+
+/**
+ * A datatype that can be accumulated, ie. has a commutative & associative +.
+ *
+ * You must define how to add data, and how to merge two of these together. For some datatypes, these might be
+ * the same operation (eg., a counter). In that case, you might want to use [[spark.AccumulatorParam]]. They won't
+ * always be the same, though -- eg., imagine you are accumulating a set. You will add items to the set, and you
+ * will union two sets together.
+ *
+ * @tparam R the full accumulated data
+ * @tparam T partial data that can be added in
+ */
+trait AccumulableParam[R,T] extends Serializable {
+ /**
+ * Add additional data to the accumulator value.
+ * @param t1 the current value of the accumulator
+ * @param t2 the data to be added to the accumulator
+ * @return the new value of the accumulator
+ */
+ def addAccumulator(t1: R, t2: T) : R
+
+ /**
+ * merge two accumulated values together
+ * @param t1 one set of accumulated data
+ * @param t2 another set of accumulated data
+ * @return both data sets merged together
+ */
+ def addInPlace(t1: R, t2: R): R
+
+ def zero(initialValue: R): R
}
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then
- val originals = Map[Long, Accumulator[_]]()
- val localAccums = Map[Thread, Map[Long, Accumulator[_]]]()
+ val originals = Map[Long, Accumulable[_,_]]()
+ val localAccums = Map[Thread, Map[Long, Accumulable[_,_]]]()
var lastId: Long = 0
def newId: Long = synchronized {
@@ -53,7 +112,7 @@ private object Accumulators {
return lastId
}
- def register(a: Accumulator[_], original: Boolean): Unit = synchronized {
+ def register(a: Accumulable[_,_], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = a
} else {
@@ -71,7 +130,7 @@ private object Accumulators {
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
- ret(id) = accum.value
+ ret(id) = accum.localValue
}
return ret
}
@@ -80,7 +139,7 @@ private object Accumulators {
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (originals.contains(id)) {
- originals(id).asInstanceOf[Accumulator[Any]] += value
+ originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
}
}
}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 4867829c17..76d1c92a12 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -225,9 +225,10 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(datasetId: Any, partition: Int) {
- datasetId match {
- //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
- case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
+ val (keySpaceId, innerId) = datasetId.asInstanceOf[(Any, Any)]
+ if (keySpaceId == cache.keySpaceId) {
+ // TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
+ trackerActor !! DroppedFromCache(innerId.asInstanceOf[Int], partition, Utils.getHost)
}
}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index c795b6c351..c8cb730d14 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -37,17 +37,17 @@ class Executor extends org.apache.mesos.Executor with Logging {
// Make sure an appropriate class loader is set for remote actors
RemoteActor.classLoader = getClass.getClassLoader
-
+
+ // Create our ClassLoader (using spark properties) and set it on this thread
+ classLoader = createClassLoader()
+ Thread.currentThread.setContextClassLoader(classLoader)
+
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(false)
SparkEnv.set(env)
// Old stuff that isn't yet using env
Broadcast.initialize(false)
- // Create our ClassLoader (using spark properties) and set it on this thread
- classLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(classLoader)
-
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 0d11ab9cbd..07dafabf2e 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -38,10 +38,10 @@ trait Logging {
// Log methods that take Throwables (Exceptions/Errors) too
def logInfo(msg: => String, throwable: Throwable) =
- if (log.isInfoEnabled) log.info(msg)
+ if (log.isInfoEnabled) log.info(msg, throwable)
def logDebug(msg: => String, throwable: Throwable) =
- if (log.isDebugEnabled) log.debug(msg)
+ if (log.isDebugEnabled) log.debug(msg, throwable)
def logWarning(msg: => String, throwable: Throwable) =
if (log.isWarnEnabled) log.warn(msg, throwable)
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index 024a4580ac..d05ef0ab5f 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -8,12 +8,16 @@ abstract class Partitioner extends Serializable {
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
- def getPartition(key: Any) = {
- val mod = key.hashCode % partitions
- if (mod < 0) {
- mod + partitions
+ def getPartition(key: Any): Int = {
+ if (key == null) {
+ return 0
} else {
- mod // Guard against negative hash codes
+ val mod = key.hashCode % partitions
+ if (mod < 0) {
+ mod + partitions
+ } else {
+ mod // Guard against negative hash codes
+ }
}
}
@@ -31,36 +35,41 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
private val ascending: Boolean = true)
extends Partitioner {
+ // An array of upper bounds for the first (partitions - 1) partitions
private val rangeBounds: Array[K] = {
- val rddSize = rdd.count()
- val maxSampleSize = partitions * 10.0
- val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
- val rddSample = rdd.sample(true, frac, 1).map(_._1).collect()
- .sortWith((x, y) => if (ascending) x < y else x > y)
- if (rddSample.length == 0) {
+ if (partitions == 1) {
Array()
} else {
- val bounds = new Array[K](partitions)
- for (i <- 0 until partitions) {
- bounds(i) = rddSample(i * rddSample.length / partitions)
+ val rddSize = rdd.count()
+ val maxSampleSize = partitions * 10.0
+ val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
+ val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
+ if (rddSample.length == 0) {
+ Array()
+ } else {
+ val bounds = new Array[K](partitions - 1)
+ for (i <- 0 until partitions - 1) {
+ val index = (rddSample.length - 1) * (i + 1) / partitions
+ bounds(i) = rddSample(index)
+ }
+ bounds
}
- bounds
}
}
- def numPartitions = rangeBounds.length
+ def numPartitions = partitions
def getPartition(key: Any): Int = {
// TODO: Use a binary search here if number of partitions is large
val k = key.asInstanceOf[K]
var partition = 0
- while (partition < rangeBounds.length - 1 && k > rangeBounds(partition)) {
+ while (partition < rangeBounds.length && k > rangeBounds(partition)) {
partition += 1
}
if (ascending) {
partition
} else {
- rangeBounds.length - 1 - partition
+ rangeBounds.length - partition
}
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 4c4b2ee30d..ede7571bf6 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -261,6 +261,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
.map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
.saveAsSequenceFile(path)
}
+
+ /** A private method for tests, to look at the contents of each partition */
+ private[spark] def collectPartitions(): Array[Array[T]] = {
+ sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
+ }
}
class MappedRDD[U: ClassManifest, T: ClassManifest](
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 9fa2180269..e220972e8f 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -148,15 +148,12 @@ class SparkContext(
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
- val job = new NewHadoopJob
- NewFileInputFormat.addInputPath(job, new Path(path))
- val conf = job.getConfiguration
newAPIHadoopFile(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
- conf)
+ new Configuration)
}
/**
@@ -248,6 +245,15 @@ class SparkContext(
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
+ /**
+ * create an accumulatable shared variable, with a `+=` method
+ * @tparam T accumulator type
+ * @tparam R type that can be added to the accumulator
+ */
+ def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
+ new Accumulable(initialValue, param)
+
+
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index cd752f8b65..7e07811c90 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -26,7 +26,7 @@ object SparkEnv {
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
- val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+ val serializer = Class.forName(serializerClass, true, Thread.currentThread.getContextClassLoader).newInstance().asInstanceOf[Serializer]
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 06049749a9..07094a034e 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -175,7 +175,7 @@ object Broadcast extends Logging with Serializable {
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
diff --git a/examples/src/main/scala/spark/examples/Vector.scala b/core/src/main/scala/spark/util/Vector.scala
index 2abccbafce..4e95ac2ac6 100644
--- a/examples/src/main/scala/spark/examples/Vector.scala
+++ b/core/src/main/scala/spark/util/Vector.scala
@@ -1,8 +1,8 @@
-package spark.examples
+package spark.util
class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length
-
+
def apply(index: Int) = elements(index)
def + (other: Vector): Vector = {
@@ -29,12 +29,43 @@ class Vector(val elements: Array[Double]) extends Serializable {
return ans
}
+ /**
+ * return (this + plus) dot other, but without creating any intermediate storage
+ * @param plus
+ * @param other
+ * @return
+ */
+ def plusDot(plus: Vector, other: Vector): Double = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ if (length != plus.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ ans += (this(i) + plus(i)) * other(i)
+ i += 1
+ }
+ return ans
+ }
+
+ def +=(other: Vector) {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ elements(i) += other(i)
+ i += 1
+ }
+ }
+
def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
def / (d: Double): Vector = this * (1 / d)
def unary_- = this * -1
-
+
def sum = elements.reduceLeft(_ + _)
def squaredDist(other: Vector): Double = {
@@ -76,6 +107,8 @@ object Vector {
implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] {
def addInPlace(t1: Vector, t2: Vector) = t1 + t2
+
def zero(initialValue: Vector) = Vector.zeros(initialValue.length)
}
+
}
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
new file mode 100644
index 0000000000..a59b77fc85
--- /dev/null
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -0,0 +1,82 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import collection.mutable
+import java.util.Random
+import scala.math.exp
+import scala.math.signum
+import spark.SparkContext._
+
+class AccumulatorSuite extends FunSuite with ShouldMatchers {
+
+ test ("basic accumulation"){
+ val sc = new SparkContext("local", "test")
+ val acc : Accumulator[Int] = sc.accumulator(0)
+
+ val d = sc.parallelize(1 to 20)
+ d.foreach{x => acc += x}
+ acc.value should be (210)
+ sc.stop()
+ }
+
+ test ("value not assignable from tasks") {
+ val sc = new SparkContext("local", "test")
+ val acc : Accumulator[Int] = sc.accumulator(0)
+
+ val d = sc.parallelize(1 to 20)
+ evaluating {d.foreach{x => acc.value = x}} should produce [Exception]
+ sc.stop()
+ }
+
+ test ("add value to collection accumulators") {
+ import SetAccum._
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ val sc = new SparkContext("local[" + nThreads + "]", "test")
+ val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
+ val d = sc.parallelize(1 to maxI)
+ d.foreach {
+ x => acc += x
+ }
+ val v = acc.value.asInstanceOf[mutable.Set[Int]]
+ for (i <- 1 to maxI) {
+ v should contain(i)
+ }
+ sc.stop()
+ }
+ }
+
+
+ implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
+ def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
+ t1 ++= t2
+ t1
+ }
+ def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = {
+ t1 += t2
+ t1
+ }
+ def zero(t: mutable.Set[Any]) : mutable.Set[Any] = {
+ new mutable.HashSet[Any]()
+ }
+ }
+
+
+ test ("value not readable in tasks") {
+ import SetAccum._
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ val sc = new SparkContext("local[" + nThreads + "]", "test")
+ val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
+ val d = sc.parallelize(1 to maxI)
+ val thrown = evaluating {
+ d.foreach {
+ x => acc.value += x
+ }
+ } should produce [SparkException]
+ println(thrown)
+ }
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
index 750703de30..1e0b587421 100644
--- a/core/src/test/scala/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -1,23 +1,31 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
-class BroadcastSuite extends FunSuite {
+class BroadcastSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("basic broadcast") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
- sc.stop()
}
test("broadcast variables accessed in multiple threads") {
- val sc = new SparkContext("local[10]", "test")
+ sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index 75df4bee09..6145baee7b 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -1,6 +1,7 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
@@ -20,11 +21,20 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite {
+class FailureSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") {
- val sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3, 3).map { x =>
FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1
@@ -39,13 +49,12 @@ class FailureSuite extends FunSuite {
assert(FailureSuiteState.tasksRun === 4)
}
assert(results.toList === List(1,4,9))
- sc.stop()
FailureSuiteState.clear()
}
// Run a map-reduce job in which a reduce task deterministically fails once.
test("failure in a two-stage job") {
- val sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
case (k, v) =>
FailureSuiteState.synchronized {
@@ -61,12 +70,11 @@ class FailureSuite extends FunSuite {
assert(FailureSuiteState.tasksRun === 4)
}
assert(results.toSet === Set((1, 1), (2, 4), (3, 9)))
- sc.stop()
FailureSuiteState.clear()
}
test("failure because task results are not serializable") {
- val sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
val thrown = intercept[spark.SparkException] {
@@ -75,7 +83,6 @@ class FailureSuite extends FunSuite {
assert(thrown.getClass === classOf[spark.SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
- sc.stop()
FailureSuiteState.clear()
}
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index b12014e6be..4cb9c7802f 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -6,13 +6,23 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import org.apache.hadoop.io._
import SparkContext._
-class FileSuite extends FunSuite {
+class FileSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("text files") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 4)
@@ -23,11 +33,10 @@ class FileSuite extends FunSuite {
assert(content === "1\n2\n3\n4\n")
// Also try reading it in as a text file RDD
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
- sc.stop()
}
test("SequenceFiles") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa)
@@ -35,11 +44,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("SequenceFile with writable key") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x))
@@ -47,11 +55,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("SequenceFile with writable value") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x)))
@@ -59,11 +66,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("SequenceFile with writable key and value") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
@@ -71,11 +77,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("implicit conversions in reading SequenceFiles") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa)
@@ -89,11 +94,10 @@ class FileSuite extends FunSuite {
assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
val output3 = sc.sequenceFile[IntWritable, String](outputDir)
assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("object files of ints") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 4)
@@ -101,11 +105,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as an object file
val output = sc.objectFile[Int](outputDir)
assert(output.collect().toList === List(1, 2, 3, 4))
- sc.stop()
}
test("object files of complex types") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x))
@@ -113,12 +116,11 @@ class FileSuite extends FunSuite {
// Try reading the output back as an object file
val output = sc.objectFile[(Int, String)](outputDir)
assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa")))
- sc.stop()
}
test("write SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
@@ -126,12 +128,11 @@ class FileSuite extends FunSuite {
outputDir)
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("read SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
@@ -139,6 +140,5 @@ class FileSuite extends FunSuite {
val output =
sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala
index 078071209a..7fdb3847ec 100644
--- a/core/src/test/scala/spark/KryoSerializerSuite.scala
+++ b/core/src/test/scala/spark/KryoSerializerSuite.scala
@@ -8,7 +8,8 @@ import com.esotericsoftware.kryo._
import SparkContext._
-class KryoSerializerSuite extends FunSuite {
+class KryoSerializerSuite extends FunSuite{
+
test("basic types") {
val ser = (new KryoSerializer).newInstance()
def check[T](t: T): Unit =
diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala
index 0e6820cbdc..2f1bea58b5 100644
--- a/core/src/test/scala/spark/MesosSchedulerSuite.scala
+++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala
@@ -3,7 +3,7 @@ package spark
import org.scalatest.FunSuite
class MesosSchedulerSuite extends FunSuite {
- test("memoryStringToMb"){
+ test("memoryStringToMb") {
assert(MesosScheduler.memoryStringToMb("1") == 0)
assert(MesosScheduler.memoryStringToMb("1048575") == 0)
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 7f7f9493dc..cf2ffeb9b1 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,12 +1,23 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class PartitioningSuite extends FunSuite {
+class PartitioningSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
+
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
val p4 = new HashPartitioner(4)
@@ -20,7 +31,7 @@ class PartitioningSuite extends FunSuite {
}
test("RangePartitioner equality") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
@@ -46,12 +57,10 @@ class PartitioningSuite extends FunSuite {
assert(p4 != descendingP4)
assert(descendingP2 != p2)
assert(descendingP4 != p4)
-
- sc.stop()
}
test("HashPartitioner not equal to RangePartitioner") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
val hashP2 = new HashPartitioner(2)
@@ -59,11 +68,10 @@ class PartitioningSuite extends FunSuite {
assert(hashP2 === hashP2)
assert(hashP2 != rangeP2)
assert(rangeP2 != hashP2)
- sc.stop()
}
test("partitioner preservation") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
@@ -95,7 +103,5 @@ class PartitioningSuite extends FunSuite {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
-
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index d5dc2efd91..db1b9835a0 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -1,12 +1,21 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
-class PipedRDDSuite extends FunSuite {
-
+class PipedRDDSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("basic pipe") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
@@ -18,18 +27,16 @@ class PipedRDDSuite extends FunSuite {
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
- sc.stop()
}
test("pipe with env variable") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
assert(c.size === 2)
assert(c(0) === "LALALA")
assert(c(1) === "LALALA")
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7199b634b7..3924a6890b 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,11 +2,21 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
-class RDDSuite extends FunSuite {
+class RDDSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("basic operations") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
assert(nums.reduce(_ + _) === 10)
@@ -18,11 +28,10 @@ class RDDSuite extends FunSuite {
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))
- sc.stop()
}
test("aggregate") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
@@ -40,6 +49,5 @@ class RDDSuite extends FunSuite {
}
val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps)
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index c61cb90f82..3ba0e274b7 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,6 +1,7 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
@@ -12,9 +13,18 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class ShuffleSuite extends FunSuite {
+class ShuffleSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("groupByKey") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
@@ -22,11 +32,10 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("groupByKey with duplicates") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
@@ -34,11 +43,10 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("groupByKey with negative key hash codes") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
@@ -46,11 +54,10 @@ class ShuffleSuite extends FunSuite {
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("groupByKey with many output partitions") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey(10).collect()
assert(groups.size === 2)
@@ -58,37 +65,33 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("reduceByKey") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
- sc.stop()
}
test("reduceByKey with collectAsMap") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collectAsMap()
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
- sc.stop()
}
test("reduceByKey with many output partitons") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
- sc.stop()
}
test("join") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2).collect()
@@ -99,11 +102,10 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
- sc.stop()
}
test("join all-to-all") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
val joined = rdd1.join(rdd2).collect()
@@ -116,11 +118,10 @@ class ShuffleSuite extends FunSuite {
(1, (3, 'x')),
(1, (3, 'y'))
))
- sc.stop()
}
test("leftOuterJoin") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.leftOuterJoin(rdd2).collect()
@@ -132,11 +133,10 @@ class ShuffleSuite extends FunSuite {
(2, (1, Some('z'))),
(3, (1, None))
))
- sc.stop()
}
test("rightOuterJoin") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.rightOuterJoin(rdd2).collect()
@@ -148,20 +148,18 @@ class ShuffleSuite extends FunSuite {
(2, (Some(1), 'z')),
(4, (None, 'w'))
))
- sc.stop()
}
test("join with no matches") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
- sc.stop()
}
test("join with many output partitions") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2, 10).collect()
@@ -172,11 +170,10 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
- sc.stop()
}
test("groupWith") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.groupWith(rdd2).collect()
@@ -187,17 +184,15 @@ class ShuffleSuite extends FunSuite {
(3, (ArrayBuffer(1), ArrayBuffer())),
(4, (ArrayBuffer(), ArrayBuffer('w')))
))
- sc.stop()
}
test("zero-partition RDD") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
assert(file.splits.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
- sc.stop()
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
}
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index caff884966..8fa1442a4d 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -1,50 +1,87 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite {
- test("sortByKey") {
- val sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
- assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
+class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
+
+ var sc: SparkContext = _
+
+ after {
+ if (sc != null) {
sc.stop()
+ }
+ }
+
+ test("sortByKey") {
+ sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
- test("sortLargeArray") {
- val sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
- sc.stop()
+ test("large array") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
- test("sortDescending") {
- val sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
- sc.stop()
+ test("sort descending") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr)
+ assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
- test("morePartitionsThanElements") {
- val sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr, 30)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
- sc.stop()
+ test("more partitions than elements") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 30)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
- test("emptyRDD") {
- val sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = new Array[(Int, Int)](0)
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
- sc.stop()
+ test("empty RDD") {
+ sc = new SparkContext("local", "test")
+ val pairArr = new Array[(Int, Int)](0)
+ val pairs = sc.parallelize(pairArr)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ }
+
+ test("partition balancing") {
+ sc = new SparkContext("local", "test")
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 4).sortByKey()
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ val partitions = sorted.collectPartitions()
+ logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
+ partitions(0).length should be > 200
+ partitions(1).length should be > 200
+ partitions(2).length should be > 200
+ partitions(3).length should be > 200
+ partitions(0).last should be < partitions(1).head
+ partitions(1).last should be < partitions(2).head
+ partitions(2).last should be < partitions(3).head
+ }
+
+ test("partition balancing for descending sort") {
+ sc = new SparkContext("local", "test")
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
+ assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
+ val partitions = sorted.collectPartitions()
+ logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
+ partitions(0).length should be > 200
+ partitions(1).length should be > 200
+ partitions(2).length should be > 200
+ partitions(3).length should be > 200
+ partitions(0).last should be > partitions(1).head
+ partitions(1).last should be > partitions(2).head
+ partitions(2).last should be > partitions(3).head
}
}
diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala
index cadf01432f..a8b5ccf721 100644
--- a/core/src/test/scala/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/spark/ThreadingSuite.scala
@@ -5,6 +5,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
@@ -21,9 +22,19 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends FunSuite {
+class ThreadingSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
+
test("accessing SparkContext form a different thread") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var answer1: Int = 0
@@ -38,11 +49,10 @@ class ThreadingSuite extends FunSuite {
sem.acquire()
assert(answer1 === 55)
assert(answer2 === 1)
- sc.stop()
}
test("accessing SparkContext form multiple threads") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var ok = true
@@ -67,11 +77,10 @@ class ThreadingSuite extends FunSuite {
if (!ok) {
fail("One or more threads got the wrong answer from an RDD operation")
}
- sc.stop()
}
test("accessing multi-threaded SparkContext form multiple threads") {
- val sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local[4]", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var ok = true
@@ -96,13 +105,12 @@ class ThreadingSuite extends FunSuite {
if (!ok) {
fail("One or more threads got the wrong answer from an RDD operation")
}
- sc.stop()
}
test("parallel job execution") {
// This test launches two jobs with two threads each on a 4-core local cluster. Each thread
// waits until there are 4 threads running at once, to test that both jobs have been launched.
- val sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local[4]", "test")
val nums = sc.parallelize(1 to 2, 2)
val sem = new Semaphore(0)
ThreadingSuiteState.clear()
@@ -132,6 +140,5 @@ class ThreadingSuite extends FunSuite {
if (ThreadingSuiteState.failed.get()) {
fail("One or more threads didn't see runningThreads = 4")
}
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index f31251e509..1ac4737f04 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -2,7 +2,7 @@ package spark
import org.scalatest.FunSuite
import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
-import util.Random
+import scala.util.Random
class UtilsSuite extends FunSuite {
diff --git a/examples/src/main/scala/spark/examples/LocalFileLR.scala b/examples/src/main/scala/spark/examples/LocalFileLR.scala
index b819fe80fe..f958ef9f72 100644
--- a/examples/src/main/scala/spark/examples/LocalFileLR.scala
+++ b/examples/src/main/scala/spark/examples/LocalFileLR.scala
@@ -1,7 +1,7 @@
package spark.examples
import java.util.Random
-import Vector._
+import spark.util.Vector
object LocalFileLR {
val D = 10 // Numer of dimensions
diff --git a/examples/src/main/scala/spark/examples/LocalKMeans.scala b/examples/src/main/scala/spark/examples/LocalKMeans.scala
index 7e8e7a6959..b442c604cd 100644
--- a/examples/src/main/scala/spark/examples/LocalKMeans.scala
+++ b/examples/src/main/scala/spark/examples/LocalKMeans.scala
@@ -1,8 +1,7 @@
package spark.examples
import java.util.Random
-import Vector._
-import spark.SparkContext
+import spark.util.Vector
import spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala
index 72c5009109..f2ac2b3e06 100644
--- a/examples/src/main/scala/spark/examples/LocalLR.scala
+++ b/examples/src/main/scala/spark/examples/LocalLR.scala
@@ -1,7 +1,7 @@
package spark.examples
import java.util.Random
-import Vector._
+import spark.util.Vector
object LocalLR {
val N = 10000 // Number of data points
diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
index a87e0a408c..1a3c1c8264 100644
--- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
@@ -2,7 +2,7 @@ package spark.examples
import java.util.Random
import scala.math.exp
-import Vector._
+import spark.util.Vector
import spark._
object SparkHdfsLR {
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala
index f310dffe23..9a30148130 100644
--- a/examples/src/main/scala/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala
@@ -1,8 +1,8 @@
package spark.examples
import java.util.Random
-import Vector._
import spark.SparkContext
+import spark.util.Vector
import spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala
index 38af1f4080..9b801ed31e 100644
--- a/examples/src/main/scala/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/spark/examples/SparkLR.scala
@@ -2,7 +2,7 @@ package spark.examples
import java.util.Random
import scala.math.exp
-import Vector._
+import spark.util.Vector
import spark._
object SparkLR {