From d292f74831de7e69c852ed26d9c15df85b4fb568 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Thu, 12 Nov 2015 13:47:34 -0800 Subject: [SPARK-11420] Updating Stddev support via Imperative Aggregate switched stddev support from DeclarativeAggregate to ImperativeAggregate. Author: JihongMa Closes #9380 from JihongMA/SPARK-11420. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 6 +- .../catalyst/expressions/aggregate/Kurtosis.scala | 4 +- .../catalyst/expressions/aggregate/Skewness.scala | 4 +- .../catalyst/expressions/aggregate/Stddev.scala | 128 ++++++--------------- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../apache/spark/sql/DataFrameAggregateSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 11 +- 8 files changed, 49 insertions(+), 112 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index bf2bff0243..92188ee54f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,8 +297,10 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index bae78d9849..8fa3aac9f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -42,9 +42,11 @@ case class Kurtosis(child: Expression, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) + if (n == 0.0 || m2 == 0.0) { Double.NaN - } else { + } + else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index c593074fa2..e1c01a5b82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -41,9 +41,11 @@ case class Skewness(child: Expression, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) + if (n == 0.0 || m2 == 0.0) { Double.NaN - } else { + } + else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 2748009623..05dd5e3b22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -17,117 +17,55 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types._ +case class StddevSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg(child) { - override def isSample: Boolean = false - override def prettyName: String = "stddev_pop" -} - - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg(child) { - override def isSample: Boolean = true - override def prettyName: String = "stddev_samp" -} - - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - def isSample: Boolean + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) - override def children: Seq[Expression] = child :: Nil + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) - override def nullable: Boolean = true - - override def dataType: DataType = resultType - - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def prettyName: String = "stddev_samp" - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function stddev") + override protected val momentOrder = 2 - private lazy val resultType = DoubleType + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - private lazy val count = AttributeReference("count", resultType)() - private lazy val avg = AttributeReference("avg", resultType)() - private lazy val mk = AttributeReference("mk", resultType)() + if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + } +} - override lazy val aggBufferAttributes = count :: avg :: mk :: Nil +case class StddevPop( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { - override lazy val initialValues: Seq[Expression] = Seq( - /* count = */ Cast(Literal(0), resultType), - /* avg = */ Cast(Literal(0), resultType), - /* mk = */ Cast(Literal(0), resultType) - ) + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - override lazy val updateExpressions: Seq[Expression] = { - val value = Cast(child, resultType) - val newCount = count + Cast(Literal(1), resultType) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) - // update average - // avg = avg + (value - avg)/count - val newAvg = avg + (value - avg) / newCount + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) - // update sum ofference from mean - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - val newMk = mk + (value - avg) * (value - newAvg) + override def prettyName: String = "stddev_pop" - Seq( - /* count = */ If(IsNull(child), count, newCount), - /* avg = */ If(IsNull(child), avg, newAvg), - /* mk = */ If(IsNull(child), mk, newMk) - ) - } + override protected val momentOrder = 2 - override lazy val mergeExpressions: Seq[Expression] = { - - // count merge - val newCount = count.left + count.right - - // average merge - val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount - - // update sum of square differences - val newMk = { - val avgDelta = avg.right - avg.left - val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount - mk.left + mk.right + mkDelta - } - - Seq( - /* count = */ If(IsNull(count.left), count.right, - If(IsNull(count.right), count.left, newCount)), - /* avg = */ If(IsNull(avg.left), avg.right, - If(IsNull(avg.right), avg.left, newAvg)), - /* mk = */ If(IsNull(mk.left), mk.right, - If(IsNull(mk.right), mk.left, newMk)) - ) - } + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - override lazy val evaluateExpression: Expression = { - // when count == 0, return null - // when count == 1, return 0 - // when count >1 - // stddev_samp = sqrt (mk/(count -1)) - // stddev_pop = sqrt (mk/count) - val varCol = - if (isSample) { - mk / Cast(count - Cast(Literal(1), resultType), resultType) - } else { - mk / count - } - - If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(varCol), resultType))) + if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b6330e230a..53cc6e0cda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -397,7 +397,7 @@ object functions extends LegacyFunctions { def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** - * Aggregate function: returns the unbiased sample standard deviation of + * Aggregate function: returns the sample standard deviation of * the expression in a group. * * @group agg_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index eb1ee266c5..432e8d1762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -195,7 +195,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("stddev") { - val testData2ADev = math.sqrt(4 / 5.0) + val testData2ADev = math.sqrt(4.0 / 5.0) checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(null, null, null)) + Row(Double.NaN, Double.NaN, Double.NaN)) } test("zero sum") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e4f23fe17b..35cdab50bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", null, null), + Row("stddev", "NaN", "NaN"), Row("min", null, null), Row("max", null, null)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 52a561d2e5..167aea87de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -314,13 +314,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) - // STDDEV - testCodeGen( - "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) - testCodeGen( - "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", - Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -341,8 +334,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", - Row(null, null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") } -- cgit v1.2.3