aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Crawford <travis@medium.com>2016-03-30 16:59:52 -0700
committerReynold Xin <rxin@databricks.com>2016-03-30 16:59:52 -0700
commitda54abfd8730ef752eca921089bcf568773bd24a (patch)
treecaffe6260f202b2ab1834630546891c5401c8b96
parent258a2434193aae62999102a8df73ca70bf0cb9f1 (diff)
downloadspark-da54abfd8730ef752eca921089bcf568773bd24a.tar.gz
spark-da54abfd8730ef752eca921089bcf568773bd24a.tar.bz2
spark-da54abfd8730ef752eca921089bcf568773bd24a.zip
[SPARK-14081][SQL] - Preserve DataFrame column types when filling nulls.
## What changes were proposed in this pull request? This change resolves an issue where `DataFrameNaFunctions.fill` changes a `FloatType` column to a `DoubleType`. We also clarify the contract that replacement values will be cast to the column data type, which may change the replacement value when casting to a lower precision type. ## How was this patch tested? This patch has associated unit tests. Author: Travis Crawford <travis@medium.com> Closes #11967 from traviscrawford/SPARK-14081-dataframena.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala50
2 files changed, 40 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 33588ef72f..f0e16eefc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -200,6 +200,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* The key of the map is the column name, and the value of the map is the replacement value.
* The value must be of the following type:
* `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`.
+ * Replacement values are cast to the column data type.
*
* For example, the following replaces null values in column "A" with string "unknown", and
* null values in column "B" with numeric value 1.0.
@@ -217,6 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* The key of the map is the column name, and the value of the map is the replacement value.
* The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`.
+ * Replacement values are cast to the column data type.
*
* For example, the following replaces null values in column "A" with string "unknown", and
* null values in column "B" with numeric value 1.0.
@@ -386,10 +388,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val projections = df.schema.fields.map { f =>
values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
v match {
- case v: jl.Float => fillCol[Double](f, v.toDouble)
+ case v: jl.Float => fillCol[Float](f, v)
case v: jl.Double => fillCol[Double](f, v)
- case v: jl.Long => fillCol[Double](f, v.toDouble)
- case v: jl.Integer => fillCol[Double](f, v.toDouble)
+ case v: jl.Long => fillCol[Long](f, v)
+ case v: jl.Integer => fillCol[Integer](f, v)
case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue())
case v: String => fillCol[String](f, v)
}
@@ -402,13 +404,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
- col.dataType match {
+ val quotedColName = "`" + col.name + "`"
+ val colValue = col.dataType match {
case DoubleType | FloatType =>
- coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)),
- lit(replacement).cast(col.dataType)).as(col.name)
- case _ =>
- coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
+ nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
+ case _ => df.col(quotedColName)
}
+ coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index e34875471f..18e04c24a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -141,26 +141,36 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
}
test("fill with map") {
- val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)](
- (null, null, null, null, null)).toDF("a", "b", "c", "d", "e")
- checkAnswer(
- df.na.fill(Map(
- "a" -> "test",
- "c" -> 1,
- "d" -> 2.2,
- "e" -> false
- )),
- Row("test", null, 1, 2.2, false))
-
- // Test Java version
- checkAnswer(
- df.na.fill(Map(
- "a" -> "test",
- "c" -> 1,
- "d" -> 2.2,
- "e" -> false
- ).asJava),
- Row("test", null, 1, 2.2, false))
+ val df = Seq[(String, String, java.lang.Integer, java.lang.Long,
+ java.lang.Float, java.lang.Double, java.lang.Boolean)](
+ (null, null, null, null, null, null, null))
+ .toDF("stringFieldA", "stringFieldB", "integerField", "longField",
+ "floatField", "doubleField", "booleanField")
+
+ val fillMap = Map(
+ "stringFieldA" -> "test",
+ "integerField" -> 1,
+ "longField" -> 2L,
+ "floatField" -> 3.3f,
+ "doubleField" -> 4.4d,
+ "booleanField" -> false)
+
+ val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false)
+
+ checkAnswer(df.na.fill(fillMap), expectedRow)
+ checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version
+
+ // Ensure replacement values are cast to the column data type.
+ checkAnswer(df.na.fill(Map(
+ "integerField" -> 1d,
+ "longField" -> 2d,
+ "floatField" -> 3d,
+ "doubleField" -> 4d)),
+ Row(null, null, 1, 2L, 3f, 4d, null))
+
+ // Ensure column types do not change. Columns that have null values replaced
+ // will no longer be flagged as nullable, so do not compare schemas directly.
+ assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType))
}
test("replace") {