aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-19 15:04:56 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-19 15:05:25 -0700
commit5c749c82cb3caa5a41fd3fd49c32ab23c6f738da (patch)
tree80100521c87f2d84a16846a0780ac8c28d33016b
parenta59475f5b4d21c12f863c26797fe5f1cea7a5954 (diff)
downloadspark-5c749c82cb3caa5a41fd3fd49c32ab23c6f738da.tar.gz
spark-5c749c82cb3caa5a41fd3fd49c32ab23c6f738da.tar.bz2
spark-5c749c82cb3caa5a41fd3fd49c32ab23c6f738da.zip
[SPARK-6489] [SQL] add column pruning for Generate
This PR takes over https://github.com/apache/spark/pull/5358 Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8268 from cloud-fan/6489. (cherry picked from commit b0dbaec4f942a47afde3490b9339ad3bd187024d) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala84
3 files changed, 100 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index d474853355..c0845e1a01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.Map
-
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 47b06cae15..42457d5318 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -165,6 +165,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
*
* - Inserting Projections beneath the following operators:
* - Aggregate
+ * - Generate
* - Project <- Join
* - LeftSemiJoin
*/
@@ -178,6 +179,21 @@ object ColumnPruning extends Rule[LogicalPlan] {
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
+ // Eliminate attributes that are not needed to calculate the Generate.
+ case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
+ g.copy(child = Project(g.references.toSeq, g.child))
+
+ case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
+ p.copy(child = g.copy(join = false))
+
+ case p @ Project(projectList, g: Generate) if g.join =>
+ val neededChildOutput = p.references -- g.generatorOutput ++ g.references
+ if (neededChildOutput == g.child.outputSet) {
+ p
+ } else {
+ Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+ }
+
case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
if (a.outputSet -- p.references).nonEmpty =>
Project(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
new file mode 100644
index 0000000000..dbebcb8680
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.Explode
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.types.StringType
+
+class ColumnPruningSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("Column pruning", FixedPoint(100),
+ ColumnPruning) :: Nil
+ }
+
+ test("Column pruning for Generate when Generate.join = false") {
+ val input = LocalRelation('a.int, 'b.array(StringType))
+
+ val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
+ val optimized = Optimize.execute(query)
+
+ val correctAnswer =
+ Generate(Explode('b), false, false, None, 's.string :: Nil,
+ Project('b.attr :: Nil, input)).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Column pruning for Generate when Generate.join = true") {
+ val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
+
+ val query =
+ Project(Seq('a, 's),
+ Generate(Explode('c), true, false, None, 's.string :: Nil,
+ input)).analyze
+ val optimized = Optimize.execute(query)
+
+ val correctAnswer =
+ Project(Seq('a, 's),
+ Generate(Explode('c), true, false, None, 's.string :: Nil,
+ Project(Seq('a, 'c),
+ input))).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Turn Generate.join to false if possible") {
+ val input = LocalRelation('b.array(StringType))
+
+ val query =
+ Project(('s + 1).as("s+1") :: Nil,
+ Generate(Explode('b), true, false, None, 's.string :: Nil,
+ input)).analyze
+ val optimized = Optimize.execute(query)
+
+ val correctAnswer =
+ Project(('s + 1).as("s+1") :: Nil,
+ Generate(Explode('b), false, false, None, 's.string :: Nil,
+ input)).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ // todo: add more tests for column pruning
+}