aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala72
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala22
10 files changed, 172 insertions, 192 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 862ffe868f..92218832d2 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,14 +21,14 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.{HashMap, HashSet, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
@@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
/**
- * Called from executors to get the server URIs and output sizes of the map outputs of
- * a given shuffle.
+ * Called from executors to get the server URIs and output sizes for each shuffle block that
+ * needs to be read from a given reduce task.
+ *
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
*/
- def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
+ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
+ : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId")
+ val startTime = System.currentTimeMillis
+
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
@@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
}
+ logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " +
+ s"${System.currentTimeMillis - startTime} ms")
+
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
@@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging {
}
}
- // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
- // any of the statuses is null (indicating a missing location due to a failed mapper),
- // throw a FetchFailedException.
+ /**
+ * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block
+ * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that
+ * block manager.
+ *
+ * If any of the statuses is null (indicating a missing location due to a failed mapper),
+ * throws a FetchFailedException.
+ *
+ * @param shuffleId Identifier for the shuffle
+ * @param reduceId Identifier for the reduce task
+ * @param statuses List of map statuses, indexed by map ID.
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
+ */
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
- statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
assert (statuses != null)
- statuses.map {
- status =>
- if (status == null) {
- logError("Missing an output location for shuffle " + shuffleId)
- throw new MetadataFetchFailedException(
- shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
- } else {
- (status.location, status.getSizeForBlock(reduceId))
- }
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
+ for ((status, mapId) <- statuses.zipWithIndex) {
+ if (status == null) {
+ val errorMessage = s"Missing an output location for shuffle $shuffleId"
+ logError(errorMessage)
+ throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage)
+ } else {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId)))
+ }
}
+
+ splitsByAddress.toSeq
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
deleted file mode 100644
index 9d8e7e9f03..0000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import java.io.InputStream
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.util.{Failure, Success}
-
-import org.apache.spark._
-import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
- ShuffleBlockId}
-
-private[hash] object BlockStoreShuffleFetcher extends Logging {
- def fetchBlockStreams(
- shuffleId: Int,
- reduceId: Int,
- context: TaskContext,
- blockManager: BlockManager,
- mapOutputTracker: MapOutputTracker)
- : Iterator[(BlockId, InputStream)] =
- {
- logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-
- val startTime = System.currentTimeMillis
- val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
- logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
- shuffleId, reduceId, System.currentTimeMillis - startTime))
-
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
- for (((address, size), index) <- statuses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
- }
-
- val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
- case (address, splits) =>
- (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
- }
-
- val blockFetcherItr = new ShuffleBlockFetcherIterator(
- context,
- blockManager.shuffleClient,
- blockManager,
- blocksByAddress,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
-
- // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
- blockFetcherItr.map { blockPair =>
- val blockId = blockPair._1
- val blockOption = blockPair._2
- blockOption match {
- case Success(inputStream) => {
- (blockId, inputStream)
- }
- case Failure(e) => {
- blockId match {
- case ShuffleBlockId(shufId, mapId, _) =>
- val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
- case _ =>
- throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block", e)
- }
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index d5c9880659..de79fa56f0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,10 +17,10 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -31,8 +31,8 @@ private[spark] class HashShuffleReader[K, C](
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
- extends ShuffleReader[K, C]
-{
+ extends ShuffleReader[K, C] with Logging {
+
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")
@@ -40,11 +40,16 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
- handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
+ val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
- val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
+ val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index e49e39679e..a759ceb96e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -21,18 +21,19 @@ import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.{Failure, Try}
+import scala.util.control.NonFatal
-import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.{Logging, SparkException, TaskContext}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
- * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
@@ -53,7 +54,7 @@ final class ShuffleBlockFetcherIterator(
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Try[InputStream])] with Logging {
+ extends Iterator[(BlockId, InputStream)] with Logging {
import ShuffleBlockFetcherIterator._
@@ -115,7 +116,7 @@ final class ShuffleBlockFetcherIterator(
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
- case SuccessFetchResult(_, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf) => buf.release()
case _ =>
}
currentResult = null
@@ -132,7 +133,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf) => buf.release()
case _ =>
}
}
@@ -157,7 +158,7 @@ final class ShuffleBlockFetcherIterator(
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
@@ -166,7 +167,7 @@ final class ShuffleBlockFetcherIterator(
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FailureFetchResult(BlockId(blockId), e))
+ results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
@@ -238,12 +239,12 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FailureFetchResult(blockId, e))
+ results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
@@ -275,12 +276,14 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
/**
- * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
+ * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
* underlying each InputStream will be freed by the cleanup() method registered with the
* TaskCompletionListener. However, callers should close() these InputStreams
* as soon as they are no longer needed, in order to release memory as early as possible.
+ *
+ * Throws a FetchFailedException if the next block could not be fetched.
*/
- override def next(): (BlockId, Try[InputStream]) = {
+ override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
@@ -289,7 +292,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
- case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
case _ =>
}
// Send fetch requests up to maxBytesInFlight
@@ -298,19 +301,28 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}
- val iteratorTry: Try[InputStream] = result match {
- case FailureFetchResult(_, e) =>
- Failure(e)
- case SuccessFetchResult(blockId, _, buf) =>
- // There is a chance that createInputStream can fail (e.g. fetching a local file that does
- // not exist, SPARK-4085). In that case, we should propagate the right exception so
- // the scheduler gets a FetchFailedException.
- Try(buf.createInputStream()).map { inputStream =>
- new BufferReleasingInputStream(inputStream, this)
+ result match {
+ case FailureFetchResult(blockId, address, e) =>
+ throwFetchFailedException(blockId, address, e)
+
+ case SuccessFetchResult(blockId, address, _, buf) =>
+ try {
+ (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
+ } catch {
+ case NonFatal(t) =>
+ throwFetchFailedException(blockId, address, t)
}
}
+ }
- (result.blockId, iteratorTry)
+ private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
+ blockId match {
+ case ShuffleBlockId(shufId, mapId, reduceId) =>
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
+ case _ =>
+ throw new SparkException(
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
+ }
}
}
@@ -366,16 +378,22 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] sealed trait FetchResult {
val blockId: BlockId
+ val address: BlockManagerId
}
/**
* Result of a fetch from a remote block successfully.
* @param blockId block id
+ * @param address BlockManager that the block was fetched from.
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
* @param buf [[ManagedBuffer]] for the content.
*/
- private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ private[storage] case class SuccessFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ size: Long,
+ buf: ManagedBuffer)
extends FetchResult {
require(buf != null)
require(size >= 0)
@@ -384,8 +402,12 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block unsuccessfully.
* @param blockId block id
+ * @param address BlockManager that the block was attempted to be fetched from
* @param e the failure exception
*/
- private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ private[storage] case class FailureFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable)
extends FetchResult
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 7a1961137c..af4e68950f 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -17,13 +17,15 @@
package org.apache.spark
+import scala.collection.mutable.ArrayBuffer
+
import org.mockito.Mockito._
import org.mockito.Matchers.{any, isA}
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
class MapOutputTrackerSuite extends SparkFunSuite {
private val conf = new SparkConf
@@ -55,9 +57,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Array(1000L, 10000L)))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L)))
- val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
- (BlockManagerId("b", "hostB", 1000), size10000)))
+ val statuses = tracker.getMapSizesByExecutorId(10, 0)
+ assert(statuses.toSet ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
+ (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
+ .toSet)
tracker.stop()
rpcEnv.shutdown()
}
@@ -75,10 +79,10 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
- assert(tracker.getServerStatuses(10, 0).nonEmpty)
+ assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
- assert(tracker.getServerStatuses(10, 0).isEmpty)
+ assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
tracker.stop()
rpcEnv.shutdown()
@@ -104,7 +108,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) }
tracker.stop()
rpcEnv.shutdown()
@@ -126,23 +130,23 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("a", "hostA", 1000), Array(1000L)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
// failure should be cached
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
masterTracker.stop()
slaveTracker.stop()
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index c3c2b1ffc1..b68102bfb9 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -66,8 +66,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// All blocks must have non-zero size
(0 until NUM_BLOCKS).foreach { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
- assert(statuses.forall(s => s._2 > 0))
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0)))
}
}
@@ -105,8 +105,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
- statuses.map(x => x._2)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ statuses.flatMap(_._2.map(_._2))
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
@@ -130,8 +130,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
- statuses.map(x => x._2)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ statuses.flatMap(_._2.map(_._2))
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 86728cb2b6..3462a82c9c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -483,8 +483,8 @@ class DAGSchedulerSuite
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
@@ -510,8 +510,8 @@ class DAGSchedulerSuite
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size))))
// we can see both result blocks now
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
- Array("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
assertDataStructuresEmpty()
@@ -527,8 +527,8 @@ class DAGSchedulerSuite
(Success, makeMapStatus("hostA", reduceRdd.partitions.size)),
(Success, makeMapStatus("hostB", reduceRdd.partitions.size))))
// The MapOutputTracker should know about both map output locations.
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
- Array("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
@@ -577,10 +577,10 @@ class DAGSchedulerSuite
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// The MapOutputTracker should know about both map output locations.
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
- Array("hostA", "hostB"))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) ===
- Array("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
@@ -713,8 +713,8 @@ class DAGSchedulerSuite
taskSet.tasks(1).epoch = newEpoch
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA",
reduceRdd.partitions.size), null, createFakeTaskInfo(), null))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
assertDataStructuresEmpty()
@@ -809,8 +809,8 @@ class DAGSchedulerSuite
(Success, makeMapStatus("hostB", 1))))
// have hostC complete the resubmitted task
complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
complete(taskSets(2), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
@@ -981,8 +981,8 @@ class DAGSchedulerSuite
submit(reduceRdd, Array(0))
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostA")))
// Reducer should run on the same host that map task ran
val reduceTaskSet = taskSets(1)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index 28ca68698e..6c9cb448e7 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -115,11 +115,15 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
- // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
- // for the code to read data over the network.
- val statuses: Array[(BlockManagerId, Long)] =
- Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong))
- when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)
+ when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn {
+ // Test a scenario where all data is local, to avoid creating a bunch of additional mocks
+ // for the code to read data over the network.
+ val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
+ val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+ (shuffleBlockId, byteOutputStream.size().toLong)
+ }
+ Seq((localBlockManagerId, shuffleBlockIdsAndSizes))
+ }
// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 9ced4148d7..64f3fbdceb 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.{SparkFunSuite, TaskContextImpl}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.shuffle.FetchFailedException
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
@@ -106,13 +107,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
val (blockId, inputStream) = iterator.next()
- assert(inputStream.isSuccess,
- s"iterator should have 5 elements defined but actually has $i elements")
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
// Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream
- val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream]
+ val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream]
verify(mockBuf, times(0)).release()
val delegateAccess = PrivateMethod[InputStream]('delegate)
@@ -175,11 +174,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
- iterator.next()._2.get.close() // close() first block's input stream
+ iterator.next()._2.close() // close() first block's input stream
verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
// Get the 2nd block but do not exhaust the iterator
- val subIter = iterator.next()._2.get
+ val subIter = iterator.next()._2
// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
@@ -239,9 +238,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
- // The first block should be defined, and the last two are not defined (due to failure)
- assert(iterator.next()._2.isSuccess)
- assert(iterator.next()._2.isFailure)
- assert(iterator.next()._2.isFailure)
+ // The first block should be returned without an exception, and the last two should throw
+ // FetchFailedExceptions (due to failure)
+ iterator.next()
+ intercept[FetchFailedException] { iterator.next() }
+ intercept[FetchFailedException] { iterator.next() }
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index 6c40685484..61601016e0 100644
--- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util
+import scala.collection.mutable.ArrayBuffer
+
import java.util.concurrent.TimeoutException
import akka.actor.ActorNotFound
@@ -24,7 +26,7 @@ import akka.actor.ActorNotFound
import org.apache.spark._
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
import org.apache.spark.SSLSampleConfigs._
@@ -107,8 +109,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security off
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
@@ -153,8 +156,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security on and passwords match
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
@@ -232,8 +236,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security off
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
@@ -278,8 +282,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()