diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-01-05 10:17:20 -0800 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-01-05 10:17:20 -0800 |
commit | 7ab9f091408dad2c264cede965ab284c9d33deb9 (patch) | |
tree | 05dda23cdb33ff09bc81d2cd71e2df754d1fcb11 | |
parent | 55809fbc6db5929332cd45fa3281f9190098c6c6 (diff) | |
parent | f4e6b9361ffeec1018d5834f09db9fd86f2ba7bd (diff) | |
download | spark-7ab9f091408dad2c264cede965ab284c9d33deb9.tar.gz spark-7ab9f091408dad2c264cede965ab284c9d33deb9.tar.bz2 spark-7ab9f091408dad2c264cede965ab284c9d33deb9.zip |
Merge pull request #352 from stephenh/collect
Add RDD.collect(PartialFunction).
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 7 | ||||
-rw-r--r-- | core/src/test/scala/spark/RDDSuite.scala | 1 |
2 files changed, 8 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..5163c80134 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -330,6 +330,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def toArray(): Array[T] = collect() /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + + /** * Reduces the elements of this RDD using the specified associative binary operator. */ def reduce(f: (T, T) => T): T = { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..872b06fd08 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) |