diff options
author | Reynold Xin <rxin@databricks.com> | 2015-11-04 16:49:25 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-04 16:49:25 -0800 |
commit | b6e0a5ae6f243139f11c9cbbf18cddd3f25db208 (patch) | |
tree | 24b86a44d17780e89448928bd27992187362758a | |
parent | 411ff6afb485c9d8cfc667c9346f836f2529ea9f (diff) | |
download | spark-b6e0a5ae6f243139f11c9cbbf18cddd3f25db208.tar.gz spark-b6e0a5ae6f243139f11c9cbbf18cddd3f25db208.tar.bz2 spark-b6e0a5ae6f243139f11c9cbbf18cddd3f25db208.zip |
[SPARK-11510][SQL] Remove SQL aggregation tests for higher order statistics
We have some aggregate function tests in both DataFrameAggregateSuite and SQLQuerySuite. The two have almost the same coverage and we should just remove the SQL one.
Author: Reynold Xin <rxin@databricks.com>
Closes #9475 from rxin/SPARK-11510.
3 files changed, 28 insertions, 147 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b0e2ffaa60..2e679e7bc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -83,13 +83,8 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("average") { checkAnswer( - testData2.agg(avg('a)), - Row(2.0)) - - // Also check mean - checkAnswer( - testData2.agg(mean('a)), - Row(2.0)) + testData2.agg(avg('a), mean('a)), + Row(2.0, 2.0)) checkAnswer( testData2.agg(avg('a), sumDistinct('a)), // non-partial @@ -98,6 +93,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) + checkAnswer( decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) @@ -168,44 +164,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - checkAnswer( emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("stddev") { - val testData2ADev = math.sqrt(4/5.0) - + val testData2ADev = math.sqrt(4 / 5.0) checkAnswer( - testData2.agg(stddev('a)), - Row(testData2ADev)) - - checkAnswer( - testData2.agg(stddev_pop('a)), - Row(math.sqrt(4/6.0))) - - checkAnswer( - testData2.agg(stddev_samp('a)), - Row(testData2ADev)) + testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() == 0) - - checkAnswer( - emptyTableData.agg(stddev('a)), - Row(null)) - checkAnswer( - emptyTableData.agg(stddev_pop('a)), - Row(null)) - - checkAnswer( - emptyTableData.agg(stddev_samp('a)), - Row(null)) + emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(null, null, null)) } test("zero sum") { @@ -227,6 +202,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) + val sparkVariancePop = testData2.agg(var_pop('a)) checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol) @@ -241,52 +217,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("zero moments") { - val emptyTableData = Seq((1, 2)).toDF("a", "b") - assert(emptyTableData.count() === 1) - - checkAnswer( - emptyTableData.agg(variance('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) - + val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - emptyTableData.agg(var_pop('a)), - Row(0.0)) + input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) checkAnswer( - emptyTableData.agg(skewness('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + input.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) } test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - - checkAnswer( - emptyTableData.agg(variance('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_pop('a)), - Row(Double.NaN)) checkAnswer( - emptyTableData.agg(skewness('a)), - Row(Double.NaN)) + emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) checkAnswer( - emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + emptyTableData.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5731a35624..3de277a79a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -726,83 +726,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("stddev") { - checkAnswer( - sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4.0 / 5.0)) - ) - } - - test("stddev_pop") { - checkAnswer( - sql("SELECT STDDEV_POP(a) FROM testData2"), - Row(math.sqrt(4.0 / 6.0)) - ) - } - - test("stddev_samp") { - checkAnswer( - sql("SELECT STDDEV_SAMP(a) FROM testData2"), - Row(math.sqrt(4/5.0)) - ) - } - - test("var_samp") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") - val expectedAnswer = Row(4.0 / 5.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("variance") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Row(0.8) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("var_pop") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") - val expectedAnswer = Row(4.0 / 6.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("skewness") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT skewness(a) FROM testData2") - val expectedAnswer = Row(0.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("kurtosis") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") - val expectedAnswer = Row(-1.5) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("stddev agg") { - checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) - } - - test("variance agg") { - val absTol = 1e-8 - checkAggregatesWithTol( - sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)), - absTol) - } - - test("skewness and kurtosis agg") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a") - val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0)) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e12e6bea30..e2090b0a83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.Decimal class StringFunctionsSuite extends QueryTest with SharedSQLContext { |