diff options
author | Liang-Chi Hsieh <viirya@appier.com> | 2015-11-01 18:37:27 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-01 18:37:27 -0800 |
commit | 3e770a64a48c271c5829d2bcbdc1d6430cda2ac9 (patch) | |
tree | ec0b281b7fafda0bd896f9468c2cb47378e088c0 /sql/hive/src | |
parent | f8d93edec82eedab59d50aec06ca2de7e4cf14f6 (diff) | |
download | spark-3e770a64a48c271c5829d2bcbdc1d6430cda2ac9.tar.gz spark-3e770a64a48c271c5829d2bcbdc1d6430cda2ac9.tar.bz2 spark-3e770a64a48c271c5829d2bcbdc1d6430cda2ac9.zip |
[SPARK-9298][SQL] Add pearson correlation aggregation function
JIRA: https://issues.apache.org/jira/browse/SPARK-9298
This patch adds pearson correlation aggregation function based on `AggregateExpression2`.
Author: Liang-Chi Hsieh <viirya@appier.com>
Closes #8587 from viirya/corr_aggregation.
Diffstat (limited to 'sql/hive/src')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f38a3f63c3..0cf0e0aab9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate @@ -556,6 +557,109 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(0, null, 1, 1, null, 0) :: Nil) } + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) + + val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") + val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0) + assert(corr4 == Row(null)) + + val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c") + val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr5 - 1.0) < 1e-12) + val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr6 + 1.0) < 1e-12) + + // Test for udaf_corr in HiveCompatibilitySuite + // udaf_corr has been blacklisted due to numerical errors + // We test it here: + // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL + // SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; => + // 1 NULL + // 2 NULL + // 3 NULL + // 4 NULL + // 5 NULL + // 6 NULL + // SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323 + + val covar_tab = Seq[(Integer, Integer, Integer)]( + (1, null, 15), + (2, 3, null), + (3, 7, 12), + (4, 4, 14), + (5, 8, 17), + (6, 2, 11)).toDF("a", "b", "c") + + covar_tab.registerTempTable("covar_tab") + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 1 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 3 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a = 3 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a + """.stripMargin), + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, null) :: + Row(5, null) :: + Row(6, null) :: Nil) + + val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) + assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) + + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + val errorMessage = intercept[SparkException] { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + }.getMessage + assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + + "Corr only supports the new AggregateExpression2")) + } + } + test("test Last implemented based on AggregateExpression1") { // TODO: Remove this test once we remove AggregateExpression1. import org.apache.spark.sql.functions._ |