aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-07-21 08:25:50 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-21 08:25:50 -0700
commitbe5c5d3741256697cc76938a8ed6f609eb2d4b11 (patch)
tree03366bfb866d93023dfc0a3f2069b388d4a8f7c8
parentf5b6dc5e3e7e3b586096b71164f052318b840e8a (diff)
downloadspark-be5c5d3741256697cc76938a8ed6f609eb2d4b11.tar.gz
spark-be5c5d3741256697cc76938a8ed6f609eb2d4b11.tar.bz2
spark-be5c5d3741256697cc76938a8ed6f609eb2d4b11.zip
[SPARK-9081] [SPARK-9168] [SQL] nanvl & dropna/fillna supporting nan as well
JIRA: https://issues.apache.org/jira/browse/SPARK-9081 https://issues.apache.org/jira/browse/SPARK-9168 This PR target at two modifications: 1. Change `isNaN` to return `false` on `null` input 2. Make `dropna` and `fillna` to fill/drop NaN values as well 3. Implement `nanvl` Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #7523 from yjshen/fillna_dropna and squashes the following commits: f0a51db [Yijie Shen] make coalesce untouched and implement nanvl 1d3e35f [Yijie Shen] make Coalesce aware of NaN in order to support fillna 2760cbc [Yijie Shen] change isNaN(null) to false as well as implement dropna
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala104
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala69
9 files changed, 222 insertions, 88 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 13523720da..e3d8d2adf2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -89,6 +89,7 @@ object FunctionRegistry {
expression[CreateStruct]("struct"),
expression[CreateNamedStruct]("named_struct"),
expression[Sqrt]("sqrt"),
+ expression[NaNvl]("nanvl"),
// math functions
expression[Acos]("acos"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 98c6708464..287718fab7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -83,7 +83,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
/**
- * Evaluates to `true` if it's NaN or null
+ * Evaluates to `true` iff it's NaN.
*/
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {
@@ -95,7 +95,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
- true
+ false
} else {
child.dataType match {
case DoubleType => value.asInstanceOf[Double].isNaN
@@ -107,26 +107,65 @@ case class IsNaN(child: Expression) extends UnaryExpression
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
child.dataType match {
- case FloatType =>
+ case DoubleType | FloatType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (${eval.isNull}) {
- ${ev.primitive} = true;
- } else {
- ${ev.primitive} = Float.isNaN(${eval.primitive});
- }
+ ${ev.primitive} = !${eval.isNull} && Double.isNaN(${eval.primitive});
"""
- case DoubleType =>
+ }
+ }
+}
+
+/**
+ * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise.
+ * This Expression is useful for mapping NaN values to null.
+ */
+case class NaNvl(left: Expression, right: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = left.dataType
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType))
+
+ override def eval(input: InternalRow): Any = {
+ val value = left.eval(input)
+ if (value == null) {
+ null
+ } else {
+ left.dataType match {
+ case DoubleType =>
+ if (!value.asInstanceOf[Double].isNaN) value else right.eval(input)
+ case FloatType =>
+ if (!value.asInstanceOf[Float].isNaN) value else right.eval(input)
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val leftGen = left.gen(ctx)
+ val rightGen = right.gen(ctx)
+ left.dataType match {
+ case DoubleType | FloatType =>
s"""
- ${eval.code}
+ ${leftGen.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (${eval.isNull}) {
- ${ev.primitive} = true;
+ if (${leftGen.isNull}) {
+ ${ev.isNull} = true;
} else {
- ${ev.primitive} = Double.isNaN(${eval.primitive});
+ if (!Double.isNaN(${leftGen.primitive})) {
+ ${ev.primitive} = ${leftGen.primitive};
+ } else {
+ ${rightGen.code}
+ if (${rightGen.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = ${rightGen.primitive};
+ }
+ }
}
"""
}
@@ -186,8 +225,15 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
var numNonNulls = 0
var i = 0
while (i < childrenArray.length && numNonNulls < n) {
- if (childrenArray(i).eval(input) != null) {
- numNonNulls += 1
+ val evalC = childrenArray(i).eval(input)
+ if (evalC != null) {
+ childrenArray(i).dataType match {
+ case DoubleType =>
+ if (!evalC.asInstanceOf[Double].isNaN) numNonNulls += 1
+ case FloatType =>
+ if (!evalC.asInstanceOf[Float].isNaN) numNonNulls += 1
+ case _ => numNonNulls += 1
+ }
}
i += 1
}
@@ -198,14 +244,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
val nonnull = ctx.freshName("nonnull")
val code = children.map { e =>
val eval = e.gen(ctx)
- s"""
- if ($nonnull < $n) {
- ${eval.code}
- if (!${eval.isNull}) {
- $nonnull += 1;
- }
- }
- """
+ e.dataType match {
+ case DoubleType | FloatType =>
+ s"""
+ if ($nonnull < $n) {
+ ${eval.code}
+ if (!${eval.isNull} && !Double.isNaN(${eval.primitive})) {
+ $nonnull += 1;
+ }
+ }
+ """
+ case _ =>
+ s"""
+ if ($nonnull < $n) {
+ ${eval.code}
+ if (!${eval.isNull}) {
+ $nonnull += 1;
+ }
+ }
+ """
+ }
}.mkString("\n")
s"""
int $nonnull = 0;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index a53ec31ee6..3f1bd2a925 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -121,7 +121,6 @@ case class InSet(child: Expression, hset: Set[Any])
}
}
-
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
override def inputType: AbstractDataType = BooleanType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index 765cc7a969..0728f6695c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -49,12 +49,22 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
- checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
+ checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false)
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
checkEvaluation(IsNaN(Literal(5.5f)), false)
}
+ test("nanvl") {
+ checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0)
+ checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null)
+ checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null)
+ checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0)
+ checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null)
+ assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)).
+ eval(EmptyRow).asInstanceOf[Double].isNaN)
+ }
+
test("coalesce") {
testAllTypes { (value: Any, tpe: DataType) =>
val lit = Literal.create(value, tpe)
@@ -66,4 +76,31 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
}
}
+
+ test("AtLeastNNonNulls") {
+ val mix = Seq(Literal("x"),
+ Literal.create(null, StringType),
+ Literal.create(null, DoubleType),
+ Literal(Double.NaN),
+ Literal(5f))
+
+ val nanOnly = Seq(Literal("x"),
+ Literal(10.0),
+ Literal(Float.NaN),
+ Literal(math.log(-2)),
+ Literal(Double.MaxValue))
+
+ val nullOnly = Seq(Literal("x"),
+ Literal.create(null, DoubleType),
+ Literal.create(null, DecimalType.Unlimited),
+ Literal(Float.MaxValue),
+ Literal(false))
+
+ checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 221cd04c6d..6e2a6525bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -401,7 +401,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
}
/**
- * True if the current expression is NaN or null
+ * True if the current expression is NaN.
*
* @group expr_ops
* @since 1.5.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 8681a56c82..a4fd4cf3b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -37,24 +37,24 @@ import org.apache.spark.sql.types._
final class DataFrameNaFunctions private[sql](df: DataFrame) {
/**
- * Returns a new [[DataFrame]] that drops rows containing any null values.
+ * Returns a new [[DataFrame]] that drops rows containing any null or NaN values.
*
* @since 1.3.1
*/
def drop(): DataFrame = drop("any", df.columns)
/**
- * Returns a new [[DataFrame]] that drops rows containing null values.
+ * Returns a new [[DataFrame]] that drops rows containing null or NaN values.
*
- * If `how` is "any", then drop rows containing any null values.
- * If `how` is "all", then drop rows only if every column is null for that row.
+ * If `how` is "any", then drop rows containing any null or NaN values.
+ * If `how` is "all", then drop rows only if every column is null or NaN for that row.
*
* @since 1.3.1
*/
def drop(how: String): DataFrame = drop(how, df.columns)
/**
- * Returns a new [[DataFrame]] that drops rows containing any null values
+ * Returns a new [[DataFrame]] that drops rows containing any null or NaN values
* in the specified columns.
*
* @since 1.3.1
@@ -62,7 +62,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
def drop(cols: Array[String]): DataFrame = drop(cols.toSeq)
/**
- * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values
+ * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing any null or NaN values
* in the specified columns.
*
* @since 1.3.1
@@ -70,22 +70,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols)
/**
- * Returns a new [[DataFrame]] that drops rows containing null values
+ * Returns a new [[DataFrame]] that drops rows containing null or NaN values
* in the specified columns.
*
- * If `how` is "any", then drop rows containing any null values in the specified columns.
- * If `how` is "all", then drop rows only if every specified column is null for that row.
+ * If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
+ * If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
*
* @since 1.3.1
*/
def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq)
/**
- * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values
+ * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null or NaN values
* in the specified columns.
*
- * If `how` is "any", then drop rows containing any null values in the specified columns.
- * If `how` is "all", then drop rows only if every specified column is null for that row.
+ * If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
+ * If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
*
* @since 1.3.1
*/
@@ -98,15 +98,16 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
}
/**
- * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values.
+ * Returns a new [[DataFrame]] that drops rows containing
+ * less than `minNonNulls` non-null and non-NaN values.
*
* @since 1.3.1
*/
def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns)
/**
- * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null
- * values in the specified columns.
+ * Returns a new [[DataFrame]] that drops rows containing
+ * less than `minNonNulls` non-null and non-NaN values in the specified columns.
*
* @since 1.3.1
*/
@@ -114,32 +115,33 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
/**
* (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than
- * `minNonNulls` non-null values in the specified columns.
+ * `minNonNulls` non-null and non-NaN values in the specified columns.
*
* @since 1.3.1
*/
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
- // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values.
+ // Filtering condition:
+ // only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
df.filter(Column(predicate))
}
/**
- * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`.
+ * Returns a new [[DataFrame]] that replaces null or NaN values in numeric columns with `value`.
*
* @since 1.3.1
*/
def fill(value: Double): DataFrame = fill(value, df.columns)
/**
- * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`.
+ * Returns a new [[DataFrame]] that replaces null values in string columns with `value`.
*
* @since 1.3.1
*/
def fill(value: String): DataFrame = fill(value, df.columns)
/**
- * Returns a new [[DataFrame]] that replaces null values in specified numeric columns.
+ * Returns a new [[DataFrame]] that replaces null or NaN values in specified numeric columns.
* If a specified column is not a numeric column, it is ignored.
*
* @since 1.3.1
@@ -147,7 +149,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
/**
- * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified
+ * (Scala-specific) Returns a new [[DataFrame]] that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
* @since 1.3.1
@@ -391,7 +393,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
- coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
+ col.dataType match {
+ case DoubleType | FloatType =>
+ coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)),
+ lit(replacement).cast(col.dataType)).as(col.name)
+ case _ =>
+ coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 60b089180c..d94d733582 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -595,7 +595,7 @@ object functions {
}
/**
- * Returns the first column that is not null.
+ * Returns the first column that is not null and not NaN.
* {{{
* df.select(coalesce(df("a"), df("b")))
* }}}
@@ -612,7 +612,7 @@ object functions {
def explode(e: Column): Column = Explode(e.expr)
/**
- * Return true if the column is NaN or null
+ * Return true iff the column is NaN.
*
* @group normal_funcs
* @since 1.5.0
@@ -637,6 +637,15 @@ object functions {
def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID()
/**
+ * Return an alternative value `r` if `l` is NaN.
+ * This function is useful for mapping NaN values to null.
+ *
+ * @group normal_funcs
+ * @since 1.5.0
+ */
+ def nanvl(l: Column, r: Column): Column = NaNvl(l.expr, r.expr)
+
+ /**
* Unary minus, i.e. negate the expression.
* {{{
* // Select the amount column and negates all values.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 6bd5804196..1f9f7118c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -211,15 +211,34 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(
testData.select($"a".isNaN, $"b".isNaN),
- Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil)
+ Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
checkAnswer(
testData.select(isNaN($"a"), isNaN($"b")),
- Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil)
+ Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
checkAnswer(
ctx.sql("select isnan(15), isnan('invalid')"),
- Row(false, true))
+ Row(false, false))
+ }
+
+ test("nanvl") {
+ val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
+ Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil),
+ StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
+ StructField("c", DoubleType), StructField("d", DoubleType))))
+
+ checkAnswer(
+ testData.select(
+ nanvl($"a", lit(5)), nanvl($"b", lit(10)),
+ nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))),
+ Row(null, 3.0, null, Double.PositiveInfinity)
+ )
+ testData.registerTempTable("t")
+ checkAnswer(
+ ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"),
+ Row(null, 3.0, null, Double.PositiveInfinity)
+ )
}
test("===") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 495701d4f6..dbe3b44ee2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -30,8 +30,10 @@ class DataFrameNaFunctionsSuite extends QueryTest {
("Bob", 16, 176.5),
("Alice", null, 164.3),
("David", 60, null),
+ ("Nina", 25, Double.NaN),
("Amy", null, null),
- (null, null, null)).toDF("name", "age", "height")
+ (null, null, null)
+ ).toDF("name", "age", "height")
}
test("drop") {
@@ -39,12 +41,12 @@ class DataFrameNaFunctionsSuite extends QueryTest {
val rows = input.collect()
checkAnswer(
- input.na.drop("name" :: Nil),
- rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil)
+ input.na.drop("name" :: Nil).select("name"),
+ Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil)
checkAnswer(
- input.na.drop("age" :: Nil),
- rows(0) :: rows(2) :: Nil)
+ input.na.drop("age" :: Nil).select("name"),
+ Row("Bob") :: Row("David") :: Row("Nina") :: Nil)
checkAnswer(
input.na.drop("age" :: "height" :: Nil),
@@ -67,8 +69,8 @@ class DataFrameNaFunctionsSuite extends QueryTest {
val rows = input.collect()
checkAnswer(
- input.na.drop("all"),
- rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil)
+ input.na.drop("all").select("name"),
+ Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil)
checkAnswer(
input.na.drop("any"),
@@ -79,8 +81,8 @@ class DataFrameNaFunctionsSuite extends QueryTest {
rows(0) :: Nil)
checkAnswer(
- input.na.drop("all", Seq("age", "height")),
- rows(0) :: rows(1) :: rows(2) :: Nil)
+ input.na.drop("all", Seq("age", "height")).select("name"),
+ Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Nil)
}
test("drop with threshold") {
@@ -108,6 +110,7 @@ class DataFrameNaFunctionsSuite extends QueryTest {
Row("Bob", 16, 176.5) ::
Row("Alice", 50, 164.3) ::
Row("David", 60, 50.6) ::
+ Row("Nina", 25, 50.6) ::
Row("Amy", 50, 50.6) ::
Row(null, 50, 50.6) :: Nil)
@@ -117,17 +120,19 @@ class DataFrameNaFunctionsSuite extends QueryTest {
// string
checkAnswer(
input.na.fill("unknown").select("name"),
- Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil)
+ Row("Bob") :: Row("Alice") :: Row("David") ::
+ Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil)
assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)
// fill double with subset columns
checkAnswer(
- input.na.fill(50.6, "age" :: Nil),
- Row("Bob", 16, 176.5) ::
- Row("Alice", 50, 164.3) ::
- Row("David", 60, null) ::
- Row("Amy", 50, null) ::
- Row(null, 50, null) :: Nil)
+ input.na.fill(50.6, "age" :: Nil).select("name", "age"),
+ Row("Bob", 16) ::
+ Row("Alice", 50) ::
+ Row("David", 60) ::
+ Row("Nina", 25) ::
+ Row("Amy", 50) ::
+ Row(null, 50) :: Nil)
// fill string with subset columns
checkAnswer(
@@ -164,29 +169,27 @@ class DataFrameNaFunctionsSuite extends QueryTest {
16 -> 61,
60 -> 6,
164.3 -> 461.3 // Alice is really tall
- ))
+ )).collect()
- checkAnswer(
- out,
- Row("Bob", 61, 176.5) ::
- Row("Alice", null, 461.3) ::
- Row("David", 6, null) ::
- Row("Amy", null, null) ::
- Row(null, null, null) :: Nil)
+ assert(out(0) === Row("Bob", 61, 176.5))
+ assert(out(1) === Row("Alice", null, 461.3))
+ assert(out(2) === Row("David", 6, null))
+ assert(out(3).get(2).asInstanceOf[Double].isNaN)
+ assert(out(4) === Row("Amy", null, null))
+ assert(out(5) === Row(null, null, null))
// Replace only the age column
val out1 = input.na.replace("age", Map(
16 -> 61,
60 -> 6,
164.3 -> 461.3 // Alice is really tall
- ))
-
- checkAnswer(
- out1,
- Row("Bob", 61, 176.5) ::
- Row("Alice", null, 164.3) ::
- Row("David", 6, null) ::
- Row("Amy", null, null) ::
- Row(null, null, null) :: Nil)
+ )).collect()
+
+ assert(out1(0) === Row("Bob", 61, 176.5))
+ assert(out1(1) === Row("Alice", null, 164.3))
+ assert(out1(2) === Row("David", 6, null))
+ assert(out1(3).get(2).asInstanceOf[Double].isNaN)
+ assert(out1(4) === Row("Amy", null, null))
+ assert(out1(5) === Row(null, null, null))
}
}