aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-02 11:50:14 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-02 11:50:14 -0800
commitbe5dd881f1eff248224a92d57cfd1309cb3acf38 (patch)
tree7fdf890c80dc6a7e63028b0829f1020ca0c65a54
parent7f6e3ec79b77400f558ceffa10b2af011962115f (diff)
downloadspark-be5dd881f1eff248224a92d57cfd1309cb3acf38.tar.gz
spark-be5dd881f1eff248224a92d57cfd1309cb3acf38.tar.bz2
spark-be5dd881f1eff248224a92d57cfd1309cb3acf38.zip
[SPARK-12913] [SQL] Improve performance of stat functions
As benchmarked and discussed here: https://github.com/apache/spark/pull/10786/files#r50038294, benefits from codegen, the declarative aggregate function could be much faster than imperative one. Author: Davies Liu <davies@databricks.com> Closes #10960 from davies/stddev.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala285
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala208
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala205
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala54
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala81
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala81
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala55
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala17
14 files changed, 331 insertions, 755 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 957ac89fa5..57bdb164e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -347,18 +347,12 @@ object HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
- case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
+ case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
+ case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
+ case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
+ case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
+ case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 30f602227b..9d2db45144 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -44,7 +42,7 @@ import org.apache.spark.sql.types._
*
* @param child to compute central moments of.
*/
-abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable {
+abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate {
/**
* The central moment order to be computed.
@@ -52,178 +50,161 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
protected def momentOrder: Int
override def children: Seq[Expression] = Seq(child)
-
override def nullable: Boolean = true
-
override def dataType: DataType = DoubleType
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+ protected val n = AttributeReference("n", DoubleType, nullable = false)()
+ protected val avg = AttributeReference("avg", DoubleType, nullable = false)()
+ protected val m2 = AttributeReference("m2", DoubleType, nullable = false)()
+ protected val m3 = AttributeReference("m3", DoubleType, nullable = false)()
+ protected val m4 = AttributeReference("m4", DoubleType, nullable = false)()
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
+ private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1)
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+ override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4))
- /**
- * Size of aggregation buffer.
- */
- private[this] val bufferSize = 5
+ override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0))
- override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i =>
- AttributeReference(s"M$i", DoubleType)()
+ override val updateExpressions: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val delta = child - avg
+ val deltaN = delta / newN
+ val newAvg = avg + deltaN
+ val newM2 = m2 + delta * (delta - deltaN)
+
+ val delta2 = delta * delta
+ val deltaN2 = deltaN * deltaN
+ val newM3 = if (momentOrder >= 3) {
+ m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
+ } else {
+ Literal(0.0)
+ }
+ val newM4 = if (momentOrder >= 4) {
+ m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
+ delta * (delta * delta2 - deltaN * deltaN2)
+ } else {
+ Literal(0.0)
+ }
+
+ trimHigherOrder(Seq(
+ If(IsNull(child), n, newN),
+ If(IsNull(child), avg, newAvg),
+ If(IsNull(child), m2, newM2),
+ If(IsNull(child), m3, newM3),
+ If(IsNull(child), m4, newM4)
+ ))
}
- // Note: although this simply copies aggBufferAttributes, this common code can not be placed
- // in the superclass because that will lead to initialization ordering issues.
- override val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
-
- // buffer offsets
- private[this] val nOffset = mutableAggBufferOffset
- private[this] val meanOffset = mutableAggBufferOffset + 1
- private[this] val secondMomentOffset = mutableAggBufferOffset + 2
- private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
- private[this] val fourthMomentOffset = mutableAggBufferOffset + 4
-
- // frequently used values for online updates
- private[this] var delta = 0.0
- private[this] var deltaN = 0.0
- private[this] var delta2 = 0.0
- private[this] var deltaN2 = 0.0
- private[this] var n = 0.0
- private[this] var mean = 0.0
- private[this] var m2 = 0.0
- private[this] var m3 = 0.0
- private[this] var m4 = 0.0
+ override val mergeExpressions: Seq[Expression] = {
- /**
- * Initialize all moments to zero.
- */
- override def initialize(buffer: MutableRow): Unit = {
- for (aggIndex <- 0 until bufferSize) {
- buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val delta = avg.right - avg.left
+ val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN)
+ val newAvg = avg.left + deltaN * n2
+
+ // higher order moments computed according to:
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
+ val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2
+ // `m3.right` is not available if momentOrder < 3
+ val newM3 = if (momentOrder >= 3) {
+ m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) +
+ Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left)
+ } else {
+ Literal(0.0)
}
+ // `m4.right` is not available if momentOrder < 4
+ val newM4 = if (momentOrder >= 4) {
+ m4.left + m4.right +
+ deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) +
+ Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) +
+ Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left)
+ } else {
+ Literal(0.0)
+ }
+
+ trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
}
+}
- /**
- * Update the central moments buffer.
- */
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val v = Cast(child, DoubleType).eval(input)
- if (v != null) {
- val updateValue = v match {
- case d: Double => d
- }
-
- n = buffer.getDouble(nOffset)
- mean = buffer.getDouble(meanOffset)
-
- n += 1.0
- buffer.setDouble(nOffset, n)
- delta = updateValue - mean
- deltaN = delta / n
- mean += deltaN
- buffer.setDouble(meanOffset, mean)
-
- if (momentOrder >= 2) {
- m2 = buffer.getDouble(secondMomentOffset)
- m2 += delta * (delta - deltaN)
- buffer.setDouble(secondMomentOffset, m2)
- }
-
- if (momentOrder >= 3) {
- delta2 = delta * delta
- deltaN2 = deltaN * deltaN
- m3 = buffer.getDouble(thirdMomentOffset)
- m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
- buffer.setDouble(thirdMomentOffset, m3)
- }
-
- if (momentOrder >= 4) {
- m4 = buffer.getDouble(fourthMomentOffset)
- m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
- delta * (delta * delta2 - deltaN * deltaN2)
- buffer.setDouble(fourthMomentOffset, m4)
- }
- }
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 2
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ Sqrt(m2 / n))
}
- /**
- * Merge two central moment buffers.
- */
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val n1 = buffer1.getDouble(nOffset)
- val n2 = buffer2.getDouble(inputAggBufferOffset)
- val mean1 = buffer1.getDouble(meanOffset)
- val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
+ override def prettyName: String = "stddev_pop"
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 2
- var secondMoment1 = 0.0
- var secondMoment2 = 0.0
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ Sqrt(m2 / (n - Literal(1.0)))))
+ }
- var thirdMoment1 = 0.0
- var thirdMoment2 = 0.0
+ override def prettyName: String = "stddev_samp"
+}
- var fourthMoment1 = 0.0
- var fourthMoment2 = 0.0
+// Compute the population variance of a column
+case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
- n = n1 + n2
- buffer1.setDouble(nOffset, n)
- delta = mean2 - mean1
- deltaN = if (n == 0.0) 0.0 else delta / n
- mean = mean1 + deltaN * n2
- buffer1.setDouble(mutableAggBufferOffset + 1, mean)
+ override protected def momentOrder = 2
- // higher order moments computed according to:
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
- if (momentOrder >= 2) {
- secondMoment1 = buffer1.getDouble(secondMomentOffset)
- secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
- m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
- buffer1.setDouble(secondMomentOffset, m2)
- }
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ m2 / n)
+ }
- if (momentOrder >= 3) {
- thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
- thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
- m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
- (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
- buffer1.setDouble(thirdMomentOffset, m3)
- }
+ override def prettyName: String = "var_pop"
+}
- if (momentOrder >= 4) {
- fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
- fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
- m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 *
- n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
- (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
- 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
- buffer1.setDouble(fourthMomentOffset, m4)
- }
+// Compute the sample variance of a column
+case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 2
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ m2 / (n - Literal(1.0))))
}
- /**
- * Compute aggregate statistic from sufficient moments.
- * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
- * needed to compute the aggregate stat.
- */
- def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any
-
- override final def eval(buffer: InternalRow): Any = {
- val n = buffer.getDouble(nOffset)
- val mean = buffer.getDouble(meanOffset)
- val moments = Array.ofDim[Double](momentOrder + 1)
- moments(0) = 1.0
- moments(1) = 0.0
- if (momentOrder >= 2) {
- moments(2) = buffer.getDouble(secondMomentOffset)
- }
- if (momentOrder >= 3) {
- moments(3) = buffer.getDouble(thirdMomentOffset)
- }
- if (momentOrder >= 4) {
- moments(4) = buffer.getDouble(fourthMomentOffset)
- }
+ override def prettyName: String = "var_samp"
+}
+
+case class Skewness(child: Expression) extends CentralMomentAgg(child) {
+
+ override def prettyName: String = "skewness"
+
+ override protected def momentOrder = 3
- getStatistic(n, mean, moments)
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(m2 === Literal(0.0), Literal(Double.NaN),
+ Sqrt(n) * m3 / Sqrt(m2 * m2 * m2)))
}
}
+
+case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 4
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(m2 === Literal(0.0), Literal(Double.NaN),
+ n * m4 / (m2 * m2) - Literal(3.0)))
+ }
+
+ override def prettyName: String = "kurtosis"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index d25f3335ff..e6b8214ef2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -29,165 +28,70 @@ import org.apache.spark.sql.types._
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
-case class Corr(
- left: Expression,
- right: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends ImperativeAggregate {
-
- def this(left: Expression, right: Expression) =
- this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def children: Seq[Expression] = Seq(left, right)
+case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
+ override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
-
override def dataType: DataType = DoubleType
-
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
- override def checkInputDataTypes(): TypeCheckResult = {
- if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"corr requires that both arguments are double type, " +
- s"not (${left.dataType}, ${right.dataType}).")
- }
+ protected val n = AttributeReference("n", DoubleType, nullable = false)()
+ protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)()
+ protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)()
+ protected val ck = AttributeReference("ck", DoubleType, nullable = false)()
+ protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)()
+ protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)()
+
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk)
+
+ override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0))
+
+ override val updateExpressions: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val dx = x - xAvg
+ val dxN = dx / newN
+ val dy = y - yAvg
+ val dyN = dy / newN
+ val newXAvg = xAvg + dxN
+ val newYAvg = yAvg + dyN
+ val newCk = ck + dx * (y - newYAvg)
+ val newXMk = xMk + dx * (x - newXAvg)
+ val newYMk = yMk + dy * (y - newYAvg)
+
+ val isNull = IsNull(x) || IsNull(y)
+ Seq(
+ If(isNull, n, newN),
+ If(isNull, xAvg, newXAvg),
+ If(isNull, yAvg, newYAvg),
+ If(isNull, ck, newCk),
+ If(isNull, xMk, newXMk),
+ If(isNull, yMk, newYMk)
+ )
}
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
- override def inputAggBufferAttributes: Seq[AttributeReference] = {
- aggBufferAttributes.map(_.newInstance())
+ override val mergeExpressions: Seq[Expression] = {
+
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val dx = xAvg.right - xAvg.left
+ val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+ val dy = yAvg.right - yAvg.left
+ val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+ val newXAvg = xAvg.left + dxN * n2
+ val newYAvg = yAvg.left + dyN * n2
+ val newCk = ck.left + ck.right + dx * dyN * n1 * n2
+ val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
+ val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
+
+ Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
}
- override val aggBufferAttributes: Seq[AttributeReference] = Seq(
- AttributeReference("xAvg", DoubleType)(),
- AttributeReference("yAvg", DoubleType)(),
- AttributeReference("Ck", DoubleType)(),
- AttributeReference("MkX", DoubleType)(),
- AttributeReference("MkY", DoubleType)(),
- AttributeReference("count", LongType)())
-
- // Local cache of mutableAggBufferOffset(s) that will be used in update and merge
- private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
- private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
- private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
- private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
- private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5
-
- // Local cache of inputAggBufferOffset(s) that will be used in update and merge
- private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
- private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
- private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
- private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
- private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def initialize(buffer: MutableRow): Unit = {
- buffer.setDouble(mutableAggBufferOffset, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
- buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ ck / Sqrt(xMk * yMk)))
}
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val leftEval = left.eval(input)
- val rightEval = right.eval(input)
-
- if (leftEval != null && rightEval != null) {
- val x = leftEval.asInstanceOf[Double]
- val y = rightEval.asInstanceOf[Double]
-
- var xAvg = buffer.getDouble(mutableAggBufferOffset)
- var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
- var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
- var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
- var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
- var count = buffer.getLong(mutableAggBufferOffsetPlus5)
-
- val deltaX = x - xAvg
- val deltaY = y - yAvg
- count += 1
- xAvg += deltaX / count
- yAvg += deltaY / count
- Ck += deltaX * (y - yAvg)
- MkX += deltaX * (x - xAvg)
- MkY += deltaY * (y - yAvg)
-
- buffer.setDouble(mutableAggBufferOffset, xAvg)
- buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
- buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
- buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
- buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
- buffer.setLong(mutableAggBufferOffsetPlus5, count)
- }
- }
-
- // Merge counters from other partitions. Formula can be found at:
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)
-
- // We only go to merge two buffers if there is at least one record aggregated in buffer2.
- // We don't need to check count in buffer1 because if count2 is more than zero, totalCount
- // is more than zero too, then we won't get a divide by zero exception.
- if (count2 > 0) {
- var xAvg = buffer1.getDouble(mutableAggBufferOffset)
- var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
- var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
- var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
- var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
- var count = buffer1.getLong(mutableAggBufferOffsetPlus5)
-
- val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
- val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
- val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
- val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
- val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)
-
- val totalCount = count + count2
- val deltaX = xAvg - xAvg2
- val deltaY = yAvg - yAvg2
- Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
- xAvg = (xAvg * count + xAvg2 * count2) / totalCount
- yAvg = (yAvg * count + yAvg2 * count2) / totalCount
- MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
- MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
- count = totalCount
-
- buffer1.setDouble(mutableAggBufferOffset, xAvg)
- buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
- buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
- buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
- buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
- buffer1.setLong(mutableAggBufferOffsetPlus5, count)
- }
- }
-
- override def eval(buffer: InternalRow): Any = {
- val count = buffer.getLong(mutableAggBufferOffsetPlus5)
- if (count > 0) {
- val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
- val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
- val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
- val corr = Ck / math.sqrt(MkX * MkY)
- if (corr.isNaN) {
- null
- } else {
- corr
- }
- } else {
- null
- }
- }
+ override def prettyName: String = "corr"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index f53b01be2a..c175a8c4c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -17,182 +17,79 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
* Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
- *
*/
-abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate
- with Serializable {
- override def children: Seq[Expression] = Seq(left, right)
+abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate {
+ override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
-
override def dataType: DataType = DoubleType
-
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
- override def checkInputDataTypes(): TypeCheckResult = {
- if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"covariance requires that both arguments are double type, " +
- s"not (${left.dataType}, ${right.dataType}).")
- }
- }
-
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
- override def inputAggBufferAttributes: Seq[AttributeReference] = {
- aggBufferAttributes.map(_.newInstance())
- }
-
- override val aggBufferAttributes: Seq[AttributeReference] = Seq(
- AttributeReference("xAvg", DoubleType)(),
- AttributeReference("yAvg", DoubleType)(),
- AttributeReference("Ck", DoubleType)(),
- AttributeReference("count", LongType)())
-
- // Local cache of mutableAggBufferOffset(s) that will be used in update and merge
- val xAvgOffset = mutableAggBufferOffset
- val yAvgOffset = mutableAggBufferOffset + 1
- val CkOffset = mutableAggBufferOffset + 2
- val countOffset = mutableAggBufferOffset + 3
-
- // Local cache of inputAggBufferOffset(s) that will be used in update and merge
- val inputXAvgOffset = inputAggBufferOffset
- val inputYAvgOffset = inputAggBufferOffset + 1
- val inputCkOffset = inputAggBufferOffset + 2
- val inputCountOffset = inputAggBufferOffset + 3
-
- override def initialize(buffer: MutableRow): Unit = {
- buffer.setDouble(xAvgOffset, 0.0)
- buffer.setDouble(yAvgOffset, 0.0)
- buffer.setDouble(CkOffset, 0.0)
- buffer.setLong(countOffset, 0L)
- }
-
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val leftEval = left.eval(input)
- val rightEval = right.eval(input)
-
- if (leftEval != null && rightEval != null) {
- val x = leftEval.asInstanceOf[Double]
- val y = rightEval.asInstanceOf[Double]
-
- var xAvg = buffer.getDouble(xAvgOffset)
- var yAvg = buffer.getDouble(yAvgOffset)
- var Ck = buffer.getDouble(CkOffset)
- var count = buffer.getLong(countOffset)
-
- val deltaX = x - xAvg
- val deltaY = y - yAvg
- count += 1
- xAvg += deltaX / count
- yAvg += deltaY / count
- Ck += deltaX * (y - yAvg)
-
- buffer.setDouble(xAvgOffset, xAvg)
- buffer.setDouble(yAvgOffset, yAvg)
- buffer.setDouble(CkOffset, Ck)
- buffer.setLong(countOffset, count)
- }
+ protected val n = AttributeReference("n", DoubleType, nullable = false)()
+ protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)()
+ protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)()
+ protected val ck = AttributeReference("ck", DoubleType, nullable = false)()
+
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck)
+
+ override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0))
+
+ override lazy val updateExpressions: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val dx = x - xAvg
+ val dy = y - yAvg
+ val dyN = dy / newN
+ val newXAvg = xAvg + dx / newN
+ val newYAvg = yAvg + dyN
+ val newCk = ck + dx * (y - newYAvg)
+
+ val isNull = IsNull(x) || IsNull(y)
+ Seq(
+ If(isNull, n, newN),
+ If(isNull, xAvg, newXAvg),
+ If(isNull, yAvg, newYAvg),
+ If(isNull, ck, newCk)
+ )
}
- // Merge counters from other partitions. Formula can be found at:
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val count2 = buffer2.getLong(inputCountOffset)
-
- // We only go to merge two buffers if there is at least one record aggregated in buffer2.
- // We don't need to check count in buffer1 because if count2 is more than zero, totalCount
- // is more than zero too, then we won't get a divide by zero exception.
- if (count2 > 0) {
- var xAvg = buffer1.getDouble(xAvgOffset)
- var yAvg = buffer1.getDouble(yAvgOffset)
- var Ck = buffer1.getDouble(CkOffset)
- var count = buffer1.getLong(countOffset)
+ override val mergeExpressions: Seq[Expression] = {
- val xAvg2 = buffer2.getDouble(inputXAvgOffset)
- val yAvg2 = buffer2.getDouble(inputYAvgOffset)
- val Ck2 = buffer2.getDouble(inputCkOffset)
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val dx = xAvg.right - xAvg.left
+ val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+ val dy = yAvg.right - yAvg.left
+ val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+ val newXAvg = xAvg.left + dxN * n2
+ val newYAvg = yAvg.left + dyN * n2
+ val newCk = ck.left + ck.right + dx * dyN * n1 * n2
- val totalCount = count + count2
- val deltaX = xAvg - xAvg2
- val deltaY = yAvg - yAvg2
- Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
- xAvg = (xAvg * count + xAvg2 * count2) / totalCount
- yAvg = (yAvg * count + yAvg2 * count2) / totalCount
- count = totalCount
-
- buffer1.setDouble(xAvgOffset, xAvg)
- buffer1.setDouble(yAvgOffset, yAvg)
- buffer1.setDouble(CkOffset, Ck)
- buffer1.setLong(countOffset, count)
- }
+ Seq(newN, newXAvg, newYAvg, newCk)
}
}
-case class CovSample(
- left: Expression,
- right: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends Covariance(left, right) {
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def eval(buffer: InternalRow): Any = {
- val count = buffer.getLong(countOffset)
- if (count > 1) {
- val Ck = buffer.getDouble(CkOffset)
- val cov = Ck / (count - 1)
- if (cov.isNaN) {
- null
- } else {
- cov
- }
- } else {
- null
- }
+case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) {
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ ck / n)
}
+ override def prettyName: String = "covar_pop"
}
-case class CovPopulation(
- left: Expression,
- right: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends Covariance(left, right) {
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
- override def eval(buffer: InternalRow): Any = {
- val count = buffer.getLong(countOffset)
- if (count > 0) {
- val Ck = buffer.getDouble(CkOffset)
- if (Ck.isNaN) {
- null
- } else {
- Ck / count
- }
- } else {
- null
- }
+case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) {
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ ck / (n - Literal(1.0))))
}
+ override def prettyName: String = "covar_samp"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
deleted file mode 100644
index c2bf2cb941..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class Kurtosis(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "kurtosis"
-
- override protected val momentOrder = 4
-
- // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
- val m2 = moments(2)
- val m4 = moments(4)
-
- if (n == 0.0) {
- null
- } else if (m2 == 0.0) {
- Double.NaN
- } else {
- n * m4 / (m2 * m2) - 3.0
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
deleted file mode 100644
index 9411bcea25..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class Skewness(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "skewness"
-
- override protected val momentOrder = 3
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
- val m2 = moments(2)
- val m3 = moments(3)
-
- if (n == 0.0) {
- null
- } else if (m2 == 0.0) {
- Double.NaN
- } else {
- math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
deleted file mode 100644
index eec79a9033..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class StddevSamp(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "stddev_samp"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else if (n == 1.0) {
- Double.NaN
- } else {
- math.sqrt(moments(2) / (n - 1.0))
- }
- }
-}
-
-case class StddevPop(
- child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "stddev_pop"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else {
- math.sqrt(moments(2) / n)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
deleted file mode 100644
index cf3a740305..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class VarianceSamp(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "var_samp"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else if (n == 1.0) {
- Double.NaN
- } else {
- moments(2) / (n - 1.0)
- }
- }
-}
-
-case class VariancePop(
- child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "var_pop"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else {
- moments(2) / n
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 36e1fa1176..f4ccadd9c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -424,3 +424,21 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
}
}
}
+
+/**
+ * Print the result of an expression to stderr (used for debugging codegen).
+ */
+case class PrintToStderr(child: Expression) extends UnaryExpression {
+
+ override def dataType: DataType = child.dataType
+
+ protected override def nullSafeEval(input: Any): Any = input
+
+ override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ nullSafeCodeGen(ctx, ev, c =>
+ s"""
+ | System.err.println("Result of ${child.simpleString} is " + $c);
+ | ${ev.value} = $c;
+ """.stripMargin)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 26a7340f1a..84154a47de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -198,7 +198,8 @@ case class Window(
functions,
ordinal,
child.output,
- (expressions, schema) => newMutableProjection(expressions, schema))
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
// Create the factory
val factory = key match {
@@ -210,7 +211,8 @@ case class Window(
ordinal,
functions,
child.output,
- (expressions, schema) => newMutableProjection(expressions, schema),
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
offset)
// Growing Frame.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 57db7262fd..a8a81d6d65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -240,7 +240,6 @@ case class TungstenAggregate(
| ${bufVars(i).value} = ${ev.value};
""".stripMargin
}
-
s"""
| // do aggregate
| ${aggVals.map(_.code).mkString("\n")}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 2f09c8a114..1ccf0e3d06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -59,6 +59,55 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
}
+ def testStatFunctions(values: Int): Unit = {
+
+ val benchmark = new Benchmark("stat functions", values)
+
+ benchmark.addCase("stddev w/o codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+ sqlContext.range(values).groupBy().agg("id" -> "stddev").collect()
+ }
+
+ benchmark.addCase("stddev w codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+ sqlContext.range(values).groupBy().agg("id" -> "stddev").collect()
+ }
+
+ benchmark.addCase("kurtosis w/o codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+ sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect()
+ }
+
+ benchmark.addCase("kurtosis w codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+ sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect()
+ }
+
+
+ /**
+ Using ImperativeAggregate (as implemented in Spark 1.6):
+
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ stddev w/o codegen 2019.04 10.39 1.00 X
+ stddev w codegen 2097.29 10.00 0.96 X
+ kurtosis w/o codegen 2108.99 9.94 0.96 X
+ kurtosis w codegen 2090.69 10.03 0.97 X
+
+ Using DeclarativeAggregate:
+
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ stddev w/o codegen 989.22 21.20 1.00 X
+ stddev w codegen 352.35 59.52 2.81 X
+ kurtosis w/o codegen 3636.91 5.77 0.27 X
+ kurtosis w codegen 369.25 56.79 2.68 X
+ */
+ benchmark.run()
+ }
+
def testAggregateWithKey(values: Int): Unit = {
val benchmark = new Benchmark("Aggregate with keys", values)
@@ -147,8 +196,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
}
- test("benchmark") {
- // testWholeStage(1024 * 1024 * 200)
+ // These benchmark are skipped in normal build
+ ignore("benchmark") {
+ // testWholeStage(200 << 20)
+ // testStddev(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 554d47d651..61b73fa557 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -325,6 +325,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"drop_partitions_ignore_protection",
"protectmode",
+ // Hive returns null rather than NaN when n = 1
+ "udaf_covar_samp",
+
// Spark parser treats numerical literals differently: it creates decimals instead of doubles.
"udf_abs",
"udf_format_number",
@@ -881,7 +884,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"type_widening",
"udaf_collect_set",
"udaf_covar_pop",
- "udaf_covar_samp",
"udaf_histogram_numeric",
"udf2",
"udf5",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 7a9ed1eaf3..caf1db9ad0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -798,7 +798,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
"""
|SELECT corr(b, c) FROM covar_tab WHERE a = 3
""".stripMargin),
- Row(null) :: Nil)
+ Row(Double.NaN) :: Nil)
checkAnswer(
sqlContext.sql(
@@ -807,10 +807,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
""".stripMargin),
Row(1, null) ::
Row(2, null) ::
- Row(3, null) ::
- Row(4, null) ::
- Row(5, null) ::
- Row(6, null) :: Nil)
+ Row(3, Double.NaN) ::
+ Row(4, Double.NaN) ::
+ Row(5, Double.NaN) ::
+ Row(6, Double.NaN) :: Nil)
val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
@@ -841,11 +841,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// one row test
val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
- val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
- assert(cov_samp3 == null)
-
- val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
- assert(cov_pop3 == 0.0)
+ checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN))
+ checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0))
}
test("no aggregation function (SPARK-11486)") {