From b5761d150b66ee0ae5f1be897d9d7a1abb039884 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 20:13:38 -0800 Subject: [SPARK-12706] [SQL] grouping() and grouping_id() Grouping() returns a column is aggregated or not, grouping_id() returns the aggregation levels. grouping()/grouping_id() could be used with window function, but does not work in having/sort clause, will be fixed by another PR. The GROUPING__ID/grouping_id() in Hive is wrong (according to docs), we also did it wrongly, this PR change that to match the behavior in most databases (also the docs of Hive). Author: Davies Liu Closes #10677 from davies/grouping. --- python/pyspark/sql/dataframe.py | 22 ++++++++++----------- python/pyspark/sql/functions.py | 44 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 11 deletions(-) (limited to 'python/pyspark/sql') diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3a8c8305ee..3104e41407 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -887,8 +887,8 @@ class DataFrame(object): [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(['name', df.age]).count().collect() - [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] + >>> sorted(df.groupBy(['name', df.age]).count().collect()) + [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData @@ -900,15 +900,15 @@ class DataFrame(object): Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. - >>> df.rollup('name', df.age).count().show() + >>> df.rollup("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ - |Alice| 2| 1| - | Bob| 5| 1| - | Bob|null| 1| | null|null| 2| |Alice|null| 1| + |Alice| 2| 1| + | Bob|null| 1| + | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) @@ -921,17 +921,17 @@ class DataFrame(object): Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. - >>> df.cube('name', df.age).count().show() + >>> df.cube("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ + | null|null| 2| | null| 2| 1| - |Alice| 2| 1| - | Bob| 5| 1| | null| 5| 1| - | Bob|null| 1| - | null|null| 2| |Alice|null| 1| + |Alice| 2| 1| + | Bob|null| 1| + | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0d57085267..680493e0e6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -288,6 +288,50 @@ def first(col, ignorenulls=False): return Column(jc) +@since(2.0) +def grouping(col): + """ + Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + or not, returns 1 for aggregated or 0 for not aggregated in the result set. + + >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() + +-----+--------------+--------+ + | name|grouping(name)|sum(age)| + +-----+--------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+--------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping(_to_java_column(col)) + return Column(jc) + + +@since(2.0) +def grouping_id(*cols): + """ + Aggregate function: returns the level of grouping, equals to + + (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + + Note: the list of columns should match with grouping columns exactly, or empty (means all the + grouping columns). + + >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() + +-----+------------+--------+ + | name|groupingid()|sum(age)| + +-----+------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. -- cgit v1.2.3