aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-12-11 22:45:25 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-11 22:45:25 -0800
commit334480362b3a133c2fb1e9af898930fe76d7a163 (patch)
tree240aad85d727d7604a56560b8ec8bffd565419a3 /sql/catalyst
parentc152dde78f73d5ce3a483fd60a47e7de1f1916da (diff)
downloadspark-334480362b3a133c2fb1e9af898930fe76d7a163.tar.gz
spark-334480362b3a133c2fb1e9af898930fe76d7a163.tar.bz2
spark-334480362b3a133c2fb1e9af898930fe76d7a163.zip
[SPARK-4293][SQL] Make Cast be able to handle complex types.
Inserting data of type including `ArrayType.containsNull == false` or `MapType.valueContainsNull == false` or `StructType.fields.exists(_.nullable == false)` into Hive table will fail because `Cast` inserted by `HiveMetastoreCatalog.PreInsertionCasts` rule of `Analyzer` can't handle these types correctly. Complex type cast rule proposal: - Cast for non-complex types should be able to cast the same as before. - Cast for `ArrayType` can evaluate if - Element type can cast - Nullability rule doesn't break - Cast for `MapType` can evaluate if - Key type can cast - Nullability for casted key type is `false` - Value type can cast - Nullability rule for value type doesn't break - Cast for `StructType` can evaluate if - The field size is the same - Each field can cast - Nullability rule for each field doesn't break - The nested structure should be the same. Nullability rule: - If the casted type is `nullable == true`, the target nullability should be `true` Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3150 from ueshin/issues/SPARK-4293 and squashes the following commits: e935939 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 ba14003 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 8999868 [Takuya UESHIN] Fix a test title. f677c30 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 287f410 [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table. 4f71bb8 [Takuya UESHIN] Make Cast be able to handle complex types.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala161
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala236
2 files changed, 353 insertions, 44 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b47865f87a..4ede0b4821 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,9 +27,14 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
+
+ override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
+
override def foldable = child.foldable
- override def nullable = (child.dataType, dataType) match {
+ override def nullable = forceNullable(child.dataType, dataType) || child.nullable
+
+ private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (DoubleType, TimestampType) => true
@@ -41,8 +46,62 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (DateType, BooleanType) => true
case (DoubleType, _: DecimalType) => true
case (FloatType, _: DecimalType) => true
- case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
- case _ => child.nullable
+ case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
+ case _ => false
+ }
+
+ private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+
+ private[this] def resolve(from: DataType, to: DataType): Boolean = {
+ (from, to) match {
+ case (from, to) if from == to => true
+
+ case (NullType, _) => true
+
+ case (_, StringType) => true
+
+ case (StringType, BinaryType) => true
+
+ case (StringType, BooleanType) => true
+ case (DateType, BooleanType) => true
+ case (TimestampType, BooleanType) => true
+ case (_: NumericType, BooleanType) => true
+
+ case (StringType, TimestampType) => true
+ case (BooleanType, TimestampType) => true
+ case (DateType, TimestampType) => true
+ case (_: NumericType, TimestampType) => true
+
+ case (_, DateType) => true
+
+ case (StringType, _: NumericType) => true
+ case (BooleanType, _: NumericType) => true
+ case (DateType, _: NumericType) => true
+ case (TimestampType, _: NumericType) => true
+ case (_: NumericType, _: NumericType) => true
+
+ case (ArrayType(from, fn), ArrayType(to, tn)) =>
+ resolve(from, to) &&
+ resolvableNullability(fn || forceNullable(from, to), tn)
+
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ resolve(fromKey, toKey) &&
+ (!forceNullable(fromKey, toKey)) &&
+ resolve(fromValue, toValue) &&
+ resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.size == toFields.size &&
+ fromFields.zip(toFields).forall {
+ case (fromField, toField) =>
+ resolve(fromField.dataType, toField.dataType) &&
+ resolvableNullability(
+ fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
+ toField.nullable)
+ }
+
+ case _ => false
+ }
}
override def toString = s"CAST($child, $dataType)"
@@ -53,7 +112,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
// UDFToString
- private[this] def castToString: Any => Any = child.dataType match {
+ private[this] def castToString(from: DataType): Any => Any = from match {
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
case DateType => buildCast[Date](_, dateToString)
case TimestampType => buildCast[Timestamp](_, timestampToString)
@@ -61,12 +120,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
// BinaryConverter
- private[this] def castToBinary: Any => Any = child.dataType match {
+ private[this] def castToBinary(from: DataType): Any => Any = from match {
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
}
// UDFToBoolean
- private[this] def castToBoolean: Any => Any = child.dataType match {
+ private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, _.length() != 0)
case TimestampType =>
@@ -91,7 +150,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
// TimestampConverter
- private[this] def castToTimestamp: Any => Any = child.dataType match {
+ private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => {
// Throw away extra if more than 9 decimal places
@@ -133,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
})
}
- private[this] def decimalToTimestamp(d: Decimal) = {
+ private[this] def decimalToTimestamp(d: Decimal) = {
val seconds = Math.floor(d.toDouble).toLong
val bd = (d.toBigDecimal - seconds) * 1000000000
val nanos = bd.intValue()
@@ -172,11 +231,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
// DateConverter
- private[this] def castToDate: Any => Any = child.dataType match {
+ private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s =>
- try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
- )
+ try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null })
case TimestampType =>
// throw valid precision more than seconds, according to Hive.
// Timestamp.nanos is in 0 to 999,999,999, no more than a second.
@@ -199,7 +257,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
// LongConverter
- private[this] def castToLong: Any => Any = child.dataType match {
+ private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
@@ -210,14 +268,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t))
- case DecimalType() =>
- buildCast[Decimal](_, _.toLong)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}
// IntConverter
- private[this] def castToInt: Any => Any = child.dataType match {
+ private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toInt catch {
case _: NumberFormatException => null
@@ -228,14 +284,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
- case DecimalType() =>
- buildCast[Decimal](_, _.toInt)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}
// ShortConverter
- private[this] def castToShort: Any => Any = child.dataType match {
+ private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
@@ -246,14 +300,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
- case DecimalType() =>
- buildCast[Decimal](_, _.toShort)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}
// ByteConverter
- private[this] def castToByte: Any => Any = child.dataType match {
+ private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
@@ -264,8 +316,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
- case DecimalType() =>
- buildCast[Decimal](_, _.toByte)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}
@@ -285,7 +335,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
}
- private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match {
+ private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
case _: NumberFormatException => null
@@ -301,7 +351,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
case LongType =>
b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
- case x: NumericType => // All other numeric types can be represented precisely as Doubles
+ case x: NumericType => // All other numeric types can be represented precisely as Doubles
b => try {
changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
} catch {
@@ -310,7 +360,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
// DoubleConverter
- private[this] def castToDouble: Any => Any = child.dataType match {
+ private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
@@ -321,14 +371,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t))
- case DecimalType() =>
- buildCast[Decimal](_, _.toDouble)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}
// FloatConverter
- private[this] def castToFloat: Any => Any = child.dataType match {
+ private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
@@ -339,28 +387,53 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
- case DecimalType() =>
- buildCast[Decimal](_, _.toFloat)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}
- private[this] lazy val cast: Any => Any = dataType match {
+ private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
+ val elementCast = cast(from.elementType, to.elementType)
+ buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
+ }
+
+ private[this] def castMap(from: MapType, to: MapType): Any => Any = {
+ val keyCast = cast(from.keyType, to.keyType)
+ val valueCast = cast(from.valueType, to.valueType)
+ buildCast[Map[Any, Any]](_, _.map {
+ case (key, value) => (keyCast(key), if (value == null) null else valueCast(value))
+ })
+ }
+
+ private[this] def castStruct(from: StructType, to: StructType): Any => Any = {
+ val casts = from.fields.zip(to.fields).map {
+ case (fromField, toField) => cast(fromField.dataType, toField.dataType)
+ }
+ buildCast[Row](_, row => Row(row.zip(casts).map {
+ case (v, cast) => if (v == null) null else cast(v)
+ }: _*))
+ }
+
+ private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
case dt if dt == child.dataType => identity[Any]
- case StringType => castToString
- case BinaryType => castToBinary
- case DateType => castToDate
- case decimal: DecimalType => castToDecimal(decimal)
- case TimestampType => castToTimestamp
- case BooleanType => castToBoolean
- case ByteType => castToByte
- case ShortType => castToShort
- case IntegerType => castToInt
- case FloatType => castToFloat
- case LongType => castToLong
- case DoubleType => castToDouble
+ case StringType => castToString(from)
+ case BinaryType => castToBinary(from)
+ case DateType => castToDate(from)
+ case decimal: DecimalType => castToDecimal(from, decimal)
+ case TimestampType => castToTimestamp(from)
+ case BooleanType => castToBoolean(from)
+ case ByteType => castToByte(from)
+ case ShortType => castToShort(from)
+ case IntegerType => castToInt(from)
+ case FloatType => castToFloat(from)
+ case LongType => castToLong(from)
+ case DoubleType => castToDouble(from)
+ case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array)
+ case map: MapType => castMap(from.asInstanceOf[MapType], map)
+ case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
}
+ private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
+
override def eval(input: Row): Any = {
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index cd2f67f448..b030483223 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -487,6 +487,242 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Cast(Literal(1.0f / 0.0f), TimestampType), null)
}
+ test("array casting") {
+ val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true))
+ val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false))
+
+ {
+ val cast = Cast(array, ArrayType(IntegerType, containsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Seq(123, null, null, null))
+ }
+ {
+ val cast = Cast(array, ArrayType(IntegerType, containsNull = false))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(array, ArrayType(BooleanType, containsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Seq(true, true, false, null))
+ }
+ {
+ val cast = Cast(array, ArrayType(BooleanType, containsNull = false))
+ assert(cast.resolved === false)
+ }
+
+ {
+ val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Seq(123, null, null))
+ }
+ {
+ val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = false))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Seq(true, true, false))
+ }
+ {
+ val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = false))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Seq(true, true, false))
+ }
+
+ {
+ val cast = Cast(array, IntegerType)
+ assert(cast.resolved === false)
+ }
+ }
+
+ test("map casting") {
+ val map = Literal(
+ Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
+ MapType(StringType, StringType, valueContainsNull = true))
+ val map_notNull = Literal(
+ Map("a" -> "123", "b" -> "abc", "c" -> ""),
+ MapType(StringType, StringType, valueContainsNull = false))
+
+ {
+ val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null))
+ }
+ {
+ val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
+ }
+ {
+ val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
+ assert(cast.resolved === false)
+ }
+
+ {
+ val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null))
+ }
+ {
+ val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false))
+ }
+ {
+ val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false))
+ }
+ {
+ val cast = Cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
+ assert(cast.resolved === false)
+ }
+
+ {
+ val cast = Cast(map, IntegerType)
+ assert(cast.resolved === false)
+ }
+ }
+
+ test("struct casting") {
+ val struct = Literal(
+ Row("123", "abc", "", null),
+ StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", StringType, nullable = true),
+ StructField("c", StringType, nullable = true),
+ StructField("d", StringType, nullable = true))))
+ val struct_notNull = Literal(
+ Row("123", "abc", ""),
+ StructType(Seq(
+ StructField("a", StringType, nullable = false),
+ StructField("b", StringType, nullable = false),
+ StructField("c", StringType, nullable = false))))
+
+ {
+ val cast = Cast(struct, StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true),
+ StructField("d", IntegerType, nullable = true))))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Row(123, null, null, null))
+ }
+ {
+ val cast = Cast(struct, StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = true))))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(struct, StructType(Seq(
+ StructField("a", BooleanType, nullable = true),
+ StructField("b", BooleanType, nullable = true),
+ StructField("c", BooleanType, nullable = true),
+ StructField("d", BooleanType, nullable = true))))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Row(true, true, false, null))
+ }
+ {
+ val cast = Cast(struct, StructType(Seq(
+ StructField("a", BooleanType, nullable = true),
+ StructField("b", BooleanType, nullable = true),
+ StructField("c", BooleanType, nullable = false),
+ StructField("d", BooleanType, nullable = true))))
+ assert(cast.resolved === false)
+ }
+
+ {
+ val cast = Cast(struct_notNull, StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true))))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Row(123, null, null))
+ }
+ {
+ val cast = Cast(struct_notNull, StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false))))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(struct_notNull, StructType(Seq(
+ StructField("a", BooleanType, nullable = true),
+ StructField("b", BooleanType, nullable = true),
+ StructField("c", BooleanType, nullable = true))))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Row(true, true, false))
+ }
+ {
+ val cast = Cast(struct_notNull, StructType(Seq(
+ StructField("a", BooleanType, nullable = true),
+ StructField("b", BooleanType, nullable = true),
+ StructField("c", BooleanType, nullable = false))))
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Row(true, true, false))
+ }
+
+ {
+ val cast = Cast(struct, StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", StringType, nullable = true),
+ StructField("c", StringType, nullable = true))))
+ assert(cast.resolved === false)
+ }
+ {
+ val cast = Cast(struct, IntegerType)
+ assert(cast.resolved === false)
+ }
+ }
+
+ test("complex casting") {
+ val complex = Literal(
+ Row(
+ Seq("123", "abc", ""),
+ Map("a" -> "123", "b" -> "abc", "c" -> ""),
+ Row(0)),
+ StructType(Seq(
+ StructField("a",
+ ArrayType(StringType, containsNull = false), nullable = true),
+ StructField("m",
+ MapType(StringType, StringType, valueContainsNull = false), nullable = true),
+ StructField("s",
+ StructType(Seq(
+ StructField("i", IntegerType, nullable = true)))))))
+
+ val cast = Cast(complex, StructType(Seq(
+ StructField("a",
+ ArrayType(IntegerType, containsNull = true), nullable = true),
+ StructField("m",
+ MapType(StringType, BooleanType, valueContainsNull = false), nullable = true),
+ StructField("s",
+ StructType(Seq(
+ StructField("l", LongType, nullable = true)))))))
+
+ assert(cast.resolved === true)
+ checkEvaluation(cast, Row(
+ Seq(123, null, null),
+ Map("a" -> true, "b" -> true, "c" -> false),
+ Row(0L)))
+ }
+
test("null checking") {
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
val c1 = 'a.string.at(0)