aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala55
8 files changed, 140 insertions, 38 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 0bc8932240..6006e7bf00 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import javax.annotation.Nullable
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -713,39 +715,68 @@ object HiveTypeCoercion {
case e: ExpectsInputTypes =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
- implicitCast(in, expected)
+ // If we cannot do the implicit cast, just use the original input.
+ implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)
}
/**
- * If needed, cast the expression into the expected type.
- * If the implicit cast is not allowed, return the expression itself.
+ * Given an expected data type, try to cast the expression and return the cast expression.
+ *
+ * If the expression already fits the input type, we simply return the expression itself.
+ * If the expression has an incompatible type that cannot be implicitly cast, return None.
*/
- def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = {
+ def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
val inType = e.dataType
- (inType, expectedType) match {
+
+ // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
+ // We wrap immediately an Option after this.
+ @Nullable val ret: Expression = (inType, expectedType) match {
+
+ // If the expected type is already a parent of the input type, no need to cast.
+ case _ if expectedType.isParentOf(inType) => e
+
// Cast null type (usually from null literals) into target types
- case (NullType, target: DataType) => Cast(e, target.defaultConcreteType)
+ case (NullType, target) => Cast(e, target.defaultConcreteType)
// Implicit cast among numeric types
+ // If input is decimal, and we expect a decimal type, just use the input.
+ case (_: DecimalType, DecimalType) => e
+ // If input is a numeric type but not decimal, and we expect a decimal type,
+ // cast the input to unlimited precision decimal.
+ case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
+ Cast(e, DecimalType.Unlimited)
+ // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
+ case (_: NumericType, target: NumericType) => e
// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
case (TimestampType, DateType) => Cast(e, DateType)
// Implicit cast from/to string
- case (StringType, NumericType) => Cast(e, DoubleType)
+ case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited)
case (StringType, target: NumericType) => Cast(e, target)
case (StringType, DateType) => Cast(e, DateType)
case (StringType, TimestampType) => Cast(e, TimestampType)
case (StringType, BinaryType) => Cast(e, BinaryType)
case (any, StringType) if any != StringType => Cast(e, StringType)
+ // Type collection.
+ // First see if we can find our input type in the type collection. If we can, then just
+ // use the current expression; otherwise, find the first one we can implicitly cast.
+ case (_, TypeCollection(types)) =>
+ if (types.exists(_.isParentOf(inType))) {
+ e
+ } else {
+ types.flatMap(implicitCast(e, _)).headOption.orNull
+ }
+
// Else, just return the same input expression
- case _ => e
+ case _ => null
}
+ Option(ret)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 43e2f8a46e..e5dc99fb62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -28,7 +28,45 @@ import org.apache.spark.util.Utils
* A non-concrete data type, reserved for internal uses.
*/
private[sql] abstract class AbstractDataType {
+ /**
+ * The default concrete type to use if we want to cast a null literal into this type.
+ */
private[sql] def defaultConcreteType: DataType
+
+ /**
+ * Returns true if this data type is a parent of the `childCandidate`.
+ */
+ private[sql] def isParentOf(childCandidate: DataType): Boolean
+}
+
+
+/**
+ * A collection of types that can be used to specify type constraints. The sequence also specifies
+ * precedence: an earlier type takes precedence over a latter type.
+ *
+ * {{{
+ * TypeCollection(StringType, BinaryType)
+ * }}}
+ *
+ * This means that we prefer StringType over BinaryType if it is possible to cast to StringType.
+ */
+private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType {
+ require(types.nonEmpty, s"TypeCollection ($types) cannot be empty")
+
+ private[sql] override def defaultConcreteType: DataType = types.head
+
+ private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
+}
+
+
+private[sql] object TypeCollection {
+
+ def apply(types: DataType*): TypeCollection = new TypeCollection(types)
+
+ def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match {
+ case typ: TypeCollection => Some(typ.types)
+ case _ => None
+ }
}
@@ -61,7 +99,7 @@ abstract class NumericType extends AtomicType {
}
-private[sql] object NumericType extends AbstractDataType {
+private[sql] object NumericType {
/**
* Enables matching against NumericType for expressions:
* {{{
@@ -70,12 +108,10 @@ private[sql] object NumericType extends AbstractDataType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
-
- private[sql] override def defaultConcreteType: DataType = IntegerType
}
-private[sql] object IntegralType extends AbstractDataType {
+private[sql] object IntegralType {
/**
* Enables matching against IntegralType for expressions:
* {{{
@@ -84,8 +120,6 @@ private[sql] object IntegralType extends AbstractDataType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
-
- private[sql] override def defaultConcreteType: DataType = IntegerType
}
@@ -94,7 +128,7 @@ private[sql] abstract class IntegralType extends NumericType {
}
-private[sql] object FractionalType extends AbstractDataType {
+private[sql] object FractionalType {
/**
* Enables matching against FractionalType for expressions:
* {{{
@@ -103,8 +137,6 @@ private[sql] object FractionalType extends AbstractDataType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType]
-
- private[sql] override def defaultConcreteType: DataType = DoubleType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 81553e7fc9..8ea6cb14c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -26,7 +26,11 @@ object ArrayType extends AbstractDataType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
- override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+ private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+
+ private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+ childCandidate.isInstanceOf[ArrayType]
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index c333fa70d1..7d00047d08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -75,7 +75,9 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def asNullable: DataType
- override def defaultConcreteType: DataType = this
+ private[sql] override def defaultConcreteType: DataType = this
+
+ private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 06373a095b..434fc037aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -86,6 +86,10 @@ object DecimalType extends AbstractDataType {
private[sql] override def defaultConcreteType: DataType = Unlimited
+ private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+ childCandidate.isInstanceOf[DecimalType]
+ }
+
val Unlimited: DecimalType = DecimalType(None)
private[sql] object Fixed {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 69c2119e23..2b25617ec6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -71,6 +71,10 @@ object MapType extends AbstractDataType {
private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType)
+ private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+ childCandidate.isInstanceOf[MapType]
+ }
+
/**
* Construct a [[MapType]] object with the given key type and value type.
* The `valueContainsNull` is true.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 6fedeabf23..7e77b77e73 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -301,7 +301,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}
-object StructType {
+object StructType extends AbstractDataType {
+
+ private[sql] override def defaultConcreteType: DataType = new StructType
+
+ private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+ childCandidate.isInstanceOf[StructType]
+ }
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 498fd86a06..60e727c6c7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -27,28 +27,47 @@ import org.apache.spark.sql.types._
class HiveTypeCoercionSuite extends PlanTest {
test("implicit type cast") {
- def shouldCast(from: DataType, to: AbstractDataType): Unit = {
+ def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
- assert(got.dataType === to.defaultConcreteType)
+ assert(got.map(_.dataType) == Option(expected),
+ s"Failed to cast $from to $to")
}
+ shouldCast(NullType, NullType, NullType)
+ shouldCast(NullType, IntegerType, IntegerType)
+ shouldCast(NullType, DecimalType, DecimalType.Unlimited)
+
// TODO: write the entire implicit cast table out for test cases.
- shouldCast(ByteType, IntegerType)
- shouldCast(IntegerType, IntegerType)
- shouldCast(IntegerType, LongType)
- shouldCast(IntegerType, DecimalType.Unlimited)
- shouldCast(LongType, IntegerType)
- shouldCast(LongType, DecimalType.Unlimited)
-
- shouldCast(DateType, TimestampType)
- shouldCast(TimestampType, DateType)
-
- shouldCast(StringType, IntegerType)
- shouldCast(StringType, DateType)
- shouldCast(StringType, TimestampType)
- shouldCast(IntegerType, StringType)
- shouldCast(DateType, StringType)
- shouldCast(TimestampType, StringType)
+ shouldCast(ByteType, IntegerType, IntegerType)
+ shouldCast(IntegerType, IntegerType, IntegerType)
+ shouldCast(IntegerType, LongType, LongType)
+ shouldCast(IntegerType, DecimalType, DecimalType.Unlimited)
+ shouldCast(LongType, IntegerType, IntegerType)
+ shouldCast(LongType, DecimalType, DecimalType.Unlimited)
+
+ shouldCast(DateType, TimestampType, TimestampType)
+ shouldCast(TimestampType, DateType, DateType)
+
+ shouldCast(StringType, IntegerType, IntegerType)
+ shouldCast(StringType, DateType, DateType)
+ shouldCast(StringType, TimestampType, TimestampType)
+ shouldCast(IntegerType, StringType, StringType)
+ shouldCast(DateType, StringType, StringType)
+ shouldCast(TimestampType, StringType, StringType)
+
+ shouldCast(StringType, BinaryType, BinaryType)
+ shouldCast(BinaryType, StringType, StringType)
+
+ shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
+
+ shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
+ shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
+ shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
+
+ shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType)
+ shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
+ shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
+ shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
}
test("tightest common bound for types") {