diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2017-02-16 09:42:13 -0800 |
---|---|---|
committer | Xiao Li <gatorsmile@gmail.com> | 2017-02-16 09:42:13 -0800 |
commit | 54a30c8a70c86294059e6eb6b30cb81978b47b54 (patch) | |
tree | 487ac72cd69144443ce55ca433fac2c40b69e134 /sql | |
parent | 3b4376876fabf7df4bd245dcf755222f4fe5f190 (diff) | |
download | spark-54a30c8a70c86294059e6eb6b30cb81978b47b54.tar.gz spark-54a30c8a70c86294059e6eb6b30cb81978b47b54.tar.bz2 spark-54a30c8a70c86294059e6eb6b30cb81978b47b54.zip |
[SPARK-19436][SQL] Add missing tests for approxQuantile
## What changes were proposed in this pull request?
1, check the behavior with illegal `quantiles` and `relativeError`
2, add tests for `relativeError` > 1
3, update tests for `null` data
4, update some docs for javadoc8
## How was this patch tested?
local test in spark-shell
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Author: Ruifeng Zheng <ruifengz@foxmail.com>
Closes #16776 from zhengruifeng/fix_approxQuantile.
Diffstat (limited to 'sql')
3 files changed, 88 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 2b782fd75c..bdcdf0c61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -58,12 +58,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (greater or equal to 0). + * @param relativeError The relative target precision to achieve (greater than or equal to 0). * If set to zero, the exact quantiles are computed, which could be very expensive. * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities * - * @note NaN values will be removed from the numerical column before calculation + * @note null and NaN values will be removed from the numerical column before calculation. If + * the dataframe is empty or all rows contain null or NaN, null is returned. * * @since 2.0.0 */ @@ -71,27 +72,25 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { col: String, probabilities: Array[Double], relativeError: Double): Array[Double] = { - StatFunctions.multipleApproxQuantiles(df.select(col).na.drop(), - Seq(col), probabilities, relativeError).head.toArray + val res = approxQuantile(Array(col), probabilities, relativeError) + Option(res).map(_.head).orNull } /** * Calculates the approximate quantiles of numerical columns of a DataFrame. - * @see [[DataFrameStatsFunctions.approxQuantile(col:Str* approxQuantile]] for - * detailed description. + * @see `approxQuantile(col:Str* approxQuantile)` for detailed description. * - * Note that rows containing any null or NaN values values will be removed before - * calculation. * @param cols the names of the numerical columns * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (>= 0). + * @param relativeError The relative target precision to achieve (greater than or equal to 0). * If set to zero, the exact quantiles are computed, which could be very expensive. * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities of each column * - * @note Rows containing any NaN values will be removed before calculation + * @note Rows containing any null or NaN values will be removed before calculation. If + * the dataframe is empty or all rows contain null or NaN, null is returned. * * @since 2.2.0 */ @@ -99,8 +98,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = { - StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, - probabilities, relativeError).map(_.toArray).toArray + // TODO: Update NaN/null handling to keep consistent with the single-column version + try { + StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, + probabilities, relativeError).map(_.toArray).toArray + } catch { + case e: NoSuchElementException => null + } } @@ -112,7 +116,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { probabilities: List[Double], relativeError: Double): java.util.List[java.util.List[Double]] = { approxQuantile(cols.toArray, probabilities.toArray, relativeError) - .map(_.toList.asJava).toList.asJava + .map(_.toList.asJava).toList.asJava } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 2b2e706125..c3d8859cb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -49,7 +49,7 @@ object StatFunctions extends Logging { * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (>= 0). + * @param relativeError The relative target precision to achieve (greater than or equal 0). * If set to zero, the exact quantiles are computed, which could be very expensive. * Note that values greater than 1 are accepted but give the same result as 1. * @@ -60,6 +60,8 @@ object StatFunctions extends Logging { cols: Seq[String], probabilities: Seq[Double], relativeError: Double): Seq[Seq[Double]] = { + require(relativeError >= 0, + s"Relative Error must be non-negative but got $relativeError") val columns: Seq[Column] = cols.map { colName => val field = df.schema(colName) require(field.dataType.isInstanceOf[NumericType], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index f52b18e27b..d0910e618a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} class DataFrameStatSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -159,16 +159,75 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(md1 - 2 * q1 * n) < error_double) assert(math.abs(md2 - 2 * q2 * n) < error_double) } - // test approxQuantile on NaN values - val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") - val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) + + // quantile should be in the range [0.0, 1.0] + val e = intercept[IllegalArgumentException] { + df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head) + } + assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]")) + + // relativeError should be non-negative + val e2 = intercept[IllegalArgumentException] { + df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0) + } + assert(e2.getMessage.contains("Relative Error must be non-negative")) + + // return null if the dataset is empty + val res1 = df.selectExpr("*").limit(0) + .stat.approxQuantile("singles", Array(q1, q2), epsilons.head) + assert(res1 === null) + + val res2 = df.selectExpr("*").limit(0) + .stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head) + assert(res2 === null) + } + + test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") { + val n = 1000 + val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + + val q1 = 0.5 + val q2 = 0.8 + val epsilons = List(2.0, 5.0, 100.0) + + val Array(single1_1) = df.stat.approxQuantile("singles", Array(q1), 1.0) + val Array(s1_1, s2_1) = df.stat.approxQuantile("singles", Array(q1, q2), 1.0) + val Array(Array(ms1_1, ms2_1), Array(md1_1, md2_1)) = + df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), 1.0) + + for (epsilon <- epsilons) { + val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon) + val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) + val Array(Array(ms1, ms2), Array(md1, md2)) = + df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) + assert(single1_1 === single1) + assert(s1_1 === s1) + assert(s2_1 === s2) + assert(ms1_1 === ms1) + assert(ms2_1 === ms2) + assert(md1_1 === md1) + assert(md2_1 === md2) + } + } + + test("approximate quantile 3: test on NaN and null values") { + val q1 = 0.5 + val q2 = 0.8 + val epsilon = 0.1 + val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0), + Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0), + Row(-1.0, null), Row(Double.NaN, null))) + val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true), + StructField("input2", DoubleType, nullable = true))) + val dfNaN = spark.createDataFrame(rows, schema) + val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) assert(resNaN.count(_.isNaN) === 0) - // test approxQuantile on multi-column NaN values - val dfNaN2 = Seq((Double.NaN, 1.0), (1.0, 1.0), (-1.0, Double.NaN), (Double.NaN, Double.NaN)) - .toDF("input1", "input2") - val resNaN2 = dfNaN2.stat.approxQuantile(Array("input1", "input2"), - Array(q1, q2), epsilons.head) + assert(resNaN.count(_ == null) === 0) + + val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"), + Array(q1, q2), epsilon) assert(resNaN2.flatten.count(_.isNaN) === 0) + assert(resNaN2.flatten.count(_ == null) === 0) } test("crosstab") { |