diff options
Diffstat (limited to 'sql/core')
7 files changed, 443 insertions, 91 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 780fe51ac6..cb9493a575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -26,14 +26,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String /** * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems, @@ -261,8 +261,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (<a href="http://jsonlines.org/">JSON Lines text format or - * newline-delimited JSON</a>) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * + * Both JSON (one record per file) and <a href="http://jsonlines.org/">JSON Lines</a> + * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -301,6 +303,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type.</li> * <li>`timeZone` (default session local timezone): sets the string that indicates a timezone * to be used to parse timestamps.</li> + * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines, + * per file</li> * </ul> * * @since 2.0.0 @@ -332,20 +336,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val parsedOptions: JSONOptions = - new JSONOptions(extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new JSONOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val createParser = CreateJacksonParser.string _ + val schema = userSpecifiedSchema.getOrElse { JsonInferSchema.infer( jsonRDD, - columnNameOfCorruptRecord, - parsedOptions) + parsedOptions, + createParser) } + val parsed = jsonRDD.mapPartitions { iter => - val parser = new JacksonParser(schema, columnNameOfCorruptRecord, parsedOptions) - iter.flatMap(parser.parse) + val parser = new JacksonParser(schema, parsedOptions) + iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 900263aeb2..0762d1b7da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.{OutputStream, OutputStreamWriter} +import java.io.{InputStream, OutputStream, OutputStreamWriter} import java.nio.charset.{Charset, StandardCharsets} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.compress._ import org.apache.hadoop.mapreduce.JobContext @@ -27,6 +28,20 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils object CodecStreams { + private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { + val compressionCodecs = new CompressionCodecFactory(config) + Option(compressionCodecs.getCodec(file)) + } + + def createInputStream(config: Configuration, file: Path): InputStream = { + val fs = file.getFileSystem(config) + val inputStream: InputStream = fs.open(file) + + getDecompressionCodec(config, file) + .map(codec => codec.createInputStream(inputStream)) + .getOrElse(inputStream) + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala new file mode 100644 index 0000000000..3e984effcb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import java.io.InputStream + +import scala.reflect.ClassTag + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * Common functions for parsing JSON files + * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] + */ +abstract class JsonDataSource[T] extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into 0 or more [[InternalRow]] instances + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] + + /** + * Create an [[RDD]] that handles the preliminary parsing of [[T]] records + */ + protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[T] + + /** + * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] + * for an instance of [[T]] + */ + def createParser(jsonFactory: JsonFactory, value: T): JsonParser + + final def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + val jsonSchema = JsonInferSchema.infer( + createBaseRdd(sparkSession, inputPaths), + parsedOptions, + createParser) + checkConstraints(jsonSchema) + Some(jsonSchema) + } else { + None + } + } + + /** Constraints to be imposed on schema to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } +} + +object JsonDataSource { + def apply(options: JSONOptions): JsonDataSource[_] = { + if (options.wholeFile) { + WholeFileJsonDataSource + } else { + TextInputJsonDataSource + } + } + + /** + * Create a new [[RDD]] via the supplied callback if there is at least one file to process, + * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. + */ + def createBaseRdd[T : ClassTag]( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus])( + fn: (Configuration, String) => RDD[T]): RDD[T] = { + val paths = inputPaths.map(_.getPath) + + if (paths.nonEmpty) { + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + fn(job.getConfiguration, paths.mkString(",")) + } else { + sparkSession.sparkContext.emptyRDD[T] + } + } +} + +object TextInputJsonDataSource extends JsonDataSource[Text] { + override val isSplitable: Boolean = { + // splittable if the underlying source is + true + } + + override protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[Text] = { + JsonDataSource.createBaseRdd(sparkSession, inputPaths) { + case (conf, name) => + sparkSession.sparkContext.newAPIHadoopRDD( + conf, + classOf[TextInputFormat], + classOf[LongWritable], + classOf[Text]) + .setName(s"JsonLines: $name") + .values // get the text column + } + } + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + } + + private def textToUTF8String(value: Text): UTF8String = { + UTF8String.fromBytes(value.getBytes, 0, value.getLength) + } + + override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { + CreateJacksonParser.text(jsonFactory, value) + } +} + +object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { + override val isSplitable: Boolean = { + false + } + + override protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + JsonDataSource.createBaseRdd(sparkSession, inputPaths) { + case (conf, name) => + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values + } + } + + private def createInputStream(config: Configuration, path: String): InputStream = { + val inputStream = CodecStreams.createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + + override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream( + jsonFactory, + createInputStream(record.getConfiguration, record.getPath())) + } + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] = { + def partitionedFileString(ignored: Any): UTF8String = { + Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) + } + } + + parser.parse( + createInputStream(conf, file.filePath), + CreateJacksonParser.inputStream, + partitionedFileString).toIterator + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index b4a8ff2cf0..2cbf4ea7be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -19,15 +19,10 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs @@ -37,29 +32,30 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { + override val shortName: String = "json" - override def shortName(): String = "json" + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val jsonDataSource = JsonDataSource(parsedOptions) + jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - if (files.isEmpty) { - None - } else { - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, files), - columnNameOfCorruptRecord, - parsedOptions) - checkConstraints(jsonSchema) - - Some(jsonSchema) - } + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + JsonDataSource(parsedOptions).infer( + sparkSession, files, parsedOptions) } override def prepareWrite( @@ -68,8 +64,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { val conf = job.getConfiguration - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) parsedOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -99,47 +97,17 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) (file: PartitionedFile) => { - val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - val lines = linesReader.map(_.toString) - val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions) - lines.flatMap(parser.parse) - } - } - - private def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[String] = { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - val conf = job.getConfiguration - - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) - } - - sparkSession.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]).map(_._2.toString) // get the text line - } - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") + val parser = new JacksonParser(requiredSchema, parsedOptions) + JsonDataSource(parsedOptions).readFile( + broadcastedHadoopConf.value.value, + file, + parser) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index f51c18d46f..ab09358115 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -36,13 +36,14 @@ private[sql] object JsonInferSchema { * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def infer( - json: RDD[String], - columnNameOfCorruptRecord: String, - configOptions: JSONOptions): StructType = { + def infer[T]( + json: RDD[T], + configOptions: JSONOptions, + createParser: (JsonFactory, T) => JsonParser): StructType = { require(configOptions.samplingRatio > 0, s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") val shouldHandleCorruptRecord = configOptions.permissive + val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { @@ -55,7 +56,7 @@ private[sql] object JsonInferSchema { configOptions.setJacksonOptions(factory) iter.flatMap { row => try { - Utils.tryWithResource(factory.createParser(row)) { parser => + Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() Some(inferField(parser, configOptions)) } @@ -79,7 +80,7 @@ private[sql] object JsonInferSchema { private[this] val structFieldComparator = new Comparator[StructField] { override def compare(o1: StructField, o2: StructField): Int = { - o1.name.compare(o2.name) + o1.name.compareTo(o2.name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4e706da184..99943944f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -141,8 +141,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Loads a JSON file stream (<a href="http://jsonlines.org/">JSON Lines text format or - * newline-delimited JSON</a>) and returns the result as a `DataFrame`. + * Loads a JSON file stream and returns the results as a `DataFrame`. + * + * Both JSON (one record per file) and <a href="http://jsonlines.org/">JSON Lines</a> + * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -183,6 +185,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type.</li> * <li>`timeZone` (default session local timezone): sets the string that indicates a timezone * to be used to parse timestamps.</li> + * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines, + * per file</li> * </ul> * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9344aeda00..05aa2ab2ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -28,8 +28,8 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.rdd.RDD import org.apache.spark.SparkException -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.{functions => F, _} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType @@ -64,7 +64,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dummyOption = new JSONOptions(Map.empty[String, String], "GMT") val dummySchema = StructType(Seq.empty) - val parser = new JacksonParser(dummySchema, "", dummyOption) + val parser = new JacksonParser(dummySchema, dummyOption) Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser => jsonParser.nextToken() @@ -1367,7 +1367,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty, "", new JSONOptions(Map.empty[String, String], "GMT")) + empty, + new JSONOptions(Map.empty[String, String], "GMT"), + CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1392,7 +1394,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords, "", new JSONOptions(Map.empty[String, String], "GMT")) + emptyRecords, + new JSONOptions(Map.empty[String, String], "GMT"), + CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1802,4 +1806,142 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val df2 = spark.read.option("PREfersdecimaL", "true").json(records) assert(df2.schema == schema) } + + test("SPARK-18352: Parse normal multi-line JSON files (compressed)") { + withTempPath { dir => + val path = dir.getCanonicalPath + primitiveFieldAndType + .toDF("value") + .write + .option("compression", "GzIp") + .text(path) + + assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write + .option("compression", "gZiP") + .json(jsonDir) + + assert(new File(jsonDir).listFiles().exists(_.getName.endsWith(".json.gz"))) + + val originalData = spark.read.json(primitiveFieldAndType) + checkAnswer(jsonDF, originalData) + checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData) + } + } + + test("SPARK-18352: Parse normal multi-line JSON files (uncompressed)") { + withTempPath { dir => + val path = dir.getCanonicalPath + primitiveFieldAndType + .toDF("value") + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write.json(jsonDir) + + val compressedFiles = new File(jsonDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".json"))) + + val originalData = spark.read.json(primitiveFieldAndType) + checkAnswer(jsonDF, originalData) + checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData) + } + } + + test("SPARK-18352: Expect one JSON document per file") { + // the json parser terminates as soon as it sees a matching END_OBJECT or END_ARRAY token. + // this might not be the optimal behavior but this test verifies that only the first value + // is parsed and the rest are discarded. + + // alternatively the parser could continue parsing following objects, which may further reduce + // allocations by skipping the line reader entirely + + withTempPath { dir => + val path = dir.getCanonicalPath + spark + .createDataFrame(Seq(Tuple1("{}{invalid}"))) + .coalesce(1) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + // no corrupt record column should be created + assert(jsonDF.schema === StructType(Seq())) + // only the first object should be read + assert(jsonDF.count() === 1) + } + } + + test("SPARK-18352: Handle multi-line corrupt documents (PERMISSIVE)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + assert(jsonDF.count() === corruptRecordCount) + assert(jsonDF.schema === new StructType() + .add("_corrupt_record", StringType) + .add("dummy", StringType)) + val counts = jsonDF + .join( + additionalCorruptRecords.toDF("value"), + F.regexp_replace($"_corrupt_record", "(^\\s+|\\s+$)", "") === F.trim($"value"), + "outer") + .agg( + F.count($"dummy").as("valid"), + F.count($"_corrupt_record").as("corrupt"), + F.count("*").as("count")) + checkAnswer(counts, Row(1, 4, 6)) + } + } + + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val schema = new StructType().add("dummy", StringType) + + // `FAILFAST` mode should throw an exception for corrupt records. + val exceptionOne = intercept[SparkException] { + spark.read + .option("wholeFile", true) + .option("mode", "FAILFAST") + .json(path) + .collect() + } + assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + + val exceptionTwo = intercept[SparkException] { + spark.read + .option("wholeFile", true) + .option("mode", "FAILFAST") + .schema(schema) + .json(path) + .collect() + } + assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + } + } } |