aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala64
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala160
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala236
11 files changed, 37 insertions, 572 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index d58c475693..287718fab7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -210,58 +210,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}
-/**
- * A predicate that is evaluated to be true if there are at least `n` null values.
- */
-case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate {
- override def nullable: Boolean = false
- override def foldable: Boolean = children.forall(_.foldable)
- override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})"
-
- private[this] val childrenArray = children.toArray
-
- override def eval(input: InternalRow): Boolean = {
- var numNulls = 0
- var i = 0
- while (i < childrenArray.length && numNulls < n) {
- val evalC = childrenArray(i).eval(input)
- if (evalC == null) {
- numNulls += 1
- }
- i += 1
- }
- numNulls >= n
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val numNulls = ctx.freshName("numNulls")
- val code = children.map { e =>
- val eval = e.gen(ctx)
- s"""
- if ($numNulls < $n) {
- ${eval.code}
- if (${eval.isNull}) {
- $numNulls += 1;
- }
- }
- """
- }.mkString("\n")
- s"""
- int $numNulls = 0;
- $code
- boolean ${ev.isNull} = false;
- boolean ${ev.primitive} = $numNulls >= $n;
- """
- }
-}
/**
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
-case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate {
+case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
- override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})"
+ override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
private[this] val childrenArray = children.toArray
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e4b6294dc7..29d706dcb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,14 +31,8 @@ import org.apache.spark.sql.types._
abstract class Optimizer extends RuleExecutor[LogicalPlan]
-class DefaultOptimizer extends Optimizer {
-
- /**
- * Override to provide additional rules for the "Operator Optimizations" batch.
- */
- val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
-
- lazy val batches =
+object DefaultOptimizer extends Optimizer {
+ val batches =
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
@@ -47,27 +41,26 @@ class DefaultOptimizer extends Optimizer {
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
- SetOperationPushDown ::
- SamplePushDown ::
- PushPredicateThroughJoin ::
- PushPredicateThroughProject ::
- PushPredicateThroughGenerate ::
- ColumnPruning ::
+ SetOperationPushDown,
+ SamplePushDown,
+ PushPredicateThroughJoin,
+ PushPredicateThroughProject,
+ PushPredicateThroughGenerate,
+ ColumnPruning,
// Operator combine
- ProjectCollapsing ::
- CombineFilters ::
- CombineLimits ::
+ ProjectCollapsing,
+ CombineFilters,
+ CombineLimits,
// Constant folding
- NullPropagation ::
- OptimizeIn ::
- ConstantFolding ::
- LikeSimplification ::
- BooleanSimplification ::
- RemovePositive ::
- SimplifyFilters ::
- SimplifyCasts ::
- SimplifyCaseConversionExpressions ::
- extendedOperatorOptimizationRules.toList : _*) ::
+ NullPropagation,
+ OptimizeIn,
+ ConstantFolding,
+ LikeSimplification,
+ BooleanSimplification,
+ RemovePositive,
+ SimplifyFilters,
+ SimplifyCasts,
+ SimplifyCaseConversionExpressions) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
@@ -229,18 +222,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/** Applies a projection only when the child is producing unnecessary attributes */
- private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = {
+ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
- // We need to preserve the nullability of c's output.
- // So, we first create a outputMap and if a reference is from the output of
- // c, we use that output attribute from c.
- val outputMap = AttributeMap(c.output.map(attr => (attr, attr)))
- val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq
- Project(projectList, c)
+ Project(allReferences.filter(c.outputSet.contains).toSeq, c)
} else {
c
}
- }
}
/**
@@ -530,13 +517,6 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
*/
object CombineFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) =>
- // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and
- // Not(AtLeastNNulls(1, e2))
- // (this is used to make sure there is no null in the result of e1 and e2 and
- // they are added by FilterNullsInJoinKey optimziation rule), we can
- // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)).
- Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild)
case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 54b5f49772..aacfc86ab0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -86,37 +86,7 @@ case class Generate(
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
- /**
- * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children
- * have at least one null value and atLeastNNulls.children are all attributes.
- */
- private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = {
- val expressions = atLeastNNulls.children
- val n = atLeastNNulls.n
- if (n != 1) {
- // AtLeastNNulls is not used to check if atLeastNNulls.children have
- // at least one null value.
- false
- } else {
- // AtLeastNNulls is used to check if atLeastNNulls.children have
- // at least one null value. We need to make sure all atLeastNNulls.children
- // are attributes.
- expressions.forall(_.isInstanceOf[Attribute])
- }
- }
-
- override def output: Seq[Attribute] = condition match {
- case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) =>
- // The condition is used to make sure that there is no null value in
- // a.children.
- val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]])
- child.output.map {
- case attr if nonNullableAttributes.contains(attr) =>
- attr.withNullability(false)
- case attr => attr
- }
- case _ => child.output
- }
+ override def output: Seq[Attribute] = child.output
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 3e55151298..a41185b4d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -31,8 +31,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
trait ExpressionEvalHelper {
self: SparkFunSuite =>
- protected val defaultOptimizer = new DefaultOptimizer
-
protected def create_row(values: Any*): InternalRow = {
InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
@@ -188,7 +186,7 @@ trait ExpressionEvalHelper {
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = defaultOptimizer.execute(plan)
+ val optimizedPlan = DefaultOptimizer.execute(plan)
checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 649a5b44dc..9fcb548af6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.types._
@@ -148,7 +149,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = defaultOptimizer.execute(plan)
+ val optimizedPlan = DefaultOptimizer.execute(plan)
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index bf197124d8..ace6c15dc8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("AtLeastNNonNullNans") {
+ test("AtLeastNNonNulls") {
val mix = Seq(Literal("x"),
Literal.create(null, StringType),
Literal.create(null, DoubleType),
@@ -96,46 +96,11 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal(Float.MaxValue),
Literal(false))
- checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow)
- }
-
- test("AtLeastNNull") {
- val mix = Seq(Literal("x"),
- Literal.create(null, StringType),
- Literal.create(null, DoubleType),
- Literal(Double.NaN),
- Literal(5f))
-
- val nanOnly = Seq(Literal("x"),
- Literal(10.0),
- Literal(Float.NaN),
- Literal(math.log(-2)),
- Literal(Double.MaxValue))
-
- val nullOnly = Seq(Literal("x"),
- Literal.create(null, DoubleType),
- Literal.create(null, DecimalType.USER_DEFAULT),
- Literal(Float.MaxValue),
- Literal(false))
-
- checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow)
- checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow)
- checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow)
- checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index ea85f0657a..a4fd4cf3b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
- val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name)))
+ val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
df.filter(Column(predicate))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 41ba1c7fe0..f836122b3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -413,10 +413,6 @@ private[spark] object SQLConf {
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
- val ADVANCED_SQL_OPTIMIZATION = booleanConf(
- "spark.sql.advancedOptimization",
- defaultValue = Some(true), isPublic = false)
-
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
- private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION)
-
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 31e2b508d4..dbb2a09846 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -41,7 +41,6 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.optimizer.FilterNullsInJoinKey
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -157,9 +156,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@transient
- protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer {
- override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil
- }
+ protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
@transient
protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
deleted file mode 100644
index 5a4dde5756..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.optimizer
-
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi}
-import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan}
-import org.apache.spark.sql.catalyst.rules.Rule
-
-/**
- * An optimization rule used to insert Filters to filter out rows whose equal join keys
- * have at least one null values. For this kind of rows, they will not contribute to
- * the join results of equal joins because a null does not equal another null. We can
- * filter them out before shuffling join input rows. For example, we have two tables
- *
- * table1(key String, value Int)
- * "str1"|1
- * null |2
- *
- * table2(key String, value Int)
- * "str1"|3
- * null |4
- *
- * For a inner equal join, the result will be
- * "str1"|1|"str1"|3
- *
- * those two rows having null as the value of key will not contribute to the result.
- * So, we can filter them out early.
- *
- * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false.
- *
- */
-case class FilterNullsInJoinKey(
- sqlContext: SQLContext)
- extends Rule[LogicalPlan] {
-
- /**
- * Checks if we need to add a Filter operator. We will add a Filter when
- * there is any attribute in `keys` whose corresponding attribute of `keys`
- * in `plan.output` is still nullable (`nullable` field is `true`).
- */
- private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = {
- val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute]))
- plan.output.filter(keyAttributeSet.contains).exists(_.nullable)
- }
-
- /**
- * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable.
- */
- private def addFilterIfNecessary(
- keys: Seq[Expression],
- child: LogicalPlan): LogicalPlan = {
- // We get all attributes from keys.
- val attributes = keys.filter(_.isInstanceOf[Attribute])
-
- // Then, we create a Filter to make sure these attributes are non-nullable.
- val filter =
- if (attributes.nonEmpty) {
- Filter(Not(AtLeastNNulls(1, attributes)), child)
- } else {
- child
- }
-
- filter
- }
-
- /**
- * We reconstruct the join condition.
- */
- private def reconstructJoinCondition(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- otherPredicate: Option[Expression]): Expression = {
- // First, we rewrite the equal condition part. When we extract those keys,
- // we use splitConjunctivePredicates. So, it is safe to use .reduce(And).
- val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map {
- case (l, r) => EqualTo(l, r)
- }.reduce(And)
-
- // Then, we add otherPredicate. When we extract those equal condition part,
- // we use splitConjunctivePredicates. So, it is safe to use
- // And(rewrittenEqualJoinCondition, c).
- val rewrittenJoinCondition = otherPredicate
- .map(c => And(rewrittenEqualJoinCondition, c))
- .getOrElse(rewrittenEqualJoinCondition)
-
- rewrittenJoinCondition
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = {
- if (!sqlContext.conf.advancedSqlOptimizations) {
- plan
- } else {
- plan transform {
- case join: Join => join match {
- // For a inner join having equal join condition part, we can add filters
- // to both sides of the join operator.
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
- val withLeftFilter = addFilterIfNecessary(leftKeys, left)
- val withRightFilter = addFilterIfNecessary(rightKeys, right)
- val rewrittenJoinCondition =
- reconstructJoinCondition(leftKeys, rightKeys, condition)
-
- Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition))
-
- // For a left outer join having equal join condition part, we can add a filter
- // to the right side of the join operator.
- case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
- if needsFilter(rightKeys, right) =>
- val withRightFilter = addFilterIfNecessary(rightKeys, right)
- val rewrittenJoinCondition =
- reconstructJoinCondition(leftKeys, rightKeys, condition)
-
- Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition))
-
- // For a right outer join having equal join condition part, we can add a filter
- // to the left side of the join operator.
- case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
- if needsFilter(leftKeys, left) =>
- val withLeftFilter = addFilterIfNecessary(leftKeys, left)
- val rewrittenJoinCondition =
- reconstructJoinCondition(leftKeys, rightKeys, condition)
-
- Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition))
-
- // For a left semi join having equal join condition part, we can add filters
- // to both sides of the join operator.
- case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
- if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
- val withLeftFilter = addFilterIfNecessary(leftKeys, left)
- val withRightFilter = addFilterIfNecessary(rightKeys, right)
- val rewrittenJoinCondition =
- reconstructJoinCondition(leftKeys, rightKeys, condition)
-
- Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition))
-
- case other => other
- }
- }
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
deleted file mode 100644
index f98e4acafb..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
+++ /dev/null
@@ -1,236 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.optimizer
-
-import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
-import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls}
-import org.apache.spark.sql.catalyst.optimizer._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
-import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.test.TestSQLContext
-
-/** This is the test suite for FilterNullsInJoinKey optimization rule. */
-class FilterNullsInJoinKeySuite extends PlanTest {
-
- // We add predicate pushdown rules at here to make sure we do not
- // create redundant Filter operators. Also, because the attribute ordering of
- // the Project operator added by ColumnPruning may be not deterministic
- // (the ordering may depend on the testing environment),
- // we first construct the plan with expected Filter operators and then
- // run the optimizer to add the the Project for column pruning.
- object Optimize extends RuleExecutor[LogicalPlan] {
- val batches =
- Batch("Subqueries", Once,
- EliminateSubQueries) ::
- Batch("Operator Optimizations", FixedPoint(100),
- FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite.
- CombineFilters,
- PushPredicateThroughProject,
- BooleanSimplification,
- PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- ColumnPruning,
- ProjectCollapsing) :: Nil
- }
-
- val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int)
-
- val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int)
-
- test("inner join") {
- val joinCondition =
- ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
-
- val joinedPlan =
- leftRelation
- .join(rightRelation, Inner, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- // For an inner join, FilterNullsInJoinKey add filter to both side.
- val correctLeft =
- leftRelation
- .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
-
- val correctRight =
- rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
-
- val correctAnswer =
- correctLeft
- .join(correctRight, Inner, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
- }
-
- test("make sure we do not keep adding filters") {
- val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int)
- val joinedPlan =
- leftRelation
- .join(rightRelation, Inner, Some('a === 'e))
- .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j))
-
- val optimized = Optimize.execute(joinedPlan.analyze)
- val conditions = optimized.collect {
- case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs
- }
-
- // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables.
- assert(conditions.length === 3)
-
- // Make sure attribtues are indeed a, b, e, i, and j.
- assert(
- conditions.flatMap(exprs => exprs).toSet ===
- joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet)
- }
-
- test("inner join (partially optimized)") {
- val joinCondition =
- ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
-
- val joinedPlan =
- leftRelation
- .join(rightRelation, Inner, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- // We cannot extract attribute from the left join key.
- val correctRight =
- rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
-
- val correctAnswer =
- leftRelation
- .join(correctRight, Inner, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
- }
-
- test("inner join (not optimized)") {
- val nonOptimizedJoinConditions =
- Some('c - 100 + 'd === 'g + 1 - 'h) ::
- Some('d > 'h || 'c === 'g) ::
- Some('d + 'g + 'c > 'd - 'h) :: Nil
-
- nonOptimizedJoinConditions.foreach { joinCondition =>
- val joinedPlan =
- leftRelation
- .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition)
- .select('a, 'c, 'f, 'd, 'h, 'g)
-
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
- }
- }
-
- test("left outer join") {
- val joinCondition =
- ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
-
- val joinedPlan =
- leftRelation
- .join(rightRelation, LeftOuter, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- // For a left outer join, FilterNullsInJoinKey add filter to the right side.
- val correctRight =
- rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
-
- val correctAnswer =
- leftRelation
- .join(correctRight, LeftOuter, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
- }
-
- test("right outer join") {
- val joinCondition =
- ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
-
- val joinedPlan =
- leftRelation
- .join(rightRelation, RightOuter, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- // For a right outer join, FilterNullsInJoinKey add filter to the left side.
- val correctLeft =
- leftRelation
- .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
-
- val correctAnswer =
- correctLeft
- .join(rightRelation, RightOuter, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
-
- comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
- }
-
- test("full outer join") {
- val joinCondition =
- ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
-
- val joinedPlan =
- leftRelation
- .join(rightRelation, FullOuter, Some(joinCondition))
- .select('a, 'f, 'd, 'h)
-
- // FilterNullsInJoinKey does not fire for a full outer join.
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
- }
-
- test("left semi join") {
- val joinCondition =
- ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
-
- val joinedPlan =
- leftRelation
- .join(rightRelation, LeftSemi, Some(joinCondition))
- .select('a, 'd)
-
- val optimized = Optimize.execute(joinedPlan.analyze)
-
- // For a left semi join, FilterNullsInJoinKey add filter to both side.
- val correctLeft =
- leftRelation
- .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
-
- val correctRight =
- rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
-
- val correctAnswer =
- correctLeft
- .join(correctRight, LeftSemi, Some(joinCondition))
- .select('a, 'd)
-
- comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
- }
-}