From f4a22808e03fa12bfe1bfc82cf713cfda7e063a9 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Sat, 12 Sep 2015 10:17:15 -0700 Subject: [SPARK-6548] Adding stddev to DataFrame functions Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change. Author: JihongMa Author: Jihong MA Author: Jihong MA Author: Jihong MA Closes #6297 from JihongMA/SPARK-SQL. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 3 + .../sql/catalyst/analysis/HiveTypeCoercion.scala | 3 + .../apache/spark/sql/catalyst/dsl/package.scala | 3 + .../catalyst/expressions/aggregate/functions.scala | 143 ++++++++++++ .../sql/catalyst/expressions/aggregate/utils.scala | 18 ++ .../sql/catalyst/expressions/aggregates.scala | 245 +++++++++++++++++++++ .../scala/org/apache/spark/sql/DataFrame.scala | 6 +- .../scala/org/apache/spark/sql/GroupedData.scala | 39 ++++ .../scala/org/apache/spark/sql/functions.scala | 27 +++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 1 + .../apache/spark/sql/DataFrameAggregateSuite.scala | 33 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 42 +++- .../sql/hive/execution/AggregationQuerySuite.scala | 35 --- 14 files changed, 555 insertions(+), 45 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d788..11b4866bf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,6 +168,9 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), + expression[Stddev]("stddev"), + expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), // string functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87c11abbad..87a3845b2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,6 +297,9 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a7e3a49327..699c4cc63d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,6 +159,9 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def stddev(e: Expression): Expression = Stddev(e) + def stddev_pop(e: Expression): Expression = StddevPop(e) + def stddev_samp(e: Expression): Expression = StddevSamp(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6ad..02cd0ac0db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = min } +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev" +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + def isSample: Boolean + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val preCount = AttributeReference("preCount", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + private val preAvg = AttributeReference("preAvg", resultType)() + private val currentAvg = AttributeReference("currentAvg", resultType)() + private val currentMk = AttributeReference("currentMk", resultType)() + + override val bufferAttributes = preCount :: currentCount :: preAvg :: + currentAvg :: currentMk :: Nil + + override val initialValues = Seq( + /* preCount = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType), + /* preAvg = */ Cast(Literal(0), resultType), + /* currentAvg = */ Cast(Literal(0), resultType), + /* currentMk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + + // update average + // avg = avg + (value - avg)/count + def avgAdd: Expression = { + currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) + } + + // update sum of square of difference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + def mkAdd: Expression = { + val delta1 = Cast(child, resultType) - preAvg + val delta2 = Cast(child, resultType) - currentAvg + currentMk + (delta1 * delta2) + } + + Seq( + /* preCount = */ If(IsNull(child), preCount, currentCount), + /* currentCount = */ If(IsNull(child), currentCount, + Add(currentCount, Cast(Literal(1), resultType))), + /* preAvg = */ If(IsNull(child), preAvg, currentAvg), + /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), + /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + ) + } + + override val mergeExpressions = { + + // count merge + def countMerge: Expression = { + currentCount.left + currentCount.right + } + + // average merge + def avgMerge: Expression = { + ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / + (preCount + currentCount.right) + } + + // update sum of square differences + def mkMerge: Expression = { + val avgDelta = currentAvg.right - preAvg + val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / + (preCount + currentCount.right) + + currentMk.left + currentMk.right + mkDelta + } + + Seq( + /* preCount = */ If(IsNull(currentCount.left), + Cast(Literal(0), resultType), currentCount.left), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, countMerge)), + /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), + /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, + If(IsNull(currentMk.right), currentMk.left, mkMerge)) + ) + } + + override val evaluateExpression = { + // when currentCount == 0, return null + // when currentCount == 1, return 0 + // when currentCount >1 + // stddev_samp = sqrt (currentMk/(currentCount -1)) + // stddev_pop = sqrt (currentMk/currentCount) + val varCol = { + if (isSample) { + currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) + } + else { + currentMk / currentCount + } + } + + If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} + case class Sum(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a95..ce3dddad87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -85,6 +85,24 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Stddev(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Stddev(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevPop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevPop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevSamp(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Sum(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Sum(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa..f1c47f3904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag result } } + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + + def isSample: Boolean + + override def asPartial: SplitEvaluation = { + val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() + SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) + } + + override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") + +} + +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV($child)" + override def isSample: Boolean = true +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_POP($child)" + override def isSample: Boolean = false +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_SAMP($child)" + override def isSample: Boolean = true +} + +case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) +} + +case class ComputePartialStdFunction ( + expr: Expression, + base: AggregateExpression1 +) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private var partialCount: Long = 0L + + // the mean of data processed so far + private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update average based on this formula: + // avg = avg + (value - avg)/count + private def avgAddFunction (value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), partialAvg) + Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) + } + + // the sum of squares of difference from mean + private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update sum of square of difference from mean based on following formula: + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), prePartialAvg) + val delta2 = Subtract(Cast(value, computeType), partialAvg) + Add(partialMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + val prePartialAvg = partialAvg.copy() + partialCount += 1 + partialAvg.update(avgAddFunction(exprValue), input) + partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), + partialAvg.eval(null), + partialMk.eval(null))) + } +} + +case class MergePartialStd( + child: Expression, + isSample: Boolean +) extends UnaryExpression with AggregateExpression1 { + def this() = this(null, false) // required for serialization + + override def children: Seq[Expression] = child:: Nil + override def nullable: Boolean = false + override def dataType: DataType = DoubleType + override def toString: String = s"MergePartialStd($child)" + override def newInstance(): MergePartialStdFunction = { + new MergePartialStdFunction(child, this, isSample) + } +} + +case class MergePartialStdFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + def this() = this (null, null, false) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private val combineCount = MutableLiteral(zero.eval(null), computeType) + private val combineAvg = MutableLiteral(zero.eval(null), computeType) + private val combineMk = MutableLiteral(zero.eval(null), computeType) + + private def avgUpdateFunction(preCount: Expression, + partialCount: Expression, + partialAvg: Expression): Expression = { + Divide(Add(Multiply(combineAvg, preCount), + Multiply(partialAvg, partialCount)), + Add(preCount, partialCount)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] + + if (evaluatedExpr != null) { + val exprValue = evaluatedExpr.toArray(computeType) + val (partialCount, partialAvg, partialMk) = + (Literal.create(exprValue(0), computeType), + Literal.create(exprValue(1), computeType), + Literal.create(exprValue(2), computeType)) + + if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { + val preCount = combineCount.copy() + combineCount.update(Add(combineCount, partialCount), input) + + val preAvg = combineAvg.copy() + val avgDelta = Subtract(partialAvg, preAvg) + val mkDelta = Multiply(Multiply(avgDelta, avgDelta), + Divide(Multiply(preCount, partialCount), + combineCount)) + + // update average based on following formula + // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) + combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) + + // update sum of square differences from mean based on following formula + // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) + combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) + } + } + } + + override def eval(input: InternalRow): Any = { + val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] + + if (count == 0) null + else if (count < 2) zero.eval(null) + else { + // when total count > 2 + // stddev_samp = sqrt (combineMk/(combineCount -1)) + // stddev_pop = sqrt (combineMk/combineCount) + val varCol = { + if (isSample) { + Divide(combineMk, Cast(Literal(count - 1), computeType)) + } + else { + Divide(combineMk, Cast(Literal(count), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} + +case class StddevFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + + def this() = this(null, null, false) // Required for serialization + + private val computeType = DoubleType + private var curCount: Long = 0L + private val zero = Cast(Literal(0), computeType) + private val curAvg = MutableLiteral(zero.eval(null), computeType) + private val curMk = MutableLiteral(zero.eval(null), computeType) + + private def curAvgAddFunction(value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), curAvg) + Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) + } + private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), preAvg) + val delta2 = Subtract(Cast(value, computeType), curAvg) + Add(curMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val preAvg: MutableLiteral = curAvg.copy() + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + curCount += 1L + curAvg.update(curAvgAddFunction(exprValue), input) + curMk.update(curMkAddFunction(exprValue, preAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + if (curCount == 0) null + else if (curCount < 2) zero.eval(null) + else { + // when total count > 2, + // stddev_samp = sqrt(curMk/(curCount - 1)) + // stddev_pop = sqrt(curMk/curCount) + val varCol = { + if (isSample) { + Divide(curMk, Cast(Literal(curCount - 1), computeType)) + } + else { + Divide(curMk, Cast(Literal(curCount), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 791c10c3d7..1a687b2374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1288,15 +1288,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> Stddev, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee31d83cce..102b802ad0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -124,6 +124,9 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min + case "stddev" => Stddev + case "stddev_pop" => StddevPop + case "stddev_samp" => StddevSamp case "sum" => Sum case "count" | "size" => // Turn count(*) into count(1) @@ -283,6 +286,42 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Stddev) + } + + /** + * Compute the population standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevPop) + } + + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevSamp) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a6..60d9c50910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -294,6 +294,33 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = Stddev(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index d981ce947f..5f9abd4999 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -90,6 +90,7 @@ public class JavaDataFrameSuite { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c0950b09b1..f5ef9ffd7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dbed4fc247..c167999af5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -436,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 664b7a1512..962b100b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index b126ec455f..a73b1bd52c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -507,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te }.getMessage assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } - - // TODO: once we support Hive UDAF in the new interface, - // we can remove the following two tests. - withSQLConf("spark.sql.useAggregate2" -> "true") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | mydoublesum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.SortBasedAggregate => agg - case agg: aggregate.TungstenAggregate => agg - } - val message = - "We should fallback to the old aggregation code path if " + - "there is any aggregate function that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) - } } } -- cgit v1.2.3