aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-02-24 08:22:30 -0800
committerHolden Karau <holden@us.ibm.com>2017-02-24 08:22:30 -0800
commit4a5e38f5747148022988631cae0248ae1affadd3 (patch)
tree6294b22089734ebeda4782a2262e333e337c013f /python/pyspark
parent8f33731e796750e6f60dc9e2fc33a94d29d198b4 (diff)
downloadspark-4a5e38f5747148022988631cae0248ae1affadd3.tar.gz
spark-4a5e38f5747148022988631cae0248ae1affadd3.tar.bz2
spark-4a5e38f5747148022988631cae0248ae1affadd3.zip
[SPARK-19161][PYTHON][SQL] Improving UDF Docstrings
## What changes were proposed in this pull request? Replaces `UserDefinedFunction` object returned from `udf` with a function wrapper providing docstring and arguments information as proposed in [SPARK-19161](https://issues.apache.org/jira/browse/SPARK-19161). ### Backward incompatible changes: - `pyspark.sql.functions.udf` will return a `function` instead of `UserDefinedFunction`. To ensure backward compatible public API we use function attributes to mimic `UserDefinedFunction` API (`func` and `returnType` attributes). This should have a minimal impact on the user code. An alternative implementation could use dynamical sub-classing. This would ensure full backward compatibility but is more fragile in practice. ### Limitations: Full functionality (retained docstring and argument list) is achieved only in the recent Python version. Legacy Python version will preserve only docstrings, but not argument list. This should be an acceptable trade-off between achieved improvements and overall complexity. ### Possible impact on other tickets: This can affect [SPARK-18777](https://issues.apache.org/jira/browse/SPARK-18777). ## How was this patch tested? Existing unit tests to ensure backward compatibility, additional tests targeting proposed changes. Author: zero323 <zero323@users.noreply.github.com> Closes #16534 from zero323/SPARK-19161.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/functions.py11
-rw-r--r--python/pyspark/sql/tests.py25
2 files changed, 25 insertions, 11 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d261720314..426a4a8c93 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()):
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
- return UserDefinedFunction(f, returnType)
+ udf_obj = UserDefinedFunction(f, returnType)
+
+ @functools.wraps(f)
+ def wrapper(*args):
+ return udf_obj(*args)
+
+ wrapper.func = udf_obj.func
+ wrapper.returnType = udf_obj.returnType
+
+ return wrapper
# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index abd68bfd39..fd083e4868 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -266,9 +266,6 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")
- with self.assertRaises(ValueError):
- data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count()
-
def test_and_in_expression(self):
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
@@ -578,6 +575,21 @@ class SQLTests(ReusedPySparkTestCase):
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
)
+ def test_udf_wrapper(self):
+ from pyspark.sql.functions import udf
+ from pyspark.sql.types import IntegerType
+
+ def f(x):
+ """Identity"""
+ return x
+
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
@@ -963,13 +975,6 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
- def test_column_alias_metadata(self):
- df = self.df
- df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'}))
- self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key')
- with self.assertRaises(AssertionError):
- df.select(df.key.alias('pk', metdata={'label': 'Primary Key'}))
-
def test_freqItems(self):
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
df = self.sc.parallelize(vals).toDF()