aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala286
1 files changed, 199 insertions, 87 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948ef1b..f5172b213a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
/**
- * Abstract class all optimizers should inherit of, contains the standard batches (extending
- * Optimizers can override this.
- */
+ * Abstract class all optimizers should inherit of, contains the standard batches (extending
+ * Optimizers can override this.
+ */
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = {
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
@@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ReorderJoin,
OuterJoinElimination,
PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
+ PushDownPredicate,
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
@@ -86,6 +84,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
BooleanSimplification,
SimplifyConditionals,
RemoveDispensableExpressions,
+ BinaryComparisonSimplification,
PruneFilters,
EliminateSorts,
SimplifyCasts,
@@ -93,6 +92,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
+ Batch("Typed Filter Optimization", FixedPoint(100),
+ EmbedSerializerInFilter) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) ::
Batch("Subquery", Once,
@@ -111,11 +112,11 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
}
/**
- * Non-abstract representation of the standard Spark optimizing strategies
- *
- * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
- * specific rules go to the subclasses
- */
+ * Non-abstract representation of the standard Spark optimizing strategies
+ *
+ * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
+ * specific rules go to the subclasses
+ */
object DefaultOptimizer extends Optimizer
/**
@@ -136,6 +137,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
+ // TODO: find a more general way to do this optimization.
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
@@ -144,6 +146,20 @@ object EliminateSerialization extends Rule[LogicalPlan] {
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
+
+ case m @ MapElements(_, deserializer, _, child: ObjectOperator)
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
+ val childWithoutSerialization = child.withObjectOutput
+ m.copy(
+ deserializer = childWithoutSerialization.output.head,
+ child = childWithoutSerialization)
+
+ case d @ DeserializeToObject(_, s: SerializeFromObject)
+ if d.outputObjectType == s.inputObjectType =>
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
}
}
@@ -270,10 +286,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val newFirstChild = Project(projectList, children.head)
- val newOtherChildren = children.tail.map ( child => {
+ val newOtherChildren = children.tail.map { child =>
val rewrites = buildRewrites(children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
- } )
+ }
Union(newFirstChild +: newOtherChildren)
} else {
p
@@ -352,8 +368,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case j @ Join(left, right, LeftSemi, condition) =>
+ // Eliminate unneeded attributes from right side of a Left Existence Join.
+ case j @ Join(left, right, LeftExistence(_), condition) =>
j.copy(right = prunedChild(right, j.references))
// all the columns will be used to compare, so we can't prune them
@@ -501,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] {
// Cases like "something\%" are not optimized, but this does not affect correctness.
private val startsWith = "([^_%]+)%".r
private val endsWith = "%([^_%]+)".r
+ private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
private val contains = "%([^_%]+)%".r
private val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(utf, StringType)) =>
- utf.toString match {
- case startsWith(pattern) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case endsWith(pattern) =>
- EndsWith(l, Literal(pattern))
- case contains(pattern) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case equalTo(pattern) =>
- EqualTo(l, Literal(pattern))
+ case Like(input, Literal(pattern, StringType)) =>
+ pattern.toString match {
+ case startsWith(prefix) if !prefix.endsWith("\\") =>
+ StartsWith(input, Literal(prefix))
+ case endsWith(postfix) =>
+ EndsWith(input, Literal(postfix))
+ // 'a%a' pattern is basically same with 'a%' && '%a'.
+ // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
+ case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
+ And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
+ And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
+ case contains(infix) if !infix.endsWith("\\") =>
+ Contains(input, Literal(infix))
+ case equalTo(str) =>
+ EqualTo(input, Literal(str))
case _ =>
- Like(l, Literal.create(utf, StringType))
+ Like(input, Literal.create(pattern, StringType))
}
}
}
@@ -527,14 +549,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
- def nonNullLiteral(e: Expression): Boolean = e match {
+ private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
+ case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +569,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
+ case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
- AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+ ae.copy(aggregateFunction = Count(Literal(1)))
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -770,20 +792,50 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
}
/**
+ * Simplifies binary comparisons with semantically-equal expressions:
+ * 1) Replace '<=>' with 'true' literal.
+ * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
+ * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
+ */
+object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ // True with equality
+ case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
+ case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+ case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
+ TrueLiteral
+ case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+
+ // False with inequality
+ case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ }
+ }
+}
+
+/**
* Simplifies conditional expressions (if / case).
*/
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+ private def falseOrNullLiteral(e: Expression): Boolean = e match {
+ case FalseLiteral => true
+ case Literal(null, _) => true
+ case _ => false
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
+ case If(Literal(null, _), _, falseValue) => falseValue
- case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+ case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
- val newBranches = branches.filter(_._1 != FalseLiteral)
+ val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
@@ -869,12 +921,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]]
- * that were defined in the projection.
+ * Pushes [[Filter]] operators through many operators iff:
+ * 1) the operator is deterministic
+ * 2) the predicate is deterministic and the operator will not change any of rows.
*
* This heuristic is valid assuming the expression evaluation cost is minimal.
*/
-object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper {
+object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
@@ -891,41 +944,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe
})
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
- }
-
-}
-
-/**
- * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
- * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
- */
-object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case filter @ Filter(condition, g: Generate) =>
- // Predicates that reference attributes produced by the `Generate` operator cannot
- // be pushed below the operator.
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
- cond.references.subsetOf(g.child.outputSet) && cond.deterministic
- }
- if (pushDown.nonEmpty) {
- val pushDownPredicate = pushDown.reduce(And)
- val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
- g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
- if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
- } else {
- filter
- }
- }
-}
-
-/**
- * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
- * non-aggregate attributes (typically literals or grouping expressions).
- */
-object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, aggregate: Aggregate) =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
@@ -951,25 +970,91 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
} else {
filter
}
+
+ case filter @ Filter(condition, child)
+ if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
+ // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+ cond.deterministic
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownCond = pushDown.reduceLeft(And)
+ val output = child.output
+ val newGrandChildren = child.children.map { grandchild =>
+ val newCond = pushDownCond transform {
+ case e if output.exists(_.semanticEquals(e)) =>
+ grandchild.output(output.indexWhere(_.semanticEquals(e)))
+ }
+ assert(newCond.references.subsetOf(grandchild.outputSet))
+ Filter(newCond, grandchild)
+ }
+ val newChild = child.withNewChildren(newGrandChildren)
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
+
+ case filter @ Filter(condition, e @ Except(left, _)) =>
+ pushDownPredicate(filter, e.left) { predicate =>
+ e.copy(left = Filter(predicate, left))
+ }
+
+ // two filters should be combine together by other rules
+ case filter @ Filter(_, f: Filter) => filter
+ // should not push predicates through sample, or will generate different results.
+ case filter @ Filter(_, s: Sample) => filter
+ // TODO: push predicates through expand
+ case filter @ Filter(_, e: Expand) => filter
+
+ case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ pushDownPredicate(filter, u.child) { predicate =>
+ u.withNewChildren(Seq(Filter(predicate, u.child)))
+ }
+ }
+
+ private def pushDownPredicate(
+ filter: Filter,
+ grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
+ // Only push down the predicates that is deterministic and all the referenced attributes
+ // come from grandchild.
+ // TODO: non-deterministic predicates could be pushed through some operators that do not change
+ // the rows.
+ val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond =>
+ cond.deterministic && cond.references.subsetOf(grandchild.outputSet)
+ }
+ if (pushDown.nonEmpty) {
+ val newChild = insertFilter(pushDown.reduceLeft(And))
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
}
}
/**
- * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
- * one condition.
- *
- * The order of joins will not be changed if all of them already have at least one condition.
- */
+ * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
+ * one condition.
+ *
+ * The order of joins will not be changed if all of them already have at least one condition.
+ */
object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
- * Join a list of plans together and push down the conditions into them.
- *
- * The joined plan are picked from left to right, prefer those has at least one join condition.
- *
- * @param input a list of LogicalPlans to join.
- * @param conditions a list of condition for join.
- */
+ * Join a list of plans together and push down the conditions into them.
+ *
+ * The joined plan are picked from left to right, prefer those has at least one join condition.
+ *
+ * @param input a list of LogicalPlans to join.
+ * @param conditions a list of condition for join.
+ */
@tailrec
def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
@@ -1110,7 +1195,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
- case _ @ (LeftOuter | LeftSemi) =>
+ case LeftOuter | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1131,7 +1216,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
- case _ @ (Inner | LeftSemi) =>
+ case Inner | LeftExistence(_) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1225,13 +1310,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
+ MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
- case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
- val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
+ val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
@@ -1313,3 +1398,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
+ * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
+ * the deserializer in filter condition to save the extra serialization at last.
+ */
+object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
+ val numObjects = condition.collect {
+ case a: Attribute if a == d.output.head => a
+ }.length
+
+ if (numObjects > 1) {
+ // If the filter condition references the object more than one times, we should not embed
+ // deserializer in it as the deserialization will happen many times and slow down the
+ // execution.
+ // TODO: we can still embed it if we can make sure subexpression elimination works here.
+ s
+ } else {
+ val newCondition = condition transform {
+ case a: Attribute if a == d.output.head => d.deserializer.child
+ }
+ Filter(newCondition, d.child)
+ }
+ }
+}