aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-08-27 00:34:35 -0700
committerReynold Xin <rxin@databricks.com>2016-08-27 00:34:35 -0700
commit5aad4509c15e131948d387157ecf56af1a705e19 (patch)
treeebb800854b2d6287fdba6a1362e4224ddb4d7ee3 /sql
parent0243b328736f83faea5f83d18c4d331890ed8e81 (diff)
downloadspark-5aad4509c15e131948d387157ecf56af1a705e19.tar.gz
spark-5aad4509c15e131948d387157ecf56af1a705e19.tar.bz2
spark-5aad4509c15e131948d387157ecf56af1a705e19.zip
[SPARK-17273][SQL] Move expression optimizer rules into a separate file
## What changes were proposed in this pull request? As part of breaking Optimizer.scala apart, this patch moves various expression optimization rules into a single file. ## How was this patch tested? This should be covered by existing tests. Author: Reynold Xin <rxin@databricks.com> Closes #14845 from rxin/SPARK-17273.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala461
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala506
2 files changed, 507 insertions, 460 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8a50368980..17cab18ff8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -534,176 +534,6 @@ object CollapseRepartition extends Rule[LogicalPlan] {
}
/**
- * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
- * For example, when the expression is just checking to see if a string starts with a given
- * pattern.
- */
-object LikeSimplification extends Rule[LogicalPlan] {
- // if guards below protect from escapes on trailing %.
- // Cases like "something\%" are not optimized, but this does not affect correctness.
- private val startsWith = "([^_%]+)%".r
- private val endsWith = "%([^_%]+)".r
- private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
- private val contains = "%([^_%]+)%".r
- private val equalTo = "([^_%]*)".r
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(input, Literal(pattern, StringType)) =>
- pattern.toString match {
- case startsWith(prefix) if !prefix.endsWith("\\") =>
- StartsWith(input, Literal(prefix))
- case endsWith(postfix) =>
- EndsWith(input, Literal(postfix))
- // 'a%a' pattern is basically same with 'a%' && '%a'.
- // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
- case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
- And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
- And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
- case contains(infix) if !infix.endsWith("\\") =>
- Contains(input, Literal(infix))
- case equalTo(str) =>
- EqualTo(input, Literal(str))
- case _ =>
- Like(input, Literal.create(pattern, StringType))
- }
- }
-}
-
-/**
- * Replaces [[Expression Expressions]] that can be statically evaluated with
- * equivalent [[Literal]] values. This rule is more specific with
- * Null value propagation from bottom to top of the expression tree.
- */
-object NullPropagation extends Rule[LogicalPlan] {
- private def nonNullLiteral(e: Expression): Boolean = e match {
- case Literal(null, _) => false
- case _ => true
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
- case e @ WindowExpression(Cast(Literal(0L, _), _), _) =>
- Cast(Literal(0L), e.dataType)
- case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
- Cast(Literal(0L), e.dataType)
- case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
- case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
- case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
- case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
- case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
- case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
- case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
- case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
- Literal.create(null, e.dataType)
- case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
- case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
- // This rule should be only triggered when isDistinct field is false.
- ae.copy(aggregateFunction = Count(Literal(1)))
-
- // For Coalesce, remove null literals.
- case e @ Coalesce(children) =>
- val newChildren = children.filter(nonNullLiteral)
- if (newChildren.isEmpty) {
- Literal.create(null, e.dataType)
- } else if (newChildren.length == 1) {
- newChildren.head
- } else {
- Coalesce(newChildren)
- }
-
- case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
- case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
- case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)
-
- // Put exceptional cases above if any
- case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
- case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
-
- case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
- case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
-
- case e: StringRegexExpression => e.children match {
- case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
- case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
- case _ => e
- }
-
- case e: StringPredicate => e.children match {
- case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
- case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
- case _ => e
- }
-
- // If the value expression is NULL then transform the In expression to
- // Literal(null)
- case In(Literal(null, _), list) => Literal.create(null, BooleanType)
-
- }
- }
-}
-
-/**
- * Propagate foldable expressions:
- * Replace attributes with aliases of the original foldable expressions if possible.
- * Other optimizations will take advantage of the propagated foldable expressions.
- *
- * {{{
- * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
- * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
- * }}}
- */
-object FoldablePropagation extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- val foldableMap = AttributeMap(plan.flatMap {
- case Project(projectList, _) => projectList.collect {
- case a: Alias if a.child.foldable => (a.toAttribute, a)
- }
- case _ => Nil
- })
-
- if (foldableMap.isEmpty) {
- plan
- } else {
- var stop = false
- CleanupAliases(plan.transformUp {
- case u: Union =>
- stop = true
- u
- case c: Command =>
- stop = true
- c
- // For outer join, although its output attributes are derived from its children, they are
- // actually different attributes: the output of outer join is not always picked from its
- // children, but can also be null.
- // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
- // of outer join.
- case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
- stop = true
- j
-
- // These 3 operators take attributes as constructor parameters, and these attributes
- // can't be replaced by alias.
- case m: MapGroups =>
- stop = true
- m
- case f: FlatMapGroupsInR =>
- stop = true
- f
- case c: CoGroup =>
- stop = true
- c
-
- case p: LogicalPlan if !stop => p.transformExpressions {
- case a: AttributeReference if foldableMap.contains(a) =>
- foldableMap(a)
- }
- })
- }
- }
-}
-
-/**
* Generate a list of additional filters from an operator's existing constraint but remove those
* that are either already part of the operator's condition or are part of the operator's child
* constraints. These filters are currently inserted to the existing conditions in the Filter
@@ -743,261 +573,6 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
}
/**
- * Reorder associative integral-type operators and fold all constants into one.
- */
-object ReorderAssociativeOperator extends Rule[LogicalPlan] {
- private def flattenAdd(e: Expression): Seq[Expression] = e match {
- case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
- case other => other :: Nil
- }
-
- private def flattenMultiply(e: Expression): Seq[Expression] = e match {
- case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
- case other => other :: Nil
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
- case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
- val (foldables, others) = flattenAdd(a).partition(_.foldable)
- if (foldables.size > 1) {
- val foldableExpr = foldables.reduce((x, y) => Add(x, y))
- val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
- if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c)
- } else {
- a
- }
- case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
- val (foldables, others) = flattenMultiply(m).partition(_.foldable)
- if (foldables.size > 1) {
- val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
- val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
- if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c)
- } else {
- m
- }
- }
- }
-}
-
-/**
- * Replaces [[Expression Expressions]] that can be statically evaluated with
- * equivalent [[Literal]] values.
- */
-object ConstantFolding extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
- // Skip redundant folding of literals. This rule is technically not necessary. Placing this
- // here avoids running the next rule for Literal values, which would create a new Literal
- // object and running eval unnecessarily.
- case l: Literal => l
-
- // Fold expressions that are foldable.
- case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
- }
- }
-}
-
-/**
- * Optimize IN predicates:
- * 1. Removes literal repetitions.
- * 2. Replaces [[In (value, seq[Literal])]] with optimized version
- * [[InSet (value, HashSet[Literal])]] which is much faster.
- */
-case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
- case expr @ In(v, list) if expr.inSetConvertible =>
- val newList = ExpressionSet(list).toSeq
- if (newList.size > conf.optimizerInSetConversionThreshold) {
- val hSet = newList.map(e => e.eval(EmptyRow))
- InSet(v, HashSet() ++ hSet)
- } else if (newList.size < list.size) {
- expr.copy(list = newList)
- } else { // newList.length == list.length
- expr
- }
- }
- }
-}
-
-/**
- * Simplifies boolean expressions:
- * 1. Simplifies expressions whose answer can be determined without evaluating both sides.
- * 2. Eliminates / extracts common factors.
- * 3. Merge same expressions
- * 4. Removes `Not` operator.
- */
-object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
- case TrueLiteral And e => e
- case e And TrueLiteral => e
- case FalseLiteral Or e => e
- case e Or FalseLiteral => e
-
- case FalseLiteral And _ => FalseLiteral
- case _ And FalseLiteral => FalseLiteral
- case TrueLiteral Or _ => TrueLiteral
- case _ Or TrueLiteral => TrueLiteral
-
- case a And b if a.semanticEquals(b) => a
- case a Or b if a.semanticEquals(b) => a
-
- case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
- case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
- case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
- case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
-
- case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
- case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
- case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
- case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
-
- // Common factor elimination for conjunction
- case and @ (left And right) =>
- // 1. Split left and right to get the disjunctive predicates,
- // i.e. lhs = (a, b), rhs = (a, c)
- // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
- // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
- // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
- val lhs = splitDisjunctivePredicates(left)
- val rhs = splitDisjunctivePredicates(right)
- val common = lhs.filter(e => rhs.exists(e.semanticEquals))
- if (common.isEmpty) {
- // No common factors, return the original predicate
- and
- } else {
- val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
- val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
- if (ldiff.isEmpty || rdiff.isEmpty) {
- // (a || b || c || ...) && (a || b) => (a || b)
- common.reduce(Or)
- } else {
- // (a || b || c || ...) && (a || b || d || ...) =>
- // ((c || ...) && (d || ...)) || a || b
- (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
- }
- }
-
- // Common factor elimination for disjunction
- case or @ (left Or right) =>
- // 1. Split left and right to get the conjunctive predicates,
- // i.e. lhs = (a, b), rhs = (a, c)
- // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
- // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
- // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
- val lhs = splitConjunctivePredicates(left)
- val rhs = splitConjunctivePredicates(right)
- val common = lhs.filter(e => rhs.exists(e.semanticEquals))
- if (common.isEmpty) {
- // No common factors, return the original predicate
- or
- } else {
- val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
- val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
- if (ldiff.isEmpty || rdiff.isEmpty) {
- // (a && b) || (a && b && c && ...) => a && b
- common.reduce(And)
- } else {
- // (a && b && c && ...) || (a && b && d && ...) =>
- // ((c && ...) || (d && ...)) && a && b
- (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
- }
- }
-
- case Not(TrueLiteral) => FalseLiteral
- case Not(FalseLiteral) => TrueLiteral
-
- case Not(a GreaterThan b) => LessThanOrEqual(a, b)
- case Not(a GreaterThanOrEqual b) => LessThan(a, b)
-
- case Not(a LessThan b) => GreaterThanOrEqual(a, b)
- case Not(a LessThanOrEqual b) => GreaterThan(a, b)
-
- case Not(a Or b) => And(Not(a), Not(b))
- case Not(a And b) => Or(Not(a), Not(b))
-
- case Not(Not(e)) => e
- }
- }
-}
-
-/**
- * Simplifies binary comparisons with semantically-equal expressions:
- * 1) Replace '<=>' with 'true' literal.
- * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
- * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
- */
-object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
- // True with equality
- case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
- case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
- case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
- TrueLiteral
- case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
-
- // False with inequality
- case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
- case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
- }
- }
-}
-
-/**
- * Simplifies conditional expressions (if / case).
- */
-object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
- private def falseOrNullLiteral(e: Expression): Boolean = e match {
- case FalseLiteral => true
- case Literal(null, _) => true
- case _ => false
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
- case If(TrueLiteral, trueValue, _) => trueValue
- case If(FalseLiteral, _, falseValue) => falseValue
- case If(Literal(null, _), _, falseValue) => falseValue
-
- case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
- // If there are branches that are always false, remove them.
- // If there are no more branches left, just use the else value.
- // Note that these two are handled together here in a single case statement because
- // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
- val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
- if (newBranches.isEmpty) {
- elseValue.getOrElse(Literal.create(null, e.dataType))
- } else {
- e.copy(branches = newBranches)
- }
-
- case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
- // If the first branch is a true literal, remove the entire CaseWhen and use the value
- // from that. Note that CaseWhen.branches should never be empty, and as a result the
- // headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
- branches.head._2
- }
- }
-}
-
-/**
- * Optimizes expressions by replacing according to CodeGen configuration.
- */
-case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case e: CaseWhen if canCodegen(e) => e.toCodegen()
- }
-
- private def canCodegen(e: CaseWhen): Boolean = {
- val numBranches = e.branches.size + e.elseValue.size
- numBranches <= conf.maxCaseBranchesForCodegen
- }
-}
-
-/**
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
@@ -1026,7 +601,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
/**
* Removes no-op SortOrder from Sort
*/
-object EliminateSorts extends Rule[LogicalPlan] {
+object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
@@ -1449,25 +1024,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
- */
-object SimplifyCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Cast(e, dataType) if e.dataType == dataType => e
- }
-}
-
-/**
- * Removes nodes that are not necessary.
- */
-object RemoveDispensableExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case UnaryPositive(child) => child
- case PromotePrecision(child) => child
- }
-}
-
-/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
*/
@@ -1483,21 +1039,6 @@ object CombineLimits extends Rule[LogicalPlan] {
}
/**
- * Removes the inner case conversion expressions that are unnecessary because
- * the inner conversion is overwritten by the outer one.
- */
-object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
- case Upper(Upper(child)) => Upper(child)
- case Upper(Lower(child)) => Upper(child)
- case Lower(Upper(child)) => Lower(child)
- case Lower(Lower(child)) => Lower(child)
- }
- }
-}
-
-/**
* Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
*
* This uses the same rules for increasing the precision and scale of the output as
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
new file mode 100644
index 0000000000..74dfd10189
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -0,0 +1,506 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.immutable.HashSet
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types._
+
+/*
+ * Optimization rules defined in this file should not affect the structure of the logical plan.
+ */
+
+
+/**
+ * Replaces [[Expression Expressions]] that can be statically evaluated with
+ * equivalent [[Literal]] values.
+ */
+object ConstantFolding extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ // Skip redundant folding of literals. This rule is technically not necessary. Placing this
+ // here avoids running the next rule for Literal values, which would create a new Literal
+ // object and running eval unnecessarily.
+ case l: Literal => l
+
+ // Fold expressions that are foldable.
+ case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
+ }
+ }
+}
+
+
+/**
+ * Reorder associative integral-type operators and fold all constants into one.
+ */
+object ReorderAssociativeOperator extends Rule[LogicalPlan] {
+ private def flattenAdd(e: Expression): Seq[Expression] = e match {
+ case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
+ case other => other :: Nil
+ }
+
+ private def flattenMultiply(e: Expression): Seq[Expression] = e match {
+ case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
+ case other => other :: Nil
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
+ val (foldables, others) = flattenAdd(a).partition(_.foldable)
+ if (foldables.size > 1) {
+ val foldableExpr = foldables.reduce((x, y) => Add(x, y))
+ val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
+ if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c)
+ } else {
+ a
+ }
+ case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
+ val (foldables, others) = flattenMultiply(m).partition(_.foldable)
+ if (foldables.size > 1) {
+ val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
+ val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
+ if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c)
+ } else {
+ m
+ }
+ }
+ }
+}
+
+
+/**
+ * Optimize IN predicates:
+ * 1. Removes literal repetitions.
+ * 2. Replaces [[In (value, seq[Literal])]] with optimized version
+ * [[InSet (value, HashSet[Literal])]] which is much faster.
+ */
+case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ case expr @ In(v, list) if expr.inSetConvertible =>
+ val newList = ExpressionSet(list).toSeq
+ if (newList.size > conf.optimizerInSetConversionThreshold) {
+ val hSet = newList.map(e => e.eval(EmptyRow))
+ InSet(v, HashSet() ++ hSet)
+ } else if (newList.size < list.size) {
+ expr.copy(list = newList)
+ } else { // newList.length == list.length
+ expr
+ }
+ }
+ }
+}
+
+
+/**
+ * Simplifies boolean expressions:
+ * 1. Simplifies expressions whose answer can be determined without evaluating both sides.
+ * 2. Eliminates / extracts common factors.
+ * 3. Merge same expressions
+ * 4. Removes `Not` operator.
+ */
+object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case TrueLiteral And e => e
+ case e And TrueLiteral => e
+ case FalseLiteral Or e => e
+ case e Or FalseLiteral => e
+
+ case FalseLiteral And _ => FalseLiteral
+ case _ And FalseLiteral => FalseLiteral
+ case TrueLiteral Or _ => TrueLiteral
+ case _ Or TrueLiteral => TrueLiteral
+
+ case a And b if a.semanticEquals(b) => a
+ case a Or b if a.semanticEquals(b) => a
+
+ case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
+ case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
+ case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
+ case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
+
+ case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
+ case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
+ case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
+ case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
+
+ // Common factor elimination for conjunction
+ case and @ (left And right) =>
+ // 1. Split left and right to get the disjunctive predicates,
+ // i.e. lhs = (a, b), rhs = (a, c)
+ // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
+ // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
+ // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
+ val lhs = splitDisjunctivePredicates(left)
+ val rhs = splitDisjunctivePredicates(right)
+ val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+ if (common.isEmpty) {
+ // No common factors, return the original predicate
+ and
+ } else {
+ val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+ val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+ if (ldiff.isEmpty || rdiff.isEmpty) {
+ // (a || b || c || ...) && (a || b) => (a || b)
+ common.reduce(Or)
+ } else {
+ // (a || b || c || ...) && (a || b || d || ...) =>
+ // ((c || ...) && (d || ...)) || a || b
+ (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
+ }
+ }
+
+ // Common factor elimination for disjunction
+ case or @ (left Or right) =>
+ // 1. Split left and right to get the conjunctive predicates,
+ // i.e. lhs = (a, b), rhs = (a, c)
+ // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
+ // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
+ // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
+ val lhs = splitConjunctivePredicates(left)
+ val rhs = splitConjunctivePredicates(right)
+ val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+ if (common.isEmpty) {
+ // No common factors, return the original predicate
+ or
+ } else {
+ val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+ val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+ if (ldiff.isEmpty || rdiff.isEmpty) {
+ // (a && b) || (a && b && c && ...) => a && b
+ common.reduce(And)
+ } else {
+ // (a && b && c && ...) || (a && b && d && ...) =>
+ // ((c && ...) || (d && ...)) && a && b
+ (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
+ }
+ }
+
+ case Not(TrueLiteral) => FalseLiteral
+ case Not(FalseLiteral) => TrueLiteral
+
+ case Not(a GreaterThan b) => LessThanOrEqual(a, b)
+ case Not(a GreaterThanOrEqual b) => LessThan(a, b)
+
+ case Not(a LessThan b) => GreaterThanOrEqual(a, b)
+ case Not(a LessThanOrEqual b) => GreaterThan(a, b)
+
+ case Not(a Or b) => And(Not(a), Not(b))
+ case Not(a And b) => Or(Not(a), Not(b))
+
+ case Not(Not(e)) => e
+ }
+ }
+}
+
+
+/**
+ * Simplifies binary comparisons with semantically-equal expressions:
+ * 1) Replace '<=>' with 'true' literal.
+ * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
+ * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
+ */
+object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ // True with equality
+ case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
+ case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+ case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
+ TrueLiteral
+ case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+
+ // False with inequality
+ case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ }
+ }
+}
+
+
+/**
+ * Simplifies conditional expressions (if / case).
+ */
+object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+ private def falseOrNullLiteral(e: Expression): Boolean = e match {
+ case FalseLiteral => true
+ case Literal(null, _) => true
+ case _ => false
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case If(TrueLiteral, trueValue, _) => trueValue
+ case If(FalseLiteral, _, falseValue) => falseValue
+ case If(Literal(null, _), _, falseValue) => falseValue
+
+ case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
+ // If there are branches that are always false, remove them.
+ // If there are no more branches left, just use the else value.
+ // Note that these two are handled together here in a single case statement because
+ // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
+ val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
+ if (newBranches.isEmpty) {
+ elseValue.getOrElse(Literal.create(null, e.dataType))
+ } else {
+ e.copy(branches = newBranches)
+ }
+
+ case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
+ // If the first branch is a true literal, remove the entire CaseWhen and use the value
+ // from that. Note that CaseWhen.branches should never be empty, and as a result the
+ // headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
+ branches.head._2
+ }
+ }
+}
+
+
+/**
+ * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
+ * For example, when the expression is just checking to see if a string starts with a given
+ * pattern.
+ */
+object LikeSimplification extends Rule[LogicalPlan] {
+ // if guards below protect from escapes on trailing %.
+ // Cases like "something\%" are not optimized, but this does not affect correctness.
+ private val startsWith = "([^_%]+)%".r
+ private val endsWith = "%([^_%]+)".r
+ private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
+ private val contains = "%([^_%]+)%".r
+ private val equalTo = "([^_%]*)".r
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case Like(input, Literal(pattern, StringType)) =>
+ pattern.toString match {
+ case startsWith(prefix) if !prefix.endsWith("\\") =>
+ StartsWith(input, Literal(prefix))
+ case endsWith(postfix) =>
+ EndsWith(input, Literal(postfix))
+ // 'a%a' pattern is basically same with 'a%' && '%a'.
+ // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
+ case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
+ And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
+ And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
+ case contains(infix) if !infix.endsWith("\\") =>
+ Contains(input, Literal(infix))
+ case equalTo(str) =>
+ EqualTo(input, Literal(str))
+ case _ =>
+ Like(input, Literal.create(pattern, StringType))
+ }
+ }
+}
+
+
+/**
+ * Replaces [[Expression Expressions]] that can be statically evaluated with
+ * equivalent [[Literal]] values. This rule is more specific with
+ * Null value propagation from bottom to top of the expression tree.
+ */
+object NullPropagation extends Rule[LogicalPlan] {
+ private def nonNullLiteral(e: Expression): Boolean = e match {
+ case Literal(null, _) => false
+ case _ => true
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case e @ WindowExpression(Cast(Literal(0L, _), _), _) =>
+ Cast(Literal(0L), e.dataType)
+ case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
+ Cast(Literal(0L), e.dataType)
+ case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
+ case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
+ case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
+ case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
+ case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
+ case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
+ Literal.create(null, e.dataType)
+ case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
+ case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
+ case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
+ // This rule should be only triggered when isDistinct field is false.
+ ae.copy(aggregateFunction = Count(Literal(1)))
+
+ // For Coalesce, remove null literals.
+ case e @ Coalesce(children) =>
+ val newChildren = children.filter(nonNullLiteral)
+ if (newChildren.isEmpty) {
+ Literal.create(null, e.dataType)
+ } else if (newChildren.length == 1) {
+ newChildren.head
+ } else {
+ Coalesce(newChildren)
+ }
+
+ case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
+ case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)
+
+ // Put exceptional cases above if any
+ case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
+
+ case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
+
+ case e: StringRegexExpression => e.children match {
+ case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
+ case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
+ case _ => e
+ }
+
+ case e: StringPredicate => e.children match {
+ case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
+ case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
+ case _ => e
+ }
+
+ // If the value expression is NULL then transform the In expression to
+ // Literal(null)
+ case In(Literal(null, _), list) => Literal.create(null, BooleanType)
+
+ }
+ }
+}
+
+
+/**
+ * Propagate foldable expressions:
+ * Replace attributes with aliases of the original foldable expressions if possible.
+ * Other optimizations will take advantage of the propagated foldable expressions.
+ *
+ * {{{
+ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
+ * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
+ * }}}
+ */
+object FoldablePropagation extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val foldableMap = AttributeMap(plan.flatMap {
+ case Project(projectList, _) => projectList.collect {
+ case a: Alias if a.child.foldable => (a.toAttribute, a)
+ }
+ case _ => Nil
+ })
+
+ if (foldableMap.isEmpty) {
+ plan
+ } else {
+ var stop = false
+ CleanupAliases(plan.transformUp {
+ case u: Union =>
+ stop = true
+ u
+ case c: Command =>
+ stop = true
+ c
+ // For outer join, although its output attributes are derived from its children, they are
+ // actually different attributes: the output of outer join is not always picked from its
+ // children, but can also be null.
+ // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
+ // of outer join.
+ case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
+ stop = true
+ j
+
+ // These 3 operators take attributes as constructor parameters, and these attributes
+ // can't be replaced by alias.
+ case m: MapGroups =>
+ stop = true
+ m
+ case f: FlatMapGroupsInR =>
+ stop = true
+ f
+ case c: CoGroup =>
+ stop = true
+ c
+
+ case p: LogicalPlan if !stop => p.transformExpressions {
+ case a: AttributeReference if foldableMap.contains(a) =>
+ foldableMap(a)
+ }
+ })
+ }
+ }
+}
+
+
+/**
+ * Optimizes expressions by replacing according to CodeGen configuration.
+ */
+case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case e: CaseWhen if canCodegen(e) => e.toCodegen()
+ }
+
+ private def canCodegen(e: CaseWhen): Boolean = {
+ val numBranches = e.branches.size + e.elseValue.size
+ numBranches <= conf.maxCaseBranchesForCodegen
+ }
+}
+
+
+/**
+ * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
+ */
+object SimplifyCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case Cast(e, dataType) if e.dataType == dataType => e
+ }
+}
+
+
+/**
+ * Removes nodes that are not necessary.
+ */
+object RemoveDispensableExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case UnaryPositive(child) => child
+ case PromotePrecision(child) => child
+ }
+}
+
+
+/**
+ * Removes the inner case conversion expressions that are unnecessary because
+ * the inner conversion is overwritten by the outer one.
+ */
+object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case Upper(Upper(child)) => Upper(child)
+ case Upper(Lower(child)) => Upper(child)
+ case Lower(Upper(child)) => Lower(child)
+ case Lower(Lower(child)) => Lower(child)
+ }
+ }
+}