diff options
author | Reynold Xin <rxin@cs.berkeley.edu> | 2012-09-26 17:11:28 -0700 |
---|---|---|
committer | Reynold Xin <rxin@cs.berkeley.edu> | 2012-09-26 17:11:28 -0700 |
commit | 1ad1331a340b7d52b1218d5a835db71d28fb4467 (patch) | |
tree | a57de42e1793e62155e2689c9757eb93647cf850 /core/src | |
parent | 58eb44acbb81a7d619a015ed32dffd6da6b15436 (diff) | |
download | spark-1ad1331a340b7d52b1218d5a835db71d28fb4467.tar.gz spark-1ad1331a340b7d52b1218d5a835db71d28fb4467.tar.bz2 spark-1ad1331a340b7d52b1218d5a835db71d28fb4467.zip |
Added MapPartitionsWithSplitRDD.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 18 | ||||
-rw-r--r-- | core/src/test/scala/spark/RDDSuite.scala | 5 |
2 files changed, 23 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 5fac955286..cce0ea2183 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = new MapPartitionsRDD(this, sc.clean(f)) + def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = + new MapPartitionsWithSplitRDD(this, sc.clean(f)) + // Actions (launch a job to return a value to the user program) def foreach(f: T => Unit) { @@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val dependencies = List(new OneToOneDependency(prev)) override def compute(split: Split) = f(prev.iterator(split)) } + +/** + * A variant of the MapPartitionsRDD that passes the split index into the + * closure. This can be used to generate or collect partition specific + * information such as the number of tuples in a partition. + */ +class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: (Int, Iterator[T]) => Iterator[U]) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = f(split.index, prev.iterator(split)) +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ba9b36adb7..04dbe3a3e4 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -29,6 +29,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) + + val partitionSumsWithSplit = nums.mapPartitionsWithSplit { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) } test("SparkContext.union") { |