diff options
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r-- | python/pyspark/worker.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py new file mode 100644 index 0000000000..812e7a9da5 --- /dev/null +++ b/python/pyspark/worker.py @@ -0,0 +1,59 @@ +""" +Worker that receives input from Piped RDD. +""" +import os +import sys +import traceback +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from pyspark.accumulators import _accumulatorRegistry +from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.cloudpickle import CloudPickler +from pyspark.files import SparkFiles +from pyspark.serializers import write_with_length, read_with_length, write_int, \ + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = os.fdopen(os.dup(1), 'w') +os.dup2(2, 1) + + +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) + + +def main(): + split_index = read_int(sys.stdin) + spark_files_dir = load_pickle(read_with_length(sys.stdin)) + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True + sys.path.append(spark_files_dir) + num_broadcast_variables = read_int(sys.stdin) + for _ in range(num_broadcast_variables): + bid = read_long(sys.stdin) + value = read_with_length(sys.stdin) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x + else: + dumps = dump_pickle + iterator = read_from_pickle_file(sys.stdin) + try: + for obj in func(split_index, iterator): + write_with_length(dumps(obj), old_stdout) + except Exception as e: + write_int(-2, old_stdout) + write_with_length(traceback.format_exc(), old_stdout) + sys.exit(-1) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) + + +if __name__ == '__main__': + main() |