aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2011-07-13 00:19:52 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2011-07-13 00:19:52 -0400
commit842e14d567a436ef1e6975f51e8845a448717688 (patch)
treec36c05de0776d21101cde6c4367419d61e8a5d2b /core
parentd05fea24f3c697e3b62f10fb794e9d51ef4441ea (diff)
downloadspark-842e14d567a436ef1e6975f51e8845a448717688.tar.gz
spark-842e14d567a436ef1e6975f51e8845a448717688.tar.bz2
spark-842e14d567a436ef1e6975f51e8845a448717688.zip
Added mapPartitions operation and a bunch of tests for RDD ops
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/RDD.scala33
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala26
2 files changed, 49 insertions, 10 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1d86062012..624e56582d 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -80,7 +80,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
}
}
- // Transformations
+ // Transformations (return a new RDD)
def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
@@ -89,23 +89,25 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
- def sample(withReplacement: Boolean, frac: Double, seed: Int): RDD[T] =
- new SampledRDD(this, withReplacement, frac, seed)
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
+ new SampledRDD(this, withReplacement, fraction, seed)
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
def ++(other: RDD[T]): RDD[T] = this.union(other)
- def glom(): RDD[Array[T]] = new SplitRDD(this)
+ def glom(): RDD[Array[T]] = new GlommedRDD(this)
def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] =
new CartesianRDD(sc, this, other)
- def groupBy[K: ClassManifest](func: T => K, numSplits: Int): RDD[(K, Seq[T])] =
- this.map(t => (func(t), t)).groupByKey(numSplits)
+ def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = {
+ val cleanF = sc.clean(f)
+ this.map(t => (cleanF(t), t)).groupByKey(numSplits)
+ }
- def groupBy[K: ClassManifest](func: T => K): RDD[(K, Seq[T])] =
- groupBy[K](func, sc.numCores)
+ def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
+ groupBy[K](f, sc.numCores)
def pipe(command: String): RDD[String] =
new PipedRDD(this, command)
@@ -113,7 +115,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
def pipe(command: Seq[String]): RDD[String] =
new PipedRDD(this, command)
- // Parallel operations
+ def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
+ new MapPartitionsRDD(this, sc.clean(f))
+
+ // Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
@@ -217,9 +222,17 @@ extends RDD[T](prev.context) {
override def compute(split: Split) = prev.iterator(split).filter(f)
}
-class SplitRDD[T: ClassManifest](prev: RDD[T])
+class GlommedRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
}
+
+class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T], f: Iterator[T] => Iterator[U])
+extends RDD[U](prev.context) {
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(prev.iterator(split))
+}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
new file mode 100644
index 0000000000..d31fdb7f8a
--- /dev/null
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -0,0 +1,26 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+import SparkContext._
+import scala.collection.mutable.ArrayBuffer
+
+class RDDSuite extends FunSuite {
+ test("basic operations") {
+ val sc = new SparkContext("local", "test")
+ val nums = sc.parallelize(Array(1, 2, 3, 4), 2)
+ assert(nums.collect().toList === List(1, 2, 3, 4))
+ assert(nums.reduce(_ + _) === 10)
+ assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
+ assert(nums.filter(_ > 2).collect().toList === List(3, 4))
+ assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
+ assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
+ val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
+ assert(partitionSums.collect().toList === List(3, 7))
+ sc.stop()
+ }
+}