aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-02 23:32:09 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-08-02 23:32:09 -0700
commit687c8c37150f4c93f8e57d86bb56321a4891286b (patch)
tree5fc768cdf7b01dae261706c148c7fcd3cf622b9d /sql/core
parent4cdd8ecd66769316e8593da7790b84cd867968cd (diff)
downloadspark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.gz
spark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.bz2
spark-687c8c37150f4c93f8e57d86bb56321a4891286b.zip
[SPARK-9372] [SQL] Filter nulls in join keys
This PR adds an optimization rule, `FilterNullsInJoinKey`, to add `Filter` before join operators to filter out rows having null values for join keys. This optimization is guarded by a new SQL conf, `spark.sql.advancedOptimization`. The code in this PR was authored by yhuai; I'm opening this PR to factor out this change from #7685, a larger pull request which contains two other optimizations. Author: Yin Huai <yhuai@databricks.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #7768 from JoshRosen/filter-nulls-in-join-key and squashes the following commits: c02fc3f [Yin Huai] Address Josh's comments. 0a8e096 [Yin Huai] Update comments. ea7d5a6 [Yin Huai] Make sure we do not keep adding filters. be88760 [Yin Huai] Make it clear that FilterNullsInJoinKeySuite.scala is used to test FilterNullsInJoinKey. 8bb39ad [Yin Huai] Fix non-deterministic tests. 303236b [Josh Rosen] Revert changes that are unrelated to null join key filtering 40eeece [Josh Rosen] Merge remote-tracking branch 'origin/master' into filter-nulls-in-join-key c57a954 [Yin Huai] Bug fix. d3d2e64 [Yin Huai] First round of cleanup. f9516b0 [Yin Huai] Style c6667e7 [Yin Huai] Add PartitioningCollection. e616d3b [Yin Huai] wip 7c2d2d8 [Yin Huai] Bug fix and refactoring. 69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. d5b84c3 [Yin Huai] Do not add unnessary filters. 2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala160
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala236
5 files changed, 407 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index a4fd4cf3b3..ea85f0657a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
- val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
+ val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name)))
df.filter(Column(predicate))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 6644e85d4a..387960c4b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -413,6 +413,10 @@ private[spark] object SQLConf {
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
+ val ADVANCED_SQL_OPTIMIZATION = booleanConf(
+ "spark.sql.advancedOptimization",
+ defaultValue = Some(true), isPublic = false)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -484,6 +488,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
+ private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION)
+
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbb2a09846..31e2b508d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.optimizer.FilterNullsInJoinKey
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@transient
- protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
+ protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer {
+ override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil
+ }
@transient
protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
new file mode 100644
index 0000000000..5a4dde5756
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.optimizer
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * An optimization rule used to insert Filters to filter out rows whose equal join keys
+ * have at least one null values. For this kind of rows, they will not contribute to
+ * the join results of equal joins because a null does not equal another null. We can
+ * filter them out before shuffling join input rows. For example, we have two tables
+ *
+ * table1(key String, value Int)
+ * "str1"|1
+ * null |2
+ *
+ * table2(key String, value Int)
+ * "str1"|3
+ * null |4
+ *
+ * For a inner equal join, the result will be
+ * "str1"|1|"str1"|3
+ *
+ * those two rows having null as the value of key will not contribute to the result.
+ * So, we can filter them out early.
+ *
+ * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false.
+ *
+ */
+case class FilterNullsInJoinKey(
+ sqlContext: SQLContext)
+ extends Rule[LogicalPlan] {
+
+ /**
+ * Checks if we need to add a Filter operator. We will add a Filter when
+ * there is any attribute in `keys` whose corresponding attribute of `keys`
+ * in `plan.output` is still nullable (`nullable` field is `true`).
+ */
+ private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = {
+ val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute]))
+ plan.output.filter(keyAttributeSet.contains).exists(_.nullable)
+ }
+
+ /**
+ * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable.
+ */
+ private def addFilterIfNecessary(
+ keys: Seq[Expression],
+ child: LogicalPlan): LogicalPlan = {
+ // We get all attributes from keys.
+ val attributes = keys.filter(_.isInstanceOf[Attribute])
+
+ // Then, we create a Filter to make sure these attributes are non-nullable.
+ val filter =
+ if (attributes.nonEmpty) {
+ Filter(Not(AtLeastNNulls(1, attributes)), child)
+ } else {
+ child
+ }
+
+ filter
+ }
+
+ /**
+ * We reconstruct the join condition.
+ */
+ private def reconstructJoinCondition(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ otherPredicate: Option[Expression]): Expression = {
+ // First, we rewrite the equal condition part. When we extract those keys,
+ // we use splitConjunctivePredicates. So, it is safe to use .reduce(And).
+ val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map {
+ case (l, r) => EqualTo(l, r)
+ }.reduce(And)
+
+ // Then, we add otherPredicate. When we extract those equal condition part,
+ // we use splitConjunctivePredicates. So, it is safe to use
+ // And(rewrittenEqualJoinCondition, c).
+ val rewrittenJoinCondition = otherPredicate
+ .map(c => And(rewrittenEqualJoinCondition, c))
+ .getOrElse(rewrittenEqualJoinCondition)
+
+ rewrittenJoinCondition
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!sqlContext.conf.advancedSqlOptimizations) {
+ plan
+ } else {
+ plan transform {
+ case join: Join => join match {
+ // For a inner join having equal join condition part, we can add filters
+ // to both sides of the join operator.
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
+ val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+ val withRightFilter = addFilterIfNecessary(rightKeys, right)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition))
+
+ // For a left outer join having equal join condition part, we can add a filter
+ // to the right side of the join operator.
+ case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(rightKeys, right) =>
+ val withRightFilter = addFilterIfNecessary(rightKeys, right)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition))
+
+ // For a right outer join having equal join condition part, we can add a filter
+ // to the left side of the join operator.
+ case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(leftKeys, left) =>
+ val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition))
+
+ // For a left semi join having equal join condition part, we can add filters
+ // to both sides of the join operator.
+ case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
+ if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
+ val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+ val withRightFilter = addFilterIfNecessary(rightKeys, right)
+ val rewrittenJoinCondition =
+ reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+ Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition))
+
+ case other => other
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
new file mode 100644
index 0000000000..f98e4acafb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls}
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.test.TestSQLContext
+
+/** This is the test suite for FilterNullsInJoinKey optimization rule. */
+class FilterNullsInJoinKeySuite extends PlanTest {
+
+ // We add predicate pushdown rules at here to make sure we do not
+ // create redundant Filter operators. Also, because the attribute ordering of
+ // the Project operator added by ColumnPruning may be not deterministic
+ // (the ordering may depend on the testing environment),
+ // we first construct the plan with expected Filter operators and then
+ // run the optimizer to add the the Project for column pruning.
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubQueries) ::
+ Batch("Operator Optimizations", FixedPoint(100),
+ FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite.
+ CombineFilters,
+ PushPredicateThroughProject,
+ BooleanSimplification,
+ PushPredicateThroughJoin,
+ PushPredicateThroughGenerate,
+ ColumnPruning,
+ ProjectCollapsing) :: Nil
+ }
+
+ val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int)
+
+ val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int)
+
+ test("inner join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For an inner join, FilterNullsInJoinKey add filter to both side.
+ val correctLeft =
+ leftRelation
+ .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ correctLeft
+ .join(correctRight, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("make sure we do not keep adding filters") {
+ val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int)
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, Inner, Some('a === 'e))
+ .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j))
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+ val conditions = optimized.collect {
+ case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs
+ }
+
+ // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables.
+ assert(conditions.length === 3)
+
+ // Make sure attribtues are indeed a, b, e, i, and j.
+ assert(
+ conditions.flatMap(exprs => exprs).toSet ===
+ joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet)
+ }
+
+ test("inner join (partially optimized)") {
+ val joinCondition =
+ ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // We cannot extract attribute from the left join key.
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ leftRelation
+ .join(correctRight, Inner, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("inner join (not optimized)") {
+ val nonOptimizedJoinConditions =
+ Some('c - 100 + 'd === 'g + 1 - 'h) ::
+ Some('d > 'h || 'c === 'g) ::
+ Some('d + 'g + 'c > 'd - 'h) :: Nil
+
+ nonOptimizedJoinConditions.foreach { joinCondition =>
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition)
+ .select('a, 'c, 'f, 'd, 'h, 'g)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
+ }
+ }
+
+ test("left outer join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, LeftOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For a left outer join, FilterNullsInJoinKey add filter to the right side.
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ leftRelation
+ .join(correctRight, LeftOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("right outer join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, RightOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For a right outer join, FilterNullsInJoinKey add filter to the left side.
+ val correctLeft =
+ leftRelation
+ .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+ val correctAnswer =
+ correctLeft
+ .join(rightRelation, RightOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+
+ test("full outer join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, FullOuter, Some(joinCondition))
+ .select('a, 'f, 'd, 'h)
+
+ // FilterNullsInJoinKey does not fire for a full outer join.
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
+ }
+
+ test("left semi join") {
+ val joinCondition =
+ ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+ val joinedPlan =
+ leftRelation
+ .join(rightRelation, LeftSemi, Some(joinCondition))
+ .select('a, 'd)
+
+ val optimized = Optimize.execute(joinedPlan.analyze)
+
+ // For a left semi join, FilterNullsInJoinKey add filter to both side.
+ val correctLeft =
+ leftRelation
+ .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+ val correctRight =
+ rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+ val correctAnswer =
+ correctLeft
+ .join(correctRight, LeftSemi, Some(joinCondition))
+ .select('a, 'd)
+
+ comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+ }
+}