diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-03 14:52:21 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-03 14:52:21 -0800 |
commit | 33beba39656fc64984db09a82fc69ca4edcc02d4 (patch) | |
tree | f1e84417d3d9d37c1ccf4d01d20bd4930f2ecf5d /python/pyspark | |
parent | ce9f1bbe20eff794cd1d588dc88f109d32588cfe (diff) | |
download | spark-33beba39656fc64984db09a82fc69ca4edcc02d4.tar.gz spark-33beba39656fc64984db09a82fc69ca4edcc02d4.tar.bz2 spark-33beba39656fc64984db09a82fc69ca4edcc02d4.zip |
Change PySpark RDD.take() to not call iterator().
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/context.py | 1 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 11 |
2 files changed, 6 insertions, 6 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6172d69dcf..4439356c1f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ class SparkContext(object): jvm = gateway.jvm _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + _takePartition = jvm.PythonRDD.takePartition def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cbffb6cc1f..4ba417b2a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -328,18 +328,17 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) + for partition in range(self._jrdd.splits().size()): + iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): |