aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala10
2 files changed, 11 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c452765af2..70fff51956 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -416,7 +416,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
- case dt if dt == child.dataType => identity[Any]
+ case dt if dt == from => identity[Any]
case StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 5ae0527a9c..5c35baacef 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -727,6 +727,16 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("cast struct with a timestamp field") {
+ val originalSchema = new StructType().add("tsField", TimestampType, nullable = false)
+ // nine out of ten times I'm casting a struct, it's to normalize its fields nullability
+ val targetSchema = new StructType().add("tsField", TimestampType, nullable = true)
+
+ val inp = Literal.create(InternalRow(0L), originalSchema)
+ val expected = InternalRow(0L)
+ checkEvaluation(cast(inp, targetSchema), expected)
+ }
+
test("complex casting") {
val complex = Literal.create(
Row(