aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala31
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala16
2 files changed, 39 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 82ab111aa2..b7458910da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -57,20 +57,37 @@ object ConstantFolding extends Rule[LogicalPlan] {
* Reorder associative integral-type operators and fold all constants into one.
*/
object ReorderAssociativeOperator extends Rule[LogicalPlan] {
- private def flattenAdd(e: Expression): Seq[Expression] = e match {
- case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
+ private def flattenAdd(
+ expression: Expression,
+ groupSet: ExpressionSet): Seq[Expression] = expression match {
+ case expr @ Add(l, r) if !groupSet.contains(expr) =>
+ flattenAdd(l, groupSet) ++ flattenAdd(r, groupSet)
case other => other :: Nil
}
- private def flattenMultiply(e: Expression): Seq[Expression] = e match {
- case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
+ private def flattenMultiply(
+ expression: Expression,
+ groupSet: ExpressionSet): Seq[Expression] = expression match {
+ case expr @ Multiply(l, r) if !groupSet.contains(expr) =>
+ flattenMultiply(l, groupSet) ++ flattenMultiply(r, groupSet)
case other => other :: Nil
}
+ private def collectGroupingExpressions(plan: LogicalPlan): ExpressionSet = plan match {
+ case Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ ExpressionSet.apply(groupingExpressions)
+ case _ => ExpressionSet(Seq())
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
+ case q: LogicalPlan =>
+ // We have to respect aggregate expressions which exists in grouping expressions when plan
+ // is an Aggregate operator, otherwise the optimized expression could not be derived from
+ // grouping expressions.
+ val groupingExpressionSet = collectGroupingExpressions(q)
+ q transformExpressionsDown {
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
- val (foldables, others) = flattenAdd(a).partition(_.foldable)
+ val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
@@ -79,7 +96,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
a
}
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
- val (foldables, others) = flattenMultiply(m).partition(_.foldable)
+ val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
index 05e15e9ec4..a1ab0a8344 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -60,4 +60,18 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("nested expression with aggregate operator") {
+ val originalQuery =
+ testRelation.as("t1")
+ .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
+ .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
+ (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = originalQuery.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}