From db5832708267f4a8413b0ad19c6a454c93f7800e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 3 Aug 2015 14:51:36 -0700 Subject: Revert "[SPARK-9372] [SQL] Filter nulls in join keys" This reverts commit 687c8c37150f4c93f8e57d86bb56321a4891286b. --- .../sql/catalyst/expressions/nullFunctions.scala | 48 +---- .../spark/sql/catalyst/optimizer/Optimizer.scala | 64 ++---- .../catalyst/plans/logical/basicOperators.scala | 32 +-- .../expressions/ExpressionEvalHelper.scala | 4 +- .../catalyst/expressions/MathFunctionsSuite.scala | 3 +- .../catalyst/expressions/NullFunctionsSuite.scala | 49 +---- .../apache/spark/sql/DataFrameNaFunctions.scala | 2 +- .../main/scala/org/apache/spark/sql/SQLConf.scala | 6 - .../scala/org/apache/spark/sql/SQLContext.scala | 5 +- .../optimizer/extendedOperatorOptimizations.scala | 160 -------------- .../sql/optimizer/FilterNullsInJoinKeySuite.scala | 236 --------------------- 11 files changed, 37 insertions(+), 572 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d58c475693..287718fab7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,58 +210,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } -/** - * A predicate that is evaluated to be true if there are at least `n` null values. - */ -case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { - override def nullable: Boolean = false - override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" - - private[this] val childrenArray = children.toArray - - override def eval(input: InternalRow): Boolean = { - var numNulls = 0 - var i = 0 - while (i < childrenArray.length && numNulls < n) { - val evalC = childrenArray(i).eval(input) - if (evalC == null) { - numNulls += 1 - } - i += 1 - } - numNulls >= n - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numNulls = ctx.freshName("numNulls") - val code = children.map { e => - val eval = e.gen(ctx) - s""" - if ($numNulls < $n) { - ${eval.code} - if (${eval.isNull}) { - $numNulls += 1; - } - } - """ - }.mkString("\n") - s""" - int $numNulls = 0; - $code - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $numNulls >= $n; - """ - } -} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4b6294dc7..29d706dcb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,14 +31,8 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -class DefaultOptimizer extends Optimizer { - - /** - * Override to provide additional rules for the "Operator Optimizations" batch. - */ - val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil - - lazy val batches = +object DefaultOptimizer extends Optimizer { + val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -47,27 +41,26 @@ class DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown :: - SamplePushDown :: - PushPredicateThroughJoin :: - PushPredicateThroughProject :: - PushPredicateThroughGenerate :: - ColumnPruning :: + SetOperationPushDown, + SamplePushDown, + PushPredicateThroughJoin, + PushPredicateThroughProject, + PushPredicateThroughGenerate, + ColumnPruning, // Operator combine - ProjectCollapsing :: - CombineFilters :: - CombineLimits :: + ProjectCollapsing, + CombineFilters, + CombineLimits, // Constant folding - NullPropagation :: - OptimizeIn :: - ConstantFolding :: - LikeSimplification :: - BooleanSimplification :: - RemovePositive :: - SimplifyFilters :: - SimplifyCasts :: - SimplifyCaseConversionExpressions :: - extendedOperatorOptimizationRules.toList : _*) :: + NullPropagation, + OptimizeIn, + ConstantFolding, + LikeSimplification, + BooleanSimplification, + RemovePositive, + SimplifyFilters, + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -229,18 +222,12 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - // We need to preserve the nullability of c's output. - // So, we first create a outputMap and if a reference is from the output of - // c, we use that output attribute from c. - val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) - val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq - Project(projectList, c) + Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { c } - } } /** @@ -530,13 +517,6 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => - // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and - // Not(AtLeastNNulls(1, e2)) - // (this is used to make sure there is no null in the result of e1 and e2 and - // they are added by FilterNullsInJoinKey optimziation rule), we can - // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). - Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 54b5f49772..aacfc86ab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -86,37 +86,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - /** - * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children - * have at least one null value and atLeastNNulls.children are all attributes. - */ - private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { - val expressions = atLeastNNulls.children - val n = atLeastNNulls.n - if (n != 1) { - // AtLeastNNulls is not used to check if atLeastNNulls.children have - // at least one null value. - false - } else { - // AtLeastNNulls is used to check if atLeastNNulls.children have - // at least one null value. We need to make sure all atLeastNNulls.children - // are attributes. - expressions.forall(_.isInstanceOf[Attribute]) - } - } - - override def output: Seq[Attribute] = condition match { - case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => - // The condition is used to make sure that there is no null value in - // a.children. - val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) - child.output.map { - case attr if nonNullableAttributes.contains(attr) => - attr.withNullability(false) - case attr => attr - } - case _ => child.output - } + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3e55151298..a41185b4d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,8 +31,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => - protected val defaultOptimizer = new DefaultOptimizer - protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -188,7 +186,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 649a5b44dc..9fcb548af6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -148,7 +149,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index bf197124d8..ace6c15dc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNullNans") { + test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,46 +96,11 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) - } - - test("AtLeastNNull") { - val mix = Seq(Literal("x"), - Literal.create(null, StringType), - Literal.create(null, DoubleType), - Literal(Double.NaN), - Literal(5f)) - - val nanOnly = Seq(Literal("x"), - Literal(10.0), - Literal(Float.NaN), - Literal(math.log(-2)), - Literal(Double.MaxValue)) - - val nullOnly = Seq(Literal("x"), - Literal.create(null, DoubleType), - Literal.create(null, DecimalType.USER_DEFAULT), - Literal(Float.MaxValue), - Literal(false)) - - checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ea85f0657a..a4fd4cf3b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 41ba1c7fe0..f836122b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -413,10 +413,6 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) - val ADVANCED_SQL_OPTIMIZATION = booleanConf( - "spark.sql.advancedOptimization", - defaultValue = Some(true), isPublic = false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 31e2b508d4..dbb2a09846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,7 +41,6 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -157,9 +156,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { - override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil - } + protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala deleted file mode 100644 index 5a4dde5756..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * An optimization rule used to insert Filters to filter out rows whose equal join keys - * have at least one null values. For this kind of rows, they will not contribute to - * the join results of equal joins because a null does not equal another null. We can - * filter them out before shuffling join input rows. For example, we have two tables - * - * table1(key String, value Int) - * "str1"|1 - * null |2 - * - * table2(key String, value Int) - * "str1"|3 - * null |4 - * - * For a inner equal join, the result will be - * "str1"|1|"str1"|3 - * - * those two rows having null as the value of key will not contribute to the result. - * So, we can filter them out early. - * - * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. - * - */ -case class FilterNullsInJoinKey( - sqlContext: SQLContext) - extends Rule[LogicalPlan] { - - /** - * Checks if we need to add a Filter operator. We will add a Filter when - * there is any attribute in `keys` whose corresponding attribute of `keys` - * in `plan.output` is still nullable (`nullable` field is `true`). - */ - private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { - val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) - plan.output.filter(keyAttributeSet.contains).exists(_.nullable) - } - - /** - * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. - */ - private def addFilterIfNecessary( - keys: Seq[Expression], - child: LogicalPlan): LogicalPlan = { - // We get all attributes from keys. - val attributes = keys.filter(_.isInstanceOf[Attribute]) - - // Then, we create a Filter to make sure these attributes are non-nullable. - val filter = - if (attributes.nonEmpty) { - Filter(Not(AtLeastNNulls(1, attributes)), child) - } else { - child - } - - filter - } - - /** - * We reconstruct the join condition. - */ - private def reconstructJoinCondition( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - otherPredicate: Option[Expression]): Expression = { - // First, we rewrite the equal condition part. When we extract those keys, - // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). - val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { - case (l, r) => EqualTo(l, r) - }.reduce(And) - - // Then, we add otherPredicate. When we extract those equal condition part, - // we use splitConjunctivePredicates. So, it is safe to use - // And(rewrittenEqualJoinCondition, c). - val rewrittenJoinCondition = otherPredicate - .map(c => And(rewrittenEqualJoinCondition, c)) - .getOrElse(rewrittenEqualJoinCondition) - - rewrittenJoinCondition - } - - def apply(plan: LogicalPlan): LogicalPlan = { - if (!sqlContext.conf.advancedSqlOptimizations) { - plan - } else { - plan transform { - case join: Join => join match { - // For a inner join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) - - // For a left outer join having equal join condition part, we can add a filter - // to the right side of the join operator. - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(rightKeys, right) => - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) - - // For a right outer join having equal join condition part, we can add a filter - // to the left side of the join operator. - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) - - // For a left semi join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) - - case other => other - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala deleted file mode 100644 index f98e4acafb..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} -import org.apache.spark.sql.catalyst.optimizer._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.test.TestSQLContext - -/** This is the test suite for FilterNullsInJoinKey optimization rule. */ -class FilterNullsInJoinKeySuite extends PlanTest { - - // We add predicate pushdown rules at here to make sure we do not - // create redundant Filter operators. Also, because the attribute ordering of - // the Project operator added by ColumnPruning may be not deterministic - // (the ordering may depend on the testing environment), - // we first construct the plan with expected Filter operators and then - // run the optimizer to add the the Project for column pruning. - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Operator Optimizations", FixedPoint(100), - FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. - CombineFilters, - PushPredicateThroughProject, - BooleanSimplification, - PushPredicateThroughJoin, - PushPredicateThroughGenerate, - ColumnPruning, - ProjectCollapsing) :: Nil - } - - val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) - - val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) - - test("inner join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For an inner join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("make sure we do not keep adding filters") { - val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some('a === 'e)) - .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) - - val optimized = Optimize.execute(joinedPlan.analyze) - val conditions = optimized.collect { - case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs - } - - // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. - assert(conditions.length === 3) - - // Make sure attribtues are indeed a, b, e, i, and j. - assert( - conditions.flatMap(exprs => exprs).toSet === - joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) - } - - test("inner join (partially optimized)") { - val joinCondition = - ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // We cannot extract attribute from the left join key. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("inner join (not optimized)") { - val nonOptimizedJoinConditions = - Some('c - 100 + 'd === 'g + 1 - 'h) :: - Some('d > 'h || 'c === 'g) :: - Some('d + 'g + 'c > 'd - 'h) :: Nil - - nonOptimizedJoinConditions.foreach { joinCondition => - val joinedPlan = - leftRelation - .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) - .select('a, 'c, 'f, 'd, 'h, 'g) - - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - } - - test("left outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left outer join, FilterNullsInJoinKey add filter to the right side. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("right outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a right outer join, FilterNullsInJoinKey add filter to the left side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("full outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, FullOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - // FilterNullsInJoinKey does not fire for a full outer join. - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - - test("left semi join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left semi join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } -} -- cgit v1.2.3