aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala159
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala18
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala104
7 files changed, 311 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ed9fcfe014..5f3ec74ac0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -178,6 +178,7 @@ object FunctionRegistry {
// aggregate functions
expression[Average]("avg"),
+ expression[Corr]("corr"),
expression[Count]("count"),
expression[First]("first"),
expression[First]("first_value"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 281404f285..5d2eb7b017 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -23,6 +23,7 @@ import java.util
import com.clearspring.analytics.hash.MurmurHash
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
@@ -524,6 +525,164 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = Cast(currentSum, resultType)
}
+/**
+ * Compute Pearson correlation between two expressions.
+ * When applied on empty data (i.e., count is zero), it returns NULL.
+ *
+ * Definition of Pearson correlation can be found at
+ * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
+ *
+ * @param left one of the expressions to compute correlation with.
+ * @param right another expression to compute correlation with.
+ */
+case class Corr(
+ left: Expression,
+ right: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends ImperativeAggregate {
+
+ def children: Seq[Expression] = Seq(left, right)
+
+ def nullable: Boolean = false
+
+ def dataType: DataType = DoubleType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+
+ def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ def inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance())
+
+ val aggBufferAttributes: Seq[AttributeReference] = Seq(
+ AttributeReference("xAvg", DoubleType)(),
+ AttributeReference("yAvg", DoubleType)(),
+ AttributeReference("Ck", DoubleType)(),
+ AttributeReference("MkX", DoubleType)(),
+ AttributeReference("MkY", DoubleType)(),
+ AttributeReference("count", LongType)())
+
+ // Local cache of mutableAggBufferOffset(s) that will be used in update and merge
+ private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
+ private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
+ private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
+ private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
+ private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5
+
+ // Local cache of inputAggBufferOffset(s) that will be used in update and merge
+ private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
+ private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
+ private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
+ private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
+ private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def initialize(buffer: MutableRow): Unit = {
+ buffer.setDouble(mutableAggBufferOffset, 0.0)
+ buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
+ buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
+ buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
+ buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
+ buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val leftEval = left.eval(input)
+ val rightEval = right.eval(input)
+
+ if (leftEval != null && rightEval != null) {
+ val x = leftEval.asInstanceOf[Double]
+ val y = rightEval.asInstanceOf[Double]
+
+ var xAvg = buffer.getDouble(mutableAggBufferOffset)
+ var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
+ var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
+ var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
+ var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
+ var count = buffer.getLong(mutableAggBufferOffsetPlus5)
+
+ val deltaX = x - xAvg
+ val deltaY = y - yAvg
+ count += 1
+ xAvg += deltaX / count
+ yAvg += deltaY / count
+ Ck += deltaX * (y - yAvg)
+ MkX += deltaX * (x - xAvg)
+ MkY += deltaY * (y - yAvg)
+
+ buffer.setDouble(mutableAggBufferOffset, xAvg)
+ buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
+ buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
+ buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
+ buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
+ buffer.setLong(mutableAggBufferOffsetPlus5, count)
+ }
+ }
+
+ // Merge counters from other partitions. Formula can be found at:
+ // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)
+
+ // We only go to merge two buffers if there is at least one record aggregated in buffer2.
+ // We don't need to check count in buffer1 because if count2 is more than zero, totalCount
+ // is more than zero too, then we won't get a divide by zero exception.
+ if (count2 > 0) {
+ var xAvg = buffer1.getDouble(mutableAggBufferOffset)
+ var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
+ var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
+ var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
+ var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
+ var count = buffer1.getLong(mutableAggBufferOffsetPlus5)
+
+ val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
+ val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
+ val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
+ val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
+ val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)
+
+ val totalCount = count + count2
+ val deltaX = xAvg - xAvg2
+ val deltaY = yAvg - yAvg2
+ Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
+ xAvg = (xAvg * count + xAvg2 * count2) / totalCount
+ yAvg = (yAvg * count + yAvg2 * count2) / totalCount
+ MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
+ MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
+ count = totalCount
+
+ buffer1.setDouble(mutableAggBufferOffset, xAvg)
+ buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
+ buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
+ buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
+ buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
+ buffer1.setLong(mutableAggBufferOffsetPlus5, count)
+ }
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ val count = buffer.getLong(mutableAggBufferOffsetPlus5)
+ if (count > 0) {
+ val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
+ val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
+ val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
+ val corr = Ck / math.sqrt(MkX * MkY)
+ if (corr.isNaN) {
+ null
+ } else {
+ corr
+ }
+ } else {
+ null
+ }
+ }
+}
+
// scalastyle:off
/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. This class
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index c911ec53f1..564174f9b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -127,6 +127,12 @@ object Utils {
mode = aggregate.Complete,
isDistinct = true)
+ case expressions.Corr(left, right) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Corr(left, right),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
case expressions.ApproxCountDistinct(child, rsd) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index c1bab6d36a..bf59660c38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -747,6 +747,24 @@ case class LastFunction(
}
}
+/**
+ * Calculate Pearson Correlation Coefficient for the given columns.
+ * Only support AggregateExpression2.
+ *
+ */
+case class Corr(left: Expression, right: Expression)
+ extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
+ override def nullable: Boolean = false
+ override def dataType: DoubleType.type = DoubleType
+ override def toString: String = s"CORRELATION($left, $right)"
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+ override def newInstance(): AggregateFunction1 = {
+ throw new UnsupportedOperationException(
+ "Corr only supports the new AggregateExpression2 and can only be used " +
+ "when spark.sql.useAggregate2 = true")
+ }
+}
+
// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c1737b1ef6..5a5c695e6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -173,6 +173,24 @@ object functions {
def avg(columnName: String): Column = avg(Column(columnName))
/**
+ * Aggregate function: returns the Pearson Correlation Coefficient for two columns.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def corr(column1: Column, column2: Column): Column =
+ Corr(column1.expr, column2.expr)
+
+ /**
+ * Aggregate function: returns the Pearson Correlation Coefficient for two columns.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def corr(columnName1: String, columnName2: String): Column =
+ corr(Column(columnName1), Column(columnName2))
+
+ /**
* Aggregate function: returns the number of items in a group.
*
* @group agg_funcs
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 9e357bf348..6ed40b0397 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -304,7 +304,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// classpath problems
"compute_stats.*",
- "udf_bitmap_.*"
+ "udf_bitmap_.*",
+
+ // The difference between the double numbers generated by Hive and Spark
+ // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322)
+ "udaf_corr"
)
/**
@@ -857,7 +861,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"type_cast_1",
"type_widening",
"udaf_collect_set",
- "udaf_corr",
"udaf_covar_pop",
"udaf_covar_samp",
"udaf_histogram_numeric",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index f38a3f63c3..0cf0e0aab9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._
+import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
@@ -556,6 +557,109 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(0, null, 1, 1, null, 0) :: Nil)
}
+ test("pearson correlation") {
+ val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
+ val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(corr1 - 1.0) < 1e-12)
+ val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+ assert(math.abs(corr2 + 1.0) < 1e-12)
+ // non-trivial example. To reproduce in python, use:
+ // >>> from scipy.stats import pearsonr
+ // >>> import numpy as np
+ // >>> a = np.array(range(20))
+ // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+ // >>> pearsonr(a, b)
+ // (0.95723391394758572, 3.8902121417802199e-11)
+ // In R, use:
+ // > a <- 0:19
+ // > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
+ // > cor(a, b)
+ // [1] 0.957233913947585835
+ val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
+ val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
+
+ val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
+ val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0)
+ assert(corr4 == Row(null))
+
+ val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
+ val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(corr5 - 1.0) < 1e-12)
+ val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+ assert(math.abs(corr6 + 1.0) < 1e-12)
+
+ // Test for udaf_corr in HiveCompatibilitySuite
+ // udaf_corr has been blacklisted due to numerical errors
+ // We test it here:
+ // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL
+ // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL
+ // SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL
+ // SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; =>
+ // 1 NULL
+ // 2 NULL
+ // 3 NULL
+ // 4 NULL
+ // 5 NULL
+ // 6 NULL
+ // SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323
+
+ val covar_tab = Seq[(Integer, Integer, Integer)](
+ (1, null, 15),
+ (2, 3, null),
+ (3, 7, 12),
+ (4, 4, 14),
+ (5, 8, 17),
+ (6, 2, 11)).toDF("a", "b", "c")
+
+ covar_tab.registerTempTable("covar_tab")
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT corr(b, c) FROM covar_tab WHERE a < 1
+ """.stripMargin),
+ Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT corr(b, c) FROM covar_tab WHERE a < 3
+ """.stripMargin),
+ Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT corr(b, c) FROM covar_tab WHERE a = 3
+ """.stripMargin),
+ Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a
+ """.stripMargin),
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, null) ::
+ Row(5, null) ::
+ Row(6, null) :: Nil)
+
+ val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
+ assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
+
+ withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+ val errorMessage = intercept[SparkException] {
+ val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
+ val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ }.getMessage
+ assert(errorMessage.contains("java.lang.UnsupportedOperationException: " +
+ "Corr only supports the new AggregateExpression2"))
+ }
+ }
+
test("test Last implemented based on AggregateExpression1") {
// TODO: Remove this test once we remove AggregateExpression1.
import org.apache.spark.sql.functions._