diff options
author | Davies Liu <davies@databricks.com> | 2016-03-29 15:06:29 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-29 15:06:29 -0700 |
commit | a7a93a116dd9813853ba6f112beb7763931d2006 (patch) | |
tree | 9818f89be4fe960ccfa7585335bbebbff3666810 /python/pyspark/sql/tests.py | |
parent | e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff) | |
download | spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2 spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip |
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1a5d422af9..84947560e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_python_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) |