aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-10-10 09:42:55 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-10-10 09:42:55 +0530
commit026ab7566167e6c8ab1b0cce75b9e09bbd485bee (patch)
treea713bacba391eb9b8e07ca0d2f6521cd2b061b49 /python
parent26860639c5fee7fc23db1e686f8eb202921e4314 (diff)
parent320418f7c8b42d4ce781b32c9ee47a9b54550b89 (diff)
downloadspark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.tar.gz
spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.tar.bz2
spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.zip
Merge branch 'master' of github.com:apache/incubator-spark into scala-2.10
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py60
1 files changed, 53 insertions, 7 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 33dc865256..245a132dfd 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -117,8 +117,6 @@ class RDD(object):
else:
return None
- # TODO persist(self, storageLevel)
-
def map(self, f, preservesPartitioning=False):
"""
Return a new RDD containing the distinct elements in this RDD.
@@ -227,7 +225,7 @@ class RDD(object):
total = num
samples = self.sample(withReplacement, fraction, seed).collect()
-
+
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
@@ -263,7 +261,55 @@ class RDD(object):
raise TypeError
return self.union(other)
- # TODO: sort
+ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
+ """
+ Sorts this RDD, which is assumed to consist of (key, value) pairs.
+
+ >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+ >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
+ >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
+ >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
+ [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+ """
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
+
+ bounds = list()
+
+ # first compute the boundary of each part via sampling: we want to partition
+ # the key-space into bins such that the bins have roughly the same
+ # number of (key, value) pairs falling into them
+ if numPartitions > 1:
+ rddSize = self.count()
+ maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+ fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+
+ samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+ samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+ # we have numPartitions many parts but one of the them has
+ # an implicit boundary
+ for i in range(0, numPartitions - 1):
+ index = (len(samples) - 1) * (i + 1) / numPartitions
+ bounds.append(samples[index])
+
+ def rangePartitionFunc(k):
+ p = 0
+ while p < len(bounds) and keyfunc(k) > bounds[p]:
+ p += 1
+ if ascending:
+ return p
+ else:
+ return numPartitions-1-p
+
+ def mapFunc(iterator):
+ yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+ return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
+ .mapPartitions(mapFunc,preservesPartitioning=True)
+ .flatMap(lambda x: x, preservesPartitioning=True))
def glom(self):
"""
@@ -425,7 +471,7 @@ class RDD(object):
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
-
+
def stats(self):
"""
Return a L{StatCounter} object that captures the mean, variance
@@ -462,7 +508,7 @@ class RDD(object):
0.816...
"""
return self.stats().stdev()
-
+
def sampleStdev(self):
"""
Compute the sample standard deviation of this RDD's elements (which corrects for bias in
@@ -832,7 +878,7 @@ class RDD(object):
>>> y = sc.parallelize([("a", 3), ("c", None)])
>>> sorted(x.subtractByKey(y).collect())
[('b', 4), ('b', 5)]
- """
+ """
filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
map_func = lambda (key, vals): [(key, val) for val in vals[0]]
return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)