aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-19 22:43:49 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-08-19 22:43:49 -0700
commit0a7ef6339f18e68d703599aff7db2dd9c2003866 (patch)
treecd77d203174069d00d1af05f03e0a98eeec8b8cf /python
parent8a74e4b2a8c7dab154b406539487cf29d578d208 (diff)
downloadspark-0a7ef6339f18e68d703599aff7db2dd9c2003866.tar.gz
spark-0a7ef6339f18e68d703599aff7db2dd9c2003866.tar.bz2
spark-0a7ef6339f18e68d703599aff7db2dd9c2003866.zip
[SPARK-3141] [PySpark] fix sortByKey() with take()
Fix sortByKey() with take() The function `f` used in mapPartitions should always return an iterator. Author: Davies Liu <davies.liu@gmail.com> Closes #2045 from davies/fix_sortbykey and squashes the following commits: 1160f59 [Davies Liu] fix sortByKey() with take()
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py18
1 files changed, 8 insertions, 10 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 140cbe05a4..3eefc878d2 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -575,6 +575,8 @@ class RDD(object):
# noqa
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortByKey().first()
+ ('1', 3)
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
@@ -587,14 +589,13 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ def sortPartition(iterator):
+ return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending))
+
if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)
-
- def sort(iterator):
- return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
-
- return self.mapPartitions(sort)
+ return self.mapPartitions(sortPartition)
# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
@@ -610,17 +611,14 @@ class RDD(object):
bounds = [samples[len(samples) * (i + 1) / numPartitions]
for i in range(0, numPartitions - 1)]
- def rangePartitionFunc(k):
+ def rangePartitioner(k):
p = bisect.bisect_left(bounds, keyfunc(k))
if ascending:
return p
else:
return numPartitions - 1 - p
- def mapFunc(iterator):
- return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
-
- return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
+ return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True)
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""