aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2017-02-07 21:02:20 +0800
committerWenchen Fan <wenchen@databricks.com>2017-02-07 21:02:20 +0800
commit3d314d08c9420e74b4bb687603cdd11394eccab5 (patch)
tree3630a9fdc2b44145102f46bb652d0b0d5c802c29
parent8fd178d2151da53c0edc7ed3a92ebd01780d7702 (diff)
downloadspark-3d314d08c9420e74b4bb687603cdd11394eccab5.tar.gz
spark-3d314d08c9420e74b4bb687603cdd11394eccab5.tar.bz2
spark-3d314d08c9420e74b4bb687603cdd11394eccab5.zip
[SPARK-16101][SQL] Refactoring CSV schema inference path to be consistent with JSON
## What changes were proposed in this pull request? This PR refactors CSV schema inference path to be consistent with JSON data source and moves some filtering codes having the similar/same logics into `CSVUtils`. It makes the methods in classes have consistent arguments with JSON ones. (this PR renames `.../json/InferSchema.scala` → `.../json/JsonInferSchema.scala`) `CSVInferSchema` and `JsonInferSchema` ``` scala private[csv] object CSVInferSchema { ... def infer( csv: Dataset[String], caseSensitive: Boolean, options: CSVOptions): StructType = { ... ``` ``` scala private[sql] object JsonInferSchema { ... def infer( json: RDD[String], columnNameOfCorruptRecord: String, configOptions: JSONOptions): StructType = { ... ``` These allow schema inference from `Dataset[String]` directly, meaning the similar functionalities that use `JacksonParser`/`JsonInferSchema` for JSON can be easily implemented by `UnivocityParser`/`CSVInferSchema` for CSV. This completes refactoring CSV datasource and they are now pretty consistent. ## How was this patch tested? Existing tests should cover this and ``` ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package ``` Author: hyukjinkwon <gurwls223@gmail.com> Closes #16680 from HyukjinKwon/SPARK-16101-schema-inference.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala112
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala115
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala69
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala134
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala)2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala6
11 files changed, 271 insertions, 246 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a787d5a9a9..1830839aee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
-import org.apache.spark.sql.execution.datasources.json.InferSchema
+import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.types.StructType
/**
@@ -334,7 +334,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val schema = userSpecifiedSchema.getOrElse {
- InferSchema.infer(
+ JsonInferSchema.infer(
jsonRDD,
columnNameOfCorruptRecord,
parsedOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 38970160d5..1d2bf07047 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.{Charset, StandardCharsets}
-import com.univocity.parsers.csv.CsvParser
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, Text}
@@ -28,13 +27,11 @@ import org.apache.hadoop.mapreduce._
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.functions.{length, trim}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -60,64 +57,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val csvOptions = new CSVOptions(options)
val paths = files.map(_.getPath.toString)
- val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
- val firstLine: String = findFirstLine(csvOptions, lines)
- val firstRow = new CsvParser(csvOptions.asParserSettings).parseLine(firstLine)
+ val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)
-
- val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
- lines,
- firstLine = if (csvOptions.headerFlag) firstLine else null,
- params = csvOptions)
- val schema = if (csvOptions.inferSchemaFlag) {
- CSVInferSchema.infer(parsedRdd, header, csvOptions)
- } else {
- // By default fields are assumed to be StringType
- val schemaFields = header.map { fieldName =>
- StructField(fieldName, StringType, nullable = true)
- }
- StructType(schemaFields)
- }
- Some(schema)
- }
-
- /**
- * Generates a header from the given row which is null-safe and duplicate-safe.
- */
- private def makeSafeHeader(
- row: Array[String],
- options: CSVOptions,
- caseSensitive: Boolean): Array[String] = {
- if (options.headerFlag) {
- val duplicates = {
- val headerNames = row.filter(_ != null)
- .map(name => if (caseSensitive) name else name.toLowerCase)
- headerNames.diff(headerNames.distinct).distinct
- }
-
- row.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == options.nullValue) {
- // When there are empty strings or the values set in `nullValue`, put the
- // index as the suffix.
- s"_c$index"
- } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
- // When there are case-insensitive duplicates, put the index as the suffix.
- s"$value$index"
- } else if (duplicates.contains(value)) {
- // When there are duplicates, put the index as the suffix.
- s"$value$index"
- } else {
- value
- }
- }
- } else {
- row.zipWithIndex.map { case (_, index) =>
- // Uses default column names, "_c#" where # is its position of fields
- // when header option is disabled.
- s"_c$index"
- }
- }
+ Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
}
override def prepareWrite(
@@ -125,7 +67,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
- verifySchema(dataSchema)
+ CSVUtils.verifySchema(dataSchema)
val conf = job.getConfiguration
val csvOptions = new CSVOptions(options)
csvOptions.compressionCodec.foreach { codec =>
@@ -155,13 +97,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
- val commentPrefix = csvOptions.comment.toString
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
(file: PartitionedFile) => {
- val lineIterator = {
+ val lines = {
val conf = broadcastedHadoopConf.value.value
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -170,32 +111,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}
- // Consumes the header in the iterator.
- CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
-
- val filteredIter = lineIterator.filter { line =>
- line.trim.nonEmpty && !line.startsWith(commentPrefix)
+ val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) {
+ // Note that if there are only comments in the first block, the header would probably
+ // be not dropped.
+ CSVUtils.dropHeaderLine(lines, csvOptions)
+ } else {
+ lines
}
+ val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions)
- filteredIter.flatMap(parser.parse)
- }
- }
-
- /**
- * Returns the first line of the first non-empty file in path
- */
- private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
- import lines.sqlContext.implicits._
- val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
- if (options.isCommentSet) {
- nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
- } else {
- nonEmptyLines.first()
+ filteredLines.flatMap(parser.parse)
}
}
- private def readText(
+ private def createBaseDataset(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): Dataset[String] = {
@@ -215,22 +145,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession.createDataset(rdd)(Encoders.STRING)
}
}
-
- private def verifySchema(schema: StructType): Unit = {
- def verifyType(dataType: DataType): Unit = dataType match {
- case ByteType | ShortType | IntegerType | LongType | FloatType |
- DoubleType | BooleanType | _: DecimalType | TimestampType |
- DateType | StringType =>
-
- case udt: UserDefinedType[_] => verifyType(udt.sqlType)
-
- case _ =>
- throw new UnsupportedOperationException(
- s"CSV data source does not support ${dataType.simpleString} data type.")
- }
-
- schema.foreach(field => verifyType(field.dataType))
- }
}
private[csv] class CsvOutputWriter(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 065bf53574..485b186c7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,17 +18,15 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.text.NumberFormat
-import java.util.Locale
import scala.util.control.Exception._
-import scala.util.Try
-import org.apache.spark.rdd.RDD
+import com.univocity.parsers.csv.CsvParser
+
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
private[csv] object CSVInferSchema {
@@ -39,22 +37,76 @@ private[csv] object CSVInferSchema {
* 3. Replace any null types with string type
*/
def infer(
- tokenRdd: RDD[Array[String]],
- header: Array[String],
+ csv: Dataset[String],
+ caseSensitive: Boolean,
options: CSVOptions): StructType = {
- val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
- val rootTypes: Array[DataType] =
- tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
-
- val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
- val dType = rootType match {
- case _: NullType => StringType
- case other => other
+ val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first()
+ val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine)
+ val header = makeSafeHeader(firstRow, caseSensitive, options)
+
+ val fields = if (options.inferSchemaFlag) {
+ val tokenRdd = csv.rdd.mapPartitions { iter =>
+ val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options)
+ val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options)
+ val parser = new CsvParser(options.asParserSettings)
+ linesWithoutHeader.map(parser.parseLine)
+ }
+
+ val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
+ val rootTypes: Array[DataType] =
+ tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
+
+ header.zip(rootTypes).map { case (thisHeader, rootType) =>
+ val dType = rootType match {
+ case _: NullType => StringType
+ case other => other
+ }
+ StructField(thisHeader, dType, nullable = true)
}
- StructField(thisHeader, dType, nullable = true)
+ } else {
+ // By default fields are assumed to be StringType
+ header.map(fieldName => StructField(fieldName, StringType, nullable = true))
}
- StructType(structFields)
+ StructType(fields)
+ }
+
+ /**
+ * Generates a header from the given row which is null-safe and duplicate-safe.
+ */
+ private def makeSafeHeader(
+ row: Array[String],
+ caseSensitive: Boolean,
+ options: CSVOptions): Array[String] = {
+ if (options.headerFlag) {
+ val duplicates = {
+ val headerNames = row.filter(_ != null)
+ .map(name => if (caseSensitive) name else name.toLowerCase)
+ headerNames.diff(headerNames.distinct).distinct
+ }
+
+ row.zipWithIndex.map { case (value, index) =>
+ if (value == null || value.isEmpty || value == options.nullValue) {
+ // When there are empty strings or the values set in `nullValue`, put the
+ // index as the suffix.
+ s"_c$index"
+ } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+ // When there are case-insensitive duplicates, put the index as the suffix.
+ s"$value$index"
+ } else if (duplicates.contains(value)) {
+ // When there are duplicates, put the index as the suffix.
+ s"$value$index"
+ } else {
+ value
+ }
+ }
+ } else {
+ row.zipWithIndex.map { case (_, index) =>
+ // Uses default column names, "_c#" where # is its position of fields
+ // when header option is disabled.
+ s"_c$index"
+ }
+ }
}
private def inferRowType(options: CSVOptions)
@@ -215,32 +267,3 @@ private[csv] object CSVInferSchema {
case _ => None
}
}
-
-private[csv] object CSVTypeCast {
- /**
- * Helper method that converts string representation of a character to actual character.
- * It handles some Java escaped strings and throws exception if given string is longer than one
- * character.
- */
- @throws[IllegalArgumentException]
- def toChar(str: String): Char = {
- if (str.charAt(0) == '\\') {
- str.charAt(1)
- match {
- case 't' => '\t'
- case 'r' => '\r'
- case 'b' => '\b'
- case 'f' => '\f'
- case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
- case '\'' => '\''
- case 'u' if str == """\u0000""" => '\u0000'
- case _ =>
- throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
- }
- } else if (str.length == 1) {
- str.charAt(0)
- } else {
- throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 140ce23958..af456c8d71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -69,7 +69,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
}
}
- val delimiter = CSVTypeCast.toChar(
+ val delimiter = CSVUtils.toChar(
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val charset = parameters.getOrElse("encoding",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
deleted file mode 100644
index 19058c23ab..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import com.univocity.parsers.csv.CsvParser
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql._
-import org.apache.spark.sql.execution.datasources.PartitionedFile
-
-object CSVRelation extends Logging {
-
- def univocityTokenizer(
- file: Dataset[String],
- firstLine: String,
- params: CSVOptions): RDD[Array[String]] = {
- // If header is set, make sure firstLine is materialized before sending to executors.
- val commentPrefix = params.comment.toString
- file.rdd.mapPartitions { iter =>
- val parser = new CsvParser(params.asParserSettings)
- val filteredIter = iter.filter { line =>
- line.trim.nonEmpty && !line.startsWith(commentPrefix)
- }
- if (params.headerFlag) {
- filteredIter.filterNot(_ == firstLine).map { item =>
- parser.parseLine(item)
- }
- } else {
- filteredIter.map { item =>
- parser.parseLine(item)
- }
- }
- }
- }
-
- // Skips the header line of each file if the `header` option is set to true.
- def dropHeaderLine(
- file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = {
- // TODO What if the first partitioned file consists of only comments and empty lines?
- if (csvOptions.headerFlag && file.start == 0) {
- val nonEmptyLines = if (csvOptions.isCommentSet) {
- val commentPrefix = csvOptions.comment.toString
- lines.dropWhile { line =>
- line.trim.isEmpty || line.trim.startsWith(commentPrefix)
- }
- } else {
- lines.dropWhile(_.trim.isEmpty)
- }
-
- if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
new file mode 100644
index 0000000000..72b053d209
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+object CSVUtils {
+ /**
+ * Filter ignorable rows for CSV dataset (lines empty and starting with `comment`).
+ * This is currently being used in CSV schema inference.
+ */
+ def filterCommentAndEmpty(lines: Dataset[String], options: CSVOptions): Dataset[String] = {
+ // Note that this was separately made by SPARK-18362. Logically, this should be the same
+ // with the one below, `filterCommentAndEmpty` but execution path is different. One of them
+ // might have to be removed in the near future if possible.
+ import lines.sqlContext.implicits._
+ val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
+ if (options.isCommentSet) {
+ nonEmptyLines.filter(!$"value".startsWith(options.comment.toString))
+ } else {
+ nonEmptyLines
+ }
+ }
+
+ /**
+ * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
+ * This is currently being used in CSV reading path and CSV schema inference.
+ */
+ def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+ iter.filter { line =>
+ line.trim.nonEmpty && !line.startsWith(options.comment.toString)
+ }
+ }
+
+ /**
+ * Skip the given first line so that only data can remain in a dataset.
+ * This is similar with `dropHeaderLine` below and currently being used in CSV schema inference.
+ */
+ def filterHeaderLine(
+ iter: Iterator[String],
+ firstLine: String,
+ options: CSVOptions): Iterator[String] = {
+ // Note that unlike actual CSV reading path, it simply filters the given first line. Therefore,
+ // this skips the line same with the header if exists. One of them might have to be removed
+ // in the near future if possible.
+ if (options.headerFlag) {
+ iter.filterNot(_ == firstLine)
+ } else {
+ iter
+ }
+ }
+
+ /**
+ * Drop header line so that only data can remain.
+ * This is similar with `filterHeaderLine` above and currently being used in CSV reading path.
+ */
+ def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+ val nonEmptyLines = if (options.isCommentSet) {
+ val commentPrefix = options.comment.toString
+ iter.dropWhile { line =>
+ line.trim.isEmpty || line.trim.startsWith(commentPrefix)
+ }
+ } else {
+ iter.dropWhile(_.trim.isEmpty)
+ }
+
+ if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
+ iter
+ }
+
+ /**
+ * Helper method that converts string representation of a character to actual character.
+ * It handles some Java escaped strings and throws exception if given string is longer than one
+ * character.
+ */
+ @throws[IllegalArgumentException]
+ def toChar(str: String): Char = {
+ if (str.charAt(0) == '\\') {
+ str.charAt(1)
+ match {
+ case 't' => '\t'
+ case 'r' => '\r'
+ case 'b' => '\b'
+ case 'f' => '\f'
+ case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
+ case '\'' => '\''
+ case 'u' if str == """\u0000""" => '\u0000'
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
+ }
+ } else if (str.length == 1) {
+ str.charAt(0)
+ } else {
+ throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
+ }
+ }
+
+ /**
+ * Verify if the schema is supported in CSV datasource.
+ */
+ def verifySchema(schema: StructType): Unit = {
+ def verifyType(dataType: DataType): Unit = dataType match {
+ case ByteType | ShortType | IntegerType | LongType | FloatType |
+ DoubleType | BooleanType | _: DecimalType | TimestampType |
+ DateType | StringType =>
+
+ case udt: UserDefinedType[_] => verifyType(udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"CSV data source does not support ${dataType.simpleString} data type.")
+ }
+
+ schema.foreach(field => verifyType(field.dataType))
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index be1f94dbad..98ab9d2850 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -51,7 +51,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- val jsonSchema = InferSchema.infer(
+ val jsonSchema = JsonInferSchema.infer(
createBaseRdd(sparkSession, files),
columnNameOfCorruptRecord,
parsedOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index 330d04de66..f51c18d46f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-private[sql] object InferSchema {
+private[sql] object JsonInferSchema {
/**
* Infer the type of a collection of json records in three stages:
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
new file mode 100644
index 0000000000..221e44ce2c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.SparkFunSuite
+
+class CSVUtilsSuite extends SparkFunSuite {
+ test("Can parse escaped characters") {
+ assert(CSVUtils.toChar("""\t""") === '\t')
+ assert(CSVUtils.toChar("""\r""") === '\r')
+ assert(CSVUtils.toChar("""\b""") === '\b')
+ assert(CSVUtils.toChar("""\f""") === '\f')
+ assert(CSVUtils.toChar("""\"""") === '\"')
+ assert(CSVUtils.toChar("""\'""") === '\'')
+ assert(CSVUtils.toChar("""\u0000""") === '\u0000')
+ }
+
+ test("Does not accept delimiter larger than one character") {
+ val exception = intercept[IllegalArgumentException]{
+ CSVUtils.toChar("ab")
+ }
+ assert(exception.getMessage.contains("cannot be more than one character"))
+ }
+
+ test("Throws exception for unsupported escaped characters") {
+ val exception = intercept[IllegalArgumentException]{
+ CSVUtils.toChar("""\1""")
+ }
+ assert(exception.getMessage.contains("Unsupported special character for delimiter"))
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
index 2ca6308852..62dae08861 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
@@ -43,30 +43,6 @@ class UnivocityParserSuite extends SparkFunSuite {
}
}
- test("Can parse escaped characters") {
- assert(CSVTypeCast.toChar("""\t""") === '\t')
- assert(CSVTypeCast.toChar("""\r""") === '\r')
- assert(CSVTypeCast.toChar("""\b""") === '\b')
- assert(CSVTypeCast.toChar("""\f""") === '\f')
- assert(CSVTypeCast.toChar("""\"""") === '\"')
- assert(CSVTypeCast.toChar("""\'""") === '\'')
- assert(CSVTypeCast.toChar("""\u0000""") === '\u0000')
- }
-
- test("Does not accept delimiter larger than one character") {
- val exception = intercept[IllegalArgumentException]{
- CSVTypeCast.toChar("ab")
- }
- assert(exception.getMessage.contains("cannot be more than one character"))
- }
-
- test("Throws exception for unsupported escaped characters") {
- val exception = intercept[IllegalArgumentException]{
- CSVTypeCast.toChar("""\1""")
- }
- assert(exception.getMessage.contains("Unsupported special character for delimiter"))
- }
-
test("Nullable types are handled") {
val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 161a409d83..156fd965b4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
+import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
- val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
+ val emptySchema = JsonInferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
assert(StructType(Seq()) === emptySchema)
}
@@ -1390,7 +1390,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-8093 Erase empty structs") {
- val emptySchema = InferSchema.infer(
+ val emptySchema = JsonInferSchema.infer(
emptyRecords, "", new JSONOptions(Map.empty[String, String]))
assert(StructType(Seq()) === emptySchema)
}