aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r--python/pyspark/worker.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 42c2f8b759..0f05fe31aa 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -50,6 +50,18 @@ def add_path(path):
sys.path.insert(1, path)
+def read_command(serializer, file):
+ command = serializer._read_with_length(file)
+ if isinstance(command, Broadcast):
+ command = serializer.loads(command.value)
+ return command
+
+
+def chain(f, g):
+ """chain two function together """
+ return lambda x: g(f(x))
+
+
def main(infile, outfile):
try:
boot_time = time.time()
@@ -95,10 +107,23 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- command = pickleSer._read_with_length(infile)
- if isinstance(command, Broadcast):
- command = pickleSer.loads(command.value)
- func, profiler, deserializer, serializer = command
+ row_based = read_int(infile)
+ num_commands = read_int(infile)
+ if row_based:
+ profiler = None # profiling is not supported for UDF
+ row_func = None
+ for i in range(num_commands):
+ f, returnType, deserializer = read_command(pickleSer, infile)
+ if row_func is None:
+ row_func = f
+ else:
+ row_func = chain(row_func, f)
+ serializer = deserializer
+ func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
+ else:
+ assert num_commands == 1
+ func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+
init_time = time.time()
def process():