aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-31 16:40:20 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-31 16:40:20 -0700
commitf0afafdc5dfee80d7e5cd2fc1fa8187def7f262d (patch)
treeb374ad4a7c98e11f8f85fbd44618422bd4fe6a1b /python
parent8de201baedc8e839e06098c536ba31b3dafd54b5 (diff)
downloadspark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.gz
spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.bz2
spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.zip
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request? This PR support multiple Python UDFs within single batch, also improve the performance. ```python >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType()) >>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType()) >>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)] +- OneRowRelation$ == Analyzed Logical Plan == double(add(1, 2)): int, add(double(2), 1): int Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Optimized Logical Plan == Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Physical Plan == WholeStageCodegen : +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] : +- INPUT +- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18] +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- Scan OneRowRelation[] ``` ## How was this patch tested? Added new tests. Using the following script to benchmark 1, 2 and 3 udfs, ``` df = sqlContext.range(1, 1 << 23, 1, 4) double = F.udf(lambda x: x * 2, LongType()) print df.select(double(df.id)).count() print df.select(double(df.id), double(df.id + 1)).count() print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count() ``` Here is the results: N | Before | After | speed up ---- |------------ | -------------|------ 1 | 22 s | 7 s | 3.1X 2 | 38 s | 13 s | 2.9X 3 | 58 s | 16 s | 3.6X This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering). Author: Davies Liu <davies@databricks.com> Closes #12057 from davies/multi_udfs.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py3
-rw-r--r--python/pyspark/sql/tests.py12
-rw-r--r--python/pyspark/worker.py68
3 files changed, 64 insertions, 19 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3211834226..3b20ba5177 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------
def _wrap_function(sc, func, returnType):
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, returnType, ser)
+ command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 84947560e7..536ef55251 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -305,7 +305,7 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
- def test_chained_python_udf(self):
+ def test_chained_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
@@ -314,6 +314,16 @@ class SQLTests(ReusedPySparkTestCase):
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)
+ def test_multiple_udfs(self):
+ self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+ self.assertEqual(tuple(row), (2, 4))
+ [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+ self.assertEqual(tuple(row), (4, 12))
+ self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+ self.assertEqual(tuple(row), (6, 5))
+
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0f05fe31aa..cf47ab8f96 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -59,7 +59,54 @@ def read_command(serializer, file):
def chain(f, g):
"""chain two function together """
- return lambda x: g(f(x))
+ return lambda *a: g(f(*a))
+
+
+def wrap_udf(f, return_type):
+ if return_type.needConversion():
+ toInternal = return_type.toInternal
+ return lambda *a: toInternal(f(*a))
+ else:
+ return lambda *a: f(*a)
+
+
+def read_single_udf(pickleSer, infile):
+ num_arg = read_int(infile)
+ arg_offsets = [read_int(infile) for i in range(num_arg)]
+ row_func = None
+ for i in range(read_int(infile)):
+ f, return_type = read_command(pickleSer, infile)
+ if row_func is None:
+ row_func = f
+ else:
+ row_func = chain(row_func, f)
+ # the last returnType will be the return type of UDF
+ return arg_offsets, wrap_udf(row_func, return_type)
+
+
+def read_udfs(pickleSer, infile):
+ num_udfs = read_int(infile)
+ if num_udfs == 1:
+ # fast path for single UDF
+ _, udf = read_single_udf(pickleSer, infile)
+ mapper = lambda a: udf(*a)
+ else:
+ udfs = {}
+ call_udf = []
+ for i in range(num_udfs):
+ arg_offsets, udf = read_single_udf(pickleSer, infile)
+ udfs['f%d' % i] = udf
+ args = ["a[%d]" % o for o in arg_offsets]
+ call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+ # Create function like this:
+ # lambda a: (f0(a0), f1(a1, a2), f2(a3))
+ mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+ mapper = eval(mapper_str, udfs)
+
+ func = lambda _, it: map(mapper, it)
+ ser = BatchedSerializer(PickleSerializer(), 100)
+ # profiling is not supported for UDF
+ return func, None, ser, ser
def main(infile, outfile):
@@ -107,21 +154,10 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- row_based = read_int(infile)
- num_commands = read_int(infile)
- if row_based:
- profiler = None # profiling is not supported for UDF
- row_func = None
- for i in range(num_commands):
- f, returnType, deserializer = read_command(pickleSer, infile)
- if row_func is None:
- row_func = f
- else:
- row_func = chain(row_func, f)
- serializer = deserializer
- func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
+ is_sql_udf = read_int(infile)
+ if is_sql_udf:
+ func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
- assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
init_time = time.time()