aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-01-13 10:26:55 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-13 10:26:55 -0800
commit63eee86cc652c108ca7712c8c0a73db1ca89ae90 (patch)
tree341aa599d17ca0c723b6ac13d1f57ec512a249c6
parentd6fd9b376b7071aecef34dc82a33eba42b183bc9 (diff)
downloadspark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.tar.gz
spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.tar.bz2
spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.zip
[SPARK-9297] [SQL] Add covar_pop and covar_samp
JIRA: https://issues.apache.org/jira/browse/SPARK-9297 Add two aggregation functions: covar_pop and covar_samp. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Closes #10029 from viirya/covar-funcs.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala198
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala40
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala32
4 files changed, 272 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 5c2aa3c06b..d9009e3848 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -182,6 +182,8 @@ object FunctionRegistry {
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
+ expression[CovPopulation]("covar_pop"),
+ expression[CovSample]("covar_samp"),
expression[First]("first"),
expression[First]("first_value"),
expression[Last]("last"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
new file mode 100644
index 0000000000..f53b01be2a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.types._
+
+/**
+ * Compute the covariance between two expressions.
+ * When applied on empty data (i.e., count is zero), it returns NULL.
+ *
+ */
+abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate
+ with Serializable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = DoubleType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"covariance requires that both arguments are double type, " +
+ s"not (${left.dataType}, ${right.dataType}).")
+ }
+ }
+
+ override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ override def inputAggBufferAttributes: Seq[AttributeReference] = {
+ aggBufferAttributes.map(_.newInstance())
+ }
+
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq(
+ AttributeReference("xAvg", DoubleType)(),
+ AttributeReference("yAvg", DoubleType)(),
+ AttributeReference("Ck", DoubleType)(),
+ AttributeReference("count", LongType)())
+
+ // Local cache of mutableAggBufferOffset(s) that will be used in update and merge
+ val xAvgOffset = mutableAggBufferOffset
+ val yAvgOffset = mutableAggBufferOffset + 1
+ val CkOffset = mutableAggBufferOffset + 2
+ val countOffset = mutableAggBufferOffset + 3
+
+ // Local cache of inputAggBufferOffset(s) that will be used in update and merge
+ val inputXAvgOffset = inputAggBufferOffset
+ val inputYAvgOffset = inputAggBufferOffset + 1
+ val inputCkOffset = inputAggBufferOffset + 2
+ val inputCountOffset = inputAggBufferOffset + 3
+
+ override def initialize(buffer: MutableRow): Unit = {
+ buffer.setDouble(xAvgOffset, 0.0)
+ buffer.setDouble(yAvgOffset, 0.0)
+ buffer.setDouble(CkOffset, 0.0)
+ buffer.setLong(countOffset, 0L)
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val leftEval = left.eval(input)
+ val rightEval = right.eval(input)
+
+ if (leftEval != null && rightEval != null) {
+ val x = leftEval.asInstanceOf[Double]
+ val y = rightEval.asInstanceOf[Double]
+
+ var xAvg = buffer.getDouble(xAvgOffset)
+ var yAvg = buffer.getDouble(yAvgOffset)
+ var Ck = buffer.getDouble(CkOffset)
+ var count = buffer.getLong(countOffset)
+
+ val deltaX = x - xAvg
+ val deltaY = y - yAvg
+ count += 1
+ xAvg += deltaX / count
+ yAvg += deltaY / count
+ Ck += deltaX * (y - yAvg)
+
+ buffer.setDouble(xAvgOffset, xAvg)
+ buffer.setDouble(yAvgOffset, yAvg)
+ buffer.setDouble(CkOffset, Ck)
+ buffer.setLong(countOffset, count)
+ }
+ }
+
+ // Merge counters from other partitions. Formula can be found at:
+ // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ val count2 = buffer2.getLong(inputCountOffset)
+
+ // We only go to merge two buffers if there is at least one record aggregated in buffer2.
+ // We don't need to check count in buffer1 because if count2 is more than zero, totalCount
+ // is more than zero too, then we won't get a divide by zero exception.
+ if (count2 > 0) {
+ var xAvg = buffer1.getDouble(xAvgOffset)
+ var yAvg = buffer1.getDouble(yAvgOffset)
+ var Ck = buffer1.getDouble(CkOffset)
+ var count = buffer1.getLong(countOffset)
+
+ val xAvg2 = buffer2.getDouble(inputXAvgOffset)
+ val yAvg2 = buffer2.getDouble(inputYAvgOffset)
+ val Ck2 = buffer2.getDouble(inputCkOffset)
+
+ val totalCount = count + count2
+ val deltaX = xAvg - xAvg2
+ val deltaY = yAvg - yAvg2
+ Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
+ xAvg = (xAvg * count + xAvg2 * count2) / totalCount
+ yAvg = (yAvg * count + yAvg2 * count2) / totalCount
+ count = totalCount
+
+ buffer1.setDouble(xAvgOffset, xAvg)
+ buffer1.setDouble(yAvgOffset, yAvg)
+ buffer1.setDouble(CkOffset, Ck)
+ buffer1.setLong(countOffset, count)
+ }
+ }
+}
+
+case class CovSample(
+ left: Expression,
+ right: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends Covariance(left, right) {
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def eval(buffer: InternalRow): Any = {
+ val count = buffer.getLong(countOffset)
+ if (count > 1) {
+ val Ck = buffer.getDouble(CkOffset)
+ val cov = Ck / (count - 1)
+ if (cov.isNaN) {
+ null
+ } else {
+ cov
+ }
+ } else {
+ null
+ }
+ }
+}
+
+case class CovPopulation(
+ left: Expression,
+ right: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends Covariance(left, right) {
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def eval(buffer: InternalRow): Any = {
+ val count = buffer.getLong(countOffset)
+ if (count > 0) {
+ val Ck = buffer.getDouble(CkOffset)
+ if (Ck.isNaN) {
+ null
+ } else {
+ Ck / count
+ }
+ } else {
+ null
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 592d79df31..71fea2716b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -309,6 +309,46 @@ object functions extends LegacyFunctions {
countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
/**
+ * Aggregate function: returns the population covariance for two columns.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction {
+ CovPopulation(column1.expr, column2.expr)
+ }
+
+ /**
+ * Aggregate function: returns the population covariance for two columns.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def covar_pop(columnName1: String, columnName2: String): Column = {
+ covar_pop(Column(columnName1), Column(columnName2))
+ }
+
+ /**
+ * Aggregate function: returns the sample covariance for two columns.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction {
+ CovSample(column1.expr, column2.expr)
+ }
+
+ /**
+ * Aggregate function: returns the sample covariance for two columns.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def covar_samp(columnName1: String, columnName2: String): Column = {
+ covar_samp(Column(columnName1), Column(columnName2))
+ }
+
+ /**
* Aggregate function: returns the first value in a group.
*
* @group agg_funcs
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 5550198c02..76b36aa891 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
}
+ test("covariance: covar_pop and covar_samp") {
+ // non-trivial example. To reproduce in python, use:
+ // >>> import numpy as np
+ // >>> a = np.array(range(20))
+ // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+ // >>> np.cov(a, b, bias = 0)[0][1]
+ // 595.0
+ // >>> np.cov(a, b, bias = 1)[0][1]
+ // 565.25
+ val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
+ val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_samp - 595.0) < 1e-12)
+
+ val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_pop - 565.25) < 1e-12)
+
+ val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
+ val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
+
+ val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
+
+ // one row test
+ val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
+ val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
+ assert(cov_samp3 == null)
+
+ val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ assert(cov_pop3 == 0.0)
+ }
+
test("no aggregation function (SPARK-11486)") {
val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
.groupBy("s").count()