aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-19 16:14:41 -0800
committerReynold Xin <rxin@databricks.com>2016-01-19 16:14:41 -0800
commit3e84ef0a54c53c45d7802cd2fecfa1c223580aee (patch)
tree9e8fe9e481d64ce42a438e5b7c43c01679c4bf85
parentf6f7ca9d2ef65da15f42085993e58e618637fad5 (diff)
downloadspark-3e84ef0a54c53c45d7802cd2fecfa1c223580aee.tar.gz
spark-3e84ef0a54c53c45d7802cd2fecfa1c223580aee.tar.bz2
spark-3e84ef0a54c53c45d7802cd2fecfa1c223580aee.zip
[SPARK-12770][SQL] Implement rules for branch elimination for CaseWhen
The three optimization cases are: 1. If the first branch's condition is a true literal, remove the CaseWhen and use the value from that branch. 2. If a branch's condition is a false or null literal, remove that branch. 3. If only the else branch is left, remove the CaseWhen and use the value from the else branch. Author: Reynold Xin <rxin@databricks.com> Closes #10827 from rxin/SPARK-12770.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala37
2 files changed, 55 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index cc3371c08f..b7caa49f50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -635,6 +635,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
+
+ case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+ // If there are branches that are always false, remove them.
+ // If there are no more branches left, just use the else value.
+ // Note that these two are handled together here in a single case statement because
+ // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
+ val newBranches = branches.filter(_._1 != FalseLiteral)
+ if (newBranches.isEmpty) {
+ elseValue.getOrElse(Literal.create(null, e.dataType))
+ } else {
+ e.copy(branches = newBranches)
+ }
+
+ case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
+ // If the first branch is a true literal, remove the entire CaseWhen and use the value
+ // from that. Note that CaseWhen.branches should never be empty, and as a result the
+ // headOption (rather than head) added above is just a extra (and unnecessary) safeguard.
+ branches.head._2
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index 8e5d7ef3c9..d436b627f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.IntegerType
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
comparePlans(actual, correctAnswer)
}
+ private val trueBranch = (TrueLiteral, Literal(5))
+ private val normalBranch = (NonFoldableLiteral(true), Literal(10))
+ private val unreachableBranch = (FalseLiteral, Literal(20))
+
test("simplify if") {
assertEquivalent(
If(TrueLiteral, Literal(10), Literal(20)),
@@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(20))
}
+ test("remove unreachable branches") {
+ // i.e. removing branches whose conditions are always false
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+ CaseWhen(normalBranch :: Nil, None))
+ }
+
+ test("remove entire CaseWhen if only the else branch is reachable") {
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+ Literal(30))
+
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
+ Literal.create(null, IntegerType))
+ }
+
+ test("remove entire CaseWhen if the first branch is always true") {
+ assertEquivalent(
+ CaseWhen(trueBranch :: normalBranch :: Nil, None),
+ Literal(5))
+
+ // Test branch elimination and simplification in combination
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+ Literal(5))
+
+ // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
+ assertEquivalent(
+ CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
+ }
}