From 1ad1331a340b7d52b1218d5a835db71d28fb4467 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 26 Sep 2012 17:11:28 -0700 Subject: Added MapPartitionsWithSplitRDD. --- core/src/main/scala/spark/RDD.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'core/src/main') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 5fac955286..cce0ea2183 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = new MapPartitionsRDD(this, sc.clean(f)) + def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = + new MapPartitionsWithSplitRDD(this, sc.clean(f)) + // Actions (launch a job to return a value to the user program) def foreach(f: T => Unit) { @@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val dependencies = List(new OneToOneDependency(prev)) override def compute(split: Split) = f(prev.iterator(split)) } + +/** + * A variant of the MapPartitionsRDD that passes the split index into the + * closure. This can be used to generate or collect partition specific + * information such as the number of tuples in a partition. + */ +class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: (Int, Iterator[T]) => Iterator[U]) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = f(split.index, prev.iterator(split)) +} -- cgit v1.2.3