aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-05-31 21:01:46 -0700
committerReynold Xin <rxin@databricks.com>2015-05-31 21:01:46 -0700
commita0e46a0d2ad23ce6a64e6ebdf2ccc776208696b6 (patch)
tree33979fe55fc3eaf9e172bdf488e911f8ef3a2dff /sql
parent91777a1c3ad3b3ec7b65d5a0413209a9baf6b36a (diff)
downloadspark-a0e46a0d2ad23ce6a64e6ebdf2ccc776208696b6.tar.gz
spark-a0e46a0d2ad23ce6a64e6ebdf2ccc776208696b6.tar.bz2
spark-a0e46a0d2ad23ce6a64e6ebdf2ccc776208696b6.zip
[SPARK-7952][SPARK-7984][SQL] equality check between boolean type and numeric type is broken.
The origin code has several problems: * `true <=> 1` will return false as we didn't set a rule to handle it. * `true = a` where `a` is not `Literal` and its value is 1, will return false as we only handle literal values. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6505 from cloud-fan/tmp1 and squashes the following commits: 77f0f39 [Wenchen Fan] minor fix b6401ba [Wenchen Fan] add type coercion for CaseKeyWhen and address comments ebc8c61 [Wenchen Fan] use SQLTestUtils and If 625973c [Wenchen Fan] improve 9ba2130 [Wenchen Fan] address comments fc0d741 [Wenchen Fan] fix style 2846a04 [Wenchen Fan] fix 7952
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala36
5 files changed, 158 insertions, 47 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 96d7b96e60..edcc918bfe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -76,7 +76,7 @@ trait HiveTypeCoercion {
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
- BooleanComparisons ::
+ BooleanEqualization ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
@@ -119,7 +119,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- val stringNaN = Literal("NaN")
+ private val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -349,17 +349,17 @@ trait HiveTypeCoercion {
import scala.math.{max, min}
// Conversion rules for integer types into fixed-precision decimals
- val intTypeToFixed: Map[DataType, DecimalType] = Map(
+ private val intTypeToFixed: Map[DataType, DecimalType] = Map(
ByteType -> DecimalType(3, 0),
ShortType -> DecimalType(5, 0),
IntegerType -> DecimalType(10, 0),
LongType -> DecimalType(20, 0)
)
- def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+ private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
// Conversion rules for float and double into fixed-precision decimals
- val floatTypeToFixed: Map[DataType, DecimalType] = Map(
+ private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
FloatType -> DecimalType(7, 7),
DoubleType -> DecimalType(15, 15)
)
@@ -482,30 +482,66 @@ trait HiveTypeCoercion {
}
/**
- * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
+ * Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
- object BooleanComparisons extends Rule[LogicalPlan] {
- val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_))
- val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_))
+ object BooleanEqualization extends Rule[LogicalPlan] {
+ private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
+ private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
+
+ private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
+ CaseKeyWhen(numericExpr, Seq(
+ Literal(trueValues.head), booleanExpr,
+ Literal(falseValues.head), Not(booleanExpr),
+ Literal(false)))
+ }
+
+ private def transform(booleanExpr: Expression, numericExpr: Expression) = {
+ If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
+ Literal.create(null, BooleanType),
+ buildCaseKeyWhen(booleanExpr, numericExpr))
+ }
+
+ private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
+ CaseWhen(Seq(
+ And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
+ Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
+ buildCaseKeyWhen(booleanExpr, numericExpr)
+ ))
+ }
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- // Hive treats (true = 1) as true and (false = 0) as true.
- case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
- case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
- case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
- case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
-
- // No need to change other EqualTo operators as that actually makes sense for boolean types.
- case e: EqualTo => e
- // No need to change the EqualNullSafe operators, too
- case e: EqualNullSafe => e
- // Otherwise turn them to Byte types so that there exists and ordering.
- case p: BinaryComparison if p.left.dataType == BooleanType &&
- p.right.dataType == BooleanType =>
- p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
+ // Hive treats (true = 1) as true and (false = 0) as true,
+ // all other cases are considered as false.
+
+ // We may simplify the expression if one side is literal numeric values
+ case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => l
+ case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => Not(l)
+ case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
+ if trueValues.contains(value) => r
+ case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
+ if falseValues.contains(value) => Not(r)
+ case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => And(IsNotNull(l), l)
+ case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => And(IsNotNull(l), Not(l))
+ case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
+ if trueValues.contains(value) => And(IsNotNull(r), r)
+ case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
+ if falseValues.contains(value) => And(IsNotNull(r), Not(r))
+
+ case EqualTo(l @ BooleanType(), r @ NumericType()) =>
+ transform(l , r)
+ case EqualTo(l @ NumericType(), r @ BooleanType()) =>
+ transform(r, l)
+ case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
+ transformNullSafe(l, r)
+ case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
+ transformNullSafe(r, l)
}
}
@@ -606,7 +642,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
+ case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
@@ -625,6 +661,23 @@ trait HiveTypeCoercion {
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
+
+ case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
+ val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
+ findTightestCommonType(v1, v2).getOrElse(sys.error(
+ s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
+ }
+ val transformedBranches = ckw.branches.sliding(2, 2).map {
+ case Seq(when, then) if when.dataType != commonType =>
+ Seq(Cast(when, commonType), then)
+ case s => s
+ }.reduce(_ ++ _)
+ val transformedKey = if (ckw.key.dataType != commonType) {
+ Cast(ckw.key, commonType)
+ } else {
+ ckw.key
+ }
+ CaseKeyWhen(transformedKey, transformedBranches)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index e2d1c8115e..4f422d69c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression {
// both then and else val should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
- def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
+ def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
override def dataType: DataType = {
if (!resolved) {
@@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
override def children: Seq[Expression] = key +: branches
override lazy val resolved: Boolean =
- childrenResolved && valueTypesEqual
+ childrenResolved && valueTypesEqual &&
+ (key +: whenList).map(_.dataType).distinct.size == 1
/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index f0101f4a88..a0798428db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
class HiveTypeCoercionSuite extends PlanTest {
@@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
+ private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+ comparePlans(
+ rule(Project(Seq(Alias(initial, "a")()), testRelation)),
+ Project(Seq(Alias(transformed, "a")()), testRelation))
+ }
+
test("coalesce casts") {
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
- def ruleTest(initial: Expression, transformed: Expression) {
- val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
- comparePlans(
- fac(Project(Seq(Alias(initial, "a")()), testRelation)),
- Project(Seq(Alias(transformed, "a")()), testRelation))
- }
- ruleTest(
+ ruleTest(fac,
Coalesce(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
@@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
- ruleTest(
+ ruleTest(fac,
Coalesce(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
:: Nil))
}
+
+ test("type coercion for CaseKeyWhen") {
+ val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
+ ruleTest(cwc,
+ CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
+ CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
+ )
+ // Will remove exception expectation in PR#6405
+ intercept[RuntimeException] {
+ ruleTest(cwc,
+ CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
+ CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
+ )
+ }
+ }
+
+ test("type coercion simplification for equal to") {
+ val be = new HiveTypeCoercion {}.BooleanEqualization
+ ruleTest(be,
+ EqualTo(Literal(true), Literal(1)),
+ Literal(true)
+ )
+ ruleTest(be,
+ EqualTo(Literal(true), Literal(0)),
+ Not(Literal(true))
+ )
+ ruleTest(be,
+ EqualNullSafe(Literal(true), Literal(1)),
+ And(IsNotNull(Literal(true)), Literal(true))
+ )
+ ruleTest(be,
+ EqualNullSafe(Literal(true), Literal(0)),
+ And(IsNotNull(Literal(true)), Not(Literal(true)))
+ )
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 3f5a660f17..b6927485f4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)
- val literalNull = Literal.create(null, BooleanType)
+ val literalNull = Literal.create(null, IntegerType)
val literalInt = Literal(1)
val literalString = Literal("a")
@@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
- checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
+ checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
- checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
- checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
+ checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
+ checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}
test("complex type") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bf18bf854a..63f7d314fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
import org.apache.spark.sql.types._
@@ -32,12 +32,12 @@ import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
-class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
+class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData
- import org.apache.spark.sql.test.TestSQLContext.implicits._
- val sqlCtx = TestSQLContext
+ val sqlContext = TestSQLContext
+ import sqlContext.implicits._
test("SPARK-6743: no columns from cache") {
Seq(
@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
+ val df2 = createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
+ val df3 = createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
+ val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
@@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}
+
+ test("SPARK-7952: fix the equality check between boolean and numeric types") {
+ withTempTable("t") {
+ // numeric field i, boolean field j, result of i = j, result of i <=> j
+ Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
+ (1, true, true, true),
+ (0, false, true, true),
+ (2, true, false, false),
+ (2, false, false, false),
+ (null, true, null, false),
+ (null, false, null, false),
+ (0, null, null, false),
+ (1, null, null, false),
+ (null, null, null, true)
+ ).toDF("i", "b", "r1", "r2").registerTempTable("t")
+
+ checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
+ checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
+ }
+ }
}