From b3b9ad23cffc1c6d83168487093e4c03d49e1c2c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Jan 2016 18:45:55 -0800 Subject: [SPARK-12788][SQL] Simplify BooleanEquality by using casts. Author: Reynold Xin Closes #10730 from rxin/SPARK-12788. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 30 ++++------------------ .../catalyst/analysis/HiveTypeCoercionSuite.scala | 28 +++++++++++++++++++- 2 files changed, 32 insertions(+), 26 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e9e2067081..980b5d52fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -482,27 +482,6 @@ object HiveTypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { - CaseKeyWhen(numericExpr, Seq( - Literal(trueValues.head), booleanExpr, - Literal(falseValues.head), Not(booleanExpr), - Literal(false))) - } - - private def transform(booleanExpr: Expression, numericExpr: Expression) = { - If(Or(IsNull(booleanExpr), IsNull(numericExpr)), - Literal.create(null, BooleanType), - buildCaseKeyWhen(booleanExpr, numericExpr)) - } - - private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { - CaseWhen(Seq( - And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), - Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), - buildCaseKeyWhen(booleanExpr, numericExpr) - )) - } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -511,6 +490,7 @@ object HiveTypeCoercion { // all other cases are considered as false. // We may simplify the expression if one side is literal numeric values + // TODO: Maybe these rules should go into the optimizer. case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) if trueValues.contains(value) => bool case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) @@ -529,13 +509,13 @@ object HiveTypeCoercion { if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) case EqualTo(left @ BooleanType(), right @ NumericType()) => - transform(left , right) + EqualTo(Cast(left, right.dataType), right) case EqualTo(left @ NumericType(), right @ BooleanType()) => - transform(right, left) + EqualTo(left, Cast(right, left.dataType)) case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => - transformNullSafe(left, right) + EqualNullSafe(Cast(left, right.dataType), right) case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => - transformNullSafe(right, left) + EqualNullSafe(left, Cast(right, left.dataType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 23b11af9ac..40378c6727 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -320,7 +320,33 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("type coercion simplification for equal to") { + test("BooleanEquality type cast") { + val be = HiveTypeCoercion.BooleanEquality + // Use something more than a literal to avoid triggering the simplification rules. + val one = Add(Literal(Decimal(1)), Literal(Decimal(0))) + + ruleTest(be, + EqualTo(Literal(true), one), + EqualTo(Cast(Literal(true), one.dataType), one) + ) + + ruleTest(be, + EqualTo(one, Literal(true)), + EqualTo(one, Cast(Literal(true), one.dataType)) + ) + + ruleTest(be, + EqualNullSafe(Literal(true), one), + EqualNullSafe(Cast(Literal(true), one.dataType), one) + ) + + ruleTest(be, + EqualNullSafe(one, Literal(true)), + EqualNullSafe(one, Cast(Literal(true), one.dataType)) + ) + } + + test("BooleanEquality simplification") { val be = HiveTypeCoercion.BooleanEquality ruleTest(be, -- cgit v1.2.3