diff options
author | JihongMa <linlin200605@gmail.com> | 2015-11-18 13:03:37 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-18 13:03:37 -0800 |
commit | 09ad9533d5760652de59fa4830c24cb8667958ac (patch) | |
tree | 6e6023e1d2df2ccf565f9df1bf26e82904a70363 | |
parent | 7c5b641808740ba5eed05ba8204cdbaf3fc579f5 (diff) | |
download | spark-09ad9533d5760652de59fa4830c24cb8667958ac.tar.gz spark-09ad9533d5760652de59fa4830c24cb8667958ac.tar.bz2 spark-09ad9533d5760652de59fa4830c24cb8667958ac.zip |
[SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function
return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null.
Author: JihongMa <linlin200605@gmail.com>
Closes #9705 from JihongMA/SPARK-11720.
8 files changed, 53 insertions, 25 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad6ad0235a..0dd75ba7ca 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ class DataFrame(object): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| NaN| + | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index de5872ab11..d07d4c338c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 8fa3aac9f1..c2bf2cb941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -37,16 +37,17 @@ case class Kurtosis(child: Expression, override protected val momentOrder = 4 // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index e1c01a5b82..9411bcea25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -36,16 +36,17 @@ case class Skewness(child: Expression, override protected val momentOrder = 3 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 05dd5e3b22..eec79a9033 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -36,11 +36,17 @@ case class StddevSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } } } @@ -62,10 +68,14 @@ case class StddevPop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ede2da2805..cf3a740305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } } } @@ -62,10 +68,14 @@ case class VariancePop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else moments(2) / n + if (n == 0.0) { + null + } else { + moments(2) / n + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 432e8d1762..71adf2148a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null)) } test("zero sum") { @@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5a7f24684d..6399b0165c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", "NaN", "NaN"), + Row("stddev", null, null), Row("min", null, null), Row("max", null, null)) |