aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala37
2 files changed, 44 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0da081ed1a..1a75fcf354 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -119,6 +119,15 @@ object ColumnPruning extends Rule[LogicalPlan] {
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
+ case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
+ if (a.outputSet -- p.references).nonEmpty =>
+ Project(
+ projectList,
+ Aggregate(
+ groupingExpressions,
+ aggregateExpressions.filter(e => p.references.contains(e)),
+ child))
+
// Eliminate unneeded attributes from either side of a Join.
case Project(projectList, Join(left, right, joinType, condition)) =>
// Collect the list of all references required either above or to evaluate the condition.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 0b74bacb18..55c6766520 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
-import org.apache.spark.sql.catalyst.expressions.Explode
+import org.apache.spark.sql.catalyst.expressions.{Count, Explode}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
@@ -37,7 +37,8 @@ class FilterPushdownSuite extends PlanTest {
CombineFilters,
PushPredicateThroughProject,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate) :: Nil
+ PushPredicateThroughGenerate,
+ ColumnPruning) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -58,6 +59,38 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("column pruning for group") {
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a, Count('b))
+ .select('a)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a)
+ .select('a).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for group with alias") {
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a as 'c, Count('b))
+ .select('c)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a as 'c)
+ .select('c).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
// After this line is unimplemented.
test("simple push down") {
val originalQuery =