aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2017-02-13 13:10:57 -0800
committerXiao Li <gatorsmile@gmail.com>2017-02-13 13:10:57 -0800
commit9af8f743b00001f9fdf8813481464c3837331ad9 (patch)
treec7957cdc679e748eeed6c2e1e9052675a1c7ffc7
parent905fdf0c243e1776c54c01a25b17878361400225 (diff)
downloadspark-9af8f743b00001f9fdf8813481464c3837331ad9.tar.gz
spark-9af8f743b00001f9fdf8813481464c3837331ad9.tar.bz2
spark-9af8f743b00001f9fdf8813481464c3837331ad9.zip
[SPARK-19435][SQL] Type coercion between ArrayTypes
## What changes were proposed in this pull request? This PR proposes to support type coercion between `ArrayType`s where the element types are compatible. **Before** ``` Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))") org.apache.spark.sql.AnalysisException: cannot resolve 'greatest(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got GREATEST(array<int>, array<double>).; line 1 pos 0; Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))") org.apache.spark.sql.AnalysisException: cannot resolve 'least(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got LEAST(array<int>, array<double>).; line 1 pos 0; sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)") org.apache.spark.sql.AnalysisException: incompatible types found in column a for inline table; line 1 pos 14 Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b")) org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. ArrayType(DoubleType,false) <> ArrayType(IntegerType,false) at the first column of the second table;; sql("SELECT IF(1=1, array(1), array(1D))") org.apache.spark.sql.AnalysisException: cannot resolve '(IF((1 = 1), array(1), array(1.0D)))' due to data type mismatch: differing types in '(IF((1 = 1), array(1), array(1.0D)))' (array<int> and array<double>).; line 1 pos 7; ``` **After** ```scala Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))") res5: org.apache.spark.sql.DataFrame = [greatest(a, array(1.0)): array<double>] Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))") res6: org.apache.spark.sql.DataFrame = [least(a, array(1.0)): array<double>] sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)") res8: org.apache.spark.sql.DataFrame = [a: array<double>] Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b")) res10: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [a: array<double>] sql("SELECT IF(1=1, array(1), array(1D))") res15: org.apache.spark.sql.DataFrame = [(IF((1 = 1), array(1), array(1.0))): array<double>] ``` ## How was this patch tested? Unit tests in `TypeCoercion` and Jenkins tests and building with scala 2.10 ```scala ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package ``` Author: hyukjinkwon <gurwls223@gmail.com> Closes #16777 from HyukjinKwon/SPARK-19435.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala80
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala83
2 files changed, 120 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index dfaac92e04..2c00957bd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -101,13 +101,11 @@ object TypeCoercion {
case _ => None
}
- /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
- def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
- findTightestCommonType(left, right).orElse((left, right) match {
- case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
- case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
- case _ => None
- })
+ /** Promotes all the way to StringType. */
+ private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
+ case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
+ case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
+ case _ => None
}
/**
@@ -117,21 +115,17 @@ object TypeCoercion {
* loss of precision when widening decimal and double, and promotion to string.
*/
private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
- (t1, t2) match {
- case (t1: DecimalType, t2: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(t1, t2))
- case (t: IntegralType, d: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (d: DecimalType, t: IntegralType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
- Some(DoubleType)
- case _ =>
- findTightestCommonTypeToString(t1, t2)
- }
+ findTightestCommonType(t1, t2)
+ .orElse(findWiderTypeForDecimal(t1, t2))
+ .orElse(stringPromotion(t1, t2))
+ .orElse((t1, t2) match {
+ case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
+ findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
+ case _ => None
+ })
}
- private def findWiderCommonType(types: Seq[DataType]) = {
+ private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeForTwo(d, c)
case None => None
@@ -139,27 +133,49 @@ object TypeCoercion {
}
/**
- * Similar to [[findWiderCommonType]] that can handle decimal types, but can't promote to
+ * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to
* string. If the wider decimal type exceeds system limitation, this rule will truncate
* the decimal type before return it.
*/
- def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
- types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
- case Some(d) => findTightestCommonType(d, c).orElse((d, c) match {
- case (t1: DecimalType, t2: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(t1, t2))
- case (t: IntegralType, d: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (d: DecimalType, t: IntegralType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
- Some(DoubleType)
+ private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
+ t1: DataType,
+ t2: DataType): Option[DataType] = {
+ findTightestCommonType(t1, t2)
+ .orElse(findWiderTypeForDecimal(t1, t2))
+ .orElse((t1, t2) match {
+ case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
+ findWiderTypeWithoutStringPromotionForTwo(et1, et2)
+ .map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
+ }
+
+ def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
+ types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
+ case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
case None => None
})
}
+ /**
+ * Finds a wider type when one or both types are decimals. If the wider decimal type exceeds
+ * system limitation, this rule will truncate the decimal type. If a decimal and other fractional
+ * types are compared, returns a double type.
+ */
+ private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
+ (dt1, dt2) match {
+ case (t1: DecimalType, t2: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(t1, t2))
+ case (t: IntegralType, d: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (d: DecimalType, t: IntegralType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
+ Some(DoubleType)
+ case _ => None
+ }
+ }
+
private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index ceb5b53e08..3e0c357b6d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -53,7 +53,8 @@ class TypeCoercionSuite extends PlanTest {
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
- // Note: ArrayType*, MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable
+ // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable.
+ // Note: ArrayType* is castable when the element type is castable according to the table.
// scalastyle:on line.size.limit
private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
@@ -125,6 +126,20 @@ class TypeCoercionSuite extends PlanTest {
}
}
+ private def checkWidenType(
+ widenFunc: (DataType, DataType) => Option[DataType],
+ t1: DataType,
+ t2: DataType,
+ expected: Option[DataType]): Unit = {
+ var found = widenFunc(t1, t2)
+ assert(found == expected,
+ s"Expected $expected as wider common type for $t1 and $t2, found $found")
+ // Test both directions to make sure the widening is symmetric.
+ found = widenFunc(t2, t1)
+ assert(found == expected,
+ s"Expected $expected as wider common type for $t2 and $t1, found $found")
+ }
+
test("implicit type cast - ByteType") {
val checkedType = ByteType
checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
@@ -308,15 +323,8 @@ class TypeCoercionSuite extends PlanTest {
}
test("tightest common bound for types") {
- def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
- var found = TypeCoercion.findTightestCommonType(t1, t2)
- assert(found == tightestCommon,
- s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
- // Test both directions to make sure the widening is symmetric.
- found = TypeCoercion.findTightestCommonType(t2, t1)
- assert(found == tightestCommon,
- s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
- }
+ def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit =
+ checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected)
// Null
widenTest(NullType, NullType, Some(NullType))
@@ -355,7 +363,6 @@ class TypeCoercionSuite extends PlanTest {
widenTest(DecimalType(2, 1), DoubleType, None)
widenTest(DecimalType(2, 1), IntegerType, None)
widenTest(DoubleType, DecimalType(2, 1), None)
- widenTest(IntegerType, DecimalType(2, 1), None)
// StringType
widenTest(NullType, StringType, Some(StringType))
@@ -379,6 +386,60 @@ class TypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
+ test("wider common type for decimal and array") {
+ def widenTestWithStringPromotion(
+ t1: DataType,
+ t2: DataType,
+ expected: Option[DataType]): Unit = {
+ checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
+ }
+
+ def widenTestWithoutStringPromotion(
+ t1: DataType,
+ t2: DataType,
+ expected: Option[DataType]): Unit = {
+ checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
+ }
+
+ // Decimal
+ widenTestWithStringPromotion(
+ DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2)))
+ widenTestWithStringPromotion(
+ DecimalType(2, 1), DoubleType, Some(DoubleType))
+ widenTestWithStringPromotion(
+ DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
+ widenTestWithStringPromotion(
+ DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))
+
+ // ArrayType
+ widenTestWithStringPromotion(
+ ArrayType(ShortType, containsNull = true),
+ ArrayType(DoubleType, containsNull = false),
+ Some(ArrayType(DoubleType, containsNull = true)))
+ widenTestWithStringPromotion(
+ ArrayType(TimestampType, containsNull = false),
+ ArrayType(StringType, containsNull = true),
+ Some(ArrayType(StringType, containsNull = true)))
+ widenTestWithStringPromotion(
+ ArrayType(ArrayType(IntegerType), containsNull = false),
+ ArrayType(ArrayType(LongType), containsNull = false),
+ Some(ArrayType(ArrayType(LongType), containsNull = false)))
+
+ // Without string promotion
+ widenTestWithoutStringPromotion(IntegerType, StringType, None)
+ widenTestWithoutStringPromotion(StringType, TimestampType, None)
+ widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None)
+ widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None)
+
+ // String promotion
+ widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
+ widenTestWithStringPromotion(StringType, TimestampType, Some(StringType))
+ widenTestWithStringPromotion(
+ ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType)))
+ widenTestWithStringPromotion(
+ ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType)))
+ }
+
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
ruleTest(Seq(rule), initial, transformed)
}