aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-26 23:16:45 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-09-26 23:16:45 -0700
commit920fab23c3ee68945687f1b03280c9c7d0f61597 (patch)
treec0d39d160817a44d60a62fcaed59a3c7cd984969 /core/src/main
parentea05fc130b64ce356ab7524a3d5bd1e022cf51b5 (diff)
parent1ad1331a340b7d52b1218d5a835db71d28fb4467 (diff)
downloadspark-920fab23c3ee68945687f1b03280c9c7d0f61597.tar.gz
spark-920fab23c3ee68945687f1b03280c9c7d0f61597.tar.bz2
spark-920fab23c3ee68945687f1b03280c9c7d0f61597.zip
Merge pull request #222 from rxin/dev
Added MapPartitionsWithSplitRDD.
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/RDD.scala18
1 files changed, 18 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 5fac955286..cce0ea2183 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
+ def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
+ new MapPartitionsWithSplitRDD(this, sc.clean(f))
+
// Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) {
@@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split))
}
+
+/**
+ * A variant of the MapPartitionsRDD that passes the split index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(split.index, prev.iterator(split))
+}