aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r--python/pyspark/worker.py71
1 files changed, 66 insertions, 5 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 42c2f8b759..cf47ab8f96 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -50,6 +50,65 @@ def add_path(path):
sys.path.insert(1, path)
+def read_command(serializer, file):
+ command = serializer._read_with_length(file)
+ if isinstance(command, Broadcast):
+ command = serializer.loads(command.value)
+ return command
+
+
+def chain(f, g):
+ """chain two function together """
+ return lambda *a: g(f(*a))
+
+
+def wrap_udf(f, return_type):
+ if return_type.needConversion():
+ toInternal = return_type.toInternal
+ return lambda *a: toInternal(f(*a))
+ else:
+ return lambda *a: f(*a)
+
+
+def read_single_udf(pickleSer, infile):
+ num_arg = read_int(infile)
+ arg_offsets = [read_int(infile) for i in range(num_arg)]
+ row_func = None
+ for i in range(read_int(infile)):
+ f, return_type = read_command(pickleSer, infile)
+ if row_func is None:
+ row_func = f
+ else:
+ row_func = chain(row_func, f)
+ # the last returnType will be the return type of UDF
+ return arg_offsets, wrap_udf(row_func, return_type)
+
+
+def read_udfs(pickleSer, infile):
+ num_udfs = read_int(infile)
+ if num_udfs == 1:
+ # fast path for single UDF
+ _, udf = read_single_udf(pickleSer, infile)
+ mapper = lambda a: udf(*a)
+ else:
+ udfs = {}
+ call_udf = []
+ for i in range(num_udfs):
+ arg_offsets, udf = read_single_udf(pickleSer, infile)
+ udfs['f%d' % i] = udf
+ args = ["a[%d]" % o for o in arg_offsets]
+ call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+ # Create function like this:
+ # lambda a: (f0(a0), f1(a1, a2), f2(a3))
+ mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+ mapper = eval(mapper_str, udfs)
+
+ func = lambda _, it: map(mapper, it)
+ ser = BatchedSerializer(PickleSerializer(), 100)
+ # profiling is not supported for UDF
+ return func, None, ser, ser
+
+
def main(infile, outfile):
try:
boot_time = time.time()
@@ -95,10 +154,12 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- command = pickleSer._read_with_length(infile)
- if isinstance(command, Broadcast):
- command = pickleSer.loads(command.value)
- func, profiler, deserializer, serializer = command
+ is_sql_udf = read_int(infile)
+ if is_sql_udf:
+ func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
+ else:
+ func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+
init_time = time.time()
def process():