diff options
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala | 192 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala | 3 |
2 files changed, 48 insertions, 147 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 40ecdb8e44..fff72872c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration @@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer( } } - private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { - val bucketIdIndex = partitionColumns.length - if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { - false - } else { - var i = partitionColumns.length - 1 - while (i >= 0) { - val dt = partitionColumns(i).dataType - if (key1.get(i, dt) != key2.get(i, dt)) return false - i -= 1 - } - true - } - } - - private def sortBasedWrite( - sorter: UnsafeKVExternalSorter, - iterator: Iterator[InternalRow], - getSortingKey: UnsafeProjection, - getOutputRow: UnsafeProjection, - getPartitionString: UnsafeProjection, - outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { - (key1, key2) => key1 != key2 - } else { - (key1, key2) => key1 == null || !sameBucket(key1, key2) - } - - val sortedIterator = sorter.sortedIterator() - var currentKey: UnsafeRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (needNewWriter(currentKey, sortedIterator.getKey)) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey, getPartitionString) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } - } - /** * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer( } def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) - var outputWritersCleared = false - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val getSortingKey = - UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) - - val sortingKeySchema = if (bucketSpec.isEmpty) { - StructType.fromAttributes(partitionColumns) - } else { // If it's bucketed, we should also consider bucket id as part of the key. - val fields = StructType.fromAttributes(partitionColumns) - .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) - StructType(fields) - } + val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns + + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // If there is no sorting columns, we set sorter to null and try the hash-based writing first, - // and fill the sorter if there are too many writers and we need to fall back on sorting. - // If there are sorting columns, then we have to sort the data anyway, and no need to try the - // hash-based writing first. - var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { - new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity } else { - null + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) } - while (iterator.hasNext && sorter == null) { - val inputRow = iterator.next() - // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. - val currentKey = getSortingKey(inputRow) - var currentWriter = outputWriters.get(currentKey) - - if (currentWriter == null) { - if (outputWriters.size < maxOpenFiles) { + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey, getPartitionString) - outputWriters.put(currentKey.copy(), currentWriter) - currentWriter.writeInternal(getOutputRow(inputRow)) - } else { - logInfo(s"Maximum partitions reached, falling back on sorting.") - sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - sorter.insertKV(currentKey, getOutputRow(inputRow)) } - } else { - currentWriter.writeInternal(getOutputRow(inputRow)) - } - } - // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort, or there are sorting columns and we need to sort the whole data set. - if (sorter != null) { - sortBasedWrite( - sorter, - iterator, - getSortingKey, - getOutputRow, - getPartitionString, - outputWriters) + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } } commitTask() @@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer( abortTask() throw new SparkException("Task failed while writing rows.", cause) } - - def clearOutputWriters(): Unit = { - if (!outputWritersCleared) { - outputWriters.asScala.values.foreach(_.close()) - outputWriters.clear() - outputWritersCleared = true - } - } - - def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { - case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c35f33132f..9f3607369c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -162,7 +162,6 @@ trait HadoopFsRelationProvider { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation - // TODO: expose bucket API to users. private[sql] def createRelation( sqlContext: SQLContext, paths: Array[String], @@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable { dataSchema: StructType, context: TaskAttemptContext): OutputWriter - // TODO: expose bucket API to users. private[sql] def newInstance( path: String, bucketId: Option[Int], @@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ - // TODO: expose bucket API to users. private[sql] def bucketSpec: Option[BucketSpec] = None private class FileStatusCache { |