aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-10 20:13:38 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-10 20:13:38 -0800
commitb5761d150b66ee0ae5f1be897d9d7a1abb039884 (patch)
tree4d2f839c621b844f09d7e5045c23156cec3a12a6 /python/pyspark/sql
parent0f09f0226983cdc409ef504dff48395787dc844f (diff)
downloadspark-b5761d150b66ee0ae5f1be897d9d7a1abb039884.tar.gz
spark-b5761d150b66ee0ae5f1be897d9d7a1abb039884.tar.bz2
spark-b5761d150b66ee0ae5f1be897d9d7a1abb039884.zip
[SPARK-12706] [SQL] grouping() and grouping_id()
Grouping() returns a column is aggregated or not, grouping_id() returns the aggregation levels. grouping()/grouping_id() could be used with window function, but does not work in having/sort clause, will be fixed by another PR. The GROUPING__ID/grouping_id() in Hive is wrong (according to docs), we also did it wrongly, this PR change that to match the behavior in most databases (also the docs of Hive). Author: Davies Liu <davies@databricks.com> Closes #10677 from davies/grouping.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/dataframe.py22
-rw-r--r--python/pyspark/sql/functions.py44
2 files changed, 55 insertions, 11 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3a8c8305ee..3104e41407 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -887,8 +887,8 @@ class DataFrame(object):
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
>>> sorted(df.groupBy(df.name).avg().collect())
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
- >>> df.groupBy(['name', df.age]).count().collect()
- [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
+ >>> sorted(df.groupBy(['name', df.age]).count().collect())
+ [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)]
"""
jgd = self._jdf.groupBy(self._jcols(*cols))
from pyspark.sql.group import GroupedData
@@ -900,15 +900,15 @@ class DataFrame(object):
Create a multi-dimensional rollup for the current :class:`DataFrame` using
the specified columns, so we can run aggregation on them.
- >>> df.rollup('name', df.age).count().show()
+ >>> df.rollup("name", df.age).count().orderBy("name", "age").show()
+-----+----+-----+
| name| age|count|
+-----+----+-----+
- |Alice| 2| 1|
- | Bob| 5| 1|
- | Bob|null| 1|
| null|null| 2|
|Alice|null| 1|
+ |Alice| 2| 1|
+ | Bob|null| 1|
+ | Bob| 5| 1|
+-----+----+-----+
"""
jgd = self._jdf.rollup(self._jcols(*cols))
@@ -921,17 +921,17 @@ class DataFrame(object):
Create a multi-dimensional cube for the current :class:`DataFrame` using
the specified columns, so we can run aggregation on them.
- >>> df.cube('name', df.age).count().show()
+ >>> df.cube("name", df.age).count().orderBy("name", "age").show()
+-----+----+-----+
| name| age|count|
+-----+----+-----+
+ | null|null| 2|
| null| 2| 1|
- |Alice| 2| 1|
- | Bob| 5| 1|
| null| 5| 1|
- | Bob|null| 1|
- | null|null| 2|
|Alice|null| 1|
+ |Alice| 2| 1|
+ | Bob|null| 1|
+ | Bob| 5| 1|
+-----+----+-----+
"""
jgd = self._jdf.cube(self._jcols(*cols))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 0d57085267..680493e0e6 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -288,6 +288,50 @@ def first(col, ignorenulls=False):
return Column(jc)
+@since(2.0)
+def grouping(col):
+ """
+ Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+ or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+
+ >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show()
+ +-----+--------------+--------+
+ | name|grouping(name)|sum(age)|
+ +-----+--------------+--------+
+ | null| 1| 7|
+ |Alice| 0| 2|
+ | Bob| 0| 5|
+ +-----+--------------+--------+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.grouping(_to_java_column(col))
+ return Column(jc)
+
+
+@since(2.0)
+def grouping_id(*cols):
+ """
+ Aggregate function: returns the level of grouping, equals to
+
+ (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+
+ Note: the list of columns should match with grouping columns exactly, or empty (means all the
+ grouping columns).
+
+ >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
+ +-----+------------+--------+
+ | name|groupingid()|sum(age)|
+ +-----+------------+--------+
+ | null| 1| 7|
+ |Alice| 0| 2|
+ | Bob| 0| 5|
+ +-----+------------+--------+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
+ return Column(jc)
+
+
@since(1.6)
def input_file_name():
"""Creates a string column for the file name of the current Spark task.