diff options
author | Travis Crawford <travis@medium.com> | 2016-03-30 16:59:52 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-03-30 16:59:52 -0700 |
commit | da54abfd8730ef752eca921089bcf568773bd24a (patch) | |
tree | caffe6260f202b2ab1834630546891c5401c8b96 /sql | |
parent | 258a2434193aae62999102a8df73ca70bf0cb9f1 (diff) | |
download | spark-da54abfd8730ef752eca921089bcf568773bd24a.tar.gz spark-da54abfd8730ef752eca921089bcf568773bd24a.tar.bz2 spark-da54abfd8730ef752eca921089bcf568773bd24a.zip |
[SPARK-14081][SQL] - Preserve DataFrame column types when filling nulls.
## What changes were proposed in this pull request?
This change resolves an issue where `DataFrameNaFunctions.fill` changes a `FloatType` column to a `DoubleType`. We also clarify the contract that replacement values will be cast to the column data type, which may change the replacement value when casting to a lower precision type.
## How was this patch tested?
This patch has associated unit tests.
Author: Travis Crawford <travis@medium.com>
Closes #11967 from traviscrawford/SPARK-14081-dataframena.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 18 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | 50 |
2 files changed, 40 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 33588ef72f..f0e16eefc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -200,6 +200,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * The key of the map is the column name, and the value of the map is the replacement value. * The value must be of the following type: * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * Replacement values are cast to the column data type. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -217,6 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * The key of the map is the column name, and the value of the map is the replacement value. * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * Replacement values are cast to the column data type. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -386,10 +388,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val projections = df.schema.fields.map { f => values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => v match { - case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Float => fillCol[Float](f, v) case v: jl.Double => fillCol[Double](f, v) - case v: jl.Long => fillCol[Double](f, v.toDouble) - case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: jl.Long => fillCol[Long](f, v) + case v: jl.Integer => fillCol[Integer](f, v) case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } @@ -402,13 +404,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - col.dataType match { + val quotedColName = "`" + col.name + "`" + val colValue = col.dataType match { case DoubleType | FloatType => - coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)), - lit(replacement).cast(col.dataType)).as(col.name) - case _ => - coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + case _ => df.col(quotedColName) } + coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index e34875471f..18e04c24a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -141,26 +141,36 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( - (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") - checkAnswer( - df.na.fill(Map( - "a" -> "test", - "c" -> 1, - "d" -> 2.2, - "e" -> false - )), - Row("test", null, 1, 2.2, false)) - - // Test Java version - checkAnswer( - df.na.fill(Map( - "a" -> "test", - "c" -> 1, - "d" -> 2.2, - "e" -> false - ).asJava), - Row("test", null, 1, 2.2, false)) + val df = Seq[(String, String, java.lang.Integer, java.lang.Long, + java.lang.Float, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null, null, null)) + .toDF("stringFieldA", "stringFieldB", "integerField", "longField", + "floatField", "doubleField", "booleanField") + + val fillMap = Map( + "stringFieldA" -> "test", + "integerField" -> 1, + "longField" -> 2L, + "floatField" -> 3.3f, + "doubleField" -> 4.4d, + "booleanField" -> false) + + val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) + + checkAnswer(df.na.fill(fillMap), expectedRow) + checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version + + // Ensure replacement values are cast to the column data type. + checkAnswer(df.na.fill(Map( + "integerField" -> 1d, + "longField" -> 2d, + "floatField" -> 3d, + "doubleField" -> 4d)), + Row(null, null, 1, 2L, 3f, 4d, null)) + + // Ensure column types do not change. Columns that have null values replaced + // will no longer be flagged as nullable, so do not compare schemas directly. + assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) } test("replace") { |