diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/tests.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1d5d691696..c631ad8a46 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import udf, col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) self.assertEqual(sel.collect(), [Row(key=1)]) + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) |