aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-18 15:10:04 -0800
committerReynold Xin <rxin@databricks.com>2016-01-18 15:10:04 -0800
commit404190221a788ebc3a0cbf5cb47cf532436ce965 (patch)
treead7c211818a75098ae7ca31a5fb985ae7f9483f1
parent4f11e3f2aa4f097ed66296fe72b5b5384924010c (diff)
downloadspark-404190221a788ebc3a0cbf5cb47cf532436ce965.tar.gz
spark-404190221a788ebc3a0cbf5cb47cf532436ce965.tar.bz2
spark-404190221a788ebc3a0cbf5cb47cf532436ce965.zip
[SPARK-12882][SQL] simplify bucket tests and add more comments
Right now, the bucket tests are kind of hard to understand, this PR simplifies them and add more commetns. Author: Wenchen Fan <wenchen@databricks.com> Closes #10813 from cloud-fan/bucket-comment.
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala56
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala68
2 files changed, 78 insertions, 46 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 58ecdd3b80..150d0c7486 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.Exchange
+import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
+ /**
+ * A helper method to test the bucket read functionality using join. It will save `df1` and `df2`
+ * to hive tables, bucketed or not, according to the given bucket specifics. Next we will join
+ * these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle
+ * exists as user expected according to the `shuffleLeft` and `shuffleRight`.
+ */
private def testBucketing(
- bucketing1: DataFrameWriter => DataFrameWriter,
- bucketing2: DataFrameWriter => DataFrameWriter,
+ bucketSpecLeft: Option[BucketSpec],
+ bucketSpecRight: Option[BucketSpec],
joinColumns: Seq[String],
shuffleLeft: Boolean,
shuffleRight: Boolean): Unit = {
withTable("bucketed_table1", "bucketed_table2") {
- bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
- bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
+ def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
+ bucketSpec.map { spec =>
+ writer.bucketBy(
+ spec.numBuckets,
+ spec.bucketColumnNames.head,
+ spec.bucketColumnNames.tail: _*)
+ }.getOrElse(writer)
+ }
+
+ withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
+ withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val t1 = hiveContext.table("bucketed_table1")
@@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
test("avoid shuffle when join 2 bucketed tables") {
- val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
- testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
}
// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
- val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
- testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
+ testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
}
test("only shuffle one side when join bucketed table and non-bucketed table") {
- val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
- testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
}
test("only shuffle one side when 2 bucketed tables have different bucket number") {
- val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
- val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
- testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
+ val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
+ testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
}
test("only shuffle one side when 2 bucketed tables have different bucket keys") {
- val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
- val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
- testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
+ val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
+ val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
+ testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
}
test("shuffle when join keys are not equal to bucket keys") {
- val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
- testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
+ val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
+ testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
}
test("shuffle when join 2 bucketed tables with bucketing disabled") {
- val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
- testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
+ testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index e812439bed..dad1fc1273 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ /**
+ * A helper method to check the bucket write functionality in low level, i.e. check the written
+ * bucket files to see if the data are correct. User should pass in a data dir that these bucket
+ * files are written to, and the format of data(parquet, json, etc.), and the bucketing
+ * information.
+ */
private def testBucketing(
dataDir: File,
source: String,
+ numBuckets: Int,
bucketCols: Seq[String],
sortCols: Seq[String] = Nil): Unit = {
val allBucketFiles = dataDir.listFiles().filterNot(f =>
f.getName.startsWith(".") || f.getName.startsWith("_")
)
- val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
- assert(groupedBucketFiles.size <= 8)
-
- for ((bucketId, bucketFiles) <- groupedBucketFiles) {
- for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
- val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
- val columns = (bucketCols ++ sortCols).zip(types).map {
- case (colName, dt) => col(colName).cast(dt)
- }
- val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
- if (sortCols.nonEmpty) {
- checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
- }
+ for (bucketFile <- allBucketFiles) {
+ val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get
+ assert(bucketId >= 0 && bucketId < numBuckets)
- val qe = readBack.select(bucketCols.map(col): _*).queryExecution
- val rows = qe.toRdd.map(_.copy()).collect()
- val getBucketId = UnsafeProjection.create(
- HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
- qe.analyzed.output)
+ // We may loss the type information after write(e.g. json format doesn't keep schema
+ // information), here we get the types from the original dataframe.
+ val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
+ val columns = (bucketCols ++ sortCols).zip(types).map {
+ case (colName, dt) => col(colName).cast(dt)
+ }
- for (row <- rows) {
- val actualBucketId = getBucketId(row).getInt(0)
- assert(actualBucketId == bucketId)
- }
+ // Read the bucket file into a dataframe, so that it's easier to test.
+ val readBack = sqlContext.read.format(source)
+ .load(bucketFile.getAbsolutePath)
+ .select(columns: _*)
+
+ // If we specified sort columns while writing bucket table, make sure the data in this
+ // bucket file is already sorted.
+ if (sortCols.nonEmpty) {
+ checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
+ }
+
+ // Go through all rows in this bucket file, calculate bucket id according to bucket column
+ // values, and make sure it equals to the expected bucket id that inferred from file name.
+ val qe = readBack.select(bucketCols.map(col): _*).queryExecution
+ val rows = qe.toRdd.map(_.copy()).collect()
+ val getBucketId = UnsafeProjection.create(
+ HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil,
+ qe.analyzed.output)
+
+ for (row <- rows) {
+ val actualBucketId = getBucketId(row).getInt(0)
+ assert(actualBucketId == bucketId)
}
}
}
@@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
for (i <- 0 until 5) {
- testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k"))
+ testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
}
}
}
@@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
for (i <- 0 until 5) {
- testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k"))
+ testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k"))
}
}
}
@@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
.saveAsTable("bucketed_table")
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
- testBucketing(tableDir, source, Seq("i", "j"))
+ testBucketing(tableDir, source, 8, Seq("i", "j"))
}
}
}
@@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
.saveAsTable("bucketed_table")
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
- testBucketing(tableDir, source, Seq("i", "j"), Seq("k"))
+ testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
}
}
}