From 5389013acc99367729dfc6deeb2cecc9edd1e24c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Jun 2016 13:38:04 -0700 Subject: [SPARK-15888] [SQL] fix Python UDF with aggregate ## What changes were proposed in this pull request? After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate. ## How was this patch tested? Added regression tests. The plan of added test query looks like this: ``` == Parsed Logical Plan == 'Project [('k, 's) AS t#26] +- Aggregate [(key#5L)], [(key#5L) AS k#17, sum(cast((value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Analyzed Logical Plan == t: int Project [(k#17, s#22L) AS t#26] +- Aggregate [(key#5L)], [(key#5L) AS k#17, sum(cast((value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Optimized Logical Plan == Project [(agg#29, agg#30L) AS t#26] +- Aggregate [(key#5L)], [(key#5L) AS agg#29, sum(cast((value#6) as bigint)) AS agg#30L] +- LogicalRDD [key#5L, value#6] == Physical Plan == *Project [pythonUDF0#37 AS t#26] +- BatchEvalPython [(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37] +- *HashAggregate(key=[(key#5L)#31], functions=[sum(cast((value#6) as bigint))], output=[agg#29,agg#30L]) +- Exchange hashpartitioning((key#5L)#31, 200) +- *HashAggregate(key=[pythonUDF0#34 AS (key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[(key#5L)#31,sum#33L]) +- BatchEvalPython [(key#5L), (value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35] +- Scan ExistingRDD[key#5L,value#6] ``` Author: Davies Liu Closes #13682 from davies/fix_py_udf. --- python/pyspark/sql/tests.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'python') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1d5d691696..c631ad8a46 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import udf, col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) self.assertEqual(sel.collect(), [Row(key=1)]) + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) -- cgit v1.2.3