aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-05 10:17:20 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-05 10:17:20 -0800
commit7ab9f091408dad2c264cede965ab284c9d33deb9 (patch)
tree05dda23cdb33ff09bc81d2cd71e2df754d1fcb11
parent55809fbc6db5929332cd45fa3281f9190098c6c6 (diff)
parentf4e6b9361ffeec1018d5834f09db9fd86f2ba7bd (diff)
downloadspark-7ab9f091408dad2c264cede965ab284c9d33deb9.tar.gz
spark-7ab9f091408dad2c264cede965ab284c9d33deb9.tar.bz2
spark-7ab9f091408dad2c264cede965ab284c9d33deb9.zip
Merge pull request #352 from stephenh/collect
Add RDD.collect(PartialFunction).
-rw-r--r--core/src/main/scala/spark/RDD.scala7
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala1
2 files changed, 8 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 7e38583391..5163c80134 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -330,6 +330,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def toArray(): Array[T] = collect()
/**
+ * Return an RDD that contains all matching values by applying `f`.
+ */
+ def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ filter(f.isDefinedAt).map(f)
+ }
+
+ /**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 45e6c5f840..872b06fd08 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
+ assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))