aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2014-04-03 15:42:17 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-03 15:42:17 -0700
commitc1ea3afb516c204925259f0928dfb17d0fa89621 (patch)
tree7d8a3485101dd9e191d858058d40e11ec3a7461e /python/pyspark/rdd.py
parent5d1feda217d25616d190f9bb369664e57417cd45 (diff)
downloadspark-c1ea3afb516c204925259f0928dfb17d0fa89621.tar.gz
spark-c1ea3afb516c204925259f0928dfb17d0fa89621.tar.bz2
spark-c1ea3afb516c204925259f0928dfb17d0fa89621.zip
Spark 1162 Implemented takeOrdered in pyspark.
Since python does not have a library for max heap and usual tricks like inverting values etc.. does not work for all cases. We have our own implementation of max heap. Author: Prashant Sharma <prashant.s@imaginea.com> Closes #97 from ScrapCodes/SPARK-1162/pyspark-top-takeOrdered2 and squashes the following commits: 35f86ba [Prashant Sharma] code review 2b1124d [Prashant Sharma] fixed tests e8a08e2 [Prashant Sharma] Code review comments. 49e6ba7 [Prashant Sharma] SPARK-1162 added takeOrdered to pyspark
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py107
1 files changed, 102 insertions, 5 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 019c249699..9943296b92 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
-from heapq import heappush, heappop, heappushpop
+import heapq
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -41,9 +41,9 @@ from pyspark.storagelevel import StorageLevel
from py4j.java_collections import ListConverter, MapConverter
-
__all__ = ["RDD"]
+
def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
@@ -91,6 +91,73 @@ class _JavaStackTrace(object):
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)
+class MaxHeapQ(object):
+ """
+ An implementation of MaxHeap.
+ >>> import pyspark.rdd
+ >>> heap = pyspark.rdd.MaxHeapQ(5)
+ >>> [heap.insert(i) for i in range(10)]
+ [None, None, None, None, None, None, None, None, None, None]
+ >>> sorted(heap.getElements())
+ [0, 1, 2, 3, 4]
+ >>> heap = pyspark.rdd.MaxHeapQ(5)
+ >>> [heap.insert(i) for i in range(9, -1, -1)]
+ [None, None, None, None, None, None, None, None, None, None]
+ >>> sorted(heap.getElements())
+ [0, 1, 2, 3, 4]
+ >>> heap = pyspark.rdd.MaxHeapQ(1)
+ >>> [heap.insert(i) for i in range(9, -1, -1)]
+ [None, None, None, None, None, None, None, None, None, None]
+ >>> heap.getElements()
+ [0]
+ """
+
+ def __init__(self, maxsize):
+ # we start from q[1], this makes calculating children as trivial as 2 * k
+ self.q = [0]
+ self.maxsize = maxsize
+
+ def _swim(self, k):
+ while (k > 1) and (self.q[k/2] < self.q[k]):
+ self._swap(k, k/2)
+ k = k/2
+
+ def _swap(self, i, j):
+ t = self.q[i]
+ self.q[i] = self.q[j]
+ self.q[j] = t
+
+ def _sink(self, k):
+ N = self.size()
+ while 2 * k <= N:
+ j = 2 * k
+ # Here we test if both children are greater than parent
+ # if not swap with larger one.
+ if j < N and self.q[j] < self.q[j + 1]:
+ j = j + 1
+ if(self.q[k] > self.q[j]):
+ break
+ self._swap(k, j)
+ k = j
+
+ def size(self):
+ return len(self.q) - 1
+
+ def insert(self, value):
+ if (self.size()) < self.maxsize:
+ self.q.append(value)
+ self._swim(self.size())
+ else:
+ self._replaceRoot(value)
+
+ def getElements(self):
+ return self.q[1:]
+
+ def _replaceRoot(self, value):
+ if(self.q[1] > value):
+ self.q[1] = value
+ self._sink(1)
+
class RDD(object):
"""
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -696,16 +763,16 @@ class RDD(object):
Note: It returns the list sorted in descending order.
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
[12]
- >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
+ >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
[6, 5]
"""
def topIterator(iterator):
q = []
for k in iterator:
if len(q) < num:
- heappush(q, k)
+ heapq.heappush(q, k)
else:
- heappushpop(q, k)
+ heapq.heappushpop(q, k)
yield q
def merge(a, b):
@@ -713,6 +780,36 @@ class RDD(object):
return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
+ def takeOrdered(self, num, key=None):
+ """
+ Get the N elements from a RDD ordered in ascending order or as specified
+ by the optional key function.
+
+ >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
+ [1, 2, 3, 4, 5, 6]
+ >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
+ [10, 9, 7, 6, 5, 4]
+ """
+
+ def topNKeyedElems(iterator, key_=None):
+ q = MaxHeapQ(num)
+ for k in iterator:
+ if key_ != None:
+ k = (key_(k), k)
+ q.insert(k)
+ yield q.getElements()
+
+ def unKey(x, key_=None):
+ if key_ != None:
+ x = [i[1] for i in x]
+ return x
+
+ def merge(a, b):
+ return next(topNKeyedElems(a + b))
+ result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
+ return sorted(unKey(result, key), key=key)
+
+
def take(self, num):
"""
Take the first num elements of the RDD.