aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala12
2 files changed, 17 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8b79b0cd65..40d2b42a0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -24,7 +24,11 @@ import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def foldable = child.foldable
- def nullable = child.nullable
+ def nullable = (child.dataType, dataType) match {
+ case (StringType, _: NumericType) => true
+ case (StringType, TimestampType) => true
+ case _ => child.nullable
+ }
override def toString = s"CAST($child, $dataType)"
type EvaluatedType = Any
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 4ce0dff9e1..d287ad73b9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -245,6 +245,18 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
+
+ assert(("abcdef" cast StringType).nullable === false)
+ assert(("abcdef" cast BinaryType).nullable === false)
+ assert(("abcdef" cast BooleanType).nullable === false)
+ assert(("abcdef" cast TimestampType).nullable === true)
+ assert(("abcdef" cast LongType).nullable === true)
+ assert(("abcdef" cast IntegerType).nullable === true)
+ assert(("abcdef" cast ShortType).nullable === true)
+ assert(("abcdef" cast ByteType).nullable === true)
+ assert(("abcdef" cast DecimalType).nullable === true)
+ assert(("abcdef" cast DoubleType).nullable === true)
+ assert(("abcdef" cast FloatType).nullable === true)
}
test("timestamp") {