aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/readwriter.py6
-rw-r--r--python/pyspark/sql/streaming.py6
-rw-r--r--python/pyspark/sql/tests.py9
3 files changed, 16 insertions, 5 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b5e5b18bcb..ec47618e73 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -308,7 +308,7 @@ class DataFrameReader(OptionUtils):
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
- columnNameOfCorruptRecord=None):
+ columnNameOfCorruptRecord=None, wholeFile=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -385,6 +385,8 @@ class DataFrameReader(OptionUtils):
``spark.sql.columnNameOfCorruptRecord``. If None is set,
it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``.
+ :param wholeFile: parse records, which may span multiple lines. If None is
+ set, it uses the default value, ``false``.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
@@ -398,7 +400,7 @@ class DataFrameReader(OptionUtils):
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
- columnNameOfCorruptRecord=columnNameOfCorruptRecord)
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index bd19fd4e38..7587875cb9 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -562,7 +562,7 @@ class DataStreamReader(OptionUtils):
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
- columnNameOfCorruptRecord=None):
+ columnNameOfCorruptRecord=None, wholeFile=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -637,6 +637,8 @@ class DataStreamReader(OptionUtils):
``spark.sql.columnNameOfCorruptRecord``. If None is set,
it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``.
+ :param wholeFile: parse one record, which may span multiple lines. If None is
+ set, it uses the default value, ``false``.
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
@@ -652,7 +654,7 @@ class DataStreamReader(OptionUtils):
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
- columnNameOfCorruptRecord=columnNameOfCorruptRecord)
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fd083e4868..e943f8da3d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -437,12 +437,19 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
def test_wholefile_json(self):
- from pyspark.sql.types import StringType
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
wholeFile=True)
self.assertEqual(people1.collect(), people_array.collect())
+ def test_wholefile_csv(self):
+ ages_newlines = self.spark.read.csv(
+ "python/test_support/sql/ages_newlines.csv", wholeFile=True)
+ expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
+ Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
+ Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
+ self.assertEqual(ages_newlines.collect(), expected)
+
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType