aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2017-03-20 18:25:59 -0700
committerXiao Li <gatorsmile@gmail.com>2017-03-20 18:25:59 -0700
commit10691d36de902e3771af20aed40336b4f99de719 (patch)
treed823a3f7e499b99c1447d0bf6a4102d6ece64a2a /sql/core
parentc2d1761a57f5d175913284533b3d0417e8718688 (diff)
downloadspark-10691d36de902e3771af20aed40336b4f99de719.tar.gz
spark-10691d36de902e3771af20aed40336b4f99de719.tar.bz2
spark-10691d36de902e3771af20aed40336b4f99de719.zip
[SPARK-19573][SQL] Make NaN/null handling consistent in approxQuantile
## What changes were proposed in this pull request? update `StatFunctions.multipleApproxQuantiles` to handle NaN/null ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #16971 from zhengruifeng/quantiles_nan.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala57
3 files changed, 55 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index bdcdf0c61f..c856d3099f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -64,7 +64,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* @return the approximate quantiles at the given probabilities
*
* @note null and NaN values will be removed from the numerical column before calculation. If
- * the dataframe is empty or all rows contain null or NaN, null is returned.
+ * the dataframe is empty or the column only contains null or NaN, an empty array is returned.
*
* @since 2.0.0
*/
@@ -72,8 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
col: String,
probabilities: Array[Double],
relativeError: Double): Array[Double] = {
- val res = approxQuantile(Array(col), probabilities, relativeError)
- Option(res).map(_.head).orNull
+ approxQuantile(Array(col), probabilities, relativeError).head
}
/**
@@ -89,8 +88,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* Note that values greater than 1 are accepted but give the same result as 1.
* @return the approximate quantiles at the given probabilities of each column
*
- * @note Rows containing any null or NaN values will be removed before calculation. If
- * the dataframe is empty or all rows contain null or NaN, null is returned.
+ * @note null and NaN values will be ignored in numerical columns before calculation. For
+ * columns only containing null or NaN values, an empty array is returned.
*
* @since 2.2.0
*/
@@ -98,13 +97,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
cols: Array[String],
probabilities: Array[Double],
relativeError: Double): Array[Array[Double]] = {
- // TODO: Update NaN/null handling to keep consistent with the single-column version
- try {
- StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols,
- probabilities, relativeError).map(_.toArray).toArray
- } catch {
- case e: NoSuchElementException => null
- }
+ StatFunctions.multipleApproxQuantiles(
+ df.select(cols.map(col): _*),
+ cols,
+ probabilities,
+ relativeError).map(_.toArray).toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index c3d8859cb7..1debad03c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -54,6 +54,9 @@ object StatFunctions extends Logging {
* Note that values greater than 1 are accepted but give the same result as 1.
*
* @return for each column, returns the requested approximations
+ *
+ * @note null and NaN values will be ignored in numerical columns before calculation. For
+ * a column only containing null or NaN values, an empty array is returned.
*/
def multipleApproxQuantiles(
df: DataFrame,
@@ -78,7 +81,10 @@ object StatFunctions extends Logging {
def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = {
var i = 0
while (i < summaries.length) {
- summaries(i) = summaries(i).insert(row.getDouble(i))
+ if (!row.isNullAt(i)) {
+ val v = row.getDouble(i)
+ if (!v.isNaN) summaries(i) = summaries(i).insert(v)
+ }
i += 1
}
summaries
@@ -91,7 +97,7 @@ object StatFunctions extends Logging {
}
val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge)
- summaries.map { summary => probabilities.map(summary.query) }
+ summaries.map { summary => probabilities.flatMap(summary.query) }
}
/** Calculate the Pearson Correlation Coefficient for the given columns */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index d0910e618a..97890a035a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -171,15 +171,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0)
}
assert(e2.getMessage.contains("Relative Error must be non-negative"))
-
- // return null if the dataset is empty
- val res1 = df.selectExpr("*").limit(0)
- .stat.approxQuantile("singles", Array(q1, q2), epsilons.head)
- assert(res1 === null)
-
- val res2 = df.selectExpr("*").limit(0)
- .stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head)
- assert(res2 === null)
}
test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") {
@@ -214,20 +205,48 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
val q1 = 0.5
val q2 = 0.8
val epsilon = 0.1
- val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0),
- Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0),
- Row(-1.0, null), Row(Double.NaN, null)))
+ val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0, Double.NaN),
+ Row(1.0, -1.0, null), Row(-1.0, Double.NaN, null), Row(Double.NaN, Double.NaN, null),
+ Row(null, null, Double.NaN), Row(null, 1.0, null), Row(-1.0, null, Double.NaN),
+ Row(Double.NaN, null, null)))
val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true),
- StructField("input2", DoubleType, nullable = true)))
+ StructField("input2", DoubleType, nullable = true),
+ StructField("input3", DoubleType, nullable = true)))
val dfNaN = spark.createDataFrame(rows, schema)
- val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon)
- assert(resNaN.count(_.isNaN) === 0)
- assert(resNaN.count(_ == null) === 0)
- val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"),
+ val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon)
+ assert(resNaN1.count(_.isNaN) === 0)
+ assert(resNaN1.count(_ == null) === 0)
+
+ val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon)
+ assert(resNaN2.count(_.isNaN) === 0)
+ assert(resNaN2.count(_ == null) === 0)
+
+ val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon)
+ assert(resNaN3.isEmpty)
+
+ val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"),
Array(q1, q2), epsilon)
- assert(resNaN2.flatten.count(_.isNaN) === 0)
- assert(resNaN2.flatten.count(_ == null) === 0)
+ assert(resNaNAll.flatten.count(_.isNaN) === 0)
+ assert(resNaNAll.flatten.count(_ == null) === 0)
+
+ assert(resNaN1(0) === resNaNAll(0)(0))
+ assert(resNaN1(1) === resNaNAll(0)(1))
+ assert(resNaN2(0) === resNaNAll(1)(0))
+ assert(resNaN2(1) === resNaNAll(1)(1))
+
+ // return empty array for columns only containing null or NaN values
+ assert(resNaNAll(2).isEmpty)
+
+ // return empty array if the dataset is empty
+ val res1 = dfNaN.selectExpr("*").limit(0)
+ .stat.approxQuantile("input1", Array(q1, q2), epsilon)
+ assert(res1.isEmpty)
+
+ val res2 = dfNaN.selectExpr("*").limit(0)
+ .stat.approxQuantile(Array("input1", "input2"), Array(q1, q2), epsilon)
+ assert(res2(0).isEmpty)
+ assert(res2(1).isEmpty)
}
test("crosstab") {