From 1aad8c6e59c1e8b18a3eaa8ded93ff6ad05d83df Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 28 Jun 2016 13:43:59 -0700 Subject: [SPARK-16259][PYSPARK] cleanup options in DataFrame read/write API ## What changes were proposed in this pull request? There are some duplicated code for options in DataFrame reader/writer API, this PR clean them up, it also fix a bug for `escapeQuotes` of csv(). ## How was this patch tested? Existing tests. Author: Davies Liu Closes #13948 from davies/csv_options. --- python/pyspark/sql/readwriter.py | 119 +++++++-------------------------------- 1 file changed, 20 insertions(+), 99 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index ccbf895c2d..3f28d7ad50 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -44,84 +44,20 @@ def to_str(value): return str(value) -class ReaderUtils(object): +class OptionUtils(object): - def _set_json_opts(self, schema, primitivesAsString, prefersDecimal, - allowComments, allowUnquotedFieldNames, allowSingleQuotes, - allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, - mode, columnNameOfCorruptRecord): + def _set_opts(self, schema=None, **options): """ - Set options based on the Json optional parameters + Set named options (filter out those the value is None) """ if schema is not None: self.schema(schema) - if primitivesAsString is not None: - self.option("primitivesAsString", primitivesAsString) - if prefersDecimal is not None: - self.option("prefersDecimal", prefersDecimal) - if allowComments is not None: - self.option("allowComments", allowComments) - if allowUnquotedFieldNames is not None: - self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) - if allowSingleQuotes is not None: - self.option("allowSingleQuotes", allowSingleQuotes) - if allowNumericLeadingZero is not None: - self.option("allowNumericLeadingZero", allowNumericLeadingZero) - if allowBackslashEscapingAnyCharacter is not None: - self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) - if mode is not None: - self.option("mode", mode) - if columnNameOfCorruptRecord is not None: - self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - - def _set_csv_opts(self, schema, sep, encoding, quote, escape, - comment, header, inferSchema, ignoreLeadingWhiteSpace, - ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, - dateFormat, maxColumns, maxCharsPerColumn, maxMalformedLogPerPartition, mode): - """ - Set options based on the CSV optional parameters - """ - if schema is not None: - self.schema(schema) - if sep is not None: - self.option("sep", sep) - if encoding is not None: - self.option("encoding", encoding) - if quote is not None: - self.option("quote", quote) - if escape is not None: - self.option("escape", escape) - if comment is not None: - self.option("comment", comment) - if header is not None: - self.option("header", header) - if inferSchema is not None: - self.option("inferSchema", inferSchema) - if ignoreLeadingWhiteSpace is not None: - self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) - if ignoreTrailingWhiteSpace is not None: - self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) - if nullValue is not None: - self.option("nullValue", nullValue) - if nanValue is not None: - self.option("nanValue", nanValue) - if positiveInf is not None: - self.option("positiveInf", positiveInf) - if negativeInf is not None: - self.option("negativeInf", negativeInf) - if dateFormat is not None: - self.option("dateFormat", dateFormat) - if maxColumns is not None: - self.option("maxColumns", maxColumns) - if maxCharsPerColumn is not None: - self.option("maxCharsPerColumn", maxCharsPerColumn) - if maxMalformedLogPerPartition is not None: - self.option("maxMalformedLogPerPartition", maxMalformedLogPerPartition) - if mode is not None: - self.option("mode", mode) - - -class DataFrameReader(ReaderUtils): + for k, v in options.items(): + if v is not None: + self.option(k, v) + + +class DataFrameReader(OptionUtils): """ Interface used to load a :class:`DataFrame` from external storage systems (e.g. file systems, key-value stores, etc). Use :func:`spark.read` @@ -270,7 +206,7 @@ class DataFrameReader(ReaderUtils): [('age', 'bigint'), ('name', 'string')] """ - self._set_json_opts( + self._set_opts( schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal, allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, @@ -413,7 +349,7 @@ class DataFrameReader(ReaderUtils): >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] """ - self._set_csv_opts( + self._set_opts( schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, @@ -484,7 +420,7 @@ class DataFrameReader(ReaderUtils): return self._df(self._jreader.jdbc(url, table, jprop)) -class DataFrameWriter(object): +class DataFrameWriter(OptionUtils): """ Interface used to write a :class:`DataFrame` to external storage systems (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` @@ -649,8 +585,7 @@ class DataFrameWriter(object): >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression) self._jwrite.json(path) @since(1.4) @@ -676,8 +611,7 @@ class DataFrameWriter(object): self.mode(mode) if partitionBy is not None: self.partitionBy(partitionBy) - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression) self._jwrite.parquet(path) @since(1.6) @@ -692,8 +626,7 @@ class DataFrameWriter(object): The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression) self._jwrite.text(path) @since(2.0) @@ -731,20 +664,8 @@ class DataFrameWriter(object): >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - if compression is not None: - self.option("compression", compression) - if sep is not None: - self.option("sep", sep) - if quote is not None: - self.option("quote", quote) - if escape is not None: - self.option("escape", escape) - if header is not None: - self.option("header", header) - if nullValue is not None: - self.option("nullValue", nullValue) - if escapeQuotes is not None: - self.option("escapeQuotes", nullValue) + self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, + nullValue=nullValue, escapeQuotes=escapeQuotes) self._jwrite.csv(path) @since(1.5) @@ -803,7 +724,7 @@ class DataFrameWriter(object): self._jwrite.mode(mode).jdbc(url, table, jprop) -class DataStreamReader(ReaderUtils): +class DataStreamReader(OptionUtils): """ Interface used to load a streaming :class:`DataFrame` from external storage systems (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` @@ -965,7 +886,7 @@ class DataStreamReader(ReaderUtils): >>> json_sdf.schema == sdf_schema True """ - self._set_json_opts( + self._set_opts( schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal, allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, @@ -1095,7 +1016,7 @@ class DataStreamReader(ReaderUtils): >>> csv_sdf.schema == sdf_schema True """ - self._set_csv_opts( + self._set_opts( schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, -- cgit v1.2.3