aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py249
1 files changed, 125 insertions, 124 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1550802332..c30326ebd1 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -35,8 +35,7 @@ __all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFuncti
class DataFrame(object):
-
- """A collection of rows that have the same columns.
+ """A distributed collection of data grouped into named columns.
A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
and can be created using various functions in :class:`SQLContext`::
@@ -69,9 +68,7 @@ class DataFrame(object):
@property
def rdd(self):
- """
- Return the content of the :class:`DataFrame` as an :class:`pyspark.RDD`
- of :class:`Row` s.
+ """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
"""
if not hasattr(self, '_lazy_rdd'):
jrdd = self._jdf.javaToPython()
@@ -93,7 +90,9 @@ class DataFrame(object):
return DataFrameNaFunctions(self)
def toJSON(self, use_unicode=False):
- """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row.
+ """Converts a :class:`DataFrame` into a :class:`RDD` of string.
+
+ Each row is turned into a JSON document as one element in the returned RDD.
>>> df.toJSON().first()
'{"age":2,"name":"Alice"}'
@@ -102,10 +101,10 @@ class DataFrame(object):
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
def saveAsParquetFile(self, path):
- """Save the contents as a Parquet file, preserving the schema.
+ """Saves the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a :class:`DataFrame` using the L{SQLContext.parquetFile} method.
+ a :class:`DataFrame` using :func:`SQLContext.parquetFile`.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
@@ -120,8 +119,8 @@ class DataFrame(object):
def registerTempTable(self, name):
"""Registers this RDD as a temporary table using the given name.
- The lifetime of this temporary table is tied to the L{SQLContext}
- that was used to create this DataFrame.
+ The lifetime of this temporary table is tied to the :class:`SQLContext`
+ that was used to create this :class:`DataFrame`.
>>> df.registerTempTable("people")
>>> df2 = sqlCtx.sql("select * from people")
@@ -131,7 +130,7 @@ class DataFrame(object):
self._jdf.registerTempTable(name)
def registerAsTable(self, name):
- """DEPRECATED: use registerTempTable() instead"""
+ """DEPRECATED: use :func:`registerTempTable` instead"""
warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
self.registerTempTable(name)
@@ -162,22 +161,19 @@ class DataFrame(object):
return jmode
def saveAsTable(self, tableName, source=None, mode="error", **options):
- """Saves the contents of the :class:`DataFrame` to a data source as a table.
+ """Saves the contents of this :class:`DataFrame` to a data source as a table.
- The data source is specified by the `source` and a set of `options`.
- If `source` is not specified, the default data source configured by
- spark.sql.sources.default will be used.
+ The data source is specified by the ``source`` and a set of ``options``.
+ If ``source`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
Additionally, mode is used to specify the behavior of the saveAsTable operation when
table already exists in the data source. There are four modes:
- * append: Contents of this :class:`DataFrame` are expected to be appended \
- to existing table.
- * overwrite: Data in the existing table is expected to be overwritten by \
- the contents of this DataFrame.
- * error: An exception is expected to be thrown.
- * ignore: The save operation is expected to not save the contents of the \
- :class:`DataFrame` and to not change the existing table.
+ * `append`: Append contents of this :class:`DataFrame` to existing data.
+ * `overwrite`: Overwrite existing data.
+ * `error`: Throw an exception if data already exists.
+ * `ignore`: Silently ignore this operation if data already exists.
"""
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
@@ -190,18 +186,17 @@ class DataFrame(object):
def save(self, path=None, source=None, mode="error", **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
- The data source is specified by the `source` and a set of `options`.
- If `source` is not specified, the default data source configured by
- spark.sql.sources.default will be used.
+ The data source is specified by the ``source`` and a set of ``options``.
+ If ``source`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
Additionally, mode is used to specify the behavior of the save operation when
data already exists in the data source. There are four modes:
- * append: Contents of this :class:`DataFrame` are expected to be appended to existing data.
- * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
- * error: An exception is expected to be thrown.
- * ignore: The save operation is expected to not save the contents of \
- the :class:`DataFrame` and to not change the existing data.
+ * `append`: Append contents of this :class:`DataFrame` to existing data.
+ * `overwrite`: Overwrite existing data.
+ * `error`: Throw an exception if data already exists.
+ * `ignore`: Silently ignore this operation if data already exists.
"""
if path is not None:
options["path"] = path
@@ -215,8 +210,7 @@ class DataFrame(object):
@property
def schema(self):
- """Returns the schema of this :class:`DataFrame` (represented by
- a L{StructType}).
+ """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`.
>>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
@@ -237,11 +231,9 @@ class DataFrame(object):
print (self._jdf.schema().treeString())
def explain(self, extended=False):
- """
- Prints the plans (logical and physical) to the console for
- debugging purpose.
+ """Prints the (logical and physical) plans to the console for debugging purpose.
- If extended is False, only prints the physical plan.
+ :param extended: boolean, default ``False``. If ``False``, prints only the physical plan.
>>> df.explain()
PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:...
@@ -263,15 +255,13 @@ class DataFrame(object):
print self._jdf.queryExecution().executedPlan().toString()
def isLocal(self):
- """
- Returns True if the `collect` and `take` methods can be run locally
+ """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
(without any Spark executors).
"""
return self._jdf.isLocal()
def show(self, n=20):
- """
- Print the first n rows.
+ """Prints the first ``n`` rows to the console.
>>> df
DataFrame[age: int, name: string]
@@ -286,11 +276,7 @@ class DataFrame(object):
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
def count(self):
- """Return the number of elements in this RDD.
-
- Unlike the base RDD implementation of count, this implementation
- leverages the query optimizer to compute the count on the DataFrame,
- which supports features such as filter pushdown.
+ """Returns the number of rows in this :class:`DataFrame`.
>>> df.count()
2L
@@ -298,10 +284,7 @@ class DataFrame(object):
return self._jdf.count()
def collect(self):
- """Return a list that contains all of the rows.
-
- Each object in the list is a Row, the fields can be accessed as
- attributes.
+ """Returns all the records as a list of :class:`Row`.
>>> df.collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
@@ -313,7 +296,7 @@ class DataFrame(object):
return [cls(r) for r in rs]
def limit(self, num):
- """Limit the result count to the number specified.
+ """Limits the result count to the number specified.
>>> df.limit(1).collect()
[Row(age=2, name=u'Alice')]
@@ -324,10 +307,7 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
def take(self, num):
- """Take the first num rows of the RDD.
-
- Each object in the list is a Row, the fields can be accessed as
- attributes.
+ """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
@@ -335,9 +315,9 @@ class DataFrame(object):
return self.limit(num).collect()
def map(self, f):
- """ Return a new RDD by applying a function to each Row
+ """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`.
- It's a shorthand for df.rdd.map()
+ This is a shorthand for ``df.rdd.map()``.
>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
@@ -345,10 +325,10 @@ class DataFrame(object):
return self.rdd.map(f)
def flatMap(self, f):
- """ Return a new RDD by first applying a function to all elements of this,
+ """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`,
and then flattening the results.
- It's a shorthand for df.rdd.flatMap()
+ This is a shorthand for ``df.rdd.flatMap()``.
>>> df.flatMap(lambda p: p.name).collect()
[u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
@@ -356,10 +336,9 @@ class DataFrame(object):
return self.rdd.flatMap(f)
def mapPartitions(self, f, preservesPartitioning=False):
- """
- Return a new RDD by applying a function to each partition.
+ """Returns a new :class:`RDD` by applying the ``f`` function to each partition.
- It's a shorthand for df.rdd.mapPartitions()
+ This is a shorthand for ``df.rdd.mapPartitions()``.
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
@@ -369,10 +348,9 @@ class DataFrame(object):
return self.rdd.mapPartitions(f, preservesPartitioning)
def foreach(self, f):
- """
- Applies a function to all rows of this DataFrame.
+ """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`.
- It's a shorthand for df.rdd.foreach()
+ This is a shorthand for ``df.rdd.foreach()``.
>>> def f(person):
... print person.name
@@ -381,10 +359,9 @@ class DataFrame(object):
return self.rdd.foreach(f)
def foreachPartition(self, f):
- """
- Applies a function to each partition of this DataFrame.
+ """Applies the ``f`` function to each partition of this :class:`DataFrame`.
- It's a shorthand for df.rdd.foreachPartition()
+ This a shorthand for ``df.rdd.foreachPartition()``.
>>> def f(people):
... for person in people:
@@ -394,14 +371,14 @@ class DataFrame(object):
return self.rdd.foreachPartition(f)
def cache(self):
- """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+ """ Persists with the default storage level (C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
self._jdf.cache()
return self
def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
- """ Set the storage level to persist its values across operations
+ """Sets the storage level to persist its values across operations
after the first time it is computed. This can only be used to assign
a new storage level if the RDD does not have a storage level set yet.
If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
@@ -412,7 +389,7 @@ class DataFrame(object):
return self
def unpersist(self, blocking=True):
- """ Mark it as non-persistent, and remove all blocks for it from
+ """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from
memory and disk.
"""
self.is_cached = False
@@ -424,8 +401,7 @@ class DataFrame(object):
# return DataFrame(rdd, self.sql_ctx)
def repartition(self, numPartitions):
- """ Return a new :class:`DataFrame` that has exactly `numPartitions`
- partitions.
+ """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions.
>>> df.repartition(10).rdd.getNumPartitions()
10
@@ -433,8 +409,7 @@ class DataFrame(object):
return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
def distinct(self):
- """
- Return a new :class:`DataFrame` containing the distinct rows in this DataFrame.
+ """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
>>> df.distinct().count()
2L
@@ -442,8 +417,7 @@ class DataFrame(object):
return DataFrame(self._jdf.distinct(), self.sql_ctx)
def sample(self, withReplacement, fraction, seed=None):
- """
- Return a sampled subset of this DataFrame.
+ """Returns a sampled subset of this :class:`DataFrame`.
>>> df.sample(False, 0.5, 97).count()
1L
@@ -455,7 +429,7 @@ class DataFrame(object):
@property
def dtypes(self):
- """Return all column names and their data types as a list.
+ """Returns all column names and their data types as a list.
>>> df.dtypes
[('age', 'int'), ('name', 'string')]
@@ -464,7 +438,7 @@ class DataFrame(object):
@property
def columns(self):
- """ Return all column names as a list.
+ """Returns all column names as a list.
>>> df.columns
[u'age', u'name']
@@ -472,13 +446,14 @@ class DataFrame(object):
return [f.name for f in self.schema.fields]
def join(self, other, joinExprs=None, joinType=None):
- """
- Join with another :class:`DataFrame`, using the given join expression.
- The following performs a full outer join between `df1` and `df2`.
+ """Joins with another :class:`DataFrame`, using the given join expression.
+
+ The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
:param joinExprs: Join expression
- :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+ :param joinType: str, default 'inner'.
+ One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
@@ -496,9 +471,9 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
def sort(self, *cols):
- """ Return a new :class:`DataFrame` sorted by the specified column(s).
+ """Returns a new :class:`DataFrame` sorted by the specified column(s).
- :param cols: The columns or expressions used for sorting
+ :param cols: list of :class:`Column` to sort by.
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
@@ -539,7 +514,9 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
def head(self, n=None):
- """ Return the first `n` rows or the first row if n is None.
+ """
+ Returns the first ``n`` rows as a list of :class:`Row`,
+ or the first :class:`Row` if ``n`` is ``None.``
>>> df.head()
Row(age=2, name=u'Alice')
@@ -552,7 +529,7 @@ class DataFrame(object):
return self.take(n)
def first(self):
- """ Return the first row.
+ """Returns the first row as a :class:`Row`.
>>> df.first()
Row(age=2, name=u'Alice')
@@ -560,7 +537,7 @@ class DataFrame(object):
return self.head()
def __getitem__(self, item):
- """ Return the column by given name
+ """Returns the column as a :class:`Column`.
>>> df.select(df['age']).collect()
[Row(age=2), Row(age=5)]
@@ -580,7 +557,7 @@ class DataFrame(object):
raise IndexError("unexpected index: %s" % item)
def __getattr__(self, name):
- """ Return the column by given name
+ """Returns the :class:`Column` denoted by ``name``.
>>> df.select(df.age).collect()
[Row(age=2), Row(age=5)]
@@ -591,7 +568,11 @@ class DataFrame(object):
return Column(jc)
def select(self, *cols):
- """ Selecting a set of expressions.
+ """Projects a set of expressions and returns a new :class:`DataFrame`.
+
+ :param cols: list of column names (string) or expressions (:class:`Column`).
+ If one of the column names is '*', that column is expanded to include all columns
+ in the current DataFrame.
>>> df.select('*').collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
@@ -606,9 +587,9 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
def selectExpr(self, *expr):
- """
- Selects a set of SQL expressions. This is a variant of
- `select` that accepts SQL expressions.
+ """Projects a set of SQL expressions and returns a new :class:`DataFrame`.
+
+ This is a variant of :func:`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
@@ -618,10 +599,12 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
def filter(self, condition):
- """ Filtering rows using the given condition, which could be
- :class:`Column` expression or string of SQL expression.
+ """Filters rows using the given condition.
+
+ :func:`where` is an alias for :func:`filter`.
- where() is an alias for filter().
+ :param condition: a :class:`Column` of :class:`types.BooleanType`
+ or a string of SQL expression.
>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
@@ -644,10 +627,13 @@ class DataFrame(object):
where = filter
def groupBy(self, *cols):
- """ Group the :class:`DataFrame` using the specified columns,
+ """Groups the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
+ :param cols: list of columns to group by.
+ Each element should be a column name (string) or an expression (:class:`Column`).
+
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
@@ -662,7 +648,7 @@ class DataFrame(object):
def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
- (shorthand for df.groupBy.agg()).
+ (shorthand for ``df.groupBy.agg()``).
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age)=5)]
@@ -699,7 +685,7 @@ class DataFrame(object):
def dropna(self, how='any', thresh=None, subset=None):
"""Returns a new :class:`DataFrame` omitting rows with null values.
- This is an alias for `na.drop`.
+ This is an alias for ``na.drop()``.
:param how: 'any' or 'all'.
If 'any', drop a row if it contains any nulls.
@@ -735,7 +721,7 @@ class DataFrame(object):
return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
def fillna(self, value, subset=None):
- """Replace null values, alias for `na.fill`.
+ """Replace null values, alias for ``na.fill()``.
:param value: int, long, float, string, or dict.
Value to replace null values with.
@@ -790,7 +776,10 @@ class DataFrame(object):
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
def withColumn(self, colName, col):
- """ Return a new :class:`DataFrame` by adding a column.
+ """Returns a new :class:`DataFrame` by adding a column.
+
+ :param colName: string, name of the new column.
+ :param col: a :class:`Column` expression for the new column.
>>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
@@ -798,7 +787,10 @@ class DataFrame(object):
return self.select('*', col.alias(colName))
def withColumnRenamed(self, existing, new):
- """ Rename an existing column to a new name
+ """REturns a new :class:`DataFrame` by renaming an existing column.
+
+ :param existing: string, name of the existing column to rename.
+ :param col: string, new name of the column.
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
@@ -809,8 +801,9 @@ class DataFrame(object):
return self.select(*cols)
def toPandas(self):
- """
- Collect all the rows and return a `pandas.DataFrame`.
+ """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
+
+ This is only available if Pandas is installed and available.
>>> df.toPandas() # doctest: +SKIP
age name
@@ -823,8 +816,7 @@ class DataFrame(object):
# Having SchemaRDD for backward compatibility (for docs)
class SchemaRDD(DataFrame):
- """
- SchemaRDD is deprecated, please use DataFrame
+ """SchemaRDD is deprecated, please use :class:`DataFrame`.
"""
@@ -851,10 +843,9 @@ def df_varargs_api(f):
class GroupedData(object):
-
"""
A set of methods for aggregations on a :class:`DataFrame`,
- created by DataFrame.groupBy().
+ created by :func:`DataFrame.groupBy`.
"""
def __init__(self, jdf, sql_ctx):
@@ -862,14 +853,17 @@ class GroupedData(object):
self.sql_ctx = sql_ctx
def agg(self, *exprs):
- """ Compute aggregates by specifying a map from column name
- to aggregate methods.
+ """Compute aggregates and returns the result as a :class:`DataFrame`.
+
+ The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
+
+ If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
+ is the column to perform aggregation on, and the value is the aggregate function.
- The available aggregate methods are `avg`, `max`, `min`,
- `sum`, `count`.
+ Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
- :param exprs: list or aggregate columns or a map from column
- name to aggregate methods.
+ :param exprs: a dict mapping from column name (string) to aggregate functions (string),
+ or a list of :class:`Column`.
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
@@ -894,7 +888,7 @@ class GroupedData(object):
@dfapi
def count(self):
- """ Count the number of rows for each group.
+ """Counts the number of records for each group.
>>> df.groupBy(df.age).count().collect()
[Row(age=2, count=1), Row(age=5, count=1)]
@@ -902,8 +896,11 @@ class GroupedData(object):
@df_varargs_api
def mean(self, *cols):
- """Compute the average value for each numeric columns
- for each group. This is an alias for `avg`.
+ """Computes average values for each numeric columns for each group.
+
+ :func:`mean` is an alias for :func:`avg`.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().mean('age').collect()
[Row(AVG(age)=3.5)]
@@ -913,8 +910,11 @@ class GroupedData(object):
@df_varargs_api
def avg(self, *cols):
- """Compute the average value for each numeric columns
- for each group.
+ """Computes average values for each numeric columns for each group.
+
+ :func:`mean` is an alias for :func:`avg`.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().avg('age').collect()
[Row(AVG(age)=3.5)]
@@ -924,8 +924,7 @@ class GroupedData(object):
@df_varargs_api
def max(self, *cols):
- """Compute the max value for each numeric columns for
- each group.
+ """Computes the max value for each numeric columns for each group.
>>> df.groupBy().max('age').collect()
[Row(MAX(age)=5)]
@@ -935,8 +934,9 @@ class GroupedData(object):
@df_varargs_api
def min(self, *cols):
- """Compute the min value for each numeric column for
- each group.
+ """Computes the min value for each numeric column for each group.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().min('age').collect()
[Row(MIN(age)=2)]
@@ -946,8 +946,9 @@ class GroupedData(object):
@df_varargs_api
def sum(self, *cols):
- """Compute the sum for each numeric columns for each
- group.
+ """Compute the sum for each numeric columns for each group.
+
+ :param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().sum('age').collect()
[Row(SUM(age)=7)]