aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-02 23:32:09 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-08-02 23:32:09 -0700
commit687c8c37150f4c93f8e57d86bb56321a4891286b (patch)
tree5fc768cdf7b01dae261706c148c7fcd3cf622b9d /sql/catalyst/src
parent4cdd8ecd66769316e8593da7790b84cd867968cd (diff)
downloadspark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.gz
spark-687c8c37150f4c93f8e57d86bb56321a4891286b.tar.bz2
spark-687c8c37150f4c93f8e57d86bb56321a4891286b.zip
[SPARK-9372] [SQL] Filter nulls in join keys
This PR adds an optimization rule, `FilterNullsInJoinKey`, to add `Filter` before join operators to filter out rows having null values for join keys. This optimization is guarded by a new SQL conf, `spark.sql.advancedOptimization`. The code in this PR was authored by yhuai; I'm opening this PR to factor out this change from #7685, a larger pull request which contains two other optimizations. Author: Yin Huai <yhuai@databricks.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #7768 from JoshRosen/filter-nulls-in-join-key and squashes the following commits: c02fc3f [Yin Huai] Address Josh's comments. 0a8e096 [Yin Huai] Update comments. ea7d5a6 [Yin Huai] Make sure we do not keep adding filters. be88760 [Yin Huai] Make it clear that FilterNullsInJoinKeySuite.scala is used to test FilterNullsInJoinKey. 8bb39ad [Yin Huai] Fix non-deterministic tests. 303236b [Josh Rosen] Revert changes that are unrelated to null join key filtering 40eeece [Josh Rosen] Merge remote-tracking branch 'origin/master' into filter-nulls-in-join-key c57a954 [Yin Huai] Bug fix. d3d2e64 [Yin Huai] First round of cleanup. f9516b0 [Yin Huai] Style c6667e7 [Yin Huai] Add PartitioningCollection. e616d3b [Yin Huai] wip 7c2d2d8 [Yin Huai] Bug fix and refactoring. 69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. d5b84c3 [Yin Huai] Do not add unnessary filters. 2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala64
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala49
6 files changed, 165 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 287718fab7..d58c475693 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}
+/**
+ * A predicate that is evaluated to be true if there are at least `n` null values.
+ */
+case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate {
+ override def nullable: Boolean = false
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})"
+
+ private[this] val childrenArray = children.toArray
+
+ override def eval(input: InternalRow): Boolean = {
+ var numNulls = 0
+ var i = 0
+ while (i < childrenArray.length && numNulls < n) {
+ val evalC = childrenArray(i).eval(input)
+ if (evalC == null) {
+ numNulls += 1
+ }
+ i += 1
+ }
+ numNulls >= n
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val numNulls = ctx.freshName("numNulls")
+ val code = children.map { e =>
+ val eval = e.gen(ctx)
+ s"""
+ if ($numNulls < $n) {
+ ${eval.code}
+ if (${eval.isNull}) {
+ $numNulls += 1;
+ }
+ }
+ """
+ }.mkString("\n")
+ s"""
+ int $numNulls = 0;
+ $code
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = $numNulls >= $n;
+ """
+ }
+}
/**
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
-case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
+case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
- override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
+ override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})"
private[this] val childrenArray = children.toArray
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29d706dcb3..e4b6294dc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,8 +31,14 @@ import org.apache.spark.sql.types._
abstract class Optimizer extends RuleExecutor[LogicalPlan]
-object DefaultOptimizer extends Optimizer {
- val batches =
+class DefaultOptimizer extends Optimizer {
+
+ /**
+ * Override to provide additional rules for the "Operator Optimizations" batch.
+ */
+ val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+
+ lazy val batches =
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
@@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer {
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
- SetOperationPushDown,
- SamplePushDown,
- PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- ColumnPruning,
+ SetOperationPushDown ::
+ SamplePushDown ::
+ PushPredicateThroughJoin ::
+ PushPredicateThroughProject ::
+ PushPredicateThroughGenerate ::
+ ColumnPruning ::
// Operator combine
- ProjectCollapsing,
- CombineFilters,
- CombineLimits,
+ ProjectCollapsing ::
+ CombineFilters ::
+ CombineLimits ::
// Constant folding
- NullPropagation,
- OptimizeIn,
- ConstantFolding,
- LikeSimplification,
- BooleanSimplification,
- RemovePositive,
- SimplifyFilters,
- SimplifyCasts,
- SimplifyCaseConversionExpressions) ::
+ NullPropagation ::
+ OptimizeIn ::
+ ConstantFolding ::
+ LikeSimplification ::
+ BooleanSimplification ::
+ RemovePositive ::
+ SimplifyFilters ::
+ SimplifyCasts ::
+ SimplifyCaseConversionExpressions ::
+ extendedOperatorOptimizationRules.toList : _*) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
@@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/** Applies a projection only when the child is producing unnecessary attributes */
- private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
+ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = {
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
- Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+ // We need to preserve the nullability of c's output.
+ // So, we first create a outputMap and if a reference is from the output of
+ // c, we use that output attribute from c.
+ val outputMap = AttributeMap(c.output.map(attr => (attr, attr)))
+ val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq
+ Project(projectList, c)
} else {
c
}
+ }
}
/**
@@ -517,6 +530,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
*/
object CombineFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) =>
+ // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and
+ // Not(AtLeastNNulls(1, e2))
+ // (this is used to make sure there is no null in the result of e1 and e2 and
+ // they are added by FilterNullsInJoinKey optimziation rule), we can
+ // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)).
+ Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild)
case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index aacfc86ab0..54b5f49772 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -86,7 +86,37 @@ case class Generate(
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = child.output
+ /**
+ * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children
+ * have at least one null value and atLeastNNulls.children are all attributes.
+ */
+ private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = {
+ val expressions = atLeastNNulls.children
+ val n = atLeastNNulls.n
+ if (n != 1) {
+ // AtLeastNNulls is not used to check if atLeastNNulls.children have
+ // at least one null value.
+ false
+ } else {
+ // AtLeastNNulls is used to check if atLeastNNulls.children have
+ // at least one null value. We need to make sure all atLeastNNulls.children
+ // are attributes.
+ expressions.forall(_.isInstanceOf[Attribute])
+ }
+ }
+
+ override def output: Seq[Attribute] = condition match {
+ case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) =>
+ // The condition is used to make sure that there is no null value in
+ // a.children.
+ val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]])
+ child.output.map {
+ case attr if nonNullableAttributes.contains(attr) =>
+ attr.withNullability(false)
+ case attr => attr
+ }
+ case _ => child.output
+ }
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a41185b4d8..3e55151298 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
trait ExpressionEvalHelper {
self: SparkFunSuite =>
+ protected val defaultOptimizer = new DefaultOptimizer
+
protected def create_row(values: Any*): InternalRow = {
InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
@@ -186,7 +188,7 @@ trait ExpressionEvalHelper {
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer.execute(plan)
+ val optimizedPlan = defaultOptimizer.execute(plan)
checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 9fcb548af6..649a5b44dc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.types._
@@ -149,7 +148,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer.execute(plan)
+ val optimizedPlan = defaultOptimizer.execute(plan)
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index ace6c15dc8..bf197124d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("AtLeastNNonNulls") {
+ test("AtLeastNNonNullNans") {
val mix = Seq(Literal("x"),
Literal.create(null, StringType),
Literal.create(null, DoubleType),
@@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal(Float.MaxValue),
Literal(false))
- checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
- checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow)
+ }
+
+ test("AtLeastNNull") {
+ val mix = Seq(Literal("x"),
+ Literal.create(null, StringType),
+ Literal.create(null, DoubleType),
+ Literal(Double.NaN),
+ Literal(5f))
+
+ val nanOnly = Seq(Literal("x"),
+ Literal(10.0),
+ Literal(Float.NaN),
+ Literal(math.log(-2)),
+ Literal(Double.MaxValue))
+
+ val nullOnly = Seq(Literal("x"),
+ Literal.create(null, DoubleType),
+ Literal.create(null, DecimalType.USER_DEFAULT),
+ Literal(Float.MaxValue),
+ Literal(false))
+
+ checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow)
}
}