From 0a7ef6339f18e68d703599aff7db2dd9c2003866 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 19 Aug 2014 22:43:49 -0700 Subject: [SPARK-3141] [PySpark] fix sortByKey() with take() Fix sortByKey() with take() The function `f` used in mapPartitions should always return an iterator. Author: Davies Liu Closes #2045 from davies/fix_sortbykey and squashes the following commits: 1160f59 [Davies Liu] fix sortByKey() with take() --- python/pyspark/rdd.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 140cbe05a4..3eefc878d2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -575,6 +575,8 @@ class RDD(object): # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey().first() + ('1', 3) >>> sc.parallelize(tmp).sortByKey(True, 1).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() @@ -587,14 +589,13 @@ class RDD(object): if numPartitions is None: numPartitions = self._defaultReducePartitions() + def sortPartition(iterator): + return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending)) + if numPartitions == 1: if self.getNumPartitions() > 1: self = self.coalesce(1) - - def sort(iterator): - return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - - return self.mapPartitions(sort) + return self.mapPartitions(sortPartition) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same @@ -610,17 +611,14 @@ class RDD(object): bounds = [samples[len(samples) * (i + 1) / numPartitions] for i in range(0, numPartitions - 1)] - def rangePartitionFunc(k): + def rangePartitioner(k): p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: return numPartitions - 1 - p - def mapFunc(iterator): - return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - - return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True) + return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True) def sortBy(self, keyfunc, ascending=True, numPartitions=None): """ -- cgit v1.2.3