aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-02 08:37:18 -0700
committerPatrick Wendell <patrick@databricks.com>2015-06-02 08:37:18 -0700
commit445647a1a36e1e24076a9fe506492fac462c66ad (patch)
tree1f466e981d060d00103d4e5f944d0de808942907 /python
parent0f80990bfac1e9969644952d1d8edaf7d26fb436 (diff)
downloadspark-445647a1a36e1e24076a9fe506492fac462c66ad.tar.gz
spark-445647a1a36e1e24076a9fe506492fac462c66ad.tar.bz2
spark-445647a1a36e1e24076a9fe506492fac462c66ad.zip
[SPARK-8021] [SQL] [PYSPARK] make Python read/write API consistent with Scala
add schema()/format()/options() for reader, add mode()/format()/options()/partitionBy() for writer cc rxin yhuai pwendell Author: Davies Liu <davies@databricks.com> Closes #6578 from davies/readwrite and squashes the following commits: 720d293 [Davies Liu] address comments b65dfa2 [Davies Liu] Update readwriter.py 1299ab6 [Davies Liu] make Python API consistent with Scala
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/readwriter.py121
1 files changed, 94 insertions, 27 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b6fd413bec..d17d87419f 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -44,6 +44,39 @@ class DataFrameReader(object):
return DataFrame(jdf, self._sqlContext)
@since(1.4)
+ def format(self, source):
+ """
+ Specifies the input data source format.
+ """
+ self._jreader = self._jreader.format(source)
+ return self
+
+ @since(1.4)
+ def schema(self, schema):
+ """
+ Specifies the input schema. Some data sources (e.g. JSON) can
+ infer the input schema automatically from data. By specifying
+ the schema here, the underlying data source can skip the schema
+ inference step, and thus speed up data loading.
+
+ :param schema: a StructType object
+ """
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+ self._jreader = self._jreader.schema(jschema)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """
+ Adds input options for the underlying data source.
+ """
+ for k in options:
+ self._jreader = self._jreader.option(k, options[k])
+ return self
+
+ @since(1.4)
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.
@@ -52,20 +85,15 @@ class DataFrameReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
"""
- jreader = self._jreader
if format is not None:
- jreader = jreader.format(format)
+ self.format(format)
if schema is not None:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jreader = jreader.schema(jschema)
- for k in options:
- jreader = jreader.option(k, options[k])
+ self.schema(schema)
+ self.options(**options)
if path is not None:
- return self._df(jreader.load(path))
+ return self._df(self._jreader.load(path))
else:
- return self._df(jreader.load())
+ return self._df(self._jreader.load())
@since(1.4)
def json(self, path, schema=None):
@@ -105,12 +133,9 @@ class DataFrameReader(object):
| |-- field5: array (nullable = true)
| | |-- element: integer (containsNull = true)
"""
- if schema is None:
- jdf = self._jreader.json(path)
- else:
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jdf = self._jreader.schema(jschema).json(path)
- return self._df(jdf)
+ if schema is not None:
+ self.schema(schema)
+ return self._df(self._jreader.json(path))
@since(1.4)
def table(self, tableName):
@@ -195,6 +220,51 @@ class DataFrameWriter(object):
self._jwrite = df._jdf.write()
@since(1.4)
+ def mode(self, saveMode):
+ """
+ Specifies the behavior when data or table already exists. Options include:
+
+ * `append`: Append contents of this :class:`DataFrame` to existing data.
+ * `overwrite`: Overwrite existing data.
+ * `error`: Throw an exception if data already exists.
+ * `ignore`: Silently ignore this operation if data already exists.
+ """
+ self._jwrite = self._jwrite.mode(saveMode)
+ return self
+
+ @since(1.4)
+ def format(self, source):
+ """
+ Specifies the underlying output data source. Built-in options include
+ "parquet", "json", etc.
+ """
+ self._jwrite = self._jwrite.format(source)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """
+ Adds output options for the underlying data source.
+ """
+ for k in options:
+ self._jwrite = self._jwrite.option(k, options[k])
+ return self
+
+ @since(1.4)
+ def partitionBy(self, *cols):
+ """
+ Partitions the output by the given columns on the file system.
+ If specified, the output is laid out on the file system similar
+ to Hive's partitioning scheme.
+
+ :param cols: name of columns
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0]
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ return self
+
+ @since(1.4)
def save(self, path=None, format=None, mode="error", **options):
"""
Saves the contents of the :class:`DataFrame` to a data source.
@@ -216,16 +286,15 @@ class DataFrameWriter(object):
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
+ self.format(format)
if path is None:
- jwrite.save()
+ self._jwrite.save()
else:
- jwrite.save(path)
+ self._jwrite.save(path)
+ @since(1.4)
def insertInto(self, tableName, overwrite=False):
"""
Inserts the content of the :class:`DataFrame` to the specified table.
@@ -256,12 +325,10 @@ class DataFrameWriter(object):
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
- return jwrite.saveAsTable(name)
+ self.format(format)
+ return self._jwrite.saveAsTable(name)
@since(1.4)
def json(self, path, mode="error"):