diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2016-01-13 10:26:55 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-01-13 10:26:55 -0800 |
commit | 63eee86cc652c108ca7712c8c0a73db1ca89ae90 (patch) | |
tree | 341aa599d17ca0c723b6ac13d1f57ec512a249c6 /sql/hive | |
parent | d6fd9b376b7071aecef34dc82a33eba42b183bc9 (diff) | |
download | spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.tar.gz spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.tar.bz2 spark-63eee86cc652c108ca7712c8c0a73db1ca89ae90.zip |
[SPARK-9297] [SQL] Add covar_pop and covar_samp
JIRA: https://issues.apache.org/jira/browse/SPARK-9297
Add two aggregation functions: covar_pop and covar_samp.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Liang-Chi Hsieh <viirya@appier.com>
Closes #10029 from viirya/covar-funcs.
Diffstat (limited to 'sql/hive')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 5550198c02..76b36aa891 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) } + test("covariance: covar_pop and covar_samp") { + // non-trivial example. To reproduce in python, use: + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> np.cov(a, b, bias = 0)[0][1] + // 595.0 + // >>> np.cov(a, b, bias = 1)[0][1] + // 565.25 + val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp - 595.0) < 1e-12) + + val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop - 565.25) < 1e-12) + + val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp2 - 11564.0) < 1e-12) + + val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12) + + // one row test + val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) + assert(cov_samp3 == null) + + val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(cov_pop3 == 0.0) + } + test("no aggregation function (SPARK-11486)") { val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() |