diff options
Diffstat (limited to 'pyspark')
-rw-r--r-- | pyspark/pyspark/context.py | 9 | ||||
-rw-r--r-- | pyspark/pyspark/rdd.py | 34 | ||||
-rw-r--r-- | pyspark/pyspark/serializers.py | 8 | ||||
-rw-r--r-- | pyspark/pyspark/worker.py | 12 |
4 files changed, 42 insertions, 21 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 50d57e5317..19f9f9e133 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -14,9 +14,8 @@ class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - pickleFile = jvm.spark.api.python.PythonRDD.pickleFile - asPickle = jvm.spark.api.python.PythonRDD.asPickle - arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle + readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile def __init__(self, master, name, defaultParallelism=None): self.master = master @@ -45,11 +44,11 @@ class SparkContext(object): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 708ea6eb55..01908cff96 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,13 +1,15 @@ +import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap import os import shlex from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle +from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -145,10 +147,30 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): + # To minimize the number of transfers between Python and Java, we'll + # flatten each partition into a list before collecting it. Due to + # pipelining, this should add minimal overhead. def asList(iterator): yield list(iterator) - pickles = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) + picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + + def _collect_array_through_file(self, array): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx.writeArrayToPickleFile(array, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) def reduce(self, f): """ @@ -220,15 +242,15 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return load_pickle(bytes(pickle)) + picklesInJava = self._jrdd.rdd().take(num) + return list(self._collect_array_through_file(picklesInJava)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) + return self.take(1)[0] def saveAsTextFile(self, path): def func(iterator): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 21ef8b106c..bfcdda8f12 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -33,3 +33,11 @@ def read_with_length(stream): if obj == "": raise EOFError return obj + + +def read_from_pickle_file(stream): + try: + while True: + yield load_pickle(read_with_length(stream)) + except EOFError: + return diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 62824a1c9b..9f6b507dbd 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - read_long, read_int, dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. @@ -20,14 +20,6 @@ def load_obj(): return load_pickle(standard_b64decode(sys.stdin.readline().strip())) -def read_input(): - try: - while True: - yield load_pickle(read_with_length(sys.stdin)) - except EOFError: - return - - def main(): num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): @@ -40,7 +32,7 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_input()): + for obj in func(read_from_pickle_file(sys.stdin)): write_with_length(dumps(obj), old_stdout) |