diff options
author | Tathagata Das <tathagata.das1565@gmail.com> | 2016-06-16 13:17:41 -0700 |
---|---|---|
committer | Shixiong Zhu <shixiong@databricks.com> | 2016-06-16 13:17:41 -0700 |
commit | 084dca770f5c26f906e7555707c7894cf05fb86b (patch) | |
tree | 123f08366594b6806067cef9128cf19764effafb | |
parent | a865f6e05297f6121bb2fde717860f9edeed263e (diff) | |
download | spark-084dca770f5c26f906e7555707c7894cf05fb86b.tar.gz spark-084dca770f5c26f906e7555707c7894cf05fb86b.tar.bz2 spark-084dca770f5c26f906e7555707c7894cf05fb86b.zip |
[SPARK-15981][SQL][STREAMING] Fixed bug and added tests in DataStreamReader Python API
## What changes were proposed in this pull request?
- Fixed bug in Python API of DataStreamReader. Because a single path was being converted to a array before calling Java DataStreamReader method (which takes a string only), it gave the following error.
```
File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 947, in pyspark.sql.readwriter.DataStreamReader.json
Failed example:
json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), schema = sdf_schema)
Exception raised:
Traceback (most recent call last):
File "/System/Library/Frameworks/Python.framework/Versions/2.6/lib/python2.6/doctest.py", line 1253, in __run
compileflags, 1) in test.globs
File "<doctest pyspark.sql.readwriter.DataStreamReader.json[0]>", line 1, in <module>
json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), schema = sdf_schema)
File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 963, in json
return self._df(self._jreader.json(path))
File "/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 316, in get_return_value
format(target_id, ".", name, value))
Py4JError: An error occurred while calling o121.json. Trace:
py4j.Py4JException: Method json([class java.util.ArrayList]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
at py4j.Gateway.invoke(Gateway.java:272)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:211)
at java.lang.Thread.run(Thread.java:744)
```
- Reduced code duplication between DataStreamReader and DataFrameWriter
- Added missing Python doctests
## How was this patch tested?
New tests
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #13703 from tdas/SPARK-15981.
-rw-r--r-- | python/pyspark/sql/readwriter.py | 258 |
1 files changed, 136 insertions, 122 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index c982de6840..72fd184d58 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -44,7 +44,82 @@ def to_str(value): return str(value) -class DataFrameReader(object): +class ReaderUtils(object): + + def _set_json_opts(self, schema, primitivesAsString, prefersDecimal, + allowComments, allowUnquotedFieldNames, allowSingleQuotes, + allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, + mode, columnNameOfCorruptRecord): + """ + Set options based on the Json optional parameters + """ + if schema is not None: + self.schema(schema) + if primitivesAsString is not None: + self.option("primitivesAsString", primitivesAsString) + if prefersDecimal is not None: + self.option("prefersDecimal", prefersDecimal) + if allowComments is not None: + self.option("allowComments", allowComments) + if allowUnquotedFieldNames is not None: + self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) + if allowSingleQuotes is not None: + self.option("allowSingleQuotes", allowSingleQuotes) + if allowNumericLeadingZero is not None: + self.option("allowNumericLeadingZero", allowNumericLeadingZero) + if allowBackslashEscapingAnyCharacter is not None: + self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) + if mode is not None: + self.option("mode", mode) + if columnNameOfCorruptRecord is not None: + self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + + def _set_csv_opts(self, schema, sep, encoding, quote, escape, + comment, header, inferSchema, ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, + dateFormat, maxColumns, maxCharsPerColumn, mode): + """ + Set options based on the CSV optional parameters + """ + if schema is not None: + self.schema(schema) + if sep is not None: + self.option("sep", sep) + if encoding is not None: + self.option("encoding", encoding) + if quote is not None: + self.option("quote", quote) + if escape is not None: + self.option("escape", escape) + if comment is not None: + self.option("comment", comment) + if header is not None: + self.option("header", header) + if inferSchema is not None: + self.option("inferSchema", inferSchema) + if ignoreLeadingWhiteSpace is not None: + self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + if ignoreTrailingWhiteSpace is not None: + self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + if nullValue is not None: + self.option("nullValue", nullValue) + if nanValue is not None: + self.option("nanValue", nanValue) + if positiveInf is not None: + self.option("positiveInf", positiveInf) + if negativeInf is not None: + self.option("negativeInf", negativeInf) + if dateFormat is not None: + self.option("dateFormat", dateFormat) + if maxColumns is not None: + self.option("maxColumns", maxColumns) + if maxCharsPerColumn is not None: + self.option("maxCharsPerColumn", maxCharsPerColumn) + if mode is not None: + self.option("mode", mode) + + +class DataFrameReader(ReaderUtils): """ Interface used to load a :class:`DataFrame` from external storage systems (e.g. file systems, key-value stores, etc). Use :func:`spark.read` @@ -193,26 +268,10 @@ class DataFrameReader(object): [('age', 'bigint'), ('name', 'string')] """ - if schema is not None: - self.schema(schema) - if primitivesAsString is not None: - self.option("primitivesAsString", primitivesAsString) - if prefersDecimal is not None: - self.option("prefersDecimal", prefersDecimal) - if allowComments is not None: - self.option("allowComments", allowComments) - if allowUnquotedFieldNames is not None: - self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) - if allowSingleQuotes is not None: - self.option("allowSingleQuotes", allowSingleQuotes) - if allowNumericLeadingZero is not None: - self.option("allowNumericLeadingZero", allowNumericLeadingZero) - if allowBackslashEscapingAnyCharacter is not None: - self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) - if mode is not None: - self.option("mode", mode) - if columnNameOfCorruptRecord is not None: - self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + self._set_json_opts(schema, primitivesAsString, prefersDecimal, + allowComments, allowUnquotedFieldNames, allowSingleQuotes, + allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, + mode, columnNameOfCorruptRecord) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -345,42 +404,11 @@ class DataFrameReader(object): >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] """ - if schema is not None: - self.schema(schema) - if sep is not None: - self.option("sep", sep) - if encoding is not None: - self.option("encoding", encoding) - if quote is not None: - self.option("quote", quote) - if escape is not None: - self.option("escape", escape) - if comment is not None: - self.option("comment", comment) - if header is not None: - self.option("header", header) - if inferSchema is not None: - self.option("inferSchema", inferSchema) - if ignoreLeadingWhiteSpace is not None: - self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) - if ignoreTrailingWhiteSpace is not None: - self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) - if nullValue is not None: - self.option("nullValue", nullValue) - if nanValue is not None: - self.option("nanValue", nanValue) - if positiveInf is not None: - self.option("positiveInf", positiveInf) - if negativeInf is not None: - self.option("negativeInf", negativeInf) - if dateFormat is not None: - self.option("dateFormat", dateFormat) - if maxColumns is not None: - self.option("maxColumns", maxColumns) - if maxCharsPerColumn is not None: - self.option("maxCharsPerColumn", maxCharsPerColumn) - if mode is not None: - self.option("mode", mode) + + self._set_csv_opts(schema, sep, encoding, quote, escape, + comment, header, inferSchema, ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, + dateFormat, maxColumns, maxCharsPerColumn, mode) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) @@ -764,7 +792,7 @@ class DataFrameWriter(object): self._jwrite.mode(mode).jdbc(url, table, jprop) -class DataStreamReader(object): +class DataStreamReader(ReaderUtils): """ Interface used to load a streaming :class:`DataFrame` from external storage systems (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` @@ -791,6 +819,7 @@ class DataStreamReader(object): :param source: string, name of the data source, e.g. 'json', 'parquet'. + >>> s = spark.readStream.format("text") """ self._jreader = self._jreader.format(source) return self @@ -806,6 +835,8 @@ class DataStreamReader(object): .. note:: Experimental. :param schema: a StructType object + + >>> s = spark.readStream.schema(sdf_schema) """ if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -818,6 +849,8 @@ class DataStreamReader(object): """Adds an input option for the underlying data source. .. note:: Experimental. + + >>> s = spark.readStream.option("x", 1) """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -827,6 +860,8 @@ class DataStreamReader(object): """Adds input options for the underlying data source. .. note:: Experimental. + + >>> s = spark.readStream.options(x="1", y=2) """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -843,6 +878,13 @@ class DataStreamReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options + >>> json_sdf = spark.readStream.format("json")\ + .schema(sdf_schema)\ + .load(os.path.join(tempfile.mkdtemp(),'data')) + >>> json_sdf.isStreaming + True + >>> json_sdf.schema == sdf_schema + True """ if format is not None: self.format(format) @@ -905,29 +947,18 @@ class DataStreamReader(object): it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + >>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), \ + schema = sdf_schema) + >>> json_sdf.isStreaming + True + >>> json_sdf.schema == sdf_schema + True """ - if schema is not None: - self.schema(schema) - if primitivesAsString is not None: - self.option("primitivesAsString", primitivesAsString) - if prefersDecimal is not None: - self.option("prefersDecimal", prefersDecimal) - if allowComments is not None: - self.option("allowComments", allowComments) - if allowUnquotedFieldNames is not None: - self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) - if allowSingleQuotes is not None: - self.option("allowSingleQuotes", allowSingleQuotes) - if allowNumericLeadingZero is not None: - self.option("allowNumericLeadingZero", allowNumericLeadingZero) - if allowBackslashEscapingAnyCharacter is not None: - self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) - if mode is not None: - self.option("mode", mode) - if columnNameOfCorruptRecord is not None: - self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + self._set_json_opts(schema, primitivesAsString, prefersDecimal, + allowComments, allowUnquotedFieldNames, allowSingleQuotes, + allowNumericLeadingZero, allowBackslashEscapingAnyCharacter, + mode, columnNameOfCorruptRecord) if isinstance(path, basestring): - path = [path] return self._df(self._jreader.json(path)) else: raise TypeError("path can be only a single string") @@ -943,10 +974,15 @@ class DataStreamReader(object): .. note:: Experimental. + >>> parquet_sdf = spark.readStream.schema(sdf_schema)\ + .parquet(os.path.join(tempfile.mkdtemp())) + >>> parquet_sdf.isStreaming + True + >>> parquet_sdf.schema == sdf_schema + True """ if isinstance(path, basestring): - path = [path] - return self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.parquet(path)) else: raise TypeError("path can be only a single string") @@ -964,10 +1000,14 @@ class DataStreamReader(object): :param paths: string, or list of strings, for input path(s). + >>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 'data')) + >>> text_sdf.isStreaming + True + >>> "value" in str(text_sdf.schema) + True """ if isinstance(path, basestring): - path = [path] - return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.text(path)) else: raise TypeError("path can be only a single string") @@ -1034,46 +1074,20 @@ class DataStreamReader(object): * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + >>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 'data'), \ + schema = sdf_schema) + >>> csv_sdf.isStreaming + True + >>> csv_sdf.schema == sdf_schema + True """ - if schema is not None: - self.schema(schema) - if sep is not None: - self.option("sep", sep) - if encoding is not None: - self.option("encoding", encoding) - if quote is not None: - self.option("quote", quote) - if escape is not None: - self.option("escape", escape) - if comment is not None: - self.option("comment", comment) - if header is not None: - self.option("header", header) - if inferSchema is not None: - self.option("inferSchema", inferSchema) - if ignoreLeadingWhiteSpace is not None: - self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) - if ignoreTrailingWhiteSpace is not None: - self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) - if nullValue is not None: - self.option("nullValue", nullValue) - if nanValue is not None: - self.option("nanValue", nanValue) - if positiveInf is not None: - self.option("positiveInf", positiveInf) - if negativeInf is not None: - self.option("negativeInf", negativeInf) - if dateFormat is not None: - self.option("dateFormat", dateFormat) - if maxColumns is not None: - self.option("maxColumns", maxColumns) - if maxCharsPerColumn is not None: - self.option("maxCharsPerColumn", maxCharsPerColumn) - if mode is not None: - self.option("mode", mode) + + self._set_csv_opts(schema, sep, encoding, quote, escape, + comment, header, inferSchema, ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf, + dateFormat, maxColumns, maxCharsPerColumn, mode) if isinstance(path, basestring): - path = [path] - return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.csv(path)) else: raise TypeError("path can be only a single string") @@ -1286,7 +1300,7 @@ def _test(): globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') globs['sdf'] = \ spark.readStream.format('text').load('python/test_support/sql/streaming') - + globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) |