aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala14
2 files changed, 17 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 28820681cd..d8f953fba5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -407,10 +407,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val quotedColName = "`" + col.name + "`"
val colValue = col.dataType match {
case DoubleType | FloatType =>
- nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
+ // nanvl only supports these types
+ nanvl(df.col(quotedColName), lit(null).cast(col.dataType))
case _ => df.col(quotedColName)
}
- coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name)
+ coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index fd829846ac..aa237d0619 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -146,6 +146,20 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
)
checkAnswer(
+ Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null),
+ (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2),
+ Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6)
+ :: Row(0, 0.2) :: Nil
+ )
+
+ checkAnswer(
+ Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null),
+ (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2),
+ Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f)
+ :: Row(0, 0.2f) :: Nil
+ )
+
+ checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
.toDF("a", "b").na.fill(2.34),
Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil