From 52508beb650a863ed5c89384414b3b7675cac11e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 2 Jul 2015 14:16:14 -0700 Subject: [SPARK-8772][SQL] Implement implicit type cast for expressions that define input types. Author: Reynold Xin Closes #7175 from rxin/implicitCast and squashes the following commits: 88080a2 [Reynold Xin] Clearer definition of implicit type cast. f0ff97f [Reynold Xin] Added missing file. c65e532 [Reynold Xin] [SPARK-8772][SQL] Implement implicit type cast for expressions that defines input types. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 41 ++++++-- .../catalyst/expressions/ExpectsInputTypes.scala | 24 +---- .../spark/sql/catalyst/expressions/math.scala | 7 +- .../spark/sql/catalyst/expressions/misc.scala | 12 +-- .../sql/catalyst/expressions/predicates.scala | 14 +-- .../catalyst/expressions/stringOperations.scala | 10 +- .../apache/spark/sql/types/AbstractDataType.scala | 114 +++++++++++++++++++++ .../org/apache/spark/sql/types/ArrayType.scala | 6 +- .../org/apache/spark/sql/types/BinaryType.scala | 2 - .../org/apache/spark/sql/types/BooleanType.scala | 2 - .../org/apache/spark/sql/types/ByteType.scala | 2 - .../org/apache/spark/sql/types/DataType.scala | 86 +--------------- .../org/apache/spark/sql/types/DateType.scala | 5 +- .../org/apache/spark/sql/types/DecimalType.scala | 7 +- .../org/apache/spark/sql/types/DoubleType.scala | 2 - .../org/apache/spark/sql/types/FloatType.scala | 2 - .../org/apache/spark/sql/types/IntegerType.scala | 2 - .../org/apache/spark/sql/types/LongType.scala | 2 - .../scala/org/apache/spark/sql/types/MapType.scala | 7 +- .../org/apache/spark/sql/types/NullType.scala | 2 - .../org/apache/spark/sql/types/ShortType.scala | 2 - .../org/apache/spark/sql/types/StringType.scala | 2 - .../org/apache/spark/sql/types/StructType.scala | 2 - .../org/apache/spark/sql/types/TimestampType.scala | 2 - .../catalyst/analysis/HiveTypeCoercionSuite.scala | 25 +++++ 25 files changed, 213 insertions(+), 169 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8420c54f7c..0bc8932240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -704,19 +704,48 @@ object HiveTypeCoercion { /** * Casts types according to the expected input types for Expressions that have the trait - * [[AutoCastInputTypes]]. + * [[ExpectsInputTypes]]. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => - val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { - case (child, actual, expected) => - if (actual == expected) child else Cast(child, expected) + case e: ExpectsInputTypes => + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + implicitCast(in, expected) } - e.withNewChildren(newC) + e.withNewChildren(children) + } + + /** + * If needed, cast the expression into the expected type. + * If the implicit cast is not allowed, return the expression itself. + */ + def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = { + val inType = e.dataType + (inType, expectedType) match { + // Cast null type (usually from null literals) into target types + case (NullType, target: DataType) => Cast(e, target.defaultConcreteType) + + // Implicit cast among numeric types + case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) + + // Implicit cast between date time types + case (DateType, TimestampType) => Cast(e, TimestampType) + case (TimestampType, DateType) => Cast(e, DateType) + + // Implicit cast from/to string + case (StringType, NumericType) => Cast(e, DoubleType) + case (StringType, target: NumericType) => Cast(e, target) + case (StringType, DateType) => Cast(e, DateType) + case (StringType, TimestampType) => Cast(e, TimestampType) + case (StringType, BinaryType) => Cast(e, BinaryType) + case (any, StringType) if any != StringType => Cast(e, StringType) + + // Else, just return the same input expression + case _ => e + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 450fc4165f..916e30154d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.AbstractDataType /** @@ -32,28 +32,12 @@ trait ExpectsInputTypes { self: Expression => * * The possible values at each position are: * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * 2. a non-leaf abstract data type, e.g. NumericType, IntegralType, FractionalType. */ - def inputTypes: Seq[Any] + def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess - } -} - -/** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. - */ -trait AutoCastInputTypes { self: Expression => - - def inputTypes: Seq[DataType] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. + // TODO: implement proper type checking. TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7504c6a066..035980da56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -56,8 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with AutoCastInputTypes { - self: Product => + extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -96,7 +95,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -208,7 +207,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with AutoCastInputTypes { + extends UnaryExpression with Serializable with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 407023e472..e008af3966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,8 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ -case class Md5(child: Expression) - extends UnaryExpression with AutoCastInputTypes { +case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -62,12 +61,10 @@ case class Md5(child: Expression) * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with AutoCastInputTypes { + extends BinaryExpression with Serializable with ExpectsInputTypes { override def dataType: DataType = StringType - override def toString: String = s"SHA2($left, $right)" - override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { @@ -147,7 +144,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -174,8 +171,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ -case class Crc32(child: Expression) - extends UnaryExpression with AutoCastInputTypes { +case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d4569241e7..0b479f466c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,7 +69,7 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { +case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -120,11 +120,11 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with ExpectsInputTypes { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def toString: String = s"($left && $right)" - override def symbol: String = "&&" + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def eval(input: InternalRow): Any = { val l = left.eval(input) @@ -169,11 +169,11 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with ExpectsInputTypes { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def toString: String = s"($left || $right)" - override def symbol: String = "||" + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def eval(input: InternalRow): Any = { val l = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b020f2bbc5..57918b32f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends AutoCastInputTypes { +trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => def escape(v: String): String @@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait CaseConversionExpression extends AutoCastInputTypes { +trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends AutoCastInputTypes { +trait StringComparison extends ExpectsInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -215,7 +215,7 @@ case class EndsWith(left: Expression, right: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with AutoCastInputTypes { + extends Expression with ExpectsInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -283,7 +283,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala new file mode 100644 index 0000000000..43e2f8a46e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} + +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils + +/** + * A non-concrete data type, reserved for internal uses. + */ +private[sql] abstract class AbstractDataType { + private[sql] def defaultConcreteType: DataType +} + + +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * Numeric data types. + */ +abstract class NumericType extends AtomicType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[InternalType] +} + + +private[sql] object NumericType extends AbstractDataType { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] + + private[sql] override def defaultConcreteType: DataType = IntegerType +} + + +private[sql] object IntegralType extends AbstractDataType { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + private[sql] override def defaultConcreteType: DataType = IntegerType +} + + +private[sql] abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[InternalType] +} + + +private[sql] object FractionalType extends AbstractDataType { + /** + * Enables matching against FractionalType for expressions: + * {{{ + * case Cast(child @ FractionalType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] + + private[sql] override def defaultConcreteType: DataType = DoubleType +} + + +private[sql] abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index b116163fac..81553e7fc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -22,9 +22,11 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi -object ArrayType { +object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) + + override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) } @@ -41,8 +43,6 @@ object ArrayType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values - * - * @group dataType */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index 9b58601e5e..f2c6f34ea5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -29,8 +29,6 @@ import org.apache.spark.sql.catalyst.util.TypeUtils * :: DeveloperApi :: * The data type representing `Array[Byte]` values. * Please use the singleton [[DataTypes.BinaryType]]. - * - * @group dataType */ @DeveloperApi class BinaryType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index a7f228cefa..2d8ee3d9bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. - * - *@group dataType */ @DeveloperApi class BooleanType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 4d8685796e..2ca427975a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. - * - * @group dataType */ @DeveloperApi class ByteType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 74677ddfca..c333fa70d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers import org.json4s._ @@ -27,19 +25,15 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.util.Utils /** * :: DeveloperApi :: * The base type of all Spark SQL data types. - * - * @group dataType */ @DeveloperApi -abstract class DataType { +abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: * {{{ @@ -80,84 +74,8 @@ abstract class DataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def asNullable: DataType -} - - -/** - * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - */ -protected[sql] abstract class AtomicType extends DataType { - private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] - private[sql] val ordering: Ordering[InternalType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) - } -} - - -/** - * :: DeveloperApi :: - * Numeric data types. - * - * @group dataType - */ -abstract class NumericType extends AtomicType { - // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for - // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets - // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[InternalType] -} - - -private[sql] object NumericType { - /** - * Enables matching against NumericType for expressions: - * {{{ - * case Cast(child @ NumericType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] -} - - -private[sql] object IntegralType { - /** - * Enables matching against IntegralType for expressions: - * {{{ - * case Cast(child @ IntegralType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] -} - - -private[sql] abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[InternalType] -} - - -private[sql] object FractionalType { - /** - * Enables matching against FractionalType for expressions: - * {{{ - * case Cast(child @ FractionalType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] -} - -private[sql] abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[InternalType] - private[sql] val asIntegral: Integral[InternalType] + override def defaultConcreteType: DataType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 03f0644bc7..1d73e40ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: - * The data type representing `java.sql.Date` values. + * A date type, supporting "0001-01-01" through "9999-12-31". + * * Please use the singleton [[DataTypes.DateType]]. * - * @group dataType + * Internally, this is represented as the number of days from epoch (1970-01-01 00:00:00 UTC). */ @DeveloperApi class DateType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 18cdfa7238..06373a095b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -39,8 +39,6 @@ case class PrecisionInfo(precision: Int, scale: Int) { * A Decimal that might have fixed precision and scale, or unlimited values for these. * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. - * - * @group dataType */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { @@ -84,7 +82,10 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ -object DecimalType { +object DecimalType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = Unlimited + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 6676662321..986c2ab055 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. - * - * @group dataType */ @DeveloperApi class DoubleType private() extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 1d5a2f4f6f..9bd48ece83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. - * - * @group dataType */ @DeveloperApi class FloatType private() extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 74e464c082..a2c6e19b05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. - * - * @group dataType */ @DeveloperApi class IntegerType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 390675782e..2b3adf6ade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. - * - * @group dataType */ @DeveloperApi class LongType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index cfdf493074..69c2119e23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -30,8 +30,6 @@ import org.json4s.JsonDSL._ * @param keyType The data type of map keys. * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. - * - * @group dataType */ case class MapType( keyType: DataType, @@ -69,7 +67,10 @@ case class MapType( } -object MapType { +object MapType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index b64b07431f..aa84115c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -23,8 +23,6 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. - * - * @group dataType */ @DeveloperApi class NullType private() extends DataType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 73e9ec780b..a13119e659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. - * - * @group dataType */ @DeveloperApi class ShortType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1e9476ad06..a7627a2de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -27,8 +27,6 @@ import org.apache.spark.unsafe.types.UTF8String /** * :: DeveloperApi :: * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. - * - * @group dataType */ @DeveloperApi class StringType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 2db0a359e9..6fedeabf23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -87,8 +87,6 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} * val row = Row(Row(1, 2, true)) * // row: Row = [[1,2,true]] * }}} - * - * @group dataType */ @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index a558641fcf..de4b511edc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock * :: DeveloperApi :: * The data type representing `java.sql.Timestamp` values. * Please use the singleton [[DataTypes.TimestampType]]. - * - * @group dataType */ @DeveloperApi class TimestampType private() extends AtomicType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index eae3666595..498fd86a06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -26,6 +26,31 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { + test("implicit type cast") { + def shouldCast(from: DataType, to: AbstractDataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.dataType === to.defaultConcreteType) + } + + // TODO: write the entire implicit cast table out for test cases. + shouldCast(ByteType, IntegerType) + shouldCast(IntegerType, IntegerType) + shouldCast(IntegerType, LongType) + shouldCast(IntegerType, DecimalType.Unlimited) + shouldCast(LongType, IntegerType) + shouldCast(LongType, DecimalType.Unlimited) + + shouldCast(DateType, TimestampType) + shouldCast(TimestampType, DateType) + + shouldCast(StringType, IntegerType) + shouldCast(StringType, DateType) + shouldCast(StringType, TimestampType) + shouldCast(IntegerType, StringType) + shouldCast(DateType, StringType) + shouldCast(TimestampType, StringType) + } + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) -- cgit v1.2.3