aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala5
2 files changed, 21 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 1728b0b8c9..fae4bd0fd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -247,6 +247,22 @@ object functions {
def last(columnName: String): Column = last(Column(columnName))
/**
+ * Aggregate function: returns the average of the values in a group.
+ * Alias for avg.
+ *
+ * @group agg_funcs
+ */
+ def mean(e: Column): Column = avg(e)
+
+ /**
+ * Aggregate function: returns the average of the values in a group.
+ * Alias for avg.
+ *
+ * @group agg_funcs
+ */
+ def mean(columnName: String): Column = avg(columnName)
+
+ /**
* Aggregate function: returns the minimum value of the expression in a group.
*
* @group agg_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d2ca8dccae..cf590cbd52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -308,6 +308,11 @@ class DataFrameSuite extends QueryTest {
testData2.agg(avg('a)),
Row(2.0))
+ // Also check mean
+ checkAnswer(
+ testData2.agg(mean('a)),
+ Row(2.0))
+
checkAnswer(
testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)