aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-12 18:45:55 -0800
committerReynold Xin <rxin@databricks.com>2016-01-12 18:45:55 -0800
commitb3b9ad23cffc1c6d83168487093e4c03d49e1c2c (patch)
tree543e72a79b853877ee1e8aa7dad02e9983f10631 /sql
parent9247084962259ebbbac4c5a80a6ccb271776f019 (diff)
downloadspark-b3b9ad23cffc1c6d83168487093e4c03d49e1c2c.tar.gz
spark-b3b9ad23cffc1c6d83168487093e4c03d49e1c2c.tar.bz2
spark-b3b9ad23cffc1c6d83168487093e4c03d49e1c2c.zip
[SPARK-12788][SQL] Simplify BooleanEquality by using casts.
Author: Reynold Xin <rxin@databricks.com> Closes #10730 from rxin/SPARK-12788.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala30
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala28
2 files changed, 32 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e9e2067081..980b5d52fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -482,27 +482,6 @@ object HiveTypeCoercion {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE)
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO)
- private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
- CaseKeyWhen(numericExpr, Seq(
- Literal(trueValues.head), booleanExpr,
- Literal(falseValues.head), Not(booleanExpr),
- Literal(false)))
- }
-
- private def transform(booleanExpr: Expression, numericExpr: Expression) = {
- If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
- Literal.create(null, BooleanType),
- buildCaseKeyWhen(booleanExpr, numericExpr))
- }
-
- private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
- CaseWhen(Seq(
- And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
- Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
- buildCaseKeyWhen(booleanExpr, numericExpr)
- ))
- }
-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -511,6 +490,7 @@ object HiveTypeCoercion {
// all other cases are considered as false.
// We may simplify the expression if one side is literal numeric values
+ // TODO: Maybe these rules should go into the optimizer.
case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => bool
case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
@@ -529,13 +509,13 @@ object HiveTypeCoercion {
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
case EqualTo(left @ BooleanType(), right @ NumericType()) =>
- transform(left , right)
+ EqualTo(Cast(left, right.dataType), right)
case EqualTo(left @ NumericType(), right @ BooleanType()) =>
- transform(right, left)
+ EqualTo(left, Cast(right, left.dataType))
case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
- transformNullSafe(left, right)
+ EqualNullSafe(Cast(left, right.dataType), right)
case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
- transformNullSafe(right, left)
+ EqualNullSafe(left, Cast(right, left.dataType))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 23b11af9ac..40378c6727 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -320,7 +320,33 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}
- test("type coercion simplification for equal to") {
+ test("BooleanEquality type cast") {
+ val be = HiveTypeCoercion.BooleanEquality
+ // Use something more than a literal to avoid triggering the simplification rules.
+ val one = Add(Literal(Decimal(1)), Literal(Decimal(0)))
+
+ ruleTest(be,
+ EqualTo(Literal(true), one),
+ EqualTo(Cast(Literal(true), one.dataType), one)
+ )
+
+ ruleTest(be,
+ EqualTo(one, Literal(true)),
+ EqualTo(one, Cast(Literal(true), one.dataType))
+ )
+
+ ruleTest(be,
+ EqualNullSafe(Literal(true), one),
+ EqualNullSafe(Cast(Literal(true), one.dataType), one)
+ )
+
+ ruleTest(be,
+ EqualNullSafe(one, Literal(true)),
+ EqualNullSafe(one, Cast(Literal(true), one.dataType))
+ )
+ }
+
+ test("BooleanEquality simplification") {
val be = HiveTypeCoercion.BooleanEquality
ruleTest(be,