aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-15 17:20:01 -0800
committerReynold Xin <rxin@databricks.com>2016-01-15 17:20:01 -0800
commit3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb (patch)
tree9be77df8147a125cadc46e9bc6da4641669b58da /sql/hive
parent8dbbf3e75e70e98391b4a1705472caddd129945a (diff)
downloadspark-3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb.tar.gz
spark-3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb.tar.bz2
spark-3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb.zip
[SPARK-12649][SQL] support reading bucketed table
This PR adds the support to read bucketed tables, and correctly populate `outputPartitioning`, so that we can avoid shuffle for some cases. TODO(follow-up PRs): * bucket pruning * avoid shuffle for bucketed table join when use any super-set of the bucketing key. (we should re-visit it after https://issues.apache.org/jira/browse/SPARK-12704 is fixed) * recognize hive bucketed table Author: Wenchen Fan <wenchen@databricks.com> Closes #10604 from cloud-fan/bucket-read.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala26
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala178
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala16
5 files changed, 203 insertions, 26 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 3d54048c24..0cfe03ba91 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -143,19 +143,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
}
}
- def partColsFromParts: Option[Seq[String]] = {
- table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols =>
- (0 until numPartCols.toInt).map { index =>
- val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull
- if (partCol == null) {
+ def getColumnNames(colType: String): Seq[String] = {
+ table.properties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map {
+ numCols => (0 until numCols.toInt).map { index =>
+ table.properties.get(s"spark.sql.sources.schema.${colType}Col.$index").getOrElse {
throw new AnalysisException(
- "Could not read partitioned columns from the metastore because it is corrupted " +
- s"(missing part $index of the it, $numPartCols parts are expected).")
+ s"Could not read $colType columns from the metastore because it is corrupted " +
+ s"(missing part $index of it, $numCols parts are expected).")
}
-
- partCol
}
- }
+ }.getOrElse(Nil)
}
// Originally, we used spark.sql.sources.schema to store the schema of a data source table.
@@ -170,7 +167,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
// We only need names at here since userSpecifiedSchema we loaded from the metastore
// contains partition columns. We can always get datatypes of partitioning columns
// from userSpecifiedSchema.
- val partitionColumns = partColsFromParts.getOrElse(Nil)
+ val partitionColumns = getColumnNames("part")
+
+ val bucketSpec = table.properties.get("spark.sql.sources.schema.numBuckets").map { n =>
+ BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort"))
+ }
// It does not appear that the ql client for the metastore has a way to enumerate all the
// SerDe properties directly...
@@ -181,6 +182,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
hive,
userSpecifiedSchema,
partitionColumns.toArray,
+ bucketSpec,
table.properties("spark.sql.sources.provider"),
options)
@@ -282,7 +284,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf)
val dataSource = ResolvedDataSource(
- hive, userSpecifiedSchema, partitionColumns, provider, options)
+ hive, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options)
def newSparkSQLSpecificMetastoreTable(): HiveTable = {
HiveTable(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 07a352873d..e703ac0164 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -213,7 +213,12 @@ case class CreateMetastoreDataSourceAsSelect(
case SaveMode.Append =>
// Check if the specified data source match the data source of the existing table.
val resolved = ResolvedDataSource(
- sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath)
+ sqlContext,
+ Some(query.schema.asNullable),
+ partitionColumns,
+ bucketSpec,
+ provider,
+ optionsWithPath)
val createdRelation = LogicalRelation(resolved.relation)
EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match {
case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 14fa152c23..40409169b0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -156,7 +156,7 @@ private[sql] class OrcRelation(
maybeDataSchema: Option[StructType],
maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
- override val bucketSpec: Option[BucketSpec],
+ override val maybeBucketSpec: Option[BucketSpec],
parameters: Map[String, String])(
@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
new file mode 100644
index 0000000000..58ecdd3b80
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+
+import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.Exchange
+import org.apache.spark.sql.execution.joins.SortMergeJoin
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
+
+class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import testImplicits._
+
+ test("read bucketed data") {
+ val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ withTable("bucketed_table") {
+ df.write
+ .format("parquet")
+ .partitionBy("i")
+ .bucketBy(8, "j", "k")
+ .saveAsTable("bucketed_table")
+
+ for (i <- 0 until 5) {
+ val rdd = hiveContext.table("bucketed_table").filter($"i" === i).queryExecution.toRdd
+ assert(rdd.partitions.length == 8)
+
+ val attrs = df.select("j", "k").schema.toAttributes
+ val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => {
+ val getBucketId = UnsafeProjection.create(
+ HashPartitioning(attrs, 8).partitionIdExpression :: Nil,
+ attrs)
+ rows.map(row => getBucketId(row).getInt(0) == index)
+ })
+
+ assert(checkBucketId.collect().reduce(_ && _))
+ }
+ }
+ }
+
+ private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
+ private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
+
+ private def testBucketing(
+ bucketing1: DataFrameWriter => DataFrameWriter,
+ bucketing2: DataFrameWriter => DataFrameWriter,
+ joinColumns: Seq[String],
+ shuffleLeft: Boolean,
+ shuffleRight: Boolean): Unit = {
+ withTable("bucketed_table1", "bucketed_table2") {
+ bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
+ bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ val t1 = hiveContext.table("bucketed_table1")
+ val t2 = hiveContext.table("bucketed_table2")
+ val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
+
+ // First check the result is corrected.
+ checkAnswer(
+ joined.sort("bucketed_table1.k", "bucketed_table2.k"),
+ df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
+
+ assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin])
+ val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin]
+
+ assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft)
+ assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight)
+ }
+ }
+ }
+
+ private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
+ joinCols.map(col => left(col) === right(col)).reduce(_ && _)
+ }
+
+ test("avoid shuffle when join 2 bucketed tables") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ }
+
+ // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
+ ignore("avoid shuffle when join keys are a super-set of bucket keys") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+ testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ }
+
+ test("only shuffle one side when join bucketed table and non-bucketed table") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ }
+
+ test("only shuffle one side when 2 bucketed tables have different bucket number") {
+ val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
+ testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ }
+
+ test("only shuffle one side when 2 bucketed tables have different bucket keys") {
+ val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+ val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
+ testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
+ }
+
+ test("shuffle when join keys are not equal to bucket keys") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+ testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
+ }
+
+ test("shuffle when join 2 bucketed tables with bucketing disabled") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
+ testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
+ }
+ }
+
+ test("avoid shuffle when grouping keys are equal to bucket keys") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table")
+ val tbl = hiveContext.table("bucketed_table")
+ val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+ checkAnswer(
+ agged.sort("i", "j"),
+ df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ }
+ }
+
+ test("avoid shuffle when grouping keys are a super-set of bucket keys") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+ val tbl = hiveContext.table("bucketed_table")
+ val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+ checkAnswer(
+ agged.sort("i", "j"),
+ df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ }
+ }
+
+ test("fallback to non-bucketing mode if there exists any malformed bucket files") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+ val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+ Utils.deleteRecursively(tableDir)
+ df1.write.parquet(tableDir.getAbsolutePath)
+
+ val agged = hiveContext.table("bucketed_table").groupBy("i").count()
+ // make sure we fall back to non-bucketing mode and can't avoid shuffle
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined)
+ checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i"))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 3ea9826544..e812439bed 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -62,15 +63,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt"))
}
- private val testFileName = """.*-(\d+)$""".r
- private val otherFileName = """.*-(\d+)\..*""".r
- private def getBucketId(fileName: String): Int = {
- fileName match {
- case testFileName(bucketId) => bucketId.toInt
- case otherFileName(bucketId) => bucketId.toInt
- }
- }
-
private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
private def testBucketing(
@@ -81,7 +73,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val allBucketFiles = dataDir.listFiles().filterNot(f =>
f.getName.startsWith(".") || f.getName.startsWith("_")
)
- val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName))
+ val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
assert(groupedBucketFiles.size <= 8)
for ((bucketId, bucketFiles) <- groupedBucketFiles) {
@@ -98,12 +90,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
- val getHashCode = UnsafeProjection.create(
+ val getBucketId = UnsafeProjection.create(
HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
qe.analyzed.output)
for (row <- rows) {
- val actualBucketId = getHashCode(row).getInt(0)
+ val actualBucketId = getBucketId(row).getInt(0)
assert(actualBucketId == bucketId)
}
}