aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
2 files changed, 27 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e7bf7cc1f1..189451d0d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -68,6 +68,19 @@ object HiveTypeCoercion {
}
/**
+ * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use
+ * [[findTightestCommonTypeToString]] to find the TightestCommonType.
+ */
+ private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = {
+ types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
+ case None => None
+ case Some(d) =>
+ findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c))
+ })
+ }
+
+
+ /**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
@@ -599,7 +612,7 @@ trait HiveTypeCoercion {
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
- findTightestCommonType(types) match {
+ findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
@@ -634,7 +647,7 @@ trait HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
- val maybeCommonType = findTightestCommonType(c.valueTypes)
+ val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes)
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, value) if value.dataType != commonType =>
@@ -650,7 +663,8 @@ trait HiveTypeCoercion {
}.getOrElse(c)
case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
- val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType))
+ val maybeCommonType =
+ findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType))
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, then) if when.dataType != commonType =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a47cc30e92..1a6ee8169c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -45,6 +45,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row("one", 6) :: Row("three", 3) :: Nil)
}
+ test("SPARK-8010: promote numeric to string") {
+ val df = Seq((1, 1)).toDF("key", "value")
+ df.registerTempTable("src")
+ val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ")
+ val queryCoalesce = sql("select coalesce(null, 1, '1') from src ")
+
+ checkAnswer(queryCaseWhen, Row("1.0") :: Nil)
+ checkAnswer(queryCoalesce, Row("1") :: Nil)
+ }
+
test("SPARK-6743: no columns from cache") {
Seq(
(83, 0, 38),