aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-02 14:35:44 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-02 14:35:44 -0800
commitb5a59a0fe2ea703ce2712561e7b9044f772660a2 (patch)
tree73d2355a4faefd6f91aff29fc5458e92b76b6092 /core
parent9e01fe2ed1e834710f4ee6a02864ab0fcc528fef (diff)
downloadspark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.gz
spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.bz2
spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.zip
[SPARK-13601] call failure callbacks before writer.close()
## What changes were proposed in this pull request? In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close(). ## How was this patch tested? Added new unit tests. Author: Davies Liu <davies@databricks.com> Closes #11450 from davies/callback.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala91
4 files changed, 138 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 65f6f741f7..7e96040bc4 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -53,6 +53,9 @@ private[spark] class TaskContextImpl(
// Whether the task has completed.
@volatile private var completed: Boolean = false
+ // Whether the task has failed.
+ @volatile private var failed: Boolean = false
+
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
@@ -63,10 +66,13 @@ private[spark] class TaskContextImpl(
this
}
- /** Marks the task as completed and triggers the failure listeners. */
+ /** Marks the task as failed and triggers the failure listeners. */
private[spark] def markTaskFailed(error: Throwable): Unit = {
+ // failure callbacks should only be called once
+ if (failed) return
+ failed = true
val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
+ // Process failure callbacks in the reverse order of registration
onFailureCallbacks.reverse.foreach { listener =>
try {
listener.onTaskFailure(this, error)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index e00b9f6cfd..91460dc406 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
require(writer != null, "Unable to obtain RecordWriter")
var recordsWritten = 0L
- Utils.tryWithSafeFinally {
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
@@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.open()
var recordsWritten = 0L
- Utils.tryWithSafeFinally {
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 6103a10ccc..cfe247c668 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging {
* exception from the original `out.write` call.
*/
def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
- // It would be nice to find a method on Try that did this
var originalThrowable: Throwable = null
try {
block
@@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Execute a block of code, call the failure callbacks before finally block if there is any
+ * exceptions happen. But if exceptions happen in the finally block, do not suppress the original
+ * exception.
+ *
+ * This is primarily an issue with `finally { out.close() }` blocks, where
+ * close needs to be called to clean up `out`, but if an exception happened
+ * in `out.write`, it's likely `out` may be corrupted and `out.close` will
+ * fail as well. This would then suppress the original/likely more meaningful
+ * exception from the original `out.write` call.
+ */
+ def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
+ var originalThrowable: Throwable = null
+ try {
+ block
+ } catch {
+ case t: Throwable =>
+ // Purposefully not using NonFatal, because even fatal exceptions
+ // we don't want to have our finallyBlock suppress
+ originalThrowable = t
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
+ throw originalThrowable
+ } finally {
+ try {
+ finallyBlock
+ } catch {
+ case t: Throwable =>
+ if (originalThrowable != null) {
+ originalThrowable.addSuppressed(t)
+ logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
+ throw originalThrowable
+ } else {
+ throw t
+ }
+ }
+ }
+ }
+
/** Default filtering function for finding call sites using `getCallSite`. */
private def sparkInternalExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the internal Spark API's
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 16e2d2e636..7d51538d92 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import java.io.IOException
+
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random
@@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext}
import org.apache.hadoop.util.Progressable
-import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite}
+import org.apache.spark._
+import org.apache.spark.Partitioner
import org.apache.spark.util.Utils
class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
@@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(FakeOutputCommitter.ran, "OutputCommitter was never called")
}
+ test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") {
+ val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
+
+ FakeWriterWithCallback.calledBy = ""
+ FakeWriterWithCallback.exception = null
+ val e = intercept[SparkException] {
+ pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored")
+ }
+ assert(e.getMessage contains "failed to write")
+
+ assert(FakeWriterWithCallback.calledBy === "write,callback,close")
+ assert(FakeWriterWithCallback.exception != null, "exception should be captured")
+ assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
+ }
+
+ test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") {
+ val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
+ val conf = new JobConf()
+
+ FakeWriterWithCallback.calledBy = ""
+ FakeWriterWithCallback.exception = null
+ val e = intercept[SparkException] {
+ pairs.saveAsHadoopFile(
+ "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf)
+ }
+ assert(e.getMessage contains "failed to write")
+
+ assert(FakeWriterWithCallback.calledBy === "write,callback,close")
+ assert(FakeWriterWithCallback.exception != null, "exception should be captured")
+ assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
+ }
+
test("lookup") {
val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7)))
@@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() {
}
}
+object FakeWriterWithCallback {
+ var calledBy: String = ""
+ var exception: Throwable = _
+
+ def onFailure(ctx: TaskContext, e: Throwable): Unit = {
+ calledBy += "callback,"
+ exception = e
+ }
+}
+
+class FakeWriterWithCallback extends FakeWriter {
+
+ override def close(p1: Reporter): Unit = {
+ FakeWriterWithCallback.calledBy += "close"
+ }
+
+ override def write(p1: Integer, p2: Integer): Unit = {
+ FakeWriterWithCallback.calledBy += "write,"
+ TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+ FakeWriterWithCallback.onFailure(t, e)
+ }
+ throw new IOException("failed to write")
+ }
+}
+
+class FakeFormatWithCallback() extends FakeOutputFormat {
+ override def getRecordWriter(
+ ignored: FileSystem,
+ job: JobConf, name: String,
+ progress: Progressable): RecordWriter[Integer, Integer] = {
+ new FakeWriterWithCallback()
+ }
+}
+
+class NewFakeWriterWithCallback extends NewFakeWriter {
+ override def close(p1: NewTaskAttempContext): Unit = {
+ FakeWriterWithCallback.calledBy += "close"
+ }
+
+ override def write(p1: Integer, p2: Integer): Unit = {
+ FakeWriterWithCallback.calledBy += "write,"
+ TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+ FakeWriterWithCallback.onFailure(t, e)
+ }
+ throw new IOException("failed to write")
+ }
+}
+
+class NewFakeFormatWithCallback() extends NewFakeFormat {
+ override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
+ new NewFakeWriterWithCallback()
+ }
+}
+
class ConfigTestFormat() extends NewFakeFormat() with Configurable {
var setConfCalled = false