From 6e821e3d1ae1ed23459bc7f1098510b968130152 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 5 Aug 2014 11:17:50 -0700 Subject: [SPARK-2860][SQL] Fix coercion of CASE WHEN. Author: Michael Armbrust Closes #1785 from marmbrus/caseNull and squashes the following commits: 126006d [Michael Armbrust] better error message 2fe357f [Michael Armbrust] Fix coercion of CASE WHEN. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 56 ++++++++++++---------- 1 file changed, 32 insertions(+), 24 deletions(-) (limited to 'sql/catalyst/src') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e94f2a3bea..15eb5982a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -49,10 +49,21 @@ trait HiveTypeCoercion { BooleanCasts :: StringToIntegralCasts :: FunctionArgumentConversion :: - CastNulls :: + CaseWhenCoercion :: Division :: Nil + trait TypeWidening { + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -133,16 +144,7 @@ trait HiveTypeCoercion { * - LongType to FloatType * - LongType to DoubleType */ - object WidenTypes extends Rule[LogicalPlan] { - - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - } + object WidenTypes extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -336,28 +338,34 @@ trait HiveTypeCoercion { } /** - * Ensures that NullType gets casted to some other types under certain circumstances. + * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CastNulls extends Rule[LogicalPlan] { + object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) => + case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) if value.resolved => Some(value.dataType) - case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) - case _ => None + case Seq(_, value) => value.dataType + case Seq(elseVal) => elseVal.dataType }.toSeq - if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { - val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + + logDebug(s"Input values for null casting ${valueTypes.mkString(",")}") + + if (valueTypes.distinct.size > 1) { + val commonType = valueTypes.reduce { (v1, v2) => + findTightestCommonType(v1, v2) + .getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } val transformedBranches = branches.sliding(2, 2).map { - case Seq(cond, value) if value.resolved && value.dataType == NullType => - Seq(cond, Cast(value, otherType)) - case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => - Seq(Cast(elseVal, otherType)) + case Seq(cond, value) if value.dataType != commonType => + Seq(cond, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) case s => s }.reduce(_ ++ _) CaseWhen(transformedBranches) } else { - // It is possible to have more types due to the possibility of short-circuiting. + // Types match up. Hopefully some other rule fixes whatever is wrong with resolution. cw } } -- cgit v1.2.3