aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-10-11 10:21:22 +0800
committerWenchen Fan <wenchen@databricks.com>2016-10-11 10:21:22 +0800
commit90217f9deed01ae187e28ef1531491aac8ee50c9 (patch)
treed9a689ce213536b8e2bfe2aa0ef3cfb483a2261e /sql/core
parentd5ec4a3e014494a3d991a6350caffbc3b17be0fd (diff)
downloadspark-90217f9deed01ae187e28ef1531491aac8ee50c9.tar.gz
spark-90217f9deed01ae187e28ef1531491aac8ee50c9.tar.bz2
spark-90217f9deed01ae187e28ef1531491aac8ee50c9.zip
[SPARK-16896][SQL] Handle duplicated field names in header consistently with null or empty strings in CSV
## What changes were proposed in this pull request? Currently, CSV datasource allows to load duplicated empty string fields or fields having `nullValue` in the header. It'd be great if this can deal with normal fields as well. This PR proposes handling the duplicates consistently with the existing behaviour with considering case-sensitivity (`spark.sql.caseSensitive`) as below: data below: ``` fieldA,fieldB,,FIELDA,fielda,, 1,2,3,4,5,6,7 ``` is parsed as below: ```scala spark.read.format("csv").option("header", "true").load("test.csv").show() ``` - when `spark.sql.caseSensitive` is `false` (by default). ``` +-------+------+---+-------+-------+---+---+ |fieldA0|fieldB|_c2|FIELDA3|fieldA4|_c5|_c6| +-------+------+---+-------+-------+---+---+ | 1| 2| 3| 4| 5| 6| 7| +-------+------+---+-------+-------+---+---+ ``` - when `spark.sql.caseSensitive` is `true`. ``` +-------+------+---+-------+-------+---+---+ |fieldA0|fieldB|_c2| FIELDA|fieldA4|_c5|_c6| +-------+------+---+-------+-------+---+---+ | 1| 2| 3| 4| 5| 6| 7| +-------+------+---+-------+-------+---+---+ ``` **In more details**, There is a good reference about this problem, `read.csv()` in R. So, I initially wanted to propose the similar behaviour. In case of R, the CSV data below: ``` fieldA,fieldB,,fieldA,fieldA,, 1,2,3,4,5,6,7 ``` is parsed as below: ```r test <- read.csv(file="test.csv",header=TRUE,sep=",") > test fieldA fieldB X fieldA.1 fieldA.2 X.1 X.2 1 1 2 3 4 5 6 7 ``` However, Spark CSV datasource already is handling duplicated empty strings and `nullValue` as field names. So the data below: ``` ,,,fieldA,,fieldB, 1,2,3,4,5,6,7 ``` is parsed as below: ```scala spark.read.format("csv").option("header", "true").load("test.csv").show() ``` ``` +---+---+---+------+---+------+---+ |_c0|_c1|_c2|fieldA|_c4|fieldB|_c6| +---+---+---+------+---+------+---+ | 1| 2| 3| 4| 5| 6| 7| +---+---+---+------+---+------+---+ ``` R starts the number for each duplicate but Spark adds the number for its position for all fields for `nullValue` and empty strings. In terms of case-sensitivity, it seems R is case-sensitive as below: (it seems it is not configurable). ``` a,a,a,A,A 1,2,3,4,5 ``` is parsed as below: ```r test <- read.csv(file="test.csv",header=TRUE,sep=",") > test a a.1 a.2 A A.1 1 1 2 3 4 5 ``` ## How was this patch tested? Unit test in `CSVSuite`. Author: hyukjinkwon <gurwls223@gmail.com> Closes #14745 from HyukjinKwon/SPARK-16896.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala50
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala2
3 files changed, 74 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 4e662a52a7..a3691158ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -59,14 +59,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)
-
- val header = if (csvOptions.headerFlag) {
- firstRow.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value
- }
- } else {
- firstRow.zipWithIndex.map { case (value, index) => s"_c$index" }
- }
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)
val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val schema = if (csvOptions.inferSchemaFlag) {
@@ -74,13 +68,51 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
- StructField(fieldName.toString, StringType, nullable = true)
+ StructField(fieldName, StringType, nullable = true)
}
StructType(schemaFields)
}
Some(schema)
}
+ /**
+ * Generates a header from the given row which is null-safe and duplicate-safe.
+ */
+ private def makeSafeHeader(
+ row: Array[String],
+ options: CSVOptions,
+ caseSensitive: Boolean): Array[String] = {
+ if (options.headerFlag) {
+ val duplicates = {
+ val headerNames = row.filter(_ != null)
+ .map(name => if (caseSensitive) name else name.toLowerCase)
+ headerNames.diff(headerNames.distinct).distinct
+ }
+
+ row.zipWithIndex.map { case (value, index) =>
+ if (value == null || value.isEmpty || value == options.nullValue) {
+ // When there are empty strings or the values set in `nullValue`, put the
+ // index as the suffix.
+ s"_c$index"
+ } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+ // When there are case-insensitive duplicates, put the index as the suffix.
+ s"$value$index"
+ } else if (duplicates.contains(value)) {
+ // When there are duplicates, put the index as the suffix.
+ s"$value$index"
+ } else {
+ value
+ }
+ }
+ } else {
+ row.zipWithIndex.map { case (_, index) =>
+ // Uses default column names, "_c#" where # is its position of fields
+ // when header option is disabled.
+ s"_c$index"
+ }
+ }
+ }
+
override def prepareWrite(
sparkSession: SparkSession,
job: Job,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 29aac9def6..f7c22c6c93 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
@@ -856,4 +857,36 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat)
}
}
+
+ test("load duplicated field names consistently with null or empty strings - case sensitive") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ withTempPath { path =>
+ Seq("a,a,c,A,b,B").toDF().write.text(path.getAbsolutePath)
+ val actualSchema = spark.read
+ .format("csv")
+ .option("header", true)
+ .load(path.getAbsolutePath)
+ .schema
+ val fields = Seq("a0", "a1", "c", "A", "b", "B").map(StructField(_, StringType, true))
+ val expectedSchema = StructType(fields)
+ assert(actualSchema == expectedSchema)
+ }
+ }
+ }
+
+ test("load duplicated field names consistently with null or empty strings - case insensitive") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ withTempPath { path =>
+ Seq("a,A,c,A,b,B").toDF().write.text(path.getAbsolutePath)
+ val actualSchema = spark.read
+ .format("csv")
+ .option("header", true)
+ .load(path.getAbsolutePath)
+ .schema
+ val fields = Seq("a0", "A1", "c", "A3", "b4", "B5").map(StructField(_, StringType, true))
+ val expectedSchema = StructType(fields)
+ assert(actualSchema == expectedSchema)
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index dae92f626c..51832a13cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -18,8 +18,6 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.sql.{Date, Timestamp}
-import java.text.SimpleDateFormat
import java.util.Locale
import org.apache.spark.SparkFunSuite