aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-02-15 10:16:34 -0800
committerHolden Karau <holden@us.ibm.com>2017-02-15 10:16:34 -0800
commitc97f4e17de0ce39e8172a5a4ae81f1914816a358 (patch)
tree54e6fd3ea35bec6f2350b0459e62db229931de6f
parent6eca21ba881120f1ac7854621380ef8a92972384 (diff)
downloadspark-c97f4e17de0ce39e8172a5a4ae81f1914816a358.tar.gz
spark-c97f4e17de0ce39e8172a5a4ae81f1914816a358.tar.bz2
spark-c97f4e17de0ce39e8172a5a4ae81f1914816a358.zip
[SPARK-19160][PYTHON][SQL] Add udf decorator
## What changes were proposed in this pull request? This PR adds `udf` decorator syntax as proposed in [SPARK-19160](https://issues.apache.org/jira/browse/SPARK-19160). This allows users to define UDF using simplified syntax: ```python from pyspark.sql.decorators import udf udf(IntegerType()) def add_one(x): """Adds one""" if x is not None: return x + 1 ``` without need to define a separate function and udf. ## How was this patch tested? Existing unit tests to ensure backward compatibility and additional unit tests covering new functionality. Author: zero323 <zero323@users.noreply.github.com> Closes #16533 from zero323/SPARK-19160.
-rw-r--r--python/pyspark/sql/functions.py41
-rw-r--r--python/pyspark/sql/tests.py57
2 files changed, 91 insertions, 7 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 4f4ae10892..d261720314 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -20,6 +20,7 @@ A collections of builtin functions
"""
import math
import sys
+import functools
if sys.version < "3":
from itertools import imap as map
@@ -1908,22 +1909,48 @@ class UserDefinedFunction(object):
@since(1.3)
-def udf(f, returnType=StringType()):
+def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).
.. note:: The user-defined functions must be deterministic. Due to optimization,
duplicate invocations may be eliminated or the function may even be invoked more times than
it is present in the query.
- :param f: python function
- :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
+ :param f: python function if used as a standalone function
+ :param returnType: a :class:`pyspark.sql.types.DataType` object
>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
- >>> df.select(slen(df.name).alias('slen')).collect()
- [Row(slen=5), Row(slen=3)]
- """
- return UserDefinedFunction(f, returnType)
+ >>> @udf
+ ... def to_upper(s):
+ ... if s is not None:
+ ... return s.upper()
+ ...
+ >>> @udf(returnType=IntegerType())
+ ... def add_one(x):
+ ... if x is not None:
+ ... return x + 1
+ ...
+ >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
+ >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).show()
+ +----------+--------------+------------+
+ |slen(name)|to_upper(name)|add_one(age)|
+ +----------+--------------+------------+
+ | 8| JOHN DOE| 22|
+ +----------+--------------+------------+
+ """
+ def _udf(f, returnType=StringType()):
+ return UserDefinedFunction(f, returnType)
+
+ # decorator @udf, @udf() or @udf(dataType())
+ if f is None or isinstance(f, (str, DataType)):
+ # If DataType has been passed as a positional argument
+ # for decorator use it as a returnType
+ return_type = f or returnType
+ return functools.partial(_udf, returnType=return_type)
+ else:
+ return _udf(f=f, returnType=returnType)
+
blacklist = ['map', 'since', 'ignore_unicode_prefix']
__all__ = [k for k, v in globals().items()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 62e1a8c363..d8b7b3137c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -514,6 +514,63 @@ class SQLTests(ReusedPySparkTestCase):
non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
+ def test_udf_with_decorator(self):
+ from pyspark.sql.functions import lit, udf
+ from pyspark.sql.types import IntegerType, DoubleType
+
+ @udf(IntegerType())
+ def add_one(x):
+ if x is not None:
+ return x + 1
+
+ @udf(returnType=DoubleType())
+ def add_two(x):
+ if x is not None:
+ return float(x + 2)
+
+ @udf
+ def to_upper(x):
+ if x is not None:
+ return x.upper()
+
+ @udf()
+ def to_lower(x):
+ if x is not None:
+ return x.lower()
+
+ @udf
+ def substr(x, start, end):
+ if x is not None:
+ return x[start:end]
+
+ @udf("long")
+ def trunc(x):
+ return int(x)
+
+ @udf(returnType="double")
+ def as_double(x):
+ return float(x)
+
+ df = (
+ self.spark
+ .createDataFrame(
+ [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
+ .select(
+ add_one("one"), add_two("one"),
+ to_upper("Foo"), to_lower("Foo"),
+ substr("foobar", lit(0), lit(3)),
+ trunc("float"), as_double("one")))
+
+ self.assertListEqual(
+ [tpe for _, tpe in df.dtypes],
+ ["int", "double", "string", "string", "string", "bigint", "double"]
+ )
+
+ self.assertListEqual(
+ list(df.first()),
+ [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
+ )
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)