aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-10-09 11:59:47 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-10-09 11:59:47 -0700
commitb4fa11f6c96ee37ecd30231c1e22630055f52115 (patch)
treebd87ea2fa9a34d8c7da631f940ecd45e06f3a0bb
parent19d445d37c38a35cc68da91f53a8780b89a1f8c9 (diff)
parentfdbae41e88af0994e97ac8741f86547184a05de9 (diff)
downloadspark-b4fa11f6c96ee37ecd30231c1e22630055f52115.tar.gz
spark-b4fa11f6c96ee37ecd30231c1e22630055f52115.tar.bz2
spark-b4fa11f6c96ee37ecd30231c1e22630055f52115.zip
Merge pull request #38 from AndreSchumacher/pyspark_sorting
SPARK-705: implement sortByKey() in PySpark This PR contains the implementation of a RangePartitioner in Python and uses its partition ID's to get a global sort in PySpark.
-rw-r--r--python/pyspark/rdd.py48
1 files changed, 47 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 39c402b412..7dfabb0b7d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -263,7 +263,53 @@ class RDD(object):
raise TypeError
return self.union(other)
- # TODO: sort
+ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
+ """
+ Sorts this RDD, which is assumed to consist of (key, value) pairs.
+
+ >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+ >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
+ >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
+ >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
+ [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+ """
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
+
+ bounds = list()
+
+ # first compute the boundary of each part via sampling: we want to partition
+ # the key-space into bins such that the bins have roughly the same
+ # number of (key, value) pairs falling into them
+ if numPartitions > 1:
+ rddSize = self.count()
+ maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+ fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+
+ samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+ samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+ # we have numPartitions many parts but one of the them has
+ # an implicit boundary
+ for i in range(0, numPartitions - 1):
+ index = (len(samples) - 1) * (i + 1) / numPartitions
+ bounds.append(samples[index])
+
+ def rangePartitionFunc(k):
+ p = 0
+ while p < len(bounds) and keyfunc(k) > bounds[p]:
+ p += 1
+ if ascending:
+ return p
+ else:
+ return numPartitions-1-p
+
+ def mapFunc(iterator):
+ yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+ return self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc).mapPartitions(mapFunc,preservesPartitioning=True).flatMap(lambda x: x, preservesPartitioning=True)
def glom(self):
"""