aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-28 22:19:12 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-28 22:21:16 -0800
commit7ec3595de28d53839cb3a45e940ec16f81ffdf45 (patch)
tree2933cb5d71d76fdcea27125168f346ad38d4fca2 /pyspark
parentfbadb1cda504b256e3d12c4ce389e723b6f2503c (diff)
downloadspark-7ec3595de28d53839cb3a45e940ec16f81ffdf45.tar.gz
spark-7ec3595de28d53839cb3a45e940ec16f81ffdf45.tar.bz2
spark-7ec3595de28d53839cb3a45e940ec16f81ffdf45.zip
Fix bug (introduced by batching) in PySpark take()
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py6
-rw-r--r--pyspark/pyspark/java_gateway.py2
-rw-r--r--pyspark/pyspark/rdd.py27
3 files changed, 21 insertions, 14 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 988c81cd5d..b90596ecc2 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -19,8 +19,8 @@ class SparkContext(object):
gateway = launch_gateway()
jvm = gateway.jvm
- readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
- writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
+ _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+ _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -94,7 +94,7 @@ class SparkContext(object):
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py
index eb2a875762..2329e536cc 100644
--- a/pyspark/pyspark/java_gateway.py
+++ b/pyspark/pyspark/java_gateway.py
@@ -30,7 +30,7 @@ def launch_gateway():
sys.stderr.write(line)
EchoOutputThread(proc.stdout).start()
# Connect to the gateway
- gateway = JavaGateway(GatewayClient(port=port))
+ gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark
java_import(gateway.jvm, "spark.api.java.*")
java_import(gateway.jvm, "spark.api.python.*")
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index bf32472d25..111476d274 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -152,8 +152,8 @@ class RDD(object):
into a list.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
- >>> rdd.glom().first()
- [1, 2]
+ >>> sorted(rdd.glom().collect())
+ [[1, 2], [3, 4]]
"""
def func(iterator): yield list(iterator)
return self.mapPartitions(func)
@@ -211,10 +211,10 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.rdd().collect()
- return list(self._collect_array_through_file(picklesInJava))
+ picklesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(picklesInJava))
- def _collect_array_through_file(self, array):
+ def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
@@ -224,7 +224,7 @@ class RDD(object):
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
- self.ctx.writeArrayToPickleFile(array, tempFile.name)
+ self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
@@ -325,11 +325,18 @@ class RDD(object):
a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead.
- >>> sc.parallelize([2, 3, 4]).take(2)
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
[2, 3]
- """
- picklesInJava = self._jrdd.rdd().take(num)
- return list(self._collect_array_through_file(picklesInJava))
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
+ [2, 3, 4, 5, 6]
+ """
+ items = []
+ splits = self._jrdd.splits()
+ while len(items) < num and splits:
+ split = splits.pop(0)
+ iterator = self._jrdd.iterator(split)
+ items.extend(self._collect_iterator_through_file(iterator))
+ return items[:num]
def first(self):
"""