aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-11-01 18:37:27 -0800
committerYin Huai <yhuai@databricks.com>2015-11-01 18:37:27 -0800
commit3e770a64a48c271c5829d2bcbdc1d6430cda2ac9 (patch)
treeec0b281b7fafda0bd896f9468c2cb47378e088c0 /sql/hive
parentf8d93edec82eedab59d50aec06ca2de7e4cf14f6 (diff)
downloadspark-3e770a64a48c271c5829d2bcbdc1d6430cda2ac9.tar.gz
spark-3e770a64a48c271c5829d2bcbdc1d6430cda2ac9.tar.bz2
spark-3e770a64a48c271c5829d2bcbdc1d6430cda2ac9.zip
[SPARK-9298][SQL] Add pearson correlation aggregation function
JIRA: https://issues.apache.org/jira/browse/SPARK-9298 This patch adds pearson correlation aggregation function based on `AggregateExpression2`. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #8587 from viirya/corr_aggregation.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala104
2 files changed, 109 insertions, 2 deletions
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 9e357bf348..6ed40b0397 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -304,7 +304,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// classpath problems
"compute_stats.*",
- "udf_bitmap_.*"
+ "udf_bitmap_.*",
+
+ // The difference between the double numbers generated by Hive and Spark
+ // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322)
+ "udaf_corr"
)
/**
@@ -857,7 +861,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"type_cast_1",
"type_widening",
"udaf_collect_set",
- "udaf_corr",
"udaf_covar_pop",
"udaf_covar_samp",
"udaf_histogram_numeric",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index f38a3f63c3..0cf0e0aab9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._
+import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
@@ -556,6 +557,109 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(0, null, 1, 1, null, 0) :: Nil)
}
+ test("pearson correlation") {
+ val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
+ val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(corr1 - 1.0) < 1e-12)
+ val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+ assert(math.abs(corr2 + 1.0) < 1e-12)
+ // non-trivial example. To reproduce in python, use:
+ // >>> from scipy.stats import pearsonr
+ // >>> import numpy as np
+ // >>> a = np.array(range(20))
+ // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+ // >>> pearsonr(a, b)
+ // (0.95723391394758572, 3.8902121417802199e-11)
+ // In R, use:
+ // > a <- 0:19
+ // > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
+ // > cor(a, b)
+ // [1] 0.957233913947585835
+ val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
+ val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
+
+ val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
+ val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0)
+ assert(corr4 == Row(null))
+
+ val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
+ val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(corr5 - 1.0) < 1e-12)
+ val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+ assert(math.abs(corr6 + 1.0) < 1e-12)
+
+ // Test for udaf_corr in HiveCompatibilitySuite
+ // udaf_corr has been blacklisted due to numerical errors
+ // We test it here:
+ // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL
+ // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL
+ // SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL
+ // SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; =>
+ // 1 NULL
+ // 2 NULL
+ // 3 NULL
+ // 4 NULL
+ // 5 NULL
+ // 6 NULL
+ // SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323
+
+ val covar_tab = Seq[(Integer, Integer, Integer)](
+ (1, null, 15),
+ (2, 3, null),
+ (3, 7, 12),
+ (4, 4, 14),
+ (5, 8, 17),
+ (6, 2, 11)).toDF("a", "b", "c")
+
+ covar_tab.registerTempTable("covar_tab")
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT corr(b, c) FROM covar_tab WHERE a < 1
+ """.stripMargin),
+ Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT corr(b, c) FROM covar_tab WHERE a < 3
+ """.stripMargin),
+ Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT corr(b, c) FROM covar_tab WHERE a = 3
+ """.stripMargin),
+ Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a
+ """.stripMargin),
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, null) ::
+ Row(5, null) ::
+ Row(6, null) :: Nil)
+
+ val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
+ assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
+
+ withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+ val errorMessage = intercept[SparkException] {
+ val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
+ val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ }.getMessage
+ assert(errorMessage.contains("java.lang.UnsupportedOperationException: " +
+ "Corr only supports the new AggregateExpression2"))
+ }
+ }
+
test("test Last implemented based on AggregateExpression1") {
// TODO: Remove this test once we remove AggregateExpression1.
import org.apache.spark.sql.functions._