diff options
2 files changed, 78 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e15c30b437..fb632cf2bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions -import org.apache.spark.sql.catalyst.util.PermissiveMode +import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode + val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards @@ -55,20 +55,24 @@ private[sql] object JsonInferSchema { Some(inferField(parser, configOptions)) } } catch { - case _: JsonParseException if shouldHandleCorruptRecord => - Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) - case _: JsonParseException => - None + case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { + case PermissiveMode => + Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw e + } } } - }.fold(StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord)) + }.fold(StructType(Nil))( + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Seq()) + StructType(Nil) } } @@ -202,19 +206,33 @@ private[sql] object JsonInferSchema { private def withCorruptField( struct: StructType, - columnNameOfCorruptRecords: String): StructType = { - if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { - // If this given struct does not have a column used for corrupt records, - // add this field. - val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(newFields, structFieldComparator) - StructType(newFields) - } else { - // Otherwise, just return this struct. + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode) = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. struct - } + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + + s" but ${other.catalogString} was found.") } /** @@ -222,21 +240,20 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - // If we see any other data type at the root level, we get records that cannot be - // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) - case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b09cef76d2..2ab0381996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1041,7 +1041,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - .collect() } assert(exceptionOne.getMessage.contains("JsonParseException")) @@ -1082,6 +1081,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(jsonDFTwo.schema === schemaTwo) } + test("SPARK-19641: Additional corrupt records: DROPMALFORMED mode") { + val schema = new StructType().add("dummy", StringType) + // `DROPMALFORMED` mode should skip corrupt records + val jsonDF = spark.read + .option("mode", "DROPMALFORMED") + .json(additionalCorruptRecords) + checkAnswer( + jsonDF, + Row("test")) + assert(jsonDF.schema === schema) + } + test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") { val schema = StructType( StructField("a", StringType, true) :: @@ -1882,6 +1893,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-19641: Handle multi-line corrupt documents (DROPMALFORMED)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + checkAnswer(jsonDF, Seq(Row("test"))) + } + } + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { withTempPath { dir => val path = dir.getCanonicalPath @@ -1903,9 +1932,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("wholeFile", true) .option("mode", "FAILFAST") .json(path) - .collect() } - assert(exceptionOne.getMessage.contains("Failed to parse a value")) + assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) val exceptionTwo = intercept[SparkException] { spark.read |