aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-05-31 13:04:57 -0700
committerReynold Xin <rxin@apache.org>2014-05-31 13:04:57 -0700
commit9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 (patch)
tree86ab8e6477ab4a631b3a91f0e89007ca69c78d37 /python/pyspark/rdd.py
parent7d52777effd0ff41aed545f53d2ab8de2364a188 (diff)
downloadspark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.gz
spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.bz2
spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.zip
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.) Author: Aaron Davidson <aaron@databricks.com> Closes #922 from aarondav/take and squashes the following commits: fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py59
1 files changed, 38 insertions, 21 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 07578b8d93..f3b1f1a665 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -841,34 +841,51 @@ class RDD(object):
"""
Take the first num elements of the RDD.
- This currently scans the partitions *one by one*, so it will be slow if
- a lot of partitions are required. In that case, use L{collect} to get
- the whole RDD instead.
+ It works by first scanning one partition, and use the results from
+ that partition to estimate the number of additional partitions needed
+ to satisfy the limit.
+
+ Translated from the Scala implementation in RDD#take().
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
[2, 3]
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
+ >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
+ [91, 92, 93]
"""
- def takeUpToNum(iterator):
- taken = 0
- while taken < num:
- yield next(iterator)
- taken += 1
- # Take only up to num elements from each partition we try
- mapped = self.mapPartitions(takeUpToNum)
items = []
- # TODO(shivaram): Similar to the scala implementation, update the take
- # method to scan multiple splits based on an estimate of how many elements
- # we have per-split.
- with _JavaStackTrace(self.context) as st:
- for partition in range(mapped._jrdd.splits().size()):
- partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
- partitionsToTake[0] = partition
- iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
- items.extend(mapped._collect_iterator_through_file(iterator))
- if len(items) >= num:
- break
+ totalParts = self._jrdd.splits().size()
+ partsScanned = 0
+
+ while len(items) < num and partsScanned < totalParts:
+ # The number of partitions to try in this iteration.
+ # It is ok for this number to be greater than totalParts because
+ # we actually cap it at totalParts in runJob.
+ numPartsToTry = 1
+ if partsScanned > 0:
+ # If we didn't find any rows after the first iteration, just
+ # try all partitions next. Otherwise, interpolate the number
+ # of partitions we need to try, but overestimate it by 50%.
+ if len(items) == 0:
+ numPartsToTry = totalParts - 1
+ else:
+ numPartsToTry = int(1.5 * num * partsScanned / len(items))
+
+ left = num - len(items)
+
+ def takeUpToNumLeft(iterator):
+ taken = 0
+ while taken < left:
+ yield next(iterator)
+ taken += 1
+
+ p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
+ res = self.context.runJob(self, takeUpToNumLeft, p, True)
+
+ items += res
+ partsScanned += numPartsToTry
+
return items[:num]
def first(self):