aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/RDD.scala18
1 files changed, 18 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 5fac955286..cce0ea2183 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
+ def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
+ new MapPartitionsWithSplitRDD(this, sc.clean(f))
+
// Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) {
@@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split))
}
+
+/**
+ * A variant of the MapPartitionsRDD that passes the split index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(split.index, prev.iterator(split))
+}