aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala14
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala133
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala21
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala2
-rw-r--r--docs/configuration.md8
5 files changed, 114 insertions, 64 deletions
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 048d1788c2..4554db2249 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -21,14 +21,14 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
+ splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(i => "shuffle_%d_%d_%d".format(shuffleId, i, reduceId)))
+ (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
@@ -43,9 +43,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, reduceId) =>
- val addr = statuses(mapId.toInt)._1
- throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null)
+ case regex(shufId, mapId, _) =>
+ val address = statuses(mapId.toInt)._1
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index c4b241bf5a..c9bcd26016 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -104,7 +104,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// TODO: This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
- val numParallelFetches = BlockManager.getNumParallelFetchesFromSystemProperties
+ // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
+ // for receiving shuffle outputs)
+ val maxBytesInFlight =
+ System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
+
val compress = System.getProperty("spark.blockManager.compress", "false").toBoolean
val host = System.getProperty("spark.hostname", Utils.localHostName())
@@ -345,9 +349,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
/**
* Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
* an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
- * fashion as they're received.
+ * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
+ * so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[String])])
+ def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
: Iterator[(String, Option[Iterator[Any]])] = {
if (blocksByAddress == null) {
@@ -359,17 +364,35 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new HashSet[String]()
- // A queue to hold our results. Because we want all the deserializing the happen in the
- // caller's thread, this will actually hold functions to produce the Iterator for each block.
- // For local blocks we'll have an iterator already, while for remote ones we'll deserialize.
- val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])]
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ // A queue to hold our results.
+ val results = new LinkedBlockingQueue[FetchResult]
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ val fetchRequests = new Queue[FetchRequest]
- // Bound the number and memory usage of fetched remote blocks.
- val blocksToRequest = new Queue[(BlockManagerId, BlockMessage)]
+ // Current bytes in flight from our requests
+ var bytesInFlight = 0L
- def sendRequest(bmId: BlockManagerId, blockMessages: Seq[BlockMessage]) {
- val cmId = new ConnectionManagerId(bmId.ip, bmId.port)
- val blockMessageArray = new BlockMessageArray(blockMessages)
+ def sendRequest(req: FetchRequest) {
+ val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map{
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
@@ -381,58 +404,71 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
- results.put((blockId, Some(() => dataDeserialize(blockMessage.getData))))
+ results.put(new FetchResult(
+ blockId, sizeMap(blockId), () => dataDeserialize(blockMessage.getData)))
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
logError("Could not get block(s) from " + cmId)
- for (blockMessage <- blockMessages) {
- results.put((blockMessage.getId, None))
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
}
}
}
}
- // Split local and remote blocks. Remote blocks are further split into ones that will
- // be requested initially and ones that will be added to a queue of blocks to request.
- val initialRequestBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockMessage]]()
- var initialRequests = 0
- val blocksToGetLater = new ArrayBuffer[(BlockManagerId, BlockMessage)]
- for ((address, blockIds) <- Utils.randomize(blocksByAddress)) {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
- localBlockIds ++= blockIds
+ localBlockIds ++= blockInfos.map(_._1)
} else {
- remoteBlockIds ++= blockIds
- for (blockId <- blockIds) {
- val blockMessage = BlockMessage.fromGetBlock(GetBlock(blockId))
- if (initialRequests < numParallelFetches) {
- initialRequestBlocks.getOrElseUpdate(address, new ArrayBuffer[BlockMessage])
- .append(blockMessage)
- initialRequests += 1
- } else {
- blocksToGetLater.append((address, blockMessage))
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 4 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 4
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 4, 1L)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
}
}
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
}
}
- // Add the remaining blocks into a queue to pull later in a random order
- blocksToRequest ++= Utils.randomize(blocksToGetLater)
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
- // Send out initial request(s) for 'numParallelFetches' blocks.
- for ((bmId, blockMessages) <- initialRequestBlocks) {
- sendRequest(bmId, blockMessages)
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
}
- logDebug("Started remote gets for " + numParallelFetches + " blocks in " +
- Utils.getUsedTimeMs(startTime) + " ms")
+ logDebug("Started remote gets in " + Utils.getUsedTimeMs(startTime) + " ms")
- // Get the local blocks while remote blocks are being fetched.
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
startTime = System.currentTimeMillis
for (id <- localBlockIds) {
getLocal(id) match {
- case Some(block) => {
- results.put((id, Some(() => block)))
+ case Some(iter) => {
+ results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
logDebug("Got local block " + id)
}
case None => {
@@ -450,12 +486,13 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
- val (blockId, functionOption) = results.take()
- if (remoteBlockIds.contains(blockId) && !blocksToRequest.isEmpty) {
- val (bmId, blockMessage) = blocksToRequest.dequeue()
- sendRequest(bmId, Seq(blockMessage))
+ val result = results.take()
+ bytesInFlight -= result.size
+ if (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
}
- (blockId, functionOption.map(_.apply()))
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
}
@@ -765,10 +802,6 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
private[spark]
object BlockManager extends Logging {
- def getNumParallelFetchesFromSystemProperties: Int = {
- System.getProperty("spark.blockManager.parallelFetches", "4").toInt
- }
-
def getMaxMemoryFromSystemProperties: Long = {
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index 48c0a830e0..97cfa9dad1 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -26,20 +26,21 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
sc.stop()
sc = null
}
+ System.clearProperty("spark.reducer.maxMbInFlight")
}
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
- assert(sc.parallelize(1 to 2, 2).count == 2)
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
sc.stop()
sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
- assert(sc.parallelize(1 to 2, 2).count == 2)
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
sc.stop()
sc = new SparkContext("local-cluster[2, 1, 512]", "test")
- assert(sc.parallelize(1 to 2, 2).count == 2)
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
sc.stop()
sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
- assert(sc.parallelize(1 to 2, 2).count == 2)
+ assert(sc.parallelize(1 to 2, 2).count() == 2)
sc.stop()
sc = null
}
@@ -55,6 +56,18 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
assert(valuesFor2.toList.sorted === List(1))
}
+ test("groupByKey where map output sizes exceed maxMbInFlight") {
+ System.setProperty("spark.reducer.maxMbInFlight", "1")
+ sc = new SparkContext(clusterUrl, "test")
+ // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output
+ // file should be about 2.5 MB
+ val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000)))
+ val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect()
+ assert(groups.length === 16)
+ assert(groups.map(_._2).sum === 2000)
+ // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block
+ }
+
test("accumulators") {
sc = new SparkContext(clusterUrl, "test")
val accum = sc.accumulator(0)
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index d28b06c013..4e9717d871 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -10,6 +10,8 @@ class MapOutputTrackerSuite extends FunSuite {
assert(MapOutputTracker.compressSize(10L) === 25)
assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145)
assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218)
+ // This last size is bigger than we can encode in a byte, so check that we just return 255
+ assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255)
}
test("decompressSize") {
diff --git a/docs/configuration.md b/docs/configuration.md
index fa7123af1b..0987f7f7b1 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -139,10 +139,12 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
- <td>spark.blockManager.parallelFetches</td>
- <td>4</td>
+ <td>spark.reducer.maxMbInFlight</td>
+ <td>48</td>
<td>
- Number of map output files to fetch concurrently from each reduce task.
+ Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since
+ each output requires us to create a buffer to receive it, this represents a fixed memory overhead
+ per reduce task, so keep it small unless you have a large amount of memory.
</td>
</tr>
<tr>