aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r--python/pyspark/worker.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7696df9d1c..4e64557fc4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \
- SpecialLengths
+ read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
+ SpecialLengths, read_mutf8, read_pairs_from_pickle_file
def load_obj(infile):
@@ -53,7 +53,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = load_pickle(read_with_length(infile))
+ spark_files_dir = read_mutf8(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -68,17 +68,21 @@ def main(infile, outfile):
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+ sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
# now load function
func = load_obj(infile)
bypassSerializer = load_obj(infile)
+ stageInputIsPairs = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
init_time = time.time()
- iterator = read_from_pickle_file(infile)
+ if stageInputIsPairs:
+ iterator = read_pairs_from_pickle_file(infile)
+ else:
+ iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
write_with_length(dumps(obj), outfile)