aboutsummaryrefslogtreecommitdiff
path: root/pyspark/pyspark/rdd.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-24 22:51:45 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:19:14 -0700
commit08b201d810c0dc0933d00d78ec2c1d9135e100c3 (patch)
treefe3995be912db8dbe10905d0091a440d01564b4d /pyspark/pyspark/rdd.py
parentf79a1e4d2a8643157136de69b8d7de84f0034712 (diff)
downloadspark-08b201d810c0dc0933d00d78ec2c1d9135e100c3.tar.gz
spark-08b201d810c0dc0933d00d78ec2c1d9135e100c3.tar.bz2
spark-08b201d810c0dc0933d00d78ec2c1d9135e100c3.zip
Add mapPartitions(), glom(), countByValue() to Python API.
Diffstat (limited to 'pyspark/pyspark/rdd.py')
-rw-r--r--pyspark/pyspark/rdd.py32
1 files changed, 28 insertions, 4 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 4459095391..f0d665236a 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,4 +1,5 @@
from base64 import standard_b64encode as b64enc
+from collections import Counter
from itertools import chain, ifilter, imap
from pyspark import cloudpickle
@@ -47,6 +48,15 @@ class RDD(object):
def func(iterator): return chain.from_iterable(imap(f, iterator))
return PipelinedRDD(self, func)
+ def mapPartitions(self, f):
+ """
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> def f(iterator): yield sum(iterator)
+ >>> rdd.mapPartitions(f).collect()
+ [3, 7]
+ """
+ return PipelinedRDD(self, f)
+
def filter(self, f):
"""
>>> rdd = sc.parallelize([1, 2, 3, 4, 5])
@@ -93,7 +103,14 @@ class RDD(object):
# TODO: Overload __add___?
- # TODO: glom
+ def glom(self):
+ """
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> rdd.glom().first()
+ [1, 2]
+ """
+ def func(iterator): yield list(iterator)
+ return PipelinedRDD(self, func)
def cartesian(self, other):
"""
@@ -115,8 +132,6 @@ class RDD(object):
# TODO: pipe
- # TODO: mapPartitions
-
def foreach(self, f):
"""
>>> def f(x): print x
@@ -177,7 +192,16 @@ class RDD(object):
"""
return self._jrdd.count()
- # TODO: count approx methods
+ def countByValue(self):
+ """
+ >>> sc.parallelize([1, 2, 1, 2, 2]).countByValue().most_common()
+ [(2, 3), (1, 2)]
+ """
+ def countPartition(iterator):
+ yield Counter(iterator)
+ def mergeMaps(m1, m2):
+ return m1 + m2
+ return self.mapPartitions(countPartition).reduce(mergeMaps)
def take(self, num):
"""