diff options
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r-- | python/pyspark/worker.py | 44 |
1 files changed, 19 insertions, 25 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d63c2aaef7..f2b3f3c142 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,23 +23,22 @@ import sys import time import socket import traceback -from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles -from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file +from pyspark.serializers import write_with_length, write_int, read_long, \ + write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer -def load_obj(infile): - return load_pickle(standard_b64decode(infile.readline().strip())) +pickleSer = PickleSerializer() +mutf8_deserializer = MUTF8Deserializer() def report_times(outfile, boot, init, finish): - write_int(-3, outfile) + write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) write_long(1000 * init, outfile) write_long(1000 * finish, outfile) @@ -52,7 +51,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = mutf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -60,38 +59,33 @@ def main(infile, outfile): num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + filename = mutf8_deserializer.loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - # now load function - func = load_obj(infile) - bypassSerializer = load_obj(infile) - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command init_time = time.time() - iterator = read_from_pickle_file(infile) try: - for obj in func(split_index, iterator): - write_with_length(dumps(obj), outfile) + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: - write_int(-2, outfile) + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, outfile) - for aid, accum in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), outfile) - write_int(-1, outfile) + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): + pickleSer._write_with_length((aid, accum._value), outfile) if __name__ == '__main__': |