aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-02-16 10:06:11 -0800
committerReynold Xin <rxin@databricks.com>2015-02-16 10:06:11 -0800
commit5c78be7a515fc2fc92cda0517318e7b5d85762f4 (patch)
treeb3685d0c4946bf4a005944a465753d8e308ca75c /python/pyspark/sql/dataframe.py
parenta3afa4a1bff88c4d8a5228fcf1e0cfc132541a22 (diff)
downloadspark-5c78be7a515fc2fc92cda0517318e7b5d85762f4.tar.gz
spark-5c78be7a515fc2fc92cda0517318e7b5d85762f4.tar.bz2
spark-5c78be7a515fc2fc92cda0517318e7b5d85762f4.zip
[SPARK-5799][SQL] Compute aggregation function on specified numeric columns
Compute aggregation function on specified numeric columns. For example: val df = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")).toDataFrame("key", "value1", "value2", "rest") df.groupBy("key").min("value2") Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #4592 from viirya/specific_cols_agg and squashes the following commits: 9446896 [Liang-Chi Hsieh] For comments. 314c4cd [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg 353fad7 [Liang-Chi Hsieh] For python unit tests. 54ed0c4 [Liang-Chi Hsieh] Address comments. b079e6b [Liang-Chi Hsieh] Remove duplicate codes. 55100fb [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg 880c2ac [Liang-Chi Hsieh] Fix Python style checks. 4c63a01 [Liang-Chi Hsieh] Fix pyspark. b1a24fc [Liang-Chi Hsieh] Address comments. 2592f29 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg 27069c3 [Liang-Chi Hsieh] Combine functions and add varargs annotation. 371a3f7 [Liang-Chi Hsieh] Compute aggregation function on specified numeric columns.
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py74
1 files changed, 59 insertions, 15 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1438fe5285..28a59e73a3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -664,6 +664,18 @@ def dfapi(f):
return _api
+def df_varargs_api(f):
+ def _api(self, *args):
+ jargs = ListConverter().convert(args,
+ self.sql_ctx._sc._gateway._gateway_client)
+ name = f.__name__
+ jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
class GroupedData(object):
"""
@@ -714,30 +726,60 @@ class GroupedData(object):
[Row(age=2, count=1), Row(age=5, count=1)]
"""
- @dfapi
- def mean(self):
+ @df_varargs_api
+ def mean(self, *cols):
"""Compute the average value for each numeric columns
- for each group. This is an alias for `avg`."""
+ for each group. This is an alias for `avg`.
- @dfapi
- def avg(self):
+ >>> df.groupBy().mean('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().mean('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def avg(self, *cols):
"""Compute the average value for each numeric columns
- for each group."""
+ for each group.
- @dfapi
- def max(self):
+ >>> df.groupBy().avg('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().avg('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def max(self, *cols):
"""Compute the max value for each numeric columns for
- each group. """
+ each group.
- @dfapi
- def min(self):
+ >>> df.groupBy().max('age').collect()
+ [Row(MAX(age#0)=5)]
+ >>> df3.groupBy().max('age', 'height').collect()
+ [Row(MAX(age#4)=5, MAX(height#5)=85)]
+ """
+
+ @df_varargs_api
+ def min(self, *cols):
"""Compute the min value for each numeric column for
- each group."""
+ each group.
- @dfapi
- def sum(self):
+ >>> df.groupBy().min('age').collect()
+ [Row(MIN(age#0)=2)]
+ >>> df3.groupBy().min('age', 'height').collect()
+ [Row(MIN(age#4)=2, MIN(height#5)=80)]
+ """
+
+ @df_varargs_api
+ def sum(self, *cols):
"""Compute the sum for each numeric columns for each
- group."""
+ group.
+
+ >>> df.groupBy().sum('age').collect()
+ [Row(SUM(age#0)=7)]
+ >>> df3.groupBy().sum('age', 'height').collect()
+ [Row(SUM(age#4)=7, SUM(height#5)=165)]
+ """
def _create_column_from_literal(literal):
@@ -945,6 +987,8 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)