diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/functions.py | 11 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 25 |
2 files changed, 25 insertions, 11 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d261720314..426a4a8c93 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()): +----------+--------------+------------+ """ def _udf(f, returnType=StringType()): - return UserDefinedFunction(f, returnType) + udf_obj = UserDefinedFunction(f, returnType) + + @functools.wraps(f) + def wrapper(*args): + return udf_obj(*args) + + wrapper.func = udf_obj.func + wrapper.returnType = udf_obj.returnType + + return wrapper # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index abd68bfd39..fd083e4868 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -266,9 +266,6 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") - with self.assertRaises(ValueError): - data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count() - def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) @@ -578,6 +575,21 @@ class SQLTests(ReusedPySparkTestCase): [2, 3.0, "FOO", "foo", "foo", 3, 1.0] ) + def test_udf_wrapper(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import IntegerType + + def f(x): + """Identity""" + return x + + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) @@ -963,13 +975,6 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - def test_column_alias_metadata(self): - df = self.df - df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'})) - self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key') - with self.assertRaises(AssertionError): - df.select(df.key.alias('pk', metdata={'label': 'Primary Key'})) - def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() |