aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala47
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala84
2 files changed, 116 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 873330e136..bcf65e9d7e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -133,6 +133,9 @@ private[spark] class BlockManager(
private val compressRdds = conf.getBoolean("spark.rdd.compress", false)
// Whether to compress shuffle output temporarily spilled to disk
private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
+ // Max number of failures before this block manager refreshes the block locations from the driver
+ private val maxFailuresBeforeLocationRefresh =
+ conf.getInt("spark.block.failures.beforeLocationRefresh", 5)
private val slaveEndpoint = rpcEnv.setupEndpoint(
"BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next,
@@ -568,26 +571,46 @@ private[spark] class BlockManager(
def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
+ var runningFailureCount = 0
+ var totalFailureCount = 0
val locations = getLocations(blockId)
- var numFetchFailures = 0
- for (loc <- locations) {
+ val maxFetchFailures = locations.size
+ var locationIterator = locations.iterator
+ while (locationIterator.hasNext) {
+ val loc = locationIterator.next()
logDebug(s"Getting remote block $blockId from $loc")
val data = try {
blockTransferService.fetchBlockSync(
loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
} catch {
case NonFatal(e) =>
- numFetchFailures += 1
- if (numFetchFailures == locations.size) {
- // An exception is thrown while fetching this block from all locations
- throw new BlockFetchException(s"Failed to fetch block from" +
- s" ${locations.size} locations. Most recent failure cause:", e)
- } else {
- // This location failed, so we retry fetch from a different one by returning null here
- logWarning(s"Failed to fetch remote block $blockId " +
- s"from $loc (failed attempt $numFetchFailures)", e)
- null
+ runningFailureCount += 1
+ totalFailureCount += 1
+
+ if (totalFailureCount >= maxFetchFailures) {
+ // Give up trying anymore locations. Either we've tried all of the original locations,
+ // or we've refreshed the list of locations from the master, and have still
+ // hit failures after trying locations from the refreshed list.
+ throw new BlockFetchException(s"Failed to fetch block after" +
+ s" ${totalFailureCount} fetch failures. Most recent failure cause:", e)
+ }
+
+ logWarning(s"Failed to fetch remote block $blockId " +
+ s"from $loc (failed attempt $runningFailureCount)", e)
+
+ // If there is a large number of executors then locations list can contain a
+ // large number of stale entries causing a large number of retries that may
+ // take a significant amount of time. To get rid of these stale entries
+ // we refresh the block locations after a certain number of fetch failures
+ if (runningFailureCount >= maxFailuresBeforeLocationRefresh) {
+ locationIterator = getLocations(blockId).iterator
+ logDebug(s"Refreshed locations from the driver " +
+ s"after ${runningFailureCount} fetch failures.")
+ runningFailureCount = 0
}
+
+ // This location failed, so we retry fetch from a different one by returning null here
+ null
}
if (data != null) {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 42595c8cf2..dc4be14677 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -21,11 +21,12 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
+import scala.concurrent.Future
import scala.language.implicitConversions
import scala.language.postfixOps
import org.mockito.{Matchers => mc}
-import org.mockito.Mockito.{mock, when}
+import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest._
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
@@ -33,7 +34,10 @@ import org.scalatest.concurrent.Timeouts._
import org.apache.spark._
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.memory.StaticMemoryManager
+import org.apache.spark.network.{BlockDataManager, BlockTransferService}
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.NettyBlockTransferService
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
@@ -66,9 +70,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
private def makeBlockManager(
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER,
- master: BlockManagerMaster = this.master): BlockManager = {
+ master: BlockManagerMaster = this.master,
+ transferService: Option[BlockTransferService] = Option.empty): BlockManager = {
val serializer = new KryoSerializer(conf)
- val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
+ val transfer = transferService
+ .getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1))
val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
@@ -1287,6 +1293,78 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(store.getSingle("a1").isDefined, "a1 was not in store")
assert(store.getSingle("a3").isDefined, "a3 was not in store")
}
+
+ test("SPARK-13328: refresh block locations (fetch should fail after hitting a threshold)") {
+ val mockBlockTransferService =
+ new MockBlockTransferService(conf.getInt("spark.block.failures.beforeLocationRefresh", 5))
+ store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService))
+ store.putSingle("item", 999L, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ intercept[BlockFetchException] {
+ store.getRemoteBytes("item")
+ }
+ }
+
+ test("SPARK-13328: refresh block locations (fetch should succeed after location refresh)") {
+ val maxFailuresBeforeLocationRefresh =
+ conf.getInt("spark.block.failures.beforeLocationRefresh", 5)
+ val mockBlockManagerMaster = mock(classOf[BlockManagerMaster])
+ val mockBlockTransferService =
+ new MockBlockTransferService(maxFailuresBeforeLocationRefresh)
+ // make sure we have more than maxFailuresBeforeLocationRefresh locations
+ // so that we have a chance to do location refresh
+ val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh)
+ .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) }
+ when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds)
+ store = makeBlockManager(8000, "executor1", mockBlockManagerMaster,
+ transferService = Option(mockBlockTransferService))
+ val block = store.getRemoteBytes("item")
+ .asInstanceOf[Option[ByteBuffer]]
+ assert(block.isDefined)
+ verify(mockBlockManagerMaster, times(2)).getLocations("item")
+ }
+
+ class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
+ var numCalls = 0
+
+ override def init(blockDataManager: BlockDataManager): Unit = {}
+
+ override def fetchBlocks(
+ host: String,
+ port: Int,
+ execId: String,
+ blockIds: Array[String],
+ listener: BlockFetchingListener): Unit = {
+ listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
+ }
+
+ override def close(): Unit = {}
+
+ override def hostName: String = { "MockBlockTransferServiceHost" }
+
+ override def port: Int = { 63332 }
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int, execId: String,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit] = {
+ import scala.concurrent.ExecutionContext.Implicits.global
+ Future {}
+ }
+
+ override def fetchBlockSync(
+ host: String,
+ port: Int,
+ execId: String,
+ blockId: String): ManagedBuffer = {
+ numCalls += 1
+ if (numCalls <= maxFailures) {
+ throw new RuntimeException("Failing block fetch in the mock block transfer service")
+ }
+ super.fetchBlockSync(host, port, execId, blockId)
+ }
+ }
}
private object BlockManagerSuite {