diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-27 11:23:29 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-07-27 11:23:29 -0700 |
commit | dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6 (patch) | |
tree | a2caa36e9a26bbea5de30580e1285395b67c35aa | |
parent | 1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42 (diff) | |
download | spark-dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6.tar.gz spark-dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6.tar.bz2 spark-dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6.zip |
[SPARK-9351] [SQL] remove literals from grouping expressions in Aggregate
literals in grouping expressions have no effect at all, only make our grouping key bigger, so we should remove them in Optimizer.
I also make old and new aggregation code consistent about literals in grouping here. In old aggregation, actually literals in grouping are already removed but new aggregation is not. So I explicitly make it a rule in Optimizer.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7583 from cloud-fan/minor and squashes the following commits:
471adff [Wenchen Fan] add test
0839925 [Wenchen Fan] use transformDown when rewrite final result expressions
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 17 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 4 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala) | 19 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 29 |
4 files changed, 57 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b59f800e7c..813c620096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,8 +36,9 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Distinct", FixedPoint(100), - ReplaceDistinctWithAggregate) :: + Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -799,3 +800,15 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { case Distinct(child) => Aggregate(child.output, child.output, child) } } + +/** + * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. + */ +object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, _, _) => + val newGrouping = grouping.filter(!_.foldable) + a.copy(groupingExpressions = newGrouping) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1e7b2a536a..b9ca712c1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -144,14 +144,14 @@ object PartialAggregation { // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + groupingExpressions.map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index df29a62ff0..2d080b95b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ReplaceDistinctWithAggregateSuite extends PlanTest { +class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + val batches = Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Nil } test("replace distinct with aggregate") { @@ -39,4 +42,16 @@ class ReplaceDistinctWithAggregateSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("remove literals in grouping expression") { + val input = LocalRelation('a.int, 'b.int) + + val query = + input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(query) + + val correctAnswer = input.groupBy('a)(sum('b)) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8cef0b39f8..358e319476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -463,12 +463,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + def literalInAggTest(): Unit = { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + literalInAggTest() + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + literalInAggTest() + } } test("aggregates with nulls") { |