aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-08 16:04:41 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-08 16:05:02 -0800
commitb57dd0f16024a82dfc223e69528b9908b931f068 (patch)
tree8ad1222593da58eaeb7746aecaef2c41c5313f71 /python/pyspark/rdd.py
parent33beba39656fc64984db09a82fc69ca4edcc02d4 (diff)
downloadspark-b57dd0f16024a82dfc223e69528b9908b931f068.tar.gz
spark-b57dd0f16024a82dfc223e69528b9908b931f068.tar.bz2
spark-b57dd0f16024a82dfc223e69528b9908b931f068.zip
Add mapPartitionsWithSplit() to PySpark.
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py33
1 files changed, 22 insertions, 11 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4ba417b2a2..1d36da42b0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -55,7 +55,7 @@ class RDD(object):
"""
Return a new RDD containing the distinct elements in this RDD.
"""
- def func(iterator): return imap(f, iterator)
+ def func(split, iterator): return imap(f, iterator)
return PipelinedRDD(self, func, preservesPartitioning)
def flatMap(self, f, preservesPartitioning=False):
@@ -69,8 +69,8 @@ class RDD(object):
>>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
- def func(iterator): return chain.from_iterable(imap(f, iterator))
- return self.mapPartitions(func, preservesPartitioning)
+ def func(s, iterator): return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitionsWithSplit(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
"""
@@ -81,9 +81,20 @@ class RDD(object):
>>> rdd.mapPartitions(f).collect()
[3, 7]
"""
- return PipelinedRDD(self, f, preservesPartitioning)
+ def func(s, iterator): return f(iterator)
+ return self.mapPartitionsWithSplit(func)
+
+ def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition of this RDD,
+ while tracking the index of the original partition.
- # TODO: mapPartitionsWithSplit
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(splitIndex, iterator): yield splitIndex
+ >>> rdd.mapPartitionsWithSplit(f).sum()
+ 6
+ """
+ return PipelinedRDD(self, f, preservesPartitioning)
def filter(self, f):
"""
@@ -362,7 +373,7 @@ class RDD(object):
>>> ''.join(input(glob(tempFile.name + "/part-0000*")))
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
"""
- def func(iterator):
+ def func(split, iterator):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
@@ -500,7 +511,7 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numSplits) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
- def add_shuffle_key(iterator):
+ def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
buckets[hashFunc(k) % numSplits].append((k, v))
@@ -653,8 +664,8 @@ class PipelinedRDD(RDD):
def __init__(self, prev, func, preservesPartitioning=False):
if isinstance(prev, PipelinedRDD) and not prev.is_cached:
prev_func = prev.func
- def pipeline_func(iterator):
- return func(prev_func(iterator))
+ def pipeline_func(split, iterator):
+ return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
@@ -677,8 +688,8 @@ class PipelinedRDD(RDD):
if not self._bypass_serializer and self.ctx.batchSize != 1:
oldfunc = self.func
batchSize = self.ctx.batchSize
- def batched_func(iterator):
- return batched(oldfunc(iterator), batchSize)
+ def batched_func(split, iterator):
+ return batched(oldfunc(split, iterator), batchSize)
func = batched_func
cmds = [func, self._bypass_serializer]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)