From b5761d150b66ee0ae5f1be897d9d7a1abb039884 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 20:13:38 -0800 Subject: [SPARK-12706] [SQL] grouping() and grouping_id() Grouping() returns a column is aggregated or not, grouping_id() returns the aggregation levels. grouping()/grouping_id() could be used with window function, but does not work in having/sort clause, will be fixed by another PR. The GROUPING__ID/grouping_id() in Hive is wrong (according to docs), we also did it wrongly, this PR change that to match the behavior in most databases (also the docs of Hive). Author: Davies Liu Closes #10677 from davies/grouping. --- python/pyspark/sql/functions.py | 44 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) (limited to 'python/pyspark/sql/functions.py') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0d57085267..680493e0e6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -288,6 +288,50 @@ def first(col, ignorenulls=False): return Column(jc) +@since(2.0) +def grouping(col): + """ + Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + or not, returns 1 for aggregated or 0 for not aggregated in the result set. + + >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() + +-----+--------------+--------+ + | name|grouping(name)|sum(age)| + +-----+--------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+--------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping(_to_java_column(col)) + return Column(jc) + + +@since(2.0) +def grouping_id(*cols): + """ + Aggregate function: returns the level of grouping, equals to + + (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + + Note: the list of columns should match with grouping columns exactly, or empty (means all the + grouping columns). + + >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() + +-----+------------+--------+ + | name|groupingid()|sum(age)| + +-----+------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. -- cgit v1.2.3