diff options
Diffstat (limited to 'sql')
4 files changed, 187 insertions, 156 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1f05f2065c..2b804976f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(projects, output, child)) - if (e.outputSet -- a.references).nonEmpty => - val newOutput = output.filter(a.references.contains(_)) - val newProjects = projects.map { proj => - proj.zip(output).filter { case (e, a) => + // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) + case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + p.copy( + child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + projectList = w.projectList.filter(p.references.contains), + windowExpressions = w.windowExpressions.filter(p.references.contains))) + case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + val newOutput = e.output.filter(a.references.contains(_)) + val newProjects = e.projections.map { proj => + proj.zip(e.output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, child)) + a.copy(child = Expand(newProjects, newOutput, grandChild)) + // TODO: support some logical plan for Dataset - case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- e.references -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) - - // Eliminate attributes that are not needed to calculate the specified aggregates. + // Prunes the unused columns from child of Aggregate/Window/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = Project(a.references.toSeq, child)) - - // Eliminate attributes that are not needed to calculate the Generate. + a.copy(child = prunedChild(child, a.references)) + case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => + w.copy(child = prunedChild(child, w.references)) + case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = Project(g.references.toSeq, g.child)) + g.copy(child = prunedChild(g.child, g.references)) + // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - case p @ Project(projectList, g: Generate) if g.join => - val neededChildOutput = p.references -- g.generatorOutput ++ g.references - if (neededChildOutput == g.child.outputSet) { - p + // Eliminate unneeded attributes from right side of a LeftSemiJoin. + case j @ Join(left, right, LeftSemi, condition) => + j.copy(right = prunedChild(right, j.references)) + + // all the columns will be used to compare, so we can't prune them + case p @ Project(_, _: SetOperation) => p + case p @ Project(_, _: Distinct) => p + // Eliminate unneeded attributes from children of Union. + case p @ Project(_, u: Union) => + if ((u.outputSet -- p.references).nonEmpty) { + val firstChild = u.children.head + val newOutput = prunedChild(firstChild, p.references).output + // pruning the columns of all children based on the pruned first child. + val newChildren = u.children.map { p => + val selected = p.output.zipWithIndex.filter { case (a, i) => + newOutput.contains(firstChild.output(i)) + }.map(_._1) + Project(selected, p) + } + p.copy(child = u.withNewChildren(newChildren)) } else { - Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + p } - case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) - if (a.outputSet -- p.references).nonEmpty => - Project( - projectList, - Aggregate( - groupingExpressions, - aggregateExpressions.filter(e => p.references.contains(e)), - child)) - - // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => - // Collect the list of all references required either above or to evaluate the condition. - val allReferences: AttributeSet = - AttributeSet( - projectList.flatMap(_.references.iterator)) ++ - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) + // Can't prune the columns on LeafNode + case p @ Project(_, l: LeafNode) => p - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) - - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => - // Collect the list of all references required to evaluate the condition. - val allReferences: AttributeSet = - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - Join(left, prunedChild(right, allReferences), LeftSemi, condition) - - // Push down project through limit, so that we may have chance to push it further. - case Project(projectList, Limit(exp, child)) => - Limit(exp, Project(projectList, child)) - - // Push down project if possible when the child is sort. - case p @ Project(projectList, s @ Sort(_, _, grandChild)) => - if (s.references.subsetOf(p.outputSet)) { - s.copy(child = Project(projectList, grandChild)) + // Eliminate no-op Projects + case p @ Project(projectList, child) if child.output == p.output => child + + // for all other logical plans that inherits the output from it's children + case p @ Project(_, child) => + val required = child.references ++ p.references + if ((child.inputSet -- required).nonEmpty) { + val newChildren = child.children.map(c => prunedChild(c, required)) + p.copy(child = child.withNewChildren(newChildren)) } else { - val neededReferences = s.references ++ p.references - if (neededReferences == grandChild.outputSet) { - // No column we can prune, return the original plan. - p - } else { - // Do not use neededReferences.toSeq directly, should respect grandChild's output order. - val newProjectList = grandChild.output.filter(neededReferences.contains) - p.copy(child = s.copy(child = Project(newProjectList, grandChild))) - } + p } - - // Eliminate no-op Projects - case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + Project(c.output.filter(allReferences.contains), c) } else { c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index c890fffc40..715d01a3cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest { Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), - Project(Seq('c, 'a), + Project(Seq('a, 'c), input))).analyze comparePlans(optimized, expected) } + test("Column pruning on Filter") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze + val expected = + Project('a :: Nil, + Filter('c > Literal(0.0), + Project(Seq('a, 'c), input))).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("Column pruning on except/intersect/distinct") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Except(input, input)).analyze + comparePlans(Optimize.execute(query), query) + + val query2 = Project('a :: Nil, Intersect(input, input)).analyze + comparePlans(Optimize.execute(query2), query2) + val query3 = Project('a :: Nil, Distinct(input)).analyze + comparePlans(Optimize.execute(query3), query3) + } + + test("Column pruning on Project") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze + val expected = Project(Seq('a), input).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("column pruning for group") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("push down project past sort") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + + test("Column pruning on Union") { + val input1 = LocalRelation('a.int, 'b.string, 'c.double) + val input2 = LocalRelation('c.int, 'd.string, 'e.double) + val query = Project('b :: Nil, + Union(input1 :: input2 :: Nil)).analyze + val expected = Project('b :: Nil, + Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + comparePlans(Optimize.execute(query), expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70b34cbb24..7d60862f5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, - ColumnPruning, CollapseProject) :: Nil } @@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("column pruning for group") { - val originalQuery = - testRelation - .groupBy('a)('a, count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val originalQuery = - testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push down project past sort") { - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) - } - test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 4858140229..22d4278085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation( @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient private[sql] var _statistics: Statistics = null, private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends LogicalPlan with MultiInstanceRelation { + extends logical.LeafNode with MultiInstanceRelation { override def producedAttributes: AttributeSet = outputSet @@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } - override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), |