aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-06-15 13:38:04 -0700
committerDavies Liu <davies.liu@gmail.com>2016-06-15 13:38:04 -0700
commit5389013acc99367729dfc6deeb2cecc9edd1e24c (patch)
tree20dcb3a447fe5c72ba7c1b9edf7aa96d94844948 /python
parent279bd4aa5fddbabdb0383a3f6f0fc8d91780e092 (diff)
downloadspark-5389013acc99367729dfc6deeb2cecc9edd1e24c.tar.gz
spark-5389013acc99367729dfc6deeb2cecc9edd1e24c.tar.bz2
spark-5389013acc99367729dfc6deeb2cecc9edd1e24c.zip
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request? After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate. ## How was this patch tested? Added regression tests. The plan of added test query looks like this: ``` == Parsed Logical Plan == 'Project [<lambda>('k, 's) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Analyzed Logical Plan == t: int Project [<lambda>(k#17, s#22L) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Optimized Logical Plan == Project [<lambda>(agg#29, agg#30L) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L] +- LogicalRDD [key#5L, value#6] == Physical Plan == *Project [pythonUDF0#37 AS t#26] +- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37] +- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L]) +- Exchange hashpartitioning(<lambda>(key#5L)#31, 200) +- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L]) +- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35] +- Scan ExistingRDD[key#5L,value#6] ``` Author: Davies Liu <davies@databricks.com> Closes #13682 from davies/fix_py_udf.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/tests.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d5d691696..c631ad8a46 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col
+ from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType
my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])
+ my_copy = udf(lambda x: x, IntegerType())
+ my_add = udf(lambda a, b: int(a + b), IntegerType())
+ my_strlen = udf(lambda x: len(x), IntegerType())
+ sel = df.groupBy(my_copy(col("key")).alias("k"))\
+ .agg(sum(my_strlen(col("value"))).alias("s"))\
+ .select(my_add(col("k"), col("s")).alias("t"))
+ self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)