aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/functions.py11
-rw-r--r--python/pyspark/sql/tests.py25
2 files changed, 25 insertions, 11 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d261720314..426a4a8c93 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()):
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
- return UserDefinedFunction(f, returnType)
+ udf_obj = UserDefinedFunction(f, returnType)
+
+ @functools.wraps(f)
+ def wrapper(*args):
+ return udf_obj(*args)
+
+ wrapper.func = udf_obj.func
+ wrapper.returnType = udf_obj.returnType
+
+ return wrapper
# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index abd68bfd39..fd083e4868 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -266,9 +266,6 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")
- with self.assertRaises(ValueError):
- data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count()
-
def test_and_in_expression(self):
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
@@ -578,6 +575,21 @@ class SQLTests(ReusedPySparkTestCase):
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
)
+ def test_udf_wrapper(self):
+ from pyspark.sql.functions import udf
+ from pyspark.sql.types import IntegerType
+
+ def f(x):
+ """Identity"""
+ return x
+
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
@@ -963,13 +975,6 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
- def test_column_alias_metadata(self):
- df = self.df
- df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'}))
- self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key')
- with self.assertRaises(AssertionError):
- df.select(df.key.alias('pk', metdata={'label': 'Primary Key'}))
-
def test_freqItems(self):
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
df = self.sc.parallelize(vals).toDF()