aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-27 11:23:29 -0700
committerMichael Armbrust <michael@databricks.com>2015-07-27 11:23:29 -0700
commitdd9ae7945ab65d353ed2b113e0c1a00a0533ffd6 (patch)
treea2caa36e9a26bbea5de30580e1285395b67c35aa /sql
parent1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42 (diff)
downloadspark-dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6.tar.gz
spark-dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6.tar.bz2
spark-dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6.zip
[SPARK-9351] [SQL] remove literals from grouping expressions in Aggregate
literals in grouping expressions have no effect at all, only make our grouping key bigger, so we should remove them in Optimizer. I also make old and new aggregation code consistent about literals in grouping here. In old aggregation, actually literals in grouping are already removed but new aggregation is not. So I explicitly make it a rule in Optimizer. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7583 from cloud-fan/minor and squashes the following commits: 471adff [Wenchen Fan] add test 0839925 [Wenchen Fan] use transformDown when rewrite final result expressions
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala)19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala29
4 files changed, 57 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b59f800e7c..813c620096 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -36,8 +36,9 @@ object DefaultOptimizer extends Optimizer {
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
- Batch("Distinct", FixedPoint(100),
- ReplaceDistinctWithAggregate) ::
+ Batch("Aggregate", FixedPoint(100),
+ ReplaceDistinctWithAggregate,
+ RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
SetOperationPushDown,
@@ -799,3 +800,15 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
case Distinct(child) => Aggregate(child.output, child.output, child)
}
}
+
+/**
+ * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
+ * but only makes the grouping key bigger.
+ */
+object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case a @ Aggregate(grouping, _, _) =>
+ val newGrouping = grouping.filter(!_.foldable)
+ a.copy(groupingExpressions = newGrouping)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 1e7b2a536a..b9ca712c1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -144,14 +144,14 @@ object PartialAggregation {
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
- groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
+ groupingExpressions.map {
case n: NamedExpression => (n, n)
case other => (other, Alias(other, "PartialGroup")())
}
// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
- val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
+ val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown {
case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index df29a62ff0..2d080b95b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-class ReplaceDistinctWithAggregateSuite extends PlanTest {
+class AggregateOptimizeSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
- val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil
+ val batches = Batch("Aggregate", FixedPoint(100),
+ ReplaceDistinctWithAggregate,
+ RemoveLiteralFromGroupExpressions) :: Nil
}
test("replace distinct with aggregate") {
@@ -39,4 +42,16 @@ class ReplaceDistinctWithAggregateSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("remove literals in grouping expression") {
+ val input = LocalRelation('a.int, 'b.int)
+
+ val query =
+ input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
+ val optimized = Optimize.execute(query)
+
+ val correctAnswer = input.groupBy('a)(sum('b))
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8cef0b39f8..358e319476 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -463,12 +463,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("literal in agg grouping expressions") {
- checkAnswer(
- sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
- checkAnswer(
- sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+ def literalInAggTest(): Unit = {
+ checkAnswer(
+ sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
+ Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+ checkAnswer(
+ sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
+ Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+
+ checkAnswer(
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ checkAnswer(
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ checkAnswer(
+ sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
+ sql("SELECT 1, 2, sum(b) FROM testData2"))
+ }
+
+ literalInAggTest()
+ withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+ literalInAggTest()
+ }
}
test("aggregates with nulls") {