aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-11 00:44:33 -0800
committerReynold Xin <rxin@databricks.com>2016-01-11 00:44:33 -0800
commitf253feff62f3eb3cce22bbec0874f317a61b0092 (patch)
tree986f2303bd5e9d24a74857dd4dddeea942183a49
parentf13c7f8f7dc8766b0a42406b5c3639d6be55cf33 (diff)
downloadspark-f253feff62f3eb3cce22bbec0874f317a61b0092.tar.gz
spark-f253feff62f3eb3cce22bbec0874f317a61b0092.tar.bz2
spark-f253feff62f3eb3cce22bbec0874f317a61b0092.zip
[SPARK-12539][FOLLOW-UP] always sort in partitioning writer
address comments in #10498 , especially https://github.com/apache/spark/pull/10498#discussion_r49021259 Author: Wenchen Fan <wenchen@databricks.com> This patch had conflicts when merged, resolved by Committer: Reynold Xin <rxin@databricks.com> Closes #10638 from cloud-fan/bucket-write.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala192
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala3
2 files changed, 48 insertions, 147 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 40ecdb8e44..fff72872c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
-import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer(
}
}
- private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = {
- val bucketIdIndex = partitionColumns.length
- if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
- false
- } else {
- var i = partitionColumns.length - 1
- while (i >= 0) {
- val dt = partitionColumns(i).dataType
- if (key1.get(i, dt) != key2.get(i, dt)) return false
- i -= 1
- }
- true
- }
- }
-
- private def sortBasedWrite(
- sorter: UnsafeKVExternalSorter,
- iterator: Iterator[InternalRow],
- getSortingKey: UnsafeProjection,
- getOutputRow: UnsafeProjection,
- getPartitionString: UnsafeProjection,
- outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = {
- while (iterator.hasNext) {
- val currentRow = iterator.next()
- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
- }
-
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
-
- val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
- (key1, key2) => key1 != key2
- } else {
- (key1, key2) => key1 == null || !sameBucket(key1, key2)
- }
-
- val sortedIterator = sorter.sortedIterator()
- var currentKey: UnsafeRow = null
- var currentWriter: OutputWriter = null
- try {
- while (sortedIterator.next()) {
- if (needNewWriter(currentKey, sortedIterator.getKey)) {
- if (currentWriter != null) {
- currentWriter.close()
- }
- currentKey = sortedIterator.getKey.copy()
- logDebug(s"Writing partition: $currentKey")
-
- // Either use an existing file from before, or open a new one.
- currentWriter = outputWriters.remove(currentKey)
- if (currentWriter == null) {
- currentWriter = newOutputWriter(currentKey, getPartitionString)
- }
- }
-
- currentWriter.writeInternal(sortedIterator.getValue)
- }
- } finally {
- if (currentWriter != null) { currentWriter.close() }
- }
- }
-
/**
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
@@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer(
}
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
executorSideSetup(taskContext)
- var outputWritersCleared = false
-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
- val getSortingKey =
- UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
-
- val sortingKeySchema = if (bucketSpec.isEmpty) {
- StructType.fromAttributes(partitionColumns)
- } else { // If it's bucketed, we should also consider bucket id as part of the key.
- val fields = StructType.fromAttributes(partitionColumns)
- .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns)
- StructType(fields)
- }
+ val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
+
+ val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
+
+ val sortingKeySchema = StructType(sortingExpressions.map {
+ case a: Attribute => StructField(a.name, a.dataType, a.nullable)
+ // The sorting expressions are all `Attribute` except bucket id.
+ case _ => StructField("bucketId", IntegerType, nullable = false)
+ })
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
@@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer(
// If anything below fails, we should abort the task.
try {
- // If there is no sorting columns, we set sorter to null and try the hash-based writing first,
- // and fill the sorter if there are too many writers and we need to fall back on sorting.
- // If there are sorting columns, then we have to sort the data anyway, and no need to try the
- // hash-based writing first.
- var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
- new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
+ // Sorts the data before write, so that we only need one writer at the same time.
+ // TODO: inject a local sort operator in planning.
+ val sorter = new UnsafeKVExternalSorter(
+ sortingKeySchema,
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+ }
+
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
+
+ val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
+ identity
} else {
- null
+ UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
+ case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+ })
}
- while (iterator.hasNext && sorter == null) {
- val inputRow = iterator.next()
- // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
- val currentKey = getSortingKey(inputRow)
- var currentWriter = outputWriters.get(currentKey)
-
- if (currentWriter == null) {
- if (outputWriters.size < maxOpenFiles) {
+
+ val sortedIterator = sorter.sortedIterator()
+ var currentKey: UnsafeRow = null
+ var currentWriter: OutputWriter = null
+ try {
+ while (sortedIterator.next()) {
+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+ if (currentKey != nextKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
+ currentKey = nextKey.copy()
+ logDebug(s"Writing partition: $currentKey")
+
currentWriter = newOutputWriter(currentKey, getPartitionString)
- outputWriters.put(currentKey.copy(), currentWriter)
- currentWriter.writeInternal(getOutputRow(inputRow))
- } else {
- logInfo(s"Maximum partitions reached, falling back on sorting.")
- sorter = new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
- sorter.insertKV(currentKey, getOutputRow(inputRow))
}
- } else {
- currentWriter.writeInternal(getOutputRow(inputRow))
- }
- }
- // If the sorter is not null that means that we reached the maxFiles above and need to finish
- // using external sort, or there are sorting columns and we need to sort the whole data set.
- if (sorter != null) {
- sortBasedWrite(
- sorter,
- iterator,
- getSortingKey,
- getOutputRow,
- getPartitionString,
- outputWriters)
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ } finally {
+ if (currentWriter != null) { currentWriter.close() }
}
commitTask()
@@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer(
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
-
- def clearOutputWriters(): Unit = {
- if (!outputWritersCleared) {
- outputWriters.asScala.values.foreach(_.close())
- outputWriters.clear()
- outputWritersCleared = true
- }
- }
-
- def commitTask(): Unit = {
- try {
- clearOutputWriters()
- super.commitTask()
- } catch {
- case cause: Throwable =>
- throw new RuntimeException("Failed to commit task", cause)
- }
- }
-
- def abortTask(): Unit = {
- try {
- clearOutputWriters()
- } finally {
- super.abortTask()
- }
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index c35f33132f..9f3607369c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -162,7 +162,6 @@ trait HadoopFsRelationProvider {
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation
- // TODO: expose bucket API to users.
private[sql] def createRelation(
sqlContext: SQLContext,
paths: Array[String],
@@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable {
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter
- // TODO: expose bucket API to users.
private[sql] def newInstance(
path: String,
bucketId: Option[Int],
@@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql](
private var _partitionSpec: PartitionSpec = _
- // TODO: expose bucket API to users.
private[sql] def bucketSpec: Option[BucketSpec] = None
private class FileStatusCache {