diff options
Diffstat (limited to 'sql/core/src/main')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala | 77 |
1 files changed, 47 insertions, 30 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e15c30b437..fb632cf2bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions -import org.apache.spark.sql.catalyst.util.PermissiveMode +import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode + val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards @@ -55,20 +55,24 @@ private[sql] object JsonInferSchema { Some(inferField(parser, configOptions)) } } catch { - case _: JsonParseException if shouldHandleCorruptRecord => - Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) - case _: JsonParseException => - None + case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { + case PermissiveMode => + Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw e + } } } - }.fold(StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord)) + }.fold(StructType(Nil))( + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Seq()) + StructType(Nil) } } @@ -202,19 +206,33 @@ private[sql] object JsonInferSchema { private def withCorruptField( struct: StructType, - columnNameOfCorruptRecords: String): StructType = { - if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { - // If this given struct does not have a column used for corrupt records, - // add this field. - val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(newFields, structFieldComparator) - StructType(newFields) - } else { - // Otherwise, just return this struct. + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode) = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. struct - } + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + + s" but ${other.catalogString} was found.") } /** @@ -222,21 +240,20 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - // If we see any other data type at the root level, we get records that cannot be - // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) - case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) |