aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-12-09 15:44:22 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-12-09 15:44:22 -0800
commitcf33a86285629abe72c1acf235b8bfa6057220a8 (patch)
treeabb07697888303338b3c481dd15e82a2a573e495 /core/src/test/scala/org
parentd60ab5fd9b6af9aa5080a2d13b3589d8b79c5c5c (diff)
downloadspark-cf33a86285629abe72c1acf235b8bfa6057220a8.tar.gz
spark-cf33a86285629abe72c1acf235b8bfa6057220a8.tar.bz2
spark-cf33a86285629abe72c1acf235b8bfa6057220a8.zip
[SPARK-4105] retry the fetch or stage if shuffle block is corrupt
## What changes were proposed in this pull request? There is an outstanding issue that existed for a long time: Sometimes the shuffle blocks are corrupt and can't be decompressed. We recently hit this in three different workloads, sometimes we can reproduce it by every try, sometimes can't. I also found that when the corruption happened, the beginning and end of the blocks are correct, the corruption happen in the middle. There was one case that the string of block id is corrupt by one character. It seems that it's very likely the corruption is introduced by some weird machine/hardware, also the checksum (16 bits) in TCP is not strong enough to identify all the corruption. Unfortunately, Spark does not have checksum for shuffle blocks or broadcast, the job will fail if any corruption happen in the shuffle block from disk, or broadcast blocks during network. This PR try to detect the corruption after fetching shuffle blocks by decompressing them, because most of the compression already have checksum in them. It will retry the block, or failed with FetchFailure, so the previous stage could be retried on different (still random) machines. Checksum for broadcast will be added by another PR. ## How was this patch tested? Added unit tests Author: Davies Liu <davies@databricks.com> Closes #15923 from davies/detect_corrupt.
Diffstat (limited to 'core/src/test/scala/org')
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala172
1 files changed, 163 insertions, 9 deletions
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index e3ec99685f..e56e440380 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.InputStream
+import java.io.{File, InputStream, IOException}
import java.util.concurrent.Semaphore
import scala.concurrent.ExecutionContext.Implicits.global
@@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._
-import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
@@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Create a mock managed buffer for testing
def createMockManagedBuffer(): ManagedBuffer = {
val mockManagedBuffer = mock(classOf[ManagedBuffer])
- when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
+ val in = mock(classOf[InputStream])
+ when(in.read(any())).thenReturn(1)
+ when(in.read(any(), any(), any())).thenReturn(1)
+ when(mockManagedBuffer.createInputStream()).thenReturn(in)
mockManagedBuffer
}
@@ -99,8 +103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
+ (_, in) => in,
48 * 1024 * 1024,
- Int.MaxValue)
+ Int.MaxValue,
+ true)
// 3 local blocks fetched in initialization
verify(blockManager, times(3)).getBlockData(any())
@@ -172,8 +178,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
+ (_, in) => in,
48 * 1024 * 1024,
- Int.MaxValue)
+ Int.MaxValue,
+ true)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
iterator.next()._2.close() // close() first block's input stream
@@ -201,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
- ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
- ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
- ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
)
// Semaphore to coordinate event sequence in two different threads.
@@ -235,8 +243,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
+ (_, in) => in,
48 * 1024 * 1024,
- Int.MaxValue)
+ Int.MaxValue,
+ true)
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
@@ -247,4 +257,148 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
intercept[FetchFailedException] { iterator.next() }
}
+
+ test("retry corrupt blocks") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+ )
+
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
+
+ val corruptStream = mock(classOf[InputStream])
+ when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+ val corruptBuffer = mock(classOf[ManagedBuffer])
+ when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+ val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
+
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
+ sem.release()
+ }
+ }
+ })
+
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+ val taskContext = TaskContext.empty()
+ val iterator = new ShuffleBlockFetcherIterator(
+ taskContext,
+ transfer,
+ blockManager,
+ blocksByAddress,
+ (_, in) => new LimitedInputStream(in, 100),
+ 48 * 1024 * 1024,
+ Int.MaxValue,
+ true)
+
+ // Continue only after the mock calls onBlockFetchFailure
+ sem.acquire()
+
+ // The first block should be returned without an exception
+ val (id1, _) = iterator.next()
+ assert(id1 === ShuffleBlockId(0, 0, 0))
+
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+ sem.release()
+ }
+ }
+ })
+
+ // The next block is corrupt local block (the second one is corrupt and retried)
+ intercept[FetchFailedException] { iterator.next() }
+
+ sem.acquire()
+ intercept[FetchFailedException] { iterator.next() }
+ }
+
+ test("retry corrupt blocks (disabled)") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+ )
+
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
+
+ val corruptStream = mock(classOf[InputStream])
+ when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+ val corruptBuffer = mock(classOf[ManagedBuffer])
+ when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, corruptBuffer)
+ sem.release()
+ }
+ }
+ })
+
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+ val taskContext = TaskContext.empty()
+ val iterator = new ShuffleBlockFetcherIterator(
+ taskContext,
+ transfer,
+ blockManager,
+ blocksByAddress,
+ (_, in) => new LimitedInputStream(in, 100),
+ 48 * 1024 * 1024,
+ Int.MaxValue,
+ false)
+
+ // Continue only after the mock calls onBlockFetchFailure
+ sem.acquire()
+
+ // The first block should be returned without an exception
+ val (id1, _) = iterator.next()
+ assert(id1 === ShuffleBlockId(0, 0, 0))
+ val (id2, _) = iterator.next()
+ assert(id2 === ShuffleBlockId(0, 1, 0))
+ val (id3, _) = iterator.next()
+ assert(id3 === ShuffleBlockId(0, 2, 0))
+ }
+
}