diff options
9 files changed, 223 insertions, 53 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bed390e60..b5e5b18bcb 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -191,10 +191,13 @@ class DataFrameReader(OptionUtils): :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record and puts the malformed string into a new field configured by \ - ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ - ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ + schema. If a schema does not have the field, it drops corrupt records during \ + parsing. When inferring a schema, it implicitly adds a \ + ``columnNameOfCorruptRecord`` field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -304,7 +307,8 @@ class DataFrameReader(OptionUtils): comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -366,11 +370,22 @@ class DataFrameReader(OptionUtils): :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an \ + user-defined schema. If a schema does not have the field, it drops corrupt \ + records during parsing. When a length of parsed CSV tokens is shorter than \ + an expected length of a schema, it sets `null` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] @@ -382,7 +397,8 @@ class DataFrameReader(OptionUtils): nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 965c8f6b26..bd19fd4e38 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -463,10 +463,13 @@ class DataStreamReader(OptionUtils): :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record and puts the malformed string into a new field configured by \ - ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ - ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ + schema. If a schema does not have the field, it drops corrupt records during \ + parsing. When inferring a schema, it implicitly adds a \ + ``columnNameOfCorruptRecord`` field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -558,7 +561,8 @@ class DataStreamReader(OptionUtils): comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -618,11 +622,22 @@ class DataStreamReader(OptionUtils): :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an \ + user-defined schema. If a schema does not have the field, it drops corrupt \ + records during parsing. When a length of parsed CSV tokens is shorter than \ + an expected length of a schema, it sets `null` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming True @@ -636,7 +651,8 @@ class DataStreamReader(OptionUtils): nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2be22761e8..59baf6e567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -286,8 +286,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * during parsing. * <ul> * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.</li> + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` + * field in an output schema.</li> * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> * </ul> @@ -447,12 +450,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. * <ul> - * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.</li> + * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When a length of parsed CSV tokens is shorter than an expected length + * of a schema, it sets `null` for extra fields.</li> * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> * </ul> * </li> + * <li>`columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li> * </ul> * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 566f40f454..59f2919edf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,9 +27,9 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.sources._ @@ -96,31 +96,44 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - + CSVUtils.verifySchema(dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val parsedOptions = new CSVOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + // Check a field requirement for corrupt records here to throw an exception in a driver side + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + (file: PartitionedFile) => { val lines = { val conf = broadcastedHadoopConf.value.value val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, csvOptions.charset) + new String(line.getBytes, 0, line.getLength, parsedOptions.charset) } } - val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) { + val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) { // Note that if there are only comments in the first block, the header would probably // be not dropped. - CSVUtils.dropHeaderLine(lines, csvOptions) + CSVUtils.dropHeaderLine(lines, parsedOptions) } else { lines } - val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions) - val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) + val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions) + val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) filteredLines.flatMap(parser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index b7fbaa4f44..1caeec7c63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,11 +27,20 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} private[csv] class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String) + @transient private val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { - def this(parameters: Map[String, String], defaultTimeZoneId: String) = - this(CaseInsensitiveMap(parameters), defaultTimeZoneId) + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) @@ -95,6 +104,9 @@ private[csv] class CSVOptions( val dropMalformed = ParseModes.isDropMalformedMode(parseMode) val permissive = ParseModes.isPermissiveMode(parseMode) + val columnNameOfCorruptRecord = + parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + val nullValue = parameters.getOrElse("nullValue", "") val nanValue = parameters.getOrElse("nanValue", "NaN") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 2e409b3f5f..eb471651db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -45,8 +45,16 @@ private[csv] class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) + corruptFieldIndex.foreach { corrFieldIndex => + require(schema(corrFieldIndex).dataType == StringType) + require(schema(corrFieldIndex).nullable) + } + + private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) + private val valueConverters = - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val parser = new CsvParser(options.asParserSettings) @@ -54,7 +62,9 @@ private[csv] class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) - private val indexArr: Array[Int] = { + // This parser loads an `indexArr._1`-th position value in input tokens, + // then put the value in `row(indexArr._2)`. + private val indexArr: Array[(Int, Int)] = { val fields = if (options.dropMalformed) { // If `dropMalformed` is enabled, then it needs to parse all the values // so that we can decide which row is malformed. @@ -62,7 +72,17 @@ private[csv] class UnivocityParser( } else { requiredSchema } - fields.map(schema.indexOf(_: StructField)).toArray + // TODO: Revisit this; we need to clean up code here for readability. + // See an URL below for related discussions: + // https://github.com/apache/spark/pull/16928#discussion_r102636720 + val fieldsWithIndexes = fields.zipWithIndex + corruptFieldIndex.map { case corrFieldIndex => + fieldsWithIndexes.filter { case (_, i) => i != corrFieldIndex } + }.getOrElse { + fieldsWithIndexes + }.map { case (f, i) => + (dataSchema.indexOf(f), i) + }.toArray } /** @@ -148,6 +168,7 @@ private[csv] class UnivocityParser( case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) + // We don't actually hit this exception though, we keep it for understandability case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") } @@ -172,16 +193,16 @@ private[csv] class UnivocityParser( * the record is malformed). */ def parse(input: String): Option[InternalRow] = { - convertWithParseMode(parser.parseLine(input)) { tokens => + convertWithParseMode(input) { tokens => var i: Int = 0 while (i < indexArr.length) { - val pos = indexArr(i) + val (pos, rowIdx) = indexArr(i) // It anyway needs to try to parse since it decides if this row is malformed // or not after trying to cast in `DROPMALFORMED` mode even if the casted // value is not stored in the row. val value = valueConverters(pos).apply(tokens(pos)) if (i < requiredSchema.length) { - row(i) = value + row(rowIdx) = value } i += 1 } @@ -190,8 +211,9 @@ private[csv] class UnivocityParser( } private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && schema.length != tokens.length) { + input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { + val tokens = parser.parseLine(input) + if (options.dropMalformed && dataSchema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") } @@ -202,14 +224,24 @@ private[csv] class UnivocityParser( } numMalformedRecords += 1 None - } else if (options.failFast && schema.length != tokens.length) { + } else if (options.failFast && dataSchema.length != tokens.length) { throw new RuntimeException(s"Malformed line in FAILFAST mode: " + s"${tokens.mkString(options.delimiter.toString)}") } else { - val checkedTokens = if (options.permissive && schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) - } else if (options.permissive && schema.length < tokens.length) { - tokens.take(schema.length) + // If a length of parsed tokens is not equal to expected one, it makes the length the same + // with the expected. If the length is shorter, it adds extra tokens in the tail. + // If longer, it drops extra tokens. + // + // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, + // we probably need to treat it as a malformed record. + // See an URL below for related discussions: + // https://github.com/apache/spark/pull/16928#discussion_r102657214 + val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { + if (dataSchema.length > tokens.length) { + tokens ++ new Array[String](dataSchema.length - tokens.length) + } else { + tokens.take(dataSchema.length) + } } else { tokens } @@ -217,6 +249,10 @@ private[csv] class UnivocityParser( try { Some(convert(checkedTokens)) } catch { + case NonFatal(e) if options.permissive => + val row = new GenericInternalRow(requiredSchema.length) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input)) + Some(row) case NonFatal(e) if options.dropMalformed => if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning("Parse exception. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 99943944f3..f78e73f319 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -168,8 +168,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * during parsing. * <ul> * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.</li> + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` + * field in an output schema.</li> * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> * </ul> @@ -245,12 +248,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. * <ul> - * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.</li> + * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When a length of parsed CSV tokens is shorter than an expected length + * of a schema, it sets `null` for extra fields.</li> * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> * </ul> * </li> + * <li>`columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li> * </ul> * * @since 2.0.0 diff --git a/sql/core/src/test/resources/test-data/value-malformed.csv b/sql/core/src/test/resources/test-data/value-malformed.csv new file mode 100644 index 0000000000..8945ed73d2 --- /dev/null +++ b/sql/core/src/test/resources/test-data/value-malformed.csv @@ -0,0 +1,2 @@ +0,2013-111-11 12:13:14 +1,1983-08-04 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 0c9a7298c3..371d4311ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -53,6 +53,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val numbersFile = "test-data/numbers.csv" private val datesFile = "test-data/dates.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" + private val valueMalformedFile = "test-data/value-malformed.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -700,12 +701,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { }.getMessage assert(msg.contains("CSV data source does not support array<double> data type")) - msg = intercept[SparkException] { + msg = intercept[UnsupportedOperationException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() - }.getCause.getMessage - assert(msg.contains("Unsupported type: array")) + }.getMessage + assert(msg.contains("CSV data source does not support array<double> data type.")) } } @@ -958,4 +959,58 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, Row(1, null)) } } + + test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val df1 = spark + .read + .option("mode", "PERMISSIVE") + .schema(schema) + .csv(testFile(valueMalformedFile)) + checkAnswer(df1, + Row(null, null) :: + Row(1, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records + val columnNameOfCorruptRecord = "_unparsed" + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schemaWithCorrField1) + .csv(testFile(valueMalformedFile)) + checkAnswer(df2, + Row(null, null, "0,2013-111-11 12:13:14") :: + Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: + Nil) + + // We put a `columnNameOfCorruptRecord` field in the middle of a schema + val schemaWithCorrField2 = new StructType() + .add("a", IntegerType) + .add(columnNameOfCorruptRecord, StringType) + .add("b", TimestampType) + val df3 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schemaWithCorrField2) + .csv(testFile(valueMalformedFile)) + checkAnswer(df3, + Row(null, "0,2013-111-11 12:13:14", null) :: + Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + val errMsg = intercept[AnalysisException] { + spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .csv(testFile(valueMalformedFile)) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } } |