aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorjiangxingbo <jiangxb1987@gmail.com>2016-09-13 17:04:51 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-13 17:04:51 +0200
commit4ba63b193c1ac292493e06343d9d618c12c5ef3f (patch)
treea2de62841e287d2e52681bba9421468a337ca739 /sql
parent3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8 (diff)
downloadspark-4ba63b193c1ac292493e06343d9d618c12c5ef3f.tar.gz
spark-4ba63b193c1ac292493e06343d9d618c12c5ef3f.tar.bz2
spark-4ba63b193c1ac292493e06343d9d618c12c5ef3f.zip
[SPARK-17142][SQL] Complex query triggers binding error in HashAggregateExec
## What changes were proposed in this pull request? In `ReorderAssociativeOperator` rule, we extract foldable expressions with Add/Multiply arithmetics, and replace with eval literal. For example, `(a + 1) + (b + 2)` is optimized to `(a + b + 3)` by this rule. For aggregate operator, output expressions should be derived from groupingExpressions, current implemenation of `ReorderAssociativeOperator` rule may break this promise. A instance could be: ``` SELECT ((t1.a + 1) + (t2.a + 2)) AS out_col FROM testdata2 AS t1 INNER JOIN testdata2 AS t2 ON (t1.a = t2.a) GROUP BY (t1.a + 1), (t2.a + 2) ``` `((t1.a + 1) + (t2.a + 2))` is optimized to `(t1.a + t2.a + 3)`, which could not be derived from `ExpressionSet((t1.a +1), (t2.a + 2))`. Maybe we should improve the rule of `ReorderAssociativeOperator` by adding a GroupingExpressionSet to keep Aggregate.groupingExpressions, and respect these expressions during the optimize stage. ## How was this patch tested? Add new test case in `ReorderAssociativeOperatorSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Closes #14917 from jiangxb1987/rao.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala31
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala16
2 files changed, 39 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 82ab111aa2..b7458910da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -57,20 +57,37 @@ object ConstantFolding extends Rule[LogicalPlan] {
* Reorder associative integral-type operators and fold all constants into one.
*/
object ReorderAssociativeOperator extends Rule[LogicalPlan] {
- private def flattenAdd(e: Expression): Seq[Expression] = e match {
- case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
+ private def flattenAdd(
+ expression: Expression,
+ groupSet: ExpressionSet): Seq[Expression] = expression match {
+ case expr @ Add(l, r) if !groupSet.contains(expr) =>
+ flattenAdd(l, groupSet) ++ flattenAdd(r, groupSet)
case other => other :: Nil
}
- private def flattenMultiply(e: Expression): Seq[Expression] = e match {
- case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
+ private def flattenMultiply(
+ expression: Expression,
+ groupSet: ExpressionSet): Seq[Expression] = expression match {
+ case expr @ Multiply(l, r) if !groupSet.contains(expr) =>
+ flattenMultiply(l, groupSet) ++ flattenMultiply(r, groupSet)
case other => other :: Nil
}
+ private def collectGroupingExpressions(plan: LogicalPlan): ExpressionSet = plan match {
+ case Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ ExpressionSet.apply(groupingExpressions)
+ case _ => ExpressionSet(Seq())
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
+ case q: LogicalPlan =>
+ // We have to respect aggregate expressions which exists in grouping expressions when plan
+ // is an Aggregate operator, otherwise the optimized expression could not be derived from
+ // grouping expressions.
+ val groupingExpressionSet = collectGroupingExpressions(q)
+ q transformExpressionsDown {
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
- val (foldables, others) = flattenAdd(a).partition(_.foldable)
+ val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
@@ -79,7 +96,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
a
}
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
- val (foldables, others) = flattenMultiply(m).partition(_.foldable)
+ val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
index 05e15e9ec4..a1ab0a8344 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -60,4 +60,18 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("nested expression with aggregate operator") {
+ val originalQuery =
+ testRelation.as("t1")
+ .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
+ .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
+ (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = originalQuery.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}