aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-24 23:09:15 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:19:22 -0700
commit8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f (patch)
tree5693f0983405691821015bb9eaeedd6be02dc7fb /pyspark
parent08b201d810c0dc0933d00d78ec2c1d9135e100c3 (diff)
downloadspark-8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f.tar.gz
spark-8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f.tar.bz2
spark-8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f.zip
Add countByKey(), reduceByKeyLocally() to Python API
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/rdd.py52
1 files changed, 39 insertions, 13 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index f0d665236a..fd41ea0b17 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -99,9 +99,17 @@ class RDD(object):
"""
return RDD(self._jrdd.union(other._jrdd), self.ctx)
- # TODO: sort
+ def __add__(self, other):
+ """
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> (rdd + rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ if not isinstance(other, RDD):
+ raise TypeError
+ return self.union(other)
- # TODO: Overload __add___?
+ # TODO: sort
def glom(self):
"""
@@ -120,7 +128,6 @@ class RDD(object):
"""
return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
- # numsplits
def groupBy(self, f, numSplits=None):
"""
>>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
@@ -236,17 +243,38 @@ class RDD(object):
def reduceByKey(self, func, numSplits=None):
"""
- >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
- >>> sorted(x.reduceByKey(lambda a, b: a + b).collect())
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKey(add).collect())
[('a', 2), ('b', 1)]
"""
return self.combineByKey(lambda x: x, func, func, numSplits)
- # TODO: reduceByKeyLocally()
-
- # TODO: countByKey()
+ def reduceByKeyLocally(self, func):
+ """
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKeyLocally(add).items())
+ [('a', 2), ('b', 1)]
+ """
+ def reducePartition(iterator):
+ m = {}
+ for (k, v) in iterator:
+ m[k] = v if k not in m else func(m[k], v)
+ yield m
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] = v if k not in m1 else func(m1[k], v)
+ return m1
+ return self.mapPartitions(reducePartition).reduce(mergeMaps)
- # TODO: partitionBy
+ def countByKey(self):
+ """
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> rdd.countByKey().most_common()
+ [('a', 2), ('b', 1)]
+ """
+ return self.map(lambda x: x[0]).countByValue()
def join(self, other, numSplits=None):
"""
@@ -277,7 +305,7 @@ class RDD(object):
# TODO: pipelining
# TODO: optimizations
- def shuffle(self, numSplits, hashFunc=hash):
+ def partitionBy(self, numSplits, hashFunc=hash):
if numSplits is None:
numSplits = self.ctx.defaultParallelism
(pipe_command, broadcast_vars) = \
@@ -302,7 +330,7 @@ class RDD(object):
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
- shuffled = self.shuffle(numSplits)
+ shuffled = self.partitionBy(numSplits)
functions = [createCombiner, mergeValue, mergeCombiners]
jpairs = shuffled._pipe(functions, "combine_by_key")
return RDD(jpairs, self.ctx)
@@ -353,8 +381,6 @@ class RDD(object):
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.
- # TODO: file saving
-
class PipelinedRDD(RDD):
"""