aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-05-15 20:09:15 -0700
committerReynold Xin <rxin@databricks.com>2015-05-15 20:09:15 -0700
commitd7b69946cb21cd2781c9ad3e691e54b28efbbf3d (patch)
tree9c45d39c8d7b51f658a9bd8c7e79aab03385bb79 /python
parentadfd366814499c0540a15dd6017091ba8c0f05da (diff)
downloadspark-d7b69946cb21cd2781c9ad3e691e54b28efbbf3d.tar.gz
spark-d7b69946cb21cd2781c9ad3e691e54b28efbbf3d.tar.bz2
spark-d7b69946cb21cd2781c9ad3e691e54b28efbbf3d.zip
[SPARK-7543] [SQL] [PySpark] split dataframe.py into multiple files
dataframe.py is splited into column.py, group.py and dataframe.py: ``` 360 column.py 1223 dataframe.py 183 group.py ``` Author: Davies Liu <davies@databricks.com> Closes #6201 from davies/split_df and squashes the following commits: fc8f5ab [Davies Liu] split dataframe.py into multiple files
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/__init__.py5
-rw-r--r--python/pyspark/sql/column.py360
-rw-r--r--python/pyspark/sql/dataframe.py449
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--python/pyspark/sql/group.py183
-rwxr-xr-xpython/run-tests2
6 files changed, 552 insertions, 449 deletions
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 7192c89b3d..19805e291e 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -55,8 +55,9 @@ del modname, sys
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
-from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
-from pyspark.sql.dataframe import DataFrameStatFunctions
+from pyspark.sql.column import Column
+from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
+from pyspark.sql.group import GroupedData
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
new file mode 100644
index 0000000000..fc7ad674da
--- /dev/null
+++ b/python/pyspark/sql/column.py
@@ -0,0 +1,360 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+if sys.version >= '3':
+ basestring = str
+ long = int
+
+from pyspark.context import SparkContext
+from pyspark.rdd import ignore_unicode_prefix
+from pyspark.sql.types import *
+
+__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions",
+ "DataFrameStatFunctions"]
+
+
+def _create_column_from_literal(literal):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.functions.lit(literal)
+
+
+def _create_column_from_name(name):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.functions.col(name)
+
+
+def _to_java_column(col):
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ return jcol
+
+
+def _to_seq(sc, cols, converter=None):
+ """
+ Convert a list of Column (or names) into a JVM Seq of Column.
+
+ An optional `converter` could be used to convert items in `cols`
+ into JVM Column objects.
+ """
+ if converter:
+ cols = [converter(c) for c in cols]
+ return sc._jvm.PythonUtils.toSeq(cols)
+
+
+def _unary_op(name, doc="unary operator"):
+ """ Create a method for given unary operator """
+ def _(self):
+ jc = getattr(self._jc, name)()
+ return Column(jc)
+ _.__doc__ = doc
+ return _
+
+
+def _func_op(name, doc=''):
+ def _(self):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.functions, name)(self._jc)
+ return Column(jc)
+ _.__doc__ = doc
+ return _
+
+
+def _bin_op(name, doc="binary operator"):
+ """ Create a method for given binary operator
+ """
+ def _(self, other):
+ jc = other._jc if isinstance(other, Column) else other
+ njc = getattr(self._jc, name)(jc)
+ return Column(njc)
+ _.__doc__ = doc
+ return _
+
+
+def _reverse_op(name, doc="binary operator"):
+ """ Create a method for binary operator (this object is on right side)
+ """
+ def _(self, other):
+ jother = _create_column_from_literal(other)
+ jc = getattr(jother, name)(self._jc)
+ return Column(jc)
+ _.__doc__ = doc
+ return _
+
+
+class Column(object):
+
+ """
+ A column in a DataFrame.
+
+ :class:`Column` instances can be created by::
+
+ # 1. Select a column out of a DataFrame
+
+ df.colName
+ df["colName"]
+
+ # 2. Create from an expression
+ df.colName + 1
+ 1 / df.colName
+ """
+
+ def __init__(self, jc):
+ self._jc = jc
+
+ # arithmetic operators
+ __neg__ = _func_op("negate")
+ __add__ = _bin_op("plus")
+ __sub__ = _bin_op("minus")
+ __mul__ = _bin_op("multiply")
+ __div__ = _bin_op("divide")
+ __truediv__ = _bin_op("divide")
+ __mod__ = _bin_op("mod")
+ __radd__ = _bin_op("plus")
+ __rsub__ = _reverse_op("minus")
+ __rmul__ = _bin_op("multiply")
+ __rdiv__ = _reverse_op("divide")
+ __rtruediv__ = _reverse_op("divide")
+ __rmod__ = _reverse_op("mod")
+
+ # logistic operators
+ __eq__ = _bin_op("equalTo")
+ __ne__ = _bin_op("notEqual")
+ __lt__ = _bin_op("lt")
+ __le__ = _bin_op("leq")
+ __ge__ = _bin_op("geq")
+ __gt__ = _bin_op("gt")
+
+ # `and`, `or`, `not` cannot be overloaded in Python,
+ # so use bitwise operators as boolean operators
+ __and__ = _bin_op('and')
+ __or__ = _bin_op('or')
+ __invert__ = _func_op('not')
+ __rand__ = _bin_op("and")
+ __ror__ = _bin_op("or")
+
+ # container operators
+ __contains__ = _bin_op("contains")
+ __getitem__ = _bin_op("apply")
+
+ # bitwise operators
+ bitwiseOR = _bin_op("bitwiseOR")
+ bitwiseAND = _bin_op("bitwiseAND")
+ bitwiseXOR = _bin_op("bitwiseXOR")
+
+ def getItem(self, key):
+ """An expression that gets an item at position `ordinal` out of a list,
+ or gets an item by key out of a dict.
+
+ >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
+ >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
+ +----+------+
+ |l[0]|d[key]|
+ +----+------+
+ | 1| value|
+ +----+------+
+ >>> df.select(df.l[0], df.d["key"]).show()
+ +----+------+
+ |l[0]|d[key]|
+ +----+------+
+ | 1| value|
+ +----+------+
+ """
+ return self[key]
+
+ def getField(self, name):
+ """An expression that gets a field by name in a StructField.
+
+ >>> from pyspark.sql import Row
+ >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
+ >>> df.select(df.r.getField("b")).show()
+ +----+
+ |r[b]|
+ +----+
+ | b|
+ +----+
+ >>> df.select(df.r.a).show()
+ +----+
+ |r[a]|
+ +----+
+ | 1|
+ +----+
+ """
+ return self[name]
+
+ def __getattr__(self, item):
+ if item.startswith("__"):
+ raise AttributeError(item)
+ return self.getField(item)
+
+ # string methods
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
+
+ @ignore_unicode_prefix
+ def substr(self, startPos, length):
+ """
+ Return a :class:`Column` which is a substring of the column
+
+ :param startPos: start position (int or Column)
+ :param length: length of the substring (int or Column)
+
+ >>> df.select(df.name.substr(1, 3).alias("col")).collect()
+ [Row(col=u'Ali'), Row(col=u'Bob')]
+ """
+ if type(startPos) != type(length):
+ raise TypeError("Can not mix the type")
+ if isinstance(startPos, (int, long)):
+ jc = self._jc.substr(startPos, length)
+ elif isinstance(startPos, Column):
+ jc = self._jc.substr(startPos._jc, length._jc)
+ else:
+ raise TypeError("Unexpected type: %s" % type(startPos))
+ return Column(jc)
+
+ __getslice__ = substr
+
+ @ignore_unicode_prefix
+ def inSet(self, *cols):
+ """ A boolean expression that is evaluated to true if the value of this
+ expression is contained by the evaluated values of the arguments.
+
+ >>> df[df.name.inSet("Bob", "Mike")].collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df[df.age.inSet([1, 2, 3])].collect()
+ [Row(age=2, name=u'Alice')]
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, set)):
+ cols = cols[0]
+ cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
+ sc = SparkContext._active_spark_context
+ jc = getattr(self._jc, "in")(_to_seq(sc, cols))
+ return Column(jc)
+
+ # order
+ asc = _unary_op("asc", "Returns a sort expression based on the"
+ " ascending order of the given column name.")
+ desc = _unary_op("desc", "Returns a sort expression based on the"
+ " descending order of the given column name.")
+
+ isNull = _unary_op("isNull", "True if the current expression is null.")
+ isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
+
+ def alias(self, *alias):
+ """Returns this column aliased with a new name or names (in the case of expressions that
+ return more than one column, such as explode).
+
+ >>> df.select(df.age.alias("age2")).collect()
+ [Row(age2=2), Row(age2=5)]
+ """
+
+ if len(alias) == 1:
+ return Column(getattr(self._jc, "as")(alias[0]))
+ else:
+ sc = SparkContext._active_spark_context
+ return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
+
+ @ignore_unicode_prefix
+ def cast(self, dataType):
+ """ Convert the column into type `dataType`
+
+ >>> df.select(df.age.cast("string").alias('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ """
+ if isinstance(dataType, basestring):
+ jc = self._jc.cast(dataType)
+ elif isinstance(dataType, DataType):
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ jc = self._jc.cast(jdt)
+ else:
+ raise TypeError("unexpected type: %s" % type(dataType))
+ return Column(jc)
+
+ @ignore_unicode_prefix
+ def between(self, lowerBound, upperBound):
+ """ A boolean expression that is evaluated to true if the value of this
+ expression is between the given columns.
+ """
+ return (self >= lowerBound) & (self <= upperBound)
+
+ @ignore_unicode_prefix
+ def when(self, condition, value):
+ """Evaluates a list of conditions and returns one of multiple possible result expressions.
+ If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+ See :func:`pyspark.sql.functions.when` for example usage.
+
+ :param condition: a boolean :class:`Column` expression.
+ :param value: a literal value, or a :class:`Column` expression.
+
+ """
+ sc = SparkContext._active_spark_context
+ if not isinstance(condition, Column):
+ raise TypeError("condition should be a Column")
+ v = value._jc if isinstance(value, Column) else value
+ jc = sc._jvm.functions.when(condition._jc, v)
+ return Column(jc)
+
+ @ignore_unicode_prefix
+ def otherwise(self, value):
+ """Evaluates a list of conditions and returns one of multiple possible result expressions.
+ If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+ See :func:`pyspark.sql.functions.when` for example usage.
+
+ :param value: a literal value, or a :class:`Column` expression.
+ """
+ v = value._jc if isinstance(value, Column) else value
+ jc = self._jc.otherwise(value)
+ return Column(jc)
+
+ def __repr__(self):
+ return 'Column<%s>' % self._jc.toString().encode('utf8')
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import SQLContext
+ import pyspark.sql.column
+ globs = pyspark.sql.column.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlContext'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
+ .toDF(StructType([StructField('age', IntegerType()),
+ StructField('name', StringType())]))
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.column, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2ed95ac8e2..96d927b9ba 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -25,17 +25,15 @@ if sys.version >= '3':
else:
from itertools import imap as map
-from pyspark.context import SparkContext
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
+from pyspark.sql.column import Column, _to_seq, _to_java_column
-
-__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions",
- "DataFrameStatFunctions"]
+__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
class DataFrame(object):
@@ -757,6 +755,7 @@ class DataFrame(object):
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
jdf = self._jdf.groupBy(self._jcols(*cols))
+ from pyspark.sql.group import GroupedData
return GroupedData(jdf, self.sql_ctx)
def agg(self, *exprs):
@@ -1141,169 +1140,6 @@ class SchemaRDD(DataFrame):
"""
-def dfapi(f):
- def _api(self):
- name = f.__name__
- jdf = getattr(self._jdf, name)()
- return DataFrame(jdf, self.sql_ctx)
- _api.__name__ = f.__name__
- _api.__doc__ = f.__doc__
- return _api
-
-
-def df_varargs_api(f):
- def _api(self, *args):
- name = f.__name__
- jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
- return DataFrame(jdf, self.sql_ctx)
- _api.__name__ = f.__name__
- _api.__doc__ = f.__doc__
- return _api
-
-
-class GroupedData(object):
- """
- A set of methods for aggregations on a :class:`DataFrame`,
- created by :func:`DataFrame.groupBy`.
- """
-
- def __init__(self, jdf, sql_ctx):
- self._jdf = jdf
- self.sql_ctx = sql_ctx
-
- @ignore_unicode_prefix
- def agg(self, *exprs):
- """Compute aggregates and returns the result as a :class:`DataFrame`.
-
- The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
-
- If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
- is the column to perform aggregation on, and the value is the aggregate function.
-
- Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
-
- :param exprs: a dict mapping from column name (string) to aggregate functions (string),
- or a list of :class:`Column`.
-
- >>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"*": "count"}).collect()
- [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
-
- >>> from pyspark.sql import functions as F
- >>> gdf.agg(F.min(df.age)).collect()
- [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
- """
- assert exprs, "exprs should not be empty"
- if len(exprs) == 1 and isinstance(exprs[0], dict):
- jdf = self._jdf.agg(exprs[0])
- else:
- # Columns
- assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
- jdf = self._jdf.agg(exprs[0]._jc,
- _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
- return DataFrame(jdf, self.sql_ctx)
-
- @dfapi
- def count(self):
- """Counts the number of records for each group.
-
- >>> df.groupBy(df.age).count().collect()
- [Row(age=2, count=1), Row(age=5, count=1)]
- """
-
- @df_varargs_api
- def mean(self, *cols):
- """Computes average values for each numeric columns for each group.
-
- :func:`mean` is an alias for :func:`avg`.
-
- :param cols: list of column names (string). Non-numeric columns are ignored.
-
- >>> df.groupBy().mean('age').collect()
- [Row(AVG(age)=3.5)]
- >>> df3.groupBy().mean('age', 'height').collect()
- [Row(AVG(age)=3.5, AVG(height)=82.5)]
- """
-
- @df_varargs_api
- def avg(self, *cols):
- """Computes average values for each numeric columns for each group.
-
- :func:`mean` is an alias for :func:`avg`.
-
- :param cols: list of column names (string). Non-numeric columns are ignored.
-
- >>> df.groupBy().avg('age').collect()
- [Row(AVG(age)=3.5)]
- >>> df3.groupBy().avg('age', 'height').collect()
- [Row(AVG(age)=3.5, AVG(height)=82.5)]
- """
-
- @df_varargs_api
- def max(self, *cols):
- """Computes the max value for each numeric columns for each group.
-
- >>> df.groupBy().max('age').collect()
- [Row(MAX(age)=5)]
- >>> df3.groupBy().max('age', 'height').collect()
- [Row(MAX(age)=5, MAX(height)=85)]
- """
-
- @df_varargs_api
- def min(self, *cols):
- """Computes the min value for each numeric column for each group.
-
- :param cols: list of column names (string). Non-numeric columns are ignored.
-
- >>> df.groupBy().min('age').collect()
- [Row(MIN(age)=2)]
- >>> df3.groupBy().min('age', 'height').collect()
- [Row(MIN(age)=2, MIN(height)=80)]
- """
-
- @df_varargs_api
- def sum(self, *cols):
- """Compute the sum for each numeric columns for each group.
-
- :param cols: list of column names (string). Non-numeric columns are ignored.
-
- >>> df.groupBy().sum('age').collect()
- [Row(SUM(age)=7)]
- >>> df3.groupBy().sum('age', 'height').collect()
- [Row(SUM(age)=7, SUM(height)=165)]
- """
-
-
-def _create_column_from_literal(literal):
- sc = SparkContext._active_spark_context
- return sc._jvm.functions.lit(literal)
-
-
-def _create_column_from_name(name):
- sc = SparkContext._active_spark_context
- return sc._jvm.functions.col(name)
-
-
-def _to_java_column(col):
- if isinstance(col, Column):
- jcol = col._jc
- else:
- jcol = _create_column_from_name(col)
- return jcol
-
-
-def _to_seq(sc, cols, converter=None):
- """
- Convert a list of Column (or names) into a JVM Seq of Column.
-
- An optional `converter` could be used to convert items in `cols`
- into JVM Column objects.
- """
- if converter:
- cols = [converter(c) for c in cols]
- return sc._jvm.PythonUtils.toSeq(cols)
-
-
def _to_scala_map(sc, jm):
"""
Convert a dict into a JVM Map.
@@ -1311,282 +1147,6 @@ def _to_scala_map(sc, jm):
return sc._jvm.PythonUtils.toScalaMap(jm)
-def _unary_op(name, doc="unary operator"):
- """ Create a method for given unary operator """
- def _(self):
- jc = getattr(self._jc, name)()
- return Column(jc)
- _.__doc__ = doc
- return _
-
-
-def _func_op(name, doc=''):
- def _(self):
- sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.functions, name)(self._jc)
- return Column(jc)
- _.__doc__ = doc
- return _
-
-
-def _bin_op(name, doc="binary operator"):
- """ Create a method for given binary operator
- """
- def _(self, other):
- jc = other._jc if isinstance(other, Column) else other
- njc = getattr(self._jc, name)(jc)
- return Column(njc)
- _.__doc__ = doc
- return _
-
-
-def _reverse_op(name, doc="binary operator"):
- """ Create a method for binary operator (this object is on right side)
- """
- def _(self, other):
- jother = _create_column_from_literal(other)
- jc = getattr(jother, name)(self._jc)
- return Column(jc)
- _.__doc__ = doc
- return _
-
-
-class Column(object):
-
- """
- A column in a DataFrame.
-
- :class:`Column` instances can be created by::
-
- # 1. Select a column out of a DataFrame
-
- df.colName
- df["colName"]
-
- # 2. Create from an expression
- df.colName + 1
- 1 / df.colName
- """
-
- def __init__(self, jc):
- self._jc = jc
-
- # arithmetic operators
- __neg__ = _func_op("negate")
- __add__ = _bin_op("plus")
- __sub__ = _bin_op("minus")
- __mul__ = _bin_op("multiply")
- __div__ = _bin_op("divide")
- __truediv__ = _bin_op("divide")
- __mod__ = _bin_op("mod")
- __radd__ = _bin_op("plus")
- __rsub__ = _reverse_op("minus")
- __rmul__ = _bin_op("multiply")
- __rdiv__ = _reverse_op("divide")
- __rtruediv__ = _reverse_op("divide")
- __rmod__ = _reverse_op("mod")
-
- # logistic operators
- __eq__ = _bin_op("equalTo")
- __ne__ = _bin_op("notEqual")
- __lt__ = _bin_op("lt")
- __le__ = _bin_op("leq")
- __ge__ = _bin_op("geq")
- __gt__ = _bin_op("gt")
-
- # `and`, `or`, `not` cannot be overloaded in Python,
- # so use bitwise operators as boolean operators
- __and__ = _bin_op('and')
- __or__ = _bin_op('or')
- __invert__ = _func_op('not')
- __rand__ = _bin_op("and")
- __ror__ = _bin_op("or")
-
- # container operators
- __contains__ = _bin_op("contains")
- __getitem__ = _bin_op("apply")
-
- # bitwise operators
- bitwiseOR = _bin_op("bitwiseOR")
- bitwiseAND = _bin_op("bitwiseAND")
- bitwiseXOR = _bin_op("bitwiseXOR")
-
- def getItem(self, key):
- """An expression that gets an item at position `ordinal` out of a list,
- or gets an item by key out of a dict.
-
- >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
- >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
- +----+------+
- |l[0]|d[key]|
- +----+------+
- | 1| value|
- +----+------+
- >>> df.select(df.l[0], df.d["key"]).show()
- +----+------+
- |l[0]|d[key]|
- +----+------+
- | 1| value|
- +----+------+
- """
- return self[key]
-
- def getField(self, name):
- """An expression that gets a field by name in a StructField.
-
- >>> from pyspark.sql import Row
- >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
- >>> df.select(df.r.getField("b")).show()
- +----+
- |r[b]|
- +----+
- | b|
- +----+
- >>> df.select(df.r.a).show()
- +----+
- |r[a]|
- +----+
- | 1|
- +----+
- """
- return self[name]
-
- def __getattr__(self, item):
- if item.startswith("__"):
- raise AttributeError(item)
- return self.getField(item)
-
- # string methods
- rlike = _bin_op("rlike")
- like = _bin_op("like")
- startswith = _bin_op("startsWith")
- endswith = _bin_op("endsWith")
-
- @ignore_unicode_prefix
- def substr(self, startPos, length):
- """
- Return a :class:`Column` which is a substring of the column
-
- :param startPos: start position (int or Column)
- :param length: length of the substring (int or Column)
-
- >>> df.select(df.name.substr(1, 3).alias("col")).collect()
- [Row(col=u'Ali'), Row(col=u'Bob')]
- """
- if type(startPos) != type(length):
- raise TypeError("Can not mix the type")
- if isinstance(startPos, (int, long)):
- jc = self._jc.substr(startPos, length)
- elif isinstance(startPos, Column):
- jc = self._jc.substr(startPos._jc, length._jc)
- else:
- raise TypeError("Unexpected type: %s" % type(startPos))
- return Column(jc)
-
- __getslice__ = substr
-
- @ignore_unicode_prefix
- def inSet(self, *cols):
- """ A boolean expression that is evaluated to true if the value of this
- expression is contained by the evaluated values of the arguments.
-
- >>> df[df.name.inSet("Bob", "Mike")].collect()
- [Row(age=5, name=u'Bob')]
- >>> df[df.age.inSet([1, 2, 3])].collect()
- [Row(age=2, name=u'Alice')]
- """
- if len(cols) == 1 and isinstance(cols[0], (list, set)):
- cols = cols[0]
- cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
- sc = SparkContext._active_spark_context
- jc = getattr(self._jc, "in")(_to_seq(sc, cols))
- return Column(jc)
-
- # order
- asc = _unary_op("asc", "Returns a sort expression based on the"
- " ascending order of the given column name.")
- desc = _unary_op("desc", "Returns a sort expression based on the"
- " descending order of the given column name.")
-
- isNull = _unary_op("isNull", "True if the current expression is null.")
- isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
-
- def alias(self, *alias):
- """Returns this column aliased with a new name or names (in the case of expressions that
- return more than one column, such as explode).
-
- >>> df.select(df.age.alias("age2")).collect()
- [Row(age2=2), Row(age2=5)]
- """
-
- if len(alias) == 1:
- return Column(getattr(self._jc, "as")(alias[0]))
- else:
- sc = SparkContext._active_spark_context
- return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
-
- @ignore_unicode_prefix
- def cast(self, dataType):
- """ Convert the column into type `dataType`
-
- >>> df.select(df.age.cast("string").alias('ages')).collect()
- [Row(ages=u'2'), Row(ages=u'5')]
- >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
- [Row(ages=u'2'), Row(ages=u'5')]
- """
- if isinstance(dataType, basestring):
- jc = self._jc.cast(dataType)
- elif isinstance(dataType, DataType):
- sc = SparkContext._active_spark_context
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(dataType.json())
- jc = self._jc.cast(jdt)
- else:
- raise TypeError("unexpected type: %s" % type(dataType))
- return Column(jc)
-
- @ignore_unicode_prefix
- def between(self, lowerBound, upperBound):
- """ A boolean expression that is evaluated to true if the value of this
- expression is between the given columns.
- """
- return (self >= lowerBound) & (self <= upperBound)
-
- @ignore_unicode_prefix
- def when(self, condition, value):
- """Evaluates a list of conditions and returns one of multiple possible result expressions.
- If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
-
- See :func:`pyspark.sql.functions.when` for example usage.
-
- :param condition: a boolean :class:`Column` expression.
- :param value: a literal value, or a :class:`Column` expression.
-
- """
- sc = SparkContext._active_spark_context
- if not isinstance(condition, Column):
- raise TypeError("condition should be a Column")
- v = value._jc if isinstance(value, Column) else value
- jc = sc._jvm.functions.when(condition._jc, v)
- return Column(jc)
-
- @ignore_unicode_prefix
- def otherwise(self, value):
- """Evaluates a list of conditions and returns one of multiple possible result expressions.
- If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
-
- See :func:`pyspark.sql.functions.when` for example usage.
-
- :param value: a literal value, or a :class:`Column` expression.
- """
- v = value._jc if isinstance(value, Column) else value
- jc = self._jc.otherwise(value)
- return Column(jc)
-
- def __repr__(self):
- return 'Column<%s>' % self._jc.toString().encode('utf8')
-
-
class DataFrameNaFunctions(object):
"""Functionality for working with missing data in :class:`DataFrame`.
"""
@@ -1646,9 +1206,6 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
- globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
- Row(name='Bob', age=5, height=85)]).toDF()
-
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6cd6974b0e..8d0e766ecd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -27,7 +27,7 @@ from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
-from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
+from pyspark.sql.column import Column, _to_java_column, _to_seq
__all__ = [
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
new file mode 100644
index 0000000000..9f7c743c05
--- /dev/null
+++ b/python/pyspark/sql/group.py
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.rdd import ignore_unicode_prefix
+from pyspark.sql.column import Column, _to_seq
+from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.types import *
+
+__all__ = ["GroupedData"]
+
+
+def dfapi(f):
+ def _api(self):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)()
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+def df_varargs_api(f):
+ def _api(self, *args):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+class GroupedData(object):
+ """
+ A set of methods for aggregations on a :class:`DataFrame`,
+ created by :func:`DataFrame.groupBy`.
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+
+ @ignore_unicode_prefix
+ def agg(self, *exprs):
+ """Compute aggregates and returns the result as a :class:`DataFrame`.
+
+ The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
+
+ If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
+ is the column to perform aggregation on, and the value is the aggregate function.
+
+ Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
+
+ :param exprs: a dict mapping from column name (string) to aggregate functions (string),
+ or a list of :class:`Column`.
+
+ >>> gdf = df.groupBy(df.name)
+ >>> gdf.agg({"*": "count"}).collect()
+ [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
+
+ >>> from pyspark.sql import functions as F
+ >>> gdf.agg(F.min(df.age)).collect()
+ [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
+ """
+ assert exprs, "exprs should not be empty"
+ if len(exprs) == 1 and isinstance(exprs[0], dict):
+ jdf = self._jdf.agg(exprs[0])
+ else:
+ # Columns
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
+ jdf = self._jdf.agg(exprs[0]._jc,
+ _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
+ return DataFrame(jdf, self.sql_ctx)
+
+ @dfapi
+ def count(self):
+ """Counts the number of records for each group.
+
+ >>> df.groupBy(df.age).count().collect()
+ [Row(age=2, count=1), Row(age=5, count=1)]
+ """
+
+ @df_varargs_api
+ def mean(self, *cols):
+ """Computes average values for each numeric columns for each group.
+
+ :func:`mean` is an alias for :func:`avg`.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
+
+ >>> df.groupBy().mean('age').collect()
+ [Row(AVG(age)=3.5)]
+ >>> df3.groupBy().mean('age', 'height').collect()
+ [Row(AVG(age)=3.5, AVG(height)=82.5)]
+ """
+
+ @df_varargs_api
+ def avg(self, *cols):
+ """Computes average values for each numeric columns for each group.
+
+ :func:`mean` is an alias for :func:`avg`.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
+
+ >>> df.groupBy().avg('age').collect()
+ [Row(AVG(age)=3.5)]
+ >>> df3.groupBy().avg('age', 'height').collect()
+ [Row(AVG(age)=3.5, AVG(height)=82.5)]
+ """
+
+ @df_varargs_api
+ def max(self, *cols):
+ """Computes the max value for each numeric columns for each group.
+
+ >>> df.groupBy().max('age').collect()
+ [Row(MAX(age)=5)]
+ >>> df3.groupBy().max('age', 'height').collect()
+ [Row(MAX(age)=5, MAX(height)=85)]
+ """
+
+ @df_varargs_api
+ def min(self, *cols):
+ """Computes the min value for each numeric column for each group.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
+
+ >>> df.groupBy().min('age').collect()
+ [Row(MIN(age)=2)]
+ >>> df3.groupBy().min('age', 'height').collect()
+ [Row(MIN(age)=2, MIN(height)=80)]
+ """
+
+ @df_varargs_api
+ def sum(self, *cols):
+ """Compute the sum for each numeric columns for each group.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
+
+ >>> df.groupBy().sum('age').collect()
+ [Row(SUM(age)=7)]
+ >>> df3.groupBy().sum('age', 'height').collect()
+ [Row(SUM(age)=7, SUM(height)=165)]
+ """
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext
+ import pyspark.sql.group
+ globs = pyspark.sql.group.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlContext'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
+ .toDF(StructType([StructField('age', IntegerType()),
+ StructField('name', StringType())]))
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.group, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/run-tests b/python/run-tests
index f2757a3967..ffde2fb24b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -72,7 +72,9 @@ function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql/_types.py"
run_test "pyspark/sql/context.py"
+ run_test "pyspark/sql/column.py"
run_test "pyspark/sql/dataframe.py"
+ run_test "pyspark/sql/group.py"
run_test "pyspark/sql/functions.py"
run_test "pyspark/sql/tests.py"
}