From b5a59a0fe2ea703ce2712561e7b9044f772660a2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 2 Mar 2016 14:35:44 -0800 Subject: [SPARK-13601] call failure callbacks before writer.close() ## What changes were proposed in this pull request? In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close(). ## How was this patch tested? Added new unit tests. Author: Davies Liu Closes #11450 from davies/callback. --- .../execution/datasources/WriterContainer.scala | 95 +++++++++++----------- 1 file changed, 49 insertions(+), 46 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index c3db2a0af4..097e9c912b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) - val writer = newOutputWriter(getWorkPath) + var writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) - var writerClosed = false - // If anything below fails, we should abort the task. try { while (iterator.hasNext) { @@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer( } catch { case cause: Throwable => logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) abortTask() throw new SparkException("Task failed while writing rows.", cause) } def commitTask(): Unit = { try { - assert(writer != null, "OutputWriter instance should have been initialized") - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true + writer = null } super.commitTask() } catch { @@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer( def abortTask(): Unit = { try { - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true } } finally { super.abortTask() @@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer( val getPartitionString = UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // If anything below fails, we should abort the task. - try { - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") - logInfo(s"Sorting complete. Writing out partition files one at a time.") + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + val sortedIterator = sorter.sortedIterator() - val sortedIterator = sorter.sortedIterator() + // If anything below fails, we should abort the task. + var currentWriter: OutputWriter = null + try { var currentKey: UnsafeRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter = newOutputWriter(currentKey, getPartitionString) } - } finally { - if (currentWriter != null) { currentWriter.close() } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + if (currentWriter != null) { + currentWriter.close() + } abortTask() throw new SparkException("Task failed while writing rows.", cause) } -- cgit v1.2.3