diff options
Diffstat (limited to 'python/pyspark/sql/readwriter.py')
-rw-r--r-- | python/pyspark/sql/readwriter.py | 258 |
1 files changed, 136 insertions, 122 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index c982de6840..72fd184d58 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -44,7 +44,82 @@ def to_str(value): return str(value) -class DataFrameReader(object): +class ReaderUtils(object): + + def _set_json_opts(self, schema, primitivesAsString, prefersDecimal, + allowComments, allowUnquotedFieldNames, allowSingleQuotes, + allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, + mode, columnNameOfCorruptRecord): + """ + Set options based on the Json optional parameters + """ + if schema is not None: + self.schema(schema) + if primitivesAsString is not None: + self.option("primitivesAsString", primitivesAsString) + if prefersDecimal is not None: + self.option("prefersDecimal", prefersDecimal) + if allowComments is not None: + self.option("allowComments", allowComments) + if allowUnquotedFieldNames is not None: + self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) + if allowSingleQuotes is not None: + self.option("allowSingleQuotes", allowSingleQuotes) + if allowNumericLeadingZero is not None: + self.option("allowNumericLeadingZero", allowNumericLeadingZero) + if allowBackslashEscapingAnyCharacter is not None: + self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) + if mode is not None: + self.option("mode", mode) + if columnNameOfCorruptRecord is not None: + self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + + def _set_csv_opts(self, schema, sep, encoding, quote, escape, + comment, header, inferSchema, ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, + dateFormat, maxColumns, maxCharsPerColumn, mode): + """ + Set options based on the CSV optional parameters + """ + if schema is not None: + self.schema(schema) + if sep is not None: + self.option("sep", sep) + if encoding is not None: + self.option("encoding", encoding) + if quote is not None: + self.option("quote", quote) + if escape is not None: + self.option("escape", escape) + if comment is not None: + self.option("comment", comment) + if header is not None: + self.option("header", header) + if inferSchema is not None: + self.option("inferSchema", inferSchema) + if ignoreLeadingWhiteSpace is not None: + self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + if ignoreTrailingWhiteSpace is not None: + self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + if nullValue is not None: + self.option("nullValue", nullValue) + if nanValue is not None: + self.option("nanValue", nanValue) + if positiveInf is not None: + self.option("positiveInf", positiveInf) + if negativeInf is not None: + self.option("negativeInf", negativeInf) + if dateFormat is not None: + self.option("dateFormat", dateFormat) + if maxColumns is not None: + self.option("maxColumns", maxColumns) + if maxCharsPerColumn is not None: + self.option("maxCharsPerColumn", maxCharsPerColumn) + if mode is not None: + self.option("mode", mode) + + +class DataFrameReader(ReaderUtils): """ Interface used to load a :class:`DataFrame` from external storage systems (e.g. file systems, key-value stores, etc). Use :func:`spark.read` @@ -193,26 +268,10 @@ class DataFrameReader(object): [('age', 'bigint'), ('name', 'string')] """ - if schema is not None: - self.schema(schema) - if primitivesAsString is not None: - self.option("primitivesAsString", primitivesAsString) - if prefersDecimal is not None: - self.option("prefersDecimal", prefersDecimal) - if allowComments is not None: - self.option("allowComments", allowComments) - if allowUnquotedFieldNames is not None: - self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) - if allowSingleQuotes is not None: - self.option("allowSingleQuotes", allowSingleQuotes) - if allowNumericLeadingZero is not None: - self.option("allowNumericLeadingZero", allowNumericLeadingZero) - if allowBackslashEscapingAnyCharacter is not None: - self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) - if mode is not None: - self.option("mode", mode) - if columnNameOfCorruptRecord is not None: - self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + self._set_json_opts(schema, primitivesAsString, prefersDecimal, + allowComments, allowUnquotedFieldNames, allowSingleQuotes, + allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, + mode, columnNameOfCorruptRecord) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -345,42 +404,11 @@ class DataFrameReader(object): >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] """ - if schema is not None: - self.schema(schema) - if sep is not None: - self.option("sep", sep) - if encoding is not None: - self.option("encoding", encoding) - if quote is not None: - self.option("quote", quote) - if escape is not None: - self.option("escape", escape) - if comment is not None: - self.option("comment", comment) - if header is not None: - self.option("header", header) - if inferSchema is not None: - self.option("inferSchema", inferSchema) - if ignoreLeadingWhiteSpace is not None: - self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) - if ignoreTrailingWhiteSpace is not None: - self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) - if nullValue is not None: - self.option("nullValue", nullValue) - if nanValue is not None: - self.option("nanValue", nanValue) - if positiveInf is not None: - self.option("positiveInf", positiveInf) - if negativeInf is not None: - self.option("negativeInf", negativeInf) - if dateFormat is not None: - self.option("dateFormat", dateFormat) - if maxColumns is not None: - self.option("maxColumns", maxColumns) - if maxCharsPerColumn is not None: - self.option("maxCharsPerColumn", maxCharsPerColumn) - if mode is not None: - self.option("mode", mode) + + self._set_csv_opts(schema, sep, encoding, quote, escape, + comment, header, inferSchema, ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, + dateFormat, maxColumns, maxCharsPerColumn, mode) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) @@ -764,7 +792,7 @@ class DataFrameWriter(object): self._jwrite.mode(mode).jdbc(url, table, jprop) -class DataStreamReader(object): +class DataStreamReader(ReaderUtils): """ Interface used to load a streaming :class:`DataFrame` from external storage systems (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` @@ -791,6 +819,7 @@ class DataStreamReader(object): :param source: string, name of the data source, e.g. 'json', 'parquet'. + >>> s = spark.readStream.format("text") """ self._jreader = self._jreader.format(source) return self @@ -806,6 +835,8 @@ class DataStreamReader(object): .. note:: Experimental. :param schema: a StructType object + + >>> s = spark.readStream.schema(sdf_schema) """ if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -818,6 +849,8 @@ class DataStreamReader(object): """Adds an input option for the underlying data source. .. note:: Experimental. + + >>> s = spark.readStream.option("x", 1) """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -827,6 +860,8 @@ class DataStreamReader(object): """Adds input options for the underlying data source. .. note:: Experimental. + + >>> s = spark.readStream.options(x="1", y=2) """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -843,6 +878,13 @@ class DataStreamReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options + >>> json_sdf = spark.readStream.format("json")\ + .schema(sdf_schema)\ + .load(os.path.join(tempfile.mkdtemp(),'data')) + >>> json_sdf.isStreaming + True + >>> json_sdf.schema == sdf_schema + True """ if format is not None: self.format(format) @@ -905,29 +947,18 @@ class DataStreamReader(object): it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + >>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), \ + schema = sdf_schema) + >>> json_sdf.isStreaming + True + >>> json_sdf.schema == sdf_schema + True """ - if schema is not None: - self.schema(schema) - if primitivesAsString is not None: - self.option("primitivesAsString", primitivesAsString) - if prefersDecimal is not None: - self.option("prefersDecimal", prefersDecimal) - if allowComments is not None: - self.option("allowComments", allowComments) - if allowUnquotedFieldNames is not None: - self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) - if allowSingleQuotes is not None: - self.option("allowSingleQuotes", allowSingleQuotes) - if allowNumericLeadingZero is not None: - self.option("allowNumericLeadingZero", allowNumericLeadingZero) - if allowBackslashEscapingAnyCharacter is not None: - self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) - if mode is not None: - self.option("mode", mode) - if columnNameOfCorruptRecord is not None: - self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + self._set_json_opts(schema, primitivesAsString, prefersDecimal, + allowComments, allowUnquotedFieldNames, allowSingleQuotes, + allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, + mode, columnNameOfCorruptRecord) if isinstance(path, basestring): - path = [path] return self._df(self._jreader.json(path)) else: raise TypeError("path can be only a single string") @@ -943,10 +974,15 @@ class DataStreamReader(object): .. note:: Experimental. + >>> parquet_sdf = spark.readStream.schema(sdf_schema)\ + .parquet(os.path.join(tempfile.mkdtemp())) + >>> parquet_sdf.isStreaming + True + >>> parquet_sdf.schema == sdf_schema + True """ if isinstance(path, basestring): - path = [path] - return self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.parquet(path)) else: raise TypeError("path can be only a single string") @@ -964,10 +1000,14 @@ class DataStreamReader(object): :param paths: string, or list of strings, for input path(s). + >>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 'data')) + >>> text_sdf.isStreaming + True + >>> "value" in str(text_sdf.schema) + True """ if isinstance(path, basestring): - path = [path] - return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.text(path)) else: raise TypeError("path can be only a single string") @@ -1034,46 +1074,20 @@ class DataStreamReader(object): * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + >>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 'data'), \ + schema = sdf_schema) + >>> csv_sdf.isStreaming + True + >>> csv_sdf.schema == sdf_schema + True """ - if schema is not None: - self.schema(schema) - if sep is not None: - self.option("sep", sep) - if encoding is not None: - self.option("encoding", encoding) - if quote is not None: - self.option("quote", quote) - if escape is not None: - self.option("escape", escape) - if comment is not None: - self.option("comment", comment) - if header is not None: - self.option("header", header) - if inferSchema is not None: - self.option("inferSchema", inferSchema) - if ignoreLeadingWhiteSpace is not None: - self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) - if ignoreTrailingWhiteSpace is not None: - self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) - if nullValue is not None: - self.option("nullValue", nullValue) - if nanValue is not None: - self.option("nanValue", nanValue) - if positiveInf is not None: - self.option("positiveInf", positiveInf) - if negativeInf is not None: - self.option("negativeInf", negativeInf) - if dateFormat is not None: - self.option("dateFormat", dateFormat) - if maxColumns is not None: - self.option("maxColumns", maxColumns) - if maxCharsPerColumn is not None: - self.option("maxCharsPerColumn", maxCharsPerColumn) - if mode is not None: - self.option("mode", mode) + + self._set_csv_opts(schema, sep, encoding, quote, escape, + comment, header, inferSchema, ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, + dateFormat, maxColumns, maxCharsPerColumn, mode) if isinstance(path, basestring): - path = [path] - return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.csv(path)) else: raise TypeError("path can be only a single string") @@ -1286,7 +1300,7 @@ def _test(): globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') globs['sdf'] = \ spark.readStream.format('text').load('python/test_support/sql/streaming') - + globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) |