From c1ea3afb516c204925259f0928dfb17d0fa89621 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 3 Apr 2014 15:42:17 -0700 Subject: Spark 1162 Implemented takeOrdered in pyspark. Since python does not have a library for max heap and usual tricks like inverting values etc.. does not work for all cases. We have our own implementation of max heap. Author: Prashant Sharma Closes #97 from ScrapCodes/SPARK-1162/pyspark-top-takeOrdered2 and squashes the following commits: 35f86ba [Prashant Sharma] code review 2b1124d [Prashant Sharma] fixed tests e8a08e2 [Prashant Sharma] Code review comments. 49e6ba7 [Prashant Sharma] SPARK-1162 added takeOrdered to pyspark --- python/pyspark/rdd.py | 107 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 102 insertions(+), 5 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 019c249699..9943296b92 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,7 +29,7 @@ from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread import warnings -from heapq import heappush, heappop, heappushpop +import heapq from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long @@ -41,9 +41,9 @@ from pyspark.storagelevel import StorageLevel from py4j.java_collections import ListConverter, MapConverter - __all__ = ["RDD"] + def _extract_concise_traceback(): """ This function returns the traceback info for a callsite, returns a dict @@ -91,6 +91,73 @@ class _JavaStackTrace(object): if _spark_stack_depth == 0: self._context._jsc.setCallSite(None) +class MaxHeapQ(object): + """ + An implementation of MaxHeap. + >>> import pyspark.rdd + >>> heap = pyspark.rdd.MaxHeapQ(5) + >>> [heap.insert(i) for i in range(10)] + [None, None, None, None, None, None, None, None, None, None] + >>> sorted(heap.getElements()) + [0, 1, 2, 3, 4] + >>> heap = pyspark.rdd.MaxHeapQ(5) + >>> [heap.insert(i) for i in range(9, -1, -1)] + [None, None, None, None, None, None, None, None, None, None] + >>> sorted(heap.getElements()) + [0, 1, 2, 3, 4] + >>> heap = pyspark.rdd.MaxHeapQ(1) + >>> [heap.insert(i) for i in range(9, -1, -1)] + [None, None, None, None, None, None, None, None, None, None] + >>> heap.getElements() + [0] + """ + + def __init__(self, maxsize): + # we start from q[1], this makes calculating children as trivial as 2 * k + self.q = [0] + self.maxsize = maxsize + + def _swim(self, k): + while (k > 1) and (self.q[k/2] < self.q[k]): + self._swap(k, k/2) + k = k/2 + + def _swap(self, i, j): + t = self.q[i] + self.q[i] = self.q[j] + self.q[j] = t + + def _sink(self, k): + N = self.size() + while 2 * k <= N: + j = 2 * k + # Here we test if both children are greater than parent + # if not swap with larger one. + if j < N and self.q[j] < self.q[j + 1]: + j = j + 1 + if(self.q[k] > self.q[j]): + break + self._swap(k, j) + k = j + + def size(self): + return len(self.q) - 1 + + def insert(self, value): + if (self.size()) < self.maxsize: + self.q.append(value) + self._swim(self.size()) + else: + self._replaceRoot(value) + + def getElements(self): + return self.q[1:] + + def _replaceRoot(self, value): + if(self.q[1] > value): + self.q[1] = value + self._sink(1) + class RDD(object): """ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. @@ -696,16 +763,16 @@ class RDD(object): Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] - >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2) + >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2) [6, 5] """ def topIterator(iterator): q = [] for k in iterator: if len(q) < num: - heappush(q, k) + heapq.heappush(q, k) else: - heappushpop(q, k) + heapq.heappushpop(q, k) yield q def merge(a, b): @@ -713,6 +780,36 @@ class RDD(object): return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True) + def takeOrdered(self, num, key=None): + """ + Get the N elements from a RDD ordered in ascending order or as specified + by the optional key function. + + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) + [1, 2, 3, 4, 5, 6] + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) + [10, 9, 7, 6, 5, 4] + """ + + def topNKeyedElems(iterator, key_=None): + q = MaxHeapQ(num) + for k in iterator: + if key_ != None: + k = (key_(k), k) + q.insert(k) + yield q.getElements() + + def unKey(x, key_=None): + if key_ != None: + x = [i[1] for i in x] + return x + + def merge(a, b): + return next(topNKeyedElems(a + b)) + result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge) + return sorted(unKey(result, key), key=key) + + def take(self, num): """ Take the first num elements of the RDD. -- cgit v1.2.3