From 560489f4e16ff18b5e66e7de1bb84d890369a462 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 11 Mar 2016 11:59:18 +0800 Subject: [SPARK-13732][SPARK-13797][SQL] Remove projectList from Window and Eliminate useless Window #### What changes were proposed in this pull request? `projectList` is useless. Its value is always the same as the child.output. Remove it from the class `Window`. Removal can simplify the codes in Analyzer and Optimizer. This PR is based on the discussion started by cloud-fan in a separate PR: https://github.com/apache/spark/pull/5604#discussion_r55140466 This PR also eliminates useless `Window`. cloud-fan yhuai #### How was this patch tested? Existing test cases cover it. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #11565 from gatorsmile/removeProjListWindow. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 11 +--- .../apache/spark/sql/catalyst/dsl/package.scala | 6 ++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 20 ++++--- .../catalyst/plans/logical/basicOperators.scala | 5 +- .../catalyst/optimizer/ColumnPruningSuite.scala | 68 +++++++++++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../org/apache/spark/sql/execution/Window.scala | 6 +- 7 files changed, 94 insertions(+), 27 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9ab0a20a52..b654827b8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -421,7 +421,7 @@ class Analyzer( val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - case oldVersion @ Window(_, windowExpressions, _, _, child) + case oldVersion @ Window(windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) @@ -658,10 +658,6 @@ class Analyzer( case p: Project => val missing = missingAttrs -- p.child.outputSet Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case w: Window => - val missing = missingAttrs -- w.child.outputSet - w.copy(projectList = w.projectList ++ missingAttrs, - child = addMissingAttr(w.child, missing)) case a: Aggregate => // all the missing attributes should be grouping expressions // TODO: push down AggregateExpression @@ -1166,7 +1162,6 @@ class Analyzer( // Set currentChild to the newly created Window operator. currentChild = Window( - currentChild.output, windowExpressions, partitionSpec, orderSpec, @@ -1436,10 +1431,10 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + case w @ Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) - Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) // Operators that operate on objects should only have expressions from encoders, which should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 63463265e3..dc5264e266 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -268,6 +268,12 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } + def window( + windowExpressions: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder]): LogicalPlan = + Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 650b4eef6e..85776670e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,21 +315,17 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { - def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = output1.size == output2.size && output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - p.copy(child = w.copy( - projectList = w.projectList.filter(p.references.contains), - windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -343,11 +339,9 @@ object ColumnPruning extends Rule[LogicalPlan] { case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => mp.copy(child = prunedChild(child, mp.references)) - // Prunes the unused columns from child of Aggregate/Window/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) - case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => - w.copy(child = prunedChild(child, w.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => @@ -381,6 +375,14 @@ object ColumnPruning extends Rule[LogicalPlan] { p } + // Prune unnecessary window expressions + case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + windowExpressions = w.windowExpressions.filter(p.references.contains))) + + // Eliminate no-op Window + case w: Window if w.windowExpressions.isEmpty => w.child + // Eliminate no-op Projects case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3bc246a32d..09ea3fea6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -434,14 +434,15 @@ case class Aggregate( } case class Window( - projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - projectList ++ windowExpressions.map(_.toAttribute) + child.output ++ windowExpressions.map(_.toAttribute) + + def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } private[sql] object Expand { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 409e92238e..dd7d65ddc9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -33,7 +34,8 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - ColumnPruning) :: Nil + ColumnPruning, + CollapseProject) :: Nil } test("Column pruning for Generate when Generate.join = false") { @@ -258,6 +260,68 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) } + test("Column pruning on Window with useless aggregate functions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.groupBy('a, 'c, 'd)('a, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window with selected agg expressions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c) + + val correctAnswer = + input.select('a, 'b, 'c) + .window(WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window) :: Nil, + 'a :: Nil, 'b.asc :: Nil) + .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window in select") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = input.select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + test("Column pruning on Union") { val input1 = LocalRelation('a.int, 'b.string, 'c.double) val input2 = LocalRelation('c.int, 'd.string, 'e.double) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index debd04aa95..bae0750788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -344,9 +344,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => - execution.Window( - projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + case logical.Window(windowExprs, partitionSpec, orderSpec, child) => + execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 84154a47de..a4c0e1c9fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -81,14 +81,14 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ case class Window( - projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { if (partitionSpec.isEmpty) { @@ -275,7 +275,7 @@ case class Window( val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) UnsafeProjection.create( - projectList ++ patchedWindowExpression, + child.output ++ patchedWindowExpression, child.output) } -- cgit v1.2.3