diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-01-15 17:20:01 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-15 17:20:01 -0800 |
commit | 3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb (patch) | |
tree | 9be77df8147a125cadc46e9bc6da4641669b58da /sql/hive | |
parent | 8dbbf3e75e70e98391b4a1705472caddd129945a (diff) | |
download | spark-3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb.tar.gz spark-3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb.tar.bz2 spark-3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb.zip |
[SPARK-12649][SQL] support reading bucketed table
This PR adds the support to read bucketed tables, and correctly populate `outputPartitioning`, so that we can avoid shuffle for some cases.
TODO(follow-up PRs):
* bucket pruning
* avoid shuffle for bucketed table join when use any super-set of the bucketing key.
(we should re-visit it after https://issues.apache.org/jira/browse/SPARK-12704 is fixed)
* recognize hive bucketed table
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10604 from cloud-fan/bucket-read.
Diffstat (limited to 'sql/hive')
5 files changed, 203 insertions, 26 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 3d54048c24..0cfe03ba91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -143,19 +143,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - def partColsFromParts: Option[Seq[String]] = { - table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols => - (0 until numPartCols.toInt).map { index => - val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull - if (partCol == null) { + def getColumnNames(colType: String): Seq[String] = { + table.properties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map { + numCols => (0 until numCols.toInt).map { index => + table.properties.get(s"spark.sql.sources.schema.${colType}Col.$index").getOrElse { throw new AnalysisException( - "Could not read partitioned columns from the metastore because it is corrupted " + - s"(missing part $index of the it, $numPartCols parts are expected).") + s"Could not read $colType columns from the metastore because it is corrupted " + + s"(missing part $index of it, $numCols parts are expected).") } - - partCol } - } + }.getOrElse(Nil) } // Originally, we used spark.sql.sources.schema to store the schema of a data source table. @@ -170,7 +167,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // We only need names at here since userSpecifiedSchema we loaded from the metastore // contains partition columns. We can always get datatypes of partitioning columns // from userSpecifiedSchema. - val partitionColumns = partColsFromParts.getOrElse(Nil) + val partitionColumns = getColumnNames("part") + + val bucketSpec = table.properties.get("spark.sql.sources.schema.numBuckets").map { n => + BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort")) + } // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -181,6 +182,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive hive, userSpecifiedSchema, partitionColumns.toArray, + bucketSpec, table.properties("spark.sql.sources.provider"), options) @@ -282,7 +284,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) val dataSource = ResolvedDataSource( - hive, userSpecifiedSchema, partitionColumns, provider, options) + hive, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options) def newSparkSQLSpecificMetastoreTable(): HiveTable = { HiveTable( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 07a352873d..e703ac0164 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -213,7 +213,12 @@ case class CreateMetastoreDataSourceAsSelect( case SaveMode.Append => // Check if the specified data source match the data source of the existing table. val resolved = ResolvedDataSource( - sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) + sqlContext, + Some(query.schema.asNullable), + partitionColumns, + bucketSpec, + provider, + optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 14fa152c23..40409169b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -156,7 +156,7 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val bucketSpec: Option[BucketSpec], + override val maybeBucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala new file mode 100644 index 0000000000..58ecdd3b80 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.joins.SortMergeJoin +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("read bucketed data") { + val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + withTable("bucketed_table") { + df.write + .format("parquet") + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + val rdd = hiveContext.table("bucketed_table").filter($"i" === i).queryExecution.toRdd + assert(rdd.partitions.length == 8) + + val attrs = df.select("j", "k").schema.toAttributes + val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { + val getBucketId = UnsafeProjection.create( + HashPartitioning(attrs, 8).partitionIdExpression :: Nil, + attrs) + rows.map(row => getBucketId(row).getInt(0) == index) + }) + + assert(checkBucketId.collect().reduce(_ && _)) + } + } + } + + private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") + private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + + private def testBucketing( + bucketing1: DataFrameWriter => DataFrameWriter, + bucketing2: DataFrameWriter => DataFrameWriter, + joinColumns: Seq[String], + shuffleLeft: Boolean, + shuffleRight: Boolean): Unit = { + withTable("bucketed_table1", "bucketed_table2") { + bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1") + bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val t1 = hiveContext.table("bucketed_table1") + val t2 = hiveContext.table("bucketed_table2") + val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) + + // First check the result is corrected. + checkAnswer( + joined.sort("bucketed_table1.k", "bucketed_table2.k"), + df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) + + assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin]) + val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin] + + assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft) + assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight) + } + } + } + + private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { + joinCols.map(col => left(col) === right(col)).reduce(_ && _) + } + + test("avoid shuffle when join 2 bucketed tables") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + } + + // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 + ignore("avoid shuffle when join keys are a super-set of bucket keys") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") + testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + } + + test("only shuffle one side when join bucketed table and non-bucketed table") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + } + + test("only shuffle one side when 2 bucketed tables have different bucket number") { + val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j") + testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + } + + test("only shuffle one side when 2 bucketed tables have different bucket keys") { + val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i") + val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j") + testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true) + } + + test("shuffle when join keys are not equal to bucket keys") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") + testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true) + } + + test("shuffle when join 2 bucketed tables with bucketing disabled") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { + testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + } + } + + test("avoid shuffle when grouping keys are equal to bucket keys") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table") + val tbl = hiveContext.table("bucketed_table") + val agged = tbl.groupBy("i", "j").agg(max("k")) + + checkAnswer( + agged.sort("i", "j"), + df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) + + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + } + } + + test("avoid shuffle when grouping keys are a super-set of bucket keys") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val tbl = hiveContext.table("bucketed_table") + val agged = tbl.groupBy("i", "j").agg(max("k")) + + checkAnswer( + agged.sort("i", "j"), + df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) + + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + } + } + + test("fallback to non-bucketing mode if there exists any malformed bucket files") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + Utils.deleteRecursively(tableDir) + df1.write.parquet(tableDir.getAbsolutePath) + + val agged = hiveContext.table("bucketed_table").groupBy("i").count() + // make sure we fall back to non-bucketing mode and can't avoid shuffle + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined) + checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i")) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 3ea9826544..e812439bed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -62,15 +63,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) } - private val testFileName = """.*-(\d+)$""".r - private val otherFileName = """.*-(\d+)\..*""".r - private def getBucketId(fileName: String): Int = { - fileName match { - case testFileName(bucketId) => bucketId.toInt - case otherFileName(bucketId) => bucketId.toInt - } - } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") private def testBucketing( @@ -81,7 +73,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val allBucketFiles = dataDir.listFiles().filterNot(f => f.getName.startsWith(".") || f.getName.startsWith("_") ) - val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get) assert(groupedBucketFiles.size <= 8) for ((bucketId, bucketFiles) <- groupedBucketFiles) { @@ -98,12 +90,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() - val getHashCode = UnsafeProjection.create( + val getBucketId = UnsafeProjection.create( HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, qe.analyzed.output) for (row <- rows) { - val actualBucketId = getHashCode(row).getInt(0) + val actualBucketId = getBucketId(row).getInt(0) assert(actualBucketId == bucketId) } } |