aboutsummaryrefslogtreecommitdiff
path: root/pyspark/pyspark/rdd.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-28 22:19:12 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-28 22:21:16 -0800
commit7ec3595de28d53839cb3a45e940ec16f81ffdf45 (patch)
tree2933cb5d71d76fdcea27125168f346ad38d4fca2 /pyspark/pyspark/rdd.py
parentfbadb1cda504b256e3d12c4ce389e723b6f2503c (diff)
downloadspark-7ec3595de28d53839cb3a45e940ec16f81ffdf45.tar.gz
spark-7ec3595de28d53839cb3a45e940ec16f81ffdf45.tar.bz2
spark-7ec3595de28d53839cb3a45e940ec16f81ffdf45.zip
Fix bug (introduced by batching) in PySpark take()
Diffstat (limited to 'pyspark/pyspark/rdd.py')
-rw-r--r--pyspark/pyspark/rdd.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index bf32472d25..111476d274 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -152,8 +152,8 @@ class RDD(object):
into a list.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
- >>> rdd.glom().first()
- [1, 2]
+ >>> sorted(rdd.glom().collect())
+ [[1, 2], [3, 4]]
"""
def func(iterator): yield list(iterator)
return self.mapPartitions(func)
@@ -211,10 +211,10 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.rdd().collect()
- return list(self._collect_array_through_file(picklesInJava))
+ picklesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(picklesInJava))
- def _collect_array_through_file(self, array):
+ def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
@@ -224,7 +224,7 @@ class RDD(object):
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
- self.ctx.writeArrayToPickleFile(array, tempFile.name)
+ self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
@@ -325,11 +325,18 @@ class RDD(object):
a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead.
- >>> sc.parallelize([2, 3, 4]).take(2)
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
[2, 3]
- """
- picklesInJava = self._jrdd.rdd().take(num)
- return list(self._collect_array_through_file(picklesInJava))
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
+ [2, 3, 4, 5, 6]
+ """
+ items = []
+ splits = self._jrdd.splits()
+ while len(items) < num and splits:
+ split = splits.pop(0)
+ iterator = self._jrdd.iterator(split)
+ items.extend(self._collect_iterator_through_file(iterator))
+ return items[:num]
def first(self):
"""