diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-11-16 17:12:18 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-11-16 17:12:18 +0800 |
commit | 74f5c2176d8449e41f520febd38109edaf3f4172 (patch) | |
tree | 6c21dd0924f8c83ee4b12df9bb92a5e822f2f5c0 /sql/core | |
parent | 95eb06bd7d0f7110ef62c8d1cb6337c72b10d99f (diff) | |
download | spark-74f5c2176d8449e41f520febd38109edaf3f4172.tar.gz spark-74f5c2176d8449e41f520febd38109edaf3f4172.tar.bz2 spark-74f5c2176d8449e41f520febd38109edaf3f4172.zip |
[SPARK-18433][SQL] Improve DataSource option keys to be more case-insensitive
## What changes were proposed in this pull request?
This PR aims to improve DataSource option keys to be more case-insensitive
DataSource partially use CaseInsensitiveMap in code-path. For example, the following fails to find url.
```scala
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
df.write.format("jdbc")
.option("UrL", url1)
.option("dbtable", "TEST.SAVETEST")
.options(properties.asScala)
.save()
```
This PR makes DataSource options to use CaseInsensitiveMap internally and also makes DataSource to use CaseInsensitiveMap generally except `InMemoryFileIndex` and `InsertIntoHadoopFsRelationCommand`. We can not pass them CaseInsensitiveMap because they creates new case-sensitive HadoopConfs by calling newHadoopConfWithOptions(options) inside.
## How was this patch tested?
Pass the Jenkins test with newly added test cases.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #15884 from dongjoon-hyun/SPARK-18433.
Diffstat (limited to 'sql/core')
12 files changed, 82 insertions, 45 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6c1c398940..588aa05c37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison} import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, PredicateHelper} -import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 65422f1495..cfee7be1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -80,13 +81,13 @@ case class DataSource( lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) lazy val sourceInfo = sourceSchema() + private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. */ private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { userSpecifiedSchema.map(_ -> partitionColumns).orElse { - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => val hdfsPath = new Path(path) @@ -114,11 +115,10 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( - sparkSession.sqlContext, userSpecifiedSchema, className, options) + sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions) SourceInfo(name, schema, Nil) case format: FileFormat => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) @@ -158,10 +158,14 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => s.createSource( - sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options) + sparkSession.sqlContext, + metadataPath, + userSpecifiedSchema, + className, + caseInsensitiveOptions) case format: FileFormat => - val path = new CaseInsensitiveMap(options).getOrElse("path", { + val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) new FileStreamSource( @@ -171,7 +175,7 @@ case class DataSource( schema = sourceInfo.schema, partitionColumns = sourceInfo.partitionColumns, metadataPath = metadataPath, - options = options) + options = caseInsensitiveOptions) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") @@ -182,10 +186,9 @@ case class DataSource( def createSink(outputMode: OutputMode): Sink = { providingClass.newInstance() match { case s: StreamSinkProvider => - s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) + s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode) case fileFormat: FileFormat => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) @@ -193,7 +196,7 @@ case class DataSource( throw new IllegalArgumentException( s"Data source $className does not support $outputMode output mode") } - new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, options) + new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, caseInsensitiveOptions) case _ => throw new UnsupportedOperationException( @@ -234,7 +237,6 @@ case class DataSource( * that files already exist, we don't need to check them again. */ def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => @@ -274,7 +276,7 @@ case class DataSource( dataSchema = dataSchema, bucketSpec = None, format, - options)(sparkSession) + caseInsensitiveOptions)(sparkSession) // This is a non-streaming file based datasource. case (format: FileFormat, _) => @@ -358,13 +360,13 @@ case class DataSource( providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, options, data) + dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) case format: FileFormat => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; // 3. It's OK that the output path doesn't exist yet; - val allPaths = paths ++ new CaseInsensitiveMap(options).get("path") + val allPaths = paths ++ caseInsensitiveOptions.get("path") val outputPath = if (allPaths.length == 1) { val path = new Path(allPaths.head) val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) @@ -391,7 +393,7 @@ case class DataSource( // TODO: Case sensitivity. val sameColumns = existingPartitionColumns.map(_.toLowerCase()) == partitionColumns.map(_.toLowerCase()) - if (existingPartitionColumns.size > 0 && !sameColumns) { + if (existingPartitionColumns.nonEmpty && !sameColumns) { throw new AnalysisException( s"""Requested partitioning does not match existing partitioning. |Existing partitioning columns: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 5903729c11..21e50307b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -23,11 +23,13 @@ import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} -private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) +private[csv] class CSVOptions(@transient private val parameters: CaseInsensitiveMap) extends Logging with Serializable { + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) + private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) paramValue match { @@ -128,7 +130,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str object CSVOptions { - def apply(): CSVOptions = new CSVOptions(Map.empty) + def apply(): CSVOptions = new CSVOptions(new CaseInsensitiveMap(Map.empty)) def apply(paramName: String, paramValue: String): CSVOptions = { new CSVOptions(Map(paramName -> paramValue)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 59fb48ffea..fa8dfa9640 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -96,21 +96,3 @@ case class RefreshResource(path: String) Seq.empty[Row] } } - -/** - * Builds a map in which keys are case insensitive - */ -class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] - with Serializable { - - val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) - - override def get(k: String): Option[String] = baseMap.get(k.toLowerCase) - - override def + [B1 >: String](kv: (String, B1)): Map[String, B1] = - baseMap + kv.copy(_1 = kv._1.toLowerCase) - - override def iterator: Iterator[(String, String)] = baseMap.iterator - - override def -(key: String): Map[String, String] = baseMap - key.toLowerCase -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index fcd7409159..7f419b5788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -22,19 +22,23 @@ import java.util.Properties import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + /** * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: Map[String, String]) + @transient private val parameters: CaseInsensitiveMap) extends Serializable { import JDBCOptions._ + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) + def this(url: String, table: String, parameters: Map[String, String]) = { - this(parameters ++ Map( + this(new CaseInsensitiveMap(parameters ++ Map( JDBCOptions.JDBC_URL -> url, - JDBCOptions.JDBC_TABLE_NAME -> table)) + JDBCOptions.JDBC_TABLE_NAME -> table))) } val asConnectionProperties: Properties = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index d0fd23605b..a81a95d510 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,18 +19,22 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ private[parquet] class ParquetOptions( - @transient private val parameters: Map[String, String], + @transient private val parameters: CaseInsensitiveMap, @transient private val sqlConf: SQLConf) extends Serializable { import ParquetOptions._ + def this(parameters: Map[String, String], sqlConf: SQLConf) = + this(new CaseInsensitiveMap(parameters), sqlConf) + /** * Compression codec to use. By default use the value specified in SQLConf. * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 3efc20c1d6..fdea65cb10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming import scala.util.Try import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.util.Utils /** * User specified options for file streams. */ -class FileStreamOptions(parameters: Map[String, String]) extends Logging { +class FileStreamOptions(parameters: CaseInsensitiveMap) extends Logging { + + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => Try(str.toInt).toOption.filter(_ > 0).getOrElse { @@ -50,5 +52,5 @@ class FileStreamOptions(parameters: Map[String, String]) extends Logging { /** Options as specified by the user, in a case-insensitive map, without "path" set. */ val optionMapWithoutPath: Map[String, String] = - new CaseInsensitiveMap(parameters).filterKeys(_ != "path") + parameters.filterKeys(_ != "path") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 5e00f669b8..93f752d107 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -109,4 +109,9 @@ class CSVInferSchemaSuite extends SparkFunSuite { val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm")) + assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 456052f79a..598e44ec8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map())) + val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String])) assert(StructType(Seq()) === emptySchema) } @@ -1390,7 +1390,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map())) + val emptySchema = InferSchema.infer( + emptyRecords, "", new JSONOptions(Map.empty[String, String])) assert(StructType(Seq()) === emptySchema) } @@ -1749,4 +1750,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(stringTimestampsWithFormat, expectedStringDatesWithFormat) } } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val records = sparkContext + .parallelize("""{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 0.000001}""" :: Nil) + + val schema = StructType( + StructField("a", DecimalType(21, 1), true) :: + StructField("b", DecimalType(7, 6), true) :: Nil) + + val df1 = spark.read.option("prefersDecimal", "true").json(records) + assert(df1.schema == schema) + val df2 = spark.read.option("PREfersdecimaL", "true").json(records) + assert(df2.schema == schema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 580eade4b1..acdadb3103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -736,6 +736,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val option = new ParquetOptions(Map("Compression" -> "uncompressed"), spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 96540ec92d..e3d3c6c3a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -303,4 +303,13 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + " and 'numPartitions' are required.")) } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + df.write.format("jdbc") + .option("Url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index fab7642994..b365af76c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1004,6 +1004,11 @@ class FileStreamSourceSuite extends FileStreamSourceTest { ) } } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val options = new FileStreamOptions(Map("maxfilespertrigger" -> "1")) + assert(options.maxFilesPerTrigger == Some(1)) + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest { |