From fdbae41e88af0994e97ac8741f86547184a05de9 Mon Sep 17 00:00:00 2001 From: Andre Schumacher Date: Mon, 7 Oct 2013 10:42:39 -0700 Subject: SPARK-705: implement sortByKey() in PySpark --- python/pyspark/rdd.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 39c402b412..7dfabb0b7d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -263,7 +263,53 @@ class RDD(object): raise TypeError return self.union(other) - # TODO: sort + def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): + """ + Sorts this RDD, which is assumed to consist of (key, value) pairs. + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 2).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] + >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)]) + >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect() + [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] + """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + + bounds = list() + + # first compute the boundary of each part via sampling: we want to partition + # the key-space into bins such that the bins have roughly the same + # number of (key, value) pairs falling into them + if numPartitions > 1: + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + for i in range(0, numPartitions - 1): + index = (len(samples) - 1) * (i + 1) / numPartitions + bounds.append(samples[index]) + + def rangePartitionFunc(k): + p = 0 + while p < len(bounds) and keyfunc(k) > bounds[p]: + p += 1 + if ascending: + return p + else: + return numPartitions-1-p + + def mapFunc(iterator): + yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc).mapPartitions(mapFunc,preservesPartitioning=True).flatMap(lambda x: x, preservesPartitioning=True) def glom(self): """ -- cgit v1.2.3