aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2017-02-28 13:34:33 -0800
committerWenchen Fan <wenchen@databricks.com>2017-02-28 13:34:33 -0800
commit7e5359be5ca038fdb579712b18e7f226d705c276 (patch)
tree6fca55568b53c2ded63bcbf846a8463ffcafc92a /sql/core/src/main/scala/org/apache
parentce233f18e381fa1ea00be74ca26e97d35baa6c9c (diff)
downloadspark-7e5359be5ca038fdb579712b18e7f226d705c276.tar.gz
spark-7e5359be5ca038fdb579712b18e7f226d705c276.tar.bz2
spark-7e5359be5ca038fdb579712b18e7f226d705c276.zip
[SPARK-19610][SQL] Support parsing multiline CSV files
## What changes were proposed in this pull request? This PR proposes the support for multiple lines for CSV by resembling the multiline supports in JSON datasource (in case of JSON, per file). So, this PR introduces `wholeFile` option which makes the format not splittable and reads each whole file. Since Univocity parser can produces each row from a stream, it should be capable of parsing very large documents when the internal rows are fix in the memory. ## How was this patch tested? Unit tests in `CSVSuite` and `tests.py` Manual tests with a single 9GB CSV file in local file system, for example, ```scala spark.read.option("wholeFile", true).option("inferSchema", true).csv("tmp.csv").count() ``` Author: hyukjinkwon <gurwls223@gmail.com> Closes #16976 from HyukjinKwon/SPARK-19610.
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala239
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala77
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala1
9 files changed, 371 insertions, 132 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 59baf6e567..63be1e5302 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -463,6 +463,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
+ * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines.</li>
* </ul>
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
index 0762d1b7da..54549f698a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
@@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.ReflectionUtils
+import org.apache.spark.TaskContext
+
object CodecStreams {
private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
val compressionCodecs = new CompressionCodecFactory(config)
@@ -42,6 +44,16 @@ object CodecStreams {
.getOrElse(inputStream)
}
+ /**
+ * Creates an input stream from the string path and add a closure for the input stream to be
+ * closed on task completion.
+ */
+ def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = {
+ val inputStream = createInputStream(config, new Path(path))
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
+ inputStream
+ }
+
private def getCompressionCodec(
context: JobContext,
file: Option[Path] = None): Option[CompressionCodec] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
new file mode 100644
index 0000000000..73e6abc6da
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.io.InputStream
+import java.nio.charset.{Charset, StandardCharsets}
+
+import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+
+import org.apache.spark.TaskContext
+import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.rdd.{BinaryFileRDD, RDD}
+import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Common functions for parsing CSV files
+ */
+abstract class CSVDataSource extends Serializable {
+ def isSplitable: Boolean
+
+ /**
+ * Parse a [[PartitionedFile]] into [[InternalRow]] instances.
+ */
+ def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser,
+ parsedOptions: CSVOptions): Iterator[InternalRow]
+
+ /**
+ * Infers the schema from `inputPaths` files.
+ */
+ def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: CSVOptions): Option[StructType]
+
+ /**
+ * Generates a header from the given row which is null-safe and duplicate-safe.
+ */
+ protected def makeSafeHeader(
+ row: Array[String],
+ caseSensitive: Boolean,
+ options: CSVOptions): Array[String] = {
+ if (options.headerFlag) {
+ val duplicates = {
+ val headerNames = row.filter(_ != null)
+ .map(name => if (caseSensitive) name else name.toLowerCase)
+ headerNames.diff(headerNames.distinct).distinct
+ }
+
+ row.zipWithIndex.map { case (value, index) =>
+ if (value == null || value.isEmpty || value == options.nullValue) {
+ // When there are empty strings or the values set in `nullValue`, put the
+ // index as the suffix.
+ s"_c$index"
+ } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+ // When there are case-insensitive duplicates, put the index as the suffix.
+ s"$value$index"
+ } else if (duplicates.contains(value)) {
+ // When there are duplicates, put the index as the suffix.
+ s"$value$index"
+ } else {
+ value
+ }
+ }
+ } else {
+ row.zipWithIndex.map { case (_, index) =>
+ // Uses default column names, "_c#" where # is its position of fields
+ // when header option is disabled.
+ s"_c$index"
+ }
+ }
+ }
+}
+
+object CSVDataSource {
+ def apply(options: CSVOptions): CSVDataSource = {
+ if (options.wholeFile) {
+ WholeFileCSVDataSource
+ } else {
+ TextInputCSVDataSource
+ }
+ }
+}
+
+object TextInputCSVDataSource extends CSVDataSource {
+ override val isSplitable: Boolean = true
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser,
+ parsedOptions: CSVOptions): Iterator[InternalRow] = {
+ val lines = {
+ val linesReader = new HadoopFileLinesReader(file, conf)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ linesReader.map { line =>
+ new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
+ }
+ }
+
+ val shouldDropHeader = parsedOptions.headerFlag && file.start == 0
+ UnivocityParser.parseIterator(lines, shouldDropHeader, parser)
+ }
+
+ override def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: CSVOptions): Option[StructType] = {
+ val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
+ val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first()
+ val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.rdd.mapPartitions { iter =>
+ val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
+ val linesWithoutHeader =
+ CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
+ val parser = new CsvParser(parsedOptions.asParserSettings)
+ linesWithoutHeader.map(parser.parseLine)
+ }
+
+ Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+ }
+
+ private def createBaseDataset(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ options: CSVOptions): Dataset[String] = {
+ val paths = inputPaths.map(_.getPath.toString)
+ if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
+ sparkSession.baseRelationToDataFrame(
+ DataSource.apply(
+ sparkSession,
+ paths = paths,
+ className = classOf[TextFileFormat].getName
+ ).resolveRelation(checkFilesExist = false))
+ .select("value").as[String](Encoders.STRING)
+ } else {
+ val charset = options.charset
+ val rdd = sparkSession.sparkContext
+ .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(","))
+ .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
+ sparkSession.createDataset(rdd)(Encoders.STRING)
+ }
+ }
+}
+
+object WholeFileCSVDataSource extends CSVDataSource {
+ override val isSplitable: Boolean = false
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser,
+ parsedOptions: CSVOptions): Iterator[InternalRow] = {
+ UnivocityParser.parseStream(
+ CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
+ parsedOptions.headerFlag,
+ parser)
+ }
+
+ override def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: CSVOptions): Option[StructType] = {
+ val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
+ val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines =>
+ UnivocityParser.tokenizeStream(
+ CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
+ false,
+ new CsvParser(parsedOptions.asParserSettings))
+ }.take(1).headOption
+
+ if (maybeFirstRow.isDefined) {
+ val firstRow = maybeFirstRow.get
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.flatMap { lines =>
+ UnivocityParser.tokenizeStream(
+ CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
+ parsedOptions.headerFlag,
+ new CsvParser(parsedOptions.asParserSettings))
+ }
+ Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+ } else {
+ // If the first row could not be read, just return the empty schema.
+ Some(StructType(Nil))
+ }
+ }
+
+ private def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ options: CSVOptions): RDD[PortableDataStream] = {
+ val paths = inputPaths.map(_.getPath)
+ val name = paths.mkString(",")
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ FileInputFormat.setInputPaths(job, paths: _*)
+ val conf = job.getConfiguration
+
+ val rdd = new BinaryFileRDD(
+ sparkSession.sparkContext,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ conf,
+ sparkSession.sparkContext.defaultMinPartitions)
+
+ // Only returns `PortableDataStream`s without paths.
+ rdd.setName(s"CSVFile: $name").values
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 59f2919edf..29c4145527 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -17,21 +17,15 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.nio.charset.{Charset, StandardCharsets}
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{LongWritable, Text}
-import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce._
-import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -43,11 +37,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "csv"
- override def toString: String = "CSV"
-
- override def hashCode(): Int = getClass.hashCode()
-
- override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ val parsedOptions =
+ new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val csvDataSource = CSVDataSource(parsedOptions)
+ csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
+ }
override def inferSchema(
sparkSession: SparkSession,
@@ -55,11 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
files: Seq[FileStatus]): Option[StructType] = {
require(files.nonEmpty, "Cannot infer schema from an empty set of files")
- val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val paths = files.map(_.getPath.toString)
- val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths)
- val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
+ val parsedOptions =
+ new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+
+ CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions)
}
override def prepareWrite(
@@ -115,49 +112,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
(file: PartitionedFile) => {
- val lines = {
- val conf = broadcastedHadoopConf.value.value
- val linesReader = new HadoopFileLinesReader(file, conf)
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
- linesReader.map { line =>
- new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
- }
- }
-
- val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) {
- // Note that if there are only comments in the first block, the header would probably
- // be not dropped.
- CSVUtils.dropHeaderLine(lines, parsedOptions)
- } else {
- lines
- }
-
- val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions)
+ val conf = broadcastedHadoopConf.value.value
val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions)
- filteredLines.flatMap(parser.parse)
+ CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions)
}
}
- private def createBaseDataset(
- sparkSession: SparkSession,
- options: CSVOptions,
- inputPaths: Seq[String]): Dataset[String] = {
- if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
- sparkSession.baseRelationToDataFrame(
- DataSource.apply(
- sparkSession,
- paths = inputPaths,
- className = classOf[TextFileFormat].getName
- ).resolveRelation(checkFilesExist = false))
- .select("value").as[String](Encoders.STRING)
- } else {
- val charset = options.charset
- val rdd = sparkSession.sparkContext
- .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(","))
- .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
- sparkSession.createDataset(rdd)(Encoders.STRING)
- }
- }
+ override def toString: String = "CSV"
+
+ override def hashCode(): Int = getClass.hashCode()
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
}
private[csv] class CsvOutputWriter(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 3fa30fe240..b64d71bb4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -21,11 +21,9 @@ import java.math.BigDecimal
import scala.util.control.Exception._
-import com.univocity.parsers.csv.CsvParser
-
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._
private[csv] object CSVInferSchema {
@@ -37,24 +35,13 @@ private[csv] object CSVInferSchema {
* 3. Replace any null types with string type
*/
def infer(
- csv: Dataset[String],
- caseSensitive: Boolean,
+ tokenRDD: RDD[Array[String]],
+ header: Array[String],
options: CSVOptions): StructType = {
- val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first()
- val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine)
- val header = makeSafeHeader(firstRow, caseSensitive, options)
-
val fields = if (options.inferSchemaFlag) {
- val tokenRdd = csv.rdd.mapPartitions { iter =>
- val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options)
- val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options)
- val parser = new CsvParser(options.asParserSettings)
- linesWithoutHeader.map(parser.parseLine)
- }
-
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
- tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
+ tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
@@ -71,44 +58,6 @@ private[csv] object CSVInferSchema {
StructType(fields)
}
- /**
- * Generates a header from the given row which is null-safe and duplicate-safe.
- */
- private def makeSafeHeader(
- row: Array[String],
- caseSensitive: Boolean,
- options: CSVOptions): Array[String] = {
- if (options.headerFlag) {
- val duplicates = {
- val headerNames = row.filter(_ != null)
- .map(name => if (caseSensitive) name else name.toLowerCase)
- headerNames.diff(headerNames.distinct).distinct
- }
-
- row.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == options.nullValue) {
- // When there are empty strings or the values set in `nullValue`, put the
- // index as the suffix.
- s"_c$index"
- } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
- // When there are case-insensitive duplicates, put the index as the suffix.
- s"$value$index"
- } else if (duplicates.contains(value)) {
- // When there are duplicates, put the index as the suffix.
- s"$value$index"
- } else {
- value
- }
- }
- } else {
- row.zipWithIndex.map { case (_, index) =>
- // Uses default column names, "_c#" where # is its position of fields
- // when header option is disabled.
- s"_c$index"
- }
- }
- }
-
private def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 1caeec7c63..50503385ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -130,6 +130,8 @@ private[csv] class CSVOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)
+ val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
+
val maxColumns = getInt("maxColumns", 20480)
val maxCharsPerColumn = getInt("maxCharsPerColumn", -1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index eb471651db..804031a5bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.csv
+import java.io.InputStream
import java.math.BigDecimal
import java.text.NumberFormat
import java.util.Locale
@@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[csv] class UnivocityParser(
schema: StructType,
requiredSchema: StructType,
- options: CSVOptions) extends Logging {
+ private val options: CSVOptions) extends Logging {
require(requiredSchema.toSet.subsetOf(schema.toSet),
"requiredSchema should be the subset of schema.")
@@ -56,12 +57,15 @@ private[csv] class UnivocityParser(
private val valueConverters =
dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
- private val parser = new CsvParser(options.asParserSettings)
+ private val tokenizer = new CsvParser(options.asParserSettings)
private var numMalformedRecords = 0
private val row = new GenericInternalRow(requiredSchema.length)
+ // This gets the raw input that is parsed lately.
+ private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd
+
// This parser loads an `indexArr._1`-th position value in input tokens,
// then put the value in `row(indexArr._2)`.
private val indexArr: Array[(Int, Int)] = {
@@ -188,12 +192,13 @@ private[csv] class UnivocityParser(
}
/**
- * Parses a single CSV record (in the form of an array of strings in which
- * each element represents a column) and turns it into either one resulting row or no row (if the
+ * Parses a single CSV string and turns it into either one resulting row or no row (if the
* the record is malformed).
*/
- def parse(input: String): Option[InternalRow] = {
- convertWithParseMode(input) { tokens =>
+ def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input))
+
+ private def convert(tokens: Array[String]): Option[InternalRow] = {
+ convertWithParseMode(tokens) { tokens =>
var i: Int = 0
while (i < indexArr.length) {
val (pos, rowIdx) = indexArr(i)
@@ -211,8 +216,7 @@ private[csv] class UnivocityParser(
}
private def convertWithParseMode(
- input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = {
- val tokens = parser.parseLine(input)
+ tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = {
if (options.dropMalformed && dataSchema.length != tokens.length) {
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")
@@ -251,7 +255,7 @@ private[csv] class UnivocityParser(
} catch {
case NonFatal(e) if options.permissive =>
val row = new GenericInternalRow(requiredSchema.length)
- corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input))
+ corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput()))
Some(row)
case NonFatal(e) if options.dropMalformed =>
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
@@ -269,3 +273,75 @@ private[csv] class UnivocityParser(
}
}
}
+
+private[csv] object UnivocityParser {
+
+ /**
+ * Parses a stream that contains CSV strings and turns it into an iterator of tokens.
+ */
+ def tokenizeStream(
+ inputStream: InputStream,
+ shouldDropHeader: Boolean,
+ tokenizer: CsvParser): Iterator[Array[String]] = {
+ convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens)
+ }
+
+ /**
+ * Parses a stream that contains CSV strings and turns it into an iterator of rows.
+ */
+ def parseStream(
+ inputStream: InputStream,
+ shouldDropHeader: Boolean,
+ parser: UnivocityParser): Iterator[InternalRow] = {
+ val tokenizer = parser.tokenizer
+ convertStream(inputStream, shouldDropHeader, tokenizer) { tokens =>
+ parser.convert(tokens)
+ }.flatten
+ }
+
+ private def convertStream[T](
+ inputStream: InputStream,
+ shouldDropHeader: Boolean,
+ tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] {
+ tokenizer.beginParsing(inputStream)
+ private var nextRecord = {
+ if (shouldDropHeader) {
+ tokenizer.parseNext()
+ }
+ tokenizer.parseNext()
+ }
+
+ override def hasNext: Boolean = nextRecord != null
+
+ override def next(): T = {
+ if (!hasNext) {
+ throw new NoSuchElementException("End of stream")
+ }
+ val curRecord = convert(nextRecord)
+ nextRecord = tokenizer.parseNext()
+ curRecord
+ }
+ }
+
+ /**
+ * Parses an iterator that contains CSV strings and turns it into an iterator of rows.
+ */
+ def parseIterator(
+ lines: Iterator[String],
+ shouldDropHeader: Boolean,
+ parser: UnivocityParser): Iterator[InternalRow] = {
+ val options = parser.options
+
+ val linesWithoutHeader = if (shouldDropHeader) {
+ // Note that if there are only comments in the first block, the header would probably
+ // be not dropped.
+ CSVUtils.dropHeaderLine(lines, options)
+ } else {
+ lines
+ }
+
+ val filteredLines: Iterator[String] =
+ CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options)
+ filteredLines.flatMap(line => parser.parse(line))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 3e984effcb..18843bfc30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql.execution.datasources.json
-import java.io.InputStream
-
import scala.reflect.ClassTag
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import com.google.common.io.ByteStreams
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
@@ -186,16 +184,10 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
}
}
- private def createInputStream(config: Configuration, path: String): InputStream = {
- val inputStream = CodecStreams.createInputStream(config, new Path(path))
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
- inputStream
- }
-
override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
CreateJacksonParser.inputStream(
jsonFactory,
- createInputStream(record.getConfiguration, record.getPath()))
+ CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))
}
override def readFile(
@@ -203,13 +195,15 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
file: PartitionedFile,
parser: JacksonParser): Iterator[InternalRow] = {
def partitionedFileString(ignored: Any): UTF8String = {
- Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream =>
+ Utils.tryWithResource {
+ CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)
+ } { inputStream =>
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
}
}
parser.parse(
- createInputStream(conf, file.filePath),
+ CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
CreateJacksonParser.inputStream,
partitionedFileString).toIterator
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index f78e73f319..6a275281d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -261,6 +261,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* <li>`columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
+ * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines.</li>
* </ul>
*
* @since 2.0.0