aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala31
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala1
4 files changed, 26 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 85776670e5..2de92d06ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -71,6 +71,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
PushPredicateThroughAggregate,
LimitPushDown,
ColumnPruning,
+ EliminateOperators,
// Operator combine
CollapseRepartition,
CollapseProject,
@@ -315,11 +316,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
* - LeftSemiJoin
*/
object ColumnPruning extends Rule[LogicalPlan] {
- private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
- output1.size == output2.size &&
- output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
// Prunes the unused columns from project list of Project/Aggregate/Expand
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
@@ -380,12 +377,6 @@ object ColumnPruning extends Rule[LogicalPlan] {
p.copy(child = w.copy(
windowExpressions = w.windowExpressions.filter(p.references.contains)))
- // Eliminate no-op Window
- case w: Window if w.windowExpressions.isEmpty => w.child
-
- // Eliminate no-op Projects
- case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child
-
// Can't prune the columns on LeafNode
case p @ Project(_, l: LeafNode) => p
@@ -410,6 +401,24 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/**
+ * Eliminate no-op Project and Window.
+ *
+ * Note: this rule should be executed just after ColumnPruning.
+ */
+object EliminateOperators extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ // Eliminate no-op Projects
+ case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child
+ // Eliminate no-op Window
+ case w: Window if w.windowExpressions.isEmpty => w.child
+ }
+
+ private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
+ output1.size == output2.size &&
+ output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
+}
+
+/**
* Combines two adjacent [[Project]] operators into one and perform alias substitution,
* merging the expressions into one single expression.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index dd7d65ddc9..6187fb9e2f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -35,6 +35,7 @@ class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
ColumnPruning,
+ EliminateOperators,
CollapseProject) :: Nil
}
@@ -327,8 +328,8 @@ class ColumnPruningSuite extends PlanTest {
val input2 = LocalRelation('c.int, 'd.string, 'e.double)
val query = Project('b :: Nil,
Union(input1 :: input2 :: Nil)).analyze
- val expected = Project('b :: Nil,
- Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+ val expected =
+ Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil).analyze
comparePlans(Optimize.execute(query), expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index 87ad81db11..e0e9b6d93e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -28,7 +28,8 @@ class CombiningLimitsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Filter Pushdown", FixedPoint(100),
- ColumnPruning) ::
+ ColumnPruning,
+ EliminateOperators) ::
Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
Batch("Constant Folding", FixedPoint(10),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index e2f8146bee..51468fa5ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -43,6 +43,7 @@ class JoinOptimizationSuite extends PlanTest {
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
ColumnPruning,
+ EliminateOperators,
CollapseProject) :: Nil
}