diff options
author | zero323 <zero323@users.noreply.github.com> | 2017-02-13 10:37:34 -0800 |
---|---|---|
committer | Holden Karau <holden@us.ibm.com> | 2017-02-13 10:37:34 -0800 |
commit | ab88b2410623e5fdb06d558017bd6d50220e466a (patch) | |
tree | 9e082d08828457ae4c83e5b84879950dbcedcfa6 /python | |
parent | 5e7cd3322b04f1dd207829b70546bc7ffdd63363 (diff) | |
download | spark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.gz spark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.bz2 spark-ab88b2410623e5fdb06d558017bd6d50220e466a.zip |
[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF
## What changes were proposed in this pull request?
Add support for data type string as a return type argument of `UserDefinedFunction`:
```python
f = udf(lambda x: x, "integer")
f.returnType
## IntegerType
```
## How was this patch tested?
Existing unit tests, additional unit tests covering new feature.
Author: zero323 <zero323@users.noreply.github.com>
Closes #16769 from zero323/SPARK-19427.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/functions.py | 8 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 15 |
2 files changed, 20 insertions, 3 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 40727ab12b..5213a3c358 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ if sys.version < "3": from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType +from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame @@ -1865,7 +1865,9 @@ class UserDefinedFunction(object): """ def __init__(self, func, returnType, name=None): self.func = func - self.returnType = returnType + self.returnType = ( + returnType if isinstance(returnType, DataType) + else _parse_datatype_string(returnType)) # Stores UserDefinedPythonFunctions jobj, once initialized self._judf_placeholder = None self._name = name or ( @@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()): it is present in the query. :param f: python function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 710585cbe2..ab9d3f6c94 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -489,6 +489,21 @@ class SQLTests(ReusedPySparkTestCase): "judf should be initialized after UDF has been called." ) + def test_udf_with_string_return_type(self): + from pyspark.sql.functions import UserDefinedFunction + + add_one = UserDefinedFunction(lambda x: x + 1, "integer") + make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>") + make_array = UserDefinedFunction( + lambda x: [float(x) for x in range(x, x + 3)], "array<double>") + + expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) + actual = (self.spark.range(1, 2).toDF("x") + .select(add_one("x"), make_pair("x"), make_array("x")) + .first()) + + self.assertTupleEqual(expected, actual) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) |