aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala250
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala99
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala4
7 files changed, 290 insertions, 84 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 3a5caa3510..6bd950205f 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -18,13 +18,15 @@
package org.apache.spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
@@ -37,31 +39,18 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
+private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext)
+
/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
extends RpcEndpoint with Logging {
- val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
- val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
- val serializedSize = mapOutputStatuses.length
- if (serializedSize > maxRpcMessageSize) {
-
- val msg = s"Map output statuses were $serializedSize bytes which " +
- s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)."
-
- /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
- * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
- val exception = new SparkException(msg)
- logError(msg, exception)
- context.sendFailure(exception)
- } else {
- context.reply(mapOutputStatuses)
- }
+ val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
case StopMapOutputTracker =>
logInfo("MapOutputTrackerMasterEndpoint stopped!")
@@ -270,12 +259,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* MapOutputTracker for the driver.
*/
-private[spark] class MapOutputTrackerMaster(conf: SparkConf)
+private[spark] class MapOutputTrackerMaster(conf: SparkConf,
+ broadcastManager: BroadcastManager, isLocal: Boolean)
extends MapOutputTracker(conf) {
/** Cache a serialized version of the output statuses for each shuffle to send them out faster */
private var cacheEpoch = epoch
+ // The size at which we use Broadcast to send the map output statuses to the executors
+ private val minSizeForBroadcast =
+ conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
+
/** Whether to compute locality preferences for reduce tasks */
private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
@@ -296,10 +290,86 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala
+ private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
+
+ // Kept in sync with cachedSerializedStatuses explicitly
+ // This is required so that the Broadcast variable remains in scope until we remove
+ // the shuffleId explicitly or implicitly.
+ private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]()
+
+ // This is to prevent multiple serializations of the same shuffle - which happens when
+ // there is a request storm when shuffle start.
+ private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
+
+ // requests for map output statuses
+ private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
+
+ // Thread pool used for handling map output status requests. This is a separate thread pool
+ // to ensure we don't block the normal dispatcher threads.
+ private val threadpool: ThreadPoolExecutor = {
+ val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
+ val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
+ for (i <- 0 until numThreads) {
+ pool.execute(new MessageLoop)
+ }
+ pool
+ }
+
+ // Make sure that that we aren't going to exceed the max RPC message size by making sure
+ // we use broadcast to send large map output statuses.
+ if (minSizeForBroadcast > maxRpcMessageSize) {
+ val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " +
+ s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " +
+ "message that is to large."
+ logError(msg)
+ throw new IllegalArgumentException(msg)
+ }
+
+ def post(message: GetMapOutputMessage): Unit = {
+ mapOutputRequests.offer(message)
+ }
+
+ /** Message loop used for dispatching messages. */
+ private class MessageLoop extends Runnable {
+ override def run(): Unit = {
+ try {
+ while (true) {
+ try {
+ val data = mapOutputRequests.take()
+ if (data == PoisonPill) {
+ // Put PoisonPill back so that other MessageLoops can see it.
+ mapOutputRequests.offer(PoisonPill)
+ return
+ }
+ val context = data.context
+ val shuffleId = data.shuffleId
+ val hostPort = context.senderAddress.hostPort
+ logDebug("Handling request to send map output locations for shuffle " + shuffleId +
+ " to " + hostPort)
+ val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
+ context.reply(mapOutputStatuses)
+ } catch {
+ case NonFatal(e) => logError(e.getMessage, e)
+ }
+ }
+ } catch {
+ case ie: InterruptedException => // exit
+ }
+ }
+ }
+
+ /** A poison endpoint that indicates MessageLoop should exit its message loop. */
+ private val PoisonPill = new GetMapOutputMessage(-99, null)
+
+ // Exposed for testing
+ private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size
+
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
+ // add in advance
+ shuffleIdLocks.putIfAbsent(shuffleId, new Object())
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@@ -337,6 +407,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
override def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
cachedSerializedStatuses.remove(shuffleId)
+ cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v))
+ shuffleIdLocks.remove(shuffleId)
}
/** Check if the given shuffle is being tracked */
@@ -428,40 +500,89 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ private def removeBroadcast(bcast: Broadcast[_]): Unit = {
+ if (null != bcast) {
+ broadcastManager.unbroadcast(bcast.id,
+ removeFromDriver = true, blocking = false)
+ }
+ }
+
+ private def clearCachedBroadcast(): Unit = {
+ for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
+ cachedSerializedBroadcast.clear()
+ }
+
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
+ var retBytes: Array[Byte] = null
var epochGotten: Long = -1
- epochLock.synchronized {
- if (epoch > cacheEpoch) {
- cachedSerializedStatuses.clear()
- cacheEpoch = epoch
- }
- cachedSerializedStatuses.get(shuffleId) match {
- case Some(bytes) =>
- return bytes
- case None =>
- statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
- epochGotten = epoch
+
+ // Check to see if we have a cached version, returns true if it does
+ // and has side effect of setting retBytes. If not returns false
+ // with side effect of setting statuses
+ def checkCachedStatuses(): Boolean = {
+ epochLock.synchronized {
+ if (epoch > cacheEpoch) {
+ cachedSerializedStatuses.clear()
+ clearCachedBroadcast()
+ cacheEpoch = epoch
+ }
+ cachedSerializedStatuses.get(shuffleId) match {
+ case Some(bytes) =>
+ retBytes = bytes
+ true
+ case None =>
+ logDebug("cached status not found for : " + shuffleId)
+ statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
+ epochGotten = epoch
+ false
+ }
}
}
- // If we got here, we failed to find the serialized locations in the cache, so we pulled
- // out a snapshot of the locations as "statuses"; let's serialize and return that
- val bytes = MapOutputTracker.serializeMapStatuses(statuses)
- logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
- // Add them into the table only if the epoch hasn't changed while we were working
- epochLock.synchronized {
- if (epoch == epochGotten) {
- cachedSerializedStatuses(shuffleId) = bytes
+
+ if (checkCachedStatuses()) return retBytes
+ var shuffleIdLock = shuffleIdLocks.get(shuffleId)
+ if (null == shuffleIdLock) {
+ val newLock = new Object()
+ // in general, this condition should be false - but good to be paranoid
+ val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
+ shuffleIdLock = if (null != prevLock) prevLock else newLock
+ }
+ // synchronize so we only serialize/broadcast it once since multiple threads call
+ // in parallel
+ shuffleIdLock.synchronized {
+ // double check to make sure someone else didn't serialize and cache the same
+ // mapstatus while we were waiting on the synchronize
+ if (checkCachedStatuses()) return retBytes
+
+ // If we got here, we failed to find the serialized locations in the cache, so we pulled
+ // out a snapshot of the locations as "statuses"; let's serialize and return that
+ val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
+ isLocal, minSizeForBroadcast)
+ logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
+ // Add them into the table only if the epoch hasn't changed while we were working
+ epochLock.synchronized {
+ if (epoch == epochGotten) {
+ cachedSerializedStatuses(shuffleId) = bytes
+ if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
+ } else {
+ logInfo("Epoch changed, not caching!")
+ removeBroadcast(bcast)
+ }
}
+ bytes
}
- bytes
}
override def stop() {
+ mapOutputRequests.offer(PoisonPill)
+ threadpool.shutdown()
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerEndpoint = null
cachedSerializedStatuses.clear()
+ clearCachedBroadcast()
+ shuffleIdLocks.clear()
}
}
@@ -477,12 +598,16 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private[spark] object MapOutputTracker extends Logging {
val ENDPOINT_NAME = "MapOutputTracker"
+ private val DIRECT = 0
+ private val BROADCAST = 1
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager,
+ isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = {
val out = new ByteArrayOutputStream
+ out.write(DIRECT)
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
Utils.tryWithSafeFinally {
// Since statuses can be modified in parallel, sync on it
@@ -492,16 +617,51 @@ private[spark] object MapOutputTracker extends Logging {
} {
objOut.close()
}
- out.toByteArray
+ val arr = out.toByteArray
+ if (arr.length >= minBroadcastSize) {
+ // Use broadcast instead.
+ // Important arr(0) is the tag == DIRECT, ignore that while deserializing !
+ val bcast = broadcastManager.newBroadcast(arr, isLocal)
+ // toByteArray creates copy, so we can reuse out
+ out.reset()
+ out.write(BROADCAST)
+ val oos = new ObjectOutputStream(new GZIPOutputStream(out))
+ oos.writeObject(bcast)
+ oos.close()
+ val outArr = out.toByteArray
+ logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
+ (outArr, bcast)
+ } else {
+ (arr, null)
+ }
}
// Opposite of serializeMapStatuses.
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
- val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- Utils.tryWithSafeFinally {
- objIn.readObject().asInstanceOf[Array[MapStatus]]
- } {
- objIn.close()
+ assert (bytes.length > 0)
+
+ def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
+ val objIn = new ObjectInputStream(new GZIPInputStream(
+ new ByteArrayInputStream(arr, off, len)))
+ Utils.tryWithSafeFinally {
+ objIn.readObject()
+ } {
+ objIn.close()
+ }
+ }
+
+ bytes(0) match {
+ case DIRECT =>
+ deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]]
+ case BROADCAST =>
+ // deserialize the Broadcast, pull .value array out of it, and then deserialize that
+ val bcast = deserializeObject(bytes, 1, bytes.length - 1).
+ asInstanceOf[Broadcast[Array[Byte]]]
+ logInfo("Broadcast mapstatuses size = " + bytes.length +
+ ", actual size = " + bcast.value.length)
+ // Important - ignore the DIRECT tag ! Start from offset 1
+ deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
+ case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 27497e21b8..4bf8890c05 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -284,8 +284,10 @@ object SparkEnv extends Logging {
}
}
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
+
val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster(conf)
+ new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
@@ -325,8 +327,6 @@ object SparkEnv extends Logging {
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
blockTransferService, securityManager, numUsableCores)
- val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
-
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ddf48765ec..c6aebc19fd 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.mockito.Matchers.{any, isA}
import org.mockito.Mockito._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
@@ -30,6 +31,12 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
class MapOutputTrackerSuite extends SparkFunSuite {
private val conf = new SparkConf
+ private def newTrackerMaster(sparkConf: SparkConf = conf) = {
+ val broadcastManager = new BroadcastManager(true, sparkConf,
+ new SecurityManager(sparkConf))
+ new MapOutputTrackerMaster(sparkConf, broadcastManager, true)
+ }
+
def createRpcEnv(name: String, host: String = "localhost", port: Int = 0,
securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = {
RpcEnv.create(name, host, port, conf, securityManager)
@@ -37,7 +44,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
test("master start and stop") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.stop()
@@ -46,7 +53,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
test("master register shuffle and fetch") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
@@ -62,13 +69,14 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
(BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
.toSet)
+ assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.stop()
rpcEnv.shutdown()
}
test("master register and unregister shuffle") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
@@ -80,6 +88,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
+ assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
@@ -90,7 +99,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
test("master register shuffle and unregister map output and fetch") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
@@ -101,6 +110,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+ assert(0 == tracker.getNumCachedSerializedBroadcast)
// As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
@@ -118,7 +128,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val hostname = "localhost"
val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
- val masterTracker = new MapOutputTrackerMaster(conf)
+ val masterTracker = newTrackerMaster()
masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
@@ -139,6 +149,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementEpoch()
@@ -147,6 +158,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
masterTracker.stop()
slaveTracker.stop()
@@ -158,8 +170,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val newConf = new SparkConf
newConf.set("spark.rpc.message.maxSize", "1")
newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+ newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "1048576")
- val masterTracker = new MapOutputTrackerMaster(conf)
+ val masterTracker = newTrackerMaster(newConf)
val rpcEnv = createRpcEnv("spark")
val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
@@ -172,45 +185,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val rpcCallContext = mock(classOf[RpcCallContext])
when(rpcCallContext.senderAddress).thenReturn(senderAddress)
masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10))
- verify(rpcCallContext).reply(any())
- verify(rpcCallContext, never()).sendFailure(any())
+ // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast
+ // to be used.
+ verify(rpcCallContext, timeout(30000)).reply(any())
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
// masterTracker.stop() // this throws an exception
rpcEnv.shutdown()
}
- test("remote fetch exceeds max RPC message size") {
+ test("min broadcast size exceeds max RPC message size") {
val newConf = new SparkConf
newConf.set("spark.rpc.message.maxSize", "1")
newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+ newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString)
- val masterTracker = new MapOutputTrackerMaster(conf)
- val rpcEnv = createRpcEnv("test")
- val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
- rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
-
- // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception.
- // Note that the size is hand-selected here because map output statuses are compressed before
- // being sent.
- masterTracker.registerShuffle(20, 100)
- (0 until 100).foreach { i =>
- masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
- }
- val senderAddress = RpcAddress("localhost", 12345)
- val rpcCallContext = mock(classOf[RpcCallContext])
- when(rpcCallContext.senderAddress).thenReturn(senderAddress)
- masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
- verify(rpcCallContext, never()).reply(any())
- verify(rpcCallContext).sendFailure(isA(classOf[SparkException]))
-
-// masterTracker.stop() // this throws an exception
- rpcEnv.shutdown()
+ intercept[IllegalArgumentException] { newTrackerMaster(newConf) }
}
test("getLocationsWithLargestOutputs with multiple outputs in same machine") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
// Setup 3 map tasks
@@ -242,4 +237,44 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.stop()
rpcEnv.shutdown()
}
+
+ test("remote fetch using broadcast") {
+ val newConf = new SparkConf
+ newConf.set("spark.rpc.message.maxSize", "1")
+ newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+ newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize
+
+ // needs TorrentBroadcast so need a SparkContext
+ val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf)
+ try {
+ val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val rpcEnv = sc.env.rpcEnv
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
+ rpcEnv.stop(masterTracker.trackerEndpoint)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+ // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
+ // Note that the size is hand-selected here because map output statuses are compressed before
+ // being sent.
+ masterTracker.registerShuffle(20, 100)
+ (0 until 100).foreach { i =>
+ masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
+ BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+ }
+ val senderAddress = RpcAddress("localhost", 12345)
+ val rpcCallContext = mock(classOf[RpcCallContext])
+ when(rpcCallContext.senderAddress).thenReturn(senderAddress)
+ masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
+ // should succeed since majority of data is broadcast and actual serialized
+ // message size is small
+ verify(rpcCallContext, timeout(30000)).reply(any())
+ assert(1 == masterTracker.getNumCachedSerializedBroadcast)
+ masterTracker.unregisterShuffle(20)
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
+
+ } finally {
+ LocalSparkContext.stop(sc)
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 844c780a3f..e3ed079e4e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -156,6 +157,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
}
var mapOutputTracker: MapOutputTrackerMaster = null
+ var broadcastManager: BroadcastManager = null
+ var securityMgr: SecurityManager = null
var scheduler: DAGScheduler = null
var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
@@ -207,7 +210,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
cancelledStages.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTrackerMaster(conf)
+ securityMgr = new SecurityManager(conf)
+ broadcastManager = new BroadcastManager(true, conf, securityMgr)
+ mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
scheduler = new DAGScheduler(
sc,
taskScheduler,
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index d14728cb50..31687e6147 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
@@ -43,7 +44,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
private var rpcEnv: RpcEnv = null
private var master: BlockManagerMaster = null
private val securityMgr = new SecurityManager(conf)
- private val mapOutputTracker = new MapOutputTrackerMaster(conf)
+ private val bcastManager = new BroadcastManager(true, conf, securityMgr)
+ private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true)
private val shuffleManager = new SortShuffleManager(conf)
// List of block manager created during an unit test, so that all of the them can be stopped
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index db1efaf2a2..a2580304c4 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
@@ -59,7 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
var rpcEnv: RpcEnv = null
var master: BlockManagerMaster = null
val securityMgr = new SecurityManager(new SparkConf(false))
- val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false))
+ val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr)
+ val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true)
val shuffleManager = new SortShuffleManager(new SparkConf(false))
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 4be4882938..e97427991b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.Logging
import org.apache.spark.memory.StaticMemoryManager
import org.apache.spark.network.netty.NettyBlockTransferService
@@ -57,7 +58,8 @@ class ReceivedBlockHandlerSuite
val hadoopConf = new Configuration()
val streamId = 1
val securityMgr = new SecurityManager(conf)
- val mapOutputTracker = new MapOutputTrackerMaster(conf)
+ val broadcastManager = new BroadcastManager(true, conf, securityMgr)
+ val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
val shuffleManager = new SortShuffleManager(conf)
val serializer = new KryoSerializer(conf)
var serializerManager = new SerializerManager(serializer, conf)