diff options
7 files changed, 290 insertions, 84 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3a5caa3510..6bd950205f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,13 +18,15 @@ package org.apache.spark import java.io._ -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor} import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag +import scala.util.control.NonFatal +import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus @@ -37,31 +39,18 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) + /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) extends RpcEndpoint with Logging { - val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.length - if (serializedSize > maxRpcMessageSize) { - - val msg = s"Map output statuses were $serializedSize bytes which " + - s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)." - - /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. - * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ - val exception = new SparkException(msg) - logError(msg, exception) - context.sendFailure(exception) - } else { - context.reply(mapOutputStatuses) - } + val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -270,12 +259,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** * MapOutputTracker for the driver. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf) +private[spark] class MapOutputTrackerMaster(conf: SparkConf, + broadcastManager: BroadcastManager, isLocal: Boolean) extends MapOutputTracker(conf) { /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch + // The size at which we use Broadcast to send the map output statuses to the executors + private val minSizeForBroadcast = + conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt + /** Whether to compute locality preferences for reduce tasks */ private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) @@ -296,10 +290,86 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + + // Kept in sync with cachedSerializedStatuses explicitly + // This is required so that the Broadcast variable remains in scope until we remove + // the shuffleId explicitly or implicitly. + private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() + + // This is to prevent multiple serializations of the same shuffle - which happens when + // there is a request storm when shuffle start. + private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() + + // requests for map output statuses + private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + + // Thread pool used for handling map output status requests. This is a separate thread pool + // to ensure we don't block the normal dispatcher threads. + private val threadpool: ThreadPoolExecutor = { + val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + // Make sure that that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { + val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + + "message that is to large." + logError(msg) + throw new IllegalArgumentException(msg) + } + + def post(message: GetMapOutputMessage): Unit = { + mapOutputRequests.offer(message) + } + + /** Message loop used for dispatching messages. */ + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val data = mapOutputRequests.take() + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + mapOutputRequests.offer(PoisonPill) + return + } + val context = data.context + val shuffleId = data.shuffleId + val hostPort = context.senderAddress.hostPort + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) + context.reply(mapOutputStatuses) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new GetMapOutputMessage(-99, null) + + // Exposed for testing + private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } + // add in advance + shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -337,6 +407,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) cachedSerializedStatuses.remove(shuffleId) + cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) + shuffleIdLocks.remove(shuffleId) } /** Check if the given shuffle is being tracked */ @@ -428,40 +500,89 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + private def removeBroadcast(bcast: Broadcast[_]): Unit = { + if (null != bcast) { + broadcastManager.unbroadcast(bcast.id, + removeFromDriver = true, blocking = false) + } + } + + private def clearCachedBroadcast(): Unit = { + for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) + cachedSerializedBroadcast.clear() + } + def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null + var retBytes: Array[Byte] = null var epochGotten: Long = -1 - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - return bytes - case None => - statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) - epochGotten = epoch + + // Check to see if we have a cached version, returns true if it does + // and has side effect of setting retBytes. If not returns false + // with side effect of setting statuses + def checkCachedStatuses(): Boolean = { + epochLock.synchronized { + if (epoch > cacheEpoch) { + cachedSerializedStatuses.clear() + clearCachedBroadcast() + cacheEpoch = epoch + } + cachedSerializedStatuses.get(shuffleId) match { + case Some(bytes) => + retBytes = bytes + true + case None => + logDebug("cached status not found for : " + shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) + epochGotten = epoch + false + } } } - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val bytes = MapOutputTracker.serializeMapStatuses(statuses) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes + + if (checkCachedStatuses()) return retBytes + var shuffleIdLock = shuffleIdLocks.get(shuffleId) + if (null == shuffleIdLock) { + val newLock = new Object() + // in general, this condition should be false - but good to be paranoid + val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) + shuffleIdLock = if (null != prevLock) prevLock else newLock + } + // synchronize so we only serialize/broadcast it once since multiple threads call + // in parallel + shuffleIdLock.synchronized { + // double check to make sure someone else didn't serialize and cache the same + // mapstatus while we were waiting on the synchronize + if (checkCachedStatuses()) return retBytes + + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "statuses"; let's serialize and return that + val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, + isLocal, minSizeForBroadcast) + logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) + // Add them into the table only if the epoch hasn't changed while we were working + epochLock.synchronized { + if (epoch == epochGotten) { + cachedSerializedStatuses(shuffleId) = bytes + if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast + } else { + logInfo("Epoch changed, not caching!") + removeBroadcast(bcast) + } } + bytes } - bytes } override def stop() { + mapOutputRequests.offer(PoisonPill) + threadpool.shutdown() sendTracker(StopMapOutputTracker) mapStatuses.clear() trackerEndpoint = null cachedSerializedStatuses.clear() + clearCachedBroadcast() + shuffleIdLocks.clear() } } @@ -477,12 +598,16 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { val ENDPOINT_NAME = "MapOutputTracker" + private val DIRECT = 0 + private val BROADCAST = 1 // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { + def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager, + isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = { val out = new ByteArrayOutputStream + out.write(DIRECT) val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) Utils.tryWithSafeFinally { // Since statuses can be modified in parallel, sync on it @@ -492,16 +617,51 @@ private[spark] object MapOutputTracker extends Logging { } { objOut.close() } - out.toByteArray + val arr = out.toByteArray + if (arr.length >= minBroadcastSize) { + // Use broadcast instead. + // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! + val bcast = broadcastManager.newBroadcast(arr, isLocal) + // toByteArray creates copy, so we can reuse out + out.reset() + out.write(BROADCAST) + val oos = new ObjectOutputStream(new GZIPOutputStream(out)) + oos.writeObject(bcast) + oos.close() + val outArr = out.toByteArray + logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + (outArr, bcast) + } else { + (arr, null) + } } // Opposite of serializeMapStatuses. def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { - val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - Utils.tryWithSafeFinally { - objIn.readObject().asInstanceOf[Array[MapStatus]] - } { - objIn.close() + assert (bytes.length > 0) + + def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { + val objIn = new ObjectInputStream(new GZIPInputStream( + new ByteArrayInputStream(arr, off, len))) + Utils.tryWithSafeFinally { + objIn.readObject() + } { + objIn.close() + } + } + + bytes(0) match { + case DIRECT => + deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] + case BROADCAST => + // deserialize the Broadcast, pull .value array out of it, and then deserialize that + val bcast = deserializeObject(bytes, 1, bytes.length - 1). + asInstanceOf[Broadcast[Array[Byte]]] + logInfo("Broadcast mapstatuses size = " + bytes.length + + ", actual size = " + bcast.value.length) + // Important - ignore the DIRECT tag ! Start from offset 1 + deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 27497e21b8..4bf8890c05 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -284,8 +284,10 @@ object SparkEnv extends Logging { } } + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) + val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf) + new MapOutputTrackerMaster(conf, broadcastManager, isLocal) } else { new MapOutputTrackerWorker(conf) } @@ -325,8 +327,6 @@ object SparkEnv extends Logging { serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) - val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ddf48765ec..c6aebc19fd 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.{any, isA} import org.mockito.Mockito._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException @@ -30,6 +31,12 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf + private def newTrackerMaster(sparkConf: SparkConf = conf) = { + val broadcastManager = new BroadcastManager(true, sparkConf, + new SecurityManager(sparkConf)) + new MapOutputTrackerMaster(sparkConf, broadcastManager, true) + } + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { RpcEnv.create(name, host, port, conf, securityManager) @@ -37,7 +44,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master start and stop") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() @@ -46,7 +53,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master register shuffle and fetch") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -62,13 +69,14 @@ class MapOutputTrackerSuite extends SparkFunSuite { Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) + assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() } test("master register and unregister shuffle") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -80,6 +88,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) @@ -90,7 +99,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master register shuffle and unregister map output and fetch") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -101,6 +110,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) + assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) @@ -118,7 +128,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster() masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) @@ -139,6 +149,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() @@ -147,6 +158,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // failure should be cached intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() slaveTracker.stop() @@ -158,8 +170,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "1048576") - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster(newConf) val rpcEnv = createRpcEnv("spark") val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) @@ -172,45 +185,27 @@ class MapOutputTrackerSuite extends SparkFunSuite { val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) - verify(rpcCallContext).reply(any()) - verify(rpcCallContext, never()).sendFailure(any()) + // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast + // to be used. + verify(rpcCallContext, timeout(30000)).reply(any()) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) // masterTracker.stop() // this throws an exception rpcEnv.shutdown() } - test("remote fetch exceeds max RPC message size") { + test("min broadcast size exceeds max RPC message size") { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString) - val masterTracker = new MapOutputTrackerMaster(conf) - val rpcEnv = createRpcEnv("test") - val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) - rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - - // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. - // Note that the size is hand-selected here because map output statuses are compressed before - // being sent. - masterTracker.registerShuffle(20, 100) - (0 until 100).foreach { i => - masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) - } - val senderAddress = RpcAddress("localhost", 12345) - val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) - verify(rpcCallContext, never()).reply(any()) - verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) - -// masterTracker.stop() // this throws an exception - rpcEnv.shutdown() + intercept[IllegalArgumentException] { newTrackerMaster(newConf) } } test("getLocationsWithLargestOutputs with multiple outputs in same machine") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) // Setup 3 map tasks @@ -242,4 +237,44 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.stop() rpcEnv.shutdown() } + + test("remote fetch using broadcast") { + val newConf = new SparkConf + newConf.set("spark.rpc.message.maxSize", "1") + newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize + + // needs TorrentBroadcast so need a SparkContext + val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf) + try { + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + + // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Note that the size is hand-selected here because map output statuses are compressed before + // being sent. + masterTracker.registerShuffle(20, 100) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + } + val senderAddress = RpcAddress("localhost", 12345) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + // should succeed since majority of data is broadcast and actual serialized + // message size is small + verify(rpcCallContext, timeout(30000)).reply(any()) + assert(1 == masterTracker.getNumCachedSerializedBroadcast) + masterTracker.unregisterShuffle(20) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) + + } finally { + LocalSparkContext.stop(sc) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 844c780a3f..e3ed079e4e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} @@ -156,6 +157,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } var mapOutputTracker: MapOutputTrackerMaster = null + var broadcastManager: BroadcastManager = null + var securityMgr: SecurityManager = null var scheduler: DAGScheduler = null var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null @@ -207,7 +210,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou cancelledStages.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster(conf) + securityMgr = new SecurityManager(conf) + broadcastManager = new BroadcastManager(true, conf, securityMgr) + mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) scheduler = new DAGScheduler( sc, taskScheduler, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d14728cb50..31687e6147 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -43,7 +44,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private var rpcEnv: RpcEnv = null private var master: BlockManagerMaster = null private val securityMgr = new SecurityManager(conf) - private val mapOutputTracker = new MapOutputTrackerMaster(conf) + private val bcastManager = new BroadcastManager(true, conf, securityMgr) + private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) private val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index db1efaf2a2..a2580304c4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -59,7 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) - val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false)) + val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 4be4882938..e97427991b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.netty.NettyBlockTransferService @@ -57,7 +58,8 @@ class ReceivedBlockHandlerSuite val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) + val broadcastManager = new BroadcastManager(true, conf, securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) val shuffleManager = new SortShuffleManager(conf) val serializer = new KryoSerializer(conf) var serializerManager = new SerializerManager(serializer, conf) |