aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-02 23:32:09 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-08-02 23:32:09 -0700
commit687c8c37150f4c93f8e57d86bb56321a4891286b (patch)
tree5fc768cdf7b01dae261706c148c7fcd3cf622b9d
parent4cdd8ecd66769316e8593da7790b84cd867968cd (diff)
downloadspark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.gz
spark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.bz2
spark-687c8c37150f4c93f8e57d86bb56321a4891286b.zip
[SPARK-9372] [SQL] Filter nulls in join keys
This PR adds an optimization rule, `FilterNullsInJoinKey`, to add `Filter` before join operators to filter out rows having null values for join keys. This optimization is guarded by a new SQL conf, `spark.sql.advancedOptimization`. The code in this PR was authored by yhuai; I'm opening this PR to factor out this change from #7685, a larger pull request which contains two other optimizations. Author: Yin Huai <yhuai@databricks.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #7768 from JoshRosen/filter-nulls-in-join-key and squashes the following commits: c02fc3f [Yin Huai] Address Josh's comments. 0a8e096 [Yin Huai] Update comments. ea7d5a6 [Yin Huai] Make sure we do not keep adding filters. be88760 [Yin Huai] Make it clear that FilterNullsInJoinKeySuite.scala is used to test FilterNullsInJoinKey. 8bb39ad [Yin Huai] Fix non-deterministic tests. 303236b [Josh Rosen] Revert changes that are unrelated to null join key filtering 40eeece [Josh Rosen] Merge remote-tracking branch 'origin/master' into filter-nulls-in-join-key c57a954 [Yin Huai] Bug fix. d3d2e64 [Yin Huai] First round of cleanup. f9516b0 [Yin Huai] Style c6667e7 [Yin Huai] Add PartitioningCollection. e616d3b [Yin Huai] wip 7c2d2d8 [Yin Huai] Bug fix and refactoring. 69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. d5b84c3 [Yin Huai] Do not add unnessary filters. 2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala64
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala160
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala236
11 files changed, 572 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 287718fab7..d58c475693 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}
+/**
+ * A predicate that is evaluated to be true if there are at least `n` null values.
+ */
+case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate {
+ override def nullable: Boolean = false
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})"
+
+ private[this] val childrenArray = children.toArray
+
+ override def eval(input: InternalRow): Boolean = {
+ var numNulls = 0
+ var i = 0
+ while (i < childrenArray.length && numNulls < n) {
+ val evalC = childrenArray(i).eval(input)
+ if (evalC == null) {
+ numNulls += 1
+ }
+ i += 1
+ }
+ numNulls >= n
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val numNulls = ctx.freshName("numNulls")
+ val code = children.map { e =>
+ val eval = e.gen(ctx)
+ s"""
+ if ($numNulls < $n) {
+ ${eval.code}
+ if (${eval.isNull}) {
+ $numNulls += 1;
+ }
+ }
+ """
+ }.mkString("\n")
+ s"""
+ int $numNulls = 0;
+ $code
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = $numNulls >= $n;
+ """
+ }
+}
/**
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
-case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
+case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
- override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
+ override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})"
private[this] val childrenArray = children.toArray
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29d706dcb3..e4b6294dc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,8 +31,14 @@ import org.apache.spark.sql.types._
abstract class Optimizer extends RuleExecutor[LogicalPlan]
-object DefaultOptimizer extends Optimizer {
- val batches =
+class DefaultOptimizer extends Optimizer {
+
+ /**
+ * Override to provide additional rules for the "Operator Optimizations" batch.
+ */
+ val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+
+ lazy val batches =
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
@@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer {
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
- SetOperationPushDown,
- SamplePushDown,
- PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- ColumnPruning,
+ SetOperationPushDown ::
+ SamplePushDown ::
+ PushPredicateThroughJoin ::
+ PushPredicateThroughProject ::
+ PushPredicateThroughGenerate ::
+ ColumnPruning ::
// Operator combine
- ProjectCollapsing,
- CombineFilters,
- CombineLimits,
+ ProjectCollapsing ::
+ CombineFilters ::
+ CombineLimits ::
// Constant folding
- NullPropagation,
- OptimizeIn,
- ConstantFolding,
- LikeSimplification,
- BooleanSimplification,
- RemovePositive,
- SimplifyFilters,
- SimplifyCasts,
- SimplifyCaseConversionExpressions) ::
+ NullPropagation ::
+ OptimizeIn ::
+ ConstantFolding ::
+ LikeSimplification ::
+ BooleanSimplification ::
+ RemovePositive ::
+ SimplifyFilters ::
+ SimplifyCasts ::
+ SimplifyCaseConversionExpressions ::
+ extendedOperatorOptimizationRules.toList : _*) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
@@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/** Applies a projection only when the child is producing unnecessary attributes */
- private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
+ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = {
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
- Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+ // We need to preserve the nullability of c's output.
+ // So, we first create a outputMap and if a reference is from the output of
+ // c, we use that output attribute from c.
+ val outputMap = AttributeMap(c.output.map(attr => (attr, attr)))
+ val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq
+ Project(projectList, c)
} else {
c
}
+ }
}
/**
@@ -517,6 +530,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
*/
object CombineFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) =>
+ // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and
+ // Not(AtLeastNNulls(1, e2))
+ // (this is used to make sure there is no null in the result of e1 and e2 and
+ // they are added by FilterNullsInJoinKey optimziation rule), we can
+ // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)).
+ Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild)
case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index aacfc86ab0..54b5f49772 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -86,7 +86,37 @@ case class Generate(
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = child.output
+ /**
+ * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children
+ * have at least one null value and atLeastNNulls.children are all attributes.
+ */
+ private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = {
+ val expressions = atLeastNNulls.children
+ val n = atLeastNNulls.n
+ if (n != 1) {
+ // AtLeastNNulls is not used to check if atLeastNNulls.children have
+ // at least one null value.
+ false
+ } else {
+ // AtLeastNNulls is used to check if atLeastNNulls.children have
+ // at least one null value. We need to make sure all atLeastNNulls.children
+ // are attributes.
+ expressions.forall(_.isInstanceOf[Attribute])
+ }
+ }
+
+ override def output: Seq[Attribute] = condition match {
+ case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) =>
+ // The condition is used to make sure that there is no null value in
+ // a.children.
+ val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]])
+ child.output.map {
+ case attr if nonNullableAttributes.contains(attr) =>
+ attr.withNullability(false)
+ case attr => attr
+ }
+ case _ => child.output
+ }
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a41185b4d8..3e55151298 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
trait ExpressionEvalHelper {
self: SparkFunSuite =>
+ protected val defaultOptimizer = new DefaultOptimizer
+
protected def create_row(values: Any*): InternalRow = {
InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
@@ -186,7 +188,7 @@ trait ExpressionEvalHelper {
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer.execute(plan)
+ val optimizedPlan = defaultOptimizer.execute(plan)
checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 9fcb548af6..649a5b44dc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.types._
@@ -149,7 +148,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer.execute(plan)
+ val optimizedPlan = defaultOptimizer.execute(plan)
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index ace6c15dc8..bf197124d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("AtLeastNNonNulls") {
+ test("AtLeastNNonNullNans") {
val mix = Seq(Literal("x"),
Literal.create(null, StringType),
Literal.create(null, DoubleType),
@@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal(Float.MaxValue),
Literal(false))
- checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow)
+ }
+
+ test("AtLeastNNull") {
+ val mix = Seq(Literal("x"),
+ Literal.create(null, StringType),
+ Literal.create(null, DoubleType),
+ Literal(Double.NaN),
+ Literal(5f))
+
+ val nanOnly = Seq(Literal("x"),
+ Literal(10.0),
+ Literal(Float.NaN),
+ Literal(math.log(-2)),
+ Literal(Double.MaxValue))
+
+ val nullOnly = Seq(Literal("x"),
+ Literal.create(null, DoubleType),
+ Literal.create(null, DecimalType.USER_DEFAULT),
+ Literal(Float.MaxValue),
+ Literal(false))
+
+ checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index a4fd4cf3b3..ea85f0657a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
- val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
+ val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name)))
df.filter(Column(predicate))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 6644e85d4a..387960c4b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -413,6 +413,10 @@ private[spark] object SQLConf {
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
+ val ADVANCED_SQL_OPTIMIZATION = booleanConf(
+ "spark.sql.advancedOptimization",
+ defaultValue = Some(true), isPublic = false)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -484,6 +488,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
+ private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION)
+
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbb2a09846..31e2b508d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.optimizer.FilterNullsInJoinKey
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@transient
- protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
+ protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer {
+ override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil
+ }
@transient
protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
new file mode 100644
index 0000000000..5a4dde5756
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.optimizer
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * An optimization rule used to insert Filters to filter out rows whose equal join keys
+ * have at least one null values. For this kind of rows, they will not contribute to
+ * the join results of equal joins because a null does not equal another null. We can
+ * filter them out before shuffling join input rows. For example, we have two tables
+ *
+ * table1(key String, value Int)
+ * "str1"|1
+ * null |2
+ *
+ * table2(key String, value Int)
+ * "str1"|3
+ * null |4
+ *
+ * For a inner equal join, the result will be
+ * "str1"|1|"str1"|3
+ *
+ * those two rows having null as the value of key will not contribute to the result.
+ * So, we can filter them out early.
+ *
+ * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false.
+ *
+ */
+case class FilterNullsInJoinKey(
+ sqlContext: SQLContext)
+ extends Rule[LogicalPlan] {
+
+ /**
+ * Checks if we need to add a Filter operator. We will add a Filter when
+ * there is any attribute in `keys` whose corresponding attribute of `keys`
+ * in `plan.output` is still nullable (`nullable` field is `true`).
+ */
+ private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = {
+ val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute]))
+ plan.output.filter(keyAttributeSet.contains).exists(_.nullable)
+ }
+
+ /**
+ * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable.
+ */
+ private def addFilterIfNecessary(
+ keys: Seq[Expression],
+ child: LogicalPlan): LogicalPlan = {
+ // We get all attributes from keys.
+ val attributes = keys.filter(_.isInstanceOf[Attribute])
+
+ // Then, we create a Filter to make sure these attributes are non-nullable.
+ val filter =
+ if (attributes.nonEmpty) {
+ Filter(Not(AtLeastNNulls(1, attributes)), child)
+ } else {
+ child
+ }
+
+ filter
+ }
+
+ /**
+ * We reconstruct the join condition.
+ */
+ private def reconstructJoinCondition(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ otherPredicate: Option[Expression]): Expression = {
+ // First, we rewrite the equal condition part. When we extract those keys,
+ // we use splitConjunctivePredicates. So, it is safe to use .reduce(And).
+ val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map {
+ case (l, r) => EqualTo(l, r)
+ }.reduce(And)
+
+ // Then, we add otherPredicate. When we extract those equal condition part,
+ // we use splitConjunctivePredicates. So, it is safe to use
+ // And(rewrittenEqualJoinCondition, c).
+ val rewrittenJoinCondition = otherPredicate
+ .map(c => And(rewrittenEqualJoinCondition, c))
+ .getOrElse(rewrittenEqualJoinCondition)
+
+ rewrittenJoinCondition
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!sqlContext.conf.advancedSqlOptimizations) {
+ plan
+ } else {
+ plan transform {
+ case join: Join => join match {
+ // For a inner join having equal join condition part, we can add filters
+ // to both sides of the join operator.
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
+ val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+ val withRightFilter = addFilterIfNecessary(rightKeys, right)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition))
+
+ // For a left outer join having equal join condition part, we can add a filter
+ // to the right side of the join operator.
+ case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(rightKeys, right) =>
+ val withRightFilter = addFilterIfNecessary(rightKeys, right)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition))
+
+ // For a right outer join having equal join condition part, we can add a filter
+ // to the left side of the join operator.
+ case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(leftKeys, left) =>
+ val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition))
+
+ // For a left semi join having equal join condition part, we can add filters
+ // to both sides of the join operator.
+ case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
+ val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+ val withRightFilter = addFilterIfNecessary(rightKeys, right)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition))
+
+ case other => other
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
new file mode 100644
index 0000000000..f98e4acafb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls}
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.test.TestSQLContext
+
+/** This is the test suite for FilterNullsInJoinKey optimization rule. */
+class FilterNullsInJoinKeySuite extends PlanTest {
+
+ // We add predicate pushdown rules at here to make sure we do not
+ // create redundant Filter operators. Also, because the attribute ordering of
+ // the Project operator added by ColumnPruning may be not deterministic
+ // (the ordering may depend on the testing environment),
+ // we first construct the plan with expected Filter operators and then
+ // run the optimizer to add the the Project for column pruning.
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubQueries) ::
+ Batch("Operator Optimizations", FixedPoint(100),
+ FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite.
+ CombineFilters,
+ PushPredicateThroughProject,
+ BooleanSimplification,
+ PushPredicateThroughJoin,
+ PushPredicateThroughGenerate,
+ ColumnPruning,
+ ProjectCollapsing) :: Nil
+ }
+
+ val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int)
+
+ val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int)
+
+ test("inner join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For an inner join, FilterNullsInJoinKey add filter to both side.
+ val correctLeft =
+ leftRelation
+ .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ correctLeft
+ .join(correctRight, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("make sure we do not keep adding filters") {
+ val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int)
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, Inner, Some('a === 'e))
+ .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j))
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+ val conditions = optimized.collect {
+ case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs
+ }
+
+ // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables.
+ assert(conditions.length === 3)
+
+ // Make sure attribtues are indeed a, b, e, i, and j.
+ assert(
+ conditions.flatMap(exprs => exprs).toSet ===
+ joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet)
+ }
+
+ test("inner join (partially optimized)") {
+ val joinCondition =
+ ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // We cannot extract attribute from the left join key.
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ leftRelation
+ .join(correctRight, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("inner join (not optimized)") {
+ val nonOptimizedJoinConditions =
+ Some('c - 100 + 'd === 'g + 1 - 'h) ::
+ Some('d > 'h || 'c === 'g) ::
+ Some('d + 'g + 'c > 'd - 'h) :: Nil
+
+ nonOptimizedJoinConditions.foreach { joinCondition =>
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition)
+ .select('a, 'c, 'f, 'd, 'h, 'g)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
+ }
+ }
+
+ test("left outer join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, LeftOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For a left outer join, FilterNullsInJoinKey add filter to the right side.
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ leftRelation
+ .join(correctRight, LeftOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("right outer join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, RightOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For a right outer join, FilterNullsInJoinKey add filter to the left side.
+ val correctLeft =
+ leftRelation
+ .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+ val correctAnswer =
+ correctLeft
+ .join(rightRelation, RightOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("full outer join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, FullOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ // FilterNullsInJoinKey does not fire for a full outer join.
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
+ }
+
+ test("left semi join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, LeftSemi, Some(joinCondition))
+ .select('a, 'd)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For a left semi join, FilterNullsInJoinKey add filter to both side.
+ val correctLeft =
+ leftRelation
+ .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ correctLeft
+ .join(correctRight, LeftSemi, Some(joinCondition))
+ .select('a, 'd)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+}