aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala18
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala15
3 files changed, 25 insertions, 48 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 81657d9e47..748ebba3e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -215,7 +215,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
- bucketSpec = getBucketSpec,
options = extraOptions.toMap)
dataSource.write(mode, df)
@@ -270,52 +269,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
ifNotExists = false)).toRdd
}
- private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
- cols.map(normalize(_, "Partition"))
- }
-
- private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols =>
- cols.map(normalize(_, "Bucketing"))
- }
-
- private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols =>
- cols.map(normalize(_, "Sorting"))
- }
-
private def getBucketSpec: Option[BucketSpec] = {
if (sortColumnNames.isDefined) {
require(numBuckets.isDefined, "sortBy must be used together with bucketBy")
}
- for {
- n <- numBuckets
- } yield {
+ numBuckets.map { n =>
require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.")
-
- // partitionBy columns cannot be used in bucketBy
- if (normalizedParCols.nonEmpty &&
- normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) {
- throw new AnalysisException(
- s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " +
- s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'")
- }
-
- BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil))
+ BucketSpec(n, bucketColumnNames.get, sortColumnNames.getOrElse(Nil))
}
}
- /**
- * The given column name may not be equal to any of the existing column names if we were in
- * case-insensitive context. Normalize the given column name to the real one so that we don't
- * need to care about case sensitivity afterwards.
- */
- private def normalize(columnName: String, columnType: String): String = {
- val validColumnNames = df.logicalPlan.output.map(_.name)
- validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName))
- .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " +
- s"existing columns (${validColumnNames.mkString(", ")})"))
- }
-
private def assertNotBucketed(operation: String): Unit = {
if (numBuckets.isDefined || sortColumnNames.isDefined) {
throw new AnalysisException(s"'$operation' does not support bucketing right now")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index e053a0e9e2..1c3e7c6d52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -226,9 +226,21 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
}
checkDuplication(columnNames, "table definition of " + table.identifier)
- table.copy(
- partitionColumnNames = normalizePartitionColumns(schema, table),
- bucketSpec = normalizeBucketSpec(schema, table))
+ val normalizedPartCols = normalizePartitionColumns(schema, table)
+ val normalizedBucketSpec = normalizeBucketSpec(schema, table)
+
+ normalizedBucketSpec.foreach { spec =>
+ for (bucketCol <- spec.bucketColumnNames if normalizedPartCols.contains(bucketCol)) {
+ throw new AnalysisException(s"bucketing column '$bucketCol' should not be part of " +
+ s"partition columns '${normalizedPartCols.mkString(", ")}'")
+ }
+ for (sortCol <- spec.sortColumnNames if normalizedPartCols.contains(sortCol)) {
+ throw new AnalysisException(s"bucket sorting column '$sortCol' should not be part of " +
+ s"partition columns '${normalizedPartCols.mkString(", ")}'")
+ }
+ }
+
+ table.copy(partitionColumnNames = normalizedPartCols, bucketSpec = normalizedBucketSpec)
}
private def normalizePartitionColumns(schema: StructType, table: CatalogTable): Seq[String] = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 2eafe18b85..8528dfc4ce 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -169,19 +169,20 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
}
}
- test("write bucketed data with the overlapping bucketBy and partitionBy columns") {
- intercept[AnalysisException](df.write
+ test("write bucketed data with the overlapping bucketBy/sortBy and partitionBy columns") {
+ val e1 = intercept[AnalysisException](df.write
.partitionBy("i", "j")
.bucketBy(8, "j", "k")
.sortBy("k")
.saveAsTable("bucketed_table"))
- }
+ assert(e1.message.contains("bucketing column 'j' should not be part of partition columns"))
- test("write bucketed data with the identical bucketBy and partitionBy columns") {
- intercept[AnalysisException](df.write
- .partitionBy("i")
- .bucketBy(8, "i")
+ val e2 = intercept[AnalysisException](df.write
+ .partitionBy("i", "j")
+ .bucketBy(8, "k")
+ .sortBy("i")
.saveAsTable("bucketed_table"))
+ assert(e2.message.contains("bucket sorting column 'i' should not be part of partition columns"))
}
test("write bucketed data without partitionBy") {