aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-12-09 15:44:22 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-12-09 15:44:22 -0800
commitcf33a86285629abe72c1acf235b8bfa6057220a8 (patch)
treeabb07697888303338b3c481dd15e82a2a573e495 /core/src/main
parentd60ab5fd9b6af9aa5080a2d13b3589d8b79c5c5c (diff)
downloadspark-cf33a86285629abe72c1acf235b8bfa6057220a8.tar.gz
spark-cf33a86285629abe72c1acf235b8bfa6057220a8.tar.bz2
spark-cf33a86285629abe72c1acf235b8bfa6057220a8.zip
[SPARK-4105] retry the fetch or stage if shuffle block is corrupt
## What changes were proposed in this pull request? There is an outstanding issue that existed for a long time: Sometimes the shuffle blocks are corrupt and can't be decompressed. We recently hit this in three different workloads, sometimes we can reproduce it by every try, sometimes can't. I also found that when the corruption happened, the beginning and end of the blocks are correct, the corruption happen in the middle. There was one case that the string of block id is corrupt by one character. It seems that it's very likely the corruption is introduced by some weird machine/hardware, also the checksum (16 bits) in TCP is not strong enough to identify all the corruption. Unfortunately, Spark does not have checksum for shuffle blocks or broadcast, the job will fail if any corruption happen in the shuffle block from disk, or broadcast blocks during network. This PR try to detect the corruption after fetching shuffle blocks by decompressing them, because most of the compression already have checksum in them. It will retry the block, or failed with FetchFailure, so the previous stage could be retried on different (still random) machines. Checksum for broadcast will be added by another PR. ## How was this patch tested? Added unit tests Author: Davies Liu <davies@databricks.com> Closes #15923 from davies/detect_corrupt.
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala133
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala2
3 files changed, 100 insertions, 48 deletions
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index b9d83495d2..8b2e26cdd9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+ serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
- SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
-
- // Wrap the streams for compression and encryption based on configuration
- val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
- serializerManager.wrapStream(blockId, inputStream)
- }
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
- val recordIter = wrappedStreams.flatMap { wrappedStream =>
+ val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 269c12d6da..b720aaee7c 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,19 +17,21 @@
package org.apache.spark.storage
-import java.io.InputStream
+import java.io.{InputStream, IOException}
+import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy
+import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.control.NonFatal
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBufferOutputStream
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -47,8 +49,10 @@ import org.apache.spark.util.Utils
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
@@ -56,8 +60,10 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
- maxReqsInFlight: Int)
+ maxReqsInFlight: Int,
+ detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging {
import ShuffleBlockFetcherIterator._
@@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
- @volatile private[this] var currentResult: FetchResult = null
+ @volatile private[this] var currentResult: SuccessFetchResult = null
/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
/** Current number of requests in flight */
private[this] var reqsInFlight = 0
+ /**
+ * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
+ * at most once for those corrupted blocks.
+ */
+ private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
/**
@@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
- currentResult match {
- case SuccessFetchResult(_, _, _, buf, _) => buf.release()
- case _ =>
+ if (currentResult != null) {
+ currentResult.buf.release()
}
currentResult = null
}
@@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator(
*/
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
- val startFetchWait = System.currentTimeMillis()
- currentResult = results.take()
- val result = currentResult
- val stopFetchWait = System.currentTimeMillis()
- shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
-
- result match {
- case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
- if (address != blockManager.blockManagerId) {
- shuffleMetrics.incRemoteBytesRead(buf.size)
- shuffleMetrics.incRemoteBlocksFetched(1)
- }
- bytesInFlight -= size
- if (isNetworkReqDone) {
- reqsInFlight -= 1
- logDebug("Number of requests in flight " + reqsInFlight)
- }
- case _ =>
- }
- // Send fetch requests up to maxBytesInFlight
- fetchUpToMaxBytes()
- result match {
- case FailureFetchResult(blockId, address, e) =>
- throwFetchFailedException(blockId, address, e)
+ var result: FetchResult = null
+ var input: InputStream = null
+ // Take the next fetched result and try to decompress it to detect data corruption,
+ // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
+ // is also corrupt, so the previous stage could be retried.
+ // For local shuffle block, throw FailureFetchResult for the first IOException.
+ while (result == null) {
+ val startFetchWait = System.currentTimeMillis()
+ result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
- case SuccessFetchResult(blockId, address, _, buf, _) =>
- try {
- (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
- } catch {
- case NonFatal(t) =>
- throwFetchFailedException(blockId, address, t)
- }
+ result match {
+ case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
+ if (address != blockManager.blockManagerId) {
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ }
+ bytesInFlight -= size
+ if (isNetworkReqDone) {
+ reqsInFlight -= 1
+ logDebug("Number of requests in flight " + reqsInFlight)
+ }
+
+ val in = try {
+ buf.createInputStream()
+ } catch {
+ // The exception could only be throwed by local shuffle block
+ case e: IOException =>
+ assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+ logError("Failed to create input stream from local block", e)
+ buf.release()
+ throwFetchFailedException(blockId, address, e)
+ }
+
+ input = streamWrapper(blockId, in)
+ // Only copy the stream if it's wrapped by compression or encryption, also the size of
+ // block is small (the decompressed block is smaller than maxBytesInFlight)
+ if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
+ val originalInput = input
+ val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
+ try {
+ // Decompress the whole block at once to detect any corruption, which could increase
+ // the memory usage tne potential increase the chance of OOM.
+ // TODO: manage the memory used here, and spill it into disk in case of OOM.
+ Utils.copyStream(input, out)
+ out.close()
+ input = out.toChunkedByteBuffer.toInputStream(dispose = true)
+ } catch {
+ case e: IOException =>
+ buf.release()
+ if (buf.isInstanceOf[FileSegmentManagedBuffer]
+ || corruptedBlocks.contains(blockId)) {
+ throwFetchFailedException(blockId, address, e)
+ } else {
+ logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+ corruptedBlocks += blockId
+ fetchRequests += FetchRequest(address, Array((blockId, size)))
+ result = null
+ }
+ } finally {
+ // TODO: release the buf here to free memory earlier
+ originalInput.close()
+ in.close()
+ }
+ }
+
+ case FailureFetchResult(blockId, address, e) =>
+ throwFetchFailedException(blockId, address, e)
+ }
+
+ // Send fetch requests up to maxBytesInFlight
+ fetchUpToMaxBytes()
}
+
+ currentResult = result.asInstanceOf[SuccessFetchResult]
+ (currentResult.blockId, new BufferReleasingInputStream(input, this))
}
private def fetchUpToMaxBytes(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index da08661d13..7572cac393 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
* @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream
* in order to close any memory-mapped files which back the buffer.
*/
-private class ChunkedByteBufferInputStream(
+private[spark] class ChunkedByteBufferInputStream(
var chunkedByteBuffer: ChunkedByteBuffer,
dispose: Boolean)
extends InputStream {