From 373a376c04320aab228b5c385e2b788809877d3e Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 19 Aug 2015 14:31:51 -0700 Subject: [SPARK-10083] [SQL] CaseWhen should support type coercion of DecimalType and FractionalType create t1 (a decimal(7, 2), b long); select case when 1=1 then a else 1.0 end from t1; select case when 1=1 then a else b end from t1; Author: Daoyuan Wang Closes #8270 from adrian-wang/casewhenfractional. --- .../apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 4 ++-- .../spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 62c27ee0b9..f2f2ba2f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -605,7 +605,7 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") - val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => @@ -622,7 +622,7 @@ object HiveTypeCoercion { case c: CaseKeyWhen if c.childrenResolved && !c.resolved => val maybeCommonType = - findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) + findWiderCommonType((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index cbdf453f60..6f33ab733b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -285,6 +285,17 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Cast(Literal(100L), DecimalType(22, 2)), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + ) } test("type coercion simplification for equal to") { -- cgit v1.2.3