From 445647a1a36e1e24076a9fe506492fac462c66ad Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Jun 2015 08:37:18 -0700 Subject: [SPARK-8021] [SQL] [PYSPARK] make Python read/write API consistent with Scala add schema()/format()/options() for reader, add mode()/format()/options()/partitionBy() for writer cc rxin yhuai pwendell Author: Davies Liu Closes #6578 from davies/readwrite and squashes the following commits: 720d293 [Davies Liu] address comments b65dfa2 [Davies Liu] Update readwriter.py 1299ab6 [Davies Liu] make Python API consistent with Scala --- python/pyspark/sql/readwriter.py | 121 ++++++++++++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 27 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b6fd413bec..d17d87419f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -43,6 +43,39 @@ class DataFrameReader(object): from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._sqlContext) + @since(1.4) + def format(self, source): + """ + Specifies the input data source format. + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """ + Specifies the input schema. Some data sources (e.g. JSON) can + infer the input schema automatically from data. By specifying + the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.4) + def options(self, **options): + """ + Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, options[k]) + return self + @since(1.4) def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. @@ -52,20 +85,15 @@ class DataFrameReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options """ - jreader = self._jreader if format is not None: - jreader = jreader.format(format) + self.format(format) if schema is not None: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jreader = jreader.schema(jschema) - for k in options: - jreader = jreader.option(k, options[k]) + self.schema(schema) + self.options(**options) if path is not None: - return self._df(jreader.load(path)) + return self._df(self._jreader.load(path)) else: - return self._df(jreader.load()) + return self._df(self._jreader.load()) @since(1.4) def json(self, path, schema=None): @@ -105,12 +133,9 @@ class DataFrameReader(object): | |-- field5: array (nullable = true) | | |-- element: integer (containsNull = true) """ - if schema is None: - jdf = self._jreader.json(path) - else: - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jdf = self._jreader.schema(jschema).json(path) - return self._df(jdf) + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) @since(1.4) def table(self, tableName): @@ -194,6 +219,51 @@ class DataFrameWriter(object): self._sqlContext = df.sql_ctx self._jwrite = df._jdf.write() + @since(1.4) + def mode(self, saveMode): + """ + Specifies the behavior when data or table already exists. Options include: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """ + Specifies the underlying output data source. Built-in options include + "parquet", "json", etc. + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.4) + def options(self, **options): + """ + Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """ + Partitions the output by the given columns on the file system. + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode="error", **options): """ @@ -216,16 +286,15 @@ class DataFrameWriter(object): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) + self.format(format) if path is None: - jwrite.save() + self._jwrite.save() else: - jwrite.save(path) + self._jwrite.save(path) + @since(1.4) def insertInto(self, tableName, overwrite=False): """ Inserts the content of the :class:`DataFrame` to the specified table. @@ -256,12 +325,10 @@ class DataFrameWriter(object): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) - return jwrite.saveAsTable(name) + self.format(format) + return self._jwrite.saveAsTable(name) @since(1.4) def json(self, path, mode="error"): -- cgit v1.2.3