aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-17 11:29:27 -0500
committerReynold Xin <rxin@databricks.com>2015-04-17 11:29:27 -0500
commitc84d91692aa25c01882bcc3f9fd5de3cfa786195 (patch)
tree7951bc6429ae21eb62de4ed6c6de658b379affde /python/pyspark
parentdc48ba9f9f7449dd2f12cbad288b65c8119d9284 (diff)
downloadspark-c84d91692aa25c01882bcc3f9fd5de3cfa786195.tar.gz
spark-c84d91692aa25c01882bcc3f9fd5de3cfa786195.tar.bz2
spark-c84d91692aa25c01882bcc3f9fd5de3cfa786195.zip
[SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas
``` select(['cola', 'colb']) groupby(['colA', 'colB']) groupby([df.colA, df.colB]) df.sort('A', ascending=True) df.sort(['A', 'B'], ascending=True) df.sort(['A', 'B'], ascending=[1, 0]) ``` cc rxin Author: Davies Liu <davies@databricks.com> Closes #5544 from davies/compatibility and squashes the following commits: 4944058 [Davies Liu] add docstrings adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into compatibility bcbbcab [Davies Liu] support ascending as list 8dabdf0 [Davies Liu] improve API compatibility to pandas
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/dataframe.py96
-rw-r--r--python/pyspark/sql/functions.py11
-rw-r--r--python/pyspark/sql/tests.py2
3 files changed, 70 insertions, 39 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b9a3e6cfe7..326d22e72f 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -485,13 +485,17 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
- def sort(self, *cols):
+ def sort(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
- :param cols: list of :class:`Column` to sort by.
+ :param cols: list of :class:`Column` or column names to sort by.
+ :param ascending: sort by ascending order or not, could be bool, int
+ or list of bool, int (default: True).
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ >>> df.sort("age", ascending=False).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
@@ -499,16 +503,42 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
if not cols:
raise ValueError("should sort by at least one column")
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- self._sc._gateway._gateway_client)
- jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+ jcols = [_to_java_column(c) for c in cols]
+ ascending = kwargs.get('ascending', True)
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ jcols = [jc.desc() for jc in jcols]
+ elif isinstance(ascending, list):
+ jcols = [jc if asc else jc.desc()
+ for asc, jc in zip(ascending, jcols)]
+ else:
+ raise TypeError("ascending can only be bool or list, but got %s" % type(ascending))
+
+ jdf = self._jdf.sort(self._jseq(jcols))
return DataFrame(jdf, self.sql_ctx)
orderBy = sort
+ def _jseq(self, cols, converter=None):
+ """Return a JVM Seq of Columns from a list of Column or names"""
+ return _to_seq(self.sql_ctx._sc, cols, converter)
+
+ def _jcols(self, *cols):
+ """Return a JVM Seq of Columns from a list of Column or column names
+
+ If `cols` has only one list in it, cols[0] will be used as the list.
+ """
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+ return self._jseq(cols, _to_java_column)
+
def describe(self, *cols):
"""Computes statistics for numeric columns.
@@ -523,9 +553,7 @@ class DataFrame(object):
min 2
max 5
"""
- cols = ListConverter().convert(cols,
- self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
+ jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@@ -607,9 +635,7 @@ class DataFrame(object):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- self._sc._gateway._gateway_client)
- jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sql_ctx)
def selectExpr(self, *expr):
@@ -620,8 +646,9 @@ class DataFrame(object):
>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
"""
- jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
- jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
+ if len(expr) == 1 and isinstance(expr[0], list):
+ expr = expr[0]
+ jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@@ -659,6 +686,8 @@ class DataFrame(object):
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
+ :func:`groupby` is an alias for :func:`groupBy`.
+
:param cols: list of columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`).
@@ -668,12 +697,14 @@ class DataFrame(object):
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
+ >>> df.groupBy(['name', df.age]).count().collect()
+ [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- self._sc._gateway._gateway_client)
- jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ jdf = self._jdf.groupBy(self._jcols(*cols))
return GroupedData(jdf, self.sql_ctx)
+ groupby = groupBy
+
def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy.agg()``).
@@ -744,9 +775,7 @@ class DataFrame(object):
if thresh is None:
thresh = len(subset) if how == 'any' else 1
- cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
- cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
- return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
+ return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
@@ -799,9 +828,7 @@ class DataFrame(object):
elif not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")
- cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
- cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
- return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
+ return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
@ignore_unicode_prefix
def withColumn(self, colName, col):
@@ -862,10 +889,8 @@ def dfapi(f):
def df_varargs_api(f):
def _api(self, *args):
- jargs = ListConverter().convert(args,
- self.sql_ctx._sc._gateway._gateway_client)
name = f.__name__
- jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+ jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
@@ -912,9 +937,8 @@ class GroupedData(object):
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
- jcols = ListConverter().convert([c._jc for c in exprs[1:]],
- self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ jdf = self._jdf.agg(exprs[0]._jc,
+ _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)
@dfapi
@@ -1006,6 +1030,19 @@ def _to_java_column(col):
return jcol
+def _to_seq(sc, cols, converter=None):
+ """
+ Convert a list of Column (or names) into a JVM Seq of Column.
+
+ An optional `converter` could be used to convert items in `cols`
+ into JVM Column objects.
+ """
+ if converter:
+ cols = [converter(c) for c in cols]
+ jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
+ return sc._jvm.PythonUtils.toSeq(jcols)
+
+
def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
@@ -1177,8 +1214,7 @@ class Column(object):
cols = cols[0]
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
sc = SparkContext._active_spark_context
- jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
- jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols))
+ jc = getattr(self._jc, "in")(_to_seq(sc, cols))
return Column(jc)
# order
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1d65369528..bb47923f24 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -23,13 +23,11 @@ import sys
if sys.version < "3":
from itertools import imap as map
-from py4j.java_collections import ListConverter
-
from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
-from pyspark.sql.dataframe import Column, _to_java_column
+from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
@@ -87,8 +85,7 @@ def countDistinct(col, *cols):
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
- jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
+ jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
return Column(jc)
@@ -138,9 +135,7 @@ class UserDefinedFunction(object):
def __call__(self, *cols):
sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
return Column(jc)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6691e8c8dc..aa3aa1d164 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -282,7 +282,7 @@ class SQLTests(ReusedPySparkTestCase):
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
StructField("null1", DoubleType(), True)])
- df = self.sqlCtx.applySchema(rdd, schema)
+ df = self.sqlCtx.createDataFrame(rdd, schema)
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),