aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-29 15:06:29 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-29 15:06:29 -0700
commita7a93a116dd9813853ba6f112beb7763931d2006 (patch)
tree9818f89be4fe960ccfa7585335bbebbff3666810 /python/pyspark/sql/tests.py
parente58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff)
downloadspark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz
spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2
spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request? This PR brings the support for chained Python UDFs, for example ```sql select udf1(udf2(a)) select udf1(udf2(a) + 3) select udf1(udf2(a) + udf3(b)) ``` Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches. For example, ```python >>> sqlContext.sql("select double(double(1))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#10 AS double(double(1))#9] : +- INPUT +- !BatchPythonEvaluation double(double(1)), [pythonUDF#10] +- Scan OneRowRelation[] >>> sqlContext.sql("select double(double(1) + double(2))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16] : +- INPUT +- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19] +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18] +- !BatchPythonEvaluation double(1), [pythonUDF#17] +- Scan OneRowRelation[] ``` TODO: will support multiple unrelated Python UDFs in one batch (another PR). ## How was this patch tested? Added new unit tests for chained UDFs. Author: Davies Liu <davies@databricks.com> Closes #12014 from davies/py_udfs.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1a5d422af9..84947560e7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
+ def test_chained_python_udf(self):
+ self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+ self.assertEqual(row[0], 2)
+ [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+ self.assertEqual(row[0], 4)
+ [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+ self.assertEqual(row[0], 6)
+
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)