aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorJihongMa <linlin200605@gmail.com>2015-11-18 13:03:37 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 13:03:37 -0800
commit09ad9533d5760652de59fa4830c24cb8667958ac (patch)
tree6e6023e1d2df2ccf565f9df1bf26e82904a70363 /sql/catalyst
parent7c5b641808740ba5eed05ba8204cdbaf3fc579f5 (diff)
downloadspark-09ad9533d5760652de59fa4830c24cb8667958ac.tar.gz
spark-09ad9533d5760652de59fa4830c24cb8667958ac.tar.bz2
spark-09ad9533d5760652de59fa4830c24cb8667958ac.zip
[SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function
return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa <linlin200605@gmail.com> Closes #9705 from JihongMA/SPARK-11720.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala18
5 files changed, 39 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index de5872ab11..d07d4c338c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
* @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
* needed to compute the aggregate stat.
*/
- def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double
+ def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any
override final def eval(buffer: InternalRow): Any = {
val n = buffer.getDouble(nOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index 8fa3aac9f1..c2bf2cb941 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -37,16 +37,17 @@ case class Kurtosis(child: Expression,
override protected val momentOrder = 4
// NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)
- if (n == 0.0 || m2 == 0.0) {
+ if (n == 0.0) {
+ null
+ } else if (m2 == 0.0) {
Double.NaN
- }
- else {
+ } else {
n * m4 / (m2 * m2) - 3.0
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index e1c01a5b82..9411bcea25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -36,16 +36,17 @@ case class Skewness(child: Expression,
override protected val momentOrder = 3
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)
- if (n == 0.0 || m2 == 0.0) {
+ if (n == 0.0) {
+ null
+ } else if (m2 == 0.0) {
Double.NaN
- }
- else {
+ } else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 05dd5e3b22..eec79a9033 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -36,11 +36,17 @@ case class StddevSamp(child: Expression,
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
+ if (n == 0.0) {
+ null
+ } else if (n == 1.0) {
+ Double.NaN
+ } else {
+ math.sqrt(moments(2) / (n - 1.0))
+ }
}
}
@@ -62,10 +68,14 @@ case class StddevPop(
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
+ if (n == 0.0) {
+ null
+ } else {
+ math.sqrt(moments(2) / n)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
index ede2da2805..cf3a740305 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
@@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression,
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
+ if (n == 0.0) {
+ null
+ } else if (n == 1.0) {
+ Double.NaN
+ } else {
+ moments(2) / (n - 1.0)
+ }
}
}
@@ -62,10 +68,14 @@ case class VariancePop(
override protected val momentOrder = 2
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- if (n == 0.0) Double.NaN else moments(2) / n
+ if (n == 0.0) {
+ null
+ } else {
+ moments(2) / n
+ }
}
}