aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-05-02 12:40:21 -0700
committerMichael Armbrust <michael@databricks.com>2016-05-02 12:40:21 -0700
commit6e6320122ea84247c67e2d0fb0e6af54e2c5bb31 (patch)
treee11dffffc8d9762d8369025fb6101d57b58286fc /sql/catalyst/src
parenta35a67a83dbb450d26ce0d142ab106e952670842 (diff)
downloadspark-6e6320122ea84247c67e2d0fb0e6af54e2c5bb31.tar.gz
spark-6e6320122ea84247c67e2d0fb0e6af54e2c5bb31.tar.bz2
spark-6e6320122ea84247c67e2d0fb0e6af54e2c5bb31.zip
[SPARK-14830][SQL] Add RemoveRepetitionFromGroupExpressions optimizer.
## What changes were proposed in this pull request? This PR aims to optimize GroupExpressions by removing repeating expressions. `RemoveRepetitionFromGroupExpressions` is added. **Before** ```scala scala> sql("select a+1 from values 1,2 T(a) group by a+1, 1+a, A+1, 1+A").explain() == Physical Plan == WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1)#6,(1 + a#0)#7,(A#0 + 1)#8,(1 + A#0)#9], functions=[], output=[(a + 1)#5]) : +- INPUT +- Exchange hashpartitioning((a#0 + 1)#6, (1 + a#0)#7, (A#0 + 1)#8, (1 + A#0)#9, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1) AS (a#0 + 1)#6,(1 + a#0) AS (1 + a#0)#7,(A#0 + 1) AS (A#0 + 1)#8,(1 + A#0) AS (1 + A#0)#9], functions=[], output=[(a#0 + 1)#6,(1 + a#0)#7,(A#0 + 1)#8,(1 + A#0)#9]) : +- INPUT +- LocalTableScan [a#0], [[1],[2]] ``` **After** ```scala scala> sql("select a+1 from values 1,2 T(a) group by a+1, 1+a, A+1, 1+A").explain() == Physical Plan == WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1)#6], functions=[], output=[(a + 1)#5]) : +- INPUT +- Exchange hashpartitioning((a#0 + 1)#6, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1) AS (a#0 + 1)#6], functions=[], output=[(a#0 + 1)#6]) : +- INPUT +- LocalTableScan [a#0], [[1],[2]] ``` ## How was this patch tested? Pass the Jenkins tests (with a new testcase) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12590 from dongjoon-hyun/SPARK-14830.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala20
2 files changed, 33 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0b70edec8e..a147fff274 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -68,7 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReplaceExceptWithAntiJoin,
ReplaceDistinctWithAggregate) ::
Batch("Aggregate", fixedPoint,
- RemoveLiteralFromGroupExpressions) ::
+ RemoveLiteralFromGroupExpressions,
+ RemoveRepetitionFromGroupExpressions) ::
Batch("Operator Optimizations", fixedPoint,
// Operator push down
SetOperationPushDown,
@@ -1440,6 +1441,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
}
/**
+ * Removes repetition from group expressions in [[Aggregate]], as they have no effect to the result
+ * but only makes the grouping key bigger.
+ */
+object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case a @ Aggregate(grouping, _, _) =>
+ val newGrouping = ExpressionSet(grouping).toSeq
+ a.copy(groupingExpressions = newGrouping)
+ }
+}
+
+/**
* Computes the current date and time to make sure we return the same result in a single query.
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index e458eb8a1d..c94dcb3354 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
@@ -25,10 +28,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
class AggregateOptimizeSuite extends PlanTest {
+ val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false)
+ val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ val analyzer = new Analyzer(catalog, conf)
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Aggregate", FixedPoint(100),
- RemoveLiteralFromGroupExpressions) :: Nil
+ RemoveLiteralFromGroupExpressions,
+ RemoveRepetitionFromGroupExpressions) :: Nil
}
test("remove literals in grouping expression") {
@@ -42,4 +49,15 @@ class AggregateOptimizeSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("remove repetition in grouping expression") {
+ val input = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
+ val optimized = Optimize.execute(analyzer.execute(query))
+
+ val correctAnswer = analyzer.execute(input.groupBy('a + 1, 'b + 2)(sum('c)))
+
+ comparePlans(optimized, correctAnswer)
+ }
}