diff options
Diffstat (limited to 'sql/core')
3 files changed, 55 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index bdcdf0c61f..c856d3099f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -64,7 +64,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return the approximate quantiles at the given probabilities * * @note null and NaN values will be removed from the numerical column before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * the dataframe is empty or the column only contains null or NaN, an empty array is returned. * * @since 2.0.0 */ @@ -72,8 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { col: String, probabilities: Array[Double], relativeError: Double): Array[Double] = { - val res = approxQuantile(Array(col), probabilities, relativeError) - Option(res).map(_.head).orNull + approxQuantile(Array(col), probabilities, relativeError).head } /** @@ -89,8 +88,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities of each column * - * @note Rows containing any null or NaN values will be removed before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * @note null and NaN values will be ignored in numerical columns before calculation. For + * columns only containing null or NaN values, an empty array is returned. * * @since 2.2.0 */ @@ -98,13 +97,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = { - // TODO: Update NaN/null handling to keep consistent with the single-column version - try { - StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, - probabilities, relativeError).map(_.toArray).toArray - } catch { - case e: NoSuchElementException => null - } + StatFunctions.multipleApproxQuantiles( + df.select(cols.map(col): _*), + cols, + probabilities, + relativeError).map(_.toArray).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c3d8859cb7..1debad03c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -54,6 +54,9 @@ object StatFunctions extends Logging { * Note that values greater than 1 are accepted but give the same result as 1. * * @return for each column, returns the requested approximations + * + * @note null and NaN values will be ignored in numerical columns before calculation. For + * a column only containing null or NaN values, an empty array is returned. */ def multipleApproxQuantiles( df: DataFrame, @@ -78,7 +81,10 @@ object StatFunctions extends Logging { def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { var i = 0 while (i < summaries.length) { - summaries(i) = summaries(i).insert(row.getDouble(i)) + if (!row.isNullAt(i)) { + val v = row.getDouble(i) + if (!v.isNaN) summaries(i) = summaries(i).insert(v) + } i += 1 } summaries @@ -91,7 +97,7 @@ object StatFunctions extends Logging { } val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) - summaries.map { summary => probabilities.map(summary.query) } + summaries.map { summary => probabilities.flatMap(summary.query) } } /** Calculate the Pearson Correlation Coefficient for the given columns */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index d0910e618a..97890a035a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -171,15 +171,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0) } assert(e2.getMessage.contains("Relative Error must be non-negative")) - - // return null if the dataset is empty - val res1 = df.selectExpr("*").limit(0) - .stat.approxQuantile("singles", Array(q1, q2), epsilons.head) - assert(res1 === null) - - val res2 = df.selectExpr("*").limit(0) - .stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head) - assert(res2 === null) } test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") { @@ -214,20 +205,48 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val q1 = 0.5 val q2 = 0.8 val epsilon = 0.1 - val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0), - Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0), - Row(-1.0, null), Row(Double.NaN, null))) + val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0, Double.NaN), + Row(1.0, -1.0, null), Row(-1.0, Double.NaN, null), Row(Double.NaN, Double.NaN, null), + Row(null, null, Double.NaN), Row(null, 1.0, null), Row(-1.0, null, Double.NaN), + Row(Double.NaN, null, null))) val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true), - StructField("input2", DoubleType, nullable = true))) + StructField("input2", DoubleType, nullable = true), + StructField("input3", DoubleType, nullable = true))) val dfNaN = spark.createDataFrame(rows, schema) - val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) - assert(resNaN.count(_.isNaN) === 0) - assert(resNaN.count(_ == null) === 0) - val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"), + val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(resNaN1.count(_.isNaN) === 0) + assert(resNaN1.count(_ == null) === 0) + + val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon) + assert(resNaN2.count(_.isNaN) === 0) + assert(resNaN2.count(_ == null) === 0) + + val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon) + assert(resNaN3.isEmpty) + + val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"), Array(q1, q2), epsilon) - assert(resNaN2.flatten.count(_.isNaN) === 0) - assert(resNaN2.flatten.count(_ == null) === 0) + assert(resNaNAll.flatten.count(_.isNaN) === 0) + assert(resNaNAll.flatten.count(_ == null) === 0) + + assert(resNaN1(0) === resNaNAll(0)(0)) + assert(resNaN1(1) === resNaNAll(0)(1)) + assert(resNaN2(0) === resNaNAll(1)(0)) + assert(resNaN2(1) === resNaNAll(1)(1)) + + // return empty array for columns only containing null or NaN values + assert(resNaNAll(2).isEmpty) + + // return empty array if the dataset is empty + val res1 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(res1.isEmpty) + + val res2 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile(Array("input1", "input2"), Array(q1, q2), epsilon) + assert(res2(0).isEmpty) + assert(res2(1).isEmpty) } test("crosstab") { |