aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-03 14:52:21 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-03 14:52:21 -0800
commit33beba39656fc64984db09a82fc69ca4edcc02d4 (patch)
treef1e84417d3d9d37c1ccf4d01d20bd4930f2ecf5d /python/pyspark/rdd.py
parentce9f1bbe20eff794cd1d588dc88f109d32588cfe (diff)
downloadspark-33beba39656fc64984db09a82fc69ca4edcc02d4.tar.gz
spark-33beba39656fc64984db09a82fc69ca4edcc02d4.tar.bz2
spark-33beba39656fc64984db09a82fc69ca4edcc02d4.zip
Change PySpark RDD.take() to not call iterator().
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cbffb6cc1f..4ba417b2a2 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -328,18 +328,17 @@ class RDD(object):
a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead.
- >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
+ >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
[2, 3]
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
items = []
- splits = self._jrdd.splits()
- taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0)
- while len(items) < num and splits:
- split = splits.pop(0)
- iterator = self._jrdd.iterator(split, taskContext)
+ for partition in range(self._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
return items[:num]
def first(self):