aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2017-03-15 10:19:19 +0800
committerWenchen Fan <wenchen@databricks.com>2017-03-15 10:19:19 +0800
commit8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2 (patch)
treebd542bdc9238fdc608e67dc0a32658135a3f69aa
parentdacc382f0c918f1ca808228484305ce0e21c705e (diff)
downloadspark-8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2.tar.gz
spark-8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2.tar.bz2
spark-8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2.zip
[SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource
## What changes were proposed in this pull request? This PR proposes to use text datasource when Json schema inference. This basically proposes the similar approach in https://github.com/apache/spark/pull/15813 If we use Dataset for initial loading when inferring the schema, there are advantages. Please refer SPARK-18362 It seems JSON one was supposed to be fixed together but taken out according to https://github.com/apache/spark/pull/15813 > A similar problem also affects the JSON file format and this patch originally fixed that as well, but I've decided to split that change into a separate patch so as not to conflict with changes in another JSON PR. Also, this seems affecting some functionalities because it does not use `FileScanRDD`. This problem is described in SPARK-19885 (but it was CSV's case). ## How was this patch tested? Existing tests should cover this and manual test by `spark.read.json(path)` and check the UI. Author: hyukjinkwon <gurwls223@gmail.com> Closes #17255 from HyukjinKwon/json-filescanrdd.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala145
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala51
5 files changed, 122 insertions, 94 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index f1bce1aa41..309654c804 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
-import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
+import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
@@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
extraOptions.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- val createParser = CreateJacksonParser.string _
val schema = userSpecifiedSchema.getOrElse {
- JsonInferSchema.infer(
- jsonDataset.rdd,
- parsedOptions,
- createParser)
+ TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
}
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+ val createParser = CreateJacksonParser.string _
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 18843bfc30..84f026620d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -17,32 +17,30 @@
package org.apache.spark.sql.execution.datasources.json
-import scala.reflect.ClassTag
-
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import com.google.common.io.ByteStreams
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
-import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.Job
-import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
-import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
-import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
* Common functions for parsing JSON files
- * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
*/
-abstract class JsonDataSource[T] extends Serializable {
+abstract class JsonDataSource extends Serializable {
def isSplitable: Boolean
/**
@@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable {
file: PartitionedFile,
parser: JacksonParser): Iterator[InternalRow]
- /**
- * Create an [[RDD]] that handles the preliminary parsing of [[T]] records
- */
- protected def createBaseRdd(
- sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): RDD[T]
-
- /**
- * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
- * for an instance of [[T]]
- */
- def createParser(jsonFactory: JsonFactory, value: T): JsonParser
-
- final def infer(
+ final def inferSchema(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): Option[StructType] = {
if (inputPaths.nonEmpty) {
- val jsonSchema = JsonInferSchema.infer(
- createBaseRdd(sparkSession, inputPaths),
- parsedOptions,
- createParser)
+ val jsonSchema = infer(sparkSession, inputPaths, parsedOptions)
checkConstraints(jsonSchema)
Some(jsonSchema)
} else {
@@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable {
}
}
+ protected def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): StructType
+
/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
@@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable {
}
object JsonDataSource {
- def apply(options: JSONOptions): JsonDataSource[_] = {
+ def apply(options: JSONOptions): JsonDataSource = {
if (options.wholeFile) {
WholeFileJsonDataSource
} else {
TextInputJsonDataSource
}
}
-
- /**
- * Create a new [[RDD]] via the supplied callback if there is at least one file to process,
- * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
- */
- def createBaseRdd[T : ClassTag](
- sparkSession: SparkSession,
- inputPaths: Seq[FileStatus])(
- fn: (Configuration, String) => RDD[T]): RDD[T] = {
- val paths = inputPaths.map(_.getPath)
-
- if (paths.nonEmpty) {
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
- FileInputFormat.setInputPaths(job, paths: _*)
- fn(job.getConfiguration, paths.mkString(","))
- } else {
- sparkSession.sparkContext.emptyRDD[T]
- }
- }
}
-object TextInputJsonDataSource extends JsonDataSource[Text] {
+object TextInputJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
// splittable if the underlying source is
true
}
- override protected def createBaseRdd(
+ override def infer(
sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): RDD[Text] = {
- JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
- case (conf, name) =>
- sparkSession.sparkContext.newAPIHadoopRDD(
- conf,
- classOf[TextInputFormat],
- classOf[LongWritable],
- classOf[Text])
- .setName(s"JsonLines: $name")
- .values // get the text column
- }
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): StructType = {
+ val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
+ inferFromDataset(json, parsedOptions)
+ }
+
+ def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
+ val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
+ val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
+ JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
+ }
+
+ private def createBaseDataset(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): Dataset[String] = {
+ val paths = inputPaths.map(_.getPath.toString)
+ sparkSession.baseRelationToDataFrame(
+ DataSource.apply(
+ sparkSession,
+ paths = paths,
+ className = classOf[TextFileFormat].getName
+ ).resolveRelation(checkFilesExist = false))
+ .select("value").as(Encoders.STRING)
}
override def readFile(
@@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] {
parser: JacksonParser): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
- linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
+ linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
}
private def textToUTF8String(value: Text): UTF8String = {
UTF8String.fromBytes(value.getBytes, 0, value.getLength)
}
-
- override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
- CreateJacksonParser.text(jsonFactory, value)
- }
}
-object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
+object WholeFileJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
false
}
- override protected def createBaseRdd(
+ override def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): StructType = {
+ val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
+ val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
+ JsonInferSchema.infer(sampled, parsedOptions, createParser)
+ }
+
+ private def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
- JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
- case (conf, name) =>
- new BinaryFileRDD(
- sparkSession.sparkContext,
- classOf[StreamInputFormat],
- classOf[String],
- classOf[PortableDataStream],
- conf,
- sparkSession.sparkContext.defaultMinPartitions)
- .setName(s"JsonFile: $name")
- .values
- }
+ val paths = inputPaths.map(_.getPath)
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ val conf = job.getConfiguration
+ val name = paths.mkString(",")
+ FileInputFormat.setInputPaths(job, paths: _*)
+ new BinaryFileRDD(
+ sparkSession.sparkContext,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ conf,
+ sparkSession.sparkContext.defaultMinPartitions)
+ .setName(s"JsonFile: $name")
+ .values
}
- override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
+ private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
CreateJacksonParser.inputStream(
jsonFactory,
CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 902fee5a7e..a9dd91eba6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- JsonDataSource(parsedOptions).infer(
+ JsonDataSource(parsedOptions).inferSchema(
sparkSession, files, parsedOptions)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index ab09358115..7475f8ec79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -40,18 +40,11 @@ private[sql] object JsonInferSchema {
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
- require(configOptions.samplingRatio > 0,
- s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
val shouldHandleCorruptRecord = configOptions.permissive
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
- val schemaData = if (configOptions.samplingRatio > 0.99) {
- json
- } else {
- json.sample(withReplacement = false, configOptions.samplingRatio, 1)
- }
// perform schema inference on each row and merge afterwards
- val rootType = schemaData.mapPartitions { iter =>
+ val rootType = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala
new file mode 100644
index 0000000000..d511594c5d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import org.apache.spark.input.PortableDataStream
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.catalyst.json.JSONOptions
+
+object JsonUtils {
+ /**
+ * Sample JSON dataset as configured by `samplingRatio`.
+ */
+ def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = {
+ require(options.samplingRatio > 0,
+ s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+ if (options.samplingRatio > 0.99) {
+ json
+ } else {
+ json.sample(withReplacement = false, options.samplingRatio, 1)
+ }
+ }
+
+ /**
+ * Sample JSON RDD as configured by `samplingRatio`.
+ */
+ def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = {
+ require(options.samplingRatio > 0,
+ s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+ if (options.samplingRatio > 0.99) {
+ json
+ } else {
+ json.sample(withReplacement = false, options.samplingRatio, 1)
+ }
+ }
+}