diff options
author | Davies Liu <davies@databricks.com> | 2016-03-29 15:06:29 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-29 15:06:29 -0700 |
commit | a7a93a116dd9813853ba6f112beb7763931d2006 (patch) | |
tree | 9818f89be4fe960ccfa7585335bbebbff3666810 /python | |
parent | e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff) | |
download | spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2 spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip |
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/functions.py | 16 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 9 | ||||
-rw-r--r-- | python/pyspark/worker.py | 33 |
3 files changed, 49 insertions, 9 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5d959ef98..3211834226 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _wrap_function, ignore_unicode_prefix +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1648,6 +1648,14 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- +def _wrap_function(sc, func, returnType): + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, returnType, ser) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class UserDefinedFunction(object): """ User defined function in Python @@ -1662,14 +1670,12 @@ class UserDefinedFunction(object): def _create_judf(self, name): from pyspark.sql import SQLContext - f, returnType = self.func, self.returnType # put them in closure `func` - func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) - ser = AutoBatchedSerializer(PickleSerializer()) sc = SparkContext.getOrCreate() - wrapped_func = _wrap_function(sc, func, ser, ser) + wrapped_func = _wrap_function(sc, self.func, self.returnType) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: + f = self.func name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( name, wrapped_func, jdt) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1a5d422af9..84947560e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_python_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b759..0f05fe31aa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,6 +50,18 @@ def add_path(path): sys.path.insert(1, path) +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two function together """ + return lambda x: g(f(x)) + + def main(infile, outfile): try: boot_time = time.time() @@ -95,10 +107,23 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - command = pickleSer._read_with_length(infile) - if isinstance(command, Broadcast): - command = pickleSer.loads(command.value) - func, profiler, deserializer, serializer = command + row_based = read_int(infile) + num_commands = read_int(infile) + if row_based: + profiler = None # profiling is not supported for UDF + row_func = None + for i in range(num_commands): + f, returnType, deserializer = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + serializer = deserializer + func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + else: + assert num_commands == 1 + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + init_time = time.time() def process(): |