From f705037617d55bb479ec60bcb1e55c736224be94 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Apr 2016 17:48:53 -0700 Subject: [SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN ## What changes were proposed in this pull request? Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too. **Before** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14] : +- INPUT +- Scan OneRowRelation[] ``` **After** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4] : +- INPUT +- Scan OneRowRelation[] ``` **Hive** ``` hive> select if(null,1,2); OK 2 hive> select case when cast(null as boolean) then 1 else 2 end; OK 2 ``` ## How was this patch tested? Pass the Jenkins tests (including new extended test cases). Author: Dongjoon Hyun Closes #12122 from dongjoon-hyun/SPARK-14338. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 13 ++++++++++--- .../catalyst/optimizer/SimplifyConditionalSuite.scala | 16 +++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 326933ec9e..a5ab390c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { - def nonNullLiteral(e: Expression): Boolean = e match { + private def nonNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => false case _ => true } @@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * Simplifies conditional expressions (if / case). */ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { + private def falseOrNullLiteral(e: Expression): Boolean = e match { + case FalseLiteral => true + case Literal(null, _) => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue + case If(Literal(null, _), _, falseValue) => falseValue - case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) => + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. // If there are no more branches left, just use the else value. // Note that these two are handled together here in a single case statement because // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null). - val newBranches = branches.filter(_._1 != FalseLiteral) + val newBranches = branches.filter(x => !falseOrNullLiteral(x._1)) if (newBranches.isEmpty) { elseValue.getOrElse(Literal.create(null, e.dataType)) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index d436b627f6..33239c0084 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { @@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val trueBranch = (TrueLiteral, Literal(5)) private val normalBranch = (NonFoldableLiteral(true), Literal(10)) private val unreachableBranch = (FalseLiteral, Literal(20)) + private val nullBranch = (Literal(null, NullType), Literal(30)) test("simplify if") { assertEquivalent( @@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { assertEquivalent( If(FalseLiteral, Literal(10), Literal(20)), Literal(20)) + + assertEquivalent( + If(Literal(null, NullType), Literal(10), Literal(20)), + Literal(20)) } test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( - CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None), + CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), CaseWhen(normalBranch :: Nil, None)) } test("remove entire CaseWhen if only the else branch is reachable") { assertEquivalent( - CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))), + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))), Literal(30)) assertEquivalent( @@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { test("remove entire CaseWhen if the first branch is always true") { assertEquivalent( - CaseWhen(trueBranch :: normalBranch :: Nil, None), + CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None), Literal(5)) // Test branch elimination and simplification in combination assertEquivalent( - CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None), + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch + :: Nil, None), Literal(5)) // Make sure this doesn't trigger if there is a non-foldable branch before the true branch -- cgit v1.2.3