aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorXiao Li <gatorsmile@gmail.com>2017-02-21 19:30:36 -0800
committergatorsmile <gatorsmile@gmail.com>2017-02-21 19:30:36 -0800
commit1a45d2b2cc6466841fb73da21a61b61f14a5d5fb (patch)
treed14cd627686ba2d06da8a468f8541a5084be6303 /sql/core
parent17d83e1ee5f14e759c6e3bf0a4cba3346f00fc48 (diff)
downloadspark-1a45d2b2cc6466841fb73da21a61b61f14a5d5fb.tar.gz
spark-1a45d2b2cc6466841fb73da21a61b61f14a5d5fb.tar.bz2
spark-1a45d2b2cc6466841fb73da21a61b61f14a5d5fb.zip
[SPARK-19670][SQL][TEST] Enable Bucketed Table Reading and Writing Testing Without Hive Support
### What changes were proposed in this pull request? Bucketed table reading and writing does not need Hive support. We can move the test cases from `sql/hive` to `sql/core`. After this PR, we can improve the test case coverage. Bucket table reading and writing can be tested with and without Hive support. ### How was this patch tested? N/A Author: Xiao Li <gatorsmile@gmail.com> Closes #17004 from gatorsmile/mvTestCaseForBuckets.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala572
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala249
2 files changed, 821 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
new file mode 100644
index 0000000000..9b65419dba
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -0,0 +1,572 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+import java.net.URI
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec}
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.BitSet
+
+class BucketedReadWithoutHiveSupportSuite extends BucketedReadSuite with SharedSQLContext {
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
+ }
+}
+
+
+abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
+ import testImplicits._
+
+ private lazy val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ private lazy val nullDF = (for {
+ i <- 0 to 50
+ s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
+ } yield (i % 5, s, i % 13)).toDF("i", "j", "k")
+
+ test("read bucketed data") {
+ withTable("bucketed_table") {
+ df.write
+ .format("parquet")
+ .partitionBy("i")
+ .bucketBy(8, "j", "k")
+ .saveAsTable("bucketed_table")
+
+ for (i <- 0 until 5) {
+ val table = spark.table("bucketed_table").filter($"i" === i)
+ val query = table.queryExecution
+ val output = query.analyzed.output
+ val rdd = query.toRdd
+
+ assert(rdd.partitions.length == 8)
+
+ val attrs = table.select("j", "k").queryExecution.analyzed.output
+ val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => {
+ val getBucketId = UnsafeProjection.create(
+ HashPartitioning(attrs, 8).partitionIdExpression :: Nil,
+ output)
+ rows.map(row => getBucketId(row).getInt(0) -> index)
+ })
+ checkBucketId.collect().foreach(r => assert(r._1 == r._2))
+ }
+ }
+ }
+
+ // To verify if the bucket pruning works, this function checks two conditions:
+ // 1) Check if the pruned buckets (before filtering) are empty.
+ // 2) Verify the final result is the same as the expected one
+ private def checkPrunedAnswers(
+ bucketSpec: BucketSpec,
+ bucketValues: Seq[Integer],
+ filterCondition: Column,
+ originalDataFrame: DataFrame): Unit = {
+ // This test verifies parts of the plan. Disable whole stage codegen.
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k")
+ val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
+ // Limit: bucket pruning only works when the bucket column has one and only one column
+ assert(bucketColumnNames.length == 1)
+ val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
+ val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
+ val matchedBuckets = new BitSet(numBuckets)
+ bucketValues.foreach { value =>
+ matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
+ }
+
+ // Filter could hide the bug in bucket pruning. Thus, skipping all the filters
+ val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
+ val rdd = plan.find(_.isInstanceOf[DataSourceScanExec])
+ assert(rdd.isDefined, plan)
+
+ val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
+ if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator()
+ }
+ // TODO: These tests are not testing the right columns.
+// // checking if all the pruned buckets are empty
+// val invalidBuckets = checkedResult.collect().toList
+// if (invalidBuckets.nonEmpty) {
+// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan")
+// }
+
+ checkAnswer(
+ bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
+ originalDataFrame.filter(filterCondition).orderBy("i", "j", "k"))
+ }
+ }
+
+ test("read partitioning bucketed tables with bucket pruning filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ // Case 1: EqualTo
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j,
+ df)
+
+ // Case 2: EqualNullSafe
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" <=> j,
+ df)
+
+ // Case 3: In
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = Seq(j, j + 1, j + 2, j + 3),
+ filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),
+ df)
+ }
+ }
+ }
+
+ test("read non-partitioning bucketed tables with bucket pruning filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j,
+ df)
+ }
+ }
+ }
+
+ test("read partitioning bucketed tables having null in bucketing key") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ nullDF.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ // Case 1: isNull
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = null :: Nil,
+ filterCondition = $"j".isNull,
+ nullDF)
+
+ // Case 2: <=> null
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = null :: Nil,
+ filterCondition = $"j" <=> null,
+ nullDF)
+ }
+ }
+
+ test("read partitioning bucketed tables having composite filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j && $"k" > $"j",
+ df)
+
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j && $"i" > j % 5,
+ df)
+ }
+ }
+ }
+
+ private lazy val df1 =
+ (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
+ private lazy val df2 =
+ (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
+
+ case class BucketedTableTestSpec(
+ bucketSpec: Option[BucketSpec],
+ numPartitions: Int = 10,
+ expectedShuffle: Boolean = true,
+ expectedSort: Boolean = true)
+
+ /**
+ * A helper method to test the bucket read functionality using join. It will save `df1` and `df2`
+ * to hive tables, bucketed or not, according to the given bucket specifics. Next we will join
+ * these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle
+ * exists as user expected according to the `shuffleLeft` and `shuffleRight`.
+ */
+ private def testBucketing(
+ bucketedTableTestSpecLeft: BucketedTableTestSpec,
+ bucketedTableTestSpecRight: BucketedTableTestSpec,
+ joinType: String = "inner",
+ joinCondition: (DataFrame, DataFrame) => Column): Unit = {
+ val BucketedTableTestSpec(bucketSpecLeft, numPartitionsLeft, shuffleLeft, sortLeft) =
+ bucketedTableTestSpecLeft
+ val BucketedTableTestSpec(bucketSpecRight, numPartitionsRight, shuffleRight, sortRight) =
+ bucketedTableTestSpecRight
+
+ withTable("bucketed_table1", "bucketed_table2") {
+ def withBucket(
+ writer: DataFrameWriter[Row],
+ bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = {
+ bucketSpec.map { spec =>
+ writer.bucketBy(
+ spec.numBuckets,
+ spec.bucketColumnNames.head,
+ spec.bucketColumnNames.tail: _*)
+
+ if (spec.sortColumnNames.nonEmpty) {
+ writer.sortBy(
+ spec.sortColumnNames.head,
+ spec.sortColumnNames.tail: _*
+ )
+ } else {
+ writer
+ }
+ }.getOrElse(writer)
+ }
+
+ withBucket(df1.repartition(numPartitionsLeft).write.format("parquet"), bucketSpecLeft)
+ .saveAsTable("bucketed_table1")
+ withBucket(df2.repartition(numPartitionsRight).write.format("parquet"), bucketSpecRight)
+ .saveAsTable("bucketed_table2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ val t1 = spark.table("bucketed_table1")
+ val t2 = spark.table("bucketed_table2")
+ val joined = t1.join(t2, joinCondition(t1, t2), joinType)
+
+ // First check the result is corrected.
+ checkAnswer(
+ joined.sort("bucketed_table1.k", "bucketed_table2.k"),
+ df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))
+
+ assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
+ val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
+
+ // check existence of shuffle
+ assert(
+ joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft,
+ s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}")
+ assert(
+ joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight,
+ s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}")
+
+ // check existence of sort
+ assert(
+ joinOperator.left.find(_.isInstanceOf[SortExec]).isDefined == sortLeft,
+ s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}")
+ assert(
+ joinOperator.right.find(_.isInstanceOf[SortExec]).isDefined == sortRight,
+ s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")
+ }
+ }
+ }
+
+ private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = {
+ joinCols.map(col => left(col) === right(col)).reduce(_ && _)
+ }
+
+ test("avoid shuffle when join 2 bucketed tables") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
+ ignore("avoid shuffle when join keys are a super-set of bucket keys") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ test("only shuffle one side when join bucketed table and non-bucketed table") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ test("only shuffle one side when 2 bucketed tables have different bucket number") {
+ val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ val bucketSpecRight = Some(BucketSpec(5, Seq("i", "j"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpecLeft, expectedShuffle = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpecRight, expectedShuffle = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ test("only shuffle one side when 2 bucketed tables have different bucket keys") {
+ val bucketSpecLeft = Some(BucketSpec(8, Seq("i"), Nil))
+ val bucketSpecRight = Some(BucketSpec(8, Seq("j"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpecLeft, expectedShuffle = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpecRight, expectedShuffle = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i"))
+ )
+ }
+
+ test("shuffle when join keys are not equal to bucket keys") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("j"))
+ )
+ }
+
+ test("shuffle when join 2 bucketed tables with bucketing disabled") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = true)
+ withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+ }
+
+ test("check sort and shuffle when bucket and sort columns are join keys") {
+ // In case of bucketing, its possible to have multiple files belonging to the
+ // same bucket in a given relation. Each of these files are locally sorted
+ // but those files combined together are not globally sorted. Given that,
+ // the RDD partition will not be sorted even if the relation has sort columns set
+ // Therefore, we still need to keep the Sort in both sides.
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
+
+ val bucketedTableTestSpecLeft1 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true)
+ val bucketedTableTestSpecRight1 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft1,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight1,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+
+ val bucketedTableTestSpecLeft2 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ val bucketedTableTestSpecRight2 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft2,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight2,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+
+ val bucketedTableTestSpecLeft3 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true)
+ val bucketedTableTestSpecRight3 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft3,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight3,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+
+ val bucketedTableTestSpecLeft4 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ val bucketedTableTestSpecRight4 = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft4,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight4,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ test("avoid shuffle and sort when sort columns are a super set of join keys") {
+ val bucketSpecLeft = Some(BucketSpec(8, Seq("i"), Seq("i", "j")))
+ val bucketSpecRight = Some(BucketSpec(8, Seq("i"), Seq("i", "k")))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(
+ bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(
+ bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i"))
+ )
+ }
+
+ test("only sort one side when sort columns are different") {
+ val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
+ val bucketSpecRight = Some(BucketSpec(8, Seq("i", "j"), Seq("k")))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(
+ bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(
+ bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ test("only sort one side when sort columns are same but their ordering is different") {
+ val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
+ val bucketSpecRight = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i")))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(
+ bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(
+ bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = true)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+ }
+
+ test("avoid shuffle when grouping keys are equal to bucket keys") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table")
+ val tbl = spark.table("bucketed_table")
+ val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+ checkAnswer(
+ agged.sort("i", "j"),
+ df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
+ }
+ }
+
+ test("avoid shuffle when grouping keys are a super-set of bucket keys") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+ val tbl = spark.table("bucketed_table")
+ val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+ checkAnswer(
+ agged.sort("i", "j"),
+ df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
+ }
+ }
+
+ test("SPARK-17698 Join predicates should not contain filter clauses") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i")))
+ val bucketedTableTestSpecLeft = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ val bucketedTableTestSpecRight = BucketedTableTestSpec(
+ bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
+ bucketedTableTestSpecRight = bucketedTableTestSpecRight,
+ joinType = "fullouter",
+ joinCondition = (left: DataFrame, right: DataFrame) => {
+ val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _)
+ val filterLeft = left("i") === Literal("1")
+ val filterRight = right("i") === Literal("1")
+ joinPredicates && filterLeft && filterRight
+ }
+ )
+ }
+
+ test("error if there exists any malformed bucket files") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+ val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath
+ val tableDir = new File(warehouseFilePath, "bucketed_table")
+ Utils.deleteRecursively(tableDir)
+ df1.write.parquet(tableDir.getAbsolutePath)
+
+ val agged = spark.table("bucketed_table").groupBy("i").count()
+ val error = intercept[Exception] {
+ agged.count()
+ }
+
+ assert(error.getCause().toString contains "Invalid bucket file")
+ }
+ }
+
+ test("disable bucketing when the output doesn't contain all bucketing columns") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+
+ checkAnswer(spark.table("bucketed_table").select("j"), df1.select("j"))
+
+ checkAnswer(spark.table("bucketed_table").groupBy("j").agg(max("k")),
+ df1.groupBy("j").agg(max("k")))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
new file mode 100644
index 0000000000..9082261af7
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+import java.net.URI
+
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.datasources.BucketingUtils
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+
+class BucketedWriteWithoutHiveSupportSuite extends BucketedWriteSuite with SharedSQLContext {
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
+ }
+
+ override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "json")
+}
+
+abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils {
+ import testImplicits._
+
+ protected def fileFormatsToTest: Seq[String]
+
+ test("bucketed by non-existing column") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt"))
+ }
+
+ test("numBuckets be greater than 0 but less than 100000") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+
+ Seq(-1, 0, 100000).foreach(numBuckets => {
+ val e = intercept[AnalysisException](df.write.bucketBy(numBuckets, "i").saveAsTable("tt"))
+ assert(
+ e.getMessage.contains("Number of buckets should be greater than 0 but less than 100000"))
+ })
+ }
+
+ test("specify sorting columns without bucketing columns") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt"))
+ }
+
+ test("sorting by non-orderable column") {
+ val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j")
+ intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt"))
+ }
+
+ test("write bucketed data using save()") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+
+ val e = intercept[AnalysisException] {
+ df.write.bucketBy(2, "i").parquet("/tmp/path")
+ }
+ assert(e.getMessage == "'save' does not support bucketing right now;")
+ }
+
+ test("write bucketed data using insertInto()") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+
+ val e = intercept[AnalysisException] {
+ df.write.bucketBy(2, "i").insertInto("tt")
+ }
+ assert(e.getMessage == "'insertInto' does not support bucketing right now;")
+ }
+
+ private lazy val df = {
+ (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ }
+
+ def tableDir: File = {
+ val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table")
+ new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier)))
+ }
+
+ /**
+ * A helper method to check the bucket write functionality in low level, i.e. check the written
+ * bucket files to see if the data are correct. User should pass in a data dir that these bucket
+ * files are written to, and the format of data(parquet, json, etc.), and the bucketing
+ * information.
+ */
+ private def testBucketing(
+ dataDir: File,
+ source: String,
+ numBuckets: Int,
+ bucketCols: Seq[String],
+ sortCols: Seq[String] = Nil): Unit = {
+ val allBucketFiles = dataDir.listFiles().filterNot(f =>
+ f.getName.startsWith(".") || f.getName.startsWith("_")
+ )
+
+ for (bucketFile <- allBucketFiles) {
+ val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse {
+ fail(s"Unable to find the related bucket files.")
+ }
+
+ // Remove the duplicate columns in bucketCols and sortCols;
+ // Otherwise, we got analysis errors due to duplicate names
+ val selectedColumns = (bucketCols ++ sortCols).distinct
+ // We may lose the type information after write(e.g. json format doesn't keep schema
+ // information), here we get the types from the original dataframe.
+ val types = df.select(selectedColumns.map(col): _*).schema.map(_.dataType)
+ val columns = selectedColumns.zip(types).map {
+ case (colName, dt) => col(colName).cast(dt)
+ }
+
+ // Read the bucket file into a dataframe, so that it's easier to test.
+ val readBack = spark.read.format(source)
+ .load(bucketFile.getAbsolutePath)
+ .select(columns: _*)
+
+ // If we specified sort columns while writing bucket table, make sure the data in this
+ // bucket file is already sorted.
+ if (sortCols.nonEmpty) {
+ checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
+ }
+
+ // Go through all rows in this bucket file, calculate bucket id according to bucket column
+ // values, and make sure it equals to the expected bucket id that inferred from file name.
+ val qe = readBack.select(bucketCols.map(col): _*).queryExecution
+ val rows = qe.toRdd.map(_.copy()).collect()
+ val getBucketId = UnsafeProjection.create(
+ HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil,
+ qe.analyzed.output)
+
+ for (row <- rows) {
+ val actualBucketId = getBucketId(row).getInt(0)
+ assert(actualBucketId == bucketId)
+ }
+ }
+ }
+
+ test("write bucketed data") {
+ for (source <- fileFormatsToTest) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .partitionBy("i")
+ .bucketBy(8, "j", "k")
+ .saveAsTable("bucketed_table")
+
+ for (i <- 0 until 5) {
+ testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
+ }
+ }
+ }
+ }
+
+ test("write bucketed data with sortBy") {
+ for (source <- fileFormatsToTest) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .partitionBy("i")
+ .bucketBy(8, "j")
+ .sortBy("k")
+ .saveAsTable("bucketed_table")
+
+ for (i <- 0 until 5) {
+ testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k"))
+ }
+ }
+ }
+ }
+
+ test("write bucketed data with the overlapping bucketBy/sortBy and partitionBy columns") {
+ val e1 = intercept[AnalysisException](df.write
+ .partitionBy("i", "j")
+ .bucketBy(8, "j", "k")
+ .sortBy("k")
+ .saveAsTable("bucketed_table"))
+ assert(e1.message.contains("bucketing column 'j' should not be part of partition columns"))
+
+ val e2 = intercept[AnalysisException](df.write
+ .partitionBy("i", "j")
+ .bucketBy(8, "k")
+ .sortBy("i")
+ .saveAsTable("bucketed_table"))
+ assert(e2.message.contains("bucket sorting column 'i' should not be part of partition columns"))
+ }
+
+ test("write bucketed data without partitionBy") {
+ for (source <- fileFormatsToTest) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .bucketBy(8, "i", "j")
+ .saveAsTable("bucketed_table")
+
+ testBucketing(tableDir, source, 8, Seq("i", "j"))
+ }
+ }
+ }
+
+ test("write bucketed data without partitionBy with sortBy") {
+ for (source <- fileFormatsToTest) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .bucketBy(8, "i", "j")
+ .sortBy("k")
+ .saveAsTable("bucketed_table")
+
+ testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
+ }
+ }
+ }
+
+ test("write bucketed data with bucketing disabled") {
+ // The configuration BUCKETING_ENABLED does not affect the writing path
+ withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
+ for (source <- fileFormatsToTest) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .partitionBy("i")
+ .bucketBy(8, "j", "k")
+ .saveAsTable("bucketed_table")
+
+ for (i <- 0 until 5) {
+ testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
+ }
+ }
+ }
+ }
+ }
+}