aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala89
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala18
2 files changed, 80 insertions, 27 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 184c5a1129..28820681cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -128,6 +128,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
*
+ * @since 2.2.0
+ */
+ def fill(value: Long): DataFrame = fill(value, df.columns)
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1
*/
def fill(value: Double): DataFrame = fill(value, df.columns)
@@ -143,6 +149,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
* If a specified column is not a numeric column, it is ignored.
*
+ * @since 2.2.0
+ */
+ def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
+ * If a specified column is not a numeric column, it is ignored.
+ *
* @since 1.3.1
*/
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
@@ -151,20 +165,18 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
+ * @since 2.2.0
+ */
+ def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
+ * numeric columns. If a specified column is not a numeric column, it is ignored.
+ *
* @since 1.3.1
*/
- def fill(value: Double, cols: Seq[String]): DataFrame = {
- val columnEquals = df.sparkSession.sessionState.analyzer.resolver
- val projections = df.schema.fields.map { f =>
- // Only fill if the column is part of the cols list.
- if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
- fillCol[Double](f, value)
- } else {
- df.col(f.name)
- }
- }
- df.select(projections : _*)
- }
+ def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
+
/**
* Returns a new `DataFrame` that replaces null values in specified string columns.
@@ -180,18 +192,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: String, cols: Seq[String]): DataFrame = {
- val columnEquals = df.sparkSession.sessionState.analyzer.resolver
- val projections = df.schema.fields.map { f =>
- // Only fill if the column is part of the cols list.
- if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
- fillCol[String](f, value)
- } else {
- df.col(f.name)
- }
- }
- df.select(projections : _*)
- }
+ def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
/**
* Returns a new `DataFrame` that replaces null values.
@@ -210,7 +211,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq)
+ def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values.
@@ -230,7 +231,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
+ def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
/**
* Replaces values matching keys in `replacement` map with the corresponding values.
@@ -368,7 +369,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
df.select(projections : _*)
}
- private def fill0(values: Seq[(String, Any)]): DataFrame = {
+ private def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling
values.foreach { case (colName, replaceValue) =>
// Check column name exists
@@ -435,4 +436,38 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
case v => throw new IllegalArgumentException(
s"Unsupported value type ${v.getClass.getName} ($v).")
}
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in specified
+ * numeric, string columns. If a specified column is not a numeric, string column,
+ * it is ignored.
+ */
+ private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
+ // the fill[T] which T is Long/Double,
+ // should apply on all the NumericType Column, for example:
+ // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
+ // input.na.fill(3.1)
+ // the result is (3,164.3), not (null, 164.3)
+ val targetType = value match {
+ case _: Double | _: Long => NumericType
+ case _: String => StringType
+ case _ => throw new IllegalArgumentException(
+ s"Unsupported value type ${value.getClass.getName} ($value).")
+ }
+
+ val columnEquals = df.sparkSession.sessionState.analyzer.resolver
+ val projections = df.schema.fields.map { f =>
+ val typeMatches = (targetType, f.dataType) match {
+ case (NumericType, dt) => dt.isInstanceOf[NumericType]
+ case (StringType, dt) => dt == StringType
+ }
+ // Only fill if the column is part of the cols list.
+ if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
+ fillCol[T](f, value)
+ } else {
+ df.col(f.name)
+ }
+ }
+ df.select(projections : _*)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 47b55e2547..fd829846ac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
Row("test", null))
+
+ checkAnswer(
+ Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L))
+ .toDF("a", "b").na.fill(0),
+ Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil
+ )
+
+ checkAnswer(
+ Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
+ .toDF("a", "b").na.fill(2.34),
+ Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil
+ )
+
+ checkAnswer(
+ Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
+ .toDF("a", "b").na.fill(5),
+ Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil
+ )
}
test("fill with map") {