diff options
author | zero323 <zero323@users.noreply.github.com> | 2017-02-13 10:37:34 -0800 |
---|---|---|
committer | Holden Karau <holden@us.ibm.com> | 2017-02-13 10:37:34 -0800 |
commit | ab88b2410623e5fdb06d558017bd6d50220e466a (patch) | |
tree | 9e082d08828457ae4c83e5b84879950dbcedcfa6 /python/pyspark/sql/functions.py | |
parent | 5e7cd3322b04f1dd207829b70546bc7ffdd63363 (diff) | |
download | spark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.gz spark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.bz2 spark-ab88b2410623e5fdb06d558017bd6d50220e466a.zip |
[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF
## What changes were proposed in this pull request?
Add support for data type string as a return type argument of `UserDefinedFunction`:
```python
f = udf(lambda x: x, "integer")
f.returnType
## IntegerType
```
## How was this patch tested?
Existing unit tests, additional unit tests covering new feature.
Author: zero323 <zero323@users.noreply.github.com>
Closes #16769 from zero323/SPARK-19427.
Diffstat (limited to 'python/pyspark/sql/functions.py')
-rw-r--r-- | python/pyspark/sql/functions.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 40727ab12b..5213a3c358 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ if sys.version < "3": from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType +from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame @@ -1865,7 +1865,9 @@ class UserDefinedFunction(object): """ def __init__(self, func, returnType, name=None): self.func = func - self.returnType = returnType + self.returnType = ( + returnType if isinstance(returnType, DataType) + else _parse_datatype_string(returnType)) # Stores UserDefinedPythonFunctions jobj, once initialized self._judf_placeholder = None self._name = name or ( @@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()): it is present in the query. :param f: python function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) |