diff options
author | hyukjinkwon <gurwls223@gmail.com> | 2016-10-11 10:21:22 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-10-11 10:21:22 +0800 |
commit | 90217f9deed01ae187e28ef1531491aac8ee50c9 (patch) | |
tree | d9a689ce213536b8e2bfe2aa0ef3cfb483a2261e /sql/core | |
parent | d5ec4a3e014494a3d991a6350caffbc3b17be0fd (diff) | |
download | spark-90217f9deed01ae187e28ef1531491aac8ee50c9.tar.gz spark-90217f9deed01ae187e28ef1531491aac8ee50c9.tar.bz2 spark-90217f9deed01ae187e28ef1531491aac8ee50c9.zip |
[SPARK-16896][SQL] Handle duplicated field names in header consistently with null or empty strings in CSV
## What changes were proposed in this pull request?
Currently, CSV datasource allows to load duplicated empty string fields or fields having `nullValue` in the header. It'd be great if this can deal with normal fields as well.
This PR proposes handling the duplicates consistently with the existing behaviour with considering case-sensitivity (`spark.sql.caseSensitive`) as below:
data below:
```
fieldA,fieldB,,FIELDA,fielda,,
1,2,3,4,5,6,7
```
is parsed as below:
```scala
spark.read.format("csv").option("header", "true").load("test.csv").show()
```
- when `spark.sql.caseSensitive` is `false` (by default).
```
+-------+------+---+-------+-------+---+---+
|fieldA0|fieldB|_c2|FIELDA3|fieldA4|_c5|_c6|
+-------+------+---+-------+-------+---+---+
| 1| 2| 3| 4| 5| 6| 7|
+-------+------+---+-------+-------+---+---+
```
- when `spark.sql.caseSensitive` is `true`.
```
+-------+------+---+-------+-------+---+---+
|fieldA0|fieldB|_c2| FIELDA|fieldA4|_c5|_c6|
+-------+------+---+-------+-------+---+---+
| 1| 2| 3| 4| 5| 6| 7|
+-------+------+---+-------+-------+---+---+
```
**In more details**,
There is a good reference about this problem, `read.csv()` in R. So, I initially wanted to propose the similar behaviour.
In case of R, the CSV data below:
```
fieldA,fieldB,,fieldA,fieldA,,
1,2,3,4,5,6,7
```
is parsed as below:
```r
test <- read.csv(file="test.csv",header=TRUE,sep=",")
> test
fieldA fieldB X fieldA.1 fieldA.2 X.1 X.2
1 1 2 3 4 5 6 7
```
However, Spark CSV datasource already is handling duplicated empty strings and `nullValue` as field names. So the data below:
```
,,,fieldA,,fieldB,
1,2,3,4,5,6,7
```
is parsed as below:
```scala
spark.read.format("csv").option("header", "true").load("test.csv").show()
```
```
+---+---+---+------+---+------+---+
|_c0|_c1|_c2|fieldA|_c4|fieldB|_c6|
+---+---+---+------+---+------+---+
| 1| 2| 3| 4| 5| 6| 7|
+---+---+---+------+---+------+---+
```
R starts the number for each duplicate but Spark adds the number for its position for all fields for `nullValue` and empty strings.
In terms of case-sensitivity, it seems R is case-sensitive as below: (it seems it is not configurable).
```
a,a,a,A,A
1,2,3,4,5
```
is parsed as below:
```r
test <- read.csv(file="test.csv",header=TRUE,sep=",")
> test
a a.1 a.2 A A.1
1 1 2 3 4 5
```
## How was this patch tested?
Unit test in `CSVSuite`.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #14745 from HyukjinKwon/SPARK-16896.
Diffstat (limited to 'sql/core')
3 files changed, 74 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 4e662a52a7..a3691158ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -59,14 +59,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val rdd = baseRdd(sparkSession, csvOptions, paths) val firstLine = findFirstLine(csvOptions, rdd) val firstRow = new CsvReader(csvOptions).parseLine(firstLine) - - val header = if (csvOptions.headerFlag) { - firstRow.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value - } - } else { - firstRow.zipWithIndex.map { case (value, index) => s"_c$index" } - } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) val schema = if (csvOptions.inferSchemaFlag) { @@ -74,13 +68,51 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => - StructField(fieldName.toString, StringType, nullable = true) + StructField(fieldName, StringType, nullable = true) } StructType(schemaFields) } Some(schema) } + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + private def makeSafeHeader( + row: Array[String], + options: CSVOptions, + caseSensitive: Boolean): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } + override def prepareWrite( sparkSession: SparkSession, job: Job, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 29aac9def6..f7c22c6c93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -856,4 +857,36 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat) } } + + test("load duplicated field names consistently with null or empty strings - case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + Seq("a,a,c,A,b,B").toDF().write.text(path.getAbsolutePath) + val actualSchema = spark.read + .format("csv") + .option("header", true) + .load(path.getAbsolutePath) + .schema + val fields = Seq("a0", "a1", "c", "A", "b", "B").map(StructField(_, StringType, true)) + val expectedSchema = StructType(fields) + assert(actualSchema == expectedSchema) + } + } + } + + test("load duplicated field names consistently with null or empty strings - case insensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + Seq("a,A,c,A,b,B").toDF().write.text(path.getAbsolutePath) + val actualSchema = spark.read + .format("csv") + .option("header", true) + .load(path.getAbsolutePath) + .schema + val fields = Seq("a0", "A1", "c", "A3", "b4", "B5").map(StructField(_, StringType, true)) + val expectedSchema = StructType(fields) + assert(actualSchema == expectedSchema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index dae92f626c..51832a13cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat import java.util.Locale import org.apache.spark.SparkFunSuite |