diff options
author | Davies Liu <davies@databricks.com> | 2016-03-29 15:06:29 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-29 15:06:29 -0700 |
commit | a7a93a116dd9813853ba6f112beb7763931d2006 (patch) | |
tree | 9818f89be4fe960ccfa7585335bbebbff3666810 /python/pyspark/sql | |
parent | e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff) | |
download | spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2 spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip |
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/functions.py | 16 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 9 |
2 files changed, 20 insertions, 5 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5d959ef98..3211834226 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _wrap_function, ignore_unicode_prefix +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1648,6 +1648,14 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- +def _wrap_function(sc, func, returnType): + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, returnType, ser) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class UserDefinedFunction(object): """ User defined function in Python @@ -1662,14 +1670,12 @@ class UserDefinedFunction(object): def _create_judf(self, name): from pyspark.sql import SQLContext - f, returnType = self.func, self.returnType # put them in closure `func` - func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) - ser = AutoBatchedSerializer(PickleSerializer()) sc = SparkContext.getOrCreate() - wrapped_func = _wrap_function(sc, func, ser, ser) + wrapped_func = _wrap_function(sc, self.func, self.returnType) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: + f = self.func name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( name, wrapped_func, jdt) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1a5d422af9..84947560e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_python_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) |