diff options
author | Aaron Davidson <aaron@databricks.com> | 2014-05-31 13:04:57 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-05-31 13:04:57 -0700 |
commit | 9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 (patch) | |
tree | 86ab8e6477ab4a631b3a91f0e89007ca69c78d37 /python/pyspark/rdd.py | |
parent | 7d52777effd0ff41aed545f53d2ab8de2364a188 (diff) | |
download | spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.gz spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.bz2 spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.zip |
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.)
Author: Aaron Davidson <aaron@databricks.com>
Closes #922 from aarondav/take and squashes the following commits:
fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r-- | python/pyspark/rdd.py | 59 |
1 files changed, 38 insertions, 21 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 07578b8d93..f3b1f1a665 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -841,34 +841,51 @@ class RDD(object): """ Take the first num elements of the RDD. - This currently scans the partitions *one by one*, so it will be slow if - a lot of partitions are required. In that case, use L{collect} to get - the whole RDD instead. + It works by first scanning one partition, and use the results from + that partition to estimate the number of additional partitions needed + to satisfy the limit. + + Translated from the Scala implementation in RDD#take(). >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] + >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3) + [91, 92, 93] """ - def takeUpToNum(iterator): - taken = 0 - while taken < num: - yield next(iterator) - taken += 1 - # Take only up to num elements from each partition we try - mapped = self.mapPartitions(takeUpToNum) items = [] - # TODO(shivaram): Similar to the scala implementation, update the take - # method to scan multiple splits based on an estimate of how many elements - # we have per-split. - with _JavaStackTrace(self.context) as st: - for partition in range(mapped._jrdd.splits().size()): - partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) - partitionsToTake[0] = partition - iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() - items.extend(mapped._collect_iterator_through_file(iterator)) - if len(items) >= num: - break + totalParts = self._jrdd.splits().size() + partsScanned = 0 + + while len(items) < num and partsScanned < totalParts: + # The number of partitions to try in this iteration. + # It is ok for this number to be greater than totalParts because + # we actually cap it at totalParts in runJob. + numPartsToTry = 1 + if partsScanned > 0: + # If we didn't find any rows after the first iteration, just + # try all partitions next. Otherwise, interpolate the number + # of partitions we need to try, but overestimate it by 50%. + if len(items) == 0: + numPartsToTry = totalParts - 1 + else: + numPartsToTry = int(1.5 * num * partsScanned / len(items)) + + left = num - len(items) + + def takeUpToNumLeft(iterator): + taken = 0 + while taken < left: + yield next(iterator) + taken += 1 + + p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) + res = self.context.runJob(self, takeUpToNumLeft, p, True) + + items += res + partsScanned += numPartsToTry + return items[:num] def first(self): |