diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/tests.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1be0b72304..c2171c277c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -328,6 +328,14 @@ class SQLTests(ReusedPySparkTestCase): [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() self.assertEqual(tuple(row), (6, 5)) + def test_udf_in_filter_on_top_of_outer_join(self): + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(a=1)]) + df = left.join(right, on='a', how='left_outer') + df = df.withColumn('b', udf(lambda x: 'x')(df.a)) + self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() |