diff options
author | Davies Liu <davies@databricks.com> | 2016-03-31 16:40:20 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-31 16:40:20 -0700 |
commit | f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d (patch) | |
tree | b374ad4a7c98e11f8f85fbd44618422bd4fe6a1b /python/pyspark | |
parent | 8de201baedc8e839e06098c536ba31b3dafd54b5 (diff) | |
download | spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.gz spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.bz2 spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.zip |
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/functions.py | 3 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 12 | ||||
-rw-r--r-- | python/pyspark/worker.py | 68 |
3 files changed, 64 insertions, 19 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3211834226..3b20ba5177 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- def _wrap_function(sc, func, returnType): - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, returnType, ser) + command = (func, returnType) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84947560e7..536ef55251 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,7 +305,7 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) - def test_chained_python_udf(self): + def test_chained_udf(self): self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.sqlCtx.sql("SELECT double(1)").collect() self.assertEqual(row[0], 2) @@ -314,6 +314,16 @@ class SQLTests(ReusedPySparkTestCase): [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) + def test_multiple_udfs(self): + self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0f05fe31aa..cf47ab8f96 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -59,7 +59,54 @@ def read_command(serializer, file): def chain(f, g): """chain two function together """ - return lambda x: g(f(x)) + return lambda *a: g(f(*a)) + + +def wrap_udf(f, return_type): + if return_type.needConversion(): + toInternal = return_type.toInternal + return lambda *a: toInternal(f(*a)) + else: + return lambda *a: f(*a) + + +def read_single_udf(pickleSer, infile): + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + row_func = None + for i in range(read_int(infile)): + f, return_type = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + # the last returnType will be the return type of UDF + return arg_offsets, wrap_udf(row_func, return_type) + + +def read_udfs(pickleSer, infile): + num_udfs = read_int(infile) + if num_udfs == 1: + # fast path for single UDF + _, udf = read_single_udf(pickleSer, infile) + mapper = lambda a: udf(*a) + else: + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) + + func = lambda _, it: map(mapper, it) + ser = BatchedSerializer(PickleSerializer(), 100) + # profiling is not supported for UDF + return func, None, ser, ser def main(infile, outfile): @@ -107,21 +154,10 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - row_based = read_int(infile) - num_commands = read_int(infile) - if row_based: - profiler = None # profiling is not supported for UDF - row_func = None - for i in range(num_commands): - f, returnType, deserializer = read_command(pickleSer, infile) - if row_func is None: - row_func = f - else: - row_func = chain(row_func, f) - serializer = deserializer - func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + is_sql_udf = read_int(infile) + if is_sql_udf: + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) else: - assert num_commands == 1 func, profiler, deserializer, serializer = read_command(pickleSer, infile) init_time = time.time() |