aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/tests.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1be0b72304..c2171c277c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -328,6 +328,14 @@ class SQLTests(ReusedPySparkTestCase):
[row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
self.assertEqual(tuple(row), (6, 5))
+ def test_udf_in_filter_on_top_of_outer_join(self):
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1)])
+ right = self.spark.createDataFrame([Row(a=1)])
+ df = left.join(right, on='a', how='left_outer')
+ df = df.withColumn('b', udf(lambda x: 'x')(df.a))
+ self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
+
def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
[row] = self.spark.sql("SELECT foo()").collect()