aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-25 00:13:07 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-25 00:13:07 -0800
commit07f92ef1fa090821bef9c60689bf41909d781ee7 (patch)
tree10d8f563352dffea1a2b3363bb2659174187562a /sql/catalyst
parent264533b553be806b6c45457201952e83c028ec78 (diff)
downloadspark-07f92ef1fa090821bef9c60689bf41909d781ee7.tar.gz
spark-07f92ef1fa090821bef9c60689bf41909d781ee7.tar.bz2
spark-07f92ef1fa090821bef9c60689bf41909d781ee7.zip
[SPARK-13376] [SPARK-13476] [SQL] improve column pruning
## What changes were proposed in this pull request? This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset). This PR also fix a bug in Generate, it should always output UnsafeRow, added an regression test for that. ## How was this patch tested? This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s). Author: Davies Liu <davies@databricks.com> Closes #11354 from davies/fix_column_pruning.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala128
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala128
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala80
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala2
4 files changed, 185 insertions, 153 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1f05f2065c..2b804976f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case a @ Aggregate(_, _, e @ Expand(projects, output, child))
- if (e.outputSet -- a.references).nonEmpty =>
- val newOutput = output.filter(a.references.contains(_))
- val newProjects = projects.map { proj =>
- proj.zip(output).filter { case (e, a) =>
+ // Prunes the unused columns from project list of Project/Aggregate/Window/Expand
+ case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
+ p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
+ case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
+ p.copy(
+ child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
+ case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
+ p.copy(child = w.copy(
+ projectList = w.projectList.filter(p.references.contains),
+ windowExpressions = w.windowExpressions.filter(p.references.contains)))
+ case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
+ val newOutput = e.output.filter(a.references.contains(_))
+ val newProjects = e.projections.map { proj =>
+ proj.zip(e.output).filter { case (e, a) =>
newOutput.contains(a)
}.unzip._1
}
- a.copy(child = Expand(newProjects, newOutput, child))
+ a.copy(child = Expand(newProjects, newOutput, grandChild))
+ // TODO: support some logical plan for Dataset
- case a @ Aggregate(_, _, e @ Expand(_, _, child))
- if (child.outputSet -- e.references -- a.references).nonEmpty =>
- a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
-
- // Eliminate attributes that are not needed to calculate the specified aggregates.
+ // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
- a.copy(child = Project(a.references.toSeq, child))
-
- // Eliminate attributes that are not needed to calculate the Generate.
+ a.copy(child = prunedChild(child, a.references))
+ case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
+ w.copy(child = prunedChild(child, w.references))
+ case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
+ e.copy(child = prunedChild(child, e.references))
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
- g.copy(child = Project(g.references.toSeq, g.child))
+ g.copy(child = prunedChild(g.child, g.references))
+ // Turn off `join` for Generate if no column from it's child is used
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- case p @ Project(projectList, g: Generate) if g.join =>
- val neededChildOutput = p.references -- g.generatorOutput ++ g.references
- if (neededChildOutput == g.child.outputSet) {
- p
+ // Eliminate unneeded attributes from right side of a LeftSemiJoin.
+ case j @ Join(left, right, LeftSemi, condition) =>
+ j.copy(right = prunedChild(right, j.references))
+
+ // all the columns will be used to compare, so we can't prune them
+ case p @ Project(_, _: SetOperation) => p
+ case p @ Project(_, _: Distinct) => p
+ // Eliminate unneeded attributes from children of Union.
+ case p @ Project(_, u: Union) =>
+ if ((u.outputSet -- p.references).nonEmpty) {
+ val firstChild = u.children.head
+ val newOutput = prunedChild(firstChild, p.references).output
+ // pruning the columns of all children based on the pruned first child.
+ val newChildren = u.children.map { p =>
+ val selected = p.output.zipWithIndex.filter { case (a, i) =>
+ newOutput.contains(firstChild.output(i))
+ }.map(_._1)
+ Project(selected, p)
+ }
+ p.copy(child = u.withNewChildren(newChildren))
} else {
- Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+ p
}
- case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
- if (a.outputSet -- p.references).nonEmpty =>
- Project(
- projectList,
- Aggregate(
- groupingExpressions,
- aggregateExpressions.filter(e => p.references.contains(e)),
- child))
-
- // Eliminate unneeded attributes from either side of a Join.
- case Project(projectList, Join(left, right, joinType, condition)) =>
- // Collect the list of all references required either above or to evaluate the condition.
- val allReferences: AttributeSet =
- AttributeSet(
- projectList.flatMap(_.references.iterator)) ++
- condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
- /** Applies a projection only when the child is producing unnecessary attributes */
- def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
+ // Can't prune the columns on LeafNode
+ case p @ Project(_, l: LeafNode) => p
- Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
-
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case Join(left, right, LeftSemi, condition) =>
- // Collect the list of all references required to evaluate the condition.
- val allReferences: AttributeSet =
- condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
- Join(left, prunedChild(right, allReferences), LeftSemi, condition)
-
- // Push down project through limit, so that we may have chance to push it further.
- case Project(projectList, Limit(exp, child)) =>
- Limit(exp, Project(projectList, child))
-
- // Push down project if possible when the child is sort.
- case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
- if (s.references.subsetOf(p.outputSet)) {
- s.copy(child = Project(projectList, grandChild))
+ // Eliminate no-op Projects
+ case p @ Project(projectList, child) if child.output == p.output => child
+
+ // for all other logical plans that inherits the output from it's children
+ case p @ Project(_, child) =>
+ val required = child.references ++ p.references
+ if ((child.inputSet -- required).nonEmpty) {
+ val newChildren = child.children.map(c => prunedChild(c, required))
+ p.copy(child = child.withNewChildren(newChildren))
} else {
- val neededReferences = s.references ++ p.references
- if (neededReferences == grandChild.outputSet) {
- // No column we can prune, return the original plan.
- p
- } else {
- // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
- val newProjectList = grandChild.output.filter(neededReferences.contains)
- p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
- }
+ p
}
-
- // Eliminate no-op Projects
- case Project(projectList, child) if child.output == projectList => child
}
/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
- Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+ Project(c.output.filter(allReferences.contains), c)
} else {
c
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index c890fffc40..715d01a3cd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest {
Seq('c, Literal.create(null, StringType), 1),
Seq('c, 'a, 2)),
Seq('c, 'aa.int, 'gid.int),
- Project(Seq('c, 'a),
+ Project(Seq('a, 'c),
input))).analyze
comparePlans(optimized, expected)
}
+ test("Column pruning on Filter") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
+ val expected =
+ Project('a :: Nil,
+ Filter('c > Literal(0.0),
+ Project(Seq('a, 'c), input))).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("Column pruning on except/intersect/distinct") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Except(input, input)).analyze
+ comparePlans(Optimize.execute(query), query)
+
+ val query2 = Project('a :: Nil, Intersect(input, input)).analyze
+ comparePlans(Optimize.execute(query2), query2)
+ val query3 = Project('a :: Nil, Distinct(input)).analyze
+ comparePlans(Optimize.execute(query3), query3)
+ }
+
+ test("Column pruning on Project") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
+ val expected = Project(Seq('a), input).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("column pruning for group") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a, count('b))
+ .select('a)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for group with alias") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a as 'c, count('b))
+ .select('c)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a as 'c).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for Project(ne, Limit)") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val originalQuery =
+ testRelation
+ .select('a, 'b)
+ .limit(2)
+ .select('a)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("push down project past sort") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val x = testRelation.subquery('x)
+
+ // push down valid
+ val originalQuery = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('a)
+ }
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ x.select('a)
+ .sortBy(SortOrder('a, Ascending)).analyze
+
+ comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+
+ // push down invalid
+ val originalQuery1 = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b)
+ }
+
+ val optimized1 = Optimize.execute(originalQuery1.analyze)
+ val correctAnswer1 =
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b).analyze
+
+ comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
+ }
+
+ test("Column pruning on Union") {
+ val input1 = LocalRelation('a.int, 'b.string, 'c.double)
+ val input2 = LocalRelation('c.int, 'd.string, 'e.double)
+ val query = Project('b :: Nil,
+ Union(input1 :: input2 :: Nil)).analyze
+ val expected = Project('b :: Nil,
+ Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
// todo: add more tests for column pruning
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 70b34cbb24..7d60862f5a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughJoin,
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
- ColumnPruning,
CollapseProject) :: Nil
}
@@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("column pruning for group") {
- val originalQuery =
- testRelation
- .groupBy('a)('a, count('b))
- .select('a)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .groupBy('a)('a).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("column pruning for group with alias") {
- val originalQuery =
- testRelation
- .groupBy('a)('a as 'c, count('b))
- .select('c)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .groupBy('a)('a as 'c).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("column pruning for Project(ne, Limit)") {
- val originalQuery =
- testRelation
- .select('a, 'b)
- .limit(2)
- .select('a)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .limit(2).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
// After this line is unimplemented.
test("simple push down") {
val originalQuery =
@@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
- test("push down project past sort") {
- val x = testRelation.subquery('x)
-
- // push down valid
- val originalQuery = {
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('a)
- }
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- x.select('a)
- .sortBy(SortOrder('a, Ascending)).analyze
-
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
-
- // push down invalid
- val originalQuery1 = {
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('b)
- }
-
- val optimized1 = Optimize.execute(originalQuery1.analyze)
- val correctAnswer1 =
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('b).analyze
-
- comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
- }
-
test("push project and filter down into sample") {
val x = testRelation.subquery('x)
val originalQuery =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 1ab53a1257..2f382bbda0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -108,7 +108,7 @@ class JoinOptimizationSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
- Project(Seq($"y.key"), BroadcastHint(SubqueryAlias("y", input))),
+ BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
Inner, None)).analyze
comparePlans(optimized, expected)