aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala55
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala26
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala178
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala16
18 files changed, 314 insertions, 45 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 8f852e5216..634c1bd473 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -109,6 +109,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
partitionColumns = Array.empty[String],
+ bucketSpec = None,
provider = source,
options = extraOptions.toMap)
DataFrame(sqlContext, LogicalRelation(resolved.relation))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 7976795ff5..4e3662724c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -422,6 +422,10 @@ private[spark] object SQLConf {
doc = "The maximum number of concurrent files to open before falling back on sorting when " +
"writing out files using dynamic partitioning.")
+ val BUCKETING_ENABLED = booleanConf("spark.sql.sources.bucketing.enabled",
+ defaultValue = Some(true),
+ doc = "When false, we will treat bucketed table as normal table")
+
// The output committer class used by HadoopFsRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
//
@@ -590,6 +594,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
private[spark] def parallelPartitionDiscoveryThreshold: Int =
getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD)
+ private[spark] def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED)
+
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 569a21feaa..92cfd5f841 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
import org.apache.spark.sql.types.DataType
@@ -98,7 +99,8 @@ private[sql] case class PhysicalRDD(
rdd: RDD[InternalRow],
override val nodeName: String,
override val metadata: Map[String, String] = Map.empty,
- isUnsafeRow: Boolean = false)
+ isUnsafeRow: Boolean = false,
+ override val outputPartitioning: Partitioning = UnknownPartitioning(0))
extends LeafNode {
protected override def doExecute(): RDD[InternalRow] = {
@@ -130,6 +132,24 @@ private[sql] object PhysicalRDD {
metadata: Map[String, String] = Map.empty): PhysicalRDD = {
// All HadoopFsRelations output UnsafeRows
val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation]
- PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows)
+
+ val bucketSpec = relation match {
+ case r: HadoopFsRelation => r.getBucketSpec
+ case _ => None
+ }
+
+ def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse {
+ throw new AnalysisException(s"bucket column $colName not found in existing columns " +
+ s"(${output.map(_.name).mkString(", ")})")
+ }
+
+ bucketSpec.map { spec =>
+ val numBuckets = spec.numBuckets
+ val bucketColumns = spec.bucketColumnNames.map(toAttribute)
+ val partitioning = HashPartitioning(bucketColumns, numBuckets)
+ PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows, partitioning)
+ }.getOrElse {
+ PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
index 7a8691e7cb..314c957d57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
@@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
|Actual: ${partitionColumns.mkString(", ")}
""".stripMargin)
- val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) {
+ val writerContainer = if (partitionColumns.isEmpty && relation.getBucketSpec.isEmpty) {
new DefaultWriterContainer(relation, job, isAppend)
} else {
val output = df.queryExecution.executedPlan.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index ece9b8a9a9..cc8dcf5930 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -97,6 +97,7 @@ object ResolvedDataSource extends Logging {
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
partitionColumns: Array[String],
+ bucketSpec: Option[BucketSpec],
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
@@ -142,6 +143,7 @@ object ResolvedDataSource extends Logging {
paths,
Some(dataSchema),
maybePartitionsSchema,
+ bucketSpec,
caseInsensitiveOptions)
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
throw new AnalysisException(s"$className does not allow user-specified schemas.")
@@ -173,7 +175,7 @@ object ResolvedDataSource extends Logging {
SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString)
}
}
- dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions)
+ dataSource.createRelation(sqlContext, paths, None, None, None, caseInsensitiveOptions)
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
throw new AnalysisException(
s"A schema needs to be specified when using $className.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index fc77529b7d..563fd9eefc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -311,7 +311,7 @@ private[sql] class DynamicPartitionWriterContainer(
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
- private val bucketSpec = relation.bucketSpec
+ private val bucketSpec = relation.getBucketSpec
private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap {
spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
index 9976829638..c7ecd6125d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
@@ -44,9 +44,7 @@ private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProv
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation =
- // TODO: throw exception here as we won't call this method during execution, after bucketed read
- // support is finished.
- createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters)
+ throw new UnsupportedOperationException("use the overload version with bucketSpec parameter")
}
private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory {
@@ -54,5 +52,20 @@ private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFact
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter =
- throw new UnsupportedOperationException("use bucket version")
+ throw new UnsupportedOperationException("use the overload version with bucketSpec parameter")
+}
+
+private[sql] object BucketingUtils {
+ // The file name of bucketed data should have 3 parts:
+ // 1. some other information in the head of file name, ends with `-`
+ // 2. bucket id part, some numbers
+ // 3. optional file extension part, in the tail of file name, starts with `.`
+ // An example of bucketed parquet file name with bucket id 3:
+ // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb-00003.gz.parquet
+ private val bucketedFileName = """.*-(\d+)(?:\..*)?$""".r
+
+ def getBucketId(fileName: String): Option[Int] = fileName match {
+ case bucketedFileName(bucketId) => Some(bucketId.toInt)
+ case other => None
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 0897fcadbc..c3603936df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -91,7 +91,7 @@ case class CreateTempTableUsing(
def run(sqlContext: SQLContext): Seq[Row] = {
val resolved = ResolvedDataSource(
- sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
+ sqlContext, userSpecifiedSchema, Array.empty[String], bucketSpec = None, provider, options)
sqlContext.catalog.registerTable(
tableIdent,
DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 8a6fa4aeeb..20c60b9c43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -57,7 +57,7 @@ class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegi
maybeDataSchema = dataSchema,
maybePartitionSpec = None,
userDefinedPartitionColumns = partitionColumns,
- bucketSpec = bucketSpec,
+ maybeBucketSpec = bucketSpec,
paths = paths,
parameters = parameters)(sqlContext)
}
@@ -68,7 +68,7 @@ private[sql] class JSONRelation(
val maybeDataSchema: Option[StructType],
val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
- override val bucketSpec: Option[BucketSpec] = None,
+ override val maybeBucketSpec: Option[BucketSpec] = None,
override val paths: Array[String] = Array.empty[String],
parameters: Map[String, String] = Map.empty[String, String])
(@transient val sqlContext: SQLContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 991a5d5aef..30ddec686c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -112,7 +112,7 @@ private[sql] class ParquetRelation(
// This is for metastore conversion.
private val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
- override val bucketSpec: Option[BucketSpec],
+ override val maybeBucketSpec: Option[BucketSpec],
parameters: Map[String, String])(
val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index dd3e66d8a9..9358c9c37b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -36,6 +36,7 @@ private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[Logica
sqlContext,
userSpecifiedSchema = None,
partitionColumns = Array(),
+ bucketSpec = None,
provider = u.tableIdentifier.database.get,
options = Map("path" -> u.tableIdentifier.table))
val plan = LogicalRelation(resolved.relation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 9f3607369c..7800776fa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -28,13 +28,13 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.execution.{FileRelation, RDDConversions}
-import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec}
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -458,7 +458,12 @@ abstract class HadoopFsRelation private[sql](
private var _partitionSpec: PartitionSpec = _
- private[sql] def bucketSpec: Option[BucketSpec] = None
+ private[this] var malformedBucketFile = false
+
+ private[sql] def maybeBucketSpec: Option[BucketSpec] = None
+
+ final private[sql] def getBucketSpec: Option[BucketSpec] =
+ maybeBucketSpec.filter(_ => sqlContext.conf.bucketingEnabled() && !malformedBucketFile)
private class FileStatusCache {
var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus]
@@ -664,6 +669,35 @@ abstract class HadoopFsRelation private[sql](
})
}
+ /**
+ * Groups the input files by bucket id, if bucketing is enabled and this data source is bucketed.
+ * Returns None if there exists any malformed bucket files.
+ */
+ private def groupBucketFiles(
+ files: Array[FileStatus]): Option[scala.collection.Map[Int, Array[FileStatus]]] = {
+ malformedBucketFile = false
+ if (getBucketSpec.isDefined) {
+ val groupedBucketFiles = mutable.HashMap.empty[Int, mutable.ArrayBuffer[FileStatus]]
+ var i = 0
+ while (!malformedBucketFile && i < files.length) {
+ val bucketId = BucketingUtils.getBucketId(files(i).getPath.getName)
+ if (bucketId.isEmpty) {
+ logError(s"File ${files(i).getPath} is expected to be a bucket file, but there is no " +
+ "bucket id information in file name. Fall back to non-bucketing mode.")
+ malformedBucketFile = true
+ } else {
+ val bucketFiles =
+ groupedBucketFiles.getOrElseUpdate(bucketId.get, mutable.ArrayBuffer.empty)
+ bucketFiles += files(i)
+ }
+ i += 1
+ }
+ if (malformedBucketFile) None else Some(groupedBucketFiles.mapValues(_.toArray))
+ } else {
+ None
+ }
+ }
+
final private[sql] def buildInternalScan(
requiredColumns: Array[String],
filters: Array[Filter],
@@ -683,7 +717,20 @@ abstract class HadoopFsRelation private[sql](
}
}
- buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf)
+ groupBucketFiles(inputStatuses).map { groupedBucketFiles =>
+ // For each bucket id, firstly we get all files belong to this bucket, by detecting bucket
+ // id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
+ // bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
+ val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
+ groupedBucketFiles.get(bucketId).map { inputStatuses =>
+ buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
+ }.getOrElse(sqlContext.emptyResult)
+ }
+
+ new UnionRDD(sqlContext.sparkContext, perBucketRows)
+ }.getOrElse {
+ buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf)
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index e70eb2a060..8de8ba355e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -1223,6 +1223,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
sqlContext,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
+ bucketSpec = None,
provider = classOf[DefaultSource].getCanonicalName,
options = Map("path" -> path))
@@ -1230,6 +1231,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
sqlContext,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
+ bucketSpec = None,
provider = classOf[DefaultSource].getCanonicalName,
options = Map("path" -> path))
assert(d1 === d2)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 3d54048c24..0cfe03ba91 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -143,19 +143,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
}
}
- def partColsFromParts: Option[Seq[String]] = {
- table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols =>
- (0 until numPartCols.toInt).map { index =>
- val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull
- if (partCol == null) {
+ def getColumnNames(colType: String): Seq[String] = {
+ table.properties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map {
+ numCols => (0 until numCols.toInt).map { index =>
+ table.properties.get(s"spark.sql.sources.schema.${colType}Col.$index").getOrElse {
throw new AnalysisException(
- "Could not read partitioned columns from the metastore because it is corrupted " +
- s"(missing part $index of the it, $numPartCols parts are expected).")
+ s"Could not read $colType columns from the metastore because it is corrupted " +
+ s"(missing part $index of it, $numCols parts are expected).")
}
-
- partCol
}
- }
+ }.getOrElse(Nil)
}
// Originally, we used spark.sql.sources.schema to store the schema of a data source table.
@@ -170,7 +167,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
// We only need names at here since userSpecifiedSchema we loaded from the metastore
// contains partition columns. We can always get datatypes of partitioning columns
// from userSpecifiedSchema.
- val partitionColumns = partColsFromParts.getOrElse(Nil)
+ val partitionColumns = getColumnNames("part")
+
+ val bucketSpec = table.properties.get("spark.sql.sources.schema.numBuckets").map { n =>
+ BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort"))
+ }
// It does not appear that the ql client for the metastore has a way to enumerate all the
// SerDe properties directly...
@@ -181,6 +182,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
hive,
userSpecifiedSchema,
partitionColumns.toArray,
+ bucketSpec,
table.properties("spark.sql.sources.provider"),
options)
@@ -282,7 +284,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf)
val dataSource = ResolvedDataSource(
- hive, userSpecifiedSchema, partitionColumns, provider, options)
+ hive, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options)
def newSparkSQLSpecificMetastoreTable(): HiveTable = {
HiveTable(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 07a352873d..e703ac0164 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -213,7 +213,12 @@ case class CreateMetastoreDataSourceAsSelect(
case SaveMode.Append =>
// Check if the specified data source match the data source of the existing table.
val resolved = ResolvedDataSource(
- sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath)
+ sqlContext,
+ Some(query.schema.asNullable),
+ partitionColumns,
+ bucketSpec,
+ provider,
+ optionsWithPath)
val createdRelation = LogicalRelation(resolved.relation)
EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match {
case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 14fa152c23..40409169b0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -156,7 +156,7 @@ private[sql] class OrcRelation(
maybeDataSchema: Option[StructType],
maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
- override val bucketSpec: Option[BucketSpec],
+ override val maybeBucketSpec: Option[BucketSpec],
parameters: Map[String, String])(
@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
new file mode 100644
index 0000000000..58ecdd3b80
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+
+import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.Exchange
+import org.apache.spark.sql.execution.joins.SortMergeJoin
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
+
+class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import testImplicits._
+
+ test("read bucketed data") {
+ val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ withTable("bucketed_table") {
+ df.write
+ .format("parquet")
+ .partitionBy("i")
+ .bucketBy(8, "j", "k")
+ .saveAsTable("bucketed_table")
+
+ for (i <- 0 until 5) {
+ val rdd = hiveContext.table("bucketed_table").filter($"i" === i).queryExecution.toRdd
+ assert(rdd.partitions.length == 8)
+
+ val attrs = df.select("j", "k").schema.toAttributes
+ val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => {
+ val getBucketId = UnsafeProjection.create(
+ HashPartitioning(attrs, 8).partitionIdExpression :: Nil,
+ attrs)
+ rows.map(row => getBucketId(row).getInt(0) == index)
+ })
+
+ assert(checkBucketId.collect().reduce(_ && _))
+ }
+ }
+ }
+
+ private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
+ private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
+
+ private def testBucketing(
+ bucketing1: DataFrameWriter => DataFrameWriter,
+ bucketing2: DataFrameWriter => DataFrameWriter,
+ joinColumns: Seq[String],
+ shuffleLeft: Boolean,
+ shuffleRight: Boolean): Unit = {
+ withTable("bucketed_table1", "bucketed_table2") {
+ bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
+ bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ val t1 = hiveContext.table("bucketed_table1")
+ val t2 = hiveContext.table("bucketed_table2")
+ val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
+
+ // First check the result is corrected.
+ checkAnswer(
+ joined.sort("bucketed_table1.k", "bucketed_table2.k"),
+ df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
+
+ assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin])
+ val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin]
+
+ assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft)
+ assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight)
+ }
+ }
+ }
+
+ private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
+ joinCols.map(col => left(col) === right(col)).reduce(_ && _)
+ }
+
+ test("avoid shuffle when join 2 bucketed tables") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ }
+
+ // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
+ ignore("avoid shuffle when join keys are a super-set of bucket keys") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+ testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ }
+
+ test("only shuffle one side when join bucketed table and non-bucketed table") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ }
+
+ test("only shuffle one side when 2 bucketed tables have different bucket number") {
+ val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
+ testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ }
+
+ test("only shuffle one side when 2 bucketed tables have different bucket keys") {
+ val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+ val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
+ testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
+ }
+
+ test("shuffle when join keys are not equal to bucket keys") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+ testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
+ }
+
+ test("shuffle when join 2 bucketed tables with bucketing disabled") {
+ val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+ withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
+ testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
+ }
+ }
+
+ test("avoid shuffle when grouping keys are equal to bucket keys") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table")
+ val tbl = hiveContext.table("bucketed_table")
+ val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+ checkAnswer(
+ agged.sort("i", "j"),
+ df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ }
+ }
+
+ test("avoid shuffle when grouping keys are a super-set of bucket keys") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+ val tbl = hiveContext.table("bucketed_table")
+ val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+ checkAnswer(
+ agged.sort("i", "j"),
+ df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ }
+ }
+
+ test("fallback to non-bucketing mode if there exists any malformed bucket files") {
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+ val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+ Utils.deleteRecursively(tableDir)
+ df1.write.parquet(tableDir.getAbsolutePath)
+
+ val agged = hiveContext.table("bucketed_table").groupBy("i").count()
+ // make sure we fall back to non-bucketing mode and can't avoid shuffle
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined)
+ checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i"))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 3ea9826544..e812439bed 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -62,15 +63,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt"))
}
- private val testFileName = """.*-(\d+)$""".r
- private val otherFileName = """.*-(\d+)\..*""".r
- private def getBucketId(fileName: String): Int = {
- fileName match {
- case testFileName(bucketId) => bucketId.toInt
- case otherFileName(bucketId) => bucketId.toInt
- }
- }
-
private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
private def testBucketing(
@@ -81,7 +73,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val allBucketFiles = dataDir.listFiles().filterNot(f =>
f.getName.startsWith(".") || f.getName.startsWith("_")
)
- val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName))
+ val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
assert(groupedBucketFiles.size <= 8)
for ((bucketId, bucketFiles) <- groupedBucketFiles) {
@@ -98,12 +90,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
- val getHashCode = UnsafeProjection.create(
+ val getBucketId = UnsafeProjection.create(
HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
qe.analyzed.output)
for (row <- rows) {
- val actualBucketId = getHashCode(row).getInt(0)
+ val actualBucketId = getBucketId(row).getInt(0)
assert(actualBucketId == bucketId)
}
}