aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-02 14:35:44 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-02 14:35:44 -0800
commitb5a59a0fe2ea703ce2712561e7b9044f772660a2 (patch)
tree73d2355a4faefd6f91aff29fc5458e92b76b6092
parent9e01fe2ed1e834710f4ee6a02864ab0fcc528fef (diff)
downloadspark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.gz
spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.bz2
spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.zip
[SPARK-13601] call failure callbacks before writer.close()
## What changes were proposed in this pull request? In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close(). ## How was this patch tested? Added new unit tests. Author: Davies Liu <davies@databricks.com> Closes #11450 from davies/callback.
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala91
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala95
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala57
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala28
7 files changed, 271 insertions, 53 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 65f6f741f7..7e96040bc4 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -53,6 +53,9 @@ private[spark] class TaskContextImpl(
// Whether the task has completed.
@volatile private var completed: Boolean = false
+ // Whether the task has failed.
+ @volatile private var failed: Boolean = false
+
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
@@ -63,10 +66,13 @@ private[spark] class TaskContextImpl(
this
}
- /** Marks the task as completed and triggers the failure listeners. */
+ /** Marks the task as failed and triggers the failure listeners. */
private[spark] def markTaskFailed(error: Throwable): Unit = {
+ // failure callbacks should only be called once
+ if (failed) return
+ failed = true
val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
+ // Process failure callbacks in the reverse order of registration
onFailureCallbacks.reverse.foreach { listener =>
try {
listener.onTaskFailure(this, error)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index e00b9f6cfd..91460dc406 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
require(writer != null, "Unable to obtain RecordWriter")
var recordsWritten = 0L
- Utils.tryWithSafeFinally {
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
@@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.open()
var recordsWritten = 0L
- Utils.tryWithSafeFinally {
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 6103a10ccc..cfe247c668 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging {
* exception from the original `out.write` call.
*/
def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
- // It would be nice to find a method on Try that did this
var originalThrowable: Throwable = null
try {
block
@@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Execute a block of code, call the failure callbacks before finally block if there is any
+ * exceptions happen. But if exceptions happen in the finally block, do not suppress the original
+ * exception.
+ *
+ * This is primarily an issue with `finally { out.close() }` blocks, where
+ * close needs to be called to clean up `out`, but if an exception happened
+ * in `out.write`, it's likely `out` may be corrupted and `out.close` will
+ * fail as well. This would then suppress the original/likely more meaningful
+ * exception from the original `out.write` call.
+ */
+ def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
+ var originalThrowable: Throwable = null
+ try {
+ block
+ } catch {
+ case t: Throwable =>
+ // Purposefully not using NonFatal, because even fatal exceptions
+ // we don't want to have our finallyBlock suppress
+ originalThrowable = t
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
+ throw originalThrowable
+ } finally {
+ try {
+ finallyBlock
+ } catch {
+ case t: Throwable =>
+ if (originalThrowable != null) {
+ originalThrowable.addSuppressed(t)
+ logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
+ throw originalThrowable
+ } else {
+ throw t
+ }
+ }
+ }
+ }
+
/** Default filtering function for finding call sites using `getCallSite`. */
private def sparkInternalExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the internal Spark API's
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 16e2d2e636..7d51538d92 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import java.io.IOException
+
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random
@@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext}
import org.apache.hadoop.util.Progressable
-import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite}
+import org.apache.spark._
+import org.apache.spark.Partitioner
import org.apache.spark.util.Utils
class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
@@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(FakeOutputCommitter.ran, "OutputCommitter was never called")
}
+ test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") {
+ val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
+
+ FakeWriterWithCallback.calledBy = ""
+ FakeWriterWithCallback.exception = null
+ val e = intercept[SparkException] {
+ pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored")
+ }
+ assert(e.getMessage contains "failed to write")
+
+ assert(FakeWriterWithCallback.calledBy === "write,callback,close")
+ assert(FakeWriterWithCallback.exception != null, "exception should be captured")
+ assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
+ }
+
+ test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") {
+ val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
+ val conf = new JobConf()
+
+ FakeWriterWithCallback.calledBy = ""
+ FakeWriterWithCallback.exception = null
+ val e = intercept[SparkException] {
+ pairs.saveAsHadoopFile(
+ "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf)
+ }
+ assert(e.getMessage contains "failed to write")
+
+ assert(FakeWriterWithCallback.calledBy === "write,callback,close")
+ assert(FakeWriterWithCallback.exception != null, "exception should be captured")
+ assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
+ }
+
test("lookup") {
val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7)))
@@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() {
}
}
+object FakeWriterWithCallback {
+ var calledBy: String = ""
+ var exception: Throwable = _
+
+ def onFailure(ctx: TaskContext, e: Throwable): Unit = {
+ calledBy += "callback,"
+ exception = e
+ }
+}
+
+class FakeWriterWithCallback extends FakeWriter {
+
+ override def close(p1: Reporter): Unit = {
+ FakeWriterWithCallback.calledBy += "close"
+ }
+
+ override def write(p1: Integer, p2: Integer): Unit = {
+ FakeWriterWithCallback.calledBy += "write,"
+ TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+ FakeWriterWithCallback.onFailure(t, e)
+ }
+ throw new IOException("failed to write")
+ }
+}
+
+class FakeFormatWithCallback() extends FakeOutputFormat {
+ override def getRecordWriter(
+ ignored: FileSystem,
+ job: JobConf, name: String,
+ progress: Progressable): RecordWriter[Integer, Integer] = {
+ new FakeWriterWithCallback()
+ }
+}
+
+class NewFakeWriterWithCallback extends NewFakeWriter {
+ override def close(p1: NewTaskAttempContext): Unit = {
+ FakeWriterWithCallback.calledBy += "close"
+ }
+
+ override def write(p1: Integer, p2: Integer): Unit = {
+ FakeWriterWithCallback.calledBy += "write,"
+ TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+ FakeWriterWithCallback.onFailure(t, e)
+ }
+ throw new IOException("failed to write")
+ }
+}
+
+class NewFakeFormatWithCallback() extends NewFakeFormat {
+ override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
+ new NewFakeWriterWithCallback()
+ }
+}
+
class ConfigTestFormat() extends NewFakeFormat() with Configurable {
var setConfCalled = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index c3db2a0af4..097e9c912b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer(
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
- val writer = newOutputWriter(getWorkPath)
+ var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)
- var writerClosed = false
-
// If anything below fails, we should abort the task.
try {
while (iterator.hasNext) {
@@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer(
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
+ // call failure callbacks first, so we could have a chance to cleanup the writer.
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
def commitTask(): Unit = {
try {
- assert(writer != null, "OutputWriter instance should have been initialized")
- if (!writerClosed) {
+ if (writer != null) {
writer.close()
- writerClosed = true
+ writer = null
}
super.commitTask()
} catch {
@@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer(
def abortTask(): Unit = {
try {
- if (!writerClosed) {
+ if (writer != null) {
writer.close()
- writerClosed = true
}
} finally {
super.abortTask()
@@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer(
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
- // If anything below fails, we should abort the task.
- try {
- // Sorts the data before write, so that we only need one writer at the same time.
- // TODO: inject a local sort operator in planning.
- val sorter = new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
-
- while (iterator.hasNext) {
- val currentRow = iterator.next()
- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
- }
+ // Sorts the data before write, so that we only need one writer at the same time.
+ // TODO: inject a local sort operator in planning.
+ val sorter = new UnsafeKVExternalSorter(
+ sortingKeySchema,
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+ }
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
+ val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
+ identity
+ } else {
+ UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
+ case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+ })
+ }
- val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
- identity
- } else {
- UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
- case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
- })
- }
+ val sortedIterator = sorter.sortedIterator()
- val sortedIterator = sorter.sortedIterator()
+ // If anything below fails, we should abort the task.
+ var currentWriter: OutputWriter = null
+ try {
var currentKey: UnsafeRow = null
- var currentWriter: OutputWriter = null
- try {
- while (sortedIterator.next()) {
- val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
- if (currentKey != nextKey) {
- if (currentWriter != null) {
- currentWriter.close()
- }
- currentKey = nextKey.copy()
- logDebug(s"Writing partition: $currentKey")
-
- currentWriter = newOutputWriter(currentKey, getPartitionString)
+ while (sortedIterator.next()) {
+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+ if (currentKey != nextKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
+ currentKey = nextKey.copy()
+ logDebug(s"Writing partition: $currentKey")
- currentWriter.writeInternal(sortedIterator.getValue)
+ currentWriter = newOutputWriter(currentKey, getPartitionString)
}
- } finally {
- if (currentWriter != null) { currentWriter.close() }
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
commitTask()
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
+ // call failure callbacks first, so we could have a chance to cleanup the writer.
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index 64c61a5092..64c27da475 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -30,6 +31,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
+ SimpleTextRelation.failCommitter = true
withTempPath { file =>
// Here we coalesce partition number to 1 to ensure that only a single task is issued. This
// prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
@@ -43,4 +45,59 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
}
}
+
+ test("call failure callbacks before close writer - default") {
+ SimpleTextRelation.failCommitter = false
+ withTempPath { file =>
+ // fail the job in the middle of writing
+ val divideByZero = udf((x: Int) => { x / (x - 1)})
+ val df = sqlContext.range(0, 10).select(divideByZero(col("id")))
+
+ SimpleTextRelation.callbackCalled = false
+ intercept[SparkException] {
+ df.write.format(dataSourceName).save(file.getCanonicalPath)
+ }
+ assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
+
+ test("failure callback of writer should not be called if failed before writing") {
+ SimpleTextRelation.failCommitter = false
+ withTempPath { file =>
+ // fail the job in the middle of writing
+ val divideByZero = udf((x: Int) => { x / (x - 1)})
+ val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), divideByZero(col("id")))
+
+ SimpleTextRelation.callbackCalled = false
+ intercept[SparkException] {
+ df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
+ }
+ assert(!SimpleTextRelation.callbackCalled,
+ "the callback of writer should not be called if job failed before writing")
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
+
+ test("call failure callbacks before close writer - partitioned") {
+ SimpleTextRelation.failCommitter = false
+ withTempPath { file =>
+ // fail the job in the middle of writing
+ val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), col("id"))
+
+ SimpleTextRelation.callbackCalled = false
+ SimpleTextRelation.failWriter = true
+ intercept[SparkException] {
+ df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
+ }
+ assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 9fc437bf88..9cdf1fc585 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{sources, Row, SQLContext}
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
@@ -199,6 +200,15 @@ object SimpleTextRelation {
// Used to test filter push-down
var pushedFilters: Set[Filter] = Set.empty
+
+ // Used to test failed committer
+ var failCommitter = false
+
+ // Used to test failed writer
+ var failWriter = false
+
+ // Used to test failure callback
+ var callbackCalled = false
}
/**
@@ -229,9 +239,25 @@ class CommitFailureTestRelation(
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) {
+ var failed = false
+ TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+ failed = true
+ SimpleTextRelation.callbackCalled = true
+ }
+
+ override def write(row: Row): Unit = {
+ if (SimpleTextRelation.failWriter) {
+ sys.error("Intentional task writer failure for testing purpose.")
+
+ }
+ super.write(row)
+ }
+
override def close(): Unit = {
+ if (SimpleTextRelation.failCommitter) {
+ sys.error("Intentional task commitment failure for testing purpose.")
+ }
super.close()
- sys.error("Intentional task commitment failure for testing purpose.")
}
}
}