From e3955643d6f838146e8b2e0463b27612d8e48d02 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 11 Jun 2014 17:58:35 -0700 Subject: [SPARK-2052] [SQL] Add optimization for CaseConversionExpression's. Add optimization for `CaseConversionExpression`'s. Author: Takuya UESHIN Closes #990 from ueshin/issues/SPARK-2052 and squashes the following commits: 2568666 [Takuya UESHIN] Move some rules back. dde7ede [Takuya UESHIN] Add tests to check if ConstantFolding can handle null literals and remove the unneeded rules from NullPropagation. c4eea67 [Takuya UESHIN] Fix toString methods. 23e2363 [Takuya UESHIN] Make CaseConversionExpressions foldable if the child is foldable. 0ff7568 [Takuya UESHIN] Add tests for collapsing case statements. 3977d80 [Takuya UESHIN] Add optimization for CaseConversionExpression's. (cherry picked from commit 9a2448daf984d5bb550dfe0d9e28cbb80ef5cb51) Signed-off-by: Michael Armbrust --- .../catalyst/expressions/stringOperations.scala | 7 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 30 +++---- .../catalyst/optimizer/ConstantFoldingSuite.scala | 61 ++++++++++++++- .../SimplifyCaseConversionExpressionsSuite.scala | 91 ++++++++++++++++++++++ 4 files changed, 174 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4203034084..c074b7bb01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -76,7 +76,8 @@ trait CaseConversionExpression { type EvaluatedType = Any def convert(v: String): String - + + override def foldable: Boolean = child.foldable def nullable: Boolean = child.nullable def dataType: DataType = StringType @@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression) case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { override def convert(v: String): String = v.toUpperCase() + + override def toString() = s"Upper($child)" } /** @@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { override def convert(v: String): String = v.toLowerCase() + + override def toString() = s"Lower($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 28d1aa2e3a..25a347bec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,7 +36,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { ConstantFolding, BooleanSimplification, SimplifyFilters, - SimplifyCasts) :: + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, @@ -132,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] { case Literal(candidate, _) if candidate == v => true case _ => false })) => Literal(true, BooleanType) - case e: UnaryMinus => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } - case e: Cast => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } - case e: Not => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) @@ -375,3 +364,18 @@ object CombineLimits extends Rule[LogicalPlan] { Limit(If(LessThan(ne, le), ne, le), grandChild) } } + +/** + * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because + * the inner conversion is overwritten by the outer one. + */ +object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case Upper(Upper(child)) => Upper(child) + case Upper(Lower(child)) => Upper(child) + case Lower(Upper(child)) => Lower(child) + case Lower(Lower(child)) => Lower(child) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 20dfba8477..6efc0e211e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType} +import org.apache.spark.sql.catalyst.types._ // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ @@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } + + test("Constant folding test: expressions have null literals") { + val originalQuery = + testRelation + .select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, + + GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, + GetField( + Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, + + UnaryMinus(Literal(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal(null, BooleanType)) as 'c8, + + Add(Literal(null, IntegerType), 1) as 'c9, + Add(1, Literal(null, IntegerType)) as 'c10, + + Equals(Literal(null, IntegerType), 1) as 'c11, + Equals(1, Literal(null, IntegerType)) as 'c12, + + Like(Literal(null, StringType), "abc") as 'c13, + Like("abc", Literal(null, StringType)) as 'c14, + + Upper(Literal(null, StringType)) as 'c15) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(true) as 'c1, + Literal(false) as 'c2, + + Literal(null, IntegerType) as 'c3, + Literal(null, IntegerType) as 'c4, + Literal(null, IntegerType) as 'c5, + + Literal(null, IntegerType) as 'c6, + Literal(null, IntegerType) as 'c7, + Literal(null, BooleanType) as 'c8, + + Literal(null, IntegerType) as 'c9, + Literal(null, IntegerType) as 'c10, + + Literal(null, BooleanType) as 'c11, + Literal(null, BooleanType) as 'c12, + + Literal(null, BooleanType) as 'c13, + Literal(null, BooleanType) as 'c14, + + Literal(null, StringType) as 'c15) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala new file mode 100644 index 0000000000..df1409fe7b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class SimplifyCaseConversionExpressionsSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Simplify CaseConversionExpressions", Once, + SimplifyCaseConversionExpressions) :: Nil + } + + val testRelation = LocalRelation('a.string) + + test("simplify UPPER(UPPER(str))") { + val originalQuery = + testRelation + .select(Upper(Upper('a)) as 'u) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify UPPER(LOWER(str))") { + val originalQuery = + testRelation + .select(Upper(Lower('a)) as 'u) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(UPPER(str))") { + val originalQuery = + testRelation + .select(Lower(Upper('a)) as 'l) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(LOWER(str))") { + val originalQuery = + testRelation + .select(Lower(Lower('a)) as 'l) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } +} -- cgit v1.2.3