aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py74
1 files changed, 59 insertions, 15 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1438fe5285..28a59e73a3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -664,6 +664,18 @@ def dfapi(f):
return _api
+def df_varargs_api(f):
+ def _api(self, *args):
+ jargs = ListConverter().convert(args,
+ self.sql_ctx._sc._gateway._gateway_client)
+ name = f.__name__
+ jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
class GroupedData(object):
"""
@@ -714,30 +726,60 @@ class GroupedData(object):
[Row(age=2, count=1), Row(age=5, count=1)]
"""
- @dfapi
- def mean(self):
+ @df_varargs_api
+ def mean(self, *cols):
"""Compute the average value for each numeric columns
- for each group. This is an alias for `avg`."""
+ for each group. This is an alias for `avg`.
- @dfapi
- def avg(self):
+ >>> df.groupBy().mean('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().mean('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def avg(self, *cols):
"""Compute the average value for each numeric columns
- for each group."""
+ for each group.
- @dfapi
- def max(self):
+ >>> df.groupBy().avg('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().avg('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def max(self, *cols):
"""Compute the max value for each numeric columns for
- each group. """
+ each group.
- @dfapi
- def min(self):
+ >>> df.groupBy().max('age').collect()
+ [Row(MAX(age#0)=5)]
+ >>> df3.groupBy().max('age', 'height').collect()
+ [Row(MAX(age#4)=5, MAX(height#5)=85)]
+ """
+
+ @df_varargs_api
+ def min(self, *cols):
"""Compute the min value for each numeric column for
- each group."""
+ each group.
- @dfapi
- def sum(self):
+ >>> df.groupBy().min('age').collect()
+ [Row(MIN(age#0)=2)]
+ >>> df3.groupBy().min('age', 'height').collect()
+ [Row(MIN(age#4)=2, MIN(height#5)=80)]
+ """
+
+ @df_varargs_api
+ def sum(self, *cols):
"""Compute the sum for each numeric columns for each
- group."""
+ group.
+
+ >>> df.groupBy().sum('age').collect()
+ [Row(SUM(age#0)=7)]
+ >>> df3.groupBy().sum('age', 'height').collect()
+ [Row(SUM(age#4)=7, SUM(height#5)=165)]
+ """
def _create_column_from_literal(literal):
@@ -945,6 +987,8 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)