From b5a59a0fe2ea703ce2712561e7b9044f772660a2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 2 Mar 2016 14:35:44 -0800 Subject: [SPARK-13601] call failure callbacks before writer.close() ## What changes were proposed in this pull request? In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close(). ## How was this patch tested? Added new unit tests. Author: Davies Liu Closes #11450 from davies/callback. --- .../scala/org/apache/spark/TaskContextImpl.scala | 10 ++- .../org/apache/spark/rdd/PairRDDFunctions.scala | 4 +- .../main/scala/org/apache/spark/util/Utils.scala | 39 +++++++++- .../apache/spark/rdd/PairRDDFunctionsSuite.scala | 91 +++++++++++++++++++++- 4 files changed, 138 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 65f6f741f7..7e96040bc4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -53,6 +53,9 @@ private[spark] class TaskContextImpl( // Whether the task has completed. @volatile private var completed: Boolean = false + // Whether the task has failed. + @volatile private var failed: Boolean = false + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -63,10 +66,13 @@ private[spark] class TaskContextImpl( this } - /** Marks the task as completed and triggers the failure listeners. */ + /** Marks the task as failed and triggers the failure listeners. */ private[spark] def markTaskFailed(error: Throwable): Unit = { + // failure callbacks should only be called once + if (failed) return + failed = true val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration + // Process failure callbacks in the reverse order of registration onFailureCallbacks.reverse.foreach { listener => try { listener.onTaskFailure(this, error) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index e00b9f6cfd..91460dc406 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - Utils.tryWithSafeFinally { + Utils.tryWithSafeFinallyAndFailureCallbacks { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.open() var recordsWritten = 0L - Utils.tryWithSafeFinally { + Utils.tryWithSafeFinallyAndFailureCallbacks { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6103a10ccc..cfe247c668 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging { * exception from the original `out.write` call. */ def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { - // It would be nice to find a method on Try that did this var originalThrowable: Throwable = null try { block @@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code, call the failure callbacks before finally block if there is any + * exceptions happen. But if exceptions happen in the finally block, do not suppress the original + * exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = { + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t) + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def sparkInternalExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the internal Spark API's diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 16e2d2e636..7d51538d92 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.io.IOException + import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random @@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable -import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} +import org.apache.spark._ +import org.apache.spark.Partitioner import org.apache.spark.util.Utils class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { @@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } + test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + FakeWriterWithCallback.calledBy = "" + FakeWriterWithCallback.exception = null + val e = intercept[SparkException] { + pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored") + } + assert(e.getMessage contains "failed to write") + + assert(FakeWriterWithCallback.calledBy === "write,callback,close") + assert(FakeWriterWithCallback.exception != null, "exception should be captured") + assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") + } + + test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val conf = new JobConf() + + FakeWriterWithCallback.calledBy = "" + FakeWriterWithCallback.exception = null + val e = intercept[SparkException] { + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) + } + assert(e.getMessage contains "failed to write") + + assert(FakeWriterWithCallback.calledBy === "write,callback,close") + assert(FakeWriterWithCallback.exception != null, "exception should be captured") + assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) @@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { } } +object FakeWriterWithCallback { + var calledBy: String = "" + var exception: Throwable = _ + + def onFailure(ctx: TaskContext, e: Throwable): Unit = { + calledBy += "callback," + exception = e + } +} + +class FakeWriterWithCallback extends FakeWriter { + + override def close(p1: Reporter): Unit = { + FakeWriterWithCallback.calledBy += "close" + } + + override def write(p1: Integer, p2: Integer): Unit = { + FakeWriterWithCallback.calledBy += "write," + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + FakeWriterWithCallback.onFailure(t, e) + } + throw new IOException("failed to write") + } +} + +class FakeFormatWithCallback() extends FakeOutputFormat { + override def getRecordWriter( + ignored: FileSystem, + job: JobConf, name: String, + progress: Progressable): RecordWriter[Integer, Integer] = { + new FakeWriterWithCallback() + } +} + +class NewFakeWriterWithCallback extends NewFakeWriter { + override def close(p1: NewTaskAttempContext): Unit = { + FakeWriterWithCallback.calledBy += "close" + } + + override def write(p1: Integer, p2: Integer): Unit = { + FakeWriterWithCallback.calledBy += "write," + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + FakeWriterWithCallback.onFailure(t, e) + } + throw new IOException("failed to write") + } +} + +class NewFakeFormatWithCallback() extends NewFakeFormat { + override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { + new NewFakeWriterWithCallback() + } +} + class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false -- cgit v1.2.3