aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-05 11:17:50 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-05 11:17:50 -0700
commit6e821e3d1ae1ed23459bc7f1098510b968130152 (patch)
tree288ef784edd88ad66a5536ef272aafcdb1593f1e /sql/catalyst
parent1c5555a23d3aa40423d658cfbf2c956ad415a6b1 (diff)
downloadspark-6e821e3d1ae1ed23459bc7f1098510b968130152.tar.gz
spark-6e821e3d1ae1ed23459bc7f1098510b968130152.tar.bz2
spark-6e821e3d1ae1ed23459bc7f1098510b968130152.zip
[SPARK-2860][SQL] Fix coercion of CASE WHEN.
Author: Michael Armbrust <michael@databricks.com> Closes #1785 from marmbrus/caseNull and squashes the following commits: 126006d [Michael Armbrust] better error message 2fe357f [Michael Armbrust] Fix coercion of CASE WHEN.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala56
1 files changed, 32 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e94f2a3bea..15eb5982a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -49,10 +49,21 @@ trait HiveTypeCoercion {
BooleanCasts ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
- CastNulls ::
+ CaseWhenCoercion ::
Division ::
Nil
+ trait TypeWidening {
+ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
+ // Try and find a promotion rule that contains both types in question.
+ val applicableConversion =
+ HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
+
+ // If found return the widest common type, otherwise None
+ applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+ }
+ }
+
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
@@ -133,16 +144,7 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
- object WidenTypes extends Rule[LogicalPlan] {
-
- def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
- // Try and find a promotion rule that contains both types in question.
- val applicableConversion =
- HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
- // If found return the widest common type, otherwise None
- applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
- }
+ object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -336,28 +338,34 @@ trait HiveTypeCoercion {
}
/**
- * Ensures that NullType gets casted to some other types under certain circumstances.
+ * Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
- object CastNulls extends Rule[LogicalPlan] {
+ object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case cw @ CaseWhen(branches) =>
+ case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
- case Seq(_, value) if value.resolved => Some(value.dataType)
- case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType)
- case _ => None
+ case Seq(_, value) => value.dataType
+ case Seq(elseVal) => elseVal.dataType
}.toSeq
- if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) {
- val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get
+
+ logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
+
+ if (valueTypes.distinct.size > 1) {
+ val commonType = valueTypes.reduce { (v1, v2) =>
+ findTightestCommonType(v1, v2)
+ .getOrElse(sys.error(
+ s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
+ }
val transformedBranches = branches.sliding(2, 2).map {
- case Seq(cond, value) if value.resolved && value.dataType == NullType =>
- Seq(cond, Cast(value, otherType))
- case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType =>
- Seq(Cast(elseVal, otherType))
+ case Seq(cond, value) if value.dataType != commonType =>
+ Seq(cond, Cast(value, commonType))
+ case Seq(elseVal) if elseVal.dataType != commonType =>
+ Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
CaseWhen(transformedBranches)
} else {
- // It is possible to have more types due to the possibility of short-circuiting.
+ // Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
cw
}
}