aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/readwriter.py32
-rw-r--r--python/pyspark/sql/streaming.py32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala62
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala18
-rw-r--r--sql/core/src/test/resources/test-data/value-malformed.csv2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala63
9 files changed, 223 insertions, 53 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 6bed390e60..b5e5b18bcb 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -191,10 +191,13 @@ class DataFrameReader(OptionUtils):
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record and puts the malformed string into a new field configured by \
- ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \
- ``null`` for extra fields.
+ * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
+ record, and puts the malformed string into a field configured by \
+ ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
+ a string type field named ``columnNameOfCorruptRecord`` in an user-defined \
+ schema. If a schema does not have the field, it drops corrupt records during \
+ parsing. When inferring a schema, it implicitly adds a \
+ ``columnNameOfCorruptRecord`` field in an output schema.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
@@ -304,7 +307,8 @@ class DataFrameReader(OptionUtils):
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
- maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None):
+ maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
+ columnNameOfCorruptRecord=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -366,11 +370,22 @@ class DataFrameReader(OptionUtils):
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record.
- When a schema is set by user, it sets ``null`` for extra fields.
+ * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
+ record, and puts the malformed string into a field configured by \
+ ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
+ a string type field named ``columnNameOfCorruptRecord`` in an \
+ user-defined schema. If a schema does not have the field, it drops corrupt \
+ records during parsing. When a length of parsed CSV tokens is shorter than \
+ an expected length of a schema, it sets `null` for extra fields.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
+ :param columnNameOfCorruptRecord: allows renaming the new field having malformed string
+ created by ``PERMISSIVE`` mode. This overrides
+ ``spark.sql.columnNameOfCorruptRecord``. If None is set,
+ it uses the value specified in
+ ``spark.sql.columnNameOfCorruptRecord``.
+
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
@@ -382,7 +397,8 @@ class DataFrameReader(OptionUtils):
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
- maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone)
+ maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 965c8f6b26..bd19fd4e38 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -463,10 +463,13 @@ class DataStreamReader(OptionUtils):
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record and puts the malformed string into a new field configured by \
- ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \
- ``null`` for extra fields.
+ * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
+ record, and puts the malformed string into a field configured by \
+ ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
+ a string type field named ``columnNameOfCorruptRecord`` in an user-defined \
+ schema. If a schema does not have the field, it drops corrupt records during \
+ parsing. When inferring a schema, it implicitly adds a \
+ ``columnNameOfCorruptRecord`` field in an output schema.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
@@ -558,7 +561,8 @@ class DataStreamReader(OptionUtils):
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
- maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None):
+ maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
+ columnNameOfCorruptRecord=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -618,11 +622,22 @@ class DataStreamReader(OptionUtils):
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record.
- When a schema is set by user, it sets ``null`` for extra fields.
+ * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
+ record, and puts the malformed string into a field configured by \
+ ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
+ a string type field named ``columnNameOfCorruptRecord`` in an \
+ user-defined schema. If a schema does not have the field, it drops corrupt \
+ records during parsing. When a length of parsed CSV tokens is shorter than \
+ an expected length of a schema, it sets `null` for extra fields.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
+ :param columnNameOfCorruptRecord: allows renaming the new field having malformed string
+ created by ``PERMISSIVE`` mode. This overrides
+ ``spark.sql.columnNameOfCorruptRecord``. If None is set,
+ it uses the value specified in
+ ``spark.sql.columnNameOfCorruptRecord``.
+
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
True
@@ -636,7 +651,8 @@ class DataStreamReader(OptionUtils):
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
- maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone)
+ maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 2be22761e8..59baf6e567 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -286,8 +286,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* during parsing.
* <ul>
* <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
- * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When
- * a schema is set by user, it sets `null` for extra fields.</li>
+ * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
+ * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
+ * in an user-defined schema. If a schema does not have the field, it drops corrupt records
+ * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord`
+ * field in an output schema.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
@@ -447,12 +450,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
* <ul>
- * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
- * a schema is set by user, it sets `null` for extra fields.</li>
+ * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
+ * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
+ * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
+ * in an user-defined schema. If a schema does not have the field, it drops corrupt records
+ * during parsing. When a length of parsed CSV tokens is shorter than an expected length
+ * of a schema, it sets `null` for extra fields.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
* </li>
+ * <li>`columnNameOfCorruptRecord` (default is the value specified in
+ * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
+ * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
* </ul>
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 566f40f454..59f2919edf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -27,9 +27,9 @@ import org.apache.hadoop.mapreduce._
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs}
+import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.sources._
@@ -96,31 +96,44 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
- val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
-
+ CSVUtils.verifySchema(dataSchema)
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ val parsedOptions = new CSVOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+
+ // Check a field requirement for corrupt records here to throw an exception in a driver side
+ dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
+ val f = dataSchema(corruptFieldIndex)
+ if (f.dataType != StringType || !f.nullable) {
+ throw new AnalysisException(
+ "The field for corrupt records must be string type and nullable")
+ }
+ }
+
(file: PartitionedFile) => {
val lines = {
val conf = broadcastedHadoopConf.value.value
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
linesReader.map { line =>
- new String(line.getBytes, 0, line.getLength, csvOptions.charset)
+ new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
}
}
- val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) {
+ val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) {
// Note that if there are only comments in the first block, the header would probably
// be not dropped.
- CSVUtils.dropHeaderLine(lines, csvOptions)
+ CSVUtils.dropHeaderLine(lines, parsedOptions)
} else {
lines
}
- val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
- val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions)
+ val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions)
+ val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions)
filteredLines.flatMap(parser.parse)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index b7fbaa4f44..1caeec7c63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -27,11 +27,20 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}
private[csv] class CSVOptions(
- @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String)
+ @transient private val parameters: CaseInsensitiveMap[String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
- def this(parameters: Map[String, String], defaultTimeZoneId: String) =
- this(CaseInsensitiveMap(parameters), defaultTimeZoneId)
+ def this(
+ parameters: Map[String, String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String = "") = {
+ this(
+ CaseInsensitiveMap(parameters),
+ defaultTimeZoneId,
+ defaultColumnNameOfCorruptRecord)
+ }
private def getChar(paramName: String, default: Char): Char = {
val paramValue = parameters.get(paramName)
@@ -95,6 +104,9 @@ private[csv] class CSVOptions(
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
val permissive = ParseModes.isPermissiveMode(parseMode)
+ val columnNameOfCorruptRecord =
+ parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
+
val nullValue = parameters.getOrElse("nullValue", "")
val nanValue = parameters.getOrElse("nanValue", "NaN")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 2e409b3f5f..eb471651db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -45,8 +45,16 @@ private[csv] class UnivocityParser(
// A `ValueConverter` is responsible for converting the given value to a desired type.
private type ValueConverter = String => Any
+ private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
+ corruptFieldIndex.foreach { corrFieldIndex =>
+ require(schema(corrFieldIndex).dataType == StringType)
+ require(schema(corrFieldIndex).nullable)
+ }
+
+ private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord))
+
private val valueConverters =
- schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
+ dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
private val parser = new CsvParser(options.asParserSettings)
@@ -54,7 +62,9 @@ private[csv] class UnivocityParser(
private val row = new GenericInternalRow(requiredSchema.length)
- private val indexArr: Array[Int] = {
+ // This parser loads an `indexArr._1`-th position value in input tokens,
+ // then put the value in `row(indexArr._2)`.
+ private val indexArr: Array[(Int, Int)] = {
val fields = if (options.dropMalformed) {
// If `dropMalformed` is enabled, then it needs to parse all the values
// so that we can decide which row is malformed.
@@ -62,7 +72,17 @@ private[csv] class UnivocityParser(
} else {
requiredSchema
}
- fields.map(schema.indexOf(_: StructField)).toArray
+ // TODO: Revisit this; we need to clean up code here for readability.
+ // See an URL below for related discussions:
+ // https://github.com/apache/spark/pull/16928#discussion_r102636720
+ val fieldsWithIndexes = fields.zipWithIndex
+ corruptFieldIndex.map { case corrFieldIndex =>
+ fieldsWithIndexes.filter { case (_, i) => i != corrFieldIndex }
+ }.getOrElse {
+ fieldsWithIndexes
+ }.map { case (f, i) =>
+ (dataSchema.indexOf(f), i)
+ }.toArray
}
/**
@@ -148,6 +168,7 @@ private[csv] class UnivocityParser(
case udt: UserDefinedType[_] => (datum: String) =>
makeConverter(name, udt.sqlType, nullable, options)
+ // We don't actually hit this exception though, we keep it for understandability
case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}")
}
@@ -172,16 +193,16 @@ private[csv] class UnivocityParser(
* the record is malformed).
*/
def parse(input: String): Option[InternalRow] = {
- convertWithParseMode(parser.parseLine(input)) { tokens =>
+ convertWithParseMode(input) { tokens =>
var i: Int = 0
while (i < indexArr.length) {
- val pos = indexArr(i)
+ val (pos, rowIdx) = indexArr(i)
// It anyway needs to try to parse since it decides if this row is malformed
// or not after trying to cast in `DROPMALFORMED` mode even if the casted
// value is not stored in the row.
val value = valueConverters(pos).apply(tokens(pos))
if (i < requiredSchema.length) {
- row(i) = value
+ row(rowIdx) = value
}
i += 1
}
@@ -190,8 +211,9 @@ private[csv] class UnivocityParser(
}
private def convertWithParseMode(
- tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = {
- if (options.dropMalformed && schema.length != tokens.length) {
+ input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = {
+ val tokens = parser.parseLine(input)
+ if (options.dropMalformed && dataSchema.length != tokens.length) {
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")
}
@@ -202,14 +224,24 @@ private[csv] class UnivocityParser(
}
numMalformedRecords += 1
None
- } else if (options.failFast && schema.length != tokens.length) {
+ } else if (options.failFast && dataSchema.length != tokens.length) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(options.delimiter.toString)}")
} else {
- val checkedTokens = if (options.permissive && schema.length > tokens.length) {
- tokens ++ new Array[String](schema.length - tokens.length)
- } else if (options.permissive && schema.length < tokens.length) {
- tokens.take(schema.length)
+ // If a length of parsed tokens is not equal to expected one, it makes the length the same
+ // with the expected. If the length is shorter, it adds extra tokens in the tail.
+ // If longer, it drops extra tokens.
+ //
+ // TODO: Revisit this; if a length of tokens does not match an expected length in the schema,
+ // we probably need to treat it as a malformed record.
+ // See an URL below for related discussions:
+ // https://github.com/apache/spark/pull/16928#discussion_r102657214
+ val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) {
+ if (dataSchema.length > tokens.length) {
+ tokens ++ new Array[String](dataSchema.length - tokens.length)
+ } else {
+ tokens.take(dataSchema.length)
+ }
} else {
tokens
}
@@ -217,6 +249,10 @@ private[csv] class UnivocityParser(
try {
Some(convert(checkedTokens))
} catch {
+ case NonFatal(e) if options.permissive =>
+ val row = new GenericInternalRow(requiredSchema.length)
+ corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input))
+ Some(row)
case NonFatal(e) if options.dropMalformed =>
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
logWarning("Parse exception. " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 99943944f3..f78e73f319 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -168,8 +168,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* during parsing.
* <ul>
* <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
- * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When
- * a schema is set by user, it sets `null` for extra fields.</li>
+ * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
+ * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
+ * in an user-defined schema. If a schema does not have the field, it drops corrupt records
+ * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord`
+ * field in an output schema.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
@@ -245,12 +248,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
* <ul>
- * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
- * a schema is set by user, it sets `null` for extra fields.</li>
+ * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
+ * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
+ * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
+ * in an user-defined schema. If a schema does not have the field, it drops corrupt records
+ * during parsing. When a length of parsed CSV tokens is shorter than an expected length
+ * of a schema, it sets `null` for extra fields.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
* </li>
+ * <li>`columnNameOfCorruptRecord` (default is the value specified in
+ * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
+ * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
* </ul>
*
* @since 2.0.0
diff --git a/sql/core/src/test/resources/test-data/value-malformed.csv b/sql/core/src/test/resources/test-data/value-malformed.csv
new file mode 100644
index 0000000000..8945ed73d2
--- /dev/null
+++ b/sql/core/src/test/resources/test-data/value-malformed.csv
@@ -0,0 +1,2 @@
+0,2013-111-11 12:13:14
+1,1983-08-04
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 0c9a7298c3..371d4311ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.SparkException
-import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
@@ -53,6 +53,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val numbersFile = "test-data/numbers.csv"
private val datesFile = "test-data/dates.csv"
private val unescapedQuotesFile = "test-data/unescaped-quotes.csv"
+ private val valueMalformedFile = "test-data/value-malformed.csv"
private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
@@ -700,12 +701,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}.getMessage
assert(msg.contains("CSV data source does not support array<double> data type"))
- msg = intercept[SparkException] {
+ msg = intercept[UnsupportedOperationException] {
val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil)
spark.range(1).write.csv(csvDir)
spark.read.schema(schema).csv(csvDir).collect()
- }.getCause.getMessage
- assert(msg.contains("Unsupported type: array"))
+ }.getMessage
+ assert(msg.contains("CSV data source does not support array<double> data type."))
}
}
@@ -958,4 +959,58 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(df, Row(1, null))
}
}
+
+ test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") {
+ val schema = new StructType().add("a", IntegerType).add("b", TimestampType)
+ val df1 = spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .schema(schema)
+ .csv(testFile(valueMalformedFile))
+ checkAnswer(df1,
+ Row(null, null) ::
+ Row(1, java.sql.Date.valueOf("1983-08-04")) ::
+ Nil)
+
+ // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records
+ val columnNameOfCorruptRecord = "_unparsed"
+ val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType)
+ val df2 = spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .schema(schemaWithCorrField1)
+ .csv(testFile(valueMalformedFile))
+ checkAnswer(df2,
+ Row(null, null, "0,2013-111-11 12:13:14") ::
+ Row(1, java.sql.Date.valueOf("1983-08-04"), null) ::
+ Nil)
+
+ // We put a `columnNameOfCorruptRecord` field in the middle of a schema
+ val schemaWithCorrField2 = new StructType()
+ .add("a", IntegerType)
+ .add(columnNameOfCorruptRecord, StringType)
+ .add("b", TimestampType)
+ val df3 = spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .schema(schemaWithCorrField2)
+ .csv(testFile(valueMalformedFile))
+ checkAnswer(df3,
+ Row(null, "0,2013-111-11 12:13:14", null) ::
+ Row(1, null, java.sql.Date.valueOf("1983-08-04")) ::
+ Nil)
+
+ val errMsg = intercept[AnalysisException] {
+ spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .schema(schema.add(columnNameOfCorruptRecord, IntegerType))
+ .csv(testFile(valueMalformedFile))
+ .collect
+ }.getMessage
+ assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
+ }
}