diff options
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r-- | python/pyspark/worker.py | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b759..0f05fe31aa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,6 +50,18 @@ def add_path(path): sys.path.insert(1, path) +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two function together """ + return lambda x: g(f(x)) + + def main(infile, outfile): try: boot_time = time.time() @@ -95,10 +107,23 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - command = pickleSer._read_with_length(infile) - if isinstance(command, Broadcast): - command = pickleSer.loads(command.value) - func, profiler, deserializer, serializer = command + row_based = read_int(infile) + num_commands = read_int(infile) + if row_based: + profiler = None # profiling is not supported for UDF + row_func = None + for i in range(num_commands): + f, returnType, deserializer = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + serializer = deserializer + func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + else: + assert num_commands == 1 + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + init_time = time.time() def process(): |