aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-02-19 18:13:12 -0800
committerWenchen Fan <wenchen@databricks.com>2017-02-19 18:13:12 -0800
commit776b8f17cfc687a57c005a421a81e591c8d44a3f (patch)
tree7b034741adc5f765674e7ff6d3f303950a20c2cd /sql
parent65fe902e13153ad73a3026a66e73c93393df1abb (diff)
downloadspark-776b8f17cfc687a57c005a421a81e591c8d44a3f.tar.gz
spark-776b8f17cfc687a57c005a421a81e591c8d44a3f.tar.bz2
spark-776b8f17cfc687a57c005a421a81e591c8d44a3f.zip
[SPARK-19563][SQL] avoid unnecessary sort in FileFormatWriter
## What changes were proposed in this pull request? In `FileFormatWriter`, we will sort the input rows by partition columns and bucket id and sort columns, if we want to write data out partitioned or bucketed. However, if the data is already sorted, we will sort it again, which is unnecssary. This PR removes the sorting logic in `FileFormatWriter` and use `SortExec` instead. We will not add `SortExec` if the data is already sorted. ## How was this patch tested? I did a micro benchmark manually ``` val df = spark.range(10000000).select($"id", $"id" % 10 as "part").sort("part") spark.time(df.write.partitionBy("part").parquet("/tmp/test")) ``` The result was about 6.4 seconds before this PR, and is 5.7 seconds afterwards. close https://github.com/apache/spark/pull/16724 Author: Wenchen Fan <wenchen@databricks.com> Closes #16898 from cloud-fan/writer.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala189
1 files changed, 90 insertions, 99 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index be13cbc51a..644358493e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
+import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A helper object for writing FileFormat data out to a location. */
@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
val serializableHadoopConf: SerializableConfiguration,
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
- val partitionColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
- val bucketSpec: Option[BucketSpec],
+ val partitionColumns: Seq[Attribute],
+ val bucketIdExpression: Option[Expression],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
+ val allColumns = queryExecution.logical.output
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
+ val bucketIdExpression = bucketSpec.map { spec =>
+ val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
+ // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
+ // guarantee the data distribution is same between shuffle and bucketed data source, which
+ // enables us to only shuffle one side when join a bucketed table and a normal one.
+ HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+ }
+ val sortColumns = bucketSpec.toSeq.flatMap {
+ spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
+ }
+
// Note: prepareWrite has side effect. It sets "job".
val outputWriterFactory =
fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
uuid = UUID.randomUUID().toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory,
- allColumns = queryExecution.logical.output,
- partitionColumns = partitionColumns,
+ allColumns = allColumns,
dataColumns = dataColumns,
- bucketSpec = bucketSpec,
+ partitionColumns = partitionColumns,
+ bucketIdExpression = bucketIdExpression,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
)
+ // We should first sort by partition columns, then bucket id, and finally sorting columns.
+ val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
+ // the sort order doesn't matter
+ val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
+ val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
+ false
+ } else {
+ requiredOrdering.zip(actualOrdering).forall {
+ case (requiredOrder, childOutputOrder) =>
+ requiredOrder.semanticEquals(childOutputOrder)
+ }
+ }
+
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)
try {
- val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd,
+ val rdd = if (orderingMatched) {
+ queryExecution.toRdd
+ } else {
+ SortExec(
+ requiredOrdering.map(SortOrder(_, Ascending)),
+ global = false,
+ child = queryExecution.executedPlan).execute()
+ }
+
+ val ret = sparkSession.sparkContext.runJob(rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
committer.setupTask(taskAttemptContext)
val writeTask =
- if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) {
+ if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
} else {
new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
* multiple directories (partitions) or files (bucketing).
*/
private class DynamicPartitionWriteTask(
- description: WriteJobDescription,
+ desc: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {
// currentWriter is initialized whenever we see a new key
private var currentWriter: OutputWriter = _
- private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
- spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get)
- }
-
- private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
- spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get)
- }
-
- private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec =>
- // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
- // guarantee the data distribution is same between shuffle and bucketed data source, which
- // enables us to only shuffle one side when join a bucketed table and a normal one.
- HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
- }
-
- /** Expressions that given a partition key build a string like: col1=val/col2=val/... */
- private def partitionStringExpression: Seq[Expression] = {
- description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+ /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
+ private def partitionPathExpression: Seq[Expression] = {
+ desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
// TODO: use correct timezone for partition values.
val escaped = ScalaUDF(
ExternalCatalogUtils.escapePathName _,
@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
}
/**
- * Open and returns a new OutputWriter given a partition key and optional bucket id.
+ * Opens a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
- * @param key vaues for fields consisting of partition keys for the current row
- * @param partString a function that projects the partition values into a string
+ * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the
+ * current row.
+ * @param getPartitionPath a function that projects the partition values into a path string.
* @param fileCounter the number of files that have been written in the past for this specific
* partition. This is used to limit the max number of records written for a
* single file. The value should start from 0.
+ * @param updatedPartitions the set of updated partition paths, we should add the new partition
+ * path of this writer to it.
*/
private def newOutputWriter(
- key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
- val partDir =
- if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
+ partColsAndBucketId: InternalRow,
+ getPartitionPath: UnsafeProjection,
+ fileCounter: Int,
+ updatedPartitions: mutable.Set[String]): Unit = {
+ val partDir = if (desc.partitionColumns.isEmpty) {
+ None
+ } else {
+ Option(getPartitionPath(partColsAndBucketId).getString(0))
+ }
+ partDir.foreach(updatedPartitions.add)
- // If the bucket spec is defined, the bucket column is right after the partition columns
- val bucketId = if (description.bucketSpec.isDefined) {
- BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length))
+ // If the bucketId expression is defined, the bucketId column is right after the partition
+ // columns.
+ val bucketId = if (desc.bucketIdExpression.isDefined) {
+ BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
} else {
""
}
// This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketId.c$fileCounter%03d" +
- description.outputWriterFactory.getFileExtension(taskAttemptContext)
+ desc.outputWriterFactory.getFileExtension(taskAttemptContext)
val customPath = partDir match {
case Some(dir) =>
- description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+ desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
case _ =>
None
}
@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
- currentWriter = description.outputWriterFactory.newInstance(
+ currentWriter = desc.outputWriterFactory.newInstance(
path = path,
- dataSchema = description.dataColumns.toStructType,
+ dataSchema = desc.dataColumns.toStructType,
context = taskAttemptContext)
}
override def execute(iter: Iterator[InternalRow]): Set[String] = {
- // We should first sort by partition columns, then bucket id, and finally sorting columns.
- val sortingExpressions: Seq[Expression] =
- description.partitionColumns ++ bucketIdExpression ++ sortColumns
- val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)
-
- val sortingKeySchema = StructType(sortingExpressions.map {
- case a: Attribute => StructField(a.name, a.dataType, a.nullable)
- // The sorting expressions are all `Attribute` except bucket id.
- case _ => StructField("bucketId", IntegerType, nullable = false)
- })
-
- // Returns the data columns to be written given an input row
- val getOutputRow = UnsafeProjection.create(
- description.dataColumns, description.allColumns)
-
- // Returns the partition path given a partition key.
- val getPartitionStringFunc = UnsafeProjection.create(
- Seq(Concat(partitionStringExpression)), description.partitionColumns)
-
- // Sorts the data before write, so that we only need one writer at the same time.
- val sorter = new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(description.dataColumns),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
-
- while (iter.hasNext) {
- val currentRow = iter.next()
- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
- }
+ val getPartitionColsAndBucketId = UnsafeProjection.create(
+ desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
- val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
- identity
- } else {
- UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
- case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
- })
- }
+ // Generates the partition path given the row generated by `getPartitionColsAndBucketId`.
+ val getPartPath = UnsafeProjection.create(
+ Seq(Concat(partitionPathExpression)), desc.partitionColumns)
- val sortedIterator = sorter.sortedIterator()
+ // Returns the data columns to be written given an input row
+ val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
// If anything below fails, we should abort the task.
var recordsInFile: Long = 0L
var fileCounter = 0
- var currentKey: UnsafeRow = null
+ var currentPartColsAndBucketId: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
- while (sortedIterator.next()) {
- val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
- if (currentKey != nextKey) {
- // See a new key - write to a new partition (new file).
- currentKey = nextKey.copy()
- logDebug(s"Writing partition: $currentKey")
+ for (row <- iter) {
+ val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
+ if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
+ // See a new partition or bucket - write to a new partition dir (or a new bucket file).
+ currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
+ logDebug(s"Writing partition: $currentPartColsAndBucketId")
recordsInFile = 0
fileCounter = 0
releaseResources()
- newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
- val partitionPath = getPartitionStringFunc(currentKey).getString(0)
- if (partitionPath.nonEmpty) {
- updatedPartitions.add(partitionPath)
- }
- } else if (description.maxRecordsPerFile > 0 &&
- recordsInFile >= description.maxRecordsPerFile) {
+ newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
+ } else if (desc.maxRecordsPerFile > 0 &&
+ recordsInFile >= desc.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
recordsInFile = 0
@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
releaseResources()
- newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
+ newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
}
- currentWriter.write(sortedIterator.getValue)
+ currentWriter.write(getOutputRow(row))
recordsInFile += 1
}
releaseResources()