aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2015-10-29 11:58:39 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-29 11:58:39 -0700
commita01cbf5daac148f39cd97299780f542abc41d1e9 (patch)
tree357dfc7f8e7784dc36cbb4f77212e84d0809d1df /sql/core
parent8185f038c13c72e1bea7b0921b84125b7a352139 (diff)
downloadspark-a01cbf5daac148f39cd97299780f542abc41d1e9.tar.gz
spark-a01cbf5daac148f39cd97299780f542abc41d1e9.tar.bz2
spark-a01cbf5daac148f39cd97299780f542abc41d1e9.zip
[SPARK-10641][SQL] Add Skewness and Kurtosis Support
Implementing skewness and kurtosis support based on following algorithm: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics Author: sethah <seth.hendrickson16@gmail.com> Closes #9003 from sethah/SPARK-10641.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala65
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala115
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala48
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala63
5 files changed, 354 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 102b802ad0..dc96384a4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -127,7 +127,12 @@ class GroupedData protected[sql](
case "stddev" => Stddev
case "stddev_pop" => StddevPop
case "stddev_samp" => StddevSamp
+ case "variance" => Variance
+ case "var_pop" => VariancePop
+ case "var_samp" => VarianceSamp
case "sum" => Sum
+ case "skewness" => Skewness
+ case "kurtosis" => Kurtosis
case "count" | "size" =>
// Turn count(*) into count(1)
(inputExpr: Expression) => inputExpr match {
@@ -251,6 +256,30 @@ class GroupedData protected[sql](
}
/**
+ * Compute the skewness for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the skewness values for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def skewness(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(Skewness)
+ }
+
+ /**
+ * Compute the kurtosis for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the kurtosis values for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def kurtosis(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(Kurtosis)
+ }
+
+ /**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the max values for them.
@@ -333,4 +362,40 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}
+
+ /**
+ * Compute the sample variance for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the variance for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def variance(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(Variance)
+ }
+
+ /**
+ * Compute the population variance for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the variance for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def var_pop(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(VariancePop)
+ }
+
+ /**
+ * Compute the sample variance for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the variance for them.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def var_samp(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames : _*)(VarianceSamp)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 15c864a8ab..c1737b1ef6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -229,6 +229,22 @@ object functions {
def first(columnName: String): Column = first(Column(columnName))
/**
+ * Aggregate function: returns the kurtosis of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def kurtosis(e: Column): Column = Kurtosis(e.expr)
+
+ /**
+ * Aggregate function: returns the kurtosis of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
+
+ /**
* Aggregate function: returns the last value in a group.
*
* @group agg_funcs
@@ -295,8 +311,24 @@ object functions {
def min(columnName: String): Column = min(Column(columnName))
/**
- * Aggregate function: returns the unbiased sample standard deviation
- * of the expression in a group.
+ * Aggregate function: returns the skewness of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def skewness(e: Column): Column = Skewness(e.expr)
+
+ /**
+ * Aggregate function: returns the skewness of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def skewness(columnName: String): Column = skewness(Column(columnName))
+
+ /**
+ * Aggregate function: returns the unbiased sample standard deviation of
+ * the expression in a group.
*
* @group agg_funcs
* @since 1.6.0
@@ -304,13 +336,13 @@ object functions {
def stddev(e: Column): Column = Stddev(e.expr)
/**
- * Aggregate function: returns the population standard deviation of
+ * Aggregate function: returns the unbiased sample standard deviation of
* the expression in a group.
*
* @group agg_funcs
* @since 1.6.0
*/
- def stddev_pop(e: Column): Column = StddevPop(e.expr)
+ def stddev(columnName: String): Column = stddev(Column(columnName))
/**
* Aggregate function: returns the unbiased sample standard deviation of
@@ -322,6 +354,33 @@ object functions {
def stddev_samp(e: Column): Column = StddevSamp(e.expr)
/**
+ * Aggregate function: returns the unbiased sample standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName))
+
+ /**
+ * Aggregate function: returns the population standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_pop(e: Column): Column = StddevPop(e.expr)
+
+ /**
+ * Aggregate function: returns the population standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName))
+
+ /**
* Aggregate function: returns the sum of all values in the expression.
*
* @group agg_funcs
@@ -353,6 +412,54 @@ object functions {
*/
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
+ /**
+ * Aggregate function: returns the population variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def variance(e: Column): Column = Variance(e.expr)
+
+ /**
+ * Aggregate function: returns the population variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def variance(columnName: String): Column = variance(Column(columnName))
+
+ /**
+ * Aggregate function: returns the unbiased variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def var_samp(e: Column): Column = VarianceSamp(e.expr)
+
+ /**
+ * Aggregate function: returns the unbiased variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def var_samp(columnName: String): Column = var_samp(Column(columnName))
+
+ /**
+ * Aggregate function: returns the population variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def var_pop(e: Column): Column = VariancePop(e.expr)
+
+ /**
+ * Aggregate function: returns the population variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def var_pop(columnName: String): Column = var_pop(Column(columnName))
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f5ef9ffd7f..9b23977c76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -221,4 +221,77 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
emptyTableData.agg(sumDistinct('a)),
Row(null))
}
+
+ test("moments") {
+ val absTol = 1e-8
+
+ val sparkVariance = testData2.agg(variance('a))
+ val expectedVariance = Row(4.0 / 6.0)
+ checkAggregatesWithTol(sparkVariance, expectedVariance, absTol)
+ val sparkVariancePop = testData2.agg(var_pop('a))
+ checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol)
+
+ val sparkVarianceSamp = testData2.agg(var_samp('a))
+ val expectedVarianceSamp = Row(4.0 / 5.0)
+ checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol)
+
+ val sparkSkewness = testData2.agg(skewness('a))
+ val expectedSkewness = Row(0.0)
+ checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol)
+
+ val sparkKurtosis = testData2.agg(kurtosis('a))
+ val expectedKurtosis = Row(-1.5)
+ checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol)
+
+ }
+
+ test("zero moments") {
+ val emptyTableData = Seq((1, 2)).toDF("a", "b")
+ assert(emptyTableData.count() === 1)
+
+ checkAnswer(
+ emptyTableData.agg(variance('a)),
+ Row(0.0))
+
+ checkAnswer(
+ emptyTableData.agg(var_samp('a)),
+ Row(Double.NaN))
+
+ checkAnswer(
+ emptyTableData.agg(var_pop('a)),
+ Row(0.0))
+
+ checkAnswer(
+ emptyTableData.agg(skewness('a)),
+ Row(Double.NaN))
+
+ checkAnswer(
+ emptyTableData.agg(kurtosis('a)),
+ Row(Double.NaN))
+ }
+
+ test("null moments") {
+ val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+ assert(emptyTableData.count() === 0)
+
+ checkAnswer(
+ emptyTableData.agg(variance('a)),
+ Row(Double.NaN))
+
+ checkAnswer(
+ emptyTableData.agg(var_samp('a)),
+ Row(Double.NaN))
+
+ checkAnswer(
+ emptyTableData.agg(var_pop('a)),
+ Row(Double.NaN))
+
+ checkAnswer(
+ emptyTableData.agg(skewness('a)),
+ Row(Double.NaN))
+
+ checkAnswer(
+ emptyTableData.agg(kurtosis('a)),
+ Row(Double.NaN))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 73e02eb0d9..3c174efe73 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -135,6 +135,32 @@ abstract class QueryTest extends PlanTest {
}
/**
+ * Runs the plan and makes sure the answer is within absTol of the expected result.
+ * @param dataFrame the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param absTol the absolute tolerance between actual and expected answers.
+ */
+ protected def checkAggregatesWithTol(dataFrame: DataFrame,
+ expectedAnswer: Seq[Row],
+ absTol: Double): Unit = {
+ // TODO: catch exceptions in data frame execution
+ val actualAnswer = dataFrame.collect()
+ require(actualAnswer.length == expectedAnswer.length,
+ s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}")
+
+ actualAnswer.zip(expectedAnswer).foreach {
+ case (actualRow, expectedRow) =>
+ QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol)
+ }
+ }
+
+ protected def checkAggregatesWithTol(dataFrame: DataFrame,
+ expectedAnswer: Row,
+ absTol: Double): Unit = {
+ checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol)
+ }
+
+ /**
* Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
*/
def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
@@ -214,6 +240,28 @@ object QueryTest {
return None
}
+ /**
+ * Runs the plan and makes sure the answer is within absTol of the expected result.
+ * @param actualAnswer the actual result in a [[Row]].
+ * @param expectedAnswer the expected result in a[[Row]].
+ * @param absTol the absolute tolerance between actual and expected answers.
+ */
+ protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = {
+ require(actualAnswer.length == expectedAnswer.length,
+ s"actual answer length ${actualAnswer.length} != " +
+ s"expected answer length ${expectedAnswer.length}")
+
+ // TODO: support other numeric types besides Double
+ // TODO: support struct types?
+ actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach {
+ case (actual: Double, expected: Double) =>
+ assert(math.abs(actual - expected) < absTol,
+ s"actual answer $actual not within $absTol of correct answer $expected")
+ case (actual, expected) =>
+ assert(actual == expected, s"$actual did not equal $expected")
+ }
+ }
+
def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
checkAnswer(df, expectedAnswer.asScala) match {
case Some(errorMessage) => errorMessage
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f5ae3ae49b..5a616fac0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -523,8 +523,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("aggregates with nulls") {
checkAnswer(
- sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
- Row(1, 3, 2, 1, 6, 3)
+ sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
+ "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
+ Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3)
)
}
@@ -717,14 +718,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("stddev") {
checkAnswer(
sql("SELECT STDDEV(a) FROM testData2"),
- Row(math.sqrt(4/5.0))
+ Row(math.sqrt(4.0 / 5.0))
)
}
test("stddev_pop") {
checkAnswer(
sql("SELECT STDDEV_POP(a) FROM testData2"),
- Row(math.sqrt(4/6.0))
+ Row(math.sqrt(4.0 / 6.0))
)
}
@@ -735,10 +736,60 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("var_samp") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2")
+ val expectedAnswer = Row(4.0 / 5.0)
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+ }
+
+ test("variance") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
+ val expectedAnswer = Row(4.0 / 6.0)
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+ }
+
+ test("var_pop") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2")
+ val expectedAnswer = Row(4.0 / 6.0)
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+ }
+
+ test("skewness") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT skewness(a) FROM testData2")
+ val expectedAnswer = Row(0.0)
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+ }
+
+ test("kurtosis") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2")
+ val expectedAnswer = Row(-1.5)
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+ }
+
test("stddev agg") {
checkAnswer(
- sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
- (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0))))
+ sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
+ (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
+ }
+
+ test("variance agg") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" +
+ "FROM testData2 GROUP BY a")
+ val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0))
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+ }
+
+ test("skewness and kurtosis agg") {
+ val absTol = 1e-8
+ val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a")
+ val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0))
+ checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}
test("inner join where, one match per row") {