aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/tests.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d5d691696..c631ad8a46 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col
+ from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType
my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])
+ my_copy = udf(lambda x: x, IntegerType())
+ my_add = udf(lambda a, b: int(a + b), IntegerType())
+ my_strlen = udf(lambda x: len(x), IntegerType())
+ sel = df.groupBy(my_copy(col("key")).alias("k"))\
+ .agg(sum(my_strlen(col("value"))).alias("s"))\
+ .select(my_add(col("k"), col("s")).alias("t"))
+ self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)