diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-08-03 11:15:09 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-08-03 11:15:09 -0700 |
commit | b55f34370f695de355b72c1518b5f2a45c324af0 (patch) | |
tree | 37b2d3e984c8de7e1c2e1cf43d31eed716981077 /sql/catalyst/src/main/scala | |
parent | 639df046a250873c26446a037cb832ab28cb5272 (diff) | |
download | spark-b55f34370f695de355b72c1518b5f2a45c324af0.tar.gz spark-b55f34370f695de355b72c1518b5f2a45c324af0.tar.bz2 spark-b55f34370f695de355b72c1518b5f2a45c324af0.zip |
[SPARK-16714][SPARK-16735][SPARK-16646] array, map, greatest, least's type coercion should handle decimal type
## What changes were proposed in this pull request?
Here is a table about the behaviours of `array`/`map` and `greatest`/`least` in Hive, MySQL and Postgres:
| |Hive|MySQL|Postgres|
|---|---|---|---|---|
|`array`/`map`|can find a wider type with decimal type arguments, and will truncate the wider decimal type if necessary|can find a wider type with decimal type arguments, no truncation problem|can find a wider type with decimal type arguments, no truncation problem|
|`greatest`/`least`|can find a wider type with decimal type arguments, and truncate if necessary, but can't do string promotion|can find a wider type with decimal type arguments, no truncation problem, but can't do string promotion|can find a wider type with decimal type arguments, no truncation problem, but can't do string promotion|
I think these behaviours makes sense and Spark SQL should follow them.
This PR fixes `array` and `map` by using `findWiderCommonType` to get the wider type.
This PR fixes `greatest` and `least` by add a `findWiderTypeWithoutStringPromotion`, which provides similar semantic of `findWiderCommonType`, but without string promotion.
## How was this patch tested?
new tests in `TypeCoersionSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Author: Yin Huai <yhuai@databricks.com>
Closes #14439 from cloud-fan/bug.
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 47 |
1 files changed, 30 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 8503b8dcf8..021952e716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -109,18 +109,6 @@ object TypeCoercion { } /** - * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use - * [[findTightestCommonTypeToString]] to find the TightestCommonType. - */ - private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => - findTightestCommonTypeToString(d, c) - }) - } - - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. */ @@ -157,6 +145,28 @@ object TypeCoercion { }) } + /** + * Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to + * [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds + * system limitation, this rule will truncate the decimal type before return it. + */ + private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) + case _ => None + }) + case None => None + }) + } + private def haveSameType(exprs: Seq[Expression]): Boolean = exprs.map(_.dataType).distinct.length == 1 @@ -440,7 +450,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -451,7 +461,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -461,7 +471,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -494,16 +504,19 @@ object TypeCoercion { case None => c } + // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if + // we need to truncate, but we should not promote one side to string if the other side is + // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } |