aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-04-02 17:48:53 -0700
committerReynold Xin <rxin@databricks.com>2016-04-02 17:48:53 -0700
commitf705037617d55bb479ec60bcb1e55c736224be94 (patch)
tree530cce0cd7b48a5a87e0f659c7a8cc3d0cb37126
parenta3e293542a6e7df9bcc7d9bbd22b3c93a81bcc38 (diff)
downloadspark-f705037617d55bb479ec60bcb1e55c736224be94.tar.gz
spark-f705037617d55bb479ec60bcb1e55c736224be94.tar.bz2
spark-f705037617d55bb479ec60bcb1e55c736224be94.zip
[SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN
## What changes were proposed in this pull request? Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too. **Before** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14] : +- INPUT +- Scan OneRowRelation[] ``` **After** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4] : +- INPUT +- Scan OneRowRelation[] ``` **Hive** ``` hive> select if(null,1,2); OK 2 hive> select case when cast(null as boolean) then 1 else 2 end; OK 2 ``` ## How was this patch tested? Pass the Jenkins tests (including new extended test cases). Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12122 from dongjoon-hyun/SPARK-14338.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala16
2 files changed, 21 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 326933ec9e..a5ab390c76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
- def nonNullLiteral(e: Expression): Boolean = e match {
+ private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
@@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
* Simplifies conditional expressions (if / case).
*/
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+ private def falseOrNullLiteral(e: Expression): Boolean = e match {
+ case FalseLiteral => true
+ case Literal(null, _) => true
+ case _ => false
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
+ case If(Literal(null, _), _, falseValue) => falseValue
- case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+ case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
- val newBranches = branches.filter(_._1 != FalseLiteral)
+ val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index d436b627f6..33239c0084 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, NullType}
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
+ private val nullBranch = (Literal(null, NullType), Literal(30))
test("simplify if") {
assertEquivalent(
@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
assertEquivalent(
If(FalseLiteral, Literal(10), Literal(20)),
Literal(20))
+
+ assertEquivalent(
+ If(Literal(null, NullType), Literal(10), Literal(20)),
+ Literal(20))
}
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
- CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+ CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
}
test("remove entire CaseWhen if only the else branch is reachable") {
assertEquivalent(
- CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+ CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
Literal(30))
assertEquivalent(
@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
- CaseWhen(trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
Literal(5))
// Test branch elimination and simplification in combination
assertEquivalent(
- CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
+ :: Nil, None),
Literal(5))
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch