aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-01-13 10:26:55 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-13 10:26:55 -0800
commit63eee86cc652c108ca7712c8c0a73db1ca89ae90 (patch)
tree341aa599d17ca0c723b6ac13d1f57ec512a249c6 /sql/hive
parentd6fd9b376b7071aecef34dc82a33eba42b183bc9 (diff)
downloadspark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.tar.gz
spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.tar.bz2
spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.zip
[SPARK-9297] [SQL] Add covar_pop and covar_samp
JIRA: https://issues.apache.org/jira/browse/SPARK-9297 Add two aggregation functions: covar_pop and covar_samp. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Closes #10029 from viirya/covar-funcs.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala32
1 files changed, 32 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 5550198c02..76b36aa891 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
}
+ test("covariance: covar_pop and covar_samp") {
+ // non-trivial example. To reproduce in python, use:
+ // >>> import numpy as np
+ // >>> a = np.array(range(20))
+ // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+ // >>> np.cov(a, b, bias = 0)[0][1]
+ // 595.0
+ // >>> np.cov(a, b, bias = 1)[0][1]
+ // 565.25
+ val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
+ val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_samp - 595.0) < 1e-12)
+
+ val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_pop - 565.25) < 1e-12)
+
+ val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
+ val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
+
+ val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
+
+ // one row test
+ val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
+ val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
+ assert(cov_samp3 == null)
+
+ val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ assert(cov_pop3 == 0.0)
+ }
+
test("no aggregation function (SPARK-11486)") {
val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
.groupBy("s").count()