aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjiangxingbo <jiangxb1987@gmail.com>2016-12-19 21:20:47 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-12-19 21:20:47 +0100
commit70d495dcecce8617b7099fc599fe7c43d7eae66e (patch)
tree38d7259c13c8459c12374bcd2f991784d542f0a1
parent7a75ee1c9224aa5c2e954fe2a71f9ad506f6782b (diff)
downloadspark-70d495dcecce8617b7099fc599fe7c43d7eae66e.tar.gz
spark-70d495dcecce8617b7099fc599fe7c43d7eae66e.tar.bz2
spark-70d495dcecce8617b7099fc599fe7c43d7eae66e.zip
[SPARK-18624][SQL] Implicit cast ArrayType(InternalType)
## What changes were proposed in this pull request? Currently `ImplicitTypeCasts` doesn't handle casts between `ArrayType`s, this is not convenient, we should add a rule to enable casting from `ArrayType(InternalType)` to `ArrayType(newInternalType)`. Goals: 1. Add a rule to `ImplicitTypeCasts` to enable casting between `ArrayType`s; 2. Simplify `Percentile` and `ApproximatePercentile`. ## How was this patch tested? Updated test cases in `TypeCoercionSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Closes #16057 from jiangxb1987/implicit-cast-complex-types.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala57
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala45
5 files changed, 92 insertions, 49 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 6662a9e974..cd73f9c897 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -673,48 +673,69 @@ object TypeCoercion {
* If the expression has an incompatible type that cannot be implicitly cast, return None.
*/
def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
- val inType = e.dataType
+ implicitCast(e.dataType, expectedType).map { dt =>
+ if (dt == e.dataType) e else Cast(e, dt)
+ }
+ }
+ private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = {
// Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
// We wrap immediately an Option after this.
- @Nullable val ret: Expression = (inType, expectedType) match {
-
+ @Nullable val ret: DataType = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
- case _ if expectedType.acceptsType(inType) => e
+ case _ if expectedType.acceptsType(inType) => inType
// Cast null type (usually from null literals) into target types
- case (NullType, target) => Cast(e, target.defaultConcreteType)
+ case (NullType, target) => target.defaultConcreteType
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
- case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
+ case (StringType, NumericType) => NumericType.defaultConcreteType
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to decimal.
- case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d))
+ case (d: NumericType, DecimalType) => DecimalType.forType(d)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
- case (_: NumericType, target: NumericType) => Cast(e, target)
+ case (_: NumericType, target: NumericType) => target
// Implicit cast between date time types
- case (DateType, TimestampType) => Cast(e, TimestampType)
- case (TimestampType, DateType) => Cast(e, DateType)
+ case (DateType, TimestampType) => TimestampType
+ case (TimestampType, DateType) => DateType
// Implicit cast from/to string
- case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT)
- case (StringType, target: NumericType) => Cast(e, target)
- case (StringType, DateType) => Cast(e, DateType)
- case (StringType, TimestampType) => Cast(e, TimestampType)
- case (StringType, BinaryType) => Cast(e, BinaryType)
+ case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT
+ case (StringType, target: NumericType) => target
+ case (StringType, DateType) => DateType
+ case (StringType, TimestampType) => TimestampType
+ case (StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
- case (any: AtomicType, StringType) if any != StringType => Cast(e, StringType)
+ case (any: AtomicType, StringType) if any != StringType => StringType
// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
- case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
+ case (_, TypeCollection(types)) =>
+ types.flatMap(implicitCast(inType, _)).headOption.orNull
+
+ // Implicit cast between array types.
+ //
+ // Compare the nullabilities of the from type and the to type, check whether the cast of
+ // the nullability is resolvable by the following rules:
+ // 1. If the nullability of the to type is true, the cast is always allowed;
+ // 2. If the nullability of the to type is false, and the nullability of the from type is
+ // true, the cast is never allowed;
+ // 3. If the nullabilities of both the from type and the to type are false, the cast is
+ // allowed only when Cast.forceNullable(fromType, toType) is false.
+ case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
+ implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
+
+ case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null
+
+ case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
+ if !Cast.forceNullable(fromType, toType) =>
+ implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
- // Else, just return the same input expression
case _ => null
}
Option(ret)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 4db1ae6faa..741730e3e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -89,9 +89,7 @@ object Cast {
case _ => false
}
- private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
-
- private def forceNullable(from: DataType, to: DataType) = (from, to) match {
+ def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => true
case (_, _) if from == to => false
@@ -110,6 +108,8 @@ object Cast {
case (_: FractionalType, _: IntegralType) => true // NaN, infinity
case _ => false
}
+
+ private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
}
/** Cast the child expression to the target data type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 01792aef37..0e71442f89 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -86,23 +86,16 @@ case class ApproximatePercentile(
private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
override def inputTypes: Seq[AbstractDataType] = {
- Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
+ Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
}
// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
- private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = {
- (percentageExpression.dataType, percentageExpression.eval()) match {
+ private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) =
+ percentageExpression.eval() match {
// Rule ImplicitTypeCasts can cast other numeric types to double
- case (_, num: Double) => (false, Array(num))
- case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
- val numericArray = arrayData.toObjectArray(baseType)
- (true, numericArray.map { x =>
- baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
- })
- case other =>
- throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
+ case num: Double => (false, Array(num))
+ case arrayData: ArrayData => (true, arrayData.toDoubleArray())
}
- }
override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
@@ -162,7 +155,7 @@ case class ApproximatePercentile(
override def nullable: Boolean = true
override def dataType: DataType = {
- if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
+ if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
}
override def prettyName: String = "percentile_approx"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index b51b55313e..2f68195dd1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -77,15 +77,9 @@ case class Percentile(
private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType]
@transient
- private lazy val percentages =
- (percentageExpression.dataType, percentageExpression.eval()) match {
- case (_, num: Double) => Seq(num)
- case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
- val numericArray = arrayData.toObjectArray(baseType)
- numericArray.map { x =>
- baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq
- case other =>
- throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages")
+ private lazy val percentages = percentageExpression.eval() match {
+ case num: Double => Seq(num)
+ case arrayData: ArrayData => arrayData.toDoubleArray().toSeq
}
override def children: Seq[Expression] = child :: percentageExpression :: Nil
@@ -99,7 +93,7 @@ case class Percentile(
}
override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
- case _: ArrayType => Seq(NumericType, ArrayType)
+ case _: ArrayType => Seq(NumericType, ArrayType(DoubleType))
case _ => Seq(NumericType, DoubleType)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 590c9d5e84..dbb1e3e70f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -57,14 +57,43 @@ class TypeCoercionSuite extends PlanTest {
// scalastyle:on line.size.limit
private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
- val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
- assert(got.map(_.dataType) == Option(expected),
+ // Check default value
+ val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
+ assert(DataType.equalsIgnoreCompatibleNullability(
+ castDefault.map(_.dataType).getOrElse(null), expected),
+ s"Failed to cast $from to $to")
+
+ // Check null value
+ val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
+ assert(DataType.equalsIgnoreCaseAndNullability(
+ castNull.map(_.dataType).getOrElse(null), expected),
s"Failed to cast $from to $to")
}
private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
- val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
- assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
+ // Check default value
+ val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
+ assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault")
+
+ // Check null value
+ val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
+ assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull")
+ }
+
+ private def default(dataType: DataType): Expression = dataType match {
+ case ArrayType(internalType: DataType, _) =>
+ CreateArray(Seq(Literal.default(internalType)))
+ case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
+ CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType)))
+ case _ => Literal.default(dataType)
+ }
+
+ private def createNull(dataType: DataType): Expression = dataType match {
+ case ArrayType(internalType: DataType, _) =>
+ CreateArray(Seq(Literal.create(null, internalType)))
+ case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
+ CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType)))
+ case _ => Literal.create(null, dataType)
}
val integralTypes: Seq[DataType] =
@@ -196,7 +225,13 @@ class TypeCoercionSuite extends PlanTest {
test("implicit type cast - ArrayType(StringType)") {
val checkedType = ArrayType(StringType)
- checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
+ val nonCastableTypes =
+ complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType)
+ checkTypeCasting(checkedType,
+ castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_)))
+ nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _))
+ shouldNotCast(ArrayType(DoubleType, containsNull = false),
+ ArrayType(LongType, containsNull = false))
shouldNotCast(checkedType, DecimalType)
shouldNotCast(checkedType, NumericType)
shouldNotCast(checkedType, IntegralType)