diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-12-24 17:20:10 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-12-24 17:20:10 -0800 |
commit | 4608902fb87af64a15b97ab21fe6382cd6e5a644 (patch) | |
tree | ef2e9e7a9d3f88dccd70e7b6e39753354b605207 /pyspark/pyspark/rdd.py | |
parent | ccd075cf960df6c6c449b709515cdd81499a52be (diff) | |
download | spark-4608902fb87af64a15b97ab21fe6382cd6e5a644.tar.gz spark-4608902fb87af64a15b97ab21fe6382cd6e5a644.tar.bz2 spark-4608902fb87af64a15b97ab21fe6382cd6e5a644.zip |
Use filesystem to collect RDDs in PySpark.
Passing large volumes of data through Py4J seems
to be slow. It appears to be faster to write the
data to the local filesystem and read it back from
Python.
Diffstat (limited to 'pyspark/pyspark/rdd.py')
-rw-r--r-- | pyspark/pyspark/rdd.py | 34 |
1 files changed, 28 insertions, 6 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 708ea6eb55..01908cff96 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,13 +1,15 @@ +import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap import os import shlex from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle +from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -145,10 +147,30 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): + # To minimize the number of transfers between Python and Java, we'll + # flatten each partition into a list before collecting it. Due to + # pipelining, this should add minimal overhead. def asList(iterator): yield list(iterator) - pickles = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) + picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + + def _collect_array_through_file(self, array): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx.writeArrayToPickleFile(array, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) def reduce(self, f): """ @@ -220,15 +242,15 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return load_pickle(bytes(pickle)) + picklesInJava = self._jrdd.rdd().take(num) + return list(self._collect_array_through_file(picklesInJava)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) + return self.take(1)[0] def saveAsTextFile(self, path): def func(iterator): |