diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-01-18 15:10:04 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-18 15:10:04 -0800 |
commit | 404190221a788ebc3a0cbf5cb47cf532436ce965 (patch) | |
tree | ad7c211818a75098ae7ca31a5fb985ae7f9483f1 /sql/hive | |
parent | 4f11e3f2aa4f097ed66296fe72b5b5384924010c (diff) | |
download | spark-404190221a788ebc3a0cbf5cb47cf532436ce965.tar.gz spark-404190221a788ebc3a0cbf5cb47cf532436ce965.tar.bz2 spark-404190221a788ebc3a0cbf5cb47cf532436ce965.zip |
[SPARK-12882][SQL] simplify bucket tests and add more comments
Right now, the bucket tests are kind of hard to understand, this PR simplifies them and add more commetns.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10813 from cloud-fan/bucket-comment.
Diffstat (limited to 'sql/hive')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala | 56 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala | 68 |
2 files changed, 78 insertions, 46 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 58ecdd3b80..150d0c7486 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + /** + * A helper method to test the bucket read functionality using join. It will save `df1` and `df2` + * to hive tables, bucketed or not, according to the given bucket specifics. Next we will join + * these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle + * exists as user expected according to the `shuffleLeft` and `shuffleRight`. + */ private def testBucketing( - bucketing1: DataFrameWriter => DataFrameWriter, - bucketing2: DataFrameWriter => DataFrameWriter, + bucketSpecLeft: Option[BucketSpec], + bucketSpecRight: Option[BucketSpec], joinColumns: Seq[String], shuffleLeft: Boolean, shuffleRight: Boolean): Unit = { withTable("bucketed_table1", "bucketed_table2") { - bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1") - bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2") + def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = { + bucketSpec.map { spec => + writer.bucketBy( + spec.numBuckets, + spec.bucketColumnNames.head, + spec.bucketColumnNames.tail: _*) + }.getOrElse(writer) + } + + withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1") + withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val t1 = hiveContext.table("bucketed_table1") @@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } test("avoid shuffle when join 2 bucketed tables") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") - testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") - testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) } test("only shuffle one side when join bucketed table and non-bucketed table") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") - testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) } test("only shuffle one side when 2 bucketed tables have different bucket number") { - val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") - val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j") - testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) + testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { - val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i") - val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j") - testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true) + val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) + testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) } test("shuffle when join keys are not equal to bucket keys") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") - testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true) + val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) } test("shuffle when join 2 bucketed tables with bucketing disabled") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index e812439bed..dad1fc1273 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + /** + * A helper method to check the bucket write functionality in low level, i.e. check the written + * bucket files to see if the data are correct. User should pass in a data dir that these bucket + * files are written to, and the format of data(parquet, json, etc.), and the bucketing + * information. + */ private def testBucketing( dataDir: File, source: String, + numBuckets: Int, bucketCols: Seq[String], sortCols: Seq[String] = Nil): Unit = { val allBucketFiles = dataDir.listFiles().filterNot(f => f.getName.startsWith(".") || f.getName.startsWith("_") ) - val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get) - assert(groupedBucketFiles.size <= 8) - - for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) { - val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) - val columns = (bucketCols ++ sortCols).zip(types).map { - case (colName, dt) => col(colName).cast(dt) - } - val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*) - if (sortCols.nonEmpty) { - checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) - } + for (bucketFile <- allBucketFiles) { + val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get + assert(bucketId >= 0 && bucketId < numBuckets) - val qe = readBack.select(bucketCols.map(col): _*).queryExecution - val rows = qe.toRdd.map(_.copy()).collect() - val getBucketId = UnsafeProjection.create( - HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, - qe.analyzed.output) + // We may loss the type information after write(e.g. json format doesn't keep schema + // information), here we get the types from the original dataframe. + val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) + val columns = (bucketCols ++ sortCols).zip(types).map { + case (colName, dt) => col(colName).cast(dt) + } - for (row <- rows) { - val actualBucketId = getBucketId(row).getInt(0) - assert(actualBucketId == bucketId) - } + // Read the bucket file into a dataframe, so that it's easier to test. + val readBack = sqlContext.read.format(source) + .load(bucketFile.getAbsolutePath) + .select(columns: _*) + + // If we specified sort columns while writing bucket table, make sure the data in this + // bucket file is already sorted. + if (sortCols.nonEmpty) { + checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) + } + + // Go through all rows in this bucket file, calculate bucket id according to bucket column + // values, and make sure it equals to the expected bucket id that inferred from file name. + val qe = readBack.select(bucketCols.map(col): _*).queryExecution + val rows = qe.toRdd.map(_.copy()).collect() + val getBucketId = UnsafeProjection.create( + HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil, + qe.analyzed.output) + + for (row <- rows) { + val actualBucketId = getBucketId(row).getInt(0) + assert(actualBucketId == bucketId) } } } @@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) } } } @@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k")) } } } @@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .saveAsTable("bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - testBucketing(tableDir, source, Seq("i", "j")) + testBucketing(tableDir, source, 8, Seq("i", "j")) } } } @@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .saveAsTable("bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) + testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k")) } } } |