diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/context.py | 10 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 11 | ||||
-rw-r--r-- | python/pyspark/serializers.py | 18 | ||||
-rw-r--r-- | python/pyspark/worker.py | 14 |
4 files changed, 39 insertions, 14 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc888..0fec1a6bf6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -42,7 +42,7 @@ class SparkContext(object): _gateway = None _jvm = None - _writeIteratorToPickleFile = None + _writeToFile = None _takePartition = None _next_accum_id = 0 _active_spark_context = None @@ -125,8 +125,8 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeIteratorToPickleFile = \ - SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._writeToFile = \ + SparkContext._jvm.PythonRDD.writeToFile SparkContext._takePartition = \ SparkContext._jvm.PythonRDD.takePartition @@ -190,8 +190,8 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile - jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8bee..d3c4d13a1e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -54,6 +54,7 @@ class RDD(object): self.is_checkpointed = False self.ctx = ctx self._partitionFunc = None + self._stage_input_is_pairs = False @property def context(self): @@ -344,6 +345,7 @@ class RDD(object): yield pair else: yield pair + java_cartesian._stage_input_is_pairs = True return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numPartitions=None): @@ -391,8 +393,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,7 +402,7 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -941,6 +943,7 @@ class PipelinedRDD(RDD): self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd + self._stage_input_is_pairs = prev._stage_input_is_pairs self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx @@ -959,7 +962,7 @@ class PipelinedRDD(RDD): def batched_func(split, iterator): return batched(oldfunc(split, iterator), batchSize) func = batched_func - cmds = [func, self._bypass_serializer] + cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fbc280fd37..fd02e1ee8f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -93,6 +93,14 @@ def write_with_length(obj, stream): stream.write(obj) +def read_mutf8(stream): + """ + Read a string written with Java's DataOutputStream.writeUTF() method. + """ + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') + + def read_with_length(stream): length = read_int(stream) obj = stream.read(length) @@ -112,3 +120,13 @@ def read_from_pickle_file(stream): yield obj except EOFError: return + + +def read_pairs_from_pickle_file(stream): + try: + while True: + a = load_pickle(read_with_length(stream)) + b = load_pickle(read_with_length(stream)) + yield (a, b) + except EOFError: + return
\ No newline at end of file diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7696df9d1c..4e64557fc4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ - SpecialLengths + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ + SpecialLengths, read_mutf8, read_pairs_from_pickle_file def load_obj(infile): @@ -53,7 +53,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = read_mutf8(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -68,17 +68,21 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) + stageInputIsPairs = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(infile) + if stageInputIsPairs: + iterator = read_pairs_from_pickle_file(infile) + else: + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) |