diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-07-28 23:28:42 -0400 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-07-29 02:51:43 -0400 |
commit | b9d6783f36d527f5082bf13a4ee6fd108e97795c (patch) | |
tree | 922a260da5591d675b58246ed1cbd23c38abc4e7 /python | |
parent | 72ff62a37c7310bab02f0231e91d3ba4d423217a (diff) | |
download | spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.tar.gz spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.tar.bz2 spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.zip |
Optimize Python take() to not compute entire first partition
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/rdd.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c6a6b24c5a..6efa61aa66 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -386,13 +386,16 @@ class RDD(object): >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ + def takeUpToNum(iterator): + taken = 0 + while taken < num: + yield next(iterator) + taken += 1 + # Take only up to num elements from each partition we try + mapped = self.mapPartitions(takeUpToNum) items = [] - for partition in range(self._jrdd.splits().size()): - iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) - # Each item in the iterator is a string, Python object, batch of - # Python objects. Regardless, it is sufficient to take `num` - # of these objects in order to collect `num` Python objects: - iterator = iterator.take(num) + for partition in range(mapped._jrdd.splits().size()): + iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break |