aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2017-03-29 20:59:48 -0700
committerReynold Xin <rxin@databricks.com>2017-03-29 20:59:48 -0700
commit79636054f60dd639e9d326e1328717e97df13304 (patch)
tree4d0ab1db8bbbaf4864dfe10904645d7fe94ba339
parent60977889eaecdf28adc6164310eaa5afed488fa1 (diff)
downloadspark-79636054f60dd639e9d326e1328717e97df13304.tar.gz
spark-79636054f60dd639e9d326e1328717e97df13304.tar.bz2
spark-79636054f60dd639e9d326e1328717e97df13304.zip
[SPARK-20148][SQL] Extend the file commit API to allow subscribing to task commit messages
## What changes were proposed in this pull request? The internal FileCommitProtocol interface returns all task commit messages in bulk to the implementation when a job finishes. However, it is sometimes useful to access those messages before the job completes, so that the driver gets incremental progress updates before the job finishes. This adds an `onTaskCommit` listener to the internal api. ## How was this patch tested? Unit tests. cc rxin Author: Eric Liang <ekl@databricks.com> Closes #17475 from ericl/file-commit-api-ext.
-rw-r--r--core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala31
3 files changed, 53 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index 2394cf361c..7efa941636 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -121,6 +121,13 @@ abstract class FileCommitProtocol {
def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = {
fs.delete(path, recursive)
}
+
+ /**
+ * Called on the driver after a task commits. This can be used to access task commit messages
+ * before the job has finished. These same task commit messages will be passed to commitJob()
+ * if the entire job succeeds.
+ */
+ def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 7957224ce4..bda64d4b91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -80,6 +80,9 @@ object FileFormatWriter extends Logging {
""".stripMargin)
}
+ /** The result of a successful write task. */
+ private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String])
+
/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
@@ -172,8 +175,9 @@ object FileFormatWriter extends Logging {
global = false,
child = queryExecution.executedPlan).execute()
}
-
- val ret = sparkSession.sparkContext.runJob(rdd,
+ val ret = new Array[WriteTaskResult](rdd.partitions.length)
+ sparkSession.sparkContext.runJob(
+ rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
@@ -182,10 +186,16 @@ object FileFormatWriter extends Logging {
sparkAttemptNumber = taskContext.attemptNumber(),
committer,
iterator = iter)
+ },
+ 0 until rdd.partitions.length,
+ (index, res: WriteTaskResult) => {
+ committer.onTaskCommit(res.commitMsg)
+ ret(index) = res
})
- val commitMsgs = ret.map(_._1)
- val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment)
+ val commitMsgs = ret.map(_.commitMsg)
+ val updatedPartitions = ret.flatMap(_.updatedPartitions)
+ .distinct.map(PartitioningUtils.parsePathFragment)
committer.commitJob(job, commitMsgs)
logInfo(s"Job ${job.getJobID} committed.")
@@ -205,7 +215,7 @@ object FileFormatWriter extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int,
committer: FileCommitProtocol,
- iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = {
+ iterator: Iterator[InternalRow]): WriteTaskResult = {
val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
@@ -238,7 +248,7 @@ object FileFormatWriter extends Logging {
// Execute the task to write rows out and commit the task.
val outputPartitions = writeTask.execute(iterator)
writeTask.releaseResources()
- (committer.commitTask(taskAttemptContext), outputPartitions)
+ WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions)
})(catchBlock = {
// If there is an error, release resource and then abort the task
try {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 8287776f8f..7c71e7280c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -18,9 +18,12 @@
package org.apache.spark.sql.test
import java.io.File
+import java.util.concurrent.ConcurrentLinkedQueue
import org.scalatest.BeforeAndAfter
+import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
@@ -41,7 +44,6 @@ object LastOptions {
}
}
-
/** Dummy provider. */
class DefaultSource
extends RelationProvider
@@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema
}
}
+object MessageCapturingCommitProtocol {
+ val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]()
+}
+
+class MessageCapturingCommitProtocol(jobId: String, path: String)
+ extends HadoopMapReduceCommitProtocol(jobId, path) {
+
+ // captures commit messages for testing
+ override def onTaskCommit(msg: TaskCommitMessage): Unit = {
+ MessageCapturingCommitProtocol.commitMessages.offer(msg)
+ }
+}
+
+
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
@@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
}
+ test("write path implements onTaskCommit API correctly") {
+ withSQLConf(
+ "spark.sql.sources.commitProtocolClass" ->
+ classOf[MessageCapturingCommitProtocol].getCanonicalName) {
+ withTempDir { dir =>
+ val path = dir.getCanonicalPath
+ MessageCapturingCommitProtocol.commitMessages.clear()
+ spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
+ assert(MessageCapturingCommitProtocol.commitMessages.size() == 10)
+ }
+ }
+ }
+
test("read a data source that does not extend SchemaRelationProvider") {
val dfReader = spark.read
.option("from", "1")