aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-13 14:57:12 -0700
committerMatei Zaharia <matei@databricks.com>2014-08-13 14:57:12 -0700
commit434bea1c002b597cff9db899da101490e1f1e9ed (patch)
tree4ddf5a151343d3c4012ff248af6efe3ef295c04d
parentc974a716e17c9fe2628b1ba1d4309ead1bd855ad (diff)
downloadspark-434bea1c002b597cff9db899da101490e1f1e9ed.tar.gz
spark-434bea1c002b597cff9db899da101490e1f1e9ed.tar.bz2
spark-434bea1c002b597cff9db899da101490e1f1e9ed.zip
[SPARK-2983] [PySpark] improve performance of sortByKey()
1. skip partitionBy() when numOfPartition is 1 2. use bisect_left (O(lg(N))) instread of loop (O(N)) in rangePartitioner Author: Davies Liu <davies.liu@gmail.com> Closes #1898 from davies/sort and squashes the following commits: 0a9608b [Davies Liu] Merge branch 'master' into sort 1cf9565 [Davies Liu] improve performance of sortByKey()
-rw-r--r--python/pyspark/rdd.py47
1 files changed, 24 insertions, 23 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 756e8f35fb..3934bdda0a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
import heapq
+import bisect
from random import Random
from math import sqrt, log
@@ -574,6 +575,8 @@ class RDD(object):
# noqa
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortByKey(True, 1).collect()
+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
@@ -584,42 +587,40 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- bounds = list()
+ if numPartitions == 1:
+ if self.getNumPartitions() > 1:
+ self = self.coalesce(1)
+
+ def sort(iterator):
+ return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+ return self.mapPartitions(sort)
# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
# number of (key, value) pairs falling into them
- if numPartitions > 1:
- rddSize = self.count()
- # constant from Spark's RangePartitioner
- maxSampleSize = numPartitions * 20.0
- fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
-
- samples = self.sample(False, fraction, 1).map(
- lambda (k, v): k).collect()
- samples = sorted(samples, reverse=(not ascending), key=keyfunc)
-
- # we have numPartitions many parts but one of the them has
- # an implicit boundary
- for i in range(0, numPartitions - 1):
- index = (len(samples) - 1) * (i + 1) / numPartitions
- bounds.append(samples[index])
+ rddSize = self.count()
+ maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+ fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+ samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+ samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+ # we have numPartitions many parts but one of the them has
+ # an implicit boundary
+ bounds = [samples[len(samples) * (i + 1) / numPartitions]
+ for i in range(0, numPartitions - 1)]
def rangePartitionFunc(k):
- p = 0
- while p < len(bounds) and keyfunc(k) > bounds[p]:
- p += 1
+ p = bisect.bisect_left(bounds, keyfunc(k))
if ascending:
return p
else:
return numPartitions - 1 - p
def mapFunc(iterator):
- yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+ return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
- return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
- .mapPartitions(mapFunc, preservesPartitioning=True)
- .flatMap(lambda x: x, preservesPartitioning=True))
+ return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""