aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala3
3 files changed, 12 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 768897dc07..e1dd010d37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -571,6 +571,7 @@ object TypeCoercion {
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)
+ case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 3e0c357b6d..011d09ff60 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -656,14 +656,20 @@ class TypeCoercionSuite extends PlanTest {
test("nanvl casts") {
ruleTest(TypeCoercion.FunctionArgumentConversion,
- NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
- NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
+ NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
+ NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion,
- NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
- NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
+ NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
+ NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
+ NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
}
test("type coercion for If") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 93d565d9fe..052d85ad33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -408,8 +408,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val quotedColName = "`" + col.name + "`"
val colValue = col.dataType match {
case DoubleType | FloatType =>
- // nanvl only supports these types
- nanvl(df.col(quotedColName), lit(null).cast(col.dataType))
+ nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
case _ => df.col(quotedColName)
}
coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)