diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/readwriter.py | 6 | ||||
-rw-r--r-- | python/pyspark/sql/streaming.py | 6 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 9 |
3 files changed, 16 insertions, 5 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b5e5b18bcb..ec47618e73 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -308,7 +308,7 @@ class DataFrameReader(OptionUtils): ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -385,6 +385,8 @@ class DataFrameReader(OptionUtils): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse records, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -398,7 +400,7 @@ class DataFrameReader(OptionUtils): dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bd19fd4e38..7587875cb9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -562,7 +562,7 @@ class DataStreamReader(OptionUtils): ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -637,6 +637,8 @@ class DataStreamReader(OptionUtils): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse one record, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -652,7 +654,7 @@ class DataStreamReader(OptionUtils): dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fd083e4868..e943f8da3d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -437,12 +437,19 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_wholefile_json(self): - from pyspark.sql.types import StringType people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", wholeFile=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_wholefile_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", wholeFile=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType |