aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/functions.py
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-02-13 10:37:34 -0800
committerHolden Karau <holden@us.ibm.com>2017-02-13 10:37:34 -0800
commitab88b2410623e5fdb06d558017bd6d50220e466a (patch)
tree9e082d08828457ae4c83e5b84879950dbcedcfa6 /python/pyspark/sql/functions.py
parent5e7cd3322b04f1dd207829b70546bc7ffdd63363 (diff)
downloadspark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.gz
spark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.bz2
spark-ab88b2410623e5fdb06d558017bd6d50220e466a.zip
[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF
## What changes were proposed in this pull request? Add support for data type string as a return type argument of `UserDefinedFunction`: ```python f = udf(lambda x: x, "integer") f.returnType ## IntegerType ``` ## How was this patch tested? Existing unit tests, additional unit tests covering new feature. Author: zero323 <zero323@users.noreply.github.com> Closes #16769 from zero323/SPARK-19427.
Diffstat (limited to 'python/pyspark/sql/functions.py')
-rw-r--r--python/pyspark/sql/functions.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 40727ab12b..5213a3c358 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -27,7 +27,7 @@ if sys.version < "3":
from pyspark import since, SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.sql.types import StringType
+from pyspark.sql.types import StringType, DataType, _parse_datatype_string
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.dataframe import DataFrame
@@ -1865,7 +1865,9 @@ class UserDefinedFunction(object):
"""
def __init__(self, func, returnType, name=None):
self.func = func
- self.returnType = returnType
+ self.returnType = (
+ returnType if isinstance(returnType, DataType)
+ else _parse_datatype_string(returnType))
# Stores UserDefinedPythonFunctions jobj, once initialized
self._judf_placeholder = None
self._name = name or (
@@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()):
it is present in the query.
:param f: python function
- :param returnType: a :class:`pyspark.sql.types.DataType` object
+ :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())