From fdbae41e88af0994e97ac8741f86547184a05de9 Mon Sep 17 00:00:00 2001 From: Andre Schumacher Date: Mon, 7 Oct 2013 10:42:39 -0700 Subject: SPARK-705: implement sortByKey() in PySpark --- python/pyspark/rdd.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 39c402b412..7dfabb0b7d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -263,7 +263,53 @@ class RDD(object): raise TypeError return self.union(other) - # TODO: sort + def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): + """ + Sorts this RDD, which is assumed to consist of (key, value) pairs. + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 2).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] + >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)]) + >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect() + [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] + """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + + bounds = list() + + # first compute the boundary of each part via sampling: we want to partition + # the key-space into bins such that the bins have roughly the same + # number of (key, value) pairs falling into them + if numPartitions > 1: + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + for i in range(0, numPartitions - 1): + index = (len(samples) - 1) * (i + 1) / numPartitions + bounds.append(samples[index]) + + def rangePartitionFunc(k): + p = 0 + while p < len(bounds) and keyfunc(k) > bounds[p]: + p += 1 + if ascending: + return p + else: + return numPartitions-1-p + + def mapFunc(iterator): + yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc).mapPartitions(mapFunc,preservesPartitioning=True).flatMap(lambda x: x, preservesPartitioning=True) def glom(self): """ -- cgit v1.2.3 From 478b2b7edcf42fa3e16f625d4b8676f2bb31f8dc Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 9 Oct 2013 12:08:04 -0700 Subject: Fix PySpark docs and an overly long line of code after fdbae41e --- docs/python-programming-guide.md | 2 +- python/pyspark/rdd.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index f67a1cc49c..6c2336ad0c 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -16,7 +16,7 @@ This guide will show how to use the Spark features described there in Python. There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of multiple types. -* PySpark does not yet support a few API calls, such as `lookup`, `sort`, and non-text input files, though these will be added in future releases. +* PySpark does not yet support a few API calls, such as `lookup` and non-text input files, though these will be added in future releases. In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types. Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7dfabb0b7d..7019fb8bee 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -117,8 +117,6 @@ class RDD(object): else: return None - # TODO persist(self, storageLevel) - def map(self, f, preservesPartitioning=False): """ Return a new RDD containing the distinct elements in this RDD. @@ -227,7 +225,7 @@ class RDD(object): total = num samples = self.sample(withReplacement, fraction, seed).collect() - + # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala @@ -288,7 +286,7 @@ class RDD(object): maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() samples = sorted(samples, reverse=(not ascending), key=keyfunc) # we have numPartitions many parts but one of the them has @@ -309,7 +307,9 @@ class RDD(object): def mapFunc(iterator): yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - return self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc).mapPartitions(mapFunc,preservesPartitioning=True).flatMap(lambda x: x, preservesPartitioning=True) + return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) + .mapPartitions(mapFunc,preservesPartitioning=True) + .flatMap(lambda x: x, preservesPartitioning=True)) def glom(self): """ @@ -471,7 +471,7 @@ class RDD(object): 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - + def stats(self): """ Return a L{StatCounter} object that captures the mean, variance @@ -508,7 +508,7 @@ class RDD(object): 0.816... """ return self.stats().stdev() - + def sampleStdev(self): """ Compute the sample standard deviation of this RDD's elements (which corrects for bias in @@ -878,7 +878,7 @@ class RDD(object): >>> y = sc.parallelize([("a", 3), ("c", None)]) >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] - """ + """ filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0 map_func = lambda (key, vals): [(key, val) for val in vals[0]] return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) -- cgit v1.2.3