aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-06 16:58:10 -0800
committerReynold Xin <rxin@databricks.com>2016-01-06 16:58:10 -0800
commit917d3fc069fb9ea1c1487119c9c12b373f4f9b77 (patch)
tree44041146c82d7be4270da08ac88afbba0d7c18c3 /sql
parent6f7ba6409a39fd2e34865e3e7a84a3dd0b00d6a4 (diff)
downloadspark-917d3fc069fb9ea1c1487119c9c12b373f4f9b77.tar.gz
spark-917d3fc069fb9ea1c1487119c9c12b373f4f9b77.tar.bz2
spark-917d3fc069fb9ea1c1487119c9c12b373f4f9b77.zip
[SPARK-12539][SQL] support writing bucketed table
This PR adds bucket write support to Spark SQL. User can specify bucketing columns, numBuckets and sorting columns with or without partition columns. For example: ``` df.write.partitionBy("year").bucketBy(8, "country").sortBy("amount").saveAsTable("sales") ``` When bucketing is used, we will calculate bucket id for each record, and group the records by bucket id. For each group, we will create a file with bucket id in its name, and write data into it. For each bucket file, if sorting columns are specified, the data will be sorted before write. Note that there may be multiply files for one bucket, as the data is distributed. Currently we store the bucket metadata at hive metastore in a non-hive-compatible way. We use different bucketing hash function compared to hive, so we can't be compatible anyway. Limitations: * Can't write bucketed data without hive metastore. * Can't insert bucketed data into existing hive tables. Author: Wenchen Fan <wenchen@databricks.com> Closes #10498 from cloud-fan/bucket-write.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala219
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala34
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala23
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala15
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala1
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala169
18 files changed, 626 insertions, 117 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index e2d72a549e..00f9817b53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -23,9 +23,9 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
-import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource}
+import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.sources.HadoopFsRelation
@@ -129,6 +129,34 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
/**
+ * Buckets the output by the given columns. If specified, the output is laid out on the file
+ * system similar to Hive's bucketing scheme.
+ *
+ * This is applicable for Parquet, JSON and ORC.
+ *
+ * @since 2.0
+ */
+ @scala.annotation.varargs
+ def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = {
+ this.numBuckets = Option(numBuckets)
+ this.bucketColumnNames = Option(colName +: colNames)
+ this
+ }
+
+ /**
+ * Sorts the output in each bucket by the given columns.
+ *
+ * This is applicable for Parquet, JSON and ORC.
+ *
+ * @since 2.0
+ */
+ @scala.annotation.varargs
+ def sortBy(colName: String, colNames: String*): DataFrameWriter = {
+ this.sortColumnNames = Option(colName +: colNames)
+ this
+ }
+
+ /**
* Saves the content of the [[DataFrame]] at the specified path.
*
* @since 1.4.0
@@ -144,10 +172,12 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def save(): Unit = {
+ assertNotBucketed()
ResolvedDataSource(
df.sqlContext,
source,
partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]),
+ getBucketSpec,
mode,
extraOptions.toMap,
df)
@@ -166,6 +196,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
+ assertNotBucketed()
val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap)
val overwrite = mode == SaveMode.Overwrite
@@ -188,13 +219,47 @@ final class DataFrameWriter private[sql](df: DataFrame) {
ifNotExists = false)).toRdd
}
- private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols =>
- parCols.map { col =>
- df.logicalPlan.output
- .map(_.name)
- .find(df.sqlContext.analyzer.resolver(_, col))
- .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " +
- s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})"))
+ private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
+ cols.map(normalize(_, "Partition"))
+ }
+
+ private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols =>
+ cols.map(normalize(_, "Bucketing"))
+ }
+
+ private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols =>
+ cols.map(normalize(_, "Sorting"))
+ }
+
+ private def getBucketSpec: Option[BucketSpec] = {
+ if (sortColumnNames.isDefined) {
+ require(numBuckets.isDefined, "sortBy must be used together with bucketBy")
+ }
+
+ for {
+ n <- numBuckets
+ } yield {
+ require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.")
+ BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil))
+ }
+ }
+
+ /**
+ * The given column name may not be equal to any of the existing column names if we were in
+ * case-insensitive context. Normalize the given column name to the real one so that we don't
+ * need to care about case sensitivity afterwards.
+ */
+ private def normalize(columnName: String, columnType: String): String = {
+ val validColumnNames = df.logicalPlan.output.map(_.name)
+ validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName))
+ .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " +
+ s"existing columns (${validColumnNames.mkString(", ")})"))
+ }
+
+ private def assertNotBucketed(): Unit = {
+ if (numBuckets.isDefined || sortColumnNames.isDefined) {
+ throw new IllegalArgumentException(
+ "Currently we don't support writing bucketed data to this data source.")
}
}
@@ -244,6 +309,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
source,
temporary = false,
partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]),
+ getBucketSpec,
mode,
extraOptions.toMap,
df.logicalPlan)
@@ -372,4 +438,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
private var partitioningColumns: Option[Seq[String]] = None
+ private var bucketColumnNames: Option[Seq[String]] = None
+
+ private var numBuckets: Option[Int] = None
+
+ private var sortColumnNames: Option[Seq[String]] = None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6cf75bc170..482130a18d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -382,13 +382,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query)
- if partitionsCols.nonEmpty =>
+ case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty =>
sys.error("Cannot create temporary partitioned table.")
- case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) =>
+ case c: CreateTableUsingAsSelect if c.temporary =>
val cmd = CreateTempTableUsingAsSelect(
- tableIdent, provider, Array.empty[String], mode, opts, query)
+ c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
index 48eff62b29..d8d21b06b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
@@ -109,6 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
provider,
temp.isDefined,
Array.empty[String],
+ bucketSpec = None,
mode,
options,
queryPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
index 38152d0cf1..7a8691e7cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
@@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
|Actual: ${partitionColumns.mkString(", ")}
""".stripMargin)
- val writerContainer = if (partitionColumns.isEmpty) {
+ val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) {
new DefaultWriterContainer(relation, job, isAppend)
} else {
val output = df.queryExecution.executedPlan.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index 0ca0a38f71..ece9b8a9a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -210,6 +210,7 @@ object ResolvedDataSource extends Logging {
sqlContext: SQLContext,
provider: String,
partitionColumns: Array[String],
+ bucketSpec: Option[BucketSpec],
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
@@ -244,6 +245,7 @@ object ResolvedDataSource extends Logging {
Array(outputPath.toString),
Some(dataSchema.asNullable),
Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)),
+ bucketSpec,
caseInsensitiveOptions)
// For partitioned relation r, r.schema's column ordering can be different from the column
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 9f23d53107..4f8524f4b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{IntegerType, StructType, StringType}
import org.apache.spark.util.SerializableConfiguration
@@ -121,9 +121,9 @@ private[sql] abstract class BaseWriterContainer(
}
}
- protected def newOutputWriter(path: String): OutputWriter = {
+ protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = {
try {
- outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext)
+ outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext)
} catch {
case e: org.apache.hadoop.fs.FileAlreadyExistsException =>
if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) {
@@ -312,19 +312,23 @@ private[sql] class DynamicPartitionWriterContainer(
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
- def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
- executorSideSetup(taskContext)
+ private val bucketSpec = relation.bucketSpec
- var outputWritersCleared = false
+ private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap {
+ spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get)
+ }
- // Returns the partition key given an input row
- val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema)
- // Returns the data columns to be written given an input row
- val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
+ private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap {
+ spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get)
+ }
+
+ private def bucketIdExpression: Option[Expression] = for {
+ BucketSpec(numBuckets, _, _) <- bucketSpec
+ } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets))
- // Expressions that given a partition key build a string like: col1=val/col2=val/...
- val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+ // Expressions that given a partition key build a string like: col1=val/col2=val/...
+ private def partitionStringExpression: Seq[Expression] = {
+ partitionColumns.zipWithIndex.flatMap { case (c, i) =>
val escaped =
ScalaUDF(
PartitioningUtils.escapePathName _,
@@ -335,6 +339,121 @@ private[sql] class DynamicPartitionWriterContainer(
val partitionName = Literal(c.name + "=") :: str :: Nil
if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName
}
+ }
+
+ private def getBucketIdFromKey(key: InternalRow): Option[Int] = {
+ if (bucketSpec.isDefined) {
+ Some(key.getInt(partitionColumns.length))
+ } else {
+ None
+ }
+ }
+
+ private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = {
+ val bucketIdIndex = partitionColumns.length
+ if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
+ false
+ } else {
+ var i = partitionColumns.length - 1
+ while (i >= 0) {
+ val dt = partitionColumns(i).dataType
+ if (key1.get(i, dt) != key2.get(i, dt)) return false
+ i -= 1
+ }
+ true
+ }
+ }
+
+ private def sortBasedWrite(
+ sorter: UnsafeKVExternalSorter,
+ iterator: Iterator[InternalRow],
+ getSortingKey: UnsafeProjection,
+ getOutputRow: UnsafeProjection,
+ getPartitionString: UnsafeProjection,
+ outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = {
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+ }
+
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
+
+ val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
+ (key1, key2) => key1 != key2
+ } else {
+ (key1, key2) => key1 == null || !sameBucket(key1, key2)
+ }
+
+ val sortedIterator = sorter.sortedIterator()
+ var currentKey: UnsafeRow = null
+ var currentWriter: OutputWriter = null
+ try {
+ while (sortedIterator.next()) {
+ if (needNewWriter(currentKey, sortedIterator.getKey)) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
+ currentKey = sortedIterator.getKey.copy()
+ logDebug(s"Writing partition: $currentKey")
+
+ // Either use an existing file from before, or open a new one.
+ currentWriter = outputWriters.remove(currentKey)
+ if (currentWriter == null) {
+ currentWriter = newOutputWriter(currentKey, getPartitionString)
+ }
+ }
+
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ } finally {
+ if (currentWriter != null) { currentWriter.close() }
+ }
+ }
+
+ /**
+ * Open and returns a new OutputWriter given a partition key and optional bucket id.
+ * If bucket id is specified, we will append it to the end of the file name, but before the
+ * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
+ */
+ private def newOutputWriter(
+ key: InternalRow,
+ getPartitionString: UnsafeProjection): OutputWriter = {
+ val configuration = taskAttemptContext.getConfiguration
+ val path = if (partitionColumns.nonEmpty) {
+ val partitionPath = getPartitionString(key).getString(0)
+ configuration.set(
+ "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString)
+ new Path(getWorkPath, partitionPath).toString
+ } else {
+ configuration.set("spark.sql.sources.output.path", outputPath)
+ getWorkPath
+ }
+ val bucketId = getBucketIdFromKey(key)
+ val newWriter = super.newOutputWriter(path, bucketId)
+ newWriter.initConverter(dataSchema)
+ newWriter
+ }
+
+ def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
+ val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
+ executorSideSetup(taskContext)
+
+ var outputWritersCleared = false
+
+ // We should first sort by partition columns, then bucket id, and finally sorting columns.
+ val getSortingKey =
+ UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
+
+ val sortingKeySchema = if (bucketSpec.isEmpty) {
+ StructType.fromAttributes(partitionColumns)
+ } else { // If it's bucketed, we should also consider bucket id as part of the key.
+ val fields = StructType.fromAttributes(partitionColumns)
+ .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns)
+ StructType(fields)
+ }
+
+ // Returns the data columns to be written given an input row
+ val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
// Returns the partition path given a partition key.
val getPartitionString =
@@ -342,22 +461,34 @@ private[sql] class DynamicPartitionWriterContainer(
// If anything below fails, we should abort the task.
try {
- // This will be filled in if we have to fall back on sorting.
- var sorter: UnsafeKVExternalSorter = null
+ // If there is no sorting columns, we set sorter to null and try the hash-based writing first,
+ // and fill the sorter if there are too many writers and we need to fall back on sorting.
+ // If there are sorting columns, then we have to sort the data anyway, and no need to try the
+ // hash-based writing first.
+ var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
+ new UnsafeKVExternalSorter(
+ sortingKeySchema,
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
+ } else {
+ null
+ }
while (iterator.hasNext && sorter == null) {
val inputRow = iterator.next()
- val currentKey = getPartitionKey(inputRow)
+ // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
+ val currentKey = getSortingKey(inputRow)
var currentWriter = outputWriters.get(currentKey)
if (currentWriter == null) {
if (outputWriters.size < maxOpenFiles) {
- currentWriter = newOutputWriter(currentKey)
+ currentWriter = newOutputWriter(currentKey, getPartitionString)
outputWriters.put(currentKey.copy(), currentWriter)
currentWriter.writeInternal(getOutputRow(inputRow))
} else {
logInfo(s"Maximum partitions reached, falling back on sorting.")
sorter = new UnsafeKVExternalSorter(
- StructType.fromAttributes(partitionColumns),
+ sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
@@ -369,39 +500,15 @@ private[sql] class DynamicPartitionWriterContainer(
}
// If the sorter is not null that means that we reached the maxFiles above and need to finish
- // using external sort.
+ // using external sort, or there are sorting columns and we need to sort the whole data set.
if (sorter != null) {
- while (iterator.hasNext) {
- val currentRow = iterator.next()
- sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow))
- }
-
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
-
- val sortedIterator = sorter.sortedIterator()
- var currentKey: InternalRow = null
- var currentWriter: OutputWriter = null
- try {
- while (sortedIterator.next()) {
- if (currentKey != sortedIterator.getKey) {
- if (currentWriter != null) {
- currentWriter.close()
- }
- currentKey = sortedIterator.getKey.copy()
- logDebug(s"Writing partition: $currentKey")
-
- // Either use an existing file from before, or open a new one.
- currentWriter = outputWriters.remove(currentKey)
- if (currentWriter == null) {
- currentWriter = newOutputWriter(currentKey)
- }
- }
-
- currentWriter.writeInternal(sortedIterator.getValue)
- }
- } finally {
- if (currentWriter != null) { currentWriter.close() }
- }
+ sortBasedWrite(
+ sorter,
+ iterator,
+ getSortingKey,
+ getOutputRow,
+ getPartitionString,
+ outputWriters)
}
commitTask()
@@ -412,18 +519,6 @@ private[sql] class DynamicPartitionWriterContainer(
throw new SparkException("Task failed while writing rows.", cause)
}
- /** Open and returns a new OutputWriter given a partition key. */
- def newOutputWriter(key: InternalRow): OutputWriter = {
- val partitionPath = getPartitionString(key).getString(0)
- val path = new Path(getWorkPath, partitionPath)
- val configuration = taskAttemptContext.getConfiguration
- configuration.set(
- "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString)
- val newWriter = super.newOutputWriter(path.toString)
- newWriter.initConverter(dataSchema)
- newWriter
- }
-
def clearOutputWriters(): Unit = {
if (!outputWritersCleared) {
outputWriters.asScala.values.foreach(_.close())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
new file mode 100644
index 0000000000..82287c8967
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A container for bucketing information.
+ * Bucketing is a technology for decomposing data sets into more manageable parts, and the number
+ * of buckets is fixed so it does not fluctuate with data.
+ *
+ * @param numBuckets number of buckets.
+ * @param bucketColumnNames the names of the columns that used to generate the bucket id.
+ * @param sortColumnNames the names of the columns that used to sort data in each bucket.
+ */
+private[sql] case class BucketSpec(
+ numBuckets: Int,
+ bucketColumnNames: Seq[String],
+ sortColumnNames: Seq[String])
+
+private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProvider {
+ final override def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ dataSchema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): HadoopFsRelation =
+ // TODO: throw exception here as we won't call this method during execution, after bucketed read
+ // support is finished.
+ createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters)
+}
+
+private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory {
+ final override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter =
+ throw new UnsupportedOperationException("use bucket version")
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index aed5d0dcf2..0897fcadbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -76,6 +76,7 @@ case class CreateTableUsingAsSelect(
provider: String,
temporary: Boolean,
partitionColumns: Array[String],
+ bucketSpec: Option[BucketSpec],
mode: SaveMode,
options: Map[String, String],
child: LogicalPlan) extends UnaryNode {
@@ -109,7 +110,14 @@ case class CreateTempTableUsingAsSelect(
override def run(sqlContext: SQLContext): Seq[Row] = {
val df = DataFrame(sqlContext, query)
- val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df)
+ val resolved = ResolvedDataSource(
+ sqlContext,
+ provider,
+ partitionColumns,
+ bucketSpec = None,
+ mode,
+ options,
+ df)
sqlContext.catalog.registerTable(
tableIdent,
DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 8bf538178b..b92edf65bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -34,13 +34,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.execution.datasources.PartitionSpec
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
-class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
override def shortName(): String = "json"
@@ -49,6 +49,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
paths: Array[String],
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
+ bucketSpec: Option[BucketSpec],
parameters: Map[String, String]): HadoopFsRelation = {
new JSONRelation(
@@ -56,6 +57,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
maybeDataSchema = dataSchema,
maybePartitionSpec = None,
userDefinedPartitionColumns = partitionColumns,
+ bucketSpec = bucketSpec,
paths = paths,
parameters = parameters)(sqlContext)
}
@@ -66,11 +68,29 @@ private[sql] class JSONRelation(
val maybeDataSchema: Option[StructType],
val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
+ override val bucketSpec: Option[BucketSpec],
override val paths: Array[String] = Array.empty[String],
parameters: Map[String, String] = Map.empty[String, String])
(@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters) {
+ def this(
+ inputRDD: Option[RDD[String]],
+ maybeDataSchema: Option[StructType],
+ maybePartitionSpec: Option[PartitionSpec],
+ userDefinedPartitionColumns: Option[StructType],
+ paths: Array[String] = Array.empty[String],
+ parameters: Map[String, String] = Map.empty[String, String])(sqlContext: SQLContext) = {
+ this(
+ inputRDD,
+ maybeDataSchema,
+ maybePartitionSpec,
+ userDefinedPartitionColumns,
+ None,
+ paths,
+ parameters)(sqlContext)
+ }
+
val options: JSONOptions = JSONOptions.createFromConfigMap(parameters)
/** Constraints to be imposed on schema to be stored. */
@@ -158,13 +178,14 @@ private[sql] class JSONRelation(
partitionColumns)
}
- override def prepareJobForWrite(job: Job): OutputWriterFactory = {
- new OutputWriterFactory {
+ override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = {
+ new BucketedOutputWriterFactory {
override def newInstance(
path: String,
+ bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new JsonOutputWriter(path, dataSchema, context)
+ new JsonOutputWriter(path, bucketId, dataSchema, context)
}
}
}
@@ -172,6 +193,7 @@ private[sql] class JSONRelation(
private[json] class JsonOutputWriter(
path: String,
+ bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter with Logging {
@@ -188,7 +210,8 @@ private[json] class JsonOutputWriter(
val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
- new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
+ val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("")
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension")
}
}.getRecordWriter(context)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 45f1dff96d..4b375de05e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -45,13 +45,13 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
-import org.apache.spark.sql.execution.datasources.PartitionSpec
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
-private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
override def shortName(): String = "parquet"
@@ -60,13 +60,17 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc
paths: Array[String],
schema: Option[StructType],
partitionColumns: Option[StructType],
+ bucketSpec: Option[BucketSpec],
parameters: Map[String, String]): HadoopFsRelation = {
- new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext)
+ new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext)
}
}
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
-private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
+private[sql] class ParquetOutputWriter(
+ path: String,
+ bucketId: Option[Int],
+ context: TaskAttemptContext)
extends OutputWriter {
private val recordWriter: RecordWriter[Void, InternalRow] = {
@@ -86,7 +90,8 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
- new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
+ val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("")
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension")
}
}
}
@@ -107,6 +112,7 @@ private[sql] class ParquetRelation(
// This is for metastore conversion.
private val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
+ override val bucketSpec: Option[BucketSpec],
parameters: Map[String, String])(
val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters)
@@ -123,6 +129,7 @@ private[sql] class ParquetRelation(
maybeDataSchema,
maybePartitionSpec,
maybePartitionSpec.map(_.partitionColumns),
+ None,
parameters)(sqlContext)
}
@@ -216,7 +223,7 @@ private[sql] class ParquetRelation(
override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum
- override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = {
val conf = ContextUtil.getConfiguration(job)
// SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible
@@ -276,10 +283,13 @@ private[sql] class ParquetRelation(
sqlContext.conf.parquetCompressionCodec.toUpperCase,
CompressionCodecName.UNCOMPRESSED).name())
- new OutputWriterFactory {
+ new BucketedOutputWriterFactory {
override def newInstance(
- path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {
- new ParquetOutputWriter(path, context)
+ path: String,
+ bucketId: Option[Int],
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new ParquetOutputWriter(path, bucketId, context)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 50ecbd3576..d484403d1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast}
+import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -165,22 +165,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
- case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) =>
+ case c: CreateTableUsingAsSelect =>
// When the SaveMode is Overwrite, we need to check if the table is an input table of
// the query. If so, we will throw an AnalysisException to let users know it is not allowed.
- if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) {
+ if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) {
// Need to remove SubQuery operator.
- EliminateSubQueries(catalog.lookupRelation(tableIdent)) match {
+ EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match {
// Only do the check if the table is a data source table
// (the relation is a BaseRelation).
case l @ LogicalRelation(dest: BaseRelation, _) =>
// Get all input data source relations of the query.
- val srcRelations = query.collect {
+ val srcRelations = c.child.collect {
case LogicalRelation(src: BaseRelation, _) => src
}
if (srcRelations.contains(dest)) {
failAnalysis(
- s"Cannot overwrite table $tableIdent that is also being read from.")
+ s"Cannot overwrite table ${c.tableIdent} that is also being read from.")
} else {
// OK
}
@@ -192,7 +192,17 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
PartitioningUtils.validatePartitionColumnDataTypes(
- query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis)
+ c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis)
+
+ for {
+ spec <- c.bucketSpec
+ sortColumnName <- spec.sortColumnNames
+ sortColumn <- c.child.schema.find(_.name == sortColumnName)
+ } {
+ if (!RowOrdering.isOrderable(sortColumn.dataType)) {
+ failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.")
+ }
+ }
case _ => // OK
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index f4c7f0a269..c35f33132f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.util.Try
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.execution.{FileRelation, RDDConversions}
-import org.apache.spark.sql.execution.datasources.{Partition, PartitioningUtils, PartitionSpec}
+import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -161,6 +161,20 @@ trait HadoopFsRelationProvider {
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation
+
+ // TODO: expose bucket API to users.
+ private[sql] def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ dataSchema: Option[StructType],
+ partitionColumns: Option[StructType],
+ bucketSpec: Option[BucketSpec],
+ parameters: Map[String, String]): HadoopFsRelation = {
+ if (bucketSpec.isDefined) {
+ throw new AnalysisException("Currently we don't support bucketing for this data source.")
+ }
+ createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters)
+ }
}
/**
@@ -351,7 +365,18 @@ abstract class OutputWriterFactory extends Serializable {
*
* @since 1.4.0
*/
- def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter
+ def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter
+
+ // TODO: expose bucket API to users.
+ private[sql] def newInstance(
+ path: String,
+ bucketId: Option[Int],
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter =
+ newInstance(path, dataSchema, context)
}
/**
@@ -435,6 +460,9 @@ abstract class HadoopFsRelation private[sql](
private var _partitionSpec: PartitionSpec = _
+ // TODO: expose bucket API to users.
+ private[sql] def bucketSpec: Option[BucketSpec] = None
+
private class FileStatusCache {
var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 1616c45952..43d84d507b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.execution.{datasources, FileRelation}
-import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
+import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _}
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.HiveNativeCommand
@@ -211,6 +211,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
partitionColumns: Array[String],
+ bucketSpec: Option[BucketSpec],
provider: String,
options: Map[String, String],
isExternal: Boolean): Unit = {
@@ -240,6 +241,25 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
}
}
+ if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) {
+ val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get
+
+ tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString)
+ tableProperties.put("spark.sql.sources.schema.numBucketCols",
+ bucketColumnNames.length.toString)
+ bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) =>
+ tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol)
+ }
+
+ if (sortColumnNames.nonEmpty) {
+ tableProperties.put("spark.sql.sources.schema.numSortCols",
+ sortColumnNames.length.toString)
+ sortColumnNames.zipWithIndex.foreach { case (sortCol, index) =>
+ tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol)
+ }
+ }
+ }
+
if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) {
// The table does not have a specified schema, which means that the schema will be inferred
// when we load the table. So, we are not expecting partition columns and we will discover
@@ -596,6 +616,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
conf.defaultDataSourceName,
temporary = false,
Array.empty[String],
+ bucketSpec = None,
mode,
options = Map.empty[String, String],
child
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 0b4f5a0fd6..3687dd6f5a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -88,10 +88,9 @@ private[hive] trait HiveStrategies {
tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)
ExecutedCommand(cmd) :: Nil
- case CreateTableUsingAsSelect(
- tableIdent, provider, false, partitionCols, mode, opts, query) =>
- val cmd =
- CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query)
+ case c: CreateTableUsingAsSelect =>
+ val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns,
+ c.bucketSpec, c.mode, c.options, c.child)
ExecutedCommand(cmd) :: Nil
case _ => Nil
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 94210a5394..612f01cda8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
+import org.apache.spark.sql.execution.datasources.{BucketSpec, LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -151,6 +151,7 @@ case class CreateMetastoreDataSource(
tableIdent,
userSpecifiedSchema,
Array.empty[String],
+ bucketSpec = None,
provider,
optionsWithPath,
isExternal)
@@ -164,6 +165,7 @@ case class CreateMetastoreDataSourceAsSelect(
tableIdent: TableIdentifier,
provider: String,
partitionColumns: Array[String],
+ bucketSpec: Option[BucketSpec],
mode: SaveMode,
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
@@ -254,8 +256,14 @@ case class CreateMetastoreDataSourceAsSelect(
}
// Create the relation based on the data of df.
- val resolved =
- ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df)
+ val resolved = ResolvedDataSource(
+ sqlContext,
+ provider,
+ partitionColumns,
+ bucketSpec,
+ mode,
+ optionsWithPath,
+ df)
if (createMetastoreTable) {
// We will use the schema of resolved.relation as the schema of the table (instead of
@@ -265,6 +273,7 @@ case class CreateMetastoreDataSourceAsSelect(
tableIdent,
Some(resolved.relation.schema),
partitionColumns,
+ bucketSpec,
provider,
optionsWithPath,
isExternal)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 3538d642d5..14fa152c23 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -37,13 +37,13 @@ import org.apache.spark.rdd.{HadoopRDD, RDD}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.datasources.PartitionSpec
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim}
import org.apache.spark.sql.sources.{Filter, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
-private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
override def shortName(): String = "orc"
@@ -52,17 +52,19 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc
paths: Array[String],
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
+ bucketSpec: Option[BucketSpec],
parameters: Map[String, String]): HadoopFsRelation = {
assert(
sqlContext.isInstanceOf[HiveContext],
"The ORC data source can only be used with HiveContext.")
- new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext)
+ new OrcRelation(paths, dataSchema, None, partitionColumns, bucketSpec, parameters)(sqlContext)
}
}
private[orc] class OrcOutputWriter(
path: String,
+ bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter with HiveInspectors {
@@ -101,7 +103,8 @@ private[orc] class OrcOutputWriter(
val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID")
val taskAttemptId = context.getTaskAttemptID
val partition = taskAttemptId.getTaskID.getId
- val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc"
+ val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("")
+ val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString.orc"
new OrcOutputFormat().getRecordWriter(
new Path(path, filename).getFileSystem(conf),
@@ -153,6 +156,7 @@ private[sql] class OrcRelation(
maybeDataSchema: Option[StructType],
maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
+ override val bucketSpec: Option[BucketSpec],
parameters: Map[String, String])(
@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters)
@@ -169,6 +173,7 @@ private[sql] class OrcRelation(
maybeDataSchema,
maybePartitionSpec,
maybePartitionSpec.map(_.partitionColumns),
+ None,
parameters)(sqlContext)
}
@@ -205,7 +210,7 @@ private[sql] class OrcRelation(
OrcTableScan(output, this, filters, inputPaths).execute()
}
- override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = {
job.getConfiguration match {
case conf: JobConf =>
conf.setOutputFormat(classOf[OrcOutputFormat])
@@ -216,12 +221,13 @@ private[sql] class OrcRelation(
classOf[MapRedOutputFormat[_, _]])
}
- new OutputWriterFactory {
+ new BucketedOutputWriterFactory {
override def newInstance(
path: String,
+ bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new OrcOutputWriter(path, dataSchema, context)
+ new OrcOutputWriter(path, bucketId, dataSchema, context)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index e22dac3bc9..202851ae13 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -707,6 +707,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
tableIdent = TableIdentifier("wide_schema"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
+ bucketSpec = None,
provider = "json",
options = Map("path" -> "just a dummy path"),
isExternal = false)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
new file mode 100644
index 0000000000..579da0291f
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+
+class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import testImplicits._
+
+ test("bucketed by non-existing column") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt"))
+ }
+
+ test("numBuckets not greater than 0 or less than 100000") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt"))
+ intercept[IllegalArgumentException](df.write.bucketBy(100000, "i").saveAsTable("tt"))
+ }
+
+ test("specify sorting columns without bucketing columns") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt"))
+ }
+
+ test("sorting by non-orderable column") {
+ val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j")
+ intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt"))
+ }
+
+ test("write bucketed data to unsupported data source") {
+ val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i")
+ intercept[AnalysisException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt"))
+ }
+
+ test("write bucketed data to non-hive-table or existing hive table") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path"))
+ intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path"))
+ intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt"))
+ }
+
+ private val testFileName = """.*-(\d+)$""".r
+ private val otherFileName = """.*-(\d+)\..*""".r
+ private def getBucketId(fileName: String): Int = {
+ fileName match {
+ case testFileName(bucketId) => bucketId.toInt
+ case otherFileName(bucketId) => bucketId.toInt
+ }
+ }
+
+ private def testBucketing(
+ dataDir: File,
+ source: String,
+ bucketCols: Seq[String],
+ sortCols: Seq[String] = Nil): Unit = {
+ val allBucketFiles = dataDir.listFiles().filterNot(f =>
+ f.getName.startsWith(".") || f.getName.startsWith("_")
+ )
+ val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName))
+ assert(groupedBucketFiles.size <= 8)
+
+ for ((bucketId, bucketFiles) <- groupedBucketFiles) {
+ for (bucketFile <- bucketFiles) {
+ val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath)
+ .select((bucketCols ++ sortCols).map(col): _*)
+
+ if (sortCols.nonEmpty) {
+ checkAnswer(df.sort(sortCols.map(col): _*), df.collect())
+ }
+
+ val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect()
+
+ for (row <- rows) {
+ assert(row.isInstanceOf[UnsafeRow])
+ val actualBucketId = (row.hashCode() % 8 + 8) % 8
+ assert(actualBucketId == bucketId)
+ }
+ }
+ }
+ }
+
+ private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+
+ test("write bucketed data") {
+ for (source <- Seq("parquet", "json", "orc")) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .partitionBy("i")
+ .bucketBy(8, "j", "k")
+ .saveAsTable("bucketed_table")
+
+ val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+ for (i <- 0 until 5) {
+ testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k"))
+ }
+ }
+ }
+ }
+
+ test("write bucketed data with sortBy") {
+ for (source <- Seq("parquet", "json", "orc")) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .partitionBy("i")
+ .bucketBy(8, "j")
+ .sortBy("k")
+ .saveAsTable("bucketed_table")
+
+ val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+ for (i <- 0 until 5) {
+ testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k"))
+ }
+ }
+ }
+ }
+
+ test("write bucketed data without partitionBy") {
+ for (source <- Seq("parquet", "json", "orc")) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .bucketBy(8, "i", "j")
+ .saveAsTable("bucketed_table")
+
+ val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+ testBucketing(tableDir, source, Seq("i", "j"))
+ }
+ }
+ }
+
+ test("write bucketed data without partitionBy with sortBy") {
+ for (source <- Seq("parquet", "json", "orc")) {
+ withTable("bucketed_table") {
+ df.write
+ .format(source)
+ .bucketBy(8, "i", "j")
+ .sortBy("k")
+ .saveAsTable("bucketed_table")
+
+ val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+ testBucketing(tableDir, source, Seq("i", "j"), Seq("k"))
+ }
+ }
+ }
+}