aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-12-19 11:40:34 -0800
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-12-19 11:40:34 -0800
commitd3234f9726db3917af4688ba70933938b078b0bd (patch)
treeb343f29b81bbaf8d1165f76e8e5748876b3fb008 /python
parentaf0cd6bd27dda73b326bcb6a66addceadebf5e54 (diff)
downloadspark-d3234f9726db3917af4688ba70933938b078b0bd.tar.gz
spark-d3234f9726db3917af4688ba70933938b078b0bd.tar.bz2
spark-d3234f9726db3917af4688ba70933938b078b0bd.zip
Make collectPartitions take an array of partitions
Change the implementation to use runJob instead of PartitionPruningRDD. Also update the unit tests and the python take implementation to use the new interface.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d81b7c90c1..7015119551 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -576,8 +576,13 @@ class RDD(object):
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
+ # TODO(shivaram): Similar to the scala implementation, update the take
+ # method to scan multiple splits based on an estimate of how many elements
+ # we have per-split.
for partition in range(mapped._jrdd.splits().size()):
- iterator = mapped._jrdd.collectPartition(partition).iterator()
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break