diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/rdd.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d81b7c90c1..7015119551 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -576,8 +576,13 @@ class RDD(object): # Take only up to num elements from each partition we try mapped = self.mapPartitions(takeUpToNum) items = [] + # TODO(shivaram): Similar to the scala implementation, update the take + # method to scan multiple splits based on an estimate of how many elements + # we have per-split. for partition in range(mapped._jrdd.splits().size()): - iterator = mapped._jrdd.collectPartition(partition).iterator() + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break |