aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/readwriter.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/readwriter.py')
-rw-r--r--python/pyspark/sql/readwriter.py258
1 files changed, 136 insertions, 122 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index c982de6840..72fd184d58 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -44,7 +44,82 @@ def to_str(value):
return str(value)
-class DataFrameReader(object):
+class ReaderUtils(object):
+
+ def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
+ allowComments, allowUnquotedFieldNames, allowSingleQuotes,
+ allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
+ mode, columnNameOfCorruptRecord):
+ """
+ Set options based on the Json optional parameters
+ """
+ if schema is not None:
+ self.schema(schema)
+ if primitivesAsString is not None:
+ self.option("primitivesAsString", primitivesAsString)
+ if prefersDecimal is not None:
+ self.option("prefersDecimal", prefersDecimal)
+ if allowComments is not None:
+ self.option("allowComments", allowComments)
+ if allowUnquotedFieldNames is not None:
+ self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
+ if allowSingleQuotes is not None:
+ self.option("allowSingleQuotes", allowSingleQuotes)
+ if allowNumericLeadingZero is not None:
+ self.option("allowNumericLeadingZero", allowNumericLeadingZero)
+ if allowBackslashEscapingAnyCharacter is not None:
+ self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
+ if mode is not None:
+ self.option("mode", mode)
+ if columnNameOfCorruptRecord is not None:
+ self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+
+ def _set_csv_opts(self, schema, sep, encoding, quote, escape,
+ comment, header, inferSchema, ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
+ dateFormat, maxColumns, maxCharsPerColumn, mode):
+ """
+ Set options based on the CSV optional parameters
+ """
+ if schema is not None:
+ self.schema(schema)
+ if sep is not None:
+ self.option("sep", sep)
+ if encoding is not None:
+ self.option("encoding", encoding)
+ if quote is not None:
+ self.option("quote", quote)
+ if escape is not None:
+ self.option("escape", escape)
+ if comment is not None:
+ self.option("comment", comment)
+ if header is not None:
+ self.option("header", header)
+ if inferSchema is not None:
+ self.option("inferSchema", inferSchema)
+ if ignoreLeadingWhiteSpace is not None:
+ self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
+ if ignoreTrailingWhiteSpace is not None:
+ self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
+ if nullValue is not None:
+ self.option("nullValue", nullValue)
+ if nanValue is not None:
+ self.option("nanValue", nanValue)
+ if positiveInf is not None:
+ self.option("positiveInf", positiveInf)
+ if negativeInf is not None:
+ self.option("negativeInf", negativeInf)
+ if dateFormat is not None:
+ self.option("dateFormat", dateFormat)
+ if maxColumns is not None:
+ self.option("maxColumns", maxColumns)
+ if maxCharsPerColumn is not None:
+ self.option("maxCharsPerColumn", maxCharsPerColumn)
+ if mode is not None:
+ self.option("mode", mode)
+
+
+class DataFrameReader(ReaderUtils):
"""
Interface used to load a :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.read`
@@ -193,26 +268,10 @@ class DataFrameReader(object):
[('age', 'bigint'), ('name', 'string')]
"""
- if schema is not None:
- self.schema(schema)
- if primitivesAsString is not None:
- self.option("primitivesAsString", primitivesAsString)
- if prefersDecimal is not None:
- self.option("prefersDecimal", prefersDecimal)
- if allowComments is not None:
- self.option("allowComments", allowComments)
- if allowUnquotedFieldNames is not None:
- self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
- if allowSingleQuotes is not None:
- self.option("allowSingleQuotes", allowSingleQuotes)
- if allowNumericLeadingZero is not None:
- self.option("allowNumericLeadingZero", allowNumericLeadingZero)
- if allowBackslashEscapingAnyCharacter is not None:
- self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
- if mode is not None:
- self.option("mode", mode)
- if columnNameOfCorruptRecord is not None:
- self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ self._set_json_opts(schema, primitivesAsString, prefersDecimal,
+ allowComments, allowUnquotedFieldNames, allowSingleQuotes,
+ allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
+ mode, columnNameOfCorruptRecord)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -345,42 +404,11 @@ class DataFrameReader(object):
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
- if schema is not None:
- self.schema(schema)
- if sep is not None:
- self.option("sep", sep)
- if encoding is not None:
- self.option("encoding", encoding)
- if quote is not None:
- self.option("quote", quote)
- if escape is not None:
- self.option("escape", escape)
- if comment is not None:
- self.option("comment", comment)
- if header is not None:
- self.option("header", header)
- if inferSchema is not None:
- self.option("inferSchema", inferSchema)
- if ignoreLeadingWhiteSpace is not None:
- self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
- if ignoreTrailingWhiteSpace is not None:
- self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
- if nullValue is not None:
- self.option("nullValue", nullValue)
- if nanValue is not None:
- self.option("nanValue", nanValue)
- if positiveInf is not None:
- self.option("positiveInf", positiveInf)
- if negativeInf is not None:
- self.option("negativeInf", negativeInf)
- if dateFormat is not None:
- self.option("dateFormat", dateFormat)
- if maxColumns is not None:
- self.option("maxColumns", maxColumns)
- if maxCharsPerColumn is not None:
- self.option("maxCharsPerColumn", maxCharsPerColumn)
- if mode is not None:
- self.option("mode", mode)
+
+ self._set_csv_opts(schema, sep, encoding, quote, escape,
+ comment, header, inferSchema, ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
+ dateFormat, maxColumns, maxCharsPerColumn, mode)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@@ -764,7 +792,7 @@ class DataFrameWriter(object):
self._jwrite.mode(mode).jdbc(url, table, jprop)
-class DataStreamReader(object):
+class DataStreamReader(ReaderUtils):
"""
Interface used to load a streaming :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
@@ -791,6 +819,7 @@ class DataStreamReader(object):
:param source: string, name of the data source, e.g. 'json', 'parquet'.
+ >>> s = spark.readStream.format("text")
"""
self._jreader = self._jreader.format(source)
return self
@@ -806,6 +835,8 @@ class DataStreamReader(object):
.. note:: Experimental.
:param schema: a StructType object
+
+ >>> s = spark.readStream.schema(sdf_schema)
"""
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
@@ -818,6 +849,8 @@ class DataStreamReader(object):
"""Adds an input option for the underlying data source.
.. note:: Experimental.
+
+ >>> s = spark.readStream.option("x", 1)
"""
self._jreader = self._jreader.option(key, to_str(value))
return self
@@ -827,6 +860,8 @@ class DataStreamReader(object):
"""Adds input options for the underlying data source.
.. note:: Experimental.
+
+ >>> s = spark.readStream.options(x="1", y=2)
"""
for k in options:
self._jreader = self._jreader.option(k, to_str(options[k]))
@@ -843,6 +878,13 @@ class DataStreamReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
+ >>> json_sdf = spark.readStream.format("json")\
+ .schema(sdf_schema)\
+ .load(os.path.join(tempfile.mkdtemp(),'data'))
+ >>> json_sdf.isStreaming
+ True
+ >>> json_sdf.schema == sdf_schema
+ True
"""
if format is not None:
self.format(format)
@@ -905,29 +947,18 @@ class DataStreamReader(object):
it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``.
+ >>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), \
+ schema = sdf_schema)
+ >>> json_sdf.isStreaming
+ True
+ >>> json_sdf.schema == sdf_schema
+ True
"""
- if schema is not None:
- self.schema(schema)
- if primitivesAsString is not None:
- self.option("primitivesAsString", primitivesAsString)
- if prefersDecimal is not None:
- self.option("prefersDecimal", prefersDecimal)
- if allowComments is not None:
- self.option("allowComments", allowComments)
- if allowUnquotedFieldNames is not None:
- self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
- if allowSingleQuotes is not None:
- self.option("allowSingleQuotes", allowSingleQuotes)
- if allowNumericLeadingZero is not None:
- self.option("allowNumericLeadingZero", allowNumericLeadingZero)
- if allowBackslashEscapingAnyCharacter is not None:
- self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
- if mode is not None:
- self.option("mode", mode)
- if columnNameOfCorruptRecord is not None:
- self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ self._set_json_opts(schema, primitivesAsString, prefersDecimal,
+ allowComments, allowUnquotedFieldNames, allowSingleQuotes,
+ allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
+ mode, columnNameOfCorruptRecord)
if isinstance(path, basestring):
- path = [path]
return self._df(self._jreader.json(path))
else:
raise TypeError("path can be only a single string")
@@ -943,10 +974,15 @@ class DataStreamReader(object):
.. note:: Experimental.
+ >>> parquet_sdf = spark.readStream.schema(sdf_schema)\
+ .parquet(os.path.join(tempfile.mkdtemp()))
+ >>> parquet_sdf.isStreaming
+ True
+ >>> parquet_sdf.schema == sdf_schema
+ True
"""
if isinstance(path, basestring):
- path = [path]
- return self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.parquet(path))
else:
raise TypeError("path can be only a single string")
@@ -964,10 +1000,14 @@ class DataStreamReader(object):
:param paths: string, or list of strings, for input path(s).
+ >>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 'data'))
+ >>> text_sdf.isStreaming
+ True
+ >>> "value" in str(text_sdf.schema)
+ True
"""
if isinstance(path, basestring):
- path = [path]
- return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.text(path))
else:
raise TypeError("path can be only a single string")
@@ -1034,46 +1074,20 @@ class DataStreamReader(object):
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
+ >>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 'data'), \
+ schema = sdf_schema)
+ >>> csv_sdf.isStreaming
+ True
+ >>> csv_sdf.schema == sdf_schema
+ True
"""
- if schema is not None:
- self.schema(schema)
- if sep is not None:
- self.option("sep", sep)
- if encoding is not None:
- self.option("encoding", encoding)
- if quote is not None:
- self.option("quote", quote)
- if escape is not None:
- self.option("escape", escape)
- if comment is not None:
- self.option("comment", comment)
- if header is not None:
- self.option("header", header)
- if inferSchema is not None:
- self.option("inferSchema", inferSchema)
- if ignoreLeadingWhiteSpace is not None:
- self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
- if ignoreTrailingWhiteSpace is not None:
- self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
- if nullValue is not None:
- self.option("nullValue", nullValue)
- if nanValue is not None:
- self.option("nanValue", nanValue)
- if positiveInf is not None:
- self.option("positiveInf", positiveInf)
- if negativeInf is not None:
- self.option("negativeInf", negativeInf)
- if dateFormat is not None:
- self.option("dateFormat", dateFormat)
- if maxColumns is not None:
- self.option("maxColumns", maxColumns)
- if maxCharsPerColumn is not None:
- self.option("maxCharsPerColumn", maxCharsPerColumn)
- if mode is not None:
- self.option("mode", mode)
+
+ self._set_csv_opts(schema, sep, encoding, quote, escape,
+ comment, header, inferSchema, ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
+ dateFormat, maxColumns, maxCharsPerColumn, mode)
if isinstance(path, basestring):
- path = [path]
- return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.csv(path))
else:
raise TypeError("path can be only a single string")
@@ -1286,7 +1300,7 @@ def _test():
globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned')
globs['sdf'] = \
spark.readStream.format('text').load('python/test_support/sql/streaming')
-
+ globs['sdf_schema'] = StructType([StructField("data", StringType(), False)])
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)