aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-12 10:58:57 -0800
committerReynold Xin <rxin@databricks.com>2016-01-12 10:58:57 -0800
commit1d8887953018b2e12b6ee47a76e50e542c836b80 (patch)
tree9ca1109fa245116ffb3e359196286b2671a75a2c /sql
parent7e15044d9d9f9839c8d422bae71f27e855d559b4 (diff)
downloadspark-1d8887953018b2e12b6ee47a76e50e542c836b80.tar.gz
spark-1d8887953018b2e12b6ee47a76e50e542c836b80.tar.bz2
spark-1d8887953018b2e12b6ee47a76e50e542c836b80.zip
[SPARK-12762][SQL] Add unit test for SimplifyConditionals optimization rule
This pull request does a few small things: 1. Separated if simplification from BooleanSimplification and created a new rule SimplifyConditionals. In the future we can also simplify other conditional expressions here. 2. Added unit test for SimplifyConditionals. 3. Renamed SimplifyCaseConversionExpressionsSuite to SimplifyStringCaseConversionSuite Author: Reynold Xin <rxin@databricks.com> Closes #10716 from rxin/SPARK-12762.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala50
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala)3
5 files changed, 69 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 19da849d2b..379e62a26e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -45,7 +45,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def dataType: DataType = trueValue.dataType
override def eval(input: InternalRow): Any = {
- if (true == predicate.eval(input)) {
+ if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) {
trueValue.eval(input)
} else {
falseValue.eval(input)
@@ -141,8 +141,8 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
}
}
- /** Written in imperative fashion for performance considerations. */
override def eval(input: InternalRow): Any = {
+ // Written in imperative fashion for performance considerations
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
@@ -389,7 +389,7 @@ case class Least(children: Seq[Expression]) extends Expression {
val evalChildren = children.map(_.gen(ctx))
val first = evalChildren(0)
val rest = evalChildren.drop(1)
- def updateEval(eval: GeneratedExpressionCode): String =
+ def updateEval(eval: GeneratedExpressionCode): String = {
s"""
${eval.code}
if (!${eval.isNull} && (${ev.isNull} ||
@@ -398,6 +398,7 @@ case class Least(children: Seq[Expression]) extends Expression {
${ev.value} = ${eval.value};
}
"""
+ }
s"""
${first.code}
boolean ${ev.isNull} = ${first.isNull};
@@ -447,7 +448,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
val evalChildren = children.map(_.gen(ctx))
val first = evalChildren(0)
val rest = evalChildren.drop(1)
- def updateEval(eval: GeneratedExpressionCode): String =
+ def updateEval(eval: GeneratedExpressionCode): String = {
s"""
${eval.code}
if (!${eval.isNull} && (${ev.isNull} ||
@@ -456,6 +457,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
${ev.value} = ${eval.value};
}
"""
+ }
s"""
${first.code}
boolean ${ev.isNull} = ${first.isNull};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b70bc184d0..487431f892 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -63,6 +63,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ConstantFolding,
LikeSimplification,
BooleanSimplification,
+ SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyFilters,
SimplifyCasts,
@@ -608,7 +609,16 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
case Not(a And b) => Or(Not(a), Not(b))
case Not(Not(e)) => e
+ }
+ }
+}
+/**
+ * Simplifies conditional expressions (if / case).
+ */
+object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index 9fe2b2d1f4..87ad81db11 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -34,7 +34,8 @@ class CombiningLimitsSuite extends PlanTest {
Batch("Constant Folding", FixedPoint(10),
NullPropagation,
ConstantFolding,
- BooleanSimplification) :: Nil
+ BooleanSimplification,
+ SimplifyConditionals) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
new file mode 100644
index 0000000000..8e5d7ef3c9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
+ }
+
+ protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+ val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
+ val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("simplify if") {
+ assertEquivalent(
+ If(TrueLiteral, Literal(10), Literal(20)),
+ Literal(10))
+
+ assertEquivalent(
+ If(FalseLiteral, Literal(10), Literal(20)),
+ Literal(20))
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala
index 41455221cf..24413e7a2a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
@@ -25,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
-class SimplifyCaseConversionExpressionsSuite extends PlanTest {
+class SimplifyStringCaseConversionSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =