aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala40
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala24
2 files changed, 43 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 9b8a08a88d..a42ffce0d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -87,7 +87,7 @@ trait HiveTypeCoercion {
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
- BooleanEqualization ::
+ BooleanEquality ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
@@ -479,9 +479,9 @@ trait HiveTypeCoercion {
/**
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
- object BooleanEqualization extends Rule[LogicalPlan] {
- private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
- private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
+ object BooleanEquality extends Rule[LogicalPlan] {
+ private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1))
+ private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0))
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq(
@@ -512,22 +512,22 @@ trait HiveTypeCoercion {
// all other cases are considered as false.
// We may simplify the expression if one side is literal numeric values
- case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
- if trueValues.contains(value) => left
- case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
- if falseValues.contains(value) => Not(left)
- case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
- if trueValues.contains(value) => right
- case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
- if falseValues.contains(value) => Not(right)
- case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
- if trueValues.contains(value) => And(IsNotNull(left), left)
- case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
- if falseValues.contains(value) => And(IsNotNull(left), Not(left))
- case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
- if trueValues.contains(value) => And(IsNotNull(right), right)
- case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
- if falseValues.contains(value) => And(IsNotNull(right), Not(right))
+ case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => bool
+ case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => Not(bool)
+ case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
+ if trueValues.contains(value) => bool
+ case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
+ if falseValues.contains(value) => Not(bool)
+ case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => And(IsNotNull(bool), bool)
+ case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
+ case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
+ if trueValues.contains(value) => And(IsNotNull(bool), bool)
+ case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
+ if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
case EqualTo(left @ BooleanType(), right @ NumericType()) =>
transform(left , right)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 0df446636e..9977f7af00 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -147,7 +147,8 @@ class HiveTypeCoercionSuite extends PlanTest {
}
test("type coercion simplification for equal to") {
- val be = new HiveTypeCoercion {}.BooleanEqualization
+ val be = new HiveTypeCoercion {}.BooleanEquality
+
ruleTest(be,
EqualTo(Literal(true), Literal(1)),
Literal(true)
@@ -164,5 +165,26 @@ class HiveTypeCoercionSuite extends PlanTest {
EqualNullSafe(Literal(true), Literal(0)),
And(IsNotNull(Literal(true)), Not(Literal(true)))
)
+
+ ruleTest(be,
+ EqualTo(Literal(true), Literal(1L)),
+ Literal(true)
+ )
+ ruleTest(be,
+ EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
+ Literal(true)
+ )
+ ruleTest(be,
+ EqualTo(Literal(BigDecimal(0)), Literal(true)),
+ Not(Literal(true))
+ )
+ ruleTest(be,
+ EqualTo(Literal(Decimal(1)), Literal(true)),
+ Literal(true)
+ )
+ ruleTest(be,
+ EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
+ Literal(true)
+ )
}
}