aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2015-11-06 15:48:20 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-06 15:48:20 -0800
commit1ab72b08601a1c8a674bdd3fab84d9804899b2c7 (patch)
treef27feef4ad014c54b27fa62b4b902a1c427cd674 /python
parent7e9a9e603abce8689938bdd62d04b29299644aa4 (diff)
downloadspark-1ab72b08601a1c8a674bdd3fab84d9804899b2c7.tar.gz
spark-1ab72b08601a1c8a674bdd3fab84d9804899b2c7.tar.bz2
spark-1ab72b08601a1c8a674bdd3fab84d9804899b2c7.zip
[SPARK-11410] [PYSPARK] Add python bindings for repartition and sortW…
…ithinPartitions. Author: Nong Li <nong@databricks.com> Closes #9504 from nongli/spark-11410.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py117
1 files changed, 101 insertions, 16 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 765a4511b6..b97c94dad8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -423,6 +423,67 @@ class DataFrame(object):
return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
@since(1.3)
+ def repartition(self, numPartitions, *cols):
+ """
+ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
+ resulting DataFrame is hash partitioned.
+
+ ``numPartitions`` can be an int to specify the target number of partitions or a Column.
+ If it is a Column, it will be used as the first partitioning column. If not specified,
+ the default number of partitions is used.
+
+ .. versionchanged:: 1.6
+ Added optional arguments to specify the partitioning columns. Also made numPartitions
+ optional if partitioning columns are specified.
+
+ >>> df.repartition(10).rdd.getNumPartitions()
+ 10
+ >>> data = df.unionAll(df).repartition("age")
+ >>> data.show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 2|Alice|
+ | 2|Alice|
+ | 5| Bob|
+ | 5| Bob|
+ +---+-----+
+ >>> data = data.repartition(7, "age")
+ >>> data.show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 5| Bob|
+ | 5| Bob|
+ | 2|Alice|
+ | 2|Alice|
+ +---+-----+
+ >>> data.rdd.getNumPartitions()
+ 7
+ >>> data = data.repartition("name", "age")
+ >>> data.show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 5| Bob|
+ | 5| Bob|
+ | 2|Alice|
+ | 2|Alice|
+ +---+-----+
+ """
+ if isinstance(numPartitions, int):
+ if len(cols) == 0:
+ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
+ else:
+ return DataFrame(
+ self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx)
+ elif isinstance(numPartitions, (basestring, Column)):
+ cols = (numPartitions, ) + cols
+ return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx)
+ else:
+ raise TypeError("numPartitions should be an int or Column")
+
+ @since(1.3)
def distinct(self):
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
@@ -589,6 +650,26 @@ class DataFrame(object):
jdf = self._jdf.join(other._jdf, on._jc, how)
return DataFrame(jdf, self.sql_ctx)
+ @since(1.6)
+ def sortWithinPartitions(self, *cols, **kwargs):
+ """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).
+
+ :param cols: list of :class:`Column` or column names to sort by.
+ :param ascending: boolean or list of boolean (default True).
+ Sort ascending vs. descending. Specify list for multiple sort orders.
+ If a list is specified, length of the list must equal length of the `cols`.
+
+ >>> df.sortWithinPartitions("age", ascending=False).show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 2|Alice|
+ | 5| Bob|
+ +---+-----+
+ """
+ jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+ return DataFrame(jdf, self.sql_ctx)
+
@ignore_unicode_prefix
@since(1.3)
def sort(self, *cols, **kwargs):
@@ -613,22 +694,7 @@ class DataFrame(object):
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
- if not cols:
- raise ValueError("should sort by at least one column")
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
- jcols = [_to_java_column(c) for c in cols]
- ascending = kwargs.get('ascending', True)
- if isinstance(ascending, (bool, int)):
- if not ascending:
- jcols = [jc.desc() for jc in jcols]
- elif isinstance(ascending, list):
- jcols = [jc if asc else jc.desc()
- for asc, jc in zip(ascending, jcols)]
- else:
- raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending))
-
- jdf = self._jdf.sort(self._jseq(jcols))
+ jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
return DataFrame(jdf, self.sql_ctx)
orderBy = sort
@@ -650,6 +716,25 @@ class DataFrame(object):
cols = cols[0]
return self._jseq(cols, _to_java_column)
+ def _sort_cols(self, cols, kwargs):
+ """ Return a JVM Seq of Columns that describes the sort order
+ """
+ if not cols:
+ raise ValueError("should sort by at least one column")
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+ jcols = [_to_java_column(c) for c in cols]
+ ascending = kwargs.get('ascending', True)
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ jcols = [jc.desc() for jc in jcols]
+ elif isinstance(ascending, list):
+ jcols = [jc if asc else jc.desc()
+ for asc, jc in zip(ascending, jcols)]
+ else:
+ raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending))
+ return self._jseq(jcols)
+
@since("1.3.1")
def describe(self, *cols):
"""Computes statistics for numeric columns.