From 2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Nov 2015 20:25:58 -0800 Subject: [SPARK-11455][SQL] fix case sensitivity of partition by depend on `caseSensitive` to do column name equality check, instead of just `==` Author: Wenchen Fan Closes #9410 from cloud-fan/partition. --- .../execution/datasources/PartitioningUtils.scala | 7 +++--- .../execution/datasources/ResolvedDataSource.scala | 27 +++++++++++++++++----- .../spark/sql/execution/datasources/rules.scala | 6 +++-- .../org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 628c5e1893..16dc23661c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -287,10 +287,11 @@ private[sql] object PartitioningUtils { def validatePartitionColumnDataTypes( schema: StructType, - partitionColumns: Array[String]): Unit = { + partitionColumns: Array[String], + caseSensitive: Boolean): Unit = { - ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field => - field.dataType match { + ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { + field => field.dataType match { case _: AtomicType => // OK case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 54beabbf63..86a306b8f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -99,7 +99,8 @@ object ResolvedDataSource extends Logging { val maybePartitionsSchema = if (partitionColumns.isEmpty) { None } else { - Some(partitionColumnsSchema(schema, partitionColumns)) + Some(partitionColumnsSchema( + schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis)) } val caseInsensitiveOptions = new CaseInsensitiveMap(options) @@ -172,14 +173,24 @@ object ResolvedDataSource extends Logging { def partitionColumnsSchema( schema: StructType, - partitionColumns: Array[String]): StructType = { + partitionColumns: Array[String], + caseSensitive: Boolean): StructType = { + val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { + schema.find(f => equality(f.name, col)).getOrElse { throw new RuntimeException(s"Partition column $col not found in schema $schema") } }).asNullable } + private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { + if (caseSensitive) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ def apply( sqlContext: SQLContext, @@ -207,14 +218,18 @@ object ResolvedDataSource extends Logging { path.makeQualified(fs.getUri, fs.getWorkingDirectory) } - PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns) + val caseSensitive = sqlContext.conf.caseSensitiveAnalysis + PartitioningUtils.validatePartitionColumnDataTypes( + data.schema, partitionColumns, caseSensitive) - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) + val equality = columnNameEquality(caseSensitive) + val dataSchema = StructType( + data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) val r = dataSource.createRelation( sqlContext, Array(outputPath.toString), Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), + Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index abc016bf02..1a8e7ab202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -140,7 +140,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray) + PartitioningUtils.validatePartitionColumnDataTypes( + r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis) // Get all input data source relations of the query. val srcRelations = query.collect { @@ -190,7 +191,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns) + PartitioningUtils.validatePartitionColumnDataTypes( + query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) case _ => // OK } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a883bcb7b1..a9e6413423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1118,4 +1118,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { if (!allSequential) throw new SparkException("Partition should contain all sequential values") }) } + + test("fix case sensitivity of partition by") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val p = path.getAbsolutePath + Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) + checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + } + } + } } -- cgit v1.2.3