aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJihongMa <linlin200605@gmail.com>2015-11-18 13:03:37 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 13:03:37 -0800
commit09ad9533d5760652de59fa4830c24cb8667958ac (patch)
tree6e6023e1d2df2ccf565f9df1bf26e82904a70363
parent7c5b641808740ba5eed05ba8204cdbaf3fc579f5 (diff)
downloadspark-09ad9533d5760652de59fa4830c24cb8667958ac.tar.gz
spark-09ad9533d5760652de59fa4830c24cb8667958ac.tar.bz2
spark-09ad9533d5760652de59fa4830c24cb8667958ac.zip
[SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function
return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa <linlin200605@gmail.com> Closes #9705 from JihongMA/SPARK-11720.
-rw-r--r--python/pyspark/sql/dataframe.py2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala2
8 files changed, 53 insertions, 25 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ad6ad0235a..0dd75ba7ca 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -761,7 +761,7 @@ class DataFrame(object):
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
- | stddev|2.1213203435596424| NaN|
+ | stddev|2.1213203435596424| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index de5872ab11..d07d4c338c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
* @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
* needed to compute the aggregate stat.
*/
- def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double
+ def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any
override final def eval(buffer: InternalRow): Any = {
val n = buffer.getDouble(nOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index 8fa3aac9f1..c2bf2cb941 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -37,16 +37,17 @@ case class Kurtosis(child: Expression,
override protected val momentOrder = 4
// NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)
- if (n == 0.0 || m2 == 0.0) {
+ if (n == 0.0) {
+ null
+ } else if (m2 == 0.0) {
Double.NaN
- }
- else {
+ } else {
n * m4 / (m2 * m2) - 3.0
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index e1c01a5b82..9411bcea25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -36,16 +36,17 @@ case class Skewness(child: Expression,
override protected val momentOrder = 3
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)
- if (n == 0.0 || m2 == 0.0) {
+ if (n == 0.0) {
+ null
+ } else if (m2 == 0.0) {
Double.NaN
- }
- else {
+ } else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 05dd5e3b22..eec79a9033 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -36,11 +36,17 @@ case class StddevSamp(child: Expression,
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
+ if (n == 0.0) {
+ null
+ } else if (n == 1.0) {
+ Double.NaN
+ } else {
+ math.sqrt(moments(2) / (n - 1.0))
+ }
}
}
@@ -62,10 +68,14 @@ case class StddevPop(
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
+ if (n == 0.0) {
+ null
+ } else {
+ math.sqrt(moments(2) / n)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
index ede2da2805..cf3a740305 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
@@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression,
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
+ if (n == 0.0) {
+ null
+ } else if (n == 1.0) {
+ Double.NaN
+ } else {
+ moments(2) / (n - 1.0)
+ }
}
}
@@ -62,10 +68,14 @@ case class VariancePop(
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0) Double.NaN else moments(2) / n
+ if (n == 0.0) {
+ null
+ } else {
+ moments(2) / n
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 432e8d1762..71adf2148a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
- Row(Double.NaN, Double.NaN, Double.NaN))
+ Row(null, null, null))
}
test("zero sum") {
@@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
test("zero moments") {
val input = Seq((1, 2)).toDF("a", "b")
checkAnswer(
- input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
- Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
+ input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a),
+ var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
+ Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0,
+ Double.NaN, Double.NaN))
checkAnswer(
input.agg(
+ expr("stddev(a)"),
+ expr("stddev_samp(a)"),
+ expr("stddev_pop(a)"),
expr("variance(a)"),
expr("var_samp(a)"),
expr("var_pop(a)"),
expr("skewness(a)"),
expr("kurtosis(a)")),
- Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
+ Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0,
+ Double.NaN, Double.NaN))
}
test("null moments") {
@@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
checkAnswer(
emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
- Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
+ Row(null, null, null, null, null))
checkAnswer(
emptyTableData.agg(
@@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
expr("var_pop(a)"),
expr("skewness(a)"),
expr("kurtosis(a)")),
- Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
+ Row(null, null, null, null, null))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5a7f24684d..6399b0165c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val emptyDescribeResult = Seq(
Row("count", "0", "0"),
Row("mean", null, null),
- Row("stddev", "NaN", "NaN"),
+ Row("stddev", null, null),
Row("min", null, null),
Row("max", null, null))