aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-26 19:00:36 -0800
committerReynold Xin <rxin@databricks.com>2015-11-26 19:00:36 -0800
commit6f6bb0e893c8370cbab4d63a56d74e00cb7f3cf6 (patch)
tree875ec2024c94dc2cd35fb080131f6e4757705340
parent0c1e72e7f79231e537299b57a1ab7cd843171923 (diff)
downloadspark-6f6bb0e893c8370cbab4d63a56d74e00cb7f3cf6.tar.gz
spark-6f6bb0e893c8370cbab4d63a56d74e00cb7f3cf6.tar.bz2
spark-6f6bb0e893c8370cbab4d63a56d74e00cb7f3cf6.zip
[SPARK-12011][SQL] Stddev/Variance etc should support columnName as arguments
Spark SQL aggregate function: ```Java stddev stddev_pop stddev_samp variance var_pop var_samp skewness kurtosis collect_list collect_set ``` should support ```columnName``` as arguments like other aggregate function(max/min/count/sum). Author: Yanbo Liang <ybliang8@gmail.com> Closes #9994 from yanboliang/SPARK-12011.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala86
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala3
2 files changed, 89 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 276c5dfc8b..e79defbbbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -215,6 +215,16 @@ object functions extends LegacyFunctions {
def collect_list(e: Column): Column = callUDF("collect_list", e)
/**
+ * Aggregate function: returns a list of objects with duplicates.
+ *
+ * For now this is an alias for the collect_list Hive UDAF.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def collect_list(columnName: String): Column = collect_list(Column(columnName))
+
+ /**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
* For now this is an alias for the collect_set Hive UDAF.
@@ -225,6 +235,16 @@ object functions extends LegacyFunctions {
def collect_set(e: Column): Column = callUDF("collect_set", e)
/**
+ * Aggregate function: returns a set of objects with duplicate elements eliminated.
+ *
+ * For now this is an alias for the collect_set Hive UDAF.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def collect_set(columnName: String): Column = collect_set(Column(columnName))
+
+ /**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
@@ -313,6 +333,14 @@ object functions extends LegacyFunctions {
def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) }
/**
+ * Aggregate function: returns the kurtosis of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
+
+ /**
* Aggregate function: returns the last value in a group.
*
* @group agg_funcs
@@ -387,6 +415,14 @@ object functions extends LegacyFunctions {
def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) }
/**
+ * Aggregate function: returns the skewness of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def skewness(columnName: String): Column = skewness(Column(columnName))
+
+ /**
* Aggregate function: alias for [[stddev_samp]].
*
* @group agg_funcs
@@ -395,6 +431,14 @@ object functions extends LegacyFunctions {
def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
/**
+ * Aggregate function: alias for [[stddev_samp]].
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev(columnName: String): Column = stddev(Column(columnName))
+
+ /**
* Aggregate function: returns the sample standard deviation of
* the expression in a group.
*
@@ -404,6 +448,15 @@ object functions extends LegacyFunctions {
def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
/**
+ * Aggregate function: returns the sample standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName))
+
+ /**
* Aggregate function: returns the population standard deviation of
* the expression in a group.
*
@@ -413,6 +466,15 @@ object functions extends LegacyFunctions {
def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) }
/**
+ * Aggregate function: returns the population standard deviation of
+ * the expression in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName))
+
+ /**
* Aggregate function: returns the sum of all values in the expression.
*
* @group agg_funcs
@@ -453,6 +515,14 @@ object functions extends LegacyFunctions {
def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) }
/**
+ * Aggregate function: alias for [[var_samp]].
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def variance(columnName: String): Column = variance(Column(columnName))
+
+ /**
* Aggregate function: returns the unbiased variance of the values in a group.
*
* @group agg_funcs
@@ -461,6 +531,14 @@ object functions extends LegacyFunctions {
def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) }
/**
+ * Aggregate function: returns the unbiased variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def var_samp(columnName: String): Column = var_samp(Column(columnName))
+
+ /**
* Aggregate function: returns the population variance of the values in a group.
*
* @group agg_funcs
@@ -468,6 +546,14 @@ object functions extends LegacyFunctions {
*/
def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) }
+ /**
+ * Aggregate function: returns the population variance of the values in a group.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def var_pop(columnName: String): Column = var_pop(Column(columnName))
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 9c42f65bb6..b5c636d0de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -261,6 +261,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
checkAnswer(
testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
+ checkAnswer(
+ testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")),
+ Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
}
test("zero stddev") {