aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2015-11-13 10:31:17 -0800
committerYin Huai <yhuai@databricks.com>2015-11-13 10:31:17 -0800
commita24477996e936b0861819ffb420f763f80f0b1da (patch)
treee72f8acb7418777e496b37598e398d0186aa246a
parent99693fef0a30432d94556154b81872356d921c64 (diff)
downloadspark-a24477996e936b0861819ffb420f763f80f0b1da.tar.gz
spark-a24477996e936b0861819ffb420f763f80f0b1da.tar.bz2
spark-a24477996e936b0861819ffb420f763f80f0b1da.zip
[SPARK-11690][PYSPARK] Add pivot to python api
This PR adds pivot to the python api of GroupedData with the same syntax as Scala/Java. Author: Andrew Ray <ray.andrew@gmail.com> Closes #9653 from aray/sql-pivot-python.
-rw-r--r--python/pyspark/sql/group.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 71c0bccc5e..227f40bc3c 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -17,7 +17,7 @@
from pyspark import since
from pyspark.rdd import ignore_unicode_prefix
-from pyspark.sql.column import Column, _to_seq
+from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *
@@ -167,6 +167,23 @@ class GroupedData(object):
[Row(sum(age)=7, sum(height)=165)]
"""
+ @since(1.6)
+ def pivot(self, pivot_col, *values):
+ """Pivots a column of the current DataFrame and preform the specified aggregation.
+
+ :param pivot_col: Column to pivot
+ :param values: Optional list of values of pivotColumn that will be translated to columns in
+ the output data frame. If values are not provided the method with do an immediate call
+ to .distinct() on the pivot column.
+ >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()
+ [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
+ >>> df4.groupBy("year").pivot("course").sum("earnings").collect()
+ [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
+ """
+ jgd = self._jdf.pivot(_to_java_column(pivot_col),
+ _to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
+ return GroupedData(jgd, self.sql_ctx)
+
def _test():
import doctest
@@ -182,6 +199,11 @@ def _test():
StructField('name', StringType())]))
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
+ globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
+ Row(course="Java", year=2012, earnings=20000),
+ Row(course="dotNET", year=2012, earnings=5000),
+ Row(course="dotNET", year=2013, earnings=48000),
+ Row(course="Java", year=2013, earnings=30000)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,