aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2016-06-16 13:17:41 -0700
committerShixiong Zhu <shixiong@databricks.com>2016-06-16 13:17:41 -0700
commit084dca770f5c26f906e7555707c7894cf05fb86b (patch)
tree123f08366594b6806067cef9128cf19764effafb /python
parenta865f6e05297f6121bb2fde717860f9edeed263e (diff)
downloadspark-084dca770f5c26f906e7555707c7894cf05fb86b.tar.gz
spark-084dca770f5c26f906e7555707c7894cf05fb86b.tar.bz2
spark-084dca770f5c26f906e7555707c7894cf05fb86b.zip
[SPARK-15981][SQL][STREAMING] Fixed bug and added tests in DataStreamReader Python API
## What changes were proposed in this pull request? - Fixed bug in Python API of DataStreamReader. Because a single path was being converted to a array before calling Java DataStreamReader method (which takes a string only), it gave the following error. ``` File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 947, in pyspark.sql.readwriter.DataStreamReader.json Failed example: json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), schema = sdf_schema) Exception raised: Traceback (most recent call last): File "/System/Library/Frameworks/Python.framework/Versions/2.6/lib/python2.6/doctest.py", line 1253, in __run compileflags, 1) in test.globs File "<doctest pyspark.sql.readwriter.DataStreamReader.json[0]>", line 1, in <module> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), schema = sdf_schema) File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 963, in json return self._df(self._jreader.json(path)) File "/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__ answer, self.gateway_client, self.target_id, self.name) File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 316, in get_return_value format(target_id, ".", name, value)) Py4JError: An error occurred while calling o121.json. Trace: py4j.Py4JException: Method json([class java.util.ArrayList]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:272) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:211) at java.lang.Thread.run(Thread.java:744) ``` - Reduced code duplication between DataStreamReader and DataFrameWriter - Added missing Python doctests ## How was this patch tested? New tests Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #13703 from tdas/SPARK-15981.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/readwriter.py258
1 files changed, 136 insertions, 122 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index c982de6840..72fd184d58 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -44,7 +44,82 @@ def to_str(value):
return str(value)
-class DataFrameReader(object):
+class ReaderUtils(object):
+
+ def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
+ allowComments, allowUnquotedFieldNames, allowSingleQuotes,
+ allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
+ mode, columnNameOfCorruptRecord):
+ """
+ Set options based on the Json optional parameters
+ """
+ if schema is not None:
+ self.schema(schema)
+ if primitivesAsString is not None:
+ self.option("primitivesAsString", primitivesAsString)
+ if prefersDecimal is not None:
+ self.option("prefersDecimal", prefersDecimal)
+ if allowComments is not None:
+ self.option("allowComments", allowComments)
+ if allowUnquotedFieldNames is not None:
+ self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
+ if allowSingleQuotes is not None:
+ self.option("allowSingleQuotes", allowSingleQuotes)
+ if allowNumericLeadingZero is not None:
+ self.option("allowNumericLeadingZero", allowNumericLeadingZero)
+ if allowBackslashEscapingAnyCharacter is not None:
+ self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
+ if mode is not None:
+ self.option("mode", mode)
+ if columnNameOfCorruptRecord is not None:
+ self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+
+ def _set_csv_opts(self, schema, sep, encoding, quote, escape,
+ comment, header, inferSchema, ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
+ dateFormat, maxColumns, maxCharsPerColumn, mode):
+ """
+ Set options based on the CSV optional parameters
+ """
+ if schema is not None:
+ self.schema(schema)
+ if sep is not None:
+ self.option("sep", sep)
+ if encoding is not None:
+ self.option("encoding", encoding)
+ if quote is not None:
+ self.option("quote", quote)
+ if escape is not None:
+ self.option("escape", escape)
+ if comment is not None:
+ self.option("comment", comment)
+ if header is not None:
+ self.option("header", header)
+ if inferSchema is not None:
+ self.option("inferSchema", inferSchema)
+ if ignoreLeadingWhiteSpace is not None:
+ self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
+ if ignoreTrailingWhiteSpace is not None:
+ self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
+ if nullValue is not None:
+ self.option("nullValue", nullValue)
+ if nanValue is not None:
+ self.option("nanValue", nanValue)
+ if positiveInf is not None:
+ self.option("positiveInf", positiveInf)
+ if negativeInf is not None:
+ self.option("negativeInf", negativeInf)
+ if dateFormat is not None:
+ self.option("dateFormat", dateFormat)
+ if maxColumns is not None:
+ self.option("maxColumns", maxColumns)
+ if maxCharsPerColumn is not None:
+ self.option("maxCharsPerColumn", maxCharsPerColumn)
+ if mode is not None:
+ self.option("mode", mode)
+
+
+class DataFrameReader(ReaderUtils):
"""
Interface used to load a :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.read`
@@ -193,26 +268,10 @@ class DataFrameReader(object):
[('age', 'bigint'), ('name', 'string')]
"""
- if schema is not None:
- self.schema(schema)
- if primitivesAsString is not None:
- self.option("primitivesAsString", primitivesAsString)
- if prefersDecimal is not None:
- self.option("prefersDecimal", prefersDecimal)
- if allowComments is not None:
- self.option("allowComments", allowComments)
- if allowUnquotedFieldNames is not None:
- self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
- if allowSingleQuotes is not None:
- self.option("allowSingleQuotes", allowSingleQuotes)
- if allowNumericLeadingZero is not None:
- self.option("allowNumericLeadingZero", allowNumericLeadingZero)
- if allowBackslashEscapingAnyCharacter is not None:
- self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
- if mode is not None:
- self.option("mode", mode)
- if columnNameOfCorruptRecord is not None:
- self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ self._set_json_opts(schema, primitivesAsString, prefersDecimal,
+ allowComments, allowUnquotedFieldNames, allowSingleQuotes,
+ allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
+ mode, columnNameOfCorruptRecord)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -345,42 +404,11 @@ class DataFrameReader(object):
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
- if schema is not None:
- self.schema(schema)
- if sep is not None:
- self.option("sep", sep)
- if encoding is not None:
- self.option("encoding", encoding)
- if quote is not None:
- self.option("quote", quote)
- if escape is not None:
- self.option("escape", escape)
- if comment is not None:
- self.option("comment", comment)
- if header is not None:
- self.option("header", header)
- if inferSchema is not None:
- self.option("inferSchema", inferSchema)
- if ignoreLeadingWhiteSpace is not None:
- self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
- if ignoreTrailingWhiteSpace is not None:
- self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
- if nullValue is not None:
- self.option("nullValue", nullValue)
- if nanValue is not None:
- self.option("nanValue", nanValue)
- if positiveInf is not None:
- self.option("positiveInf", positiveInf)
- if negativeInf is not None:
- self.option("negativeInf", negativeInf)
- if dateFormat is not None:
- self.option("dateFormat", dateFormat)
- if maxColumns is not None:
- self.option("maxColumns", maxColumns)
- if maxCharsPerColumn is not None:
- self.option("maxCharsPerColumn", maxCharsPerColumn)
- if mode is not None:
- self.option("mode", mode)
+
+ self._set_csv_opts(schema, sep, encoding, quote, escape,
+ comment, header, inferSchema, ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
+ dateFormat, maxColumns, maxCharsPerColumn, mode)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@@ -764,7 +792,7 @@ class DataFrameWriter(object):
self._jwrite.mode(mode).jdbc(url, table, jprop)
-class DataStreamReader(object):
+class DataStreamReader(ReaderUtils):
"""
Interface used to load a streaming :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
@@ -791,6 +819,7 @@ class DataStreamReader(object):
:param source: string, name of the data source, e.g. 'json', 'parquet'.
+ >>> s = spark.readStream.format("text")
"""
self._jreader = self._jreader.format(source)
return self
@@ -806,6 +835,8 @@ class DataStreamReader(object):
.. note:: Experimental.
:param schema: a StructType object
+
+ >>> s = spark.readStream.schema(sdf_schema)
"""
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
@@ -818,6 +849,8 @@ class DataStreamReader(object):
"""Adds an input option for the underlying data source.
.. note:: Experimental.
+
+ >>> s = spark.readStream.option("x", 1)
"""
self._jreader = self._jreader.option(key, to_str(value))
return self
@@ -827,6 +860,8 @@ class DataStreamReader(object):
"""Adds input options for the underlying data source.
.. note:: Experimental.
+
+ >>> s = spark.readStream.options(x="1", y=2)
"""
for k in options:
self._jreader = self._jreader.option(k, to_str(options[k]))
@@ -843,6 +878,13 @@ class DataStreamReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
+ >>> json_sdf = spark.readStream.format("json")\
+ .schema(sdf_schema)\
+ .load(os.path.join(tempfile.mkdtemp(),'data'))
+ >>> json_sdf.isStreaming
+ True
+ >>> json_sdf.schema == sdf_schema
+ True
"""
if format is not None:
self.format(format)
@@ -905,29 +947,18 @@ class DataStreamReader(object):
it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``.
+ >>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), \
+ schema = sdf_schema)
+ >>> json_sdf.isStreaming
+ True
+ >>> json_sdf.schema == sdf_schema
+ True
"""
- if schema is not None:
- self.schema(schema)
- if primitivesAsString is not None:
- self.option("primitivesAsString", primitivesAsString)
- if prefersDecimal is not None:
- self.option("prefersDecimal", prefersDecimal)
- if allowComments is not None:
- self.option("allowComments", allowComments)
- if allowUnquotedFieldNames is not None:
- self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
- if allowSingleQuotes is not None:
- self.option("allowSingleQuotes", allowSingleQuotes)
- if allowNumericLeadingZero is not None:
- self.option("allowNumericLeadingZero", allowNumericLeadingZero)
- if allowBackslashEscapingAnyCharacter is not None:
- self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
- if mode is not None:
- self.option("mode", mode)
- if columnNameOfCorruptRecord is not None:
- self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ self._set_json_opts(schema, primitivesAsString, prefersDecimal,
+ allowComments, allowUnquotedFieldNames, allowSingleQuotes,
+ allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
+ mode, columnNameOfCorruptRecord)
if isinstance(path, basestring):
- path = [path]
return self._df(self._jreader.json(path))
else:
raise TypeError("path can be only a single string")
@@ -943,10 +974,15 @@ class DataStreamReader(object):
.. note:: Experimental.
+ >>> parquet_sdf = spark.readStream.schema(sdf_schema)\
+ .parquet(os.path.join(tempfile.mkdtemp()))
+ >>> parquet_sdf.isStreaming
+ True
+ >>> parquet_sdf.schema == sdf_schema
+ True
"""
if isinstance(path, basestring):
- path = [path]
- return self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.parquet(path))
else:
raise TypeError("path can be only a single string")
@@ -964,10 +1000,14 @@ class DataStreamReader(object):
:param paths: string, or list of strings, for input path(s).
+ >>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 'data'))
+ >>> text_sdf.isStreaming
+ True
+ >>> "value" in str(text_sdf.schema)
+ True
"""
if isinstance(path, basestring):
- path = [path]
- return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.text(path))
else:
raise TypeError("path can be only a single string")
@@ -1034,46 +1074,20 @@ class DataStreamReader(object):
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
+ >>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 'data'), \
+ schema = sdf_schema)
+ >>> csv_sdf.isStreaming
+ True
+ >>> csv_sdf.schema == sdf_schema
+ True
"""
- if schema is not None:
- self.schema(schema)
- if sep is not None:
- self.option("sep", sep)
- if encoding is not None:
- self.option("encoding", encoding)
- if quote is not None:
- self.option("quote", quote)
- if escape is not None:
- self.option("escape", escape)
- if comment is not None:
- self.option("comment", comment)
- if header is not None:
- self.option("header", header)
- if inferSchema is not None:
- self.option("inferSchema", inferSchema)
- if ignoreLeadingWhiteSpace is not None:
- self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
- if ignoreTrailingWhiteSpace is not None:
- self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
- if nullValue is not None:
- self.option("nullValue", nullValue)
- if nanValue is not None:
- self.option("nanValue", nanValue)
- if positiveInf is not None:
- self.option("positiveInf", positiveInf)
- if negativeInf is not None:
- self.option("negativeInf", negativeInf)
- if dateFormat is not None:
- self.option("dateFormat", dateFormat)
- if maxColumns is not None:
- self.option("maxColumns", maxColumns)
- if maxCharsPerColumn is not None:
- self.option("maxCharsPerColumn", maxCharsPerColumn)
- if mode is not None:
- self.option("mode", mode)
+
+ self._set_csv_opts(schema, sep, encoding, quote, escape,
+ comment, header, inferSchema, ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
+ dateFormat, maxColumns, maxCharsPerColumn, mode)
if isinstance(path, basestring):
- path = [path]
- return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.csv(path))
else:
raise TypeError("path can be only a single string")
@@ -1286,7 +1300,7 @@ def _test():
globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned')
globs['sdf'] = \
spark.readStream.format('text').load('python/test_support/sql/streaming')
-
+ globs['sdf_schema'] = StructType([StructField("data", StringType(), False)])
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)