aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py9
-rw-r--r--pyspark/pyspark/rdd.py34
-rw-r--r--pyspark/pyspark/serializers.py8
-rw-r--r--pyspark/pyspark/worker.py12
4 files changed, 42 insertions, 21 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 50d57e5317..19f9f9e133 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -14,9 +14,8 @@ class SparkContext(object):
gateway = launch_gateway()
jvm = gateway.jvm
- pickleFile = jvm.spark.api.python.PythonRDD.pickleFile
- asPickle = jvm.spark.api.python.PythonRDD.asPickle
- arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
+ readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+ writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
def __init__(self, master, name, defaultParallelism=None):
self.master = master
@@ -45,11 +44,11 @@ class SparkContext(object):
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
+ atexit.register(lambda: os.unlink(tempFile.name))
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- atexit.register(lambda: os.unlink(tempFile.name))
- jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
+ jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 708ea6eb55..01908cff96 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,13 +1,15 @@
+import atexit
from base64 import standard_b64encode as b64enc
from collections import defaultdict
from itertools import chain, ifilter, imap
import os
import shlex
from subprocess import Popen, PIPE
+from tempfile import NamedTemporaryFile
from threading import Thread
from pyspark import cloudpickle
-from pyspark.serializers import dump_pickle, load_pickle
+from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
@@ -145,10 +147,30 @@ class RDD(object):
self.map(f).collect() # Force evaluation
def collect(self):
+ # To minimize the number of transfers between Python and Java, we'll
+ # flatten each partition into a list before collecting it. Due to
+ # pipelining, this should add minimal overhead.
def asList(iterator):
yield list(iterator)
- pickles = self.mapPartitions(asList)._jrdd.rdd().collect()
- return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles))
+ picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect()
+ return list(chain.from_iterable(self._collect_array_through_file(picklesInJava)))
+
+ def _collect_array_through_file(self, array):
+ # Transferring lots of data through Py4J can be slow because
+ # socket.readline() is inefficient. Instead, we'll dump the data to a
+ # file and read it back.
+ tempFile = NamedTemporaryFile(delete=False)
+ tempFile.close()
+ def clean_up_file():
+ try: os.unlink(tempFile.name)
+ except: pass
+ atexit.register(clean_up_file)
+ self.ctx.writeArrayToPickleFile(array, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ for item in read_from_pickle_file(tempFile):
+ yield item
+ os.unlink(tempFile.name)
def reduce(self, f):
"""
@@ -220,15 +242,15 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4]).take(2)
[2, 3]
"""
- pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
- return load_pickle(bytes(pickle))
+ picklesInJava = self._jrdd.rdd().take(num)
+ return list(self._collect_array_through_file(picklesInJava))
def first(self):
"""
>>> sc.parallelize([2, 3, 4]).first()
2
"""
- return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
+ return self.take(1)[0]
def saveAsTextFile(self, path):
def func(iterator):
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index 21ef8b106c..bfcdda8f12 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -33,3 +33,11 @@ def read_with_length(stream):
if obj == "":
raise EOFError
return obj
+
+
+def read_from_pickle_file(stream):
+ try:
+ while True:
+ yield load_pickle(read_with_length(stream))
+ except EOFError:
+ return
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index 62824a1c9b..9f6b507dbd 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -8,7 +8,7 @@ from base64 import standard_b64decode
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \
- read_long, read_int, dump_pickle, load_pickle
+ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions.
@@ -20,14 +20,6 @@ def load_obj():
return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
-def read_input():
- try:
- while True:
- yield load_pickle(read_with_length(sys.stdin))
- except EOFError:
- return
-
-
def main():
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
@@ -40,7 +32,7 @@ def main():
dumps = lambda x: x
else:
dumps = dump_pickle
- for obj in func(read_input()):
+ for obj in func(read_from_pickle_file(sys.stdin)):
write_with_length(dumps(obj), old_stdout)