diff options
author | Prashant Sharma <prashant.s@imaginea.com> | 2013-10-10 09:42:55 +0530 |
---|---|---|
committer | Prashant Sharma <prashant.s@imaginea.com> | 2013-10-10 09:42:55 +0530 |
commit | 026ab7566167e6c8ab1b0cce75b9e09bbd485bee (patch) | |
tree | a713bacba391eb9b8e07ca0d2f6521cd2b061b49 /python/pyspark | |
parent | 26860639c5fee7fc23db1e686f8eb202921e4314 (diff) | |
parent | 320418f7c8b42d4ce781b32c9ee47a9b54550b89 (diff) | |
download | spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.tar.gz spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.tar.bz2 spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.zip |
Merge branch 'master' of github.com:apache/incubator-spark into scala-2.10
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/rdd.py | 60 |
1 files changed, 53 insertions, 7 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 33dc865256..245a132dfd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -117,8 +117,6 @@ class RDD(object): else: return None - # TODO persist(self, storageLevel) - def map(self, f, preservesPartitioning=False): """ Return a new RDD containing the distinct elements in this RDD. @@ -227,7 +225,7 @@ class RDD(object): total = num samples = self.sample(withReplacement, fraction, seed).collect() - + # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala @@ -263,7 +261,55 @@ class RDD(object): raise TypeError return self.union(other) - # TODO: sort + def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): + """ + Sorts this RDD, which is assumed to consist of (key, value) pairs. + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 2).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] + >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)]) + >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect() + [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] + """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + + bounds = list() + + # first compute the boundary of each part via sampling: we want to partition + # the key-space into bins such that the bins have roughly the same + # number of (key, value) pairs falling into them + if numPartitions > 1: + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + for i in range(0, numPartitions - 1): + index = (len(samples) - 1) * (i + 1) / numPartitions + bounds.append(samples[index]) + + def rangePartitionFunc(k): + p = 0 + while p < len(bounds) and keyfunc(k) > bounds[p]: + p += 1 + if ascending: + return p + else: + return numPartitions-1-p + + def mapFunc(iterator): + yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) + .mapPartitions(mapFunc,preservesPartitioning=True) + .flatMap(lambda x: x, preservesPartitioning=True)) def glom(self): """ @@ -425,7 +471,7 @@ class RDD(object): 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - + def stats(self): """ Return a L{StatCounter} object that captures the mean, variance @@ -462,7 +508,7 @@ class RDD(object): 0.816... """ return self.stats().stdev() - + def sampleStdev(self): """ Compute the sample standard deviation of this RDD's elements (which corrects for bias in @@ -832,7 +878,7 @@ class RDD(object): >>> y = sc.parallelize([("a", 3), ("c", None)]) >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] - """ + """ filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0 map_func = lambda (key, vals): [(key, val) for val in vals[0]] return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) |