diff options
author | hyukjinkwon <gurwls223@gmail.com> | 2017-02-13 13:10:57 -0800 |
---|---|---|
committer | Xiao Li <gatorsmile@gmail.com> | 2017-02-13 13:10:57 -0800 |
commit | 9af8f743b00001f9fdf8813481464c3837331ad9 (patch) | |
tree | c7957cdc679e748eeed6c2e1e9052675a1c7ffc7 /sql | |
parent | 905fdf0c243e1776c54c01a25b17878361400225 (diff) | |
download | spark-9af8f743b00001f9fdf8813481464c3837331ad9.tar.gz spark-9af8f743b00001f9fdf8813481464c3837331ad9.tar.bz2 spark-9af8f743b00001f9fdf8813481464c3837331ad9.zip |
[SPARK-19435][SQL] Type coercion between ArrayTypes
## What changes were proposed in this pull request?
This PR proposes to support type coercion between `ArrayType`s where the element types are compatible.
**Before**
```
Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))")
org.apache.spark.sql.AnalysisException: cannot resolve 'greatest(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got GREATEST(array<int>, array<double>).; line 1 pos 0;
Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))")
org.apache.spark.sql.AnalysisException: cannot resolve 'least(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got LEAST(array<int>, array<double>).; line 1 pos 0;
sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)")
org.apache.spark.sql.AnalysisException: incompatible types found in column a for inline table; line 1 pos 14
Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b"))
org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. ArrayType(DoubleType,false) <> ArrayType(IntegerType,false) at the first column of the second table;;
sql("SELECT IF(1=1, array(1), array(1D))")
org.apache.spark.sql.AnalysisException: cannot resolve '(IF((1 = 1), array(1), array(1.0D)))' due to data type mismatch: differing types in '(IF((1 = 1), array(1), array(1.0D)))' (array<int> and array<double>).; line 1 pos 7;
```
**After**
```scala
Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))")
res5: org.apache.spark.sql.DataFrame = [greatest(a, array(1.0)): array<double>]
Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))")
res6: org.apache.spark.sql.DataFrame = [least(a, array(1.0)): array<double>]
sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)")
res8: org.apache.spark.sql.DataFrame = [a: array<double>]
Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b"))
res10: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [a: array<double>]
sql("SELECT IF(1=1, array(1), array(1D))")
res15: org.apache.spark.sql.DataFrame = [(IF((1 = 1), array(1), array(1.0))): array<double>]
```
## How was this patch tested?
Unit tests in `TypeCoercion` and Jenkins tests and
building with scala 2.10
```scala
./dev/change-scala-version.sh 2.10
./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #16777 from HyukjinKwon/SPARK-19435.
Diffstat (limited to 'sql')
2 files changed, 120 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index dfaac92e04..2c00957bd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -101,13 +101,11 @@ object TypeCoercion { case _ => None } - /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ - def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { - findTightestCommonType(left, right).orElse((left, right) match { - case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) - case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) - case _ => None - }) + /** Promotes all the way to StringType. */ + private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None } /** @@ -117,21 +115,17 @@ object TypeCoercion { * loss of precision when widening decimal and double, and promotion to string. */ private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { - (t1, t2) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) - case _ => - findTightestCommonTypeToString(t1, t2) - } + findTightestCommonType(t1, t2) + .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse(stringPromotion(t1, t2)) + .orElse((t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case _ => None + }) } - private def findWiderCommonType(types: Seq[DataType]) = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findWiderTypeForTwo(d, c) case None => None @@ -139,27 +133,49 @@ object TypeCoercion { } /** - * Similar to [[findWiderCommonType]] that can handle decimal types, but can't promote to + * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to * string. If the wider decimal type exceeds system limitation, this rule will truncate * the decimal type before return it. */ - def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findTightestCommonType(d, c).orElse((d, c) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) + private[analysis] def findWiderTypeWithoutStringPromotionForTwo( + t1: DataType, + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) + .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse((t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findWiderTypeWithoutStringPromotionForTwo(et1, et2) + .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) + } + + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } + /** + * Finds a wider type when one or both types are decimals. If the wider decimal type exceeds + * system limitation, this rule will truncate the decimal type. If a decimal and other fractional + * types are compared, returns a double type. + */ + private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = { + (dt1, dt2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) + case _ => None + } + } + private def haveSameType(exprs: Seq[Expression]): Boolean = exprs.map(_.dataType).distinct.length == 1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index ceb5b53e08..3e0c357b6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -53,7 +53,8 @@ class TypeCoercionSuite extends PlanTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: ArrayType*, MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable + // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: ArrayType* is castable when the element type is castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -125,6 +126,20 @@ class TypeCoercionSuite extends PlanTest { } } + private def checkWidenType( + widenFunc: (DataType, DataType) => Option[DataType], + t1: DataType, + t2: DataType, + expected: Option[DataType]): Unit = { + var found = widenFunc(t1, t2) + assert(found == expected, + s"Expected $expected as wider common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } + test("implicit type cast - ByteType") { val checkedType = ByteType checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType)) @@ -308,15 +323,8 @@ class TypeCoercionSuite extends PlanTest { } test("tightest common bound for types") { - def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = TypeCoercion.findTightestCommonType(t1, t2) - assert(found == tightestCommon, - s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") - // Test both directions to make sure the widening is symmetric. - found = TypeCoercion.findTightestCommonType(t2, t1) - assert(found == tightestCommon, - s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") - } + def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = + checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected) // Null widenTest(NullType, NullType, Some(NullType)) @@ -355,7 +363,6 @@ class TypeCoercionSuite extends PlanTest { widenTest(DecimalType(2, 1), DoubleType, None) widenTest(DecimalType(2, 1), IntegerType, None) widenTest(DoubleType, DecimalType(2, 1), None) - widenTest(IntegerType, DecimalType(2, 1), None) // StringType widenTest(NullType, StringType, Some(StringType)) @@ -379,6 +386,60 @@ class TypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } + test("wider common type for decimal and array") { + def widenTestWithStringPromotion( + t1: DataType, + t2: DataType, + expected: Option[DataType]): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + } + + def widenTestWithoutStringPromotion( + t1: DataType, + t2: DataType, + expected: Option[DataType]): Unit = { + checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + } + + // Decimal + widenTestWithStringPromotion( + DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2))) + widenTestWithStringPromotion( + DecimalType(2, 1), DoubleType, Some(DoubleType)) + widenTestWithStringPromotion( + DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1))) + widenTestWithStringPromotion( + DecimalType(2, 1), LongType, Some(DecimalType(21, 1))) + + // ArrayType + widenTestWithStringPromotion( + ArrayType(ShortType, containsNull = true), + ArrayType(DoubleType, containsNull = false), + Some(ArrayType(DoubleType, containsNull = true))) + widenTestWithStringPromotion( + ArrayType(TimestampType, containsNull = false), + ArrayType(StringType, containsNull = true), + Some(ArrayType(StringType, containsNull = true))) + widenTestWithStringPromotion( + ArrayType(ArrayType(IntegerType), containsNull = false), + ArrayType(ArrayType(LongType), containsNull = false), + Some(ArrayType(ArrayType(LongType), containsNull = false))) + + // Without string promotion + widenTestWithoutStringPromotion(IntegerType, StringType, None) + widenTestWithoutStringPromotion(StringType, TimestampType, None) + widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) + widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + + // String promotion + widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) + widenTestWithStringPromotion(StringType, TimestampType, Some(StringType)) + widenTestWithStringPromotion( + ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + } + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { ruleTest(Seq(rule), initial, transformed) } |