aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-03 20:25:58 -0800
committerYin Huai <yhuai@databricks.com>2015-11-03 20:25:58 -0800
commit2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd (patch)
treeadf2e33cd0088c229d7acc6f918c8854a1a758fd
parente352de0db2789919e1e0385b79f29b508a6b2b77 (diff)
downloadspark-2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd.tar.gz
spark-2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd.tar.bz2
spark-2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd.zip
[SPARK-11455][SQL] fix case sensitivity of partition by
depend on `caseSensitive` to do column name equality check, instead of just `==` Author: Wenchen Fan <wenchen@databricks.com> Closes #9410 from cloud-fan/partition.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala10
4 files changed, 39 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 628c5e1893..16dc23661c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -287,10 +287,11 @@ private[sql] object PartitioningUtils {
def validatePartitionColumnDataTypes(
schema: StructType,
- partitionColumns: Array[String]): Unit = {
+ partitionColumns: Array[String],
+ caseSensitive: Boolean): Unit = {
- ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field =>
- field.dataType match {
+ ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach {
+ field => field.dataType match {
case _: AtomicType => // OK
case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index 54beabbf63..86a306b8f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -99,7 +99,8 @@ object ResolvedDataSource extends Logging {
val maybePartitionsSchema = if (partitionColumns.isEmpty) {
None
} else {
- Some(partitionColumnsSchema(schema, partitionColumns))
+ Some(partitionColumnsSchema(
+ schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis))
}
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
@@ -172,14 +173,24 @@ object ResolvedDataSource extends Logging {
def partitionColumnsSchema(
schema: StructType,
- partitionColumns: Array[String]): StructType = {
+ partitionColumns: Array[String],
+ caseSensitive: Boolean): StructType = {
+ val equality = columnNameEquality(caseSensitive)
StructType(partitionColumns.map { col =>
- schema.find(_.name == col).getOrElse {
+ schema.find(f => equality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema $schema")
}
}).asNullable
}
+ private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = {
+ if (caseSensitive) {
+ org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+ } else {
+ org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+ }
+ }
+
/** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */
def apply(
sqlContext: SQLContext,
@@ -207,14 +218,18 @@ object ResolvedDataSource extends Logging {
path.makeQualified(fs.getUri, fs.getWorkingDirectory)
}
- PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns)
+ val caseSensitive = sqlContext.conf.caseSensitiveAnalysis
+ PartitioningUtils.validatePartitionColumnDataTypes(
+ data.schema, partitionColumns, caseSensitive)
- val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name)))
+ val equality = columnNameEquality(caseSensitive)
+ val dataSchema = StructType(
+ data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
val r = dataSource.createRelation(
sqlContext,
Array(outputPath.toString),
Some(dataSchema.asNullable),
- Some(partitionColumnsSchema(data.schema, partitionColumns)),
+ Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)),
caseInsensitiveOptions)
// For partitioned relation r, r.schema's column ordering can be different from the column
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index abc016bf02..1a8e7ab202 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -140,7 +140,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
- PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray)
+ PartitioningUtils.validatePartitionColumnDataTypes(
+ r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis)
// Get all input data source relations of the query.
val srcRelations = query.collect {
@@ -190,7 +191,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
- PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns)
+ PartitioningUtils.validatePartitionColumnDataTypes(
+ query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis)
case _ => // OK
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a883bcb7b1..a9e6413423 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1118,4 +1118,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
if (!allSequential) throw new SparkException("Partition should contain all sequential values")
})
}
+
+ test("fix case sensitivity of partition by") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ withTempPath { path =>
+ val p = path.getAbsolutePath
+ Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p)
+ checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012))
+ }
+ }
+ }
}