From 7e5359be5ca038fdb579712b18e7f226d705c276 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 28 Feb 2017 13:34:33 -0800 Subject: [SPARK-19610][SQL] Support parsing multiline CSV files ## What changes were proposed in this pull request? This PR proposes the support for multiple lines for CSV by resembling the multiline supports in JSON datasource (in case of JSON, per file). So, this PR introduces `wholeFile` option which makes the format not splittable and reads each whole file. Since Univocity parser can produces each row from a stream, it should be capable of parsing very large documents when the internal rows are fix in the memory. ## How was this patch tested? Unit tests in `CSVSuite` and `tests.py` Manual tests with a single 9GB CSV file in local file system, for example, ```scala spark.read.option("wholeFile", true).option("inferSchema", true).csv("tmp.csv").count() ``` Author: hyukjinkwon Closes #16976 from HyukjinKwon/SPARK-19610. --- python/pyspark/sql/readwriter.py | 6 ++++-- python/pyspark/sql/streaming.py | 6 ++++-- python/pyspark/sql/tests.py | 9 ++++++++- 3 files changed, 16 insertions(+), 5 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b5e5b18bcb..ec47618e73 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -308,7 +308,7 @@ class DataFrameReader(OptionUtils): ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -385,6 +385,8 @@ class DataFrameReader(OptionUtils): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse records, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -398,7 +400,7 @@ class DataFrameReader(OptionUtils): dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bd19fd4e38..7587875cb9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -562,7 +562,7 @@ class DataStreamReader(OptionUtils): ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -637,6 +637,8 @@ class DataStreamReader(OptionUtils): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse one record, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -652,7 +654,7 @@ class DataStreamReader(OptionUtils): dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fd083e4868..e943f8da3d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -437,12 +437,19 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_wholefile_json(self): - from pyspark.sql.types import StringType people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", wholeFile=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_wholefile_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", wholeFile=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType -- cgit v1.2.3