From 63eee86cc652c108ca7712c8c0a73db1ca89ae90 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Jan 2016 10:26:55 -0800 Subject: [SPARK-9297] [SQL] Add covar_pop and covar_samp JIRA: https://issues.apache.org/jira/browse/SPARK-9297 Add two aggregation functions: covar_pop and covar_samp. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10029 from viirya/covar-funcs. --- .../sql/hive/execution/AggregationQuerySuite.scala | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) (limited to 'sql/hive') diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 5550198c02..76b36aa891 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) } + test("covariance: covar_pop and covar_samp") { + // non-trivial example. To reproduce in python, use: + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> np.cov(a, b, bias = 0)[0][1] + // 595.0 + // >>> np.cov(a, b, bias = 1)[0][1] + // 565.25 + val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp - 595.0) < 1e-12) + + val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop - 565.25) < 1e-12) + + val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp2 - 11564.0) < 1e-12) + + val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12) + + // one row test + val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) + assert(cov_samp3 == null) + + val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(cov_pop3 == 0.0) + } + test("no aggregation function (SPARK-11486)") { val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() -- cgit v1.2.3