aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2012-07-27 15:18:23 -0700
committerMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2012-07-27 15:18:23 -0700
commit1f19fbb8db96135efc79fe56fcd96f1f02598b86 (patch)
tree025d32f88267885837aa0e27df99251730a9f8f5 /core
parent85cd9979f2ba0ed6c0b3d458ab4d3d4f0a7909b2 (diff)
parentb51d733a5783ef29077951e842882bb002a4139e (diff)
downloadspark-1f19fbb8db96135efc79fe56fcd96f1f02598b86.tar.gz
spark-1f19fbb8db96135efc79fe56fcd96f1f02598b86.tar.bz2
spark-1f19fbb8db96135efc79fe56fcd96f1f02598b86.zip
Merge remote-tracking branch 'upstream/dev' into dev
Conflicts: core/src/main/scala/spark/broadcast/Broadcast.scala
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/RDD.scala2
-rw-r--r--core/src/main/scala/spark/SparkContext.scala11
-rw-r--r--core/src/main/scala/spark/api/java/JavaDoubleRDD.scala71
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala277
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala38
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala170
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala219
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java47
-rw-r--r--core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java13
-rw-r--r--core/src/main/scala/spark/api/java/function/DoubleFunction.java13
-rw-r--r--core/src/main/scala/spark/api/java/function/FlatMapFunction.scala7
-rw-r--r--core/src/main/scala/spark/api/java/function/Function.java21
-rw-r--r--core/src/main/scala/spark/api/java/function/Function2.java17
-rw-r--r--core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java25
-rw-r--r--core/src/main/scala/spark/api/java/function/PairFunction.java25
-rw-r--r--core/src/main/scala/spark/api/java/function/VoidFunction.scala12
-rw-r--r--core/src/main/scala/spark/partial/PartialResult.scala28
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala2
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala18
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala21
-rw-r--r--core/src/test/scala/spark/FileSuite.scala42
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java552
-rw-r--r--core/src/test/scala/spark/KryoSerializerSuite.scala3
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala24
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala19
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala18
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala61
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala29
-rw-r--r--core/src/test/scala/spark/ThreadingSuite.scala25
29 files changed, 1695 insertions, 115 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1710ff58b3..429e9c936f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -112,6 +112,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
+ def distinct(): RDD[T] = map(x => (x, "")).reduceByKey((x, y) => x).map(_._1)
+
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 01b0a29ce8..bfd3e8d732 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -48,10 +48,12 @@ import spark.storage.BlockManagerMaster
class SparkContext(
master: String,
frameworkName: String,
- val sparkHome: String = null,
- val jars: Seq[String] = Nil)
+ val sparkHome: String,
+ val jars: Seq[String])
extends Logging {
+ def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
+
// Ensure logging is initialized before we spawn any threads
initLogging()
@@ -182,15 +184,12 @@ class SparkContext(
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
- val job = new NewHadoopJob
- NewFileInputFormat.addInputPath(job, new Path(path))
- val conf = job.getConfiguration
newAPIHadoopFile(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
- conf)
+ new Configuration)
}
/**
diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
new file mode 100644
index 0000000000..7c0b17c45e
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
@@ -0,0 +1,71 @@
+package spark.api.java
+
+import spark.RDD
+import spark.SparkContext.doubleRDDToDoubleRDDFunctions
+import spark.api.java.function.{Function => JFunction}
+import spark.util.StatCounter
+import spark.partial.{BoundedDouble, PartialResult}
+import spark.storage.StorageLevel
+
+import java.lang.Double
+
+class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
+
+ override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
+
+ override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x))
+
+ override def wrapRDD(rdd: RDD[Double]): JavaDoubleRDD =
+ new JavaDoubleRDD(rdd.map(_.doubleValue))
+
+ // Common RDD functions
+
+ import JavaDoubleRDD.fromRDD
+
+ def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
+
+ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+
+ // first() has to be overriden here in order for its return type to be Double instead of Object.
+ override def first(): Double = srdd.first()
+
+ // Transformations (return a new RDD)
+
+ def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
+
+ def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
+ fromRDD(srdd.filter(x => f(x).booleanValue()))
+
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
+ fromRDD(srdd.sample(withReplacement, fraction, seed))
+
+ def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
+
+ // Double RDD functions
+
+ def sum(): Double = srdd.sum()
+
+ def stats(): StatCounter = srdd.stats()
+
+ def mean(): Double = srdd.mean()
+
+ def variance(): Double = srdd.variance()
+
+ def stdev(): Double = srdd.stdev()
+
+ def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
+ srdd.meanApprox(timeout, confidence)
+
+ def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout)
+
+ def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
+ srdd.sumApprox(timeout, confidence)
+
+ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+}
+
+object JavaDoubleRDD {
+ def fromRDD(rdd: RDD[scala.Double]): JavaDoubleRDD = new JavaDoubleRDD(rdd)
+
+ implicit def toRDD(rdd: JavaDoubleRDD): RDD[scala.Double] = rdd.srdd
+}
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
new file mode 100644
index 0000000000..c28a13b061
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -0,0 +1,277 @@
+package spark.api.java
+
+import spark.SparkContext.rddToPairRDDFunctions
+import spark.api.java.function.{Function2 => JFunction2}
+import spark.api.java.function.{Function => JFunction}
+import spark.partial.BoundedDouble
+import spark.partial.PartialResult
+import spark.storage.StorageLevel
+import spark._
+
+import java.util.{List => JList}
+import java.util.Comparator
+
+import scala.Tuple2
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.OutputFormat
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+import org.apache.hadoop.conf.Configuration
+
+class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
+ implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
+
+ override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
+
+ override val classManifest: ClassManifest[(K, V)] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+
+ import JavaPairRDD._
+
+ // Common RDD functions
+
+ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
+
+ def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.persist(newLevel))
+
+ // Transformations (return a new RDD)
+
+ def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
+
+ def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
+
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
+
+ def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.union(other.rdd))
+
+ // first() has to be overridden here so that the generated method has the signature
+ // 'public scala.Tuple2 first()'; if the trait's definition is used,
+ // then the method has the signature 'public java.lang.Object first()',
+ // causing NoSuchMethodErrors at runtime.
+ override def first(): (K, V) = rdd.first()
+
+ // Pair RDD functions
+
+ def combineByKey[C](createCombiner: Function[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ partitioner: Partitioner): JavaPairRDD[K, C] = {
+ implicit val cm: ClassManifest[C] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ fromRDD(rdd.combineByKey(
+ createCombiner,
+ mergeValue,
+ mergeCombiners,
+ partitioner
+ ))
+ }
+
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ numSplits: Int): JavaPairRDD[K, C] =
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
+
+ def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.reduceByKey(partitioner, func))
+
+ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
+ mapAsJavaMap(rdd.reduceByKeyLocally(func))
+
+ def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+
+ def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
+ rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+
+ def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
+ : PartialResult[java.util.Map[K, BoundedDouble]] =
+ rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+
+ def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] =
+ fromRDD(rdd.reduceByKey(func, numSplits))
+
+ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey(partitioner)))
+
+ def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey(numSplits)))
+
+ def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] =
+ fromRDD(rdd.partitionBy(partitioner))
+
+ def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other, partitioner))
+
+ def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (V, Option[W])] =
+ fromRDD(rdd.leftOuterJoin(other, partitioner))
+
+ def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (Option[V], W)] =
+ fromRDD(rdd.rightOuterJoin(other, partitioner))
+
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
+ implicit val cm: ClassManifest[C] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners))
+ }
+
+ def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
+ val partitioner = rdd.defaultPartitioner(rdd)
+ fromRDD(reduceByKey(partitioner, func))
+ }
+
+ def groupByKey(): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey()))
+
+ def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other))
+
+ def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other, numSplits))
+
+ def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
+ fromRDD(rdd.leftOuterJoin(other))
+
+ def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] =
+ fromRDD(rdd.leftOuterJoin(other, numSplits))
+
+ def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
+ fromRDD(rdd.rightOuterJoin(other))
+
+ def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] =
+ fromRDD(rdd.rightOuterJoin(other, numSplits))
+
+ def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+
+ def mapValues[U](f: Function[V, U]): JavaPairRDD[K, U] = {
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ fromRDD(rdd.mapValues(f))
+ }
+
+ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: V) => f.apply(x).asScala
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ fromRDD(rdd.flatMapValues(fn))
+ }
+
+ def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (JList[V], JList[W])] =
+ fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner)))
+
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner)
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
+
+ def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
+ fromRDD(cogroupResultToJava(rdd.cogroup(other)))
+
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
+
+ def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])]
+ = fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits)))
+
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int)
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits)))
+
+ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
+ fromRDD(cogroupResultToJava(rdd.groupWith(other)))
+
+ def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
+
+ def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key))
+
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: JobConf) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
+ }
+
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
+ }
+
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: Configuration) {
+ rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
+ }
+
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
+ rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
+ }
+
+ def saveAsHadoopDataset(conf: JobConf) {
+ rdd.saveAsHadoopDataset(conf)
+ }
+
+
+ // Ordered RDD Functions
+ def sortByKey(): JavaPairRDD[K, V] = sortByKey(true)
+
+ def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
+ sortByKey(comp, true)
+ }
+
+ def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true)
+
+ def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = {
+ class KeyOrdering(val a: K) extends Ordered[K] {
+ override def compare(b: K) = comp.compare(a, b)
+ }
+ implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending))
+ }
+}
+
+object JavaPairRDD {
+ def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
+ vcm: ClassManifest[T]): RDD[(K, JList[T])] =
+ rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
+
+ def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
+ vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
+ Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+
+ def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
+ Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
+ JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
+ (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
+ seqAsJavaList(x._2),
+ seqAsJavaList(x._3)))
+
+ def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd)
+
+ implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
new file mode 100644
index 0000000000..541aa1e60b
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -0,0 +1,38 @@
+package spark.api.java
+
+import spark._
+import spark.api.java.function.{Function => JFunction}
+import spark.storage.StorageLevel
+
+class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
+JavaRDDLike[T, JavaRDD[T]] {
+
+ override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
+
+ // Common RDD functions
+
+ def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
+
+ def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+
+ // Transformations (return a new RDD)
+
+ def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
+
+ def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
+ wrapRDD(rdd.filter((x => f(x).booleanValue())))
+
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
+ wrapRDD(rdd.sample(withReplacement, fraction, seed))
+
+ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
+
+}
+
+object JavaRDD {
+
+ implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
+
+ implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
+}
+
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
new file mode 100644
index 0000000000..785dd96394
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -0,0 +1,170 @@
+package spark.api.java
+
+import spark.{SparkContext, Split, RDD}
+import spark.api.java.JavaPairRDD._
+import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
+import spark.partial.{PartialResult, BoundedDouble}
+import spark.storage.StorageLevel
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConversions._
+import java.{util, lang}
+import scala.Tuple2
+
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
+ def wrapRDD(rdd: RDD[T]): This
+
+ implicit val classManifest: ClassManifest[T]
+
+ def rdd: RDD[T]
+
+ def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
+
+ def context: SparkContext = rdd.context
+
+ def id: Int = rdd.id
+
+ def getStorageLevel: StorageLevel = rdd.getStorageLevel
+
+ def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split))
+
+ // Transformations (return a new RDD)
+
+ def map[R](f: JFunction[T, R]): JavaRDD[R] =
+ new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
+
+ def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
+ new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
+
+ def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
+ }
+
+ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
+ }
+
+ def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
+ }
+
+ def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
+ }
+
+ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
+ }
+
+ def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
+ }
+
+ def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
+ JavaPairRDD[K, V] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
+ }
+
+ def glom(): JavaRDD[JList[T]] =
+ new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+
+ def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
+ JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
+ other.classManifest)
+
+ def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
+ implicit val kcm: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val vcm: ClassManifest[JList[T]] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
+ }
+
+ def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = {
+ implicit val kcm: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val vcm: ClassManifest[JList[T]] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm)
+ }
+
+ def pipe(command: String): JavaRDD[String] = rdd.pipe(command)
+
+ def pipe(command: JList[String]): JavaRDD[String] =
+ rdd.pipe(asScalaBuffer(command))
+
+ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] =
+ rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env))
+
+ // Actions (launch a job to return a value to the user program)
+
+ def foreach(f: VoidFunction[T]) {
+ val cleanF = rdd.context.clean(f)
+ rdd.foreach(cleanF)
+ }
+
+ def collect(): JList[T] = {
+ import scala.collection.JavaConversions._
+ val arr: java.util.Collection[T] = rdd.collect().toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
+
+ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T =
+ rdd.fold(zeroValue)(f)
+
+ def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U]): U =
+ rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
+
+ def count(): Long = rdd.count()
+
+ def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
+ rdd.countApprox(timeout, confidence)
+
+ def countApprox(timeout: Long): PartialResult[BoundedDouble] =
+ rdd.countApprox(timeout)
+
+ def countByValue(): java.util.Map[T, java.lang.Long] =
+ mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2)))))
+
+ def countByValueApprox(
+ timeout: Long,
+ confidence: Double
+ ): PartialResult[java.util.Map[T, BoundedDouble]] =
+ rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+
+ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
+ rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+
+ def take(num: Int): JList[T] = {
+ import scala.collection.JavaConversions._
+ val arr: java.util.Collection[T] = rdd.take(num).toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
+ import scala.collection.JavaConversions._
+ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ def first(): T = rdd.first()
+
+ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+
+ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
+}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
new file mode 100644
index 0000000000..2d43bfa4ef
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -0,0 +1,219 @@
+package spark.api.java
+
+import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
+import spark.SparkContext.IntAccumulatorParam
+import spark.SparkContext.DoubleAccumulatorParam
+import spark.broadcast.Broadcast
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.JobConf
+
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+
+
+import scala.collection.JavaConversions
+
+class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
+
+ def this(master: String, frameworkName: String) = this(new SparkContext(master, frameworkName))
+
+ val env = sc.env
+
+ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
+ }
+
+ def parallelize[T](list: java.util.List[T]): JavaRDD[T] =
+ parallelize(list, sc.defaultParallelism)
+
+
+ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
+ : JavaPairRDD[K, V] = {
+ implicit val kcm: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val vcm: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
+ }
+
+ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] =
+ parallelizePairs(list, sc.defaultParallelism)
+
+ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD =
+ JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()),
+ numSlices))
+
+ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD =
+ parallelizeDoubles(list, sc.defaultParallelism)
+
+ def textFile(path: String): JavaRDD[String] = sc.textFile(path)
+
+ def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
+
+ /**Get an RDD for a Hadoop SequenceFile with given key and value types */
+ def sequenceFile[K, V](path: String,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
+ }
+
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
+ JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass))
+ }
+
+ /**
+ * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
+ * BytesWritable values that contain a serialized partition. This is still an experimental storage
+ * format and may not be supported exactly as is in future Spark releases. It will also be pretty
+ * slow if you use the default serializer (Java serialization), though the nice thing about it is
+ * that there's very little effort required to save arbitrary objects.
+ */
+ def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ sc.objectFile(path, minSplits)(cm)
+ }
+
+ def objectFile[T](path: String): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ sc.objectFile(path)(cm)
+ }
+
+ /**
+ * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
+ * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
+ * etc).
+ */
+ def hadoopRDD[K, V, F <: InputFormat[K, V]](
+ conf: JobConf,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
+ }
+
+ def hadoopRDD[K, V, F <: InputFormat[K, V]](
+ conf: JobConf,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V]
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
+ }
+
+ /**Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](
+ path: String,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
+ }
+
+ def hadoopFile[K, V, F <: InputFormat[K, V]](
+ path: String,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V]
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopFile(path,
+ inputFormatClass, keyClass, valueClass))
+ }
+
+ /**
+ * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
+ * and extra configuration options to pass to the input format.
+ */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
+ path: String,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V],
+ conf: Configuration): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(kClass)
+ implicit val vcm = ClassManifest.fromClass(vClass)
+ new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
+ }
+
+ /**
+ * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
+ * and extra configuration options to pass to the input format.
+ */
+ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
+ conf: Configuration,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V]): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(kClass)
+ implicit val vcm = ClassManifest.fromClass(vClass)
+ new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
+ }
+
+ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
+ val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
+ implicit val cm: ClassManifest[T] = first.classManifest
+ sc.union(rdds: _*)(cm)
+ }
+
+ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
+ : JavaPairRDD[K, V] = {
+ val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
+ implicit val cm: ClassManifest[(K, V)] = first.classManifest
+ implicit val kcm: ClassManifest[K] = first.kManifest
+ implicit val vcm: ClassManifest[V] = first.vManifest
+ new JavaPairRDD(sc.union(rdds: _*)(cm))(kcm, vcm)
+ }
+
+ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = {
+ val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd)
+ new JavaDoubleRDD(sc.union(rdds: _*))
+ }
+
+ def intAccumulator(initialValue: Int): Accumulator[Int] =
+ sc.accumulator(initialValue)(IntAccumulatorParam)
+
+ def doubleAccumulator(initialValue: Double): Accumulator[Double] =
+ sc.accumulator(initialValue)(DoubleAccumulatorParam)
+
+ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
+ sc.accumulator(initialValue)(accumulatorParam)
+
+ def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
+
+ def stop() {
+ sc.stop()
+ }
+
+ def getSparkHome(): Option[String] = sc.getSparkHome()
+}
+
+object JavaSparkContext {
+ implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc)
+
+ implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc
+}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java
new file mode 100644
index 0000000000..97344e73da
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -0,0 +1,47 @@
+package spark.api.java;
+
+import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.List;
+
+// See
+// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html
+abstract class JavaSparkContextVarargsWorkaround {
+ public <T> JavaRDD<T> union(JavaRDD<T>... rdds) {
+ if (rdds.length == 0) {
+ throw new IllegalArgumentException("Union called on empty list");
+ }
+ ArrayList<JavaRDD<T>> rest = new ArrayList<JavaRDD<T>>(rdds.length - 1);
+ for (int i = 1; i < rdds.length; i++) {
+ rest.add(rdds[i]);
+ }
+ return union(rdds[0], rest);
+ }
+
+ public JavaDoubleRDD union(JavaDoubleRDD... rdds) {
+ if (rdds.length == 0) {
+ throw new IllegalArgumentException("Union called on empty list");
+ }
+ ArrayList<JavaDoubleRDD> rest = new ArrayList<JavaDoubleRDD>(rdds.length - 1);
+ for (int i = 1; i < rdds.length; i++) {
+ rest.add(rdds[i]);
+ }
+ return union(rdds[0], rest);
+ }
+
+ public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V>... rdds) {
+ if (rdds.length == 0) {
+ throw new IllegalArgumentException("Union called on empty list");
+ }
+ ArrayList<JavaPairRDD<K, V>> rest = new ArrayList<JavaPairRDD<K, V>>(rdds.length - 1);
+ for (int i = 1; i < rdds.length; i++) {
+ rest.add(rdds[i]);
+ }
+ return union(rdds[0], rest);
+ }
+
+ // These methods take separate "first" and "rest" elements to avoid having the same type erasure
+ abstract public <T> JavaRDD<T> union(JavaRDD<T> first, List<JavaRDD<T>> rest);
+ abstract public JavaDoubleRDD union(JavaDoubleRDD first, List<JavaDoubleRDD> rest);
+ abstract public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V> first, List<JavaPairRDD<K, V>> rest);
+}
diff --git a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java
new file mode 100644
index 0000000000..240747390c
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java
@@ -0,0 +1,13 @@
+package spark.api.java.function;
+
+
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
+// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
+public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+ implements Serializable {
+ public abstract Iterable<Double> apply(T t);
+}
diff --git a/core/src/main/scala/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFunction.java
new file mode 100644
index 0000000000..378ffd427d
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/DoubleFunction.java
@@ -0,0 +1,13 @@
+package spark.api.java.function;
+
+
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+// DoubleFunction does not extend Function because some UDF functions, like map,
+// are overloaded for both Function and DoubleFunction.
+public abstract class DoubleFunction<T> extends AbstractFunction1<T, Double>
+ implements Serializable {
+ public abstract Double apply(T t);
+}
diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala
new file mode 100644
index 0000000000..1045e006a0
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala
@@ -0,0 +1,7 @@
+package spark.api.java.function
+
+abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
+ def apply(x: T) : java.lang.Iterable[R]
+
+ def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
+}
diff --git a/core/src/main/scala/spark/api/java/function/Function.java b/core/src/main/scala/spark/api/java/function/Function.java
new file mode 100644
index 0000000000..ad38b89f0f
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/Function.java
@@ -0,0 +1,21 @@
+package spark.api.java.function;
+
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+
+/**
+ * Base class for functions whose return types do not have special RDDs; DoubleFunction is
+ * handled separately, to allow DoubleRDDs to be constructed when mapping RDDs to doubles.
+ */
+public abstract class Function<T, R> extends AbstractFunction1<T, R> implements Serializable {
+ public abstract R apply(T t);
+
+ public ClassManifest<R> returnType() {
+ return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+}
+
diff --git a/core/src/main/scala/spark/api/java/function/Function2.java b/core/src/main/scala/spark/api/java/function/Function2.java
new file mode 100644
index 0000000000..883804dfe4
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/Function2.java
@@ -0,0 +1,17 @@
+package spark.api.java.function;
+
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction2;
+
+import java.io.Serializable;
+
+public abstract class Function2<T1, T2, R> extends AbstractFunction2<T1, T2, R>
+ implements Serializable {
+ public ClassManifest<R> returnType() {
+ return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public abstract R apply(T1 t1, T2 t2);
+}
+
diff --git a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java
new file mode 100644
index 0000000000..aead6c4e03
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java
@@ -0,0 +1,25 @@
+package spark.api.java.function;
+
+import scala.Tuple2;
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
+// overloaded for both FlatMapFunction and PairFlatMapFunction.
+public abstract class PairFlatMapFunction<T, K, V> extends AbstractFunction1<T, Iterable<Tuple2<K,
+ V>>> implements Serializable {
+
+ public ClassManifest<K> keyType() {
+ return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public ClassManifest<V> valueType() {
+ return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public abstract Iterable<Tuple2<K, V>> apply(T t);
+
+}
diff --git a/core/src/main/scala/spark/api/java/function/PairFunction.java b/core/src/main/scala/spark/api/java/function/PairFunction.java
new file mode 100644
index 0000000000..3284bfb11e
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/PairFunction.java
@@ -0,0 +1,25 @@
+package spark.api.java.function;
+
+import scala.Tuple2;
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+// PairFunction does not extend Function because some UDF functions, like map,
+// are overloaded for both Function and PairFunction.
+public abstract class PairFunction<T, K, V> extends AbstractFunction1<T, Tuple2<K,
+ V>> implements Serializable {
+
+ public ClassManifest<K> keyType() {
+ return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public ClassManifest<V> valueType() {
+ return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public abstract Tuple2<K, V> apply(T t);
+
+}
diff --git a/core/src/main/scala/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/spark/api/java/function/VoidFunction.scala
new file mode 100644
index 0000000000..be4cbaff39
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/VoidFunction.scala
@@ -0,0 +1,12 @@
+package spark.api.java.function
+
+// This allows Java users to write void methods without having to return Unit.
+abstract class VoidFunction[T] extends Serializable {
+ def apply(t: T) : Unit
+}
+
+// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly
+// return Unit), so it is implicitly converted to a Function1[T, Unit]:
+object VoidFunction {
+ implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f(x))
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala
index e7d2d4e8cc..200ed4ea1e 100644
--- a/core/src/main/scala/spark/partial/PartialResult.scala
+++ b/core/src/main/scala/spark/partial/PartialResult.scala
@@ -57,6 +57,32 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
}
}
+ /**
+ * Transform this PartialResult into a PartialResult of type T.
+ */
+ def map[T](f: R => T) : PartialResult[T] = {
+ new PartialResult[T](f(initialVal), isFinal) {
+ override def getFinalValue() : T = synchronized {
+ f(PartialResult.this.getFinalValue())
+ }
+ override def onComplete(handler: T => Unit): PartialResult[T] = synchronized {
+ PartialResult.this.onComplete(handler.compose(f)).map(f)
+ }
+ override def onFail(handler: Exception => Unit) {
+ synchronized {
+ PartialResult.this.onFail(handler)
+ }
+ }
+ override def toString : String = synchronized {
+ PartialResult.this.getFinalValueInternal() match {
+ case Some(value) => "(final: " + f(value) + ")"
+ case None => "(partial: " + initialValue + ")"
+ }
+ }
+ def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f)
+ }
+ }
+
private[spark] def setFinalValue(value: R) {
synchronized {
if (finalValue != None) {
@@ -70,6 +96,8 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
}
}
+ private def getFinalValueInternal() = finalValue
+
private[spark] def setFailure(exception: Exception) {
synchronized {
if (failure != None) {
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
index efb1ae7529..11d7939204 100644
--- a/core/src/main/scala/spark/util/StatCounter.scala
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -5,7 +5,7 @@ package spark.util
* numerically robust way. Includes support for merging two StatCounters. Based on Welford and
* Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
*/
-class StatCounter(values: TraversableOnce[Double]) {
+class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
index 750703de30..1e0b587421 100644
--- a/core/src/test/scala/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -1,23 +1,31 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
-class BroadcastSuite extends FunSuite {
+class BroadcastSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("basic broadcast") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
- sc.stop()
}
test("broadcast variables accessed in multiple threads") {
- val sc = new SparkContext("local[10]", "test")
+ sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index 816411debe..0aaa16dca4 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -1,6 +1,7 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
@@ -22,11 +23,20 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite {
+class FailureSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") {
- val sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3, 3).map { x =>
FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1
@@ -41,13 +51,12 @@ class FailureSuite extends FunSuite {
assert(FailureSuiteState.tasksRun === 4)
}
assert(results.toList === List(1,4,9))
- sc.stop()
FailureSuiteState.clear()
}
// Run a map-reduce job in which a reduce task deterministically fails once.
test("failure in a two-stage job") {
- val sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
case (k, v) =>
FailureSuiteState.synchronized {
@@ -63,12 +72,11 @@ class FailureSuite extends FunSuite {
assert(FailureSuiteState.tasksRun === 4)
}
assert(results.toSet === Set((1, 1), (2, 4), (3, 9)))
- sc.stop()
FailureSuiteState.clear()
}
test("failure because task results are not serializable") {
- val sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
val thrown = intercept[spark.SparkException] {
@@ -77,7 +85,6 @@ class FailureSuite extends FunSuite {
assert(thrown.getClass === classOf[spark.SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
- sc.stop()
FailureSuiteState.clear()
}
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index b12014e6be..4cb9c7802f 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -6,13 +6,23 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import org.apache.hadoop.io._
import SparkContext._
-class FileSuite extends FunSuite {
+class FileSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("text files") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 4)
@@ -23,11 +33,10 @@ class FileSuite extends FunSuite {
assert(content === "1\n2\n3\n4\n")
// Also try reading it in as a text file RDD
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
- sc.stop()
}
test("SequenceFiles") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa)
@@ -35,11 +44,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("SequenceFile with writable key") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x))
@@ -47,11 +55,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("SequenceFile with writable value") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x)))
@@ -59,11 +66,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("SequenceFile with writable key and value") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
@@ -71,11 +77,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("implicit conversions in reading SequenceFiles") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa)
@@ -89,11 +94,10 @@ class FileSuite extends FunSuite {
assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
val output3 = sc.sequenceFile[IntWritable, String](outputDir)
assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("object files of ints") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 4)
@@ -101,11 +105,10 @@ class FileSuite extends FunSuite {
// Try reading the output back as an object file
val output = sc.objectFile[Int](outputDir)
assert(output.collect().toList === List(1, 2, 3, 4))
- sc.stop()
}
test("object files of complex types") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x))
@@ -113,12 +116,11 @@ class FileSuite extends FunSuite {
// Try reading the output back as an object file
val output = sc.objectFile[(Int, String)](outputDir)
assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa")))
- sc.stop()
}
test("write SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
@@ -126,12 +128,11 @@ class FileSuite extends FunSuite {
outputDir)
val output = sc.sequenceFile[IntWritable, Text](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
test("read SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
@@ -139,6 +140,5 @@ class FileSuite extends FunSuite {
val output =
sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
new file mode 100644
index 0000000000..5f0293e55b
--- /dev/null
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -0,0 +1,552 @@
+package spark;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Files;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.Job;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import scala.Tuple2;
+
+import spark.api.java.JavaDoubleRDD;
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.*;
+import spark.partial.BoundedDouble;
+import spark.partial.PartialResult;
+import spark.storage.StorageLevel;
+import spark.util.StatCounter;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.*;
+
+// The test suite itself is Serializable so that anonymous Function implementations can be
+// serialized, as an alternative to converting these anonymous classes to static inner classes;
+// see http://stackoverflow.com/questions/758570/.
+public class JavaAPISuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaAPISuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ static class ReverseIntComparator implements Comparator<Integer>, Serializable {
+
+ @Override
+ public int compare(Integer a, Integer b) {
+ if (a > b) return -1;
+ else if (a < b) return 1;
+ else return 0;
+ }
+ };
+
+ @Test
+ public void sparkContextUnion() {
+ // Union of non-specialized JavaRDDs
+ List<String> strings = Arrays.asList("Hello", "World");
+ JavaRDD<String> s1 = sc.parallelize(strings);
+ JavaRDD<String> s2 = sc.parallelize(strings);
+ // Varargs
+ JavaRDD<String> sUnion = sc.union(s1, s2);
+ Assert.assertEquals(4, sUnion.count());
+ // List
+ List<JavaRDD<String>> list = new ArrayList<JavaRDD<String>>();
+ list.add(s2);
+ sUnion = sc.union(s1, list);
+ Assert.assertEquals(4, sUnion.count());
+
+ // Union of JavaDoubleRDDs
+ List<Double> doubles = Arrays.asList(1.0, 2.0);
+ JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD dUnion = sc.union(d1, d2);
+ Assert.assertEquals(4, dUnion.count());
+
+ // Union of JavaPairRDDs
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(1, 2));
+ pairs.add(new Tuple2<Integer, Integer>(3, 4));
+ JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> pUnion = sc.union(p1, p2);
+ Assert.assertEquals(4, pUnion.count());
+ }
+
+ @Test
+ public void sortByKey() {
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(0, 4));
+ pairs.add(new Tuple2<Integer, Integer>(3, 2));
+ pairs.add(new Tuple2<Integer, Integer>(-1, 1));
+
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+
+ // Default comparator
+ JavaPairRDD<Integer, Integer> sortedRDD = rdd.sortByKey();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ List<Tuple2<Integer, Integer>> sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+
+ // Custom comparator
+ sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false);
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+ }
+
+ @Test
+ public void foreach() {
+ JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World"));
+ rdd.foreach(new VoidFunction<String>() {
+ @Override
+ public void apply(String s) {
+ System.out.println(s);
+ }
+ });
+ }
+
+ @Test
+ public void groupBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() {
+ @Override
+ public Boolean apply(Integer x) {
+ return x % 2 == 0;
+ }
+ };
+ JavaPairRDD<Boolean, List<Integer>> oddsAndEvens = rdd.groupBy(isOdd);
+ Assert.assertEquals(2, oddsAndEvens.count());
+ Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens
+ Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
+
+ oddsAndEvens = rdd.groupBy(isOdd, 1);
+ Assert.assertEquals(2, oddsAndEvens.count());
+ Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens
+ Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
+ }
+
+ @Test
+ public void cogroup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 2),
+ new Tuple2<String, Integer>("Apples", 3)
+ ));
+ JavaPairRDD<String, Tuple2<List<String>, List<Integer>>> cogrouped = categories.cogroup(prices);
+ Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString());
+ Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString());
+
+ cogrouped.collect();
+ }
+
+ @Test
+ public void foldReduce() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer apply(Integer a, Integer b) {
+ return a + b;
+ }
+ };
+
+ int sum = rdd.fold(0, add);
+ Assert.assertEquals(33, sum);
+
+ sum = rdd.reduce(add);
+ Assert.assertEquals(33, sum);
+ }
+
+ @Test
+ public void reduceByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> counts = rdd.reduceByKey(
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer apply(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, counts.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, counts.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, counts.lookup(3).get(0).intValue());
+
+ Map<Integer, Integer> localCounts = counts.collectAsMap();
+ Assert.assertEquals(1, localCounts.get(1).intValue());
+ Assert.assertEquals(2, localCounts.get(2).intValue());
+ Assert.assertEquals(3, localCounts.get(3).intValue());
+
+ localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer,
+ Integer>() {
+ @Override
+ public Integer apply(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, localCounts.get(1).intValue());
+ Assert.assertEquals(2, localCounts.get(2).intValue());
+ Assert.assertEquals(3, localCounts.get(3).intValue());
+ }
+
+ @Test
+ public void approximateResults() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Map<Integer, Long> countsByValue = rdd.countByValue();
+ Assert.assertEquals(2, countsByValue.get(1).longValue());
+ Assert.assertEquals(1, countsByValue.get(13).longValue());
+
+ PartialResult<Map<Integer, BoundedDouble>> approx = rdd.countByValueApprox(1);
+ Map<Integer, BoundedDouble> finalValue = approx.getFinalValue();
+ Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01);
+ Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01);
+ }
+
+ @Test
+ public void take() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Assert.assertEquals(1, rdd.first().intValue());
+ List<Integer> firstTwo = rdd.take(2);
+ List<Integer> sample = rdd.takeSample(false, 2, 42);
+ }
+
+ @Test
+ public void cartesian() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ JavaRDD<String> stringRDD = sc.parallelize(Arrays.asList("Hello", "World"));
+ JavaPairRDD<String, Double> cartesian = stringRDD.cartesian(doubleRDD);
+ Assert.assertEquals(new Tuple2<String, Double>("Hello", 1.0), cartesian.first());
+ }
+
+ @Test
+ public void javaDoubleRDD() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ JavaDoubleRDD distinct = rdd.distinct();
+ Assert.assertEquals(5, distinct.count());
+ JavaDoubleRDD filter = rdd.filter(new Function<Double, Boolean>() {
+ @Override
+ public Boolean apply(Double x) {
+ return x > 2.0;
+ }
+ });
+ Assert.assertEquals(3, filter.count());
+ JavaDoubleRDD union = rdd.union(rdd);
+ Assert.assertEquals(12, union.count());
+ union = union.cache();
+ Assert.assertEquals(12, union.count());
+
+ Assert.assertEquals(20, rdd.sum(), 0.01);
+ StatCounter stats = rdd.stats();
+ Assert.assertEquals(20, stats.sum(), 0.01);
+ Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
+ Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
+ Assert.assertEquals(6.22222, rdd.variance(), 0.01);
+ Assert.assertEquals(2.49444, rdd.stdev(), 0.01);
+
+ Double first = rdd.first();
+ List<Double> take = rdd.take(5);
+ }
+
+ @Test
+ public void map() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
+ @Override
+ public Double apply(Integer x) {
+ return 1.0 * x;
+ }
+ }).cache();
+ JavaPairRDD<Integer, Integer> pairs = rdd.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> apply(Integer x) {
+ return new Tuple2<Integer, Integer>(x, x);
+ }
+ }).cache();
+ JavaRDD<String> strings = rdd.map(new Function<Integer, String>() {
+ @Override
+ public String apply(Integer x) {
+ return x.toString();
+ }
+ }).cache();
+ }
+
+ @Test
+ public void flatMap() {
+ JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello World!",
+ "The quick brown fox jumps over the lazy dog."));
+ JavaRDD<String> words = rdd.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> apply(String x) {
+ return Arrays.asList(x.split(" "));
+ }
+ });
+ Assert.assertEquals("Hello", words.first());
+ Assert.assertEquals(11, words.count());
+
+ JavaPairRDD<String, String> pairs = rdd.flatMap(
+ new PairFlatMapFunction<String, String, String>() {
+
+ @Override
+ public Iterable<Tuple2<String, String>> apply(String s) {
+ List<Tuple2<String, String>> pairs = new LinkedList<Tuple2<String, String>>();
+ for (String word : s.split(" ")) pairs.add(new Tuple2<String, String>(word, word));
+ return pairs;
+ }
+ }
+ );
+ Assert.assertEquals(new Tuple2<String, String>("Hello", "Hello"), pairs.first());
+ Assert.assertEquals(11, pairs.count());
+
+ JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction<String>() {
+ @Override
+ public Iterable<Double> apply(String s) {
+ List<Double> lengths = new LinkedList<Double>();
+ for (String word : s.split(" ")) lengths.add(word.length() * 1.0);
+ return lengths;
+ }
+ });
+ Double x = doubles.first();
+ Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01);
+ Assert.assertEquals(11, pairs.count());
+ }
+
+ @Test
+ public void mapPartitions() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ JavaRDD<Integer> partitionSums = rdd.mapPartitions(
+ new FlatMapFunction<Iterator<Integer>, Integer>() {
+ @Override
+ public Iterable<Integer> apply(Iterator<Integer> iter) {
+ int sum = 0;
+ while (iter.hasNext()) {
+ sum += iter.next();
+ }
+ return Collections.singletonList(sum);
+ }
+ });
+ Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
+ }
+
+ @Test
+ public void persist() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(20, doubleRDD.sum(), 0.1);
+
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+ pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals("a", pairRDD.first()._2());
+
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ rdd = rdd.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(1, rdd.first().intValue());
+ }
+
+ @Test
+ public void iterator() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
+ Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0)).next().intValue());
+ }
+
+ @Test
+ public void glom() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ Assert.assertEquals("[1, 2]", rdd.glom().first().toString());
+ }
+
+ // File input / output tests are largely adapted from FileSuite:
+
+ @Test
+ public void textFiles() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir);
+ // Read the plain text file and check it's OK
+ File outputFile = new File(outputDir, "part-00000");
+ String content = Files.toString(outputFile, Charsets.UTF_8);
+ Assert.assertEquals("1\n2\n3\n4\n", content);
+ // Also try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void sequenceFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> apply(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ // Try reading the output back as an object file
+ JavaPairRDD<Integer, String> readRDD = sc.sequenceFile(outputDir, IntWritable.class,
+ Text.class).map(new PairFunction<Tuple2<IntWritable, Text>, Integer, String>() {
+ @Override
+ public Tuple2<Integer, String> apply(Tuple2<IntWritable, Text> pair) {
+ return new Tuple2<Integer, String>(pair._1().get(), pair._2().toString());
+ }
+ });
+ Assert.assertEquals(pairs, readRDD.collect());
+ }
+
+ @Test
+ public void writeWithNewAPIHadoopFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> apply(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class,
+ org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.sequenceFile(outputDir, IntWritable.class,
+ Text.class);
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String apply(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void readWithNewAPIHadoopFile() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> apply(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.newAPIHadoopFile(outputDir,
+ org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class,
+ Text.class, new Job().getConfiguration());
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String apply(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void objectFilesOfInts() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsObjectFile(outputDir);
+ // Try reading the output back as an object file
+ List<Integer> expected = Arrays.asList(1, 2, 3, 4);
+ JavaRDD<Integer> readRDD = sc.objectFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void objectFilesOfComplexTypes() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+ rdd.saveAsObjectFile(outputDir);
+ // Try reading the output back as an object file
+ JavaRDD<Tuple2<Integer, String>> readRDD = sc.objectFile(outputDir);
+ Assert.assertEquals(pairs, readRDD.collect());
+ }
+
+ @Test
+ public void hadoopFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> apply(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String apply(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+}
diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala
index 06d446ea24..e889769b9a 100644
--- a/core/src/test/scala/spark/KryoSerializerSuite.scala
+++ b/core/src/test/scala/spark/KryoSerializerSuite.scala
@@ -8,7 +8,8 @@ import com.esotericsoftware.kryo._
import SparkContext._
-class KryoSerializerSuite extends FunSuite {
+class KryoSerializerSuite extends FunSuite{
+
test("basic types") {
val ser = (new KryoSerializer).newInstance()
def check[T](t: T) {
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 7f7f9493dc..cf2ffeb9b1 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,12 +1,23 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class PartitioningSuite extends FunSuite {
+class PartitioningSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
+
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
val p4 = new HashPartitioner(4)
@@ -20,7 +31,7 @@ class PartitioningSuite extends FunSuite {
}
test("RangePartitioner equality") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
@@ -46,12 +57,10 @@ class PartitioningSuite extends FunSuite {
assert(p4 != descendingP4)
assert(descendingP2 != p2)
assert(descendingP4 != p4)
-
- sc.stop()
}
test("HashPartitioner not equal to RangePartitioner") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
val hashP2 = new HashPartitioner(2)
@@ -59,11 +68,10 @@ class PartitioningSuite extends FunSuite {
assert(hashP2 === hashP2)
assert(hashP2 != rangeP2)
assert(rangeP2 != hashP2)
- sc.stop()
}
test("partitioner preservation") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
@@ -95,7 +103,5 @@ class PartitioningSuite extends FunSuite {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
-
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index d5dc2efd91..db1b9835a0 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -1,12 +1,21 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
-class PipedRDDSuite extends FunSuite {
-
+class PipedRDDSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("basic pipe") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
@@ -18,18 +27,16 @@ class PipedRDDSuite extends FunSuite {
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
- sc.stop()
}
test("pipe with env variable") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
assert(c.size === 2)
assert(c(0) === "LALALA")
assert(c(1) === "LALALA")
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7199b634b7..3924a6890b 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,11 +2,21 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
-class RDDSuite extends FunSuite {
+class RDDSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("basic operations") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
assert(nums.reduce(_ + _) === 10)
@@ -18,11 +28,10 @@ class RDDSuite extends FunSuite {
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))
- sc.stop()
}
test("aggregate") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
@@ -40,6 +49,5 @@ class RDDSuite extends FunSuite {
}
val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps)
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 00b24464a6..5fa494160f 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,6 +1,7 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
@@ -12,9 +13,18 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class ShuffleSuite extends FunSuite {
+class ShuffleSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("groupByKey") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
@@ -22,11 +32,10 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("groupByKey with duplicates") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
@@ -34,11 +43,10 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("groupByKey with negative key hash codes") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
@@ -46,11 +54,10 @@ class ShuffleSuite extends FunSuite {
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("groupByKey with many output partitions") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey(10).collect()
assert(groups.size === 2)
@@ -58,37 +65,33 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
- sc.stop()
}
test("reduceByKey") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
- sc.stop()
}
test("reduceByKey with collectAsMap") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collectAsMap()
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
- sc.stop()
}
test("reduceByKey with many output partitons") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
- sc.stop()
}
test("join") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2).collect()
@@ -99,11 +102,10 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
- sc.stop()
}
test("join all-to-all") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
val joined = rdd1.join(rdd2).collect()
@@ -116,11 +118,10 @@ class ShuffleSuite extends FunSuite {
(1, (3, 'x')),
(1, (3, 'y'))
))
- sc.stop()
}
test("leftOuterJoin") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.leftOuterJoin(rdd2).collect()
@@ -132,11 +133,10 @@ class ShuffleSuite extends FunSuite {
(2, (1, Some('z'))),
(3, (1, None))
))
- sc.stop()
}
test("rightOuterJoin") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.rightOuterJoin(rdd2).collect()
@@ -148,20 +148,18 @@ class ShuffleSuite extends FunSuite {
(2, (Some(1), 'z')),
(4, (None, 'w'))
))
- sc.stop()
}
test("join with no matches") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
- sc.stop()
}
test("join with many output partitions") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2, 10).collect()
@@ -172,11 +170,10 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
- sc.stop()
}
test("groupWith") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.groupWith(rdd2).collect()
@@ -187,17 +184,15 @@ class ShuffleSuite extends FunSuite {
(3, (ArrayBuffer(1), ArrayBuffer())),
(4, (ArrayBuffer(), ArrayBuffer('w')))
))
- sc.stop()
}
test("zero-partition RDD") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
assert(file.splits.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
- sc.stop()
- }
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ }
}
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index caff884966..d2dd514edb 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -1,50 +1,55 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
-class SortingSuite extends FunSuite {
+class SortingSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
test("sortByKey") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
- assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
- sc.stop()
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
test("sortLargeArray") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
- sc.stop()
}
test("sortDescending") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
- sc.stop()
}
test("morePartitionsThanElements") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
- sc.stop()
}
test("emptyRDD") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
- sc.stop()
}
}
diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala
index d38e72d8b8..90409a54ec 100644
--- a/core/src/test/scala/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/spark/ThreadingSuite.scala
@@ -5,6 +5,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import SparkContext._
@@ -21,9 +22,19 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends FunSuite {
+class ThreadingSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if(sc != null) {
+ sc.stop()
+ }
+ }
+
+
test("accessing SparkContext form a different thread") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var answer1: Int = 0
@@ -38,11 +49,10 @@ class ThreadingSuite extends FunSuite {
sem.acquire()
assert(answer1 === 55)
assert(answer2 === 1)
- sc.stop()
}
test("accessing SparkContext form multiple threads") {
- val sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var ok = true
@@ -67,11 +77,10 @@ class ThreadingSuite extends FunSuite {
if (!ok) {
fail("One or more threads got the wrong answer from an RDD operation")
}
- sc.stop()
}
test("accessing multi-threaded SparkContext form multiple threads") {
- val sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local[4]", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var ok = true
@@ -96,13 +105,12 @@ class ThreadingSuite extends FunSuite {
if (!ok) {
fail("One or more threads got the wrong answer from an RDD operation")
}
- sc.stop()
}
test("parallel job execution") {
// This test launches two jobs with two threads each on a 4-core local cluster. Each thread
// waits until there are 4 threads running at once, to test that both jobs have been launched.
- val sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local[4]", "test")
val nums = sc.parallelize(1 to 2, 2)
val sem = new Semaphore(0)
ThreadingSuiteState.clear()
@@ -132,6 +140,5 @@ class ThreadingSuite extends FunSuite {
if (ThreadingSuiteState.failed.get()) {
fail("One or more threads didn't see runningThreads = 4")
}
- sc.stop()
}
}