aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-18 13:07:41 -0800
committerJosh Rosen <joshrosen@databricks.com>2016-02-18 13:07:41 -0800
commit26f38bb83c423e512955ca25775914dae7e5bbf0 (patch)
tree3221f51a121ea8ddd50979c68b6b29751702d712 /sql/catalyst
parent78562535feb6e214520b29e0bbdd4b1302f01e93 (diff)
downloadspark-26f38bb83c423e512955ca25775914dae7e5bbf0.tar.gz
spark-26f38bb83c423e512955ca25775914dae7e5bbf0.tar.bz2
spark-26f38bb83c423e512955ca25775914dae7e5bbf0.zip
[SPARK-13351][SQL] fix column pruning on Expand
Currently, the columns in projects of Expand that are not used by Aggregate are not pruned, this PR fix that. Author: Davies Liu <davies@databricks.com> Closes #11225 from davies/fix_pruning_expand.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala33
2 files changed, 41 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 567010f23f..55c168d552 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -300,6 +300,16 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case a @ Aggregate(_, _, e @ Expand(projects, output, child))
+ if (e.outputSet -- a.references).nonEmpty =>
+ val newOutput = output.filter(a.references.contains(_))
+ val newProjects = projects.map { proj =>
+ proj.zip(output).filter { case (e, a) =>
+ newOutput.contains(a)
+ }.unzip._1
+ }
+ a.copy(child = Expand(newProjects, newOutput, child))
+
case a @ Aggregate(_, _, e @ Expand(_, _, child))
if (child.outputSet -- e.references -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 81f3928035..c890fffc40 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Explode
+import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StringType
@@ -96,5 +96,34 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("Column pruning for Expand") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query =
+ Aggregate(
+ Seq('aa, 'gid),
+ Seq(sum('c).as("sum")),
+ Expand(
+ Seq(
+ Seq('a, 'b, 'c, Literal.create(null, StringType), 1),
+ Seq('a, 'b, 'c, 'a, 2)),
+ Seq('a, 'b, 'c, 'aa.int, 'gid.int),
+ input)).analyze
+ val optimized = Optimize.execute(query)
+
+ val expected =
+ Aggregate(
+ Seq('aa, 'gid),
+ Seq(sum('c).as("sum")),
+ Expand(
+ Seq(
+ Seq('c, Literal.create(null, StringType), 1),
+ Seq('c, 'a, 2)),
+ Seq('c, 'aa.int, 'gid.int),
+ Project(Seq('c, 'a),
+ input))).analyze
+
+ comparePlans(optimized, expected)
+ }
+
// todo: add more tests for column pruning
}