diff options
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 33 | ||||
-rw-r--r-- | core/src/test/scala/spark/RDDSuite.scala | 26 |
2 files changed, 49 insertions, 10 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 1d86062012..624e56582d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -80,7 +80,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { } } - // Transformations + // Transformations (return a new RDD) def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) @@ -89,23 +89,25 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) - def sample(withReplacement: Boolean, frac: Double, seed: Int): RDD[T] = - new SampledRDD(this, withReplacement, frac, seed) + def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = + new SampledRDD(this, withReplacement, fraction, seed) def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) def ++(other: RDD[T]): RDD[T] = this.union(other) - def glom(): RDD[Array[T]] = new SplitRDD(this) + def glom(): RDD[Array[T]] = new GlommedRDD(this) def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) - def groupBy[K: ClassManifest](func: T => K, numSplits: Int): RDD[(K, Seq[T])] = - this.map(t => (func(t), t)).groupByKey(numSplits) + def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = { + val cleanF = sc.clean(f) + this.map(t => (cleanF(t), t)).groupByKey(numSplits) + } - def groupBy[K: ClassManifest](func: T => K): RDD[(K, Seq[T])] = - groupBy[K](func, sc.numCores) + def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = + groupBy[K](f, sc.numCores) def pipe(command: String): RDD[String] = new PipedRDD(this, command) @@ -113,7 +115,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) - // Parallel operations + def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = + new MapPartitionsRDD(this, sc.clean(f)) + + // Actions (launch a job to return a value to the user program) def foreach(f: T => Unit) { val cleanF = sc.clean(f) @@ -217,9 +222,17 @@ extends RDD[T](prev.context) { override def compute(split: Split) = prev.iterator(split).filter(f) } -class SplitRDD[T: ClassManifest](prev: RDD[T]) +class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator } + +class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], f: Iterator[T] => Iterator[U]) +extends RDD[U](prev.context) { + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = f(prev.iterator(split)) +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala new file mode 100644 index 0000000000..d31fdb7f8a --- /dev/null +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -0,0 +1,26 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ +import SparkContext._ +import scala.collection.mutable.ArrayBuffer + +class RDDSuite extends FunSuite { + test("basic operations") { + val sc = new SparkContext("local", "test") + val nums = sc.parallelize(Array(1, 2, 3, 4), 2) + assert(nums.collect().toList === List(1, 2, 3, 4)) + assert(nums.reduce(_ + _) === 10) + assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) + assert(nums.filter(_ > 2).collect().toList === List(3, 4)) + assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) + assert(partitionSums.collect().toList === List(3, 7)) + sc.stop() + } +} |