diff options
author | Yin Huai <yhuai@databricks.com> | 2015-08-02 23:32:09 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-08-02 23:32:09 -0700 |
commit | 687c8c37150f4c93f8e57d86bb56321a4891286b (patch) | |
tree | 5fc768cdf7b01dae261706c148c7fcd3cf622b9d /sql/catalyst/src/main | |
parent | 4cdd8ecd66769316e8593da7790b84cd867968cd (diff) | |
download | spark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.gz spark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.bz2 spark-687c8c37150f4c93f8e57d86bb56321a4891286b.zip |
[SPARK-9372] [SQL] Filter nulls in join keys
This PR adds an optimization rule, `FilterNullsInJoinKey`, to add `Filter` before join operators to filter out rows having null values for join keys.
This optimization is guarded by a new SQL conf, `spark.sql.advancedOptimization`.
The code in this PR was authored by yhuai; I'm opening this PR to factor out this change from #7685, a larger pull request which contains two other optimizations.
Author: Yin Huai <yhuai@databricks.com>
Author: Josh Rosen <joshrosen@databricks.com>
Closes #7768 from JoshRosen/filter-nulls-in-join-key and squashes the following commits:
c02fc3f [Yin Huai] Address Josh's comments.
0a8e096 [Yin Huai] Update comments.
ea7d5a6 [Yin Huai] Make sure we do not keep adding filters.
be88760 [Yin Huai] Make it clear that FilterNullsInJoinKeySuite.scala is used to test FilterNullsInJoinKey.
8bb39ad [Yin Huai] Fix non-deterministic tests.
303236b [Josh Rosen] Revert changes that are unrelated to null join key filtering
40eeece [Josh Rosen] Merge remote-tracking branch 'origin/master' into filter-nulls-in-join-key
c57a954 [Yin Huai] Bug fix.
d3d2e64 [Yin Huai] First round of cleanup.
f9516b0 [Yin Huai] Style
c6667e7 [Yin Huai] Add PartitioningCollection.
e616d3b [Yin Huai] wip
7c2d2d8 [Yin Huai] Bug fix and refactoring.
69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning.
d5b84c3 [Yin Huai] Do not add unnessary filters.
2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.
Diffstat (limited to 'sql/catalyst/src/main')
3 files changed, 119 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 287718fab7..d58c475693 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } +/** + * A predicate that is evaluated to be true if there are at least `n` null values. + */ +case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: InternalRow): Boolean = { + var numNulls = 0 + var i = 0 + while (i < childrenArray.length && numNulls < n) { + val evalC = childrenArray(i).eval(input) + if (evalC == null) { + numNulls += 1 + } + i += 1 + } + numNulls >= n + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numNulls = ctx.freshName("numNulls") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($numNulls < $n) { + ${eval.code} + if (${eval.isNull}) { + $numNulls += 1; + } + } + """ + }.mkString("\n") + s""" + int $numNulls = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $numNulls >= $n; + """ + } +} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 29d706dcb3..e4b6294dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,8 +31,14 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -object DefaultOptimizer extends Optimizer { - val batches = +class DefaultOptimizer extends Optimizer { + + /** + * Override to provide additional rules for the "Operator Optimizations" batch. + */ + val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown, - SamplePushDown, - PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - ColumnPruning, + SetOperationPushDown :: + SamplePushDown :: + PushPredicateThroughJoin :: + PushPredicateThroughProject :: + PushPredicateThroughGenerate :: + ColumnPruning :: // Operator combine - ProjectCollapsing, - CombineFilters, - CombineLimits, + ProjectCollapsing :: + CombineFilters :: + CombineLimits :: // Constant folding - NullPropagation, - OptimizeIn, - ConstantFolding, - LikeSimplification, - BooleanSimplification, - RemovePositive, - SimplifyFilters, - SimplifyCasts, - SimplifyCaseConversionExpressions) :: + NullPropagation :: + OptimizeIn :: + ConstantFolding :: + LikeSimplification :: + BooleanSimplification :: + RemovePositive :: + SimplifyFilters :: + SimplifyCasts :: + SimplifyCaseConversionExpressions :: + extendedOperatorOptimizationRules.toList : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + // We need to preserve the nullability of c's output. + // So, we first create a outputMap and if a reference is from the output of + // c, we use that output attribute from c. + val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) + val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq + Project(projectList, c) } else { c } + } } /** @@ -517,6 +530,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => + // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and + // Not(AtLeastNNulls(1, e2)) + // (this is used to make sure there is no null in the result of e1 and e2 and + // they are added by FilterNullsInJoinKey optimziation rule), we can + // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). + Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index aacfc86ab0..54b5f49772 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -86,7 +86,37 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + /** + * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children + * have at least one null value and atLeastNNulls.children are all attributes. + */ + private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { + val expressions = atLeastNNulls.children + val n = atLeastNNulls.n + if (n != 1) { + // AtLeastNNulls is not used to check if atLeastNNulls.children have + // at least one null value. + false + } else { + // AtLeastNNulls is used to check if atLeastNNulls.children have + // at least one null value. We need to make sure all atLeastNNulls.children + // are attributes. + expressions.forall(_.isInstanceOf[Attribute]) + } + } + + override def output: Seq[Attribute] = condition match { + case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => + // The condition is used to make sure that there is no null value in + // a.children. + val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) + child.output.map { + case attr if nonNullableAttributes.contains(attr) => + attr.withNullability(false) + case attr => attr + } + case _ => child.output + } } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { |