From 7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Nov 2013 11:03:02 -0800 Subject: Remove Pickle-wrapping of Java objects in PySpark. If we support custom serializers, the Python worker will know what type of input to expect, so we won't need to wrap Tuple2 and Strings into pickled tuples and strings. --- python/pyspark/context.py | 10 +++++----- python/pyspark/rdd.py | 11 +++++++---- python/pyspark/serializers.py | 18 ++++++++++++++++++ python/pyspark/worker.py | 14 +++++++++----- 4 files changed, 39 insertions(+), 14 deletions(-) (limited to 'python') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc888..0fec1a6bf6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -42,7 +42,7 @@ class SparkContext(object): _gateway = None _jvm = None - _writeIteratorToPickleFile = None + _writeToFile = None _takePartition = None _next_accum_id = 0 _active_spark_context = None @@ -125,8 +125,8 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeIteratorToPickleFile = \ - SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._writeToFile = \ + SparkContext._jvm.PythonRDD.writeToFile SparkContext._takePartition = \ SparkContext._jvm.PythonRDD.takePartition @@ -190,8 +190,8 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile - jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8bee..d3c4d13a1e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -54,6 +54,7 @@ class RDD(object): self.is_checkpointed = False self.ctx = ctx self._partitionFunc = None + self._stage_input_is_pairs = False @property def context(self): @@ -344,6 +345,7 @@ class RDD(object): yield pair else: yield pair + java_cartesian._stage_input_is_pairs = True return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numPartitions=None): @@ -391,8 +393,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,7 +402,7 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -941,6 +943,7 @@ class PipelinedRDD(RDD): self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd + self._stage_input_is_pairs = prev._stage_input_is_pairs self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx @@ -959,7 +962,7 @@ class PipelinedRDD(RDD): def batched_func(split, iterator): return batched(oldfunc(split, iterator), batchSize) func = batched_func - cmds = [func, self._bypass_serializer] + cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fbc280fd37..fd02e1ee8f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -93,6 +93,14 @@ def write_with_length(obj, stream): stream.write(obj) +def read_mutf8(stream): + """ + Read a string written with Java's DataOutputStream.writeUTF() method. + """ + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') + + def read_with_length(stream): length = read_int(stream) obj = stream.read(length) @@ -112,3 +120,13 @@ def read_from_pickle_file(stream): yield obj except EOFError: return + + +def read_pairs_from_pickle_file(stream): + try: + while True: + a = load_pickle(read_with_length(stream)) + b = load_pickle(read_with_length(stream)) + yield (a, b) + except EOFError: + return \ No newline at end of file diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7696df9d1c..4e64557fc4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ - SpecialLengths + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ + SpecialLengths, read_mutf8, read_pairs_from_pickle_file def load_obj(infile): @@ -53,7 +53,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = read_mutf8(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -68,17 +68,21 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) + stageInputIsPairs = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(infile) + if stageInputIsPairs: + iterator = read_pairs_from_pickle_file(infile) + else: + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) -- cgit v1.2.3