From 9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sat, 31 May 2014 13:04:57 -0700 Subject: SPARK-1839: PySpark RDD#take() shouldn't always read from driver This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.) Author: Aaron Davidson Closes #922 from aarondav/take and squashes the following commits: fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver --- python/pyspark/context.py | 26 +++++++++++++++++++++ python/pyspark/rdd.py | 59 ++++++++++++++++++++++++++++++----------------- 2 files changed, 64 insertions(+), 21 deletions(-) (limited to 'python') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 56746cb7aa..9ae9305d4f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -537,6 +537,32 @@ class SparkContext(object): """ self._jsc.sc().cancelAllJobs() + def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): + """ + Executes the given partitionFunc on the specified set of partitions, + returning the result as an array of elements. + + If 'partitions' is not specified, this will run over all partitions. + + >>> myRDD = sc.parallelize(range(6), 3) + >>> sc.runJob(myRDD, lambda part: [x * x for x in part]) + [0, 1, 4, 9, 16, 25] + + >>> myRDD = sc.parallelize(range(6), 3) + >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True) + [0, 1, 16, 25] + """ + if partitions == None: + partitions = range(rdd._jrdd.splits().size()) + javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) + + # Implementation note: This is implemented as a mapPartitions followed + # by runJob() in order to avoid having to pass a Python lambda into + # SparkContext#runJob. + mappedRDD = rdd.mapPartitions(partitionFunc) + it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) + return list(mappedRDD._collect_iterator_through_file(it)) + def _test(): import atexit import doctest diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 07578b8d93..f3b1f1a665 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -841,34 +841,51 @@ class RDD(object): """ Take the first num elements of the RDD. - This currently scans the partitions *one by one*, so it will be slow if - a lot of partitions are required. In that case, use L{collect} to get - the whole RDD instead. + It works by first scanning one partition, and use the results from + that partition to estimate the number of additional partitions needed + to satisfy the limit. + + Translated from the Scala implementation in RDD#take(). >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] + >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3) + [91, 92, 93] """ - def takeUpToNum(iterator): - taken = 0 - while taken < num: - yield next(iterator) - taken += 1 - # Take only up to num elements from each partition we try - mapped = self.mapPartitions(takeUpToNum) items = [] - # TODO(shivaram): Similar to the scala implementation, update the take - # method to scan multiple splits based on an estimate of how many elements - # we have per-split. - with _JavaStackTrace(self.context) as st: - for partition in range(mapped._jrdd.splits().size()): - partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) - partitionsToTake[0] = partition - iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() - items.extend(mapped._collect_iterator_through_file(iterator)) - if len(items) >= num: - break + totalParts = self._jrdd.splits().size() + partsScanned = 0 + + while len(items) < num and partsScanned < totalParts: + # The number of partitions to try in this iteration. + # It is ok for this number to be greater than totalParts because + # we actually cap it at totalParts in runJob. + numPartsToTry = 1 + if partsScanned > 0: + # If we didn't find any rows after the first iteration, just + # try all partitions next. Otherwise, interpolate the number + # of partitions we need to try, but overestimate it by 50%. + if len(items) == 0: + numPartsToTry = totalParts - 1 + else: + numPartsToTry = int(1.5 * num * partsScanned / len(items)) + + left = num - len(items) + + def takeUpToNumLeft(iterator): + taken = 0 + while taken < left: + yield next(iterator) + taken += 1 + + p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) + res = self.context.runJob(self, takeUpToNumLeft, p, True) + + items += res + partsScanned += numPartsToTry + return items[:num] def first(self): -- cgit v1.2.3