From 3e84ef0a54c53c45d7802cd2fecfa1c223580aee Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 19 Jan 2016 16:14:41 -0800 Subject: [SPARK-12770][SQL] Implement rules for branch elimination for CaseWhen The three optimization cases are: 1. If the first branch's condition is a true literal, remove the CaseWhen and use the value from that branch. 2. If a branch's condition is a false or null literal, remove that branch. 3. If only the else branch is left, remove the CaseWhen and use the value from the else branch. Author: Reynold Xin Closes #10827 from rxin/SPARK-12770. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 18 +++++++++++ .../optimizer/SimplifyConditionalSuite.scala | 37 ++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index cc3371c08f..b7caa49f50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -635,6 +635,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue + + case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) => + // If there are branches that are always false, remove them. + // If there are no more branches left, just use the else value. + // Note that these two are handled together here in a single case statement because + // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null). + val newBranches = branches.filter(_._1 != FalseLiteral) + if (newBranches.isEmpty) { + elseValue.getOrElse(Literal.create(null, e.dataType)) + } else { + e.copy(branches = newBranches) + } + + case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + // If the first branch is a true literal, remove the entire CaseWhen and use the value + // from that. Note that CaseWhen.branches should never be empty, and as a result the + // headOption (rather than head) added above is just a extra (and unnecessary) safeguard. + branches.head._2 } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 8e5d7ef3c9..d436b627f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.IntegerType class SimplifyConditionalSuite extends PlanTest with PredicateHelper { @@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { comparePlans(actual, correctAnswer) } + private val trueBranch = (TrueLiteral, Literal(5)) + private val normalBranch = (NonFoldableLiteral(true), Literal(10)) + private val unreachableBranch = (FalseLiteral, Literal(20)) + test("simplify if") { assertEquivalent( If(TrueLiteral, Literal(10), Literal(20)), @@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) } + test("remove unreachable branches") { + // i.e. removing branches whose conditions are always false + assertEquivalent( + CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None), + CaseWhen(normalBranch :: Nil, None)) + } + + test("remove entire CaseWhen if only the else branch is reachable") { + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))), + Literal(30)) + + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None), + Literal.create(null, IntegerType)) + } + + test("remove entire CaseWhen if the first branch is always true") { + assertEquivalent( + CaseWhen(trueBranch :: normalBranch :: Nil, None), + Literal(5)) + + // Test branch elimination and simplification in combination + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None), + Literal(5)) + + // Make sure this doesn't trigger if there is a non-foldable branch before the true branch + assertEquivalent( + CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None), + CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None)) + } } -- cgit v1.2.3