aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/functions.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-29 15:06:29 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-29 15:06:29 -0700
commita7a93a116dd9813853ba6f112beb7763931d2006 (patch)
tree9818f89be4fe960ccfa7585335bbebbff3666810 /python/pyspark/sql/functions.py
parente58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff)
downloadspark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz
spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2
spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request? This PR brings the support for chained Python UDFs, for example ```sql select udf1(udf2(a)) select udf1(udf2(a) + 3) select udf1(udf2(a) + udf3(b)) ``` Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches. For example, ```python >>> sqlContext.sql("select double(double(1))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#10 AS double(double(1))#9] : +- INPUT +- !BatchPythonEvaluation double(double(1)), [pythonUDF#10] +- Scan OneRowRelation[] >>> sqlContext.sql("select double(double(1) + double(2))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16] : +- INPUT +- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19] +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18] +- !BatchPythonEvaluation double(1), [pythonUDF#17] +- Scan OneRowRelation[] ``` TODO: will support multiple unrelated Python UDFs in one batch (another PR). ## How was this patch tested? Added new unit tests for chained UDFs. Author: Davies Liu <davies@databricks.com> Closes #12014 from davies/py_udfs.
Diffstat (limited to 'python/pyspark/sql/functions.py')
-rw-r--r--python/pyspark/sql/functions.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f5d959ef98..3211834226 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -25,7 +25,7 @@ if sys.version < "3":
from itertools import imap as map
from pyspark import since, SparkContext
-from pyspark.rdd import _wrap_function, ignore_unicode_prefix
+from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1648,6 +1648,14 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------
+def _wrap_function(sc, func, returnType):
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, returnType, ser)
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+ sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
class UserDefinedFunction(object):
"""
User defined function in Python
@@ -1662,14 +1670,12 @@ class UserDefinedFunction(object):
def _create_judf(self, name):
from pyspark.sql import SQLContext
- f, returnType = self.func, self.returnType # put them in closure `func`
- func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
- ser = AutoBatchedSerializer(PickleSerializer())
sc = SparkContext.getOrCreate()
- wrapped_func = _wrap_function(sc, func, ser, ser)
+ wrapped_func = _wrap_function(sc, self.func, self.returnType)
ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
+ f = self.func
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, wrapped_func, jdt)