aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala25
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala63
2 files changed, 87 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3bc48c95c5..fd58b9681e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -50,6 +50,7 @@ object DefaultOptimizer extends Optimizer {
CombineFilters,
PushPredicateThroughProject,
PushPredicateThroughJoin,
+ PushPredicateThroughGenerate,
ColumnPruning) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) :: Nil
@@ -456,6 +457,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
}
/**
+ * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
+ * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
+ */
+object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case filter @ Filter(condition,
+ generate @ Generate(generator, join, outer, alias, grandChild)) =>
+ // Predicates that reference attributes produced by the `Generate` operator cannot
+ // be pushed below the operator.
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
+ conjunct => conjunct.references subsetOf grandChild.outputSet
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownPredicate = pushDown.reduce(And)
+ val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
+ stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+ } else {
+ filter
+ }
+ }
+}
+
+/**
* Pushes down [[Filter]] operators where the `condition` can be
* evaluated using only the attributes of the left or right side of a join. Other
* [[Filter]] conditions are moved into the `condition` of the [[Join]].
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ebb123c1f9..1158b5dfc6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.IntegerType
class FilterPushdownSuite extends PlanTest {
@@ -34,7 +36,8 @@ class FilterPushdownSuite extends PlanTest {
Batch("Filter Pushdown", Once,
CombineFilters,
PushPredicateThroughProject,
- PushPredicateThroughJoin) :: Nil
+ PushPredicateThroughJoin,
+ PushPredicateThroughGenerate) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -411,4 +414,62 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
}
+
+ val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
+
+ test("generate: predicate referenced no generated column") {
+ val originalQuery = {
+ testRelationWithArrayType
+ .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+ .where(('b >= 5) && ('a > 6))
+ }
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer = {
+ testRelationWithArrayType
+ .where(('b >= 5) && ('a > 6))
+ .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
+ }
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("generate: part of conjuncts referenced generated column") {
+ val generator = Explode(Seq("c"), 'c_arr)
+ val originalQuery = {
+ testRelationWithArrayType
+ .generate(generator, true, false, Some("arr"))
+ .where(('b >= 5) && ('c > 6))
+ }
+ val optimized = Optimize(originalQuery.analyze)
+ val referenceResult = {
+ testRelationWithArrayType
+ .where('b >= 5)
+ .generate(generator, true, false, Some("arr"))
+ .where('c > 6).analyze
+ }
+
+ // Since newly generated columns get different ids every time being analyzed
+ // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails.
+ // So we check operators manually here.
+ // Filter("c" > 6)
+ assertResult(classOf[Filter])(optimized.getClass)
+ assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size)
+ assertResult("c"){
+ optimized.asInstanceOf[Filter].condition.references.toSeq(0).name
+ }
+
+ // the rest part
+ comparePlans(optimized.children(0), referenceResult.children(0))
+ }
+
+ test("generate: all conjuncts referenced generated column") {
+ val originalQuery = {
+ testRelationWithArrayType
+ .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+ .where(('c > 6) || ('b > 5)).analyze
+ }
+ val optimized = Optimize(originalQuery)
+
+ comparePlans(optimized, originalQuery)
+ }
}