aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d81b7c90c1..7015119551 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -576,8 +576,13 @@ class RDD(object):
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
+ # TODO(shivaram): Similar to the scala implementation, update the take
+ # method to scan multiple splits based on an estimate of how many elements
+ # we have per-split.
for partition in range(mapped._jrdd.splits().size()):
- iterator = mapped._jrdd.collectPartition(partition).iterator()
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break