aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-07-28 23:28:42 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-07-29 02:51:43 -0400
commitb9d6783f36d527f5082bf13a4ee6fd108e97795c (patch)
tree922a260da5591d675b58246ed1cbd23c38abc4e7 /python
parent72ff62a37c7310bab02f0231e91d3ba4d423217a (diff)
downloadspark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.tar.gz
spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.tar.bz2
spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.zip
Optimize Python take() to not compute entire first partition
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c6a6b24c5a..6efa61aa66 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -386,13 +386,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
+ def takeUpToNum(iterator):
+ taken = 0
+ while taken < num:
+ yield next(iterator)
+ taken += 1
+ # Take only up to num elements from each partition we try
+ mapped = self.mapPartitions(takeUpToNum)
items = []
- for partition in range(self._jrdd.splits().size()):
- iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
- # Each item in the iterator is a string, Python object, batch of
- # Python objects. Regardless, it is sufficient to take `num`
- # of these objects in order to collect `num` Python objects:
- iterator = iterator.take(num)
+ for partition in range(mapped._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break