From b57dd0f16024a82dfc223e69528b9908b931f068 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 8 Jan 2013 16:04:41 -0800 Subject: Add mapPartitionsWithSplit() to PySpark. --- python/pyspark/rdd.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4ba417b2a2..1d36da42b0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -55,7 +55,7 @@ class RDD(object): """ Return a new RDD containing the distinct elements in this RDD. """ - def func(iterator): return imap(f, iterator) + def func(split, iterator): return imap(f, iterator) return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -69,8 +69,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - def func(iterator): return chain.from_iterable(imap(f, iterator)) - return self.mapPartitions(func, preservesPartitioning) + def func(s, iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithSplit(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): """ @@ -81,9 +81,20 @@ class RDD(object): >>> rdd.mapPartitions(f).collect() [3, 7] """ - return PipelinedRDD(self, f, preservesPartitioning) + def func(s, iterator): return f(iterator) + return self.mapPartitionsWithSplit(func) + + def mapPartitionsWithSplit(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. - # TODO: mapPartitionsWithSplit + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithSplit(f).sum() + 6 + """ + return PipelinedRDD(self, f, preservesPartitioning) def filter(self, f): """ @@ -362,7 +373,7 @@ class RDD(object): >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ - def func(iterator): + def func(split, iterator): return (str(x).encode("utf-8") for x in iterator) keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True @@ -500,7 +511,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numSplits) objects # to Java. Each object is a (splitNumber, [objects]) pair. - def add_shuffle_key(iterator): + def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) @@ -653,8 +664,8 @@ class PipelinedRDD(RDD): def __init__(self, prev, func, preservesPartitioning=False): if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - def pipeline_func(iterator): - return func(prev_func(iterator)) + def pipeline_func(split, iterator): + return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning @@ -677,8 +688,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(split, iterator): + return batched(oldfunc(split, iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) -- cgit v1.2.3