aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala56
1 files changed, 32 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e94f2a3bea..15eb5982a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -49,10 +49,21 @@ trait HiveTypeCoercion {
BooleanCasts ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
- CastNulls ::
+ CaseWhenCoercion ::
Division ::
Nil
+ trait TypeWidening {
+ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
+ // Try and find a promotion rule that contains both types in question.
+ val applicableConversion =
+ HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
+
+ // If found return the widest common type, otherwise None
+ applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+ }
+ }
+
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
@@ -133,16 +144,7 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
- object WidenTypes extends Rule[LogicalPlan] {
-
- def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
- // Try and find a promotion rule that contains both types in question.
- val applicableConversion =
- HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
- // If found return the widest common type, otherwise None
- applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
- }
+ object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -336,28 +338,34 @@ trait HiveTypeCoercion {
}
/**
- * Ensures that NullType gets casted to some other types under certain circumstances.
+ * Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
- object CastNulls extends Rule[LogicalPlan] {
+ object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case cw @ CaseWhen(branches) =>
+ case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
- case Seq(_, value) if value.resolved => Some(value.dataType)
- case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType)
- case _ => None
+ case Seq(_, value) => value.dataType
+ case Seq(elseVal) => elseVal.dataType
}.toSeq
- if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) {
- val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get
+
+ logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
+
+ if (valueTypes.distinct.size > 1) {
+ val commonType = valueTypes.reduce { (v1, v2) =>
+ findTightestCommonType(v1, v2)
+ .getOrElse(sys.error(
+ s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
+ }
val transformedBranches = branches.sliding(2, 2).map {
- case Seq(cond, value) if value.resolved && value.dataType == NullType =>
- Seq(cond, Cast(value, otherType))
- case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType =>
- Seq(Cast(elseVal, otherType))
+ case Seq(cond, value) if value.dataType != commonType =>
+ Seq(cond, Cast(value, commonType))
+ case Seq(elseVal) if elseVal.dataType != commonType =>
+ Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
CaseWhen(transformedBranches)
} else {
- // It is possible to have more types due to the possibility of short-circuiting.
+ // Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
cw
}
}