aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJihongMa <linlin200605@gmail.com>2015-09-12 10:17:15 -0700
committerDavies Liu <davies.liu@gmail.com>2015-09-12 10:17:15 -0700
commitf4a22808e03fa12bfe1bfc82cf713cfda7e063a9 (patch)
tree49d22700542e44203793940eb28341e8df573cd5 /sql
parent22730ad54d681ad30e63fe910e8d89360853177d (diff)
downloadspark-f4a22808e03fa12bfe1bfc82cf713cfda7e063a9.tar.gz
spark-f4a22808e03fa12bfe1bfc82cf713cfda7e063a9.tar.bz2
spark-f4a22808e03fa12bfe1bfc82cf713cfda7e063a9.zip
[SPARK-6548] Adding stddev to DataFrame functions
Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change. Author: JihongMa <linlin200605@gmail.com> Author: Jihong MA <linlin200605@gmail.com> Author: Jihong MA <jihongma@jihongs-mbp.usca.ibm.com> Author: Jihong MA <jihongma@Jihongs-MacBook-Pro.local> Closes #6297 from JihongMA/SPARK-SQL.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala143
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala245
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala27
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala42
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala35
14 files changed, 555 insertions, 45 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index cd5a90d788..11b4866bf2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -168,6 +168,9 @@ object FunctionRegistry {
expression[Last]("last"),
expression[Max]("max"),
expression[Min]("min"),
+ expression[Stddev]("stddev"),
+ expression[StddevPop]("stddev_pop"),
+ expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
// string functions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 87c11abbad..87a3845b2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -297,6 +297,9 @@ object HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
+ case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
+ case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
+ case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index a7e3a49327..699c4cc63d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -159,6 +159,9 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
+ def stddev(e: Expression): Expression = Stddev(e)
+ def stddev_pop(e: Expression): Expression = StddevPop(e)
+ def stddev_samp(e: Expression): Expression = StddevSamp(e)
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index a73024d6ad..02cd0ac0db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = min
}
+// Compute the sample standard deviation of a column
+case class Stddev(child: Expression) extends StddevAgg(child) {
+
+ override def isSample: Boolean = true
+ override def prettyName: String = "stddev"
+}
+
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends StddevAgg(child) {
+
+ override def isSample: Boolean = false
+ override def prettyName: String = "stddev_pop"
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends StddevAgg(child) {
+
+ override def isSample: Boolean = true
+ override def prettyName: String = "stddev_samp"
+}
+
+// Compute standard deviation based on online algorithm specified here:
+// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ def isSample: Boolean
+
+ // Return data type.
+ override def dataType: DataType = resultType
+
+ // Expected input data type.
+ // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+ // new version at planning time (after analysis phase). For now, NullType is added at here
+ // to make it resolved when we have cases like `select stddev(null)`.
+ // We can use our analyzer to cast NullType to the default data type of the NumericType once
+ // we remove the old aggregate functions. Then, we will not need NullType at here.
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+
+ private val resultType = DoubleType
+
+ private val preCount = AttributeReference("preCount", resultType)()
+ private val currentCount = AttributeReference("currentCount", resultType)()
+ private val preAvg = AttributeReference("preAvg", resultType)()
+ private val currentAvg = AttributeReference("currentAvg", resultType)()
+ private val currentMk = AttributeReference("currentMk", resultType)()
+
+ override val bufferAttributes = preCount :: currentCount :: preAvg ::
+ currentAvg :: currentMk :: Nil
+
+ override val initialValues = Seq(
+ /* preCount = */ Cast(Literal(0), resultType),
+ /* currentCount = */ Cast(Literal(0), resultType),
+ /* preAvg = */ Cast(Literal(0), resultType),
+ /* currentAvg = */ Cast(Literal(0), resultType),
+ /* currentMk = */ Cast(Literal(0), resultType)
+ )
+
+ override val updateExpressions = {
+
+ // update average
+ // avg = avg + (value - avg)/count
+ def avgAdd: Expression = {
+ currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
+ }
+
+ // update sum of square of difference from mean
+ // Mk = Mk + (value - preAvg) * (value - updatedAvg)
+ def mkAdd: Expression = {
+ val delta1 = Cast(child, resultType) - preAvg
+ val delta2 = Cast(child, resultType) - currentAvg
+ currentMk + (delta1 * delta2)
+ }
+
+ Seq(
+ /* preCount = */ If(IsNull(child), preCount, currentCount),
+ /* currentCount = */ If(IsNull(child), currentCount,
+ Add(currentCount, Cast(Literal(1), resultType))),
+ /* preAvg = */ If(IsNull(child), preAvg, currentAvg),
+ /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
+ /* currentMk = */ If(IsNull(child), currentMk, mkAdd)
+ )
+ }
+
+ override val mergeExpressions = {
+
+ // count merge
+ def countMerge: Expression = {
+ currentCount.left + currentCount.right
+ }
+
+ // average merge
+ def avgMerge: Expression = {
+ ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
+ (preCount + currentCount.right)
+ }
+
+ // update sum of square differences
+ def mkMerge: Expression = {
+ val avgDelta = currentAvg.right - preAvg
+ val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
+ (preCount + currentCount.right)
+
+ currentMk.left + currentMk.right + mkDelta
+ }
+
+ Seq(
+ /* preCount = */ If(IsNull(currentCount.left),
+ Cast(Literal(0), resultType), currentCount.left),
+ /* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
+ If(IsNull(currentCount.right), currentCount.left, countMerge)),
+ /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
+ /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
+ If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
+ /* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
+ If(IsNull(currentMk.right), currentMk.left, mkMerge))
+ )
+ }
+
+ override val evaluateExpression = {
+ // when currentCount == 0, return null
+ // when currentCount == 1, return 0
+ // when currentCount >1
+ // stddev_samp = sqrt (currentMk/(currentCount -1))
+ // stddev_pop = sqrt (currentMk/currentCount)
+ val varCol = {
+ if (isSample) {
+ currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType)
+ }
+ else {
+ currentMk / currentCount
+ }
+ }
+
+ If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
+ If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
+ Cast(Sqrt(varCol), resultType)))
+ }
+}
+
case class Sum(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index 4a43318a95..ce3dddad87 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -85,6 +85,24 @@ object Utils {
mode = aggregate.Complete,
isDistinct = false)
+ case expressions.Stddev(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Stddev(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.StddevPop(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.StddevPop(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.StddevSamp(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.StddevSamp(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
case expressions.Sum(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Sum(child),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5e8298aaaa..f1c47f3904 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
result
}
}
+
+// Compute standard deviation based on online algorithm specified here:
+// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
+ override def nullable: Boolean = true
+ override def dataType: DataType = DoubleType
+
+ def isSample: Boolean
+
+ override def asPartial: SplitEvaluation = {
+ val partialStd = Alias(ComputePartialStd(child), "PartialStddev")()
+ SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil)
+ }
+
+ override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
+
+}
+
+// Compute the sample standard deviation of a column
+case class Stddev(child: Expression) extends StddevAgg1(child) {
+
+ override def toString: String = s"STDDEV($child)"
+ override def isSample: Boolean = true
+}
+
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends StddevAgg1(child) {
+
+ override def toString: String = s"STDDEV_POP($child)"
+ override def isSample: Boolean = false
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends StddevAgg1(child) {
+
+ override def toString: String = s"STDDEV_SAMP($child)"
+ override def isSample: Boolean = true
+}
+
+case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
+ def this() = this(null)
+
+ override def children: Seq[Expression] = child :: Nil
+ override def nullable: Boolean = false
+ override def dataType: DataType = ArrayType(DoubleType)
+ override def toString: String = s"computePartialStddev($child)"
+ override def newInstance(): ComputePartialStdFunction =
+ new ComputePartialStdFunction(child, this)
+}
+
+case class ComputePartialStdFunction (
+ expr: Expression,
+ base: AggregateExpression1
+) extends AggregateFunction1 {
+ def this() = this(null, null) // Required for serialization
+
+ private val computeType = DoubleType
+ private val zero = Cast(Literal(0), computeType)
+ private var partialCount: Long = 0L
+
+ // the mean of data processed so far
+ private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
+
+ // update average based on this formula:
+ // avg = avg + (value - avg)/count
+ private def avgAddFunction (value: Literal): Expression = {
+ val delta = Subtract(Cast(value, computeType), partialAvg)
+ Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType)))
+ }
+
+ // the sum of squares of difference from mean
+ private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
+
+ // update sum of square of difference from mean based on following formula:
+ // Mk = Mk + (value - preAvg) * (value - updatedAvg)
+ private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = {
+ val delta1 = Subtract(Cast(value, computeType), prePartialAvg)
+ val delta2 = Subtract(Cast(value, computeType), partialAvg)
+ Add(partialMk, Multiply(delta1, delta2))
+ }
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ val exprValue = Literal.create(evaluatedExpr, expr.dataType)
+ val prePartialAvg = partialAvg.copy()
+ partialCount += 1
+ partialAvg.update(avgAddFunction(exprValue), input)
+ partialMk.update(mkAddFunction(exprValue, prePartialAvg), input)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null),
+ partialAvg.eval(null),
+ partialMk.eval(null)))
+ }
+}
+
+case class MergePartialStd(
+ child: Expression,
+ isSample: Boolean
+) extends UnaryExpression with AggregateExpression1 {
+ def this() = this(null, false) // required for serialization
+
+ override def children: Seq[Expression] = child:: Nil
+ override def nullable: Boolean = false
+ override def dataType: DataType = DoubleType
+ override def toString: String = s"MergePartialStd($child)"
+ override def newInstance(): MergePartialStdFunction = {
+ new MergePartialStdFunction(child, this, isSample)
+ }
+}
+
+case class MergePartialStdFunction(
+ expr: Expression,
+ base: AggregateExpression1,
+ isSample: Boolean
+) extends AggregateFunction1 {
+ def this() = this (null, null, false) // Required for serialization
+
+ private val computeType = DoubleType
+ private val zero = Cast(Literal(0), computeType)
+ private val combineCount = MutableLiteral(zero.eval(null), computeType)
+ private val combineAvg = MutableLiteral(zero.eval(null), computeType)
+ private val combineMk = MutableLiteral(zero.eval(null), computeType)
+
+ private def avgUpdateFunction(preCount: Expression,
+ partialCount: Expression,
+ partialAvg: Expression): Expression = {
+ Divide(Add(Multiply(combineAvg, preCount),
+ Multiply(partialAvg, partialCount)),
+ Add(preCount, partialCount))
+ }
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData]
+
+ if (evaluatedExpr != null) {
+ val exprValue = evaluatedExpr.toArray(computeType)
+ val (partialCount, partialAvg, partialMk) =
+ (Literal.create(exprValue(0), computeType),
+ Literal.create(exprValue(1), computeType),
+ Literal.create(exprValue(2), computeType))
+
+ if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) {
+ val preCount = combineCount.copy()
+ combineCount.update(Add(combineCount, partialCount), input)
+
+ val preAvg = combineAvg.copy()
+ val avgDelta = Subtract(partialAvg, preAvg)
+ val mkDelta = Multiply(Multiply(avgDelta, avgDelta),
+ Divide(Multiply(preCount, partialCount),
+ combineCount))
+
+ // update average based on following formula
+ // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount)
+ combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input)
+
+ // update sum of square differences from mean based on following formula
+ // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount)
+ combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input)
+ }
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long]
+
+ if (count == 0) null
+ else if (count < 2) zero.eval(null)
+ else {
+ // when total count > 2
+ // stddev_samp = sqrt (combineMk/(combineCount -1))
+ // stddev_pop = sqrt (combineMk/combineCount)
+ val varCol = {
+ if (isSample) {
+ Divide(combineMk, Cast(Literal(count - 1), computeType))
+ }
+ else {
+ Divide(combineMk, Cast(Literal(count), computeType))
+ }
+ }
+ Sqrt(varCol).eval(null)
+ }
+ }
+}
+
+case class StddevFunction(
+ expr: Expression,
+ base: AggregateExpression1,
+ isSample: Boolean
+) extends AggregateFunction1 {
+
+ def this() = this(null, null, false) // Required for serialization
+
+ private val computeType = DoubleType
+ private var curCount: Long = 0L
+ private val zero = Cast(Literal(0), computeType)
+ private val curAvg = MutableLiteral(zero.eval(null), computeType)
+ private val curMk = MutableLiteral(zero.eval(null), computeType)
+
+ private def curAvgAddFunction(value: Literal): Expression = {
+ val delta = Subtract(Cast(value, computeType), curAvg)
+ Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType)))
+ }
+ private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = {
+ val delta1 = Subtract(Cast(value, computeType), preAvg)
+ val delta2 = Subtract(Cast(value, computeType), curAvg)
+ Add(curMk, Multiply(delta1, delta2))
+ }
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ val preAvg: MutableLiteral = curAvg.copy()
+ val exprValue = Literal.create(evaluatedExpr, expr.dataType)
+ curCount += 1L
+ curAvg.update(curAvgAddFunction(exprValue), input)
+ curMk.update(curMkAddFunction(exprValue, preAvg), input)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ if (curCount == 0) null
+ else if (curCount < 2) zero.eval(null)
+ else {
+ // when total count > 2,
+ // stddev_samp = sqrt(curMk/(curCount - 1))
+ // stddev_pop = sqrt(curMk/curCount)
+ val varCol = {
+ if (isSample) {
+ Divide(curMk, Cast(Literal(curCount - 1), computeType))
+ }
+ else {
+ Divide(curMk, Cast(Literal(curCount), computeType))
+ }
+ }
+ Sqrt(varCol).eval(null)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 791c10c3d7..1a687b2374 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1288,15 +1288,11 @@ class DataFrame private[sql](
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {
- // TODO: Add stddev as an expression, and remove it from here.
- def stddevExpr(expr: Expression): Expression =
- Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
-
// The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
"count" -> Count,
"mean" -> Average,
- "stddev" -> stddevExpr,
+ "stddev" -> Stddev,
"min" -> Min,
"max" -> Max)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index ee31d83cce..102b802ad0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -124,6 +124,9 @@ class GroupedData protected[sql](
case "avg" | "average" | "mean" => Average
case "max" => Max
case "min" => Min
+ case "stddev" => Stddev
+ case "stddev_pop" => StddevPop
+ case "stddev_samp" => StddevSamp
case "sum" => Sum
case "count" | "size" =>
// Turn count(*) into count(1)
@@ -284,6 +287,42 @@ class GroupedData protected[sql](
}
/**
+ * Compute the sample standard deviation for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the stddev for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def stddev(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(Stddev)
+ }
+
+ /**
+ * Compute the population standard deviation for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the stddev for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def stddev_pop(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(StddevPop)
+ }
+
+ /**
+ * Compute the sample standard deviation for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the stddev for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def stddev_samp(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(StddevSamp)
+ }
+
+ /**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the sum for them.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 435e6319a6..60d9c50910 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -295,6 +295,33 @@ object functions {
def min(columnName: String): Column = min(Column(columnName))
/**
+ * Aggregate function: returns the unbiased sample standard deviation
+ * of the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev(e: Column): Column = Stddev(e.expr)
+
+ /**
+ * Aggregate function: returns the population standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_pop(e: Column): Column = StddevPop(e.expr)
+
+ /**
+ * Aggregate function: returns the unbiased sample standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_samp(e: Column): Column = StddevSamp(e.expr)
+
+ /**
* Aggregate function: returns the sum of all values in the expression.
*
* @group agg_funcs
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index d981ce947f..5f9abd4999 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -90,6 +90,7 @@ public class JavaDataFrameSuite {
df.groupBy().mean("key");
df.groupBy().max("key");
df.groupBy().min("key");
+ df.groupBy().stddev("key");
df.groupBy().sum("key");
// Varargs in column expressions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index c0950b09b1..f5ef9ffd7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(0, null))
}
+ test("stddev") {
+ val testData2ADev = math.sqrt(4/5.0)
+
+ checkAnswer(
+ testData2.agg(stddev('a)),
+ Row(testData2ADev))
+
+ checkAnswer(
+ testData2.agg(stddev_pop('a)),
+ Row(math.sqrt(4/6.0)))
+
+ checkAnswer(
+ testData2.agg(stddev_samp('a)),
+ Row(testData2ADev))
+ }
+
+ test("zero stddev") {
+ val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+ assert(emptyTableData.count() == 0)
+
+ checkAnswer(
+ emptyTableData.agg(stddev('a)),
+ Row(null))
+
+ checkAnswer(
+ emptyTableData.agg(stddev_pop('a)),
+ Row(null))
+
+ checkAnswer(
+ emptyTableData.agg(stddev_samp('a)),
+ Row(null))
+ }
+
test("zero sum") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index dbed4fc247..c167999af5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -436,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val describeResult = Seq(
Row("count", "4", "4"),
Row("mean", "33.0", "178.0"),
- Row("stddev", "16.583123951777", "10.0"),
+ Row("stddev", "19.148542155126762", "11.547005383792516"),
Row("min", "16", "164"),
Row("max", "60", "192"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 664b7a1512..962b100b53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
testCodeGen(
"SELECT min(key) FROM testData3x",
Row(1) :: Nil)
+ // STDDEV
+ testCodeGen(
+ "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
+ (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
+ testCodeGen(
+ "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
+ Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
// Some combinations.
testCodeGen(
"""
@@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(100, 1, 50.5, 300, 100) :: Nil)
// Aggregate with Code generation handling all null values
testCodeGen(
- "SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(null, null, 0) :: Nil)
+ "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
+ Row(null, null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
@@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("aggregates with nulls") {
checkAnswer(
- sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
- Row(1, 3, 2, 6, 3)
+ sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
+ Row(1, 3, 2, 1, 6, 3)
)
}
@@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
+ test("stddev") {
+ checkAnswer(
+ sql("SELECT STDDEV(a) FROM testData2"),
+ Row(math.sqrt(4/5.0))
+ )
+ }
+
+ test("stddev_pop") {
+ checkAnswer(
+ sql("SELECT STDDEV_POP(a) FROM testData2"),
+ Row(math.sqrt(4/6.0))
+ )
+ }
+
+ test("stddev_samp") {
+ checkAnswer(
+ sql("SELECT STDDEV_SAMP(a) FROM testData2"),
+ Row(math.sqrt(4/5.0))
+ )
+ }
+
+ test("stddev agg") {
+ checkAnswer(
+ sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
+ (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0))))
+ }
+
test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index b126ec455f..a73b1bd52c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -507,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
}.getMessage
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
}
-
- // TODO: once we support Hive UDAF in the new interface,
- // we can remove the following two tests.
- withSQLConf("spark.sql.useAggregate2" -> "true") {
- val errorMessage = intercept[AnalysisException] {
- sqlContext.sql(
- """
- |SELECT
- | key,
- | mydoublesum(value + 1.5 * key),
- | stddev_samp(value)
- |FROM agg1
- |GROUP BY key
- """.stripMargin).collect()
- }.getMessage
- assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
-
- // This will fall back to the old aggregate
- val newAggregateOperators = sqlContext.sql(
- """
- |SELECT
- | key,
- | sum(value + 1.5 * key),
- | stddev_samp(value)
- |FROM agg1
- |GROUP BY key
- """.stripMargin).queryExecution.executedPlan.collect {
- case agg: aggregate.SortBasedAggregate => agg
- case agg: aggregate.TungstenAggregate => agg
- }
- val message =
- "We should fallback to the old aggregation code path if " +
- "there is any aggregate function that cannot be converted to the new interface."
- assert(newAggregateOperators.isEmpty, message)
- }
}
}