aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-06-15 13:38:04 -0700
committerDavies Liu <davies.liu@gmail.com>2016-06-15 13:38:04 -0700
commit5389013acc99367729dfc6deeb2cecc9edd1e24c (patch)
tree20dcb3a447fe5c72ba7c1b9edf7aa96d94844948 /sql
parent279bd4aa5fddbabdb0383a3f6f0fc8d91780e092 (diff)
downloadspark-5389013acc99367729dfc6deeb2cecc9edd1e24c.tar.gz
spark-5389013acc99367729dfc6deeb2cecc9edd1e24c.tar.bz2
spark-5389013acc99367729dfc6deeb2cecc9edd1e24c.zip
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request? After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate. ## How was this patch tested? Added regression tests. The plan of added test query looks like this: ``` == Parsed Logical Plan == 'Project [<lambda>('k, 's) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Analyzed Logical Plan == t: int Project [<lambda>(k#17, s#22L) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Optimized Logical Plan == Project [<lambda>(agg#29, agg#30L) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L] +- LogicalRDD [key#5L, value#6] == Physical Plan == *Project [pythonUDF0#37 AS t#26] +- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37] +- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L]) +- Exchange hashpartitioning(<lambda>(key#5L)#31, 200) +- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L]) +- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35] +- Scan ExistingRDD[key#5L,value#6] ``` Author: Davies Liu <davies@databricks.com> Closes #13682 from davies/fix_py_udf.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala70
3 files changed, 68 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 08b2d7fcd4..12a10cba20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf
class SparkOptimizer(
@@ -28,6 +29,7 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {
- override def batches: Seq[Batch] = super.batches :+ Batch(
- "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
+ override def batches: Seq[Batch] = super.batches :+
+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 061d7c7f79..d9bf4d3ccf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
def children: Seq[SparkPlan] = child :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index ab192360e1..668470ee6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -18,12 +18,68 @@
package org.apache.spark.sql.execution.python
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
+ * grouping key, evaluate them after aggregate.
+ */
+private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
+
+ /**
+ * Returns whether the expression could only be evaluated within aggregate.
+ */
+ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
+ e.isInstanceOf[AggregateExpression] ||
+ agg.groupingExpressions.exists(_.semanticEquals(e))
+ }
+
+ private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
+ expr.find {
+ e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
+ }.isDefined
+ }
+
+ private def extract(agg: Aggregate): LogicalPlan = {
+ val projList = new ArrayBuffer[NamedExpression]()
+ val aggExpr = new ArrayBuffer[NamedExpression]()
+ agg.aggregateExpressions.foreach { expr =>
+ if (hasPythonUdfOverAggregate(expr, agg)) {
+ // Python UDF can only be evaluated after aggregate
+ val newE = expr transformDown {
+ case e: Expression if belongAggregate(e, agg) =>
+ val alias = e match {
+ case a: NamedExpression => a
+ case o => Alias(e, "agg")()
+ }
+ aggExpr += alias
+ alias.toAttribute
+ }
+ projList += newE.asInstanceOf[NamedExpression]
+ } else {
+ aggExpr += expr
+ projList += expr.toAttribute
+ }
+ }
+ // There is no Python UDF over aggregate expression
+ Project(projList, agg.copy(aggregateExpressions = aggExpr))
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
+ extract(agg)
+ }
+}
+
+
/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
}
/**
- * Extract all the PythonUDFs from the current operator.
+ * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
- def extract(plan: SparkPlan): SparkPlan = {
+ private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ // ignore the PythonUDF that come from second/third aggregate, which is not used
+ .filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
- if (udf.references.subsetOf(plan.inputSet)) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
- }
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}
val rewritten = plan.transformExpressions {