aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-23 18:55:13 -0700
committerJosh Rosen <joshrosen@apache.org>2014-08-23 18:55:13 -0700
commitdb436e36c4e20893de708a0bc07a5a8877c49563 (patch)
tree4ed5910b6ab36a97eee5d7dae35e8db992a82894 /python
parent3519b5e8e55b4530d7f7c0bcab254f863dbfa814 (diff)
downloadspark-db436e36c4e20893de708a0bc07a5a8877c49563.tar.gz
spark-db436e36c4e20893de708a0bc07a5a8877c49563.tar.bz2
spark-db436e36c4e20893de708a0bc07a5a8877c49563.zip
[SPARK-2871] [PySpark] add `key` argument for max(), min() and top(n)
RDD.max(key=None) param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() 43.0 >>> rdd.max(key=str) 5.0 RDD.min(key=None) Find the minimum item in this RDD. param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() 2.0 >>> rdd.min(key=str) 10.0 RDD.top(num, key=None) Get the top N elements from a RDD. Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) [6, 5] >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str) [4, 3, 2] Author: Davies Liu <davies.liu@gmail.com> Closes #2094 from davies/cmp and squashes the following commits: ccbaf25 [Davies Liu] add `key` to top() ad7e374 [Davies Liu] fix tests 2f63512 [Davies Liu] change `comp` to `key` in min/max dd91e08 [Davies Liu] add `comp` argument for RDD.max() and RDD.min()
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py44
1 files changed, 27 insertions, 17 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 3eefc878d2..bdd8bc8286 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -810,23 +810,37 @@ class RDD(object):
return self.mapPartitions(func).fold(zeroValue, combOp)
- def max(self):
+ def max(self, key=None):
"""
Find the maximum item in this RDD.
- >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
+ @param key: A function used to generate key for comparing
+
+ >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
+ >>> rdd.max()
43.0
+ >>> rdd.max(key=str)
+ 5.0
"""
- return self.reduce(max)
+ if key is None:
+ return self.reduce(max)
+ return self.reduce(lambda a, b: max(a, b, key=key))
- def min(self):
+ def min(self, key=None):
"""
Find the minimum item in this RDD.
- >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
- 1.0
+ @param key: A function used to generate key for comparing
+
+ >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
+ >>> rdd.min()
+ 2.0
+ >>> rdd.min(key=str)
+ 10.0
"""
- return self.reduce(min)
+ if key is None:
+ return self.reduce(min)
+ return self.reduce(lambda a, b: min(a, b, key=key))
def sum(self):
"""
@@ -924,7 +938,7 @@ class RDD(object):
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
- def top(self, num):
+ def top(self, num, key=None):
"""
Get the top N elements from a RDD.
@@ -933,20 +947,16 @@ class RDD(object):
[12]
>>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
[6, 5]
+ >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
+ [4, 3, 2]
"""
def topIterator(iterator):
- q = []
- for k in iterator:
- if len(q) < num:
- heapq.heappush(q, k)
- else:
- heapq.heappushpop(q, k)
- yield q
+ yield heapq.nlargest(num, iterator, key=key)
def merge(a, b):
- return next(topIterator(a + b))
+ return heapq.nlargest(num, a + b, key=key)
- return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
+ return self.mapPartitions(topIterator).reduce(merge)
def takeOrdered(self, num, key=None):
"""