aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-12-30 13:38:27 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-30 13:38:27 -0800
commit61a99f6a11d85e931e7d60f9ab4370b3b40a52ef (patch)
treed80690362a7d4b65e85f40da279ee88d092d3094
parenta75dd83b72586695768c89ed32b240aa8f48f32c (diff)
downloadspark-61a99f6a11d85e931e7d60f9ab4370b3b40a52ef.tar.gz
spark-61a99f6a11d85e931e7d60f9ab4370b3b40a52ef.tar.bz2
spark-61a99f6a11d85e931e7d60f9ab4370b3b40a52ef.zip
[SPARK-4937][SQL] Normalizes conjunctions and disjunctions to eliminate common predicates
This PR is a simplified version of several filter optimization rules introduced in #3778 authored by scwf. Newly introduced optimizations include: 1. `a && a` => `a` 2. `a || a` => `a` 3. `(a || b || c || ...) && (a || b || d || ...)` => `a && b && (c || d || ...)` The 3rd rule is particularly useful for optimizing the following query, which is planned into a cartesian product ```sql SELECT * FROM t1, t2 WHERE (t1.key = t2.key AND t1.value > 10) OR (t1.key = t2.key AND t2.value < 20) ``` to the following one, which is planned into an equi-join: ```sql SELECT * FROM t1, t2 WHERE t1.key = t2.key AND (t1.value > 10 OR t2.value < 20) ``` The example above is quite artificial, but common predicates are likely to appear in real life complex queries (like the one mentioned in #3778). A difference between this PR and #3778 is that these optimizations are not limited to `Filter`, but are generalized to all logical plan nodes. Thanks to scwf for bringing up these optimizations, and chenghao-intel for the generalization suggestion. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3784) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #3784 from liancheng/normalize-filters and squashes the following commits: caca560 [Cheng Lian] Moves filter normalization into BooleanSimplification rule 4ab3a58 [Cheng Lian] Fixes test failure, adds more tests 5d54349 [Cheng Lian] Fixes typo in comment 2abbf8e [Cheng Lian] Forgot our sacred Apache licence header... cf95639 [Cheng Lian] Adds an optimization rule for filter normalization
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala27
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala10
4 files changed, 110 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 94b6fb084d..cb5ff67959 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.BooleanType
@@ -48,6 +47,14 @@ trait PredicateHelper {
}
}
+ protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
+ condition match {
+ case Or(cond1, cond2) =>
+ splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
+ case other => other :: Nil
+ }
+ }
+
/**
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
* can be used to determine when is is acceptable to move expression evaluation within a query
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0f2eae6400..cd3137980c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -294,11 +294,16 @@ object OptimizeIn extends Rule[LogicalPlan] {
}
/**
- * Simplifies boolean expressions where the answer can be determined without evaluating both sides.
+ * Simplifies boolean expressions:
+ *
+ * 1. Simplifies expressions whose answer can be determined without evaluating both sides.
+ * 2. Eliminates / extracts common factors.
+ * 3. Removes `Not` operator.
+ *
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
* is only safe when evaluations of expressions does not result in side effects.
*/
-object BooleanSimplification extends Rule[LogicalPlan] {
+object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case and @ And(left, right) =>
@@ -307,7 +312,9 @@ object BooleanSimplification extends Rule[LogicalPlan] {
case (l, Literal(true, BooleanType)) => l
case (Literal(false, BooleanType), _) => Literal(false)
case (_, Literal(false, BooleanType)) => Literal(false)
- case (_, _) => and
+ // a && a && a ... => a
+ case _ if splitConjunctivePredicates(and).distinct.size == 1 => left
+ case _ => and
}
case or @ Or(left, right) =>
@@ -316,7 +323,19 @@ object BooleanSimplification extends Rule[LogicalPlan] {
case (_, Literal(true, BooleanType)) => Literal(true)
case (Literal(false, BooleanType), r) => r
case (l, Literal(false, BooleanType)) => l
- case (_, _) => or
+ // a || a || a ... => a
+ case _ if splitDisjunctivePredicates(or).distinct.size == 1 => left
+ // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...)
+ case _ =>
+ val lhsSet = splitConjunctivePredicates(left).toSet
+ val rhsSet = splitConjunctivePredicates(right).toSet
+ val common = lhsSet.intersect(rhsSet)
+
+ (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
+ .reduceOption(Or)
+ .map(_ :: common.toList)
+ .getOrElse(common.toList)
+ .reduce(And)
}
case not @ Not(exp) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala
new file mode 100644
index 0000000000..906300d833
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+// For implicit conversions
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+
+class NormalizeFiltersSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Seq(
+ Batch("AnalysisNodes", Once,
+ EliminateAnalysisOperators),
+ Batch("NormalizeFilters", FixedPoint(100),
+ BooleanSimplification,
+ SimplifyFilters))
+ }
+
+ val relation = LocalRelation('a.int, 'b.int, 'c.string)
+
+ def checkExpression(original: Expression, expected: Expression): Unit = {
+ val actual = Optimize(relation.where(original)).collect { case f: Filter => f.condition }.head
+ val result = (actual, expected) match {
+ case (And(l1, r1), And(l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1)
+ case (Or (l1, r1), Or (l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1)
+ case (lhs, rhs) => lhs fastEquals rhs
+ }
+
+ assert(result, s"$actual isn't equivalent to $expected")
+ }
+
+ test("a && a => a") {
+ checkExpression('a === 1 && 'a === 1, 'a === 1)
+ checkExpression('a === 1 && 'a === 1 && 'a === 1, 'a === 1)
+ }
+
+ test("a || a => a") {
+ checkExpression('a === 1 || 'a === 1, 'a === 1)
+ checkExpression('a === 1 || 'a === 1 || 'a === 1, 'a === 1)
+ }
+
+ test("(a && b) || (a && c) => a && (b || c)") {
+ checkExpression(
+ ('a === 1 && 'a < 10) || ('a > 2 && 'a === 1),
+ ('a === 1) && ('a < 10 || 'a > 2))
+
+ checkExpression(
+ ('a < 1 && 'b > 2 && 'c.isNull) || ('a < 1 && 'c === "hello" && 'b > 2),
+ ('c.isNull || 'c === "hello") && 'a < 1 && 'b > 2)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 82afa31a99..1915c25392 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -105,7 +105,9 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
test(query) {
val schemaRdd = sql(query)
- assertResult(expectedQueryResult.toArray, "Wrong query result") {
+ val queryExecution = schemaRdd.queryExecution
+
+ assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
schemaRdd.collect().map(_.head).toArray
}
@@ -113,8 +115,10 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
}.head
- assert(readBatches === expectedReadBatches, "Wrong number of read batches")
- assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions")
+ assert(readBatches === expectedReadBatches, s"Wrong number of read batches: $queryExecution")
+ assert(
+ readPartitions === expectedReadPartitions,
+ s"Wrong number of read partitions: $queryExecution")
}
}
}