aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-02 14:35:44 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-02 14:35:44 -0800
commitb5a59a0fe2ea703ce2712561e7b9044f772660a2 (patch)
tree73d2355a4faefd6f91aff29fc5458e92b76b6092 /sql
parent9e01fe2ed1e834710f4ee6a02864ab0fcc528fef (diff)
downloadspark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.gz
spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.bz2
spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.zip
[SPARK-13601] call failure callbacks before writer.close()
## What changes were proposed in this pull request? In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close(). ## How was this patch tested? Added new unit tests. Author: Davies Liu <davies@databricks.com> Closes #11450 from davies/callback.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala95
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala57
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala28
3 files changed, 133 insertions, 47 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index c3db2a0af4..097e9c912b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer(
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
- val writer = newOutputWriter(getWorkPath)
+ var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)
- var writerClosed = false
-
// If anything below fails, we should abort the task.
try {
while (iterator.hasNext) {
@@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer(
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
+ // call failure callbacks first, so we could have a chance to cleanup the writer.
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
def commitTask(): Unit = {
try {
- assert(writer != null, "OutputWriter instance should have been initialized")
- if (!writerClosed) {
+ if (writer != null) {
writer.close()
- writerClosed = true
+ writer = null
}
super.commitTask()
} catch {
@@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer(
def abortTask(): Unit = {
try {
- if (!writerClosed) {
+ if (writer != null) {
writer.close()
- writerClosed = true
}
} finally {
super.abortTask()
@@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer(
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
- // If anything below fails, we should abort the task.
- try {
- // Sorts the data before write, so that we only need one writer at the same time.
- // TODO: inject a local sort operator in planning.
- val sorter = new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
-
- while (iterator.hasNext) {
- val currentRow = iterator.next()
- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
- }
+ // Sorts the data before write, so that we only need one writer at the same time.
+ // TODO: inject a local sort operator in planning.
+ val sorter = new UnsafeKVExternalSorter(
+ sortingKeySchema,
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+ }
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
+ val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
+ identity
+ } else {
+ UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
+ case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+ })
+ }
- val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
- identity
- } else {
- UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
- case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
- })
- }
+ val sortedIterator = sorter.sortedIterator()
- val sortedIterator = sorter.sortedIterator()
+ // If anything below fails, we should abort the task.
+ var currentWriter: OutputWriter = null
+ try {
var currentKey: UnsafeRow = null
- var currentWriter: OutputWriter = null
- try {
- while (sortedIterator.next()) {
- val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
- if (currentKey != nextKey) {
- if (currentWriter != null) {
- currentWriter.close()
- }
- currentKey = nextKey.copy()
- logDebug(s"Writing partition: $currentKey")
-
- currentWriter = newOutputWriter(currentKey, getPartitionString)
+ while (sortedIterator.next()) {
+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+ if (currentKey != nextKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
+ currentKey = nextKey.copy()
+ logDebug(s"Writing partition: $currentKey")
- currentWriter.writeInternal(sortedIterator.getValue)
+ currentWriter = newOutputWriter(currentKey, getPartitionString)
}
- } finally {
- if (currentWriter != null) { currentWriter.close() }
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
commitTask()
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
+ // call failure callbacks first, so we could have a chance to cleanup the writer.
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index 64c61a5092..64c27da475 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -30,6 +31,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
+ SimpleTextRelation.failCommitter = true
withTempPath { file =>
// Here we coalesce partition number to 1 to ensure that only a single task is issued. This
// prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
@@ -43,4 +45,59 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
}
}
+
+ test("call failure callbacks before close writer - default") {
+ SimpleTextRelation.failCommitter = false
+ withTempPath { file =>
+ // fail the job in the middle of writing
+ val divideByZero = udf((x: Int) => { x / (x - 1)})
+ val df = sqlContext.range(0, 10).select(divideByZero(col("id")))
+
+ SimpleTextRelation.callbackCalled = false
+ intercept[SparkException] {
+ df.write.format(dataSourceName).save(file.getCanonicalPath)
+ }
+ assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
+
+ test("failure callback of writer should not be called if failed before writing") {
+ SimpleTextRelation.failCommitter = false
+ withTempPath { file =>
+ // fail the job in the middle of writing
+ val divideByZero = udf((x: Int) => { x / (x - 1)})
+ val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), divideByZero(col("id")))
+
+ SimpleTextRelation.callbackCalled = false
+ intercept[SparkException] {
+ df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
+ }
+ assert(!SimpleTextRelation.callbackCalled,
+ "the callback of writer should not be called if job failed before writing")
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
+
+ test("call failure callbacks before close writer - partitioned") {
+ SimpleTextRelation.failCommitter = false
+ withTempPath { file =>
+ // fail the job in the middle of writing
+ val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), col("id"))
+
+ SimpleTextRelation.callbackCalled = false
+ SimpleTextRelation.failWriter = true
+ intercept[SparkException] {
+ df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
+ }
+ assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 9fc437bf88..9cdf1fc585 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{sources, Row, SQLContext}
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
@@ -199,6 +200,15 @@ object SimpleTextRelation {
// Used to test filter push-down
var pushedFilters: Set[Filter] = Set.empty
+
+ // Used to test failed committer
+ var failCommitter = false
+
+ // Used to test failed writer
+ var failWriter = false
+
+ // Used to test failure callback
+ var callbackCalled = false
}
/**
@@ -229,9 +239,25 @@ class CommitFailureTestRelation(
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) {
+ var failed = false
+ TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+ failed = true
+ SimpleTextRelation.callbackCalled = true
+ }
+
+ override def write(row: Row): Unit = {
+ if (SimpleTextRelation.failWriter) {
+ sys.error("Intentional task writer failure for testing purpose.")
+
+ }
+ super.write(row)
+ }
+
override def close(): Unit = {
+ if (SimpleTextRelation.failCommitter) {
+ sys.error("Intentional task commitment failure for testing purpose.")
+ }
super.close()
- sys.error("Intentional task commitment failure for testing purpose.")
}
}
}