aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala216
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala96
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala152
7 files changed, 443 insertions, 91 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 780fe51ac6..cb9493a575 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -26,14 +26,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
/**
* Interface used to load a [[Dataset]] from external storage systems (e.g. file systems,
@@ -261,8 +261,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
- * Loads a JSON file (<a href="http://jsonlines.org/">JSON Lines text format or
- * newline-delimited JSON</a>) and returns the result as a `DataFrame`.
+ * Loads a JSON file and returns the results as a `DataFrame`.
+ *
+ * Both JSON (one record per file) and <a href="http://jsonlines.org/">JSON Lines</a>
+ * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option.
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
@@ -301,6 +303,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.</li>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
* to be used to parse timestamps.</li>
+ * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines,
+ * per file</li>
* </ul>
*
* @since 2.0.0
@@ -332,20 +336,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def json(jsonRDD: RDD[String]): DataFrame = {
- val parsedOptions: JSONOptions =
- new JSONOptions(extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val columnNameOfCorruptRecord =
- parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val parsedOptions = new JSONOptions(
+ extraOptions.toMap,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val createParser = CreateJacksonParser.string _
+
val schema = userSpecifiedSchema.getOrElse {
JsonInferSchema.infer(
jsonRDD,
- columnNameOfCorruptRecord,
- parsedOptions)
+ parsedOptions,
+ createParser)
}
+
val parsed = jsonRDD.mapPartitions { iter =>
- val parser = new JacksonParser(schema, columnNameOfCorruptRecord, parsedOptions)
- iter.flatMap(parser.parse)
+ val parser = new JacksonParser(schema, parsedOptions)
+ iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
}
Dataset.ofRows(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
index 900263aeb2..0762d1b7da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.datasources
-import java.io.{OutputStream, OutputStreamWriter}
+import java.io.{InputStream, OutputStream, OutputStreamWriter}
import java.nio.charset.{Charset, StandardCharsets}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress._
import org.apache.hadoop.mapreduce.JobContext
@@ -27,6 +28,20 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.ReflectionUtils
object CodecStreams {
+ private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
+ val compressionCodecs = new CompressionCodecFactory(config)
+ Option(compressionCodecs.getCodec(file))
+ }
+
+ def createInputStream(config: Configuration, file: Path): InputStream = {
+ val fs = file.getFileSystem(config)
+ val inputStream: InputStream = fs.open(file)
+
+ getDecompressionCodec(config, file)
+ .map(codec => codec.createInputStream(inputStream))
+ .getOrElse(inputStream)
+ }
+
private def getCompressionCodec(
context: JobContext,
file: Option[Path] = None): Option[CompressionCodec] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
new file mode 100644
index 0000000000..3e984effcb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import java.io.InputStream
+
+import scala.reflect.ClassTag
+
+import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.rdd.{BinaryFileRDD, RDD}
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+/**
+ * Common functions for parsing JSON files
+ * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
+ */
+abstract class JsonDataSource[T] extends Serializable {
+ def isSplitable: Boolean
+
+ /**
+ * Parse a [[PartitionedFile]] into 0 or more [[InternalRow]] instances
+ */
+ def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: JacksonParser): Iterator[InternalRow]
+
+ /**
+ * Create an [[RDD]] that handles the preliminary parsing of [[T]] records
+ */
+ protected def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[T]
+
+ /**
+ * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
+ * for an instance of [[T]]
+ */
+ def createParser(jsonFactory: JsonFactory, value: T): JsonParser
+
+ final def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): Option[StructType] = {
+ if (inputPaths.nonEmpty) {
+ val jsonSchema = JsonInferSchema.infer(
+ createBaseRdd(sparkSession, inputPaths),
+ parsedOptions,
+ createParser)
+ checkConstraints(jsonSchema)
+ Some(jsonSchema)
+ } else {
+ None
+ }
+ }
+
+ /** Constraints to be imposed on schema to be stored. */
+ private def checkConstraints(schema: StructType): Unit = {
+ if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
+ val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
+ case (x, ys) if ys.length > 1 => "\"" + x + "\""
+ }.mkString(", ")
+ throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
+ s"cannot save to JSON format")
+ }
+ }
+}
+
+object JsonDataSource {
+ def apply(options: JSONOptions): JsonDataSource[_] = {
+ if (options.wholeFile) {
+ WholeFileJsonDataSource
+ } else {
+ TextInputJsonDataSource
+ }
+ }
+
+ /**
+ * Create a new [[RDD]] via the supplied callback if there is at least one file to process,
+ * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
+ */
+ def createBaseRdd[T : ClassTag](
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus])(
+ fn: (Configuration, String) => RDD[T]): RDD[T] = {
+ val paths = inputPaths.map(_.getPath)
+
+ if (paths.nonEmpty) {
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ FileInputFormat.setInputPaths(job, paths: _*)
+ fn(job.getConfiguration, paths.mkString(","))
+ } else {
+ sparkSession.sparkContext.emptyRDD[T]
+ }
+ }
+}
+
+object TextInputJsonDataSource extends JsonDataSource[Text] {
+ override val isSplitable: Boolean = {
+ // splittable if the underlying source is
+ true
+ }
+
+ override protected def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[Text] = {
+ JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
+ case (conf, name) =>
+ sparkSession.sparkContext.newAPIHadoopRDD(
+ conf,
+ classOf[TextInputFormat],
+ classOf[LongWritable],
+ classOf[Text])
+ .setName(s"JsonLines: $name")
+ .values // get the text column
+ }
+ }
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: JacksonParser): Iterator[InternalRow] = {
+ val linesReader = new HadoopFileLinesReader(file, conf)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
+ }
+
+ private def textToUTF8String(value: Text): UTF8String = {
+ UTF8String.fromBytes(value.getBytes, 0, value.getLength)
+ }
+
+ override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
+ CreateJacksonParser.text(jsonFactory, value)
+ }
+}
+
+object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
+ override val isSplitable: Boolean = {
+ false
+ }
+
+ override protected def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
+ JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
+ case (conf, name) =>
+ new BinaryFileRDD(
+ sparkSession.sparkContext,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ conf,
+ sparkSession.sparkContext.defaultMinPartitions)
+ .setName(s"JsonFile: $name")
+ .values
+ }
+ }
+
+ private def createInputStream(config: Configuration, path: String): InputStream = {
+ val inputStream = CodecStreams.createInputStream(config, new Path(path))
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
+ inputStream
+ }
+
+ override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
+ CreateJacksonParser.inputStream(
+ jsonFactory,
+ createInputStream(record.getConfiguration, record.getPath()))
+ }
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: JacksonParser): Iterator[InternalRow] = {
+ def partitionedFileString(ignored: Any): UTF8String = {
+ Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream =>
+ UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
+ }
+ }
+
+ parser.parse(
+ createInputStream(conf, file.filePath),
+ CreateJacksonParser.inputStream,
+ partitionedFileString).toIterator
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index b4a8ff2cf0..2cbf4ea7be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -19,15 +19,10 @@ package org.apache.spark.sql.execution.datasources.json
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{LongWritable, Text}
-import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
-import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
@@ -37,29 +32,30 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
+ override val shortName: String = "json"
- override def shortName(): String = "json"
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val jsonDataSource = JsonDataSource(parsedOptions)
+ jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
+ }
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
- if (files.isEmpty) {
- None
- } else {
- val parsedOptions: JSONOptions =
- new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val columnNameOfCorruptRecord =
- parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- val jsonSchema = JsonInferSchema.infer(
- createBaseRdd(sparkSession, files),
- columnNameOfCorruptRecord,
- parsedOptions)
- checkConstraints(jsonSchema)
-
- Some(jsonSchema)
- }
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ JsonDataSource(parsedOptions).infer(
+ sparkSession, files, parsedOptions)
}
override def prepareWrite(
@@ -68,8 +64,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val conf = job.getConfiguration
- val parsedOptions: JSONOptions =
- new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
parsedOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
@@ -99,47 +97,17 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
- val parsedOptions: JSONOptions =
- new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
(file: PartitionedFile) => {
- val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
- val lines = linesReader.map(_.toString)
- val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions)
- lines.flatMap(parser.parse)
- }
- }
-
- private def createBaseRdd(
- sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): RDD[String] = {
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
- val conf = job.getConfiguration
-
- val paths = inputPaths.map(_.getPath)
-
- if (paths.nonEmpty) {
- FileInputFormat.setInputPaths(job, paths: _*)
- }
-
- sparkSession.sparkContext.hadoopRDD(
- conf.asInstanceOf[JobConf],
- classOf[TextInputFormat],
- classOf[LongWritable],
- classOf[Text]).map(_._2.toString) // get the text line
- }
-
- /** Constraints to be imposed on schema to be stored. */
- private def checkConstraints(schema: StructType): Unit = {
- if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
- val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
- case (x, ys) if ys.length > 1 => "\"" + x + "\""
- }.mkString(", ")
- throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
- s"cannot save to JSON format")
+ val parser = new JacksonParser(requiredSchema, parsedOptions)
+ JsonDataSource(parsedOptions).readFile(
+ broadcastedHadoopConf.value.value,
+ file,
+ parser)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index f51c18d46f..ab09358115 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -36,13 +36,14 @@ private[sql] object JsonInferSchema {
* 2. Merge types by choosing the lowest type necessary to cover equal keys
* 3. Replace any remaining null fields with string, the top type
*/
- def infer(
- json: RDD[String],
- columnNameOfCorruptRecord: String,
- configOptions: JSONOptions): StructType = {
+ def infer[T](
+ json: RDD[T],
+ configOptions: JSONOptions,
+ createParser: (JsonFactory, T) => JsonParser): StructType = {
require(configOptions.samplingRatio > 0,
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
val shouldHandleCorruptRecord = configOptions.permissive
+ val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
val schemaData = if (configOptions.samplingRatio > 0.99) {
json
} else {
@@ -55,7 +56,7 @@ private[sql] object JsonInferSchema {
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
try {
- Utils.tryWithResource(factory.createParser(row)) { parser =>
+ Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
Some(inferField(parser, configOptions))
}
@@ -79,7 +80,7 @@ private[sql] object JsonInferSchema {
private[this] val structFieldComparator = new Comparator[StructField] {
override def compare(o1: StructField, o2: StructField): Int = {
- o1.name.compare(o2.name)
+ o1.name.compareTo(o2.name)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 4e706da184..99943944f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -141,8 +141,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
}
/**
- * Loads a JSON file stream (<a href="http://jsonlines.org/">JSON Lines text format or
- * newline-delimited JSON</a>) and returns the result as a `DataFrame`.
+ * Loads a JSON file stream and returns the results as a `DataFrame`.
+ *
+ * Both JSON (one record per file) and <a href="http://jsonlines.org/">JSON Lines</a>
+ * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option.
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
@@ -183,6 +185,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `java.text.SimpleDateFormat`. This applies to timestamp type.</li>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
* to be used to parse timestamps.</li>
+ * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines,
+ * per file</li>
* </ul>
*
* @since 2.0.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 9344aeda00..05aa2ab2ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -28,8 +28,8 @@ import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkException
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
+import org.apache.spark.sql.{functions => F, _}
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType
@@ -64,7 +64,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val dummyOption = new JSONOptions(Map.empty[String, String], "GMT")
val dummySchema = StructType(Seq.empty)
- val parser = new JacksonParser(dummySchema, "", dummyOption)
+ val parser = new JacksonParser(dummySchema, dummyOption)
Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser =>
jsonParser.nextToken()
@@ -1367,7 +1367,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
val emptySchema = JsonInferSchema.infer(
- empty, "", new JSONOptions(Map.empty[String, String], "GMT"))
+ empty,
+ new JSONOptions(Map.empty[String, String], "GMT"),
+ CreateJacksonParser.string)
assert(StructType(Seq()) === emptySchema)
}
@@ -1392,7 +1394,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-8093 Erase empty structs") {
val emptySchema = JsonInferSchema.infer(
- emptyRecords, "", new JSONOptions(Map.empty[String, String], "GMT"))
+ emptyRecords,
+ new JSONOptions(Map.empty[String, String], "GMT"),
+ CreateJacksonParser.string)
assert(StructType(Seq()) === emptySchema)
}
@@ -1802,4 +1806,142 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val df2 = spark.read.option("PREfersdecimaL", "true").json(records)
assert(df2.schema == schema)
}
+
+ test("SPARK-18352: Parse normal multi-line JSON files (compressed)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ primitiveFieldAndType
+ .toDF("value")
+ .write
+ .option("compression", "GzIp")
+ .text(path)
+
+ assert(new File(path).listFiles().exists(_.getName.endsWith(".gz")))
+
+ val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ jsonDF.coalesce(1).write
+ .option("compression", "gZiP")
+ .json(jsonDir)
+
+ assert(new File(jsonDir).listFiles().exists(_.getName.endsWith(".json.gz")))
+
+ val originalData = spark.read.json(primitiveFieldAndType)
+ checkAnswer(jsonDF, originalData)
+ checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData)
+ }
+ }
+
+ test("SPARK-18352: Parse normal multi-line JSON files (uncompressed)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ primitiveFieldAndType
+ .toDF("value")
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ jsonDF.coalesce(1).write.json(jsonDir)
+
+ val compressedFiles = new File(jsonDir).listFiles()
+ assert(compressedFiles.exists(_.getName.endsWith(".json")))
+
+ val originalData = spark.read.json(primitiveFieldAndType)
+ checkAnswer(jsonDF, originalData)
+ checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData)
+ }
+ }
+
+ test("SPARK-18352: Expect one JSON document per file") {
+ // the json parser terminates as soon as it sees a matching END_OBJECT or END_ARRAY token.
+ // this might not be the optimal behavior but this test verifies that only the first value
+ // is parsed and the rest are discarded.
+
+ // alternatively the parser could continue parsing following objects, which may further reduce
+ // allocations by skipping the line reader entirely
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ spark
+ .createDataFrame(Seq(Tuple1("{}{invalid}")))
+ .coalesce(1)
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).json(path)
+ // no corrupt record column should be created
+ assert(jsonDF.schema === StructType(Seq()))
+ // only the first object should be read
+ assert(jsonDF.count() === 1)
+ }
+ }
+
+ test("SPARK-18352: Handle multi-line corrupt documents (PERMISSIVE)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val corruptRecordCount = additionalCorruptRecords.count().toInt
+ assert(corruptRecordCount === 5)
+
+ additionalCorruptRecords
+ .toDF("value")
+ // this is the minimum partition count that avoids hash collisions
+ .repartition(corruptRecordCount * 4, F.hash($"value"))
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path)
+ assert(jsonDF.count() === corruptRecordCount)
+ assert(jsonDF.schema === new StructType()
+ .add("_corrupt_record", StringType)
+ .add("dummy", StringType))
+ val counts = jsonDF
+ .join(
+ additionalCorruptRecords.toDF("value"),
+ F.regexp_replace($"_corrupt_record", "(^\\s+|\\s+$)", "") === F.trim($"value"),
+ "outer")
+ .agg(
+ F.count($"dummy").as("valid"),
+ F.count($"_corrupt_record").as("corrupt"),
+ F.count("*").as("count"))
+ checkAnswer(counts, Row(1, 4, 6))
+ }
+ }
+
+ test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val corruptRecordCount = additionalCorruptRecords.count().toInt
+ assert(corruptRecordCount === 5)
+
+ additionalCorruptRecords
+ .toDF("value")
+ // this is the minimum partition count that avoids hash collisions
+ .repartition(corruptRecordCount * 4, F.hash($"value"))
+ .write
+ .text(path)
+
+ val schema = new StructType().add("dummy", StringType)
+
+ // `FAILFAST` mode should throw an exception for corrupt records.
+ val exceptionOne = intercept[SparkException] {
+ spark.read
+ .option("wholeFile", true)
+ .option("mode", "FAILFAST")
+ .json(path)
+ .collect()
+ }
+ assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode"))
+
+ val exceptionTwo = intercept[SparkException] {
+ spark.read
+ .option("wholeFile", true)
+ .option("mode", "FAILFAST")
+ .schema(schema)
+ .json(path)
+ .collect()
+ }
+ assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode"))
+ }
+ }
}