aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-02-13 10:37:34 -0800
committerHolden Karau <holden@us.ibm.com>2017-02-13 10:37:34 -0800
commitab88b2410623e5fdb06d558017bd6d50220e466a (patch)
tree9e082d08828457ae4c83e5b84879950dbcedcfa6 /python
parent5e7cd3322b04f1dd207829b70546bc7ffdd63363 (diff)
downloadspark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.gz
spark-ab88b2410623e5fdb06d558017bd6d50220e466a.tar.bz2
spark-ab88b2410623e5fdb06d558017bd6d50220e466a.zip
[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF
## What changes were proposed in this pull request? Add support for data type string as a return type argument of `UserDefinedFunction`: ```python f = udf(lambda x: x, "integer") f.returnType ## IntegerType ``` ## How was this patch tested? Existing unit tests, additional unit tests covering new feature. Author: zero323 <zero323@users.noreply.github.com> Closes #16769 from zero323/SPARK-19427.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py8
-rw-r--r--python/pyspark/sql/tests.py15
2 files changed, 20 insertions, 3 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 40727ab12b..5213a3c358 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -27,7 +27,7 @@ if sys.version < "3":
from pyspark import since, SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.sql.types import StringType
+from pyspark.sql.types import StringType, DataType, _parse_datatype_string
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.dataframe import DataFrame
@@ -1865,7 +1865,9 @@ class UserDefinedFunction(object):
"""
def __init__(self, func, returnType, name=None):
self.func = func
- self.returnType = returnType
+ self.returnType = (
+ returnType if isinstance(returnType, DataType)
+ else _parse_datatype_string(returnType))
# Stores UserDefinedPythonFunctions jobj, once initialized
self._judf_placeholder = None
self._name = name or (
@@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()):
it is present in the query.
:param f: python function
- :param returnType: a :class:`pyspark.sql.types.DataType` object
+ :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 710585cbe2..ab9d3f6c94 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -489,6 +489,21 @@ class SQLTests(ReusedPySparkTestCase):
"judf should be initialized after UDF has been called."
)
+ def test_udf_with_string_return_type(self):
+ from pyspark.sql.functions import UserDefinedFunction
+
+ add_one = UserDefinedFunction(lambda x: x + 1, "integer")
+ make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
+ make_array = UserDefinedFunction(
+ lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
+
+ expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
+ actual = (self.spark.range(1, 2).toDF("x")
+ .select(add_one("x"), make_pair("x"), make_array("x"))
+ .first())
+
+ self.assertTupleEqual(expected, actual)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)