aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-23 18:19:22 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-23 18:19:22 -0800
commite9533b419e3a87589313350310890ce0caf73dbb (patch)
tree0825206e354d8a196d195c442fc03e031a2275e8 /sql
parent230bbeaa614ed0ee87ecceece42355dd9a4bacb3 (diff)
downloadspark-e9533b419e3a87589313350310890ce0caf73dbb.tar.gz
spark-e9533b419e3a87589313350310890ce0caf73dbb.tar.bz2
spark-e9533b419e3a87589313350310890ce0caf73dbb.zip
[SPARK-13376] [SQL] improve column pruning
## What changes were proposed in this pull request? This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset). ## How was the this patch tested? This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s). Author: Davies Liu <davies@databricks.com> Closes #11256 from davies/fix_column_pruning.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala128
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala128
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala80
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala7
4 files changed, 187 insertions, 156 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1f05f2065c..2b804976f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case a @ Aggregate(_, _, e @ Expand(projects, output, child))
- if (e.outputSet -- a.references).nonEmpty =>
- val newOutput = output.filter(a.references.contains(_))
- val newProjects = projects.map { proj =>
- proj.zip(output).filter { case (e, a) =>
+ // Prunes the unused columns from project list of Project/Aggregate/Window/Expand
+ case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
+ p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
+ case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
+ p.copy(
+ child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
+ case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
+ p.copy(child = w.copy(
+ projectList = w.projectList.filter(p.references.contains),
+ windowExpressions = w.windowExpressions.filter(p.references.contains)))
+ case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
+ val newOutput = e.output.filter(a.references.contains(_))
+ val newProjects = e.projections.map { proj =>
+ proj.zip(e.output).filter { case (e, a) =>
newOutput.contains(a)
}.unzip._1
}
- a.copy(child = Expand(newProjects, newOutput, child))
+ a.copy(child = Expand(newProjects, newOutput, grandChild))
+ // TODO: support some logical plan for Dataset
- case a @ Aggregate(_, _, e @ Expand(_, _, child))
- if (child.outputSet -- e.references -- a.references).nonEmpty =>
- a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
-
- // Eliminate attributes that are not needed to calculate the specified aggregates.
+ // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
- a.copy(child = Project(a.references.toSeq, child))
-
- // Eliminate attributes that are not needed to calculate the Generate.
+ a.copy(child = prunedChild(child, a.references))
+ case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
+ w.copy(child = prunedChild(child, w.references))
+ case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
+ e.copy(child = prunedChild(child, e.references))
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
- g.copy(child = Project(g.references.toSeq, g.child))
+ g.copy(child = prunedChild(g.child, g.references))
+ // Turn off `join` for Generate if no column from it's child is used
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- case p @ Project(projectList, g: Generate) if g.join =>
- val neededChildOutput = p.references -- g.generatorOutput ++ g.references
- if (neededChildOutput == g.child.outputSet) {
- p
+ // Eliminate unneeded attributes from right side of a LeftSemiJoin.
+ case j @ Join(left, right, LeftSemi, condition) =>
+ j.copy(right = prunedChild(right, j.references))
+
+ // all the columns will be used to compare, so we can't prune them
+ case p @ Project(_, _: SetOperation) => p
+ case p @ Project(_, _: Distinct) => p
+ // Eliminate unneeded attributes from children of Union.
+ case p @ Project(_, u: Union) =>
+ if ((u.outputSet -- p.references).nonEmpty) {
+ val firstChild = u.children.head
+ val newOutput = prunedChild(firstChild, p.references).output
+ // pruning the columns of all children based on the pruned first child.
+ val newChildren = u.children.map { p =>
+ val selected = p.output.zipWithIndex.filter { case (a, i) =>
+ newOutput.contains(firstChild.output(i))
+ }.map(_._1)
+ Project(selected, p)
+ }
+ p.copy(child = u.withNewChildren(newChildren))
} else {
- Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+ p
}
- case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
- if (a.outputSet -- p.references).nonEmpty =>
- Project(
- projectList,
- Aggregate(
- groupingExpressions,
- aggregateExpressions.filter(e => p.references.contains(e)),
- child))
-
- // Eliminate unneeded attributes from either side of a Join.
- case Project(projectList, Join(left, right, joinType, condition)) =>
- // Collect the list of all references required either above or to evaluate the condition.
- val allReferences: AttributeSet =
- AttributeSet(
- projectList.flatMap(_.references.iterator)) ++
- condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
- /** Applies a projection only when the child is producing unnecessary attributes */
- def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
+ // Can't prune the columns on LeafNode
+ case p @ Project(_, l: LeafNode) => p
- Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
-
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case Join(left, right, LeftSemi, condition) =>
- // Collect the list of all references required to evaluate the condition.
- val allReferences: AttributeSet =
- condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
- Join(left, prunedChild(right, allReferences), LeftSemi, condition)
-
- // Push down project through limit, so that we may have chance to push it further.
- case Project(projectList, Limit(exp, child)) =>
- Limit(exp, Project(projectList, child))
-
- // Push down project if possible when the child is sort.
- case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
- if (s.references.subsetOf(p.outputSet)) {
- s.copy(child = Project(projectList, grandChild))
+ // Eliminate no-op Projects
+ case p @ Project(projectList, child) if child.output == p.output => child
+
+ // for all other logical plans that inherits the output from it's children
+ case p @ Project(_, child) =>
+ val required = child.references ++ p.references
+ if ((child.inputSet -- required).nonEmpty) {
+ val newChildren = child.children.map(c => prunedChild(c, required))
+ p.copy(child = child.withNewChildren(newChildren))
} else {
- val neededReferences = s.references ++ p.references
- if (neededReferences == grandChild.outputSet) {
- // No column we can prune, return the original plan.
- p
- } else {
- // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
- val newProjectList = grandChild.output.filter(neededReferences.contains)
- p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
- }
+ p
}
-
- // Eliminate no-op Projects
- case Project(projectList, child) if child.output == projectList => child
}
/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
- Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+ Project(c.output.filter(allReferences.contains), c)
} else {
c
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index c890fffc40..715d01a3cd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest {
Seq('c, Literal.create(null, StringType), 1),
Seq('c, 'a, 2)),
Seq('c, 'aa.int, 'gid.int),
- Project(Seq('c, 'a),
+ Project(Seq('a, 'c),
input))).analyze
comparePlans(optimized, expected)
}
+ test("Column pruning on Filter") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
+ val expected =
+ Project('a :: Nil,
+ Filter('c > Literal(0.0),
+ Project(Seq('a, 'c), input))).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("Column pruning on except/intersect/distinct") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Except(input, input)).analyze
+ comparePlans(Optimize.execute(query), query)
+
+ val query2 = Project('a :: Nil, Intersect(input, input)).analyze
+ comparePlans(Optimize.execute(query2), query2)
+ val query3 = Project('a :: Nil, Distinct(input)).analyze
+ comparePlans(Optimize.execute(query3), query3)
+ }
+
+ test("Column pruning on Project") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
+ val expected = Project(Seq('a), input).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("column pruning for group") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a, count('b))
+ .select('a)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for group with alias") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a as 'c, count('b))
+ .select('c)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a as 'c).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for Project(ne, Limit)") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val originalQuery =
+ testRelation
+ .select('a, 'b)
+ .limit(2)
+ .select('a)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("push down project past sort") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val x = testRelation.subquery('x)
+
+ // push down valid
+ val originalQuery = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('a)
+ }
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ x.select('a)
+ .sortBy(SortOrder('a, Ascending)).analyze
+
+ comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+
+ // push down invalid
+ val originalQuery1 = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b)
+ }
+
+ val optimized1 = Optimize.execute(originalQuery1.analyze)
+ val correctAnswer1 =
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b).analyze
+
+ comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
+ }
+
+ test("Column pruning on Union") {
+ val input1 = LocalRelation('a.int, 'b.string, 'c.double)
+ val input2 = LocalRelation('c.int, 'd.string, 'e.double)
+ val query = Project('b :: Nil,
+ Union(input1 :: input2 :: Nil)).analyze
+ val expected = Project('b :: Nil,
+ Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
// todo: add more tests for column pruning
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 70b34cbb24..7d60862f5a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughJoin,
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
- ColumnPruning,
CollapseProject) :: Nil
}
@@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("column pruning for group") {
- val originalQuery =
- testRelation
- .groupBy('a)('a, count('b))
- .select('a)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .groupBy('a)('a).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("column pruning for group with alias") {
- val originalQuery =
- testRelation
- .groupBy('a)('a as 'c, count('b))
- .select('c)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .groupBy('a)('a as 'c).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("column pruning for Project(ne, Limit)") {
- val originalQuery =
- testRelation
- .select('a, 'b)
- .limit(2)
- .select('a)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .limit(2).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
// After this line is unimplemented.
test("simple push down") {
val originalQuery =
@@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
- test("push down project past sort") {
- val x = testRelation.subquery('x)
-
- // push down valid
- val originalQuery = {
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('a)
- }
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- x.select('a)
- .sortBy(SortOrder('a, Ascending)).analyze
-
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
-
- // push down invalid
- val originalQuery1 = {
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('b)
- }
-
- val optimized1 = Optimize.execute(originalQuery1.analyze)
- val correctAnswer1 =
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('b).analyze
-
- comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
- }
-
test("push project and filter down into sample") {
val x = testRelation.subquery('x)
val originalQuery =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 4858140229..22d4278085 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation(
@transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
@transient private[sql] var _statistics: Statistics = null,
private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
- extends LogicalPlan with MultiInstanceRelation {
+ extends logical.LeafNode with MultiInstanceRelation {
override def producedAttributes: AttributeSet = outputSet
@@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation(
_cachedColumnBuffers, statisticsToBePropagated, batchStats)
}
- override def children: Seq[LogicalPlan] = Seq.empty
-
override def newInstance(): this.type = {
new InMemoryRelation(
output.map(_.newInstance()),