diff options
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala | 13 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala | 11 |
2 files changed, 17 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 412f5fa87e..fef3255c73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -347,15 +347,14 @@ private[sql] object DataSourceScanExec { case _ => None } - def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { - throw new AnalysisException(s"bucket column $colName not found in existing columns " + - s"(${output.map(_.name).mkString(", ")})") - } - bucketSpec.map { spec => val numBuckets = spec.numBuckets - val bucketColumns = spec.bucketColumnNames.map(toAttribute) - HashPartitioning(bucketColumns, numBuckets) + val bucketColumns = spec.bucketColumnNames.flatMap { n => output.find(_.name == n) } + if (bucketColumns.size == spec.bucketColumnNames.size) { + HashPartitioning(bucketColumns, numBuckets) + } else { + UnknownPartitioning(0) + } }.getOrElse { UnknownPartitioning(0) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index f9891ac571..bab0092c37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -362,4 +362,15 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet assert(error.toString contains "Invalid bucket file") } } + + test("disable bucketing when the output doesn't contain all bucketing columns") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + + checkAnswer(hiveContext.table("bucketed_table").select("j"), df1.select("j")) + + checkAnswer(hiveContext.table("bucketed_table").groupBy("j").agg(max("k")), + df1.groupBy("j").agg(max("k"))) + } + } } |