aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R4
-rw-r--r--python/pyspark/sql/dataframe.py2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala128
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala11
10 files changed, 52 insertions, 115 deletions
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 9e453a1e7c..af024e6183 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1007,7 +1007,7 @@ test_that("group by, agg functions", {
df3 <- agg(gd, age = "stddev")
expect_is(df3, "DataFrame")
df3_local <- collect(df3)
- expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2])
+ expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2]))
df4 <- agg(gd, sumAge = sum(df$age))
expect_is(df4, "DataFrame")
@@ -1038,7 +1038,7 @@ test_that("group by, agg functions", {
df7 <- agg(gd2, value = "stddev")
df7_local <- collect(df7)
expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6)
- expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2])
+ expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2]))
mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}",
"{\"name\":\"Andy\", \"age\":30}",
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0dd75ba7ca..ad6ad0235a 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -761,7 +761,7 @@ class DataFrame(object):
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
- | stddev|2.1213203435596424| null|
+ | stddev|2.1213203435596424| NaN|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index bf2bff0243..92188ee54f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -297,8 +297,10 @@ object HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
- case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
- case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
+ case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index bae78d9849..8fa3aac9f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -42,9 +42,11 @@ case class Kurtosis(child: Expression,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)
+
if (n == 0.0 || m2 == 0.0) {
Double.NaN
- } else {
+ }
+ else {
n * m4 / (m2 * m2) - 3.0
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index c593074fa2..e1c01a5b82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -41,9 +41,11 @@ case class Skewness(child: Expression,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)
+
if (n == 0.0 || m2 == 0.0) {
Double.NaN
- } else {
+ }
+ else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 2748009623..05dd5e3b22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -17,117 +17,55 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types._
+case class StddevSamp(child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends CentralMomentAgg(child) {
-// Compute the population standard deviation of a column
-case class StddevPop(child: Expression) extends StddevAgg(child) {
- override def isSample: Boolean = false
- override def prettyName: String = "stddev_pop"
-}
-
-
-// Compute the sample standard deviation of a column
-case class StddevSamp(child: Expression) extends StddevAgg(child) {
- override def isSample: Boolean = true
- override def prettyName: String = "stddev_samp"
-}
-
-
-// Compute standard deviation based on online algorithm specified here:
-// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
- def isSample: Boolean
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
- override def children: Seq[Expression] = child :: Nil
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
- override def nullable: Boolean = true
-
- override def dataType: DataType = resultType
-
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+ override def prettyName: String = "stddev_samp"
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
+ override protected val momentOrder = 2
- private lazy val resultType = DoubleType
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ require(moments.length == momentOrder + 1,
+ s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- private lazy val count = AttributeReference("count", resultType)()
- private lazy val avg = AttributeReference("avg", resultType)()
- private lazy val mk = AttributeReference("mk", resultType)()
+ if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
+ }
+}
- override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
+case class StddevPop(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends CentralMomentAgg(child) {
- override lazy val initialValues: Seq[Expression] = Seq(
- /* count = */ Cast(Literal(0), resultType),
- /* avg = */ Cast(Literal(0), resultType),
- /* mk = */ Cast(Literal(0), resultType)
- )
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
- override lazy val updateExpressions: Seq[Expression] = {
- val value = Cast(child, resultType)
- val newCount = count + Cast(Literal(1), resultType)
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
- // update average
- // avg = avg + (value - avg)/count
- val newAvg = avg + (value - avg) / newCount
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
- // update sum ofference from mean
- // Mk = Mk + (value - preAvg) * (value - updatedAvg)
- val newMk = mk + (value - avg) * (value - newAvg)
+ override def prettyName: String = "stddev_pop"
- Seq(
- /* count = */ If(IsNull(child), count, newCount),
- /* avg = */ If(IsNull(child), avg, newAvg),
- /* mk = */ If(IsNull(child), mk, newMk)
- )
- }
+ override protected val momentOrder = 2
- override lazy val mergeExpressions: Seq[Expression] = {
-
- // count merge
- val newCount = count.left + count.right
-
- // average merge
- val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount
-
- // update sum of square differences
- val newMk = {
- val avgDelta = avg.right - avg.left
- val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
- mk.left + mk.right + mkDelta
- }
-
- Seq(
- /* count = */ If(IsNull(count.left), count.right,
- If(IsNull(count.right), count.left, newCount)),
- /* avg = */ If(IsNull(avg.left), avg.right,
- If(IsNull(avg.right), avg.left, newAvg)),
- /* mk = */ If(IsNull(mk.left), mk.right,
- If(IsNull(mk.right), mk.left, newMk))
- )
- }
+ override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+ require(moments.length == momentOrder + 1,
+ s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
- override lazy val evaluateExpression: Expression = {
- // when count == 0, return null
- // when count == 1, return 0
- // when count >1
- // stddev_samp = sqrt (mk/(count -1))
- // stddev_pop = sqrt (mk/count)
- val varCol =
- if (isSample) {
- mk / Cast(count - Cast(Literal(1), resultType), resultType)
- } else {
- mk / count
- }
-
- If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
- If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
- Cast(Sqrt(varCol), resultType)))
+ if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b6330e230a..53cc6e0cda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -397,7 +397,7 @@ object functions extends LegacyFunctions {
def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
/**
- * Aggregate function: returns the unbiased sample standard deviation of
+ * Aggregate function: returns the sample standard deviation of
* the expression in a group.
*
* @group agg_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index eb1ee266c5..432e8d1762 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -195,7 +195,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
test("stddev") {
- val testData2ADev = math.sqrt(4 / 5.0)
+ val testData2ADev = math.sqrt(4.0 / 5.0)
checkAnswer(
testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
@@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
- Row(null, null, null))
+ Row(Double.NaN, Double.NaN, Double.NaN))
}
test("zero sum") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e4f23fe17b..35cdab50bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val emptyDescribeResult = Seq(
Row("count", "0", "0"),
Row("mean", null, null),
- Row("stddev", null, null),
+ Row("stddev", "NaN", "NaN"),
Row("min", null, null),
Row("max", null, null))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 52a561d2e5..167aea87de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -314,13 +314,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
testCodeGen(
"SELECT min(key) FROM testData3x",
Row(1) :: Nil)
- // STDDEV
- testCodeGen(
- "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
- (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
- testCodeGen(
- "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
- Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
// Some combinations.
testCodeGen(
"""
@@ -341,8 +334,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(100, 1, 50.5, 300, 100) :: Nil)
// Aggregate with Code generation handling all null values
testCodeGen(
- "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
- Row(null, null, null, 0) :: Nil)
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
}