diff options
3 files changed, 25 insertions, 48 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 81657d9e47..748ebba3e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -215,7 +215,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - bucketSpec = getBucketSpec, options = extraOptions.toMap) dataSource.write(mode, df) @@ -270,52 +269,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { ifNotExists = false)).toRdd } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => - cols.map(normalize(_, "Partition")) - } - - private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => - cols.map(normalize(_, "Bucketing")) - } - - private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => - cols.map(normalize(_, "Sorting")) - } - private def getBucketSpec: Option[BucketSpec] = { if (sortColumnNames.isDefined) { require(numBuckets.isDefined, "sortBy must be used together with bucketBy") } - for { - n <- numBuckets - } yield { + numBuckets.map { n => require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") - - // partitionBy columns cannot be used in bucketBy - if (normalizedParCols.nonEmpty && - normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) { - throw new AnalysisException( - s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " + - s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'") - } - - BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) + BucketSpec(n, bucketColumnNames.get, sortColumnNames.getOrElse(Nil)) } } - /** - * The given column name may not be equal to any of the existing column names if we were in - * case-insensitive context. Normalize the given column name to the real one so that we don't - * need to care about case sensitivity afterwards. - */ - private def normalize(columnName: String, columnType: String): String = { - val validColumnNames = df.logicalPlan.output.map(_.name) - validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) - .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + - s"existing columns (${validColumnNames.mkString(", ")})")) - } - private def assertNotBucketed(operation: String): Unit = { if (numBuckets.isDefined || sortColumnNames.isDefined) { throw new AnalysisException(s"'$operation' does not support bucketing right now") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index e053a0e9e2..1c3e7c6d52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -226,9 +226,21 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } checkDuplication(columnNames, "table definition of " + table.identifier) - table.copy( - partitionColumnNames = normalizePartitionColumns(schema, table), - bucketSpec = normalizeBucketSpec(schema, table)) + val normalizedPartCols = normalizePartitionColumns(schema, table) + val normalizedBucketSpec = normalizeBucketSpec(schema, table) + + normalizedBucketSpec.foreach { spec => + for (bucketCol <- spec.bucketColumnNames if normalizedPartCols.contains(bucketCol)) { + throw new AnalysisException(s"bucketing column '$bucketCol' should not be part of " + + s"partition columns '${normalizedPartCols.mkString(", ")}'") + } + for (sortCol <- spec.sortColumnNames if normalizedPartCols.contains(sortCol)) { + throw new AnalysisException(s"bucket sorting column '$sortCol' should not be part of " + + s"partition columns '${normalizedPartCols.mkString(", ")}'") + } + } + + table.copy(partitionColumnNames = normalizedPartCols, bucketSpec = normalizedBucketSpec) } private def normalizePartitionColumns(schema: StructType, table: CatalogTable): Seq[String] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 2eafe18b85..8528dfc4ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -169,19 +169,20 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } - test("write bucketed data with the overlapping bucketBy and partitionBy columns") { - intercept[AnalysisException](df.write + test("write bucketed data with the overlapping bucketBy/sortBy and partitionBy columns") { + val e1 = intercept[AnalysisException](df.write .partitionBy("i", "j") .bucketBy(8, "j", "k") .sortBy("k") .saveAsTable("bucketed_table")) - } + assert(e1.message.contains("bucketing column 'j' should not be part of partition columns")) - test("write bucketed data with the identical bucketBy and partitionBy columns") { - intercept[AnalysisException](df.write - .partitionBy("i") - .bucketBy(8, "i") + val e2 = intercept[AnalysisException](df.write + .partitionBy("i", "j") + .bucketBy(8, "k") + .sortBy("i") .saveAsTable("bucketed_table")) + assert(e2.message.contains("bucket sorting column 'i' should not be part of partition columns")) } test("write bucketed data without partitionBy") { |