From 33beba39656fc64984db09a82fc69ca4edcc02d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Jan 2013 14:52:21 -0800 Subject: Change PySpark RDD.take() to not call iterator(). --- python/pyspark/rdd.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cbffb6cc1f..4ba417b2a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -328,18 +328,17 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) + for partition in range(self._jrdd.splits().size()): + iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): -- cgit v1.2.3