diff options
author | Reynold Xin <rxin@databricks.com> | 2016-01-12 10:58:57 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-12 10:58:57 -0800 |
commit | 1d8887953018b2e12b6ee47a76e50e542c836b80 (patch) | |
tree | 9ca1109fa245116ffb3e359196286b2671a75a2c /sql | |
parent | 7e15044d9d9f9839c8d422bae71f27e855d559b4 (diff) | |
download | spark-1d8887953018b2e12b6ee47a76e50e542c836b80.tar.gz spark-1d8887953018b2e12b6ee47a76e50e542c836b80.tar.bz2 spark-1d8887953018b2e12b6ee47a76e50e542c836b80.zip |
[SPARK-12762][SQL] Add unit test for SimplifyConditionals optimization rule
This pull request does a few small things:
1. Separated if simplification from BooleanSimplification and created a new rule SimplifyConditionals. In the future we can also simplify other conditional expressions here.
2. Added unit test for SimplifyConditionals.
3. Renamed SimplifyCaseConversionExpressionsSuite to SimplifyStringCaseConversionSuite
Author: Reynold Xin <rxin@databricks.com>
Closes #10716 from rxin/SPARK-12762.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala | 10 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 10 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala | 3 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala | 50 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala) | 3 |
5 files changed, 69 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 19da849d2b..379e62a26e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -45,7 +45,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def dataType: DataType = trueValue.dataType override def eval(input: InternalRow): Any = { - if (true == predicate.eval(input)) { + if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) } else { falseValue.eval(input) @@ -141,8 +141,8 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { } } - /** Written in imperative fashion for performance considerations. */ override def eval(input: InternalRow): Any = { + // Written in imperative fashion for performance considerations val len = branchesArr.length var i = 0 // If all branches fail and an elseVal is not provided, the whole statement @@ -389,7 +389,7 @@ case class Least(children: Seq[Expression]) extends Expression { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = + def updateEval(eval: GeneratedExpressionCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || @@ -398,6 +398,7 @@ case class Least(children: Seq[Expression]) extends Expression { ${ev.value} = ${eval.value}; } """ + } s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; @@ -447,7 +448,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = + def updateEval(eval: GeneratedExpressionCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || @@ -456,6 +457,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ${ev.value} = ${eval.value}; } """ + } s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b70bc184d0..487431f892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -63,6 +63,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ConstantFolding, LikeSimplification, BooleanSimplification, + SimplifyConditionals, RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, @@ -608,7 +609,16 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case Not(a And b) => Or(Not(a), Not(b)) case Not(Not(e)) => e + } + } +} +/** + * Simplifies conditional expressions (if / case). + */ +object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 9fe2b2d1f4..87ad81db11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -34,7 +34,8 @@ class CombiningLimitsSuite extends PlanTest { Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, - BooleanSimplification) :: Nil + BooleanSimplification, + SimplifyConditionals) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala new file mode 100644 index 0000000000..8e5d7ef3c9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class SimplifyConditionalSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + test("simplify if") { + assertEquivalent( + If(TrueLiteral, Literal(10), Literal(20)), + Literal(10)) + + assertEquivalent( + If(FalseLiteral, Literal(10), Literal(20)), + Literal(20)) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala index 41455221cf..24413e7a2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -25,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -class SimplifyCaseConversionExpressionsSuite extends PlanTest { +class SimplifyStringCaseConversionSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = |