diff options
author | sethah <seth.hendrickson16@gmail.com> | 2015-10-29 11:58:39 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-10-29 11:58:39 -0700 |
commit | a01cbf5daac148f39cd97299780f542abc41d1e9 (patch) | |
tree | 357dfc7f8e7784dc36cbb4f77212e84d0809d1df | |
parent | 8185f038c13c72e1bea7b0921b84125b7a352139 (diff) | |
download | spark-a01cbf5daac148f39cd97299780f542abc41d1e9.tar.gz spark-a01cbf5daac148f39cd97299780f542abc41d1e9.tar.bz2 spark-a01cbf5daac148f39cd97299780f542abc41d1e9.zip |
[SPARK-10641][SQL] Add Skewness and Kurtosis Support
Implementing skewness and kurtosis support based on following algorithm:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
Author: sethah <seth.hendrickson16@gmail.com>
Closes #9003 from sethah/SPARK-10641.
12 files changed, 823 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3dce6c1a27..ed9fcfe014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -189,6 +189,11 @@ object FunctionRegistry { expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), + expression[Variance]("variance"), + expression[VariancePop]("var_pop"), + expression[VarianceSamp]("var_samp"), + expression[Skewness]("skewness"), + expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 1140150f66..3c675672da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -300,6 +300,11 @@ object HiveTypeCoercion { case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case Variance(e @ StringType()) => Variance(Cast(e, DoubleType)) + case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) + case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) + case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 27b3cd84b3..787f67a297 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -162,6 +162,11 @@ package object dsl { def stddev(e: Expression): Expression = Stddev(e) def stddev_pop(e: Expression): Expression = StddevPop(e) def stddev_samp(e: Expression): Expression = StddevSamp(e) + def variance(e: Expression): Expression = Variance(e) + def var_pop(e: Expression): Expression = VariancePop(e) + def var_samp(e: Expression): Expression = VarianceSamp(e) + def skewness(e: Expression): Expression = Skewness(e) + def kurtosis(e: Expression): Expression = Kurtosis(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 515246d344..281404f285 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -930,3 +930,332 @@ object HyperLogLogPlusPlus { ) // scalastyle:on } + +/** + * A central moment is the expected value of a specified power of the deviation of a random + * variable from the mean. Central moments are often used to characterize the properties of about + * the shape of a distribution. + * + * This class implements online, one-pass algorithms for computing the central moments of a set of + * points. + * + * Behavior: + * - null values are ignored + * - returns `Double.NaN` when the column contains `Double.NaN` values + * + * References: + * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." + * 2015. http://arxiv.org/abs/1510.04923 + * + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + * Algorithms for calculating variance (Wikipedia)]] + * + * @param child to compute central moments of. + */ +abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { + + /** + * The central moment order to be computed. + */ + protected def momentOrder: Int + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * Size of aggregation buffer. + */ + private[this] val bufferSize = 5 + + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => + AttributeReference(s"M$i", DoubleType)() + } + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + // buffer offsets + private[this] val nOffset = mutableAggBufferOffset + private[this] val meanOffset = mutableAggBufferOffset + 1 + private[this] val secondMomentOffset = mutableAggBufferOffset + 2 + private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 + private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 + + // frequently used values for online updates + private[this] var delta = 0.0 + private[this] var deltaN = 0.0 + private[this] var delta2 = 0.0 + private[this] var deltaN2 = 0.0 + private[this] var n = 0.0 + private[this] var mean = 0.0 + private[this] var m2 = 0.0 + private[this] var m3 = 0.0 + private[this] var m4 = 0.0 + + /** + * Initialize all moments to zero. + */ + override def initialize(buffer: MutableRow): Unit = { + for (aggIndex <- 0 until bufferSize) { + buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + } + } + + /** + * Update the central moments buffer. + */ + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = Cast(child, DoubleType).eval(input) + if (v != null) { + val updateValue = v match { + case d: Double => d + } + + n = buffer.getDouble(nOffset) + mean = buffer.getDouble(meanOffset) + + n += 1.0 + buffer.setDouble(nOffset, n) + delta = updateValue - mean + deltaN = delta / n + mean += deltaN + buffer.setDouble(meanOffset, mean) + + if (momentOrder >= 2) { + m2 = buffer.getDouble(secondMomentOffset) + m2 += delta * (delta - deltaN) + buffer.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + delta2 = delta * delta + deltaN2 = deltaN * deltaN + m3 = buffer.getDouble(thirdMomentOffset) + m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) + buffer.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + m4 = buffer.getDouble(fourthMomentOffset) + m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + + delta * (delta * delta2 - deltaN * deltaN2) + buffer.setDouble(fourthMomentOffset, m4) + } + } + } + + /** + * Merge two central moment buffers. + */ + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val n1 = buffer1.getDouble(nOffset) + val n2 = buffer2.getDouble(inputAggBufferOffset) + val mean1 = buffer1.getDouble(meanOffset) + val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + + var secondMoment1 = 0.0 + var secondMoment2 = 0.0 + + var thirdMoment1 = 0.0 + var thirdMoment2 = 0.0 + + var fourthMoment1 = 0.0 + var fourthMoment2 = 0.0 + + n = n1 + n2 + buffer1.setDouble(nOffset, n) + delta = mean2 - mean1 + deltaN = if (n == 0.0) 0.0 else delta / n + mean = mean1 + deltaN * n2 + buffer1.setDouble(mutableAggBufferOffset + 1, mean) + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + if (momentOrder >= 2) { + secondMoment1 = buffer1.getDouble(secondMomentOffset) + secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) + m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 + buffer1.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + thirdMoment1 = buffer1.getDouble(thirdMomentOffset) + thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) + m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * + (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) + buffer1.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + fourthMoment1 = buffer1.getDouble(fourthMomentOffset) + fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) + m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * + n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * + (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + + 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) + buffer1.setDouble(fourthMomentOffset, m4) + } + } + + /** + * Compute aggregate statistic from sufficient moments. + * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) + * needed to compute the aggregate stat. + */ + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + + override final def eval(buffer: InternalRow): Any = { + val n = buffer.getDouble(nOffset) + val mean = buffer.getDouble(meanOffset) + val moments = Array.ofDim[Double](momentOrder + 1) + moments(0) = 1.0 + moments(1) = 0.0 + if (momentOrder >= 2) { + moments(2) = buffer.getDouble(secondMomentOffset) + } + if (momentOrder >= 3) { + moments(3) = buffer.getDouble(thirdMomentOffset) + } + if (momentOrder >= 4) { + moments(4) = buffer.getDouble(fourthMomentOffset) + } + + getStatistic(n, mean, moments) + } +} + +case class Variance(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "variance" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + + if (n == 0.0) Double.NaN else moments(2) / n + } +} + +case class VarianceSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "variance_samp" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + } +} + +case class VariancePop(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "variance_pop" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) Double.NaN else moments(2) / n + } +} + +case class Skewness(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "skewness" + + override protected val momentOrder = 3 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m3 = moments(3) + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) + } + } +} + +case class Kurtosis(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "kurtosis" + + override protected val momentOrder = 4 + + // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m4 = moments(4) + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + n * m4 / (m2 * m2) - 3.0 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 12bdab0915..c911ec53f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -67,6 +67,12 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Kurtosis(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Kurtosis(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Last(child, ignoreNulls) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Last(child, ignoreNulls), @@ -85,6 +91,12 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Skewness(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Skewness(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Stddev(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Stddev(child), @@ -120,6 +132,24 @@ object Utils { aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), mode = aggregate.Complete, isDistinct = false) + + case expressions.Variance(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Variance(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.VariancePop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.VariancePop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.VarianceSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.VarianceSamp(child), + mode = aggregate.Complete, + isDistinct = false) } // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 70819be5af..c1bab6d36a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -991,3 +991,98 @@ case class StddevFunction( } } } + +// placeholder +case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "kurtosis" + + override def toString: String = s"KURTOSIS($child)" +} + +// placeholder +case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "skewness" + + override def toString: String = s"SKEWNESS($child)" +} + +// placeholder +case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "variance" + + override def toString: String = s"VARIANCE($child)" +} + +// placeholder +case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "variance_pop" + + override def toString: String = s"VAR_POP($child)" +} + +// placeholder +case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "variance_samp" + + override def toString: String = s"VAR_SAMP($child)" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 102b802ad0..dc96384a4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -127,7 +127,12 @@ class GroupedData protected[sql]( case "stddev" => Stddev case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp + case "variance" => Variance + case "var_pop" => VariancePop + case "var_samp" => VarianceSamp case "sum" => Sum + case "skewness" => Skewness + case "kurtosis" => Kurtosis case "count" | "size" => // Turn count(*) into count(1) (inputExpr: Expression) => inputExpr match { @@ -251,6 +256,30 @@ class GroupedData protected[sql]( } /** + * Compute the skewness for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the skewness values for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def skewness(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Skewness) + } + + /** + * Compute the kurtosis for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the kurtosis values for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def kurtosis(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Kurtosis) + } + + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. * When specified columns are given, only compute the max values for them. @@ -333,4 +362,40 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * Compute the sample variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def variance(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Variance) + } + + /** + * Compute the population variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def var_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(VariancePop) + } + + /** + * Compute the sample variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def var_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(VarianceSamp) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 15c864a8ab..c1737b1ef6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -229,6 +229,22 @@ object functions { def first(columnName: String): Column = first(Column(columnName)) /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(e: Column): Column = Kurtosis(e.expr) + + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + + /** * Aggregate function: returns the last value in a group. * * @group agg_funcs @@ -295,8 +311,24 @@ object functions { def min(columnName: String): Column = min(Column(columnName)) /** - * Aggregate function: returns the unbiased sample standard deviation - * of the expression in a group. + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(e: Column): Column = Skewness(e.expr) + + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. * * @group agg_funcs * @since 1.6.0 @@ -304,13 +336,13 @@ object functions { def stddev(e: Column): Column = Stddev(e.expr) /** - * Aggregate function: returns the population standard deviation of + * Aggregate function: returns the unbiased sample standard deviation of * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = StddevPop(e.expr) + def stddev(columnName: String): Column = stddev(Column(columnName)) /** * Aggregate function: returns the unbiased sample standard deviation of @@ -322,6 +354,33 @@ object functions { def stddev_samp(e: Column): Column = StddevSamp(e.expr) /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + + /** * Aggregate function: returns the sum of all values in the expression. * * @group agg_funcs @@ -353,6 +412,54 @@ object functions { */ def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(e: Column): Column = Variance(e.expr) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(e: Column): Column = VarianceSamp(e.expr) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(e: Column): Column = VariancePop(e.expr) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f5ef9ffd7f..9b23977c76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -221,4 +221,77 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { emptyTableData.agg(sumDistinct('a)), Row(null)) } + + test("moments") { + val absTol = 1e-8 + + val sparkVariance = testData2.agg(variance('a)) + val expectedVariance = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkVariance, expectedVariance, absTol) + val sparkVariancePop = testData2.agg(var_pop('a)) + checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol) + + val sparkVarianceSamp = testData2.agg(var_samp('a)) + val expectedVarianceSamp = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol) + + val sparkSkewness = testData2.agg(skewness('a)) + val expectedSkewness = Row(0.0) + checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol) + + val sparkKurtosis = testData2.agg(kurtosis('a)) + val expectedKurtosis = Row(-1.5) + checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol) + + } + + test("zero moments") { + val emptyTableData = Seq((1, 2)).toDF("a", "b") + assert(emptyTableData.count() === 1) + + checkAnswer( + emptyTableData.agg(variance('a)), + Row(0.0)) + + checkAnswer( + emptyTableData.agg(var_samp('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_pop('a)), + Row(0.0)) + + checkAnswer( + emptyTableData.agg(skewness('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(kurtosis('a)), + Row(Double.NaN)) + } + + test("null moments") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.agg(variance('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_samp('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_pop('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(skewness('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(kurtosis('a)), + Row(Double.NaN)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 73e02eb0d9..3c174efe73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -135,6 +135,32 @@ abstract class QueryTest extends PlanTest { } /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param dataFrame the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Seq[Row], + absTol: Double): Unit = { + // TODO: catch exceptions in data frame execution + val actualAnswer = dataFrame.collect() + require(actualAnswer.length == expectedAnswer.length, + s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") + + actualAnswer.zip(expectedAnswer).foreach { + case (actualRow, expectedRow) => + QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) + } + } + + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Row, + absTol: Double): Unit = { + checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) + } + + /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { @@ -214,6 +240,28 @@ object QueryTest { return None } + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require(actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert(math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f5ae3ae49b..5a616fac0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -523,8 +523,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 1, 6, 3) + sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3) ) } @@ -717,14 +718,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev") { checkAnswer( sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4/5.0)) + Row(math.sqrt(4.0 / 5.0)) ) } test("stddev_pop") { checkAnswer( sql("SELECT STDDEV_POP(a) FROM testData2"), - Row(math.sqrt(4/6.0)) + Row(math.sqrt(4.0 / 6.0)) ) } @@ -735,10 +736,60 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("var_samp") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") + val expectedAnswer = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("variance") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") + val expectedAnswer = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("var_pop") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") + val expectedAnswer = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("skewness") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT skewness(a) FROM testData2") + val expectedAnswer = Row(0.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("kurtosis") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") + val expectedAnswer = Row(-1.5) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + test("stddev agg") { checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) + } + + test("variance agg") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" + + "FROM testData2 GROUP BY a") + val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0)) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("skewness and kurtosis agg") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a") + val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0)) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("inner join where, one match per row") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index eed9e436f9..9e357bf348 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -467,7 +467,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_orderby1", "escape_sortby1", "explain_rearrange", - "fetch_aggregation", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", |