From 334480362b3a133c2fb1e9af898930fe76d7a163 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 11 Dec 2014 22:45:25 -0800 Subject: [SPARK-4293][SQL] Make Cast be able to handle complex types. Inserting data of type including `ArrayType.containsNull == false` or `MapType.valueContainsNull == false` or `StructType.fields.exists(_.nullable == false)` into Hive table will fail because `Cast` inserted by `HiveMetastoreCatalog.PreInsertionCasts` rule of `Analyzer` can't handle these types correctly. Complex type cast rule proposal: - Cast for non-complex types should be able to cast the same as before. - Cast for `ArrayType` can evaluate if - Element type can cast - Nullability rule doesn't break - Cast for `MapType` can evaluate if - Key type can cast - Nullability for casted key type is `false` - Value type can cast - Nullability rule for value type doesn't break - Cast for `StructType` can evaluate if - The field size is the same - Each field can cast - Nullability rule for each field doesn't break - The nested structure should be the same. Nullability rule: - If the casted type is `nullable == true`, the target nullability should be `true` Author: Takuya UESHIN Closes #3150 from ueshin/issues/SPARK-4293 and squashes the following commits: e935939 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 ba14003 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 8999868 [Takuya UESHIN] Fix a test title. f677c30 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 287f410 [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table. 4f71bb8 [Takuya UESHIN] Make Cast be able to handle complex types. --- .../spark/sql/catalyst/expressions/Cast.scala | 161 ++++++++++---- .../expressions/ExpressionEvaluationSuite.scala | 236 +++++++++++++++++++++ 2 files changed, 353 insertions(+), 44 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b47865f87a..4ede0b4821 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -27,9 +27,14 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { + + override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) + override def foldable = child.foldable - override def nullable = (child.dataType, dataType) match { + override def nullable = forceNullable(child.dataType, dataType) || child.nullable + + private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case (DoubleType, TimestampType) => true @@ -41,8 +46,62 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (DateType, BooleanType) => true case (DoubleType, _: DecimalType) => true case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null - case _ => child.nullable + case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null + case _ => false + } + + private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to + + private[this] def resolve(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (from, to) if from == to => true + + case (NullType, _) => true + + case (_, StringType) => true + + case (StringType, BinaryType) => true + + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true + + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true + + case (_, DateType) => true + + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true + case (_: NumericType, _: NumericType) => true + + case (ArrayType(from, fn), ArrayType(to, tn)) => + resolve(from, to) && + resolvableNullability(fn || forceNullable(from, to), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolve(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + resolve(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.size == toFields.size && + fromFields.zip(toFields).forall { + case (fromField, toField) => + resolve(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case _ => false + } } override def toString = s"CAST($child, $dataType)" @@ -53,7 +112,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) // UDFToString - private[this] def castToString: Any => Any = child.dataType match { + private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) case DateType => buildCast[Date](_, dateToString) case TimestampType => buildCast[Timestamp](_, timestampToString) @@ -61,12 +120,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } // BinaryConverter - private[this] def castToBinary: Any => Any = child.dataType match { + private[this] def castToBinary(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, _.getBytes("UTF-8")) } // UDFToBoolean - private[this] def castToBoolean: Any => Any = child.dataType match { + private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, _.length() != 0) case TimestampType => @@ -91,7 +150,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } // TimestampConverter - private[this] def castToTimestamp: Any => Any = child.dataType match { + private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => { // Throw away extra if more than 9 decimal places @@ -133,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w }) } - private[this] def decimalToTimestamp(d: Decimal) = { + private[this] def decimalToTimestamp(d: Decimal) = { val seconds = Math.floor(d.toDouble).toLong val bd = (d.toBigDecimal - seconds) * 1000000000 val nanos = bd.intValue() @@ -172,11 +231,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } // DateConverter - private[this] def castToDate: Any => Any = child.dataType match { + private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => - try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null } - ) + try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. @@ -199,7 +257,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } // LongConverter - private[this] def castToLong: Any => Any = child.dataType match { + private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => try s.toLong catch { case _: NumberFormatException => null @@ -210,14 +268,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType() => - buildCast[Decimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } // IntConverter - private[this] def castToInt: Any => Any = child.dataType match { + private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => try s.toInt catch { case _: NumberFormatException => null @@ -228,14 +284,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType() => - buildCast[Decimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } // ShortConverter - private[this] def castToShort: Any => Any = child.dataType match { + private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => try s.toShort catch { case _: NumberFormatException => null @@ -246,14 +300,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType() => - buildCast[Decimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } // ByteConverter - private[this] def castToByte: Any => Any = child.dataType match { + private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => try s.toByte catch { case _: NumberFormatException => null @@ -264,8 +316,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType() => - buildCast[Decimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } @@ -285,7 +335,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match { + private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { case _: NumberFormatException => null @@ -301,7 +351,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w b => changePrecision(b.asInstanceOf[Decimal].clone(), target) case LongType => b => changePrecision(Decimal(b.asInstanceOf[Long]), target) - case x: NumericType => // All other numeric types can be represented precisely as Doubles + case x: NumericType => // All other numeric types can be represented precisely as Doubles b => try { changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) } catch { @@ -310,7 +360,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } // DoubleConverter - private[this] def castToDouble: Any => Any = child.dataType match { + private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => try s.toDouble catch { case _: NumberFormatException => null @@ -321,14 +371,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType() => - buildCast[Decimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } // FloatConverter - private[this] def castToFloat: Any => Any = child.dataType match { + private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => try s.toFloat catch { case _: NumberFormatException => null @@ -339,28 +387,53 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType() => - buildCast[Decimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - private[this] lazy val cast: Any => Any = dataType match { + private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { + val elementCast = cast(from.elementType, to.elementType) + buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v))) + } + + private[this] def castMap(from: MapType, to: MapType): Any => Any = { + val keyCast = cast(from.keyType, to.keyType) + val valueCast = cast(from.valueType, to.valueType) + buildCast[Map[Any, Any]](_, _.map { + case (key, value) => (keyCast(key), if (value == null) null else valueCast(value)) + }) + } + + private[this] def castStruct(from: StructType, to: StructType): Any => Any = { + val casts = from.fields.zip(to.fields).map { + case (fromField, toField) => cast(fromField.dataType, toField.dataType) + } + buildCast[Row](_, row => Row(row.zip(casts).map { + case (v, cast) => if (v == null) null else cast(v) + }: _*)) + } + + private[this] def cast(from: DataType, to: DataType): Any => Any = to match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString - case BinaryType => castToBinary - case DateType => castToDate - case decimal: DecimalType => castToDecimal(decimal) - case TimestampType => castToTimestamp - case BooleanType => castToBoolean - case ByteType => castToByte - case ShortType => castToShort - case IntegerType => castToInt - case FloatType => castToFloat - case LongType => castToLong - case DoubleType => castToDouble + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) } + private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) + override def eval(input: Row): Any = { val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index cd2f67f448..b030483223 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -487,6 +487,242 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.0f / 0.0f), TimestampType), null) } + test("array casting") { + val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) + val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + + { + val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Seq(123, null, null, null)) + } + { + val cast = Cast(array, ArrayType(IntegerType, containsNull = false)) + assert(cast.resolved === false) + } + { + val cast = Cast(array, ArrayType(BooleanType, containsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Seq(true, true, false, null)) + } + { + val cast = Cast(array, ArrayType(BooleanType, containsNull = false)) + assert(cast.resolved === false) + } + + { + val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Seq(123, null, null)) + } + { + val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = false)) + assert(cast.resolved === false) + } + { + val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Seq(true, true, false)) + } + { + val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = false)) + assert(cast.resolved === true) + checkEvaluation(cast, Seq(true, true, false)) + } + + { + val cast = Cast(array, IntegerType) + assert(cast.resolved === false) + } + } + + test("map casting") { + val map = Literal( + Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val map_notNull = Literal( + Map("a" -> "123", "b" -> "abc", "c" -> ""), + MapType(StringType, StringType, valueContainsNull = false)) + + { + val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) + } + { + val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(cast.resolved === false) + } + { + val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + } + { + val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(cast.resolved === false) + } + { + val cast = Cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(cast.resolved === false) + } + + { + val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null)) + } + { + val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(cast.resolved === false) + } + { + val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(cast.resolved === true) + checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(cast.resolved === true) + checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val cast = Cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(cast.resolved === false) + } + + { + val cast = Cast(map, IntegerType) + assert(cast.resolved === false) + } + } + + test("struct casting") { + val struct = Literal( + Row("123", "abc", "", null), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true), + StructField("d", StringType, nullable = true)))) + val struct_notNull = Literal( + Row("123", "abc", ""), + StructType(Seq( + StructField("a", StringType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", StringType, nullable = false)))) + + { + val cast = Cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true), + StructField("d", IntegerType, nullable = true)))) + assert(cast.resolved === true) + checkEvaluation(cast, Row(123, null, null, null)) + } + { + val cast = Cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true)))) + assert(cast.resolved === false) + } + { + val cast = Cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true), + StructField("d", BooleanType, nullable = true)))) + assert(cast.resolved === true) + checkEvaluation(cast, Row(true, true, false, null)) + } + { + val cast = Cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false), + StructField("d", BooleanType, nullable = true)))) + assert(cast.resolved === false) + } + + { + val cast = Cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true)))) + assert(cast.resolved === true) + checkEvaluation(cast, Row(123, null, null)) + } + { + val cast = Cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false)))) + assert(cast.resolved === false) + } + { + val cast = Cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true)))) + assert(cast.resolved === true) + checkEvaluation(cast, Row(true, true, false)) + } + { + val cast = Cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false)))) + assert(cast.resolved === true) + checkEvaluation(cast, Row(true, true, false)) + } + + { + val cast = Cast(struct, StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true)))) + assert(cast.resolved === false) + } + { + val cast = Cast(struct, IntegerType) + assert(cast.resolved === false) + } + } + + test("complex casting") { + val complex = Literal( + Row( + Seq("123", "abc", ""), + Map("a" -> "123", "b" -> "abc", "c" -> ""), + Row(0)), + StructType(Seq( + StructField("a", + ArrayType(StringType, containsNull = false), nullable = true), + StructField("m", + MapType(StringType, StringType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("i", IntegerType, nullable = true))))))) + + val cast = Cast(complex, StructType(Seq( + StructField("a", + ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("m", + MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("l", LongType, nullable = true))))))) + + assert(cast.resolved === true) + checkEvaluation(cast, Row( + Seq(123, null, null), + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) + } + test("null checking") { val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) val c1 = 'a.string.at(0) -- cgit v1.2.3