diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala | 22 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala | 31 |
2 files changed, 46 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 7957224ce4..bda64d4b91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -80,6 +80,9 @@ object FileFormatWriter extends Logging { """.stripMargin) } + /** The result of a successful write task. */ + private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -172,8 +175,9 @@ object FileFormatWriter extends Logging { global = false, child = queryExecution.executedPlan).execute() } - - val ret = sparkSession.sparkContext.runJob(rdd, + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -182,10 +186,16 @@ object FileFormatWriter extends Logging { sparkAttemptNumber = taskContext.attemptNumber(), committer, iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res }) - val commitMsgs = ret.map(_._1) - val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") @@ -205,7 +215,7 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { + iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -238,7 +248,7 @@ object FileFormatWriter extends Logging { // Execute the task to write rows out and commit the task. val outputPartitions = writeTask.execute(iterator) writeTask.releaseResources() - (committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) })(catchBlock = { // If there is an error, release resource and then abort the task try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8287776f8f..7c71e7280c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.test import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ @@ -41,7 +44,6 @@ object LastOptions { } } - /** Dummy provider. */ class DefaultSource extends RelationProvider @@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema } } +object MessageCapturingCommitProtocol { + val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]() +} + +class MessageCapturingCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + // captures commit messages for testing + override def onTaskCommit(msg: TaskCommitMessage): Unit = { + MessageCapturingCommitProtocol.commitMessages.offer(msg) + } +} + + class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("write path implements onTaskCommit API correctly") { + withSQLConf( + "spark.sql.sources.commitProtocolClass" -> + classOf[MessageCapturingCommitProtocol].getCanonicalName) { + withTempDir { dir => + val path = dir.getCanonicalPath + MessageCapturingCommitProtocol.commitMessages.clear() + spark.range(10).repartition(10).write.mode("overwrite").parquet(path) + assert(MessageCapturingCommitProtocol.commitMessages.size() == 10) + } + } + } + test("read a data source that does not extend SchemaRelationProvider") { val dfReader = spark.read .option("from", "1") |