aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2017-02-16 09:42:13 -0800
committerXiao Li <gatorsmile@gmail.com>2017-02-16 09:42:13 -0800
commit54a30c8a70c86294059e6eb6b30cb81978b47b54 (patch)
tree487ac72cd69144443ce55ca433fac2c40b69e134 /sql/core/src/main/scala/org
parent3b4376876fabf7df4bd245dcf755222f4fe5f190 (diff)
downloadspark-54a30c8a70c86294059e6eb6b30cb81978b47b54.tar.gz
spark-54a30c8a70c86294059e6eb6b30cb81978b47b54.tar.bz2
spark-54a30c8a70c86294059e6eb6b30cb81978b47b54.zip
[SPARK-19436][SQL] Add missing tests for approxQuantile
## What changes were proposed in this pull request? 1, check the behavior with illegal `quantiles` and `relativeError` 2, add tests for `relativeError` > 1 3, update tests for `null` data 4, update some docs for javadoc8 ## How was this patch tested? local test in spark-shell Author: Zheng RuiFeng <ruifengz@foxmail.com> Author: Ruifeng Zheng <ruifengz@foxmail.com> Closes #16776 from zhengruifeng/fix_approxQuantile.
Diffstat (limited to 'sql/core/src/main/scala/org')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala4
2 files changed, 20 insertions, 14 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 2b782fd75c..bdcdf0c61f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -58,12 +58,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* @param probabilities a list of quantile probabilities
* Each number must belong to [0, 1].
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (greater or equal to 0).
+ * @param relativeError The relative target precision to achieve (greater than or equal to 0).
* If set to zero, the exact quantiles are computed, which could be very expensive.
* Note that values greater than 1 are accepted but give the same result as 1.
* @return the approximate quantiles at the given probabilities
*
- * @note NaN values will be removed from the numerical column before calculation
+ * @note null and NaN values will be removed from the numerical column before calculation. If
+ * the dataframe is empty or all rows contain null or NaN, null is returned.
*
* @since 2.0.0
*/
@@ -71,27 +72,25 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
col: String,
probabilities: Array[Double],
relativeError: Double): Array[Double] = {
- StatFunctions.multipleApproxQuantiles(df.select(col).na.drop(),
- Seq(col), probabilities, relativeError).head.toArray
+ val res = approxQuantile(Array(col), probabilities, relativeError)
+ Option(res).map(_.head).orNull
}
/**
* Calculates the approximate quantiles of numerical columns of a DataFrame.
- * @see [[DataFrameStatsFunctions.approxQuantile(col:Str* approxQuantile]] for
- * detailed description.
+ * @see `approxQuantile(col:Str* approxQuantile)` for detailed description.
*
- * Note that rows containing any null or NaN values values will be removed before
- * calculation.
* @param cols the names of the numerical columns
* @param probabilities a list of quantile probabilities
* Each number must belong to [0, 1].
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (>= 0).
+ * @param relativeError The relative target precision to achieve (greater than or equal to 0).
* If set to zero, the exact quantiles are computed, which could be very expensive.
* Note that values greater than 1 are accepted but give the same result as 1.
* @return the approximate quantiles at the given probabilities of each column
*
- * @note Rows containing any NaN values will be removed before calculation
+ * @note Rows containing any null or NaN values will be removed before calculation. If
+ * the dataframe is empty or all rows contain null or NaN, null is returned.
*
* @since 2.2.0
*/
@@ -99,8 +98,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
cols: Array[String],
probabilities: Array[Double],
relativeError: Double): Array[Array[Double]] = {
- StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols,
- probabilities, relativeError).map(_.toArray).toArray
+ // TODO: Update NaN/null handling to keep consistent with the single-column version
+ try {
+ StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols,
+ probabilities, relativeError).map(_.toArray).toArray
+ } catch {
+ case e: NoSuchElementException => null
+ }
}
@@ -112,7 +116,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
probabilities: List[Double],
relativeError: Double): java.util.List[java.util.List[Double]] = {
approxQuantile(cols.toArray, probabilities.toArray, relativeError)
- .map(_.toList.asJava).toList.asJava
+ .map(_.toList.asJava).toList.asJava
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 2b2e706125..c3d8859cb7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -49,7 +49,7 @@ object StatFunctions extends Logging {
* @param probabilities a list of quantile probabilities
* Each number must belong to [0, 1].
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (>= 0).
+ * @param relativeError The relative target precision to achieve (greater than or equal 0).
* If set to zero, the exact quantiles are computed, which could be very expensive.
* Note that values greater than 1 are accepted but give the same result as 1.
*
@@ -60,6 +60,8 @@ object StatFunctions extends Logging {
cols: Seq[String],
probabilities: Seq[Double],
relativeError: Double): Seq[Seq[Double]] = {
+ require(relativeError >= 0,
+ s"Relative Error must be non-negative but got $relativeError")
val columns: Seq[Column] = cols.map { colName =>
val field = df.schema(colName)
require(field.dataType.isInstanceOf[NumericType],