aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2012-09-26 17:11:28 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2012-09-26 17:11:28 -0700
commit1ad1331a340b7d52b1218d5a835db71d28fb4467 (patch)
treea57de42e1793e62155e2689c9757eb93647cf850 /core/src/main
parent58eb44acbb81a7d619a015ed32dffd6da6b15436 (diff)
downloadspark-1ad1331a340b7d52b1218d5a835db71d28fb4467.tar.gz
spark-1ad1331a340b7d52b1218d5a835db71d28fb4467.tar.bz2
spark-1ad1331a340b7d52b1218d5a835db71d28fb4467.zip
Added MapPartitionsWithSplitRDD.
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/RDD.scala18
1 files changed, 18 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 5fac955286..cce0ea2183 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
+ def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
+ new MapPartitionsWithSplitRDD(this, sc.clean(f))
+
// Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) {
@@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split))
}
+
+/**
+ * A variant of the MapPartitionsRDD that passes the split index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(split.index, prev.iterator(split))
+}