aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorThomas Graves <tgraves@staydecay.corp.gq1.yahoo.com>2016-05-06 19:31:26 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-06 19:31:26 -0700
commitcc95f1ed5fdf2566bcefe8d10116eee544cf9184 (patch)
tree7cd0f48681e96228c3c60d30fabc87a95ac0342c /core
parentf7b7ef41662d7d02fc4f834f3c6c4ee8802e949c (diff)
downloadspark-cc95f1ed5fdf2566bcefe8d10116eee544cf9184.tar.gz
spark-cc95f1ed5fdf2566bcefe8d10116eee544cf9184.tar.bz2
spark-cc95f1ed5fdf2566bcefe8d10116eee544cf9184.zip
[SPARK-1239] Improve fetching of map output statuses
The main issue we are trying to solve is the memory bloat of the Driver when tasks request the map output statuses. This means with a large number of tasks you either need a huge amount of memory on Driver or you have to repartition to smaller number. This makes it really difficult to run over say 50000 tasks. The main issues that cause the memory bloat are: 1) no flow control on sending the map output status responses. We serialize the map status output and then hand off to netty to send. netty is sending asynchronously and it can't send them fast enough to keep up with incoming requests so we end up with lots of copies of the serialized map output statuses sitting there and this causes huge bloat when you have 10's of thousands of tasks and map output status is in the 10's of MB. 2) When initial reduce tasks are started up, they all request the map output statuses from the Driver. These requests are handled by multiple threads in parallel so even though we check to see if we have a cached version, initially when we don't have a cached version yet, many of initial requests can all end up serializing the exact same map output statuses. This patch does a couple of things: - When the map output status size is over a threshold (default 512K) then it uses broadcast to send the map statuses. This means we no longer serialize a large map output status and thus we don't have issues with memory bloat. the messages sizes are now in the 300-400 byte range and the map status output are broadcast. If its under the threadshold it sends it as before, the message contains the DIRECT indicator now. - synchronize the incoming requests to allow one thread to cache the serialized output and broadcast the map output status that can then be used by everyone else. This ensures we don't create multiple broadcast variables when we don't need to. To ensure this happens I added a second thread pool which the Dispatcher hands the requests to so that those threads can block without blocking the main dispatcher threads (which would cause things like heartbeats and such not to come through) Note that some of design and code was contributed by mridulm ## How was this patch tested? Unit tests and a lot of manually testing. Ran with akka and netty rpc. Ran with both dynamic allocation on and off. one of the large jobs I used to test this was a join of 15TB of data. it had 200,000 map tasks, and 20,000 reduce tasks. Executors ranged from 200 to 2000. This job ran successfully with 5GB of memory on the driver with these changes. Without these changes I was using 20GB and only had 500 reduce tasks. The job has 50mb of serialized map output statuses and took roughly the same amount of time for the executors to get the map output statuses as before. Ran a variety of other jobs, from large wordcounts to small ones not using broadcasts. Author: Thomas Graves <tgraves@staydecay.corp.gq1.yahoo.com> Closes #12113 from tgravescs/SPARK-1239.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala250
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala99
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala4
6 files changed, 287 insertions, 83 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 3a5caa3510..6bd950205f 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -18,13 +18,15 @@
package org.apache.spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
@@ -37,31 +39,18 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
+private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext)
+
/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
extends RpcEndpoint with Logging {
- val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
- val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
- val serializedSize = mapOutputStatuses.length
- if (serializedSize > maxRpcMessageSize) {
-
- val msg = s"Map output statuses were $serializedSize bytes which " +
- s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)."
-
- /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
- * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
- val exception = new SparkException(msg)
- logError(msg, exception)
- context.sendFailure(exception)
- } else {
- context.reply(mapOutputStatuses)
- }
+ val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
case StopMapOutputTracker =>
logInfo("MapOutputTrackerMasterEndpoint stopped!")
@@ -270,12 +259,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* MapOutputTracker for the driver.
*/
-private[spark] class MapOutputTrackerMaster(conf: SparkConf)
+private[spark] class MapOutputTrackerMaster(conf: SparkConf,
+ broadcastManager: BroadcastManager, isLocal: Boolean)
extends MapOutputTracker(conf) {
/** Cache a serialized version of the output statuses for each shuffle to send them out faster */
private var cacheEpoch = epoch
+ // The size at which we use Broadcast to send the map output statuses to the executors
+ private val minSizeForBroadcast =
+ conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
+
/** Whether to compute locality preferences for reduce tasks */
private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
@@ -296,10 +290,86 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala
+ private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
+
+ // Kept in sync with cachedSerializedStatuses explicitly
+ // This is required so that the Broadcast variable remains in scope until we remove
+ // the shuffleId explicitly or implicitly.
+ private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]()
+
+ // This is to prevent multiple serializations of the same shuffle - which happens when
+ // there is a request storm when shuffle start.
+ private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
+
+ // requests for map output statuses
+ private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
+
+ // Thread pool used for handling map output status requests. This is a separate thread pool
+ // to ensure we don't block the normal dispatcher threads.
+ private val threadpool: ThreadPoolExecutor = {
+ val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
+ val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
+ for (i <- 0 until numThreads) {
+ pool.execute(new MessageLoop)
+ }
+ pool
+ }
+
+ // Make sure that that we aren't going to exceed the max RPC message size by making sure
+ // we use broadcast to send large map output statuses.
+ if (minSizeForBroadcast > maxRpcMessageSize) {
+ val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " +
+ s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " +
+ "message that is to large."
+ logError(msg)
+ throw new IllegalArgumentException(msg)
+ }
+
+ def post(message: GetMapOutputMessage): Unit = {
+ mapOutputRequests.offer(message)
+ }
+
+ /** Message loop used for dispatching messages. */
+ private class MessageLoop extends Runnable {
+ override def run(): Unit = {
+ try {
+ while (true) {
+ try {
+ val data = mapOutputRequests.take()
+ if (data == PoisonPill) {
+ // Put PoisonPill back so that other MessageLoops can see it.
+ mapOutputRequests.offer(PoisonPill)
+ return
+ }
+ val context = data.context
+ val shuffleId = data.shuffleId
+ val hostPort = context.senderAddress.hostPort
+ logDebug("Handling request to send map output locations for shuffle " + shuffleId +
+ " to " + hostPort)
+ val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
+ context.reply(mapOutputStatuses)
+ } catch {
+ case NonFatal(e) => logError(e.getMessage, e)
+ }
+ }
+ } catch {
+ case ie: InterruptedException => // exit
+ }
+ }
+ }
+
+ /** A poison endpoint that indicates MessageLoop should exit its message loop. */
+ private val PoisonPill = new GetMapOutputMessage(-99, null)
+
+ // Exposed for testing
+ private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size
+
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
+ // add in advance
+ shuffleIdLocks.putIfAbsent(shuffleId, new Object())
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@@ -337,6 +407,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
override def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
cachedSerializedStatuses.remove(shuffleId)
+ cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v))
+ shuffleIdLocks.remove(shuffleId)
}
/** Check if the given shuffle is being tracked */
@@ -428,40 +500,89 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ private def removeBroadcast(bcast: Broadcast[_]): Unit = {
+ if (null != bcast) {
+ broadcastManager.unbroadcast(bcast.id,
+ removeFromDriver = true, blocking = false)
+ }
+ }
+
+ private def clearCachedBroadcast(): Unit = {
+ for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
+ cachedSerializedBroadcast.clear()
+ }
+
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
+ var retBytes: Array[Byte] = null
var epochGotten: Long = -1
- epochLock.synchronized {
- if (epoch > cacheEpoch) {
- cachedSerializedStatuses.clear()
- cacheEpoch = epoch
- }
- cachedSerializedStatuses.get(shuffleId) match {
- case Some(bytes) =>
- return bytes
- case None =>
- statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
- epochGotten = epoch
+
+ // Check to see if we have a cached version, returns true if it does
+ // and has side effect of setting retBytes. If not returns false
+ // with side effect of setting statuses
+ def checkCachedStatuses(): Boolean = {
+ epochLock.synchronized {
+ if (epoch > cacheEpoch) {
+ cachedSerializedStatuses.clear()
+ clearCachedBroadcast()
+ cacheEpoch = epoch
+ }
+ cachedSerializedStatuses.get(shuffleId) match {
+ case Some(bytes) =>
+ retBytes = bytes
+ true
+ case None =>
+ logDebug("cached status not found for : " + shuffleId)
+ statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
+ epochGotten = epoch
+ false
+ }
}
}
- // If we got here, we failed to find the serialized locations in the cache, so we pulled
- // out a snapshot of the locations as "statuses"; let's serialize and return that
- val bytes = MapOutputTracker.serializeMapStatuses(statuses)
- logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
- // Add them into the table only if the epoch hasn't changed while we were working
- epochLock.synchronized {
- if (epoch == epochGotten) {
- cachedSerializedStatuses(shuffleId) = bytes
+
+ if (checkCachedStatuses()) return retBytes
+ var shuffleIdLock = shuffleIdLocks.get(shuffleId)
+ if (null == shuffleIdLock) {
+ val newLock = new Object()
+ // in general, this condition should be false - but good to be paranoid
+ val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
+ shuffleIdLock = if (null != prevLock) prevLock else newLock
+ }
+ // synchronize so we only serialize/broadcast it once since multiple threads call
+ // in parallel
+ shuffleIdLock.synchronized {
+ // double check to make sure someone else didn't serialize and cache the same
+ // mapstatus while we were waiting on the synchronize
+ if (checkCachedStatuses()) return retBytes
+
+ // If we got here, we failed to find the serialized locations in the cache, so we pulled
+ // out a snapshot of the locations as "statuses"; let's serialize and return that
+ val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
+ isLocal, minSizeForBroadcast)
+ logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
+ // Add them into the table only if the epoch hasn't changed while we were working
+ epochLock.synchronized {
+ if (epoch == epochGotten) {
+ cachedSerializedStatuses(shuffleId) = bytes
+ if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
+ } else {
+ logInfo("Epoch changed, not caching!")
+ removeBroadcast(bcast)
+ }
}
+ bytes
}
- bytes
}
override def stop() {
+ mapOutputRequests.offer(PoisonPill)
+ threadpool.shutdown()
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerEndpoint = null
cachedSerializedStatuses.clear()
+ clearCachedBroadcast()
+ shuffleIdLocks.clear()
}
}
@@ -477,12 +598,16 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private[spark] object MapOutputTracker extends Logging {
val ENDPOINT_NAME = "MapOutputTracker"
+ private val DIRECT = 0
+ private val BROADCAST = 1
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager,
+ isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = {
val out = new ByteArrayOutputStream
+ out.write(DIRECT)
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
Utils.tryWithSafeFinally {
// Since statuses can be modified in parallel, sync on it
@@ -492,16 +617,51 @@ private[spark] object MapOutputTracker extends Logging {
} {
objOut.close()
}
- out.toByteArray
+ val arr = out.toByteArray
+ if (arr.length >= minBroadcastSize) {
+ // Use broadcast instead.
+ // Important arr(0) is the tag == DIRECT, ignore that while deserializing !
+ val bcast = broadcastManager.newBroadcast(arr, isLocal)
+ // toByteArray creates copy, so we can reuse out
+ out.reset()
+ out.write(BROADCAST)
+ val oos = new ObjectOutputStream(new GZIPOutputStream(out))
+ oos.writeObject(bcast)
+ oos.close()
+ val outArr = out.toByteArray
+ logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
+ (outArr, bcast)
+ } else {
+ (arr, null)
+ }
}
// Opposite of serializeMapStatuses.
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
- val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- Utils.tryWithSafeFinally {
- objIn.readObject().asInstanceOf[Array[MapStatus]]
- } {
- objIn.close()
+ assert (bytes.length > 0)
+
+ def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
+ val objIn = new ObjectInputStream(new GZIPInputStream(
+ new ByteArrayInputStream(arr, off, len)))
+ Utils.tryWithSafeFinally {
+ objIn.readObject()
+ } {
+ objIn.close()
+ }
+ }
+
+ bytes(0) match {
+ case DIRECT =>
+ deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]]
+ case BROADCAST =>
+ // deserialize the Broadcast, pull .value array out of it, and then deserialize that
+ val bcast = deserializeObject(bytes, 1, bytes.length - 1).
+ asInstanceOf[Broadcast[Array[Byte]]]
+ logInfo("Broadcast mapstatuses size = " + bytes.length +
+ ", actual size = " + bcast.value.length)
+ // Important - ignore the DIRECT tag ! Start from offset 1
+ deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
+ case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 27497e21b8..4bf8890c05 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -284,8 +284,10 @@ object SparkEnv extends Logging {
}
}
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
+
val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster(conf)
+ new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
@@ -325,8 +327,6 @@ object SparkEnv extends Logging {
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
blockTransferService, securityManager, numUsableCores)
- val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
-
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ddf48765ec..c6aebc19fd 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.mockito.Matchers.{any, isA}
import org.mockito.Mockito._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
@@ -30,6 +31,12 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
class MapOutputTrackerSuite extends SparkFunSuite {
private val conf = new SparkConf
+ private def newTrackerMaster(sparkConf: SparkConf = conf) = {
+ val broadcastManager = new BroadcastManager(true, sparkConf,
+ new SecurityManager(sparkConf))
+ new MapOutputTrackerMaster(sparkConf, broadcastManager, true)
+ }
+
def createRpcEnv(name: String, host: String = "localhost", port: Int = 0,
securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = {
RpcEnv.create(name, host, port, conf, securityManager)
@@ -37,7 +44,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
test("master start and stop") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.stop()
@@ -46,7 +53,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
test("master register shuffle and fetch") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
@@ -62,13 +69,14 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
(BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
.toSet)
+ assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.stop()
rpcEnv.shutdown()
}
test("master register and unregister shuffle") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
@@ -80,6 +88,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
+ assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
@@ -90,7 +99,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
test("master register shuffle and unregister map output and fetch") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
@@ -101,6 +110,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+ assert(0 == tracker.getNumCachedSerializedBroadcast)
// As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
@@ -118,7 +128,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val hostname = "localhost"
val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
- val masterTracker = new MapOutputTrackerMaster(conf)
+ val masterTracker = newTrackerMaster()
masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
@@ -139,6 +149,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementEpoch()
@@ -147,6 +158,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
masterTracker.stop()
slaveTracker.stop()
@@ -158,8 +170,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val newConf = new SparkConf
newConf.set("spark.rpc.message.maxSize", "1")
newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+ newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "1048576")
- val masterTracker = new MapOutputTrackerMaster(conf)
+ val masterTracker = newTrackerMaster(newConf)
val rpcEnv = createRpcEnv("spark")
val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
@@ -172,45 +185,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val rpcCallContext = mock(classOf[RpcCallContext])
when(rpcCallContext.senderAddress).thenReturn(senderAddress)
masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10))
- verify(rpcCallContext).reply(any())
- verify(rpcCallContext, never()).sendFailure(any())
+ // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast
+ // to be used.
+ verify(rpcCallContext, timeout(30000)).reply(any())
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
// masterTracker.stop() // this throws an exception
rpcEnv.shutdown()
}
- test("remote fetch exceeds max RPC message size") {
+ test("min broadcast size exceeds max RPC message size") {
val newConf = new SparkConf
newConf.set("spark.rpc.message.maxSize", "1")
newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+ newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString)
- val masterTracker = new MapOutputTrackerMaster(conf)
- val rpcEnv = createRpcEnv("test")
- val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
- rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
-
- // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception.
- // Note that the size is hand-selected here because map output statuses are compressed before
- // being sent.
- masterTracker.registerShuffle(20, 100)
- (0 until 100).foreach { i =>
- masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
- }
- val senderAddress = RpcAddress("localhost", 12345)
- val rpcCallContext = mock(classOf[RpcCallContext])
- when(rpcCallContext.senderAddress).thenReturn(senderAddress)
- masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
- verify(rpcCallContext, never()).reply(any())
- verify(rpcCallContext).sendFailure(isA(classOf[SparkException]))
-
-// masterTracker.stop() // this throws an exception
- rpcEnv.shutdown()
+ intercept[IllegalArgumentException] { newTrackerMaster(newConf) }
}
test("getLocationsWithLargestOutputs with multiple outputs in same machine") {
val rpcEnv = createRpcEnv("test")
- val tracker = new MapOutputTrackerMaster(conf)
+ val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
// Setup 3 map tasks
@@ -242,4 +237,44 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.stop()
rpcEnv.shutdown()
}
+
+ test("remote fetch using broadcast") {
+ val newConf = new SparkConf
+ newConf.set("spark.rpc.message.maxSize", "1")
+ newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+ newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize
+
+ // needs TorrentBroadcast so need a SparkContext
+ val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf)
+ try {
+ val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val rpcEnv = sc.env.rpcEnv
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
+ rpcEnv.stop(masterTracker.trackerEndpoint)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+ // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
+ // Note that the size is hand-selected here because map output statuses are compressed before
+ // being sent.
+ masterTracker.registerShuffle(20, 100)
+ (0 until 100).foreach { i =>
+ masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
+ BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+ }
+ val senderAddress = RpcAddress("localhost", 12345)
+ val rpcCallContext = mock(classOf[RpcCallContext])
+ when(rpcCallContext.senderAddress).thenReturn(senderAddress)
+ masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
+ // should succeed since majority of data is broadcast and actual serialized
+ // message size is small
+ verify(rpcCallContext, timeout(30000)).reply(any())
+ assert(1 == masterTracker.getNumCachedSerializedBroadcast)
+ masterTracker.unregisterShuffle(20)
+ assert(0 == masterTracker.getNumCachedSerializedBroadcast)
+
+ } finally {
+ LocalSparkContext.stop(sc)
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 844c780a3f..e3ed079e4e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -156,6 +157,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
}
var mapOutputTracker: MapOutputTrackerMaster = null
+ var broadcastManager: BroadcastManager = null
+ var securityMgr: SecurityManager = null
var scheduler: DAGScheduler = null
var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
@@ -207,7 +210,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
cancelledStages.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTrackerMaster(conf)
+ securityMgr = new SecurityManager(conf)
+ broadcastManager = new BroadcastManager(true, conf, securityMgr)
+ mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
scheduler = new DAGScheduler(
sc,
taskScheduler,
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index d14728cb50..31687e6147 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
@@ -43,7 +44,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
private var rpcEnv: RpcEnv = null
private var master: BlockManagerMaster = null
private val securityMgr = new SecurityManager(conf)
- private val mapOutputTracker = new MapOutputTrackerMaster(conf)
+ private val bcastManager = new BroadcastManager(true, conf, securityMgr)
+ private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true)
private val shuffleManager = new SortShuffleManager(conf)
// List of block manager created during an unit test, so that all of the them can be stopped
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index db1efaf2a2..a2580304c4 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
@@ -59,7 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
var rpcEnv: RpcEnv = null
var master: BlockManagerMaster = null
val securityMgr = new SecurityManager(new SparkConf(false))
- val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false))
+ val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr)
+ val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true)
val shuffleManager = new SortShuffleManager(new SparkConf(false))
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test