aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala168
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala125
4 files changed, 33 insertions, 291 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 12e19f955c..1bf57882ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -56,7 +56,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString)
val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
- val firstRow = new LineCsvReader(csvOptions).parseLine(firstLine)
+ val firstRow = new CsvReader(csvOptions).parseLine(firstLine)
val header = if (csvOptions.headerFlag) {
firstRow.zipWithIndex.map { case (value, index) =>
@@ -103,6 +103,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
+ val commentPrefix = csvOptions.comment.toString
val headers = requiredSchema.fields.map(_.name)
val broadcastedHadoopConf =
@@ -118,7 +119,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
- val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
+ val csvParser = new CsvReader(csvOptions)
+ val tokenizedIterator = lineIterator.filter { line =>
+ line.trim.nonEmpty && !line.startsWith(commentPrefix)
+ }.map { line =>
+ csvParser.parseLine(line)
+ }
val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
var numMalformedRecords = 0
tokenizedIterator.flatMap { recordTokens =>
@@ -146,7 +152,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val rdd = baseRdd(sparkSession, options, inputPaths)
// Make sure firstLine is materialized before sending to executors
val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
- CSVRelation.univocityTokenizer(rdd, header, firstLine, options)
+ CSVRelation.univocityTokenizer(rdd, firstLine, options)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index 2103262580..bf62732dd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -27,11 +27,10 @@ import org.apache.spark.internal.Logging
* Read and parse CSV-like input
*
* @param params Parameters object
- * @param headers headers for the columns
*/
-private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) {
+private[sql] class CsvReader(params: CSVOptions) {
- protected lazy val parser: CsvParser = {
+ private val parser: CsvParser = {
val settings = new CsvParserSettings()
val format = settings.getFormat
format.setDelimiter(params.delimiter)
@@ -47,10 +46,17 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String])
settings.setNullValue(params.nullValue)
settings.setMaxCharsPerColumn(params.maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
- if (headers != null) settings.setHeaders(headers: _*)
new CsvParser(settings)
}
+
+ /**
+ * parse a line
+ *
+ * @param line a String with no newline at the end
+ * @return array of strings where each string is a field in the CSV record
+ */
+ def parseLine(line: String): Array[String] = parser.parseLine(line)
}
/**
@@ -97,157 +103,3 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
writer.close()
}
}
-
-/**
- * Parser for parsing a line at a time. Not efficient for bulk data.
- *
- * @param params Parameters object
- */
-private[sql] class LineCsvReader(params: CSVOptions)
- extends CsvReader(params, null) {
- /**
- * parse a line
- *
- * @param line a String with no newline at the end
- * @return array of strings where each string is a field in the CSV record
- */
- def parseLine(line: String): Array[String] = {
- parser.beginParsing(new StringReader(line))
- val parsed = parser.parseNext()
- parser.stopParsing()
- parsed
- }
-}
-
-/**
- * Parser for parsing lines in bulk. Use this when efficiency is desired.
- *
- * @param iter iterator over lines in the file
- * @param params Parameters object
- * @param headers headers for the columns
- */
-private[sql] class BulkCsvReader(
- iter: Iterator[String],
- params: CSVOptions,
- headers: Seq[String])
- extends CsvReader(params, headers) with Iterator[Array[String]] {
-
- private val reader = new StringIteratorReader(iter)
- parser.beginParsing(reader)
- private var nextRecord = parser.parseNext()
-
- /**
- * get the next parsed line.
- * @return array of strings where each string is a field in the CSV record
- */
- override def next(): Array[String] = {
- val curRecord = nextRecord
- if(curRecord != null) {
- nextRecord = parser.parseNext()
- } else {
- throw new NoSuchElementException("next record is null")
- }
- curRecord
- }
-
- override def hasNext: Boolean = nextRecord != null
-
-}
-
-/**
- * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at
- * end of each line Univocity parser requires a Reader that provides access to the data to be
- * parsed and needs the newlines to be present
- * @param iter iterator over RDD[String]
- */
-private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader {
-
- private var next: Long = 0
- private var length: Long = 0 // length of input so far
- private var start: Long = 0
- private var str: String = null // current string from iter
-
- /**
- * fetch next string from iter, if done with current one
- * pretend there is a new line at the end of every string we get from from iter
- */
- private def refill(): Unit = {
- if (length == next) {
- if (iter.hasNext) {
- str = iter.next()
- start = length
- length += (str.length + 1) // allowance for newline removed by SparkContext.textFile()
- } else {
- str = null
- }
- }
- }
-
- /**
- * read the next character, if at end of string pretend there is a new line
- */
- override def read(): Int = {
- refill()
- if (next >= length) {
- -1
- } else {
- val cur = next - start
- next += 1
- if (cur == str.length) '\n' else str.charAt(cur.toInt)
- }
- }
-
- /**
- * read from str into cbuf
- */
- override def read(cbuf: Array[Char], off: Int, len: Int): Int = {
- refill()
- var n = 0
- if ((off < 0) || (off > cbuf.length) || (len < 0) ||
- ((off + len) > cbuf.length) || ((off + len) < 0)) {
- throw new IndexOutOfBoundsException()
- } else if (len == 0) {
- n = 0
- } else {
- if (next >= length) { // end of input
- n = -1
- } else {
- n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size
- if (n == length - next) {
- str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off)
- cbuf(off + n - 1) = '\n'
- } else {
- str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off)
- }
- next += n
- if (n < len) {
- val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter
- if(m != -1) n += m
- }
- }
- }
-
- n
- }
-
- override def skip(ns: Long): Long = {
- throw new IllegalArgumentException("Skip not implemented")
- }
-
- override def ready: Boolean = {
- refill()
- true
- }
-
- override def markSupported: Boolean = false
-
- override def mark(readAheadLimit: Int): Unit = {
- throw new IllegalArgumentException("Mark not implemented")
- }
-
- override def reset(): Unit = {
- throw new IllegalArgumentException("Mark and hence reset not implemented")
- }
-
- override def close(): Unit = { }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index e8c0134d38..c6ba424d86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -38,15 +38,24 @@ object CSVRelation extends Logging {
def univocityTokenizer(
file: RDD[String],
- header: Seq[String],
firstLine: String,
params: CSVOptions): RDD[Array[String]] = {
// If header is set, make sure firstLine is materialized before sending to executors.
+ val commentPrefix = params.comment.toString
file.mapPartitions { iter =>
- new BulkCsvReader(
- if (params.headerFlag) iter.filterNot(_ == firstLine) else iter,
- params,
- headers = header)
+ val parser = new CsvReader(params)
+ val filteredIter = iter.filter { line =>
+ line.trim.nonEmpty && !line.startsWith(commentPrefix)
+ }
+ if (params.headerFlag) {
+ filteredIter.filterNot(_ == firstLine).map { item =>
+ parser.parseLine(item)
+ }
+ } else {
+ filteredIter.map { item =>
+ parser.parseLine(item)
+ }
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
deleted file mode 100644
index aaeecef5f3..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import org.apache.spark.SparkFunSuite
-
-/**
- * test cases for StringIteratorReader
- */
-class CSVParserSuite extends SparkFunSuite {
-
- private def readAll(iter: Iterator[String]) = {
- val reader = new StringIteratorReader(iter)
- var c: Int = -1
- val read = new scala.collection.mutable.StringBuilder()
- do {
- c = reader.read()
- read.append(c.toChar)
- } while (c != -1)
-
- read.dropRight(1).toString
- }
-
- private def readBufAll(iter: Iterator[String], bufSize: Int) = {
- val reader = new StringIteratorReader(iter)
- val cbuf = new Array[Char](bufSize)
- val read = new scala.collection.mutable.StringBuilder()
-
- var done = false
- do { // read all input one cbuf at a time
- var numRead = 0
- var n = 0
- do { // try to fill cbuf
- var off = 0
- var len = cbuf.length
- n = reader.read(cbuf, off, len)
-
- if (n != -1) {
- off += n
- len -= n
- }
-
- assert(len >= 0 && len <= cbuf.length)
- assert(off >= 0 && off <= cbuf.length)
- read.appendAll(cbuf.take(n))
- } while (n > 0)
- if(n != -1) {
- numRead += n
- } else {
- done = true
- }
- } while (!done)
-
- read.toString
- }
-
- test("Hygiene") {
- val reader = new StringIteratorReader(List("").toIterator)
- assert(reader.ready === true)
- assert(reader.markSupported === false)
- intercept[IllegalArgumentException] { reader.skip(1) }
- intercept[IllegalArgumentException] { reader.mark(1) }
- intercept[IllegalArgumentException] { reader.reset() }
- }
-
- test("Regular case") {
- val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"")
- val read = readAll(input.toIterator)
- assert(read === input.mkString("\n") ++ "\n")
- }
-
- test("Empty iter") {
- val input = List[String]()
- val read = readAll(input.toIterator)
- assert(read === "")
- }
-
- test("Embedded new line") {
- val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"")
- val read = readAll(input.toIterator)
- assert(read === input.mkString("\n") ++ "\n")
- }
-
- test("Buffer Regular case") {
- val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"")
- val output = input.mkString("\n") ++ "\n"
- for(i <- 1 to output.length + 5) {
- val read = readBufAll(input.toIterator, i)
- assert(read === output)
- }
- }
-
- test("Buffer Empty iter") {
- val input = List[String]()
- val output = ""
- for(i <- 1 to output.length + 5) {
- val read = readBufAll(input.toIterator, 1)
- assert(read === "")
- }
- }
-
- test("Buffer Embedded new line") {
- val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"")
- val output = input.mkString("\n") ++ "\n"
- for(i <- 1 to output.length + 5) {
- val read = readBufAll(input.toIterator, 1)
- assert(read === output)
- }
- }
-}