aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala40
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala9
-rw-r--r--docs/configuration.md10
5 files changed, 49 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0675957e16..6132fa349e 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -69,7 +69,7 @@ import org.apache.spark.util.Utils
*
* - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
* for the HttpServer. Jetty supports multiple authentication mechanisms -
- * Basic, Digest, Form, Spengo, etc. It also supports multiple different login
+ * Basic, Digest, Form, Spnego, etc. It also supports multiple different login
* services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService
* to authenticate using DIGEST-MD5 via a single user and the shared secret.
* Since we are using DIGEST-MD5, the shared secret is not passed on the wire
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index acbe16001f..dc182f5963 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -46,7 +46,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index c368a39e62..478a928acd 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -47,6 +47,7 @@ import org.apache.spark.util.Utils
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
@@ -54,7 +55,8 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- maxBytesInFlight: Long)
+ maxBytesInFlight: Long,
+ maxReqsInFlight: Int)
extends Iterator[(BlockId, InputStream)] with Logging {
import ShuffleBlockFetcherIterator._
@@ -102,6 +104,9 @@ final class ShuffleBlockFetcherIterator(
/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
+ /** Current number of requests in flight */
+ private[this] var reqsInFlight = 0
+
private[this] val shuffleMetrics = context.taskMetrics().registerTempShuffleReadMetrics()
/**
@@ -118,7 +123,7 @@ final class ShuffleBlockFetcherIterator(
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
- case SuccessFetchResult(_, _, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf, _) => buf.release()
case _ =>
}
currentResult = null
@@ -137,7 +142,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, address, _, buf) => {
+ case SuccessFetchResult(_, address, _, buf, _) => {
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
@@ -153,9 +158,11 @@ final class ShuffleBlockFetcherIterator(
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
+ reqsInFlight += 1
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
+ val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
@@ -169,7 +176,10 @@ final class ShuffleBlockFetcherIterator(
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
+ remainingBlocks -= blockId
+ results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
+ remainingBlocks.isEmpty))
+ logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
@@ -249,7 +259,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
@@ -268,6 +278,9 @@ final class ShuffleBlockFetcherIterator(
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
+ assert ((0 == reqsInFlight) == (0 == bytesInFlight),
+ "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
+ ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
fetchUpToMaxBytes()
@@ -299,12 +312,16 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
- case SuccessFetchResult(_, address, size, buf) => {
+ case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => {
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
+ if (isNetworkReqDone) {
+ reqsInFlight -= 1
+ logDebug("Number of requests in flight " + reqsInFlight)
+ }
}
case _ =>
}
@@ -315,7 +332,7 @@ final class ShuffleBlockFetcherIterator(
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
- case SuccessFetchResult(blockId, address, _, buf) =>
+ case SuccessFetchResult(blockId, address, _, buf, _) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
@@ -328,7 +345,9 @@ final class ShuffleBlockFetcherIterator(
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ (bytesInFlight == 0 ||
+ (reqsInFlight + 1 <= maxReqsInFlight &&
+ bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
sendRequest(fetchRequests.dequeue())
}
}
@@ -406,13 +425,14 @@ object ShuffleBlockFetcherIterator {
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
* @param buf [[ManagedBuffer]] for the content.
+ * @param isNetworkReqDone Is this the last network request for this host in this fetch request.
*/
private[storage] case class SuccessFetchResult(
blockId: BlockId,
address: BlockManagerId,
size: Long,
- buf: ManagedBuffer)
- extends FetchResult {
+ buf: ManagedBuffer,
+ isNetworkReqDone: Boolean) extends FetchResult {
require(buf != null)
require(size >= 0)
}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index c9c2fb2691..e3ec99685f 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -99,7 +99,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
- 48 * 1024 * 1024)
+ 48 * 1024 * 1024,
+ Int.MaxValue)
// 3 local blocks fetched in initialization
verify(blockManager, times(3)).getBlockData(any())
@@ -171,7 +172,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
- 48 * 1024 * 1024)
+ 48 * 1024 * 1024,
+ Int.MaxValue)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
iterator.next()._2.close() // close() first block's input stream
@@ -233,7 +235,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
- 48 * 1024 * 1024)
+ 48 * 1024 * 1024,
+ Int.MaxValue)
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
diff --git a/docs/configuration.md b/docs/configuration.md
index dd2cde8194..0dbfe3b079 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -392,6 +392,16 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td><code>spark.reducer.maxReqsInFlight</code></td>
+ <td>Int.MaxValue</td>
+ <td>
+ This configuration limits the number of remote requests to fetch blocks at any given point.
+ When the number of hosts in the cluster increase, it might lead to very large number
+ of in-bound connections to one or more nodes, causing the workers to fail under load.
+ By allowing it to limit the number of fetch requests, this scenario can be mitigated.
+ </td>
+</tr>
+<tr>
<td><code>spark.shuffle.compress</code></td>
<td>true</td>
<td>