diff options
author | Davies Liu <davies@databricks.com> | 2016-03-02 14:35:44 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-02 14:35:44 -0800 |
commit | b5a59a0fe2ea703ce2712561e7b9044f772660a2 (patch) | |
tree | 73d2355a4faefd6f91aff29fc5458e92b76b6092 /sql | |
parent | 9e01fe2ed1e834710f4ee6a02864ab0fcc528fef (diff) | |
download | spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.gz spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.tar.bz2 spark-b5a59a0fe2ea703ce2712561e7b9044f772660a2.zip |
[SPARK-13601] call failure callbacks before writer.close()
## What changes were proposed in this pull request?
In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close().
## How was this patch tested?
Added new unit tests.
Author: Davies Liu <davies@databricks.com>
Closes #11450 from davies/callback.
Diffstat (limited to 'sql')
3 files changed, 133 insertions, 47 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index c3db2a0af4..097e9c912b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) - val writer = newOutputWriter(getWorkPath) + var writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) - var writerClosed = false - // If anything below fails, we should abort the task. try { while (iterator.hasNext) { @@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer( } catch { case cause: Throwable => logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) abortTask() throw new SparkException("Task failed while writing rows.", cause) } def commitTask(): Unit = { try { - assert(writer != null, "OutputWriter instance should have been initialized") - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true + writer = null } super.commitTask() } catch { @@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer( def abortTask(): Unit = { try { - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true } } finally { super.abortTask() @@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer( val getPartitionString = UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // If anything below fails, we should abort the task. - try { - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") - logInfo(s"Sorting complete. Writing out partition files one at a time.") + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + val sortedIterator = sorter.sortedIterator() - val sortedIterator = sorter.sortedIterator() + // If anything below fails, we should abort the task. + var currentWriter: OutputWriter = null + try { var currentKey: UnsafeRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter = newOutputWriter(currentKey, getPartitionString) } - } finally { - if (currentWriter != null) { currentWriter.close() } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + if (currentWriter != null) { + currentWriter.close() + } abortTask() throw new SparkException("Task failed while writing rows.", cause) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index 64c61a5092..64c27da475 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -30,6 +31,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + SimpleTextRelation.failCommitter = true withTempPath { file => // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` @@ -43,4 +45,59 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) } } + + test("call failure callbacks before close writer - default") { + SimpleTextRelation.failCommitter = false + withTempPath { file => + // fail the job in the middle of writing + val divideByZero = udf((x: Int) => { x / (x - 1)}) + val df = sqlContext.range(0, 10).select(divideByZero(col("id"))) + + SimpleTextRelation.callbackCalled = false + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + assert(SimpleTextRelation.callbackCalled, "failure callback should be called") + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } + + test("failure callback of writer should not be called if failed before writing") { + SimpleTextRelation.failCommitter = false + withTempPath { file => + // fail the job in the middle of writing + val divideByZero = udf((x: Int) => { x / (x - 1)}) + val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), divideByZero(col("id"))) + + SimpleTextRelation.callbackCalled = false + intercept[SparkException] { + df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath) + } + assert(!SimpleTextRelation.callbackCalled, + "the callback of writer should not be called if job failed before writing") + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } + + test("call failure callbacks before close writer - partitioned") { + SimpleTextRelation.failCommitter = false + withTempPath { file => + // fail the job in the middle of writing + val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), col("id")) + + SimpleTextRelation.callbackCalled = false + SimpleTextRelation.failWriter = true + intercept[SparkException] { + df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath) + } + assert(SimpleTextRelation.callbackCalled, "failure callback should be called") + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9fc437bf88..9cdf1fc585 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{sources, Row, SQLContext} import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} @@ -199,6 +200,15 @@ object SimpleTextRelation { // Used to test filter push-down var pushedFilters: Set[Filter] = Set.empty + + // Used to test failed committer + var failCommitter = false + + // Used to test failed writer + var failWriter = false + + // Used to test failure callback + var callbackCalled = false } /** @@ -229,9 +239,25 @@ class CommitFailureTestRelation( dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { new SimpleTextOutputWriter(path, context) { + var failed = false + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + failed = true + SimpleTextRelation.callbackCalled = true + } + + override def write(row: Row): Unit = { + if (SimpleTextRelation.failWriter) { + sys.error("Intentional task writer failure for testing purpose.") + + } + super.write(row) + } + override def close(): Unit = { + if (SimpleTextRelation.failCommitter) { + sys.error("Intentional task commitment failure for testing purpose.") + } super.close() - sys.error("Intentional task commitment failure for testing purpose.") } } } |