diff options
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r-- | python/pyspark/worker.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7696df9d1c..4e64557fc4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ - SpecialLengths + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ + SpecialLengths, read_mutf8, read_pairs_from_pickle_file def load_obj(infile): @@ -53,7 +53,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = read_mutf8(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -68,17 +68,21 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) + stageInputIsPairs = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(infile) + if stageInputIsPairs: + iterator = read_pairs_from_pickle_file(infile) + else: + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) |