aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-08-23 19:38:28 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-08-23 19:38:28 -0700
commit7310a6f4993ba7539706446d3bda0a803ee4cb81 (patch)
treea5859b7d1a24cebc859016aac363202dfea63b8d
parent25a6a39e6d6db1264ab7633d1dcfe886415fbf1a (diff)
parentd821dd3ccc37f9c06c76da08c464daac38a1f045 (diff)
downloadspark-7310a6f4993ba7539706446d3bda0a803ee4cb81.tar.gz
spark-7310a6f4993ba7539706446d3bda0a803ee4cb81.tar.bz2
spark-7310a6f4993ba7539706446d3bda0a803ee4cb81.zip
Merge pull request #147 from mosharaf/dev
Broadcast refactoring/cleaning up
-rw-r--r--core/src/main/scala/spark/SparkContext.scala7
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala8
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala542
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala188
-rw-r--r--core/src/main/scala/spark/broadcast/BroadcastFactory.scala1
-rw-r--r--core/src/main/scala/spark/broadcast/ChainedBroadcast.scala794
-rw-r--r--core/src/main/scala/spark/broadcast/DfsBroadcast.scala135
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala38
-rw-r--r--core/src/main/scala/spark/broadcast/MultiTracker.scala394
-rw-r--r--core/src/main/scala/spark/broadcast/SourceInfo.scala10
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala452
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala2
12 files changed, 661 insertions, 1910 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index b0f5e12a76..43414d2e41 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -65,7 +65,7 @@ class SparkContext(
System.setProperty("spark.master.port", "0")
}
- private val isLocal = (master == "local" || master.startsWith("local["))
+ private val isLocal = (master == "local" || master.startsWith("local[")) && !master.startsWith("localhost")
// Create the Spark execution environment (cache, map output tracker, etc)
val env = SparkEnv.createFromSystemProperties(
@@ -74,7 +74,6 @@ class SparkContext(
true,
isLocal)
SparkEnv.set(env)
- Broadcast.initialize(true)
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@@ -295,14 +294,14 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
- def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal)
+ def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Stop the SparkContext
def stop() {
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
- // TODO: Broadcast.stop(), Cache.stop()?
+ // TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 694db6b2a3..add8fcec51 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -2,6 +2,7 @@ package spark
import akka.actor.ActorSystem
+import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
@@ -16,13 +17,14 @@ class SparkEnv (
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
+ val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager
) {
/** No-parameter constructor for unit tests. */
def this() = {
- this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
+ this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}
def stop() {
@@ -30,6 +32,7 @@ class SparkEnv (
cacheTracker.stop()
shuffleFetcher.stop()
shuffleManager.stop()
+ broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
actorSystem.shutdown()
@@ -74,6 +77,8 @@ object SparkEnv {
val shuffleManager = new ShuffleManager()
+ val broadcastManager = new BroadcastManager(isMaster)
+
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
@@ -119,6 +124,7 @@ object SparkEnv {
mapOutputTracker,
shuffleFetcher,
shuffleManager,
+ broadcastManager,
blockManager,
connectionManager)
}
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index e009d4e7db..473d080044 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -2,21 +2,22 @@ package spark.broadcast
import java.io._
import java.net._
-import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID}
+import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import spark._
+import spark.storage.StorageLevel
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
- BitTorrentBroadcast.synchronized {
- BitTorrentBroadcast.values.put(uuid, 0, value_)
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -25,8 +26,6 @@ extends Broadcast[T] with Logging with Serializable {
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = new AtomicInteger(0)
- // CHANGED: BlockSize in the Broadcast object is expected to change over time
- @transient var blockSize = Broadcast.BlockSize
// Used ONLY by Master to track how many unique blocks have been sent out
@transient var sentBlocks = new AtomicInteger(0)
@@ -45,14 +44,10 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers
@transient var ttGuide: TalkToGuide = null
- @transient var rxSpeeds = new SpeedTracker
- @transient var txSpeeds = new SpeedTracker
-
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
- @transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
@@ -63,19 +58,10 @@ extends Broadcast[T] with Logging with Serializable {
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
- // Store a persistent copy in HDFS
- // TODO: Turned OFF for now. Related to persistence
- // val out = new ObjectOutputStream(BroadcastCH.openFileForWriting(uuid))
- // out.writeObject(value_)
- // out.close()
- // FIXME: Fix this at some point
- hasCopyInHDFS = true
-
// Create a variableInfo object and store it in valueInfos
- var variableInfo = Broadcast.blockifyObject(value_)
+ var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
- // TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
@@ -95,9 +81,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER guideMR is created
while (guidePort == -1) {
- guidePortLock.synchronized {
- guidePortLock.wait()
- }
+ guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
@@ -107,14 +91,12 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER serveMR is created
while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
+ listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
hasBlocksBitVector.synchronized {
masterSource.hasBlocksBitVector = hasBlocksBitVector
}
@@ -123,46 +105,42 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
- registerBroadcast(uuid,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes, blockSize))
+ MultiTracker.registerBroadcast(uuid,
+ SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
- BitTorrentBroadcast.synchronized {
- val cachedVal = BitTorrentBroadcast.values.get(uuid, 0)
-
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- // Only the first worker in a node can ever be inside this 'else'
- initializeWorkerVariables
-
- logInfo("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(uuid)
- // If does not succeed, then get from HDFS copy
- if (receptionSucceeded) {
- value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- BitTorrentBroadcast.values.put(uuid, 0, value_)
- } else {
- // TODO: This part won't work, cause HDFS writing is turned OFF
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- BitTorrentBroadcast.values.put(uuid, 0, value_)
- fileIn.close()
- }
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.getSingle(uuid.toString) match {
+ case Some(x) => x.asInstanceOf[T]
+ case None => {
+ logInfo("Started reading broadcast variable " + uuid)
+ // Initializing everything because Master will only send null/0 values
+ // Only the 1st worker in a node can be here. Others will get from cache
+ initializeWorkerVariables
+
+ logInfo("Local host address: " + hostAddress)
+
+ // Start local ServeMultipleRequests thread first
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(uuid)
+ if (receptionSucceeded) {
+ value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
+ } else {
+ logError("Reading Broadcasted variable " + uuid + " failed")
+ }
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
}
}
}
@@ -175,7 +153,6 @@ extends Broadcast[T] with Logging with Serializable {
totalBytes = -1
totalBlocks = -1
hasBlocks = new AtomicInteger(0)
- blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
@@ -183,9 +160,6 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null
ttGuide = null
- rxSpeeds = new SpeedTracker
- txSpeeds = new SpeedTracker
-
hostAddress = Utils.localIpAddress
listenPort = -1
@@ -194,75 +168,19 @@ extends Broadcast[T] with Logging with Serializable {
stopBroadcast = false
}
- private def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
- val socket = new Socket(Broadcast.MasterHostAddress,
- Broadcast.MasterTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send UUID of this broadcast
- oosST.writeObject(uuid)
- oosST.flush()
-
- // Send this tracker's information
- oosST.writeObject(gInfo)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- private def unregisterBroadcast(uuid: UUID) {
- val socket = new Socket(Broadcast.MasterHostAddress,
- Broadcast.MasterTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send UUID of this broadcast
- oosST.writeObject(uuid)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
private def getLocalSourceInfo: SourceInfo = {
// Wait till hostName and listenPort are OK
while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
+ listenPortLock.synchronized { listenPortLock.wait() }
}
// Wait till totalBlocks and totalBytes are OK
while (totalBlocks == -1) {
- totalBlocksLock.synchronized {
- totalBlocksLock.wait()
- }
+ totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
var localSourceInfo = SourceInfo(
- hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+ hostAddress, listenPort, totalBlocks, totalBytes)
localSourceInfo.hasBlocks = hasBlocks.get
@@ -274,7 +192,7 @@ extends Broadcast[T] with Logging with Serializable {
}
// Add new SourceInfo to the listOfSources. Update if it exists already.
- // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance
+ // Optimizing just by OR-ing the BitVectors was BAD for performance
private def addToListOfSources(newSourceInfo: SourceInfo) {
listOfSources.synchronized {
if (listOfSources.contains(newSourceInfo)) {
@@ -297,9 +215,9 @@ extends Broadcast[T] with Logging with Serializable {
// Keep exchaning information until all blocks have been received
while (hasBlocks.get < totalBlocks) {
talkOnce
- Thread.sleep(BitTorrentBroadcast.ranGen.nextInt(
- Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
- Broadcast.MinKnockInterval)
+ Thread.sleep(MultiTracker.ranGen.nextInt(
+ MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
+ MultiTracker.MinKnockInterval)
}
// Talk one more time to let the Guide know of reception completion
@@ -324,7 +242,7 @@ extends Broadcast[T] with Logging with Serializable {
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logInfo("Received suitableSources from Master " + suitableSources)
+ logDebug("Received suitableSources from Master " + suitableSources)
addToListOfSources(suitableSources)
@@ -334,76 +252,17 @@ extends Broadcast[T] with Logging with Serializable {
}
}
- def getGuideInfo(variableUUID: UUID): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToHDFS)
-
- var retriesLeft = Broadcast.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- clientSocketToTracker =
- new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send messageType/intention
- oosTracker.writeObject(Broadcast.FIND_BROADCAST_TRACKER)
- oosTracker.flush()
-
- // Send UUID and receive GuideInfo
- oosTracker.writeObject(uuid)
- oosTracker.flush()
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => {
- logInfo("getGuideInfo had a " + e)
- }
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
-
- Thread.sleep(BitTorrentBroadcast.ranGen.nextInt(
- Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
- Broadcast.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logInfo("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
def receiveBroadcast(variableUUID: UUID): Boolean = {
- val gInfo = getGuideInfo(variableUUID)
+ val gInfo = MultiTracker.getGuideInfo(variableUUID)
- if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS ||
- gInfo.listenPort == SourceInfo.TxNotStartedRetry) {
- // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
- // to HDFS anyway when receiveBroadcast returns false
+ if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
+ listenPortLock.synchronized { listenPortLock.wait() }
}
// Setup initial states of variables
@@ -411,11 +270,8 @@ extends Broadcast[T] with Logging with Serializable {
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
hasBlocksBitVector = new BitSet(totalBlocks)
numCopiesSent = new Array[Int](totalBlocks)
- totalBlocksLock.synchronized {
- totalBlocksLock.notifyAll()
- }
+ totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = gInfo.totalBytes
- blockSize = gInfo.blockSize
// Start ttGuide to periodically talk to the Guide
var ttGuide = new TalkToGuide(gInfo)
@@ -432,7 +288,7 @@ extends Broadcast[T] with Logging with Serializable {
// FIXME: Must fix this. This might never break if broadcast fails.
// We should be able to break and send false. Also need to kill threads
while (hasBlocks.get < totalBlocks) {
- Thread.sleep(Broadcast.MaxKnockInterval)
+ Thread.sleep(MultiTracker.MaxKnockInterval)
}
return true
@@ -446,36 +302,34 @@ extends Broadcast[T] with Logging with Serializable {
private var blocksInRequestBitVector = new BitSet(totalBlocks)
override def run() {
- var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxRxSlots)
+ var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
while (hasBlocks.get < totalBlocks) {
var numThreadsToCreate =
- math.min(listOfSources.size, Broadcast.MaxRxSlots) -
+ math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
threadPool.getActiveCount
while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
var peerToTalkTo = pickPeerToTalkToRandom
if (peerToTalkTo != null)
- logInfo("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
+ logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
else
- logInfo("No peer chosen...")
+ logDebug("No peer chosen...")
if (peerToTalkTo != null) {
threadPool.execute(new TalkToPeer(peerToTalkTo))
// Add to peersNowTalking. Remove in the thread. We have to do this
// ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized {
- peersNowTalking += peerToTalkTo
- }
+ peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
}
numThreadsToCreate = numThreadsToCreate - 1
}
// Sleep for a while before starting some more threads
- Thread.sleep(Broadcast.MinKnockInterval)
+ Thread.sleep(MultiTracker.MinKnockInterval)
}
// Shutdown the thread pool
threadPool.shutdown()
@@ -487,7 +341,7 @@ extends Broadcast[T] with Logging with Serializable {
var curPeer: SourceInfo = null
var curMax = 0
- logInfo("Picking peers to talk to...")
+ logDebug("Picking peers to talk to...")
// Find peers that are not connected right now
var peersNotInUse = ListBuffer[SourceInfo]()
@@ -512,11 +366,10 @@ extends Broadcast[T] with Logging with Serializable {
}
}
- // TODO: Always pick randomly or randomly pick randomly?
- // Now always picking randomly
+ // Always picking randomly
if (curPeer == null && peersNotInUse.size > 0) {
// Pick uniformly the i'th required peer
- var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size)
+ var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
var peerIter = peersNotInUse.iterator
curPeer = peerIter.next
@@ -552,8 +405,8 @@ extends Broadcast[T] with Logging with Serializable {
}
}
- // TODO: A block is rare if there are at most 2 copies of that block
- // TODO: This CONSTANT could be a function of the neighborhood size
+ // A block is considered rare if there are at most 2 copies of that block
+ // This CONSTANT could be a function of the neighborhood size
var rareBlocksIndices = ListBuffer[Int]()
for (i <- 0 until totalBlocks) {
if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
@@ -587,7 +440,7 @@ extends Broadcast[T] with Logging with Serializable {
// Sort the peers based on how many rare blocks they have
peersWithRareBlocks.sortBy(_._2)
- var randomNumber = BitTorrentBroadcast.ranGen.nextDouble
+ var randomNumber = MultiTracker.ranGen.nextDouble
var tempSum = 0.0
var i = 0
@@ -625,7 +478,7 @@ extends Broadcast[T] with Logging with Serializable {
}
var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Broadcast.MaxKnockInterval)
+ timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
logInfo("TalkToPeer started... => " + peerToTalkTo)
@@ -677,7 +530,7 @@ extends Broadcast[T] with Logging with Serializable {
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
- logInfo("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
+ logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
if (!hasBlocksBitVector.get(bcBlock.blockID)) {
arrayOfBlocks(bcBlock.blockID) = bcBlock
@@ -688,8 +541,6 @@ extends Broadcast[T] with Logging with Serializable {
hasBlocks.getAndIncrement
}
- rxSpeeds.addDataPoint(peerToTalkTo, receptionTime)
-
// Some block(may NOT be blockToAskFor) has arrived.
// In any case, blockToAskFor is not in request any more
blocksInRequestBitVector.synchronized {
@@ -710,7 +561,7 @@ extends Broadcast[T] with Logging with Serializable {
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
- logInfo("TalktoPeer had a " + e)
+ logError("TalktoPeer had a " + e)
// FIXME: Remove 'newPeerToTalkTo' from listOfSources
// We probably should have the following in some form, but not
// really here. This exception can happen if the sender just breaks connection
@@ -741,8 +592,8 @@ extends Broadcast[T] with Logging with Serializable {
}
// Include blocks already in transmission ONLY IF
- // BitTorrentBroadcast.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) {
+ // MultiTracker.EndGameFraction has NOT been achieved
+ if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
blocksInRequestBitVector.synchronized {
needBlocksBitVector.or(blocksInRequestBitVector)
}
@@ -758,7 +609,7 @@ extends Broadcast[T] with Logging with Serializable {
return -1
} else {
// Pick uniformly the i'th required block
- var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality)
+ var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
while (i > 0) {
@@ -781,8 +632,8 @@ extends Broadcast[T] with Logging with Serializable {
}
// Include blocks already in transmission ONLY IF
- // BitTorrentBroadcast.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) {
+ // MultiTracker.EndGameFraction has NOT been achieved
+ if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
blocksInRequestBitVector.synchronized {
needBlocksBitVector.or(blocksInRequestBitVector)
}
@@ -830,7 +681,7 @@ extends Broadcast[T] with Logging with Serializable {
return -1
} else {
// Pick uniformly the i'th index
- var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size)
+ var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
return minBlocksIndices(i)
}
}
@@ -848,9 +699,7 @@ extends Broadcast[T] with Logging with Serializable {
}
// Delete from peersNowTalking
- peersNowTalking.synchronized {
- peersNowTalking = peersNowTalking - peerToTalkTo
- }
+ peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
}
}
}
@@ -868,20 +717,18 @@ extends Broadcast[T] with Logging with Serializable {
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
- guidePortLock.synchronized {
- guidePortLock.notifyAll()
- }
+ guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
// Don't stop until there is a copy in HDFS
- while (!stopBroadcast || !hasCopyInHDFS) {
+ while (!stopBroadcast) {
var clientSocket: Socket = null
try {
- serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
- logInfo("GuideMultipleRequests Timeout.")
+ logError("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
@@ -893,7 +740,7 @@ extends Broadcast[T] with Logging with Serializable {
}
}
if (clientSocket != null) {
- logInfo("Guide: Accepted new client connection:" + clientSocket)
+ logDebug("Guide: Accepted new client connection:" + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
@@ -911,7 +758,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
- unregisterBroadcast(uuid)
+ MultiTracker.unregisterBroadcast(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@@ -930,13 +777,10 @@ extends Broadcast[T] with Logging with Serializable {
try {
// Connect to the source
- guideSocketToSource =
- new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource =
- new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
- gisSource =
- new ObjectInputStream(guideSocketToSource.getInputStream)
+ gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
// Throw away whatever comes in
gisSource.readObject.asInstanceOf[SourceInfo]
@@ -946,7 +790,7 @@ extends Broadcast[T] with Logging with Serializable {
gosSource.flush()
} catch {
case e: Exception => {
- logInfo("sendStopBroadcastNotifications had a " + e)
+ logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
@@ -980,7 +824,7 @@ extends Broadcast[T] with Logging with Serializable {
// Select a suitable source and send it back to the worker
selectedSources = selectSuitableSources(sourceInfo)
- logInfo("Sending selectedSources:" + selectedSources)
+ logDebug("Sending selectedSources:" + selectedSources)
oos.writeObject(selectedSources)
oos.flush()
@@ -990,12 +834,11 @@ extends Broadcast[T] with Logging with Serializable {
case e: Exception => {
// Assuming exception caused by receiver failure: remove
if (listOfSources != null) {
- listOfSources.synchronized {
- listOfSources = listOfSources - sourceInfo
- }
+ listOfSources.synchronized { listOfSources -= sourceInfo }
}
}
} finally {
+ logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
@@ -1009,24 +852,22 @@ extends Broadcast[T] with Logging with Serializable {
// If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
// then add skipSourceInfo to setOfCompletedSources. Return blank.
if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources.synchronized {
- setOfCompletedSources += skipSourceInfo
- }
+ setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
return selectedSources
}
listOfSources.synchronized {
- if (listOfSources.size <= Broadcast.MaxPeersInGuideResponse) {
+ if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
selectedSources = listOfSources.clone
} else {
- var picksLeft = Broadcast.MaxPeersInGuideResponse
+ var picksLeft = MultiTracker.MaxPeersInGuideResponse
var alreadyPicked = new BitSet(listOfSources.size)
while (picksLeft > 0) {
var i = -1
do {
- i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size)
+ i = MultiTracker.ranGen.nextInt(listOfSources.size)
} while (alreadyPicked.get(i))
var peerIter = listOfSources.iterator
@@ -1057,8 +898,8 @@ extends Broadcast[T] with Logging with Serializable {
class ServeMultipleRequests
extends Thread with Logging {
- // Server at most Broadcast.MaxTxSlots peers
- var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxTxSlots)
+ // Server at most MultiTracker.MaxChatSlots peers
+ var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
override def run() {
var serverSocket = new ServerSocket(0)
@@ -1066,30 +907,26 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("ServeMultipleRequests started with " + serverSocket)
- listenPortLock.synchronized {
- listenPortLock.notifyAll()
- }
+ listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
- serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
- logInfo("ServeMultipleRequests Timeout.")
+ logError("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
+ logDebug("Serve: Accepted new client connection:" + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
+ case ioe: IOException => clientSocket.close()
}
}
}
@@ -1125,14 +962,13 @@ extends Broadcast[T] with Logging with Serializable {
if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
- // Carry on
addToListOfSources(rxSourceInfo)
}
val startTime = System.currentTimeMillis
var curTime = startTime
var keepSending = true
- var numBlocksToSend = Broadcast.MaxChatBlocks
+ var numBlocksToSend = MultiTracker.MaxChatBlocks
while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
// Receive which block to send
@@ -1140,7 +976,7 @@ extends Broadcast[T] with Logging with Serializable {
// If it is master AND at least one copy of each block has not been
// sent out already, MODIFY blockToSend
- if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) {
+ if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) {
blockToSend = sentBlocks.getAndIncrement
}
@@ -1152,27 +988,21 @@ extends Broadcast[T] with Logging with Serializable {
// Receive latest SourceInfo from the receiver
rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
+ logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
addToListOfSources(rxSourceInfo)
curTime = System.currentTimeMillis
// Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= Broadcast.MaxChatTime &&
+ if (curTime - startTime >= MultiTracker.MaxChatTime &&
threadPool.getQueue.size > 0) {
keepSending = false
}
}
} catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- // Exception can happen if the receiver stops receiving
- case e: Exception => {
- logInfo("ServeSingleRequest had a " + e)
- }
+ case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
- // TODO: The following line causes a "java.net.SocketException: Socket closed"
oos.close()
clientSocket.close()
}
@@ -1183,11 +1013,9 @@ extends Broadcast[T] with Logging with Serializable {
oos.writeObject(arrayOfBlocks(blockToSend))
oos.flush()
} catch {
- case e: Exception => {
- logInfo("sendBlock had a " + e)
- }
+ case e: Exception => logError("sendBlock had a " + e)
}
- logInfo("Sent block: " + blockToSend + " to " + clientSocket)
+ logDebug("Sent block: " + blockToSend + " to " + clientSocket)
}
}
}
@@ -1195,161 +1023,7 @@ extends Broadcast[T] with Logging with Serializable {
class BitTorrentBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) {
- BitTorrentBroadcast.initialize(isMaster)
- }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean) = {
- new BitTorrentBroadcast[T](value_, isLocal)
- }
-}
-
-private object BitTorrentBroadcast
-extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
- var valueToGuideMap = Map[UUID, SourceInfo]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var isMaster_ = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(isMaster__ : Boolean) {
- synchronized {
- if (!initialized) {
-
- isMaster_ = isMaster__
-
- if (isMaster) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
- // TODO: Logging the following line makes the Spark framework ID not
- // getting logged, cause it calls logInfo before log4j is initialized
- logInfo("TrackMultipleValues started...")
- }
-
- // Initialize DfsBroadcast to be used for broadcast variable persistence
- // TODO: Think about persistence
- DfsBroadcast.initialize
-
- initialized = true
- }
- }
- }
-
- def isMaster = isMaster_
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
- logInfo("TrackMultipleValues" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- logInfo("TrackMultipleValues Timeout. Stopping listening...")
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // First, read message type
- val messageType = ois.readObject.asInstanceOf[Int]
-
- if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) {
- // Receive UUID
- val uuid = ois.readObject.asInstanceOf[UUID]
- // Receive hostAddress and listenPort
- val gInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Add to the map
- valueToGuideMap.synchronized {
- valueToGuideMap += (uuid -> gInfo)
- }
-
- logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) {
- // Receive UUID
- val uuid = ois.readObject.asInstanceOf[UUID]
-
- // Remove from the map
- valueToGuideMap.synchronized {
- valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToHDFS)
- logInfo("Value unregistered from the Tracker " + valueToGuideMap)
- }
-
- logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) {
- // Receive UUID
- val uuid = ois.readObject.asInstanceOf[UUID]
-
- var gInfo =
- if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
- else SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
-
- // Send reply back
- oos.writeObject(gInfo)
- oos.flush()
- } else if (messageType == Broadcast.GET_UPDATED_SHARE) {
- // TODO: Not implemented
- } else {
- throw new SparkException("Undefined messageType at TrackMultipleValues")
- }
- } catch {
- case e: Exception => {
- logInfo("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
+ def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
+ def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
+ def stop() = MultiTracker.stop
}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index eaa9153279..d68e56a114 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -5,6 +5,8 @@ import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+import scala.collection.mutable.Map
+
import spark._
trait Broadcast[T] extends Serializable {
@@ -13,24 +15,20 @@ trait Broadcast[T] extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
- // readObject having to be 'private' in sub-classes. Possibly a Scala bug!
+ // readObject having to be 'private' in sub-classes.
override def toString = "spark.Broadcast(" + uuid + ")"
}
-object Broadcast extends Logging with Serializable {
- // Messages
- val REGISTER_BROADCAST_TRACKER = 0
- val UNREGISTER_BROADCAST_TRACKER = 1
- val FIND_BROADCAST_TRACKER = 2
- val GET_UPDATED_SHARE = 3
+class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
private var initialized = false
- private var isMaster_ = false
private var broadcastFactory: BroadcastFactory = null
+ initialize()
+
// Called by SparkContext or Executor before using Broadcast
- def initialize (isMaster__ : Boolean) {
+ private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
@@ -39,14 +37,6 @@ object Broadcast extends Logging with Serializable {
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
- // Setup isMaster before using it
- isMaster_ = isMaster__
-
- // Set masterHostAddress to the master's IP address for the slaves to read
- if (isMaster) {
- System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
- }
-
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
@@ -55,170 +45,18 @@ object Broadcast extends Logging with Serializable {
}
}
- def getBroadcastFactory: BroadcastFactory = {
+ def stop() {
+ broadcastFactory.stop()
+ }
+
+ private def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
- // Load common broadcast-related config parameters
- private var MasterHostAddress_ = System.getProperty(
- "spark.broadcast.masterHostAddress", "")
- private var MasterTrackerPort_ = System.getProperty(
- "spark.broadcast.masterTrackerPort", "11111").toInt
- private var BlockSize_ = System.getProperty(
- "spark.broadcast.blockSize", "4096").toInt * 1024
- private var MaxRetryCount_ = System.getProperty(
- "spark.broadcast.maxRetryCount", "2").toInt
-
- private var TrackerSocketTimeout_ = System.getProperty(
- "spark.broadcast.trackerSocketTimeout", "50000").toInt
- private var ServerSocketTimeout_ = System.getProperty(
- "spark.broadcast.serverSocketTimeout", "10000").toInt
-
- private var MinKnockInterval_ = System.getProperty(
- "spark.broadcast.minKnockInterval", "500").toInt
- private var MaxKnockInterval_ = System.getProperty(
- "spark.broadcast.maxKnockInterval", "999").toInt
-
- // Load ChainedBroadcast config params
-
- // Load TreeBroadcast config params
- private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt
-
- // Load BitTorrentBroadcast config params
- private var MaxPeersInGuideResponse_ = System.getProperty(
- "spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- private var MaxRxSlots_ = System.getProperty("spark.broadcast.maxRxSlots", "4").toInt
- private var MaxTxSlots_ = System.getProperty("spark.broadcast.maxTxSlots", "4").toInt
-
- private var MaxChatTime_ = System.getProperty("spark.broadcast.maxChatTime", "500").toInt
- private var MaxChatBlocks_ = System.getProperty("spark.broadcast.maxChatBlocks", "1024").toInt
-
- private var EndGameFraction_ = System.getProperty(
- "spark.broadcast.endGameFraction", "0.95").toDouble
+ def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal)
def isMaster = isMaster_
-
- // Common config params
- def MasterHostAddress = MasterHostAddress_
- def MasterTrackerPort = MasterTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- // ChainedBroadcast configs
-
- // TreeBroadcast configs
- def MaxDegree = MaxDegree_
-
- // BitTorrentBroadcast configs
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxRxSlots = MaxRxSlots_
- def MaxTxSlots = MaxTxSlots_
-
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- // Helper functions to convert an object to Array[BroadcastBlock]
- def blockifyObject[IN](obj: IN): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- oos.writeObject(obj)
- oos.close()
- baos.close()
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = (byteArray.length / Broadcast.BlockSize)
- if (byteArray.length % Broadcast.BlockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock](blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
- val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
-
- retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
- blockID += 1
- }
- bais.close()
-
- var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- // Helper function to convert Array[BroadcastBlock] to object
- def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
- totalBytes: Int,
- totalBlocks: Int): OUT = {
-
- var retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject(retByteArray)
- }
-
- private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) {
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
- }
- val retVal = in.readObject.asInstanceOf[OUT]
- in.close()
- return retVal
- }
-}
-
-case class BroadcastBlock (blockID: Int, byteArray: Array[Byte]) extends Serializable
-
-case class VariableInfo (@transient arrayOfBlocks : Array[BroadcastBlock],
- totalBlocks: Int,
- totalBytes: Int)
- extends Serializable {
-
- @transient
- var hasBlocks = 0
-}
-
-class SpeedTracker extends Serializable {
- // Mapping 'source' to '(totalTime, numBlocks)'
- private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
-
- def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long) {
- sourceToSpeedMap.synchronized {
- if (!sourceToSpeedMap.contains(srcInfo)) {
- sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
- } else {
- val tTnB = sourceToSpeedMap (srcInfo)
- sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
- }
- }
- }
-
- def getTimePerBlock (srcInfo: SourceInfo): Double = {
- sourceToSpeedMap.synchronized {
- val tTnB = sourceToSpeedMap (srcInfo)
- return tTnB._1 / tTnB._2
- }
- }
-
- override def toString = sourceToSpeedMap.toString
}
diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
index b18908f789..e341d556bf 100644
--- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
@@ -9,4 +9,5 @@ package spark.broadcast
trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
+ def stop(): Unit
}
diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
deleted file mode 100644
index 43290c241f..0000000000
--- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
+++ /dev/null
@@ -1,794 +0,0 @@
-package spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{Comparator, PriorityQueue, Random, UUID}
-
-import scala.collection.mutable.{Map, Set}
-import scala.math
-
-import spark._
-
-class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging with Serializable {
-
- def value = value_
-
- ChainedBroadcast.synchronized {
- ChainedBroadcast.values.put(uuid, 0, value_)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
- // CHANGED: BlockSize in the Broadcast object is expected to change over time
- @transient var blockSize = Broadcast.BlockSize
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var pqOfSources = new PriorityQueue[SourceInfo]
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var hasCopyInHDFS = false
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Store a persistent copy in HDFS
- // TODO: Turned OFF for now
- // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
- // out.writeObject(value_)
- // out.close()
- // TODO: Fix this at some point
- hasCopyInHDFS = true
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = Broadcast.blockifyObject(value_)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Prepare the value being broadcasted
- // TODO: Refactoring and clean-up required here
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
- }
-
- pqOfSources = new PriorityQueue[SourceInfo]
- val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
- pqOfSources.add(masterSource)
-
- // Register with the Tracker
- while (guidePort == -1) {
- guidePortLock.synchronized {
- guidePortLock.wait()
- }
- }
- ChainedBroadcast.registerValue(uuid, guidePort)
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- ChainedBroadcast.synchronized {
- val cachedVal = ChainedBroadcast.values.get(uuid, 0)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- // Initializing everything because Master will only send null/0 values
- initializeSlaveVariables
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(uuid)
- // If does not succeed, then get from HDFS copy
- if (receptionSucceeded) {
- value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- ChainedBroadcast.values.put(uuid, 0, value_)
- } else {
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- ChainedBroadcast.values.put(uuid, 0, value_)
- fileIn.close()
- }
-
- val time =(System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
- }
- }
-
- private def initializeSlaveVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
- blockSize = -1
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- def getMasterListenPort(variableUUID: UUID): Int = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
-
- var retriesLeft = Broadcast.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out the guide
- clientSocketToTracker =
- new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send UUID and receive masterListenPort
- oosTracker.writeObject(uuid)
- oosTracker.flush()
- masterListenPort = oisTracker.readObject.asInstanceOf[Int]
- } catch {
- case e: Exception => {
- logInfo("getMasterListenPort had a " + e)
- }
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
- retriesLeft -= 1
-
- Thread.sleep(ChainedBroadcast.ranGen.nextInt(
- Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
- Broadcast.MinKnockInterval)
-
- } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
-
- logInfo("Got this guidePort from Tracker: " + masterListenPort)
- return masterListenPort
- }
-
- def receiveBroadcast(variableUUID: UUID): Boolean = {
- val masterListenPort = getMasterListenPort(variableUUID)
-
- if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
- masterListenPort == SourceInfo.TxNotStartedRetry) {
- // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
- // to HDFS anyway when receiveBroadcast returns false
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
- }
-
- var clientSocketToMaster: Socket = null
- var oosMaster: ObjectOutputStream = null
- var oisMaster: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = Broadcast.MaxRetryCount
- do {
- // Connect to Master and send this worker's Information
- clientSocketToMaster =
- new Socket(Broadcast.MasterHostAddress, masterListenPort)
- // TODO: Guiding object connection is reusable
- oosMaster =
- new ObjectOutputStream(clientSocketToMaster.getOutputStream)
- oosMaster.flush()
- oisMaster =
- new ObjectInputStream(clientSocketToMaster.getInputStream)
-
- logInfo("Connected to Master's guiding object")
-
- // Send local source information
- oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
- oosMaster.flush()
-
- // Receive source information from Master
- var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized {
- totalBlocksLock.notifyAll()
- }
- totalBytes = sourceInfo.totalBytes
-
- logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission(sourceInfo)
- val time =(System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Master will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Master
- oosMaster.writeObject(sourceInfo)
-
- if (oisMaster != null) {
- oisMaster.close()
- }
- if (oosMaster != null) {
- oosMaster.close()
- }
- if (clientSocketToMaster != null) {
- clientSocketToMaster.close()
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return(hasBlocks == totalBlocks)
- }
-
- // Tries to receive broadcast from the source and returns Boolean status.
- // This might be called multiple times to retry a defined number of times.
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource =
- new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource =
- new ObjectOutputStream(clientSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource =
- new ObjectInputStream(clientSocketToSource.getInputStream)
-
- logInfo("Inside receiveSingleTransmission")
- logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush()
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime =(System.currentTimeMillis - recvStartTime)
-
- logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized {
- hasBlocksLock.notifyAll()
- }
- }
- } catch {
- case e: Exception => {
- logInfo("receiveSingleTransmission had a " + e)
- }
- } finally {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close()
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized {
- guidePortLock.notifyAll()
- }
-
- try {
- // Don't stop until there is a copy in HDFS
- while (!stopBroadcast || !hasCopyInHDFS) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo("GuideMultipleRequests Timeout.")
-
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // pqOfSources.size - 1, because it includes the Guide itself
- if (pqOfSources.size > 1 &&
- setOfCompletedSources.size == pqOfSources.size - 1) {
- stopBroadcast = true
- }
- }
- }
- if (clientSocket != null) {
- logInfo("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close the socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- ChainedBroadcast.unregisterValue(uuid)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- private def sendStopBroadcastNotifications() {
- pqOfSources.synchronized {
- var pqIter = pqOfSources.iterator
- while (pqIter.hasNext) {
- var sourceInfo = pqIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource =
- new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource =
- new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource =
- new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
- gosSource.writeObject((SourceInfo.StopBroadcast,
- SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logInfo("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid(SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- pqOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource(sourceInfo)
- logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject(selectedSourceInfo)
- oos.flush()
-
- // Add this new(if it can finish) source to the PQ of sources
- thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
- logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
- pqOfSources.add(thisWorkerInfo)
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in pqOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- pqOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert(pqOfSources.contains(selectedSourceInfo))
-
- // Remove first
- pqOfSources.remove(selectedSourceInfo)
- // TODO: Removing a source based on just one failure notification!
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources.synchronized {
- setOfCompletedSources += thisWorkerInfo
- }
-
- selectedSourceInfo.currentLeechers -= 1
-
- // Put it back
- pqOfSources.add(selectedSourceInfo)
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- case e: Exception => {
- // Assuming that exception caused due to receiver worker failure.
- // Remove failed worker from pqOfSources and update leecherCount of
- // corresponding source worker
- pqOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- pqOfSources.remove(selectedSourceInfo)
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- pqOfSources.add(selectedSourceInfo)
- }
-
- // Remove thisWorkerInfo
- if (pqOfSources != null) {
- pqOfSources.remove(thisWorkerInfo)
- }
- }
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // FIXME: Caller must have a synchronized block on pqOfSources
- // FIXME: If a worker fails to get the broadcasted variable from a source and
- // comes back to Master, this function might choose the worker itself as a
- // source tp create a dependency cycle(this worker was put into pqOfSources
- // as a streming source when it first arrived). The length of this cycle can
- // be arbitrarily long.
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- // Select one based on the ordering strategy(e.g., least leechers etc.)
- // take is a blocking call removing the element from PQ
- var selectedSource = pqOfSources.poll
- assert(selectedSource != null)
- // Update leecher count
- selectedSource.currentLeechers += 1
- // Add it back and then return
- pqOfSources.add(selectedSource)
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized {
- listenPortLock.notifyAll()
- }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo("ServeMultipleRequests Timeout.")
- }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run() {
- try {
- logInfo("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- if (sendFrom == SourceInfo.StopBroadcast &&
- sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- // Carry on
- sendObject
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- case e: Exception => {
- logInfo("ServeSingleRequest had a " + e)
- }
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendObject() {
- // Wait till receiving the SourceInfo from Master
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized {
- totalBlocksLock.wait()
- }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized {
- hasBlocksLock.wait()
- }
- }
- try {
- oos.writeObject(arrayOfBlocks(i))
- oos.flush()
- } catch {
- case e: Exception => {
- logInfo("sendObject had a " + e)
- }
- }
- logInfo("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-class ChainedBroadcastFactory
-extends BroadcastFactory {
- def initialize(isMaster: Boolean) {
- ChainedBroadcast.initialize(isMaster)
- }
- def newBroadcast[T](value_ : T, isLocal: Boolean) = {
- new ChainedBroadcast[T](value_, isLocal)
- }
-}
-
-private object ChainedBroadcast
-extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
- var valueToGuidePortMap = Map[UUID, Int]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var isMaster_ = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(isMaster__ : Boolean) {
- synchronized {
- if (!initialized) {
- isMaster_ = isMaster__
-
- if (isMaster) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
- // TODO: Logging the following line makes the Spark framework ID not
- // getting logged, cause it calls logInfo before log4j is initialized
- logInfo("TrackMultipleValues started...")
- }
-
- // Initialize DfsBroadcast to be used for broadcast variable persistence
- DfsBroadcast.initialize
-
- initialized = true
- }
- }
- }
-
- def isMaster = isMaster_
-
- def registerValue(uuid: UUID, guidePort: Int) {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap +=(uuid -> guidePort)
- logInfo("New value registered with the Tracker " + valueToGuidePortMap)
- }
- }
-
- def unregisterValue(uuid: UUID) {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
- logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
- }
- }
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
- logInfo("TrackMultipleValues" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo("TrackMultipleValues Timeout. Stopping listening...")
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
- try {
- val uuid = ois.readObject.asInstanceOf[UUID]
- var guidePort =
- if (valueToGuidePortMap.contains(uuid)) {
- valueToGuidePortMap(uuid)
- } else SourceInfo.TxNotStartedRetry
- logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
- oos.writeObject(guidePort)
- } catch {
- case e: Exception => {
- logInfo("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-}
diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
deleted file mode 100644
index d18dfb8963..0000000000
--- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
+++ /dev/null
@@ -1,135 +0,0 @@
-package spark.broadcast
-
-import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
-
-import java.io._
-import java.net._
-import java.util.UUID
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
-
-import spark._
-
-class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging with Serializable {
-
- def value = value_
-
- DfsBroadcast.synchronized {
- DfsBroadcast.values.put(uuid, 0, value_)
- }
-
- if (!isLocal) {
- sendBroadcast
- }
-
- def sendBroadcast () {
- val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
- out.writeObject (value_)
- out.close()
- }
-
- // Called by JVM when deserializing an object
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- DfsBroadcast.synchronized {
- val cachedVal = DfsBroadcast.values.get(uuid, 0)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- logInfo( "Started reading Broadcasted variable " + uuid)
- val start = System.nanoTime
-
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- DfsBroadcast.values.put(uuid, 0, value_)
- fileIn.close()
-
- val time = (System.nanoTime - start) / 1e9
- logInfo( "Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
- }
- }
-}
-
-class DfsBroadcastFactory
-extends BroadcastFactory {
- def initialize (isMaster: Boolean) {
- DfsBroadcast.initialize
- }
- def newBroadcast[T] (value_ : T, isLocal: Boolean) =
- new DfsBroadcast[T] (value_, isLocal)
-}
-
-private object DfsBroadcast
-extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
- private var initialized = false
-
- private var fileSystem: FileSystem = null
- private var workDir: String = null
- private var compress: Boolean = false
- private var bufferSize: Int = 65536
-
- def initialize() {
- synchronized {
- if (!initialized) {
- bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- val dfs = System.getProperty("spark.dfs", "file:///")
- if (!dfs.startsWith("file://")) {
- val conf = new Configuration()
- conf.setInt("io.file.buffer.size", bufferSize)
- val rep = System.getProperty("spark.dfs.replication", "3").toInt
- conf.setInt("dfs.replication", rep)
- fileSystem = FileSystem.get(new URI(dfs), conf)
- }
- workDir = System.getProperty("spark.dfs.workDir", "/tmp")
- compress = System.getProperty("spark.compress", "false").toBoolean
-
- initialized = true
- }
- }
- }
-
- private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid)
-
- def openFileForReading(uuid: UUID): InputStream = {
- val fileStream = if (fileSystem != null) {
- fileSystem.open(getPath(uuid))
- } else {
- // Local filesystem
- new FileInputStream(getPath(uuid).toString)
- }
-
- if (compress) {
- // LZF stream does its own buffering
- new LZFInputStream(fileStream)
- } else if (fileSystem == null) {
- new BufferedInputStream(fileStream, bufferSize)
- } else {
- // Hadoop streams do their own buffering
- fileStream
- }
- }
-
- def openFileForWriting(uuid: UUID): OutputStream = {
- val fileStream = if (fileSystem != null) {
- fileSystem.create(getPath(uuid))
- } else {
- // Local filesystem
- new FileOutputStream(getPath(uuid).toString)
- }
-
- if (compress) {
- // LZF stream does its own buffering
- new LZFOutputStream(fileStream)
- } else if (fileSystem == null) {
- new BufferedOutputStream(fileStream, bufferSize)
- } else {
- // Hadoop streams do their own buffering
- fileStream
- }
- }
-}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 6e3dde76bd..e4b1356448 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -10,14 +10,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
+import spark.storage.StorageLevel
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
- HttpBroadcast.synchronized {
- HttpBroadcast.values.put(uuid, 0, value_)
+ HttpBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
}
if (!isLocal) {
@@ -28,31 +29,28 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
- val cachedVal = HttpBroadcast.values.get(uuid, 0)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- logInfo("Started reading broadcast variable " + uuid)
- val start = System.nanoTime
- value_ = HttpBroadcast.read[T](uuid)
- HttpBroadcast.values.put(uuid, 0, value_)
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
+ SparkEnv.get.blockManager.getSingle(uuid.toString) match {
+ case Some(x) => value_ = x.asInstanceOf[T]
+ case None => {
+ logInfo("Started reading broadcast variable " + uuid)
+ val start = System.nanoTime
+ value_ = HttpBroadcast.read[T](uuid)
+ SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
+ }
}
}
}
}
class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isMaster: Boolean) {
- HttpBroadcast.initialize(isMaster)
- }
+ def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
+ def stop() = HttpBroadcast.stop()
}
private object HttpBroadcast extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
private var initialized = false
private var broadcastDir: File = null
@@ -74,6 +72,12 @@ private object HttpBroadcast extends Logging {
}
}
}
+
+ def stop() {
+ if (server != null) {
+ server.stop()
+ }
+ }
private def createServer() {
broadcastDir = Utils.createTempDir()
diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala
new file mode 100644
index 0000000000..d5f5b22461
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala
@@ -0,0 +1,394 @@
+package spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{UUID, Random}
+import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+
+import scala.collection.mutable.Map
+
+import spark._
+
+private object MultiTracker
+extends Logging {
+
+ // Tracker Messages
+ val REGISTER_BROADCAST_TRACKER = 0
+ val UNREGISTER_BROADCAST_TRACKER = 1
+ val FIND_BROADCAST_TRACKER = 2
+
+ // Map to keep track of guides of ongoing broadcasts
+ var valueToGuideMap = Map[UUID, SourceInfo]()
+
+ // Random number generator
+ var ranGen = new Random
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var stopBroadcast = false
+
+ private var trackMV: TrackMultipleValues = null
+
+ def initialize(isMaster__ : Boolean) {
+ synchronized {
+ if (!initialized) {
+
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ trackMV = new TrackMultipleValues
+ trackMV.setDaemon(true)
+ trackMV.start()
+
+ // Set masterHostAddress to the master's IP address for the slaves to read
+ System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress)
+ }
+
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ stopBroadcast = true
+ }
+
+ // Load common parameters
+ private var MasterHostAddress_ = System.getProperty(
+ "spark.MultiTracker.MasterHostAddress", "")
+ private var MasterTrackerPort_ = System.getProperty(
+ "spark.broadcast.masterTrackerPort", "11111").toInt
+ private var BlockSize_ = System.getProperty(
+ "spark.broadcast.blockSize", "4096").toInt * 1024
+ private var MaxRetryCount_ = System.getProperty(
+ "spark.broadcast.maxRetryCount", "2").toInt
+
+ private var TrackerSocketTimeout_ = System.getProperty(
+ "spark.broadcast.trackerSocketTimeout", "50000").toInt
+ private var ServerSocketTimeout_ = System.getProperty(
+ "spark.broadcast.serverSocketTimeout", "10000").toInt
+
+ private var MinKnockInterval_ = System.getProperty(
+ "spark.broadcast.minKnockInterval", "500").toInt
+ private var MaxKnockInterval_ = System.getProperty(
+ "spark.broadcast.maxKnockInterval", "999").toInt
+
+ // Load TreeBroadcast config params
+ private var MaxDegree_ = System.getProperty(
+ "spark.broadcast.maxDegree", "2").toInt
+
+ // Load BitTorrentBroadcast config params
+ private var MaxPeersInGuideResponse_ = System.getProperty(
+ "spark.broadcast.maxPeersInGuideResponse", "4").toInt
+
+ private var MaxChatSlots_ = System.getProperty(
+ "spark.broadcast.maxChatSlots", "4").toInt
+ private var MaxChatTime_ = System.getProperty(
+ "spark.broadcast.maxChatTime", "500").toInt
+ private var MaxChatBlocks_ = System.getProperty(
+ "spark.broadcast.maxChatBlocks", "1024").toInt
+
+ private var EndGameFraction_ = System.getProperty(
+ "spark.broadcast.endGameFraction", "0.95").toDouble
+
+ def isMaster = isMaster_
+
+ // Common config params
+ def MasterHostAddress = MasterHostAddress_
+ def MasterTrackerPort = MasterTrackerPort_
+ def BlockSize = BlockSize_
+ def MaxRetryCount = MaxRetryCount_
+
+ def TrackerSocketTimeout = TrackerSocketTimeout_
+ def ServerSocketTimeout = ServerSocketTimeout_
+
+ def MinKnockInterval = MinKnockInterval_
+ def MaxKnockInterval = MaxKnockInterval_
+
+ // TreeBroadcast configs
+ def MaxDegree = MaxDegree_
+
+ // BitTorrentBroadcast configs
+ def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
+
+ def MaxChatSlots = MaxChatSlots_
+ def MaxChatTime = MaxChatTime_
+ def MaxChatBlocks = MaxChatBlocks_
+
+ def EndGameFraction = EndGameFraction_
+
+ class TrackMultipleValues
+ extends Thread with Logging {
+ override def run() {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(MasterTrackerPort)
+ logInfo("TrackMultipleValues started at " + serverSocket)
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(TrackerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ if (stopBroadcast) {
+ logInfo("Stopping TrackMultipleValues...")
+ }
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run() {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // First, read message type
+ val messageType = ois.readObject.asInstanceOf[Int]
+
+ if (messageType == REGISTER_BROADCAST_TRACKER) {
+ // Receive UUID
+ val uuid = ois.readObject.asInstanceOf[UUID]
+ // Receive hostAddress and listenPort
+ val gInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Add to the map
+ valueToGuideMap.synchronized {
+ valueToGuideMap += (uuid -> gInfo)
+ }
+
+ logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
+
+ // Send dummy ACK
+ oos.writeObject(-1)
+ oos.flush()
+ } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
+ // Receive UUID
+ val uuid = ois.readObject.asInstanceOf[UUID]
+
+ // Remove from the map
+ valueToGuideMap.synchronized {
+ valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault)
+ }
+
+ logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
+
+ // Send dummy ACK
+ oos.writeObject(-1)
+ oos.flush()
+ } else if (messageType == FIND_BROADCAST_TRACKER) {
+ // Receive UUID
+ val uuid = ois.readObject.asInstanceOf[UUID]
+
+ var gInfo =
+ if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
+ else SourceInfo("", SourceInfo.TxNotStartedRetry)
+
+ logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
+
+ // Send reply back
+ oos.writeObject(gInfo)
+ oos.flush()
+ } else {
+ throw new SparkException("Undefined messageType at TrackMultipleValues")
+ }
+ } catch {
+ case e: Exception => {
+ logError("TrackMultipleValues had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+
+ def getGuideInfo(variableUUID: UUID): SourceInfo = {
+ var clientSocketToTracker: Socket = null
+ var oosTracker: ObjectOutputStream = null
+ var oisTracker: ObjectInputStream = null
+
+ var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToDefault)
+
+ var retriesLeft = MultiTracker.MaxRetryCount
+ do {
+ try {
+ // Connect to the tracker to find out GuideInfo
+ clientSocketToTracker =
+ new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort)
+ oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ // Send messageType/intention
+ oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
+ oosTracker.flush()
+
+ // Send UUID and receive GuideInfo
+ oosTracker.writeObject(variableUUID)
+ oosTracker.flush()
+ gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
+ } catch {
+ case e: Exception => logError("getGuideInfo had a " + e)
+ } finally {
+ if (oisTracker != null) {
+ oisTracker.close()
+ }
+ if (oosTracker != null) {
+ oosTracker.close()
+ }
+ if (clientSocketToTracker != null) {
+ clientSocketToTracker.close()
+ }
+ }
+
+ Thread.sleep(MultiTracker.ranGen.nextInt(
+ MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
+ MultiTracker.MinKnockInterval)
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
+
+ logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
+ return gInfo
+ }
+
+ def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
+ val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
+ val oosST = new ObjectOutputStream(socket.getOutputStream)
+ oosST.flush()
+ val oisST = new ObjectInputStream(socket.getInputStream)
+
+ // Send messageType/intention
+ oosST.writeObject(REGISTER_BROADCAST_TRACKER)
+ oosST.flush()
+
+ // Send UUID of this broadcast
+ oosST.writeObject(uuid)
+ oosST.flush()
+
+ // Send this tracker's information
+ oosST.writeObject(gInfo)
+ oosST.flush()
+
+ // Receive ACK and throw it away
+ oisST.readObject.asInstanceOf[Int]
+
+ // Shut stuff down
+ oisST.close()
+ oosST.close()
+ socket.close()
+ }
+
+ def unregisterBroadcast(uuid: UUID) {
+ val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
+ val oosST = new ObjectOutputStream(socket.getOutputStream)
+ oosST.flush()
+ val oisST = new ObjectInputStream(socket.getInputStream)
+
+ // Send messageType/intention
+ oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
+ oosST.flush()
+
+ // Send UUID of this broadcast
+ oosST.writeObject(uuid)
+ oosST.flush()
+
+ // Receive ACK and throw it away
+ oisST.readObject.asInstanceOf[Int]
+
+ // Shut stuff down
+ oisST.close()
+ oosST.close()
+ socket.close()
+ }
+
+ // Helper method to convert an object to Array[BroadcastBlock]
+ def blockifyObject[IN](obj: IN): VariableInfo = {
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream(baos)
+ oos.writeObject(obj)
+ oos.close()
+ baos.close()
+ val byteArray = baos.toByteArray
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BlockSize)
+ if (byteArray.length % BlockSize != 0)
+ blockNum += 1
+
+ var retVal = new Array[BroadcastBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BlockSize)) {
+ val thisBlockSize = math.min(BlockSize, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
+ variableInfo.hasBlocks = blockNum
+
+ return variableInfo
+ }
+
+ // Helper method to convert Array[BroadcastBlock] to object
+ def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
+ totalBytes: Int,
+ totalBlocks: Int): OUT = {
+
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BlockSize, arrayOfBlocks(i).byteArray.length)
+ }
+ byteArrayToObject(retByteArray)
+ }
+
+ private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
+ }
+ val retVal = in.readObject.asInstanceOf[OUT]
+ in.close()
+ return retVal
+ }
+}
+
+case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
+extends Serializable
+
+case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+extends Serializable {
+ @transient var hasBlocks = 0
+}
diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala
index 09907f4ee7..f90385fd47 100644
--- a/core/src/main/scala/spark/broadcast/SourceInfo.scala
+++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala
@@ -6,15 +6,11 @@ import spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
- *
- * CHANGED: Keep track of the blockSize for THIS broadcast variable.
- * Broadcast.BlockSize is expected to be updated across different broadcasts
*/
case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
- totalBytes: Int = SourceInfo.UnusedParam,
- blockSize: Int = Broadcast.BlockSize)
+ totalBytes: Int = SourceInfo.UnusedParam)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
@@ -33,8 +29,8 @@ extends Comparable[SourceInfo] with Logging {
object SourceInfo {
// Constants for special values of listenPort
val TxNotStartedRetry = -1
- val TxOverGoToHDFS = 0
+ val TxOverGoToDefault = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index f5527b6ec9..6928253537 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -8,22 +8,21 @@ import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import spark._
+import spark.storage.StorageLevel
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
- TreeBroadcast.synchronized {
- TreeBroadcast.values.put(uuid, 0, value_)
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
- // CHANGED: BlockSize in the Broadcast object is expected to change over time
- @transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@@ -39,7 +38,6 @@ extends Broadcast[T] with Logging with Serializable {
@transient var listenPort = -1
@transient var guidePort = -1
- @transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
@@ -50,19 +48,10 @@ extends Broadcast[T] with Logging with Serializable {
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
- // Store a persistent copy in HDFS
- // TODO: Turned OFF for now
- // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
- // out.writeObject(value_)
- // out.close()
- // TODO: Fix this at some point
- hasCopyInHDFS = true
-
// Create a variableInfo object and store it in valueInfos
- var variableInfo = Broadcast.blockifyObject(value_)
+ var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
- // TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
@@ -75,9 +64,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER guideMR is created
while (guidePort == -1) {
- guidePortLock.synchronized {
- guidePortLock.wait()
- }
+ guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
@@ -87,63 +74,59 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER serveMR is created
while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
+ listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
listOfSources += masterSource
// Register with the Tracker
- TreeBroadcast.registerValue(uuid, guidePort)
+ MultiTracker.registerBroadcast(uuid,
+ SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
- TreeBroadcast.synchronized {
- val cachedVal = TreeBroadcast.values.get(uuid, 0)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- // Initializing everything because Master will only send null/0 values
- initializeSlaveVariables
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(uuid)
- // If does not succeed, then get from HDFS copy
- if (receptionSucceeded) {
- value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- TreeBroadcast.values.put(uuid, 0, value_)
- } else {
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- TreeBroadcast.values.put(uuid, 0, value_)
- fileIn.close()
- }
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.getSingle(uuid.toString) match {
+ case Some(x) => x.asInstanceOf[T]
+ case None => {
+ logInfo("Started reading broadcast variable " + uuid)
+ // Initializing everything because Master will only send null/0 values
+ // Only the 1st worker in a node can be here. Others will get from cache
+ initializeWorkerVariables
+
+ logInfo("Local host address: " + hostAddress)
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(uuid)
+ if (receptionSucceeded) {
+ value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
+ } else {
+ logError("Reading Broadcasted variable " + uuid + " failed")
+ }
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
}
}
}
- private def initializeSlaveVariables() {
+ private def initializeWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
- blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
@@ -157,72 +140,17 @@ extends Broadcast[T] with Logging with Serializable {
stopBroadcast = false
}
- def getMasterListenPort(variableUUID: UUID): Int = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
-
- var retriesLeft = Broadcast.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out the guide
- clientSocketToTracker =
- new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send UUID and receive masterListenPort
- oosTracker.writeObject(uuid)
- oosTracker.flush()
- masterListenPort = oisTracker.readObject.asInstanceOf[Int]
- } catch {
- case e: Exception => {
- logInfo("getMasterListenPort had a " + e)
- }
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
- retriesLeft -= 1
-
- Thread.sleep(TreeBroadcast.ranGen.nextInt(
- Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
- Broadcast.MinKnockInterval)
-
- } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
-
- logInfo("Got this guidePort from Tracker: " + masterListenPort)
- return masterListenPort
- }
-
def receiveBroadcast(variableUUID: UUID): Boolean = {
- val masterListenPort = getMasterListenPort(variableUUID)
-
- if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
- masterListenPort == SourceInfo.TxNotStartedRetry) {
- // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
- // to HDFS anyway when receiveBroadcast returns false
+ val gInfo = MultiTracker.getGuideInfo(variableUUID)
+
+ if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait()
- }
+ listenPortLock.synchronized { listenPortLock.wait() }
}
var clientSocketToMaster: Socket = null
@@ -231,19 +159,15 @@ extends Broadcast[T] with Logging with Serializable {
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
- var retriesLeft = Broadcast.MaxRetryCount
+ var retriesLeft = MultiTracker.MaxRetryCount
do {
// Connect to Master and send this worker's Information
- clientSocketToMaster =
- new Socket(Broadcast.MasterHostAddress, masterListenPort)
- // TODO: Guiding object connection is reusable
- oosMaster =
- new ObjectOutputStream(clientSocketToMaster.getOutputStream)
+ clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort)
+ oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
- oisMaster =
- new ObjectInputStream(clientSocketToMaster.getInputStream)
+ oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream)
- logInfo("Connected to Master's guiding object")
+ logDebug("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
@@ -253,13 +177,10 @@ extends Broadcast[T] with Logging with Serializable {
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized {
- totalBlocksLock.notifyAll()
- }
+ totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
- blockSize = sourceInfo.blockSize
- logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+ logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
@@ -289,8 +210,10 @@ extends Broadcast[T] with Logging with Serializable {
return (hasBlocks == totalBlocks)
}
- // Tries to receive broadcast from the source and returns Boolean status.
- // This might be called multiple times to retry a defined number of times.
+ /**
+ * Tries to receive broadcast from the source and returns Boolean status.
+ * This might be called multiple times to retry a defined number of times.
+ */
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
@@ -299,16 +222,13 @@ extends Broadcast[T] with Logging with Serializable {
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
- clientSocketToSource =
- new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource =
- new ObjectOutputStream(clientSocketToSource.getOutputStream)
+ clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
- oisSource =
- new ObjectInputStream(clientSocketToSource.getInputStream)
+ oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
- logInfo("Inside receiveSingleTransmission")
- logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+ logDebug("Inside receiveSingleTransmission")
+ logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
@@ -319,20 +239,17 @@ extends Broadcast[T] with Logging with Serializable {
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
- logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
+ logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
+
// Set to true if at least one block is received
receptionSucceeded = true
- hasBlocksLock.synchronized {
- hasBlocksLock.notifyAll()
- }
+ hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
}
} catch {
- case e: Exception => {
- logInfo("receiveSingleTransmission had a " + e)
- }
+ case e: Exception => logError("receiveSingleTransmission had a " + e)
} finally {
if (oisSource != null) {
oisSource.close()
@@ -361,24 +278,22 @@ extends Broadcast[T] with Logging with Serializable {
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
- guidePortLock.synchronized {
- guidePortLock.notifyAll()
- }
+ guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
- // Don't stop until there is a copy in HDFS
- while (!stopBroadcast || !hasCopyInHDFS) {
+ while (!stopBroadcast) {
var clientSocket: Socket = null
try {
- serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
- logInfo("GuideMultipleRequests Timeout.")
+ logError("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
+ // everyone connected so far are done.
+ // Comparing with listOfSources.size - 1, because the Guide itself
+ // is included
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
@@ -386,7 +301,7 @@ extends Broadcast[T] with Logging with Serializable {
}
}
if (clientSocket != null) {
- logInfo("Guide: Accepted new client connection: " + clientSocket)
+ logDebug("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
@@ -399,14 +314,13 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
- TreeBroadcast.unregisterValue(uuid)
+ MultiTracker.unregisterBroadcast(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
-
// Shutdown the thread pool
threadPool.shutdown()
}
@@ -423,21 +337,17 @@ extends Broadcast[T] with Logging with Serializable {
try {
// Connect to the source
- guideSocketToSource =
- new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource =
- new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
- gisSource =
- new ObjectInputStream(guideSocketToSource.getInputStream)
+ gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
- // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
- gosSource.writeObject((SourceInfo.StopBroadcast,
- SourceInfo.StopBroadcast))
+ // Send stopBroadcast signal
+ gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
- logInfo("sendStopBroadcastNotifications had a " + e)
+ logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
@@ -473,14 +383,14 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
- logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
+ logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
- logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo)
+ sourceInfo.listenPort, totalBlocks, totalBytes)
+ logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
@@ -492,9 +402,9 @@ extends Broadcast[T] with Logging with Serializable {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
- // Remove first
+ // Remove first
+ // (Currently removing a source based on just one failure notification!)
listOfSources = listOfSources - selectedSourceInfo
- // TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
@@ -503,17 +413,13 @@ extends Broadcast[T] with Logging with Serializable {
setOfCompletedSources += thisWorkerInfo
}
+ // Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
-
- // Put it back
listOfSources += selectedSourceInfo
}
}
} catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close() everything up
case e: Exception => {
- // Assuming that exception caused due to receiver worker failure.
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
@@ -532,27 +438,23 @@ extends Broadcast[T] with Logging with Serializable {
}
}
} finally {
+ logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
- // FIXME: Caller must have a synchronized block on listOfSources
- // FIXME: If a worker fails to get the broadcasted variable from a source
- // and comes back to the Master, this function might choose the worker
- // itself as a source to create a dependency cycle (this worker was put
- // into listOfSources as a streming source when it first arrived). The
- // length of this cycle can be arbitrarily long.
+ // Assuming the caller to have a synchronized block on listOfSources
+ // Select one with the most leechers. This will level-wise fill the tree
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- // Select one with the most leechers. This will level-wise fill the tree
-
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
- if (source != skipSourceInfo &&
- source.currentLeechers < Broadcast.MaxDegree &&
+ if ((source.hostAddress != skipSourceInfo.hostAddress ||
+ source.listenPort != skipSourceInfo.listenPort) &&
+ source.currentLeechers < MultiTracker.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
@@ -561,7 +463,6 @@ extends Broadcast[T] with Logging with Serializable {
// Update leecher count
selectedSource.currentLeechers += 1
-
return selectedSource
}
}
@@ -569,35 +470,33 @@ extends Broadcast[T] with Logging with Serializable {
class ServeMultipleRequests
extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
+
+ var threadPool = Utils.newDaemonCachedThreadPool()
+
+ override def run() {
+ var serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
+
logInfo("ServeMultipleRequests started with " + serverSocket)
- listenPortLock.synchronized {
- listenPortLock.notifyAll()
- }
+ listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
- serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
- case e: Exception => {
- logInfo("ServeMultipleRequests Timeout.")
- }
+ case e: Exception => logError("ServeMultipleRequests Timeout.")
}
+
if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection: " + clientSocket)
+ logDebug("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
- // In failure, close() socket here; else, the thread will close() it
+ // In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
@@ -608,7 +507,6 @@ extends Broadcast[T] with Logging with Serializable {
serverSocket.close()
}
}
-
// Shutdown the thread pool
threadPool.shutdown()
}
@@ -631,19 +529,14 @@ extends Broadcast[T] with Logging with Serializable {
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
- if (sendFrom == SourceInfo.StopBroadcast &&
- sendUntil == SourceInfo.StopBroadcast) {
+ // If not a valid range, stop broadcast
+ if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
- // Carry on
sendObject
}
} catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close() everything up
- case e: Exception => {
- logInfo("ServeSingleRequest had a " + e)
- }
+ case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
@@ -655,26 +548,20 @@ extends Broadcast[T] with Logging with Serializable {
private def sendObject() {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
- totalBlocksLock.synchronized {
- totalBlocksLock.wait()
- }
+ totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
- hasBlocksLock.synchronized {
- hasBlocksLock.wait()
- }
+ hasBlocksLock.synchronized { hasBlocksLock.wait() }
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
- case e: Exception => {
- logInfo("sendObject had a " + e)
- }
+ case e: Exception => logError("sendObject had a " + e)
}
- logInfo("Sent block: " + i + " to " + clientSocket)
+ logDebug("Sent block: " + i + " to " + clientSocket)
}
}
}
@@ -683,124 +570,7 @@ extends Broadcast[T] with Logging with Serializable {
class TreeBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster)
- def newBroadcast[T](value_ : T, isLocal: Boolean) =
- new TreeBroadcast[T](value_, isLocal)
-}
-
-private object TreeBroadcast
-extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
- var valueToGuidePortMap = Map[UUID, Int]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var isMaster_ = false
-
- private var trackMV: TrackMultipleValues = null
-
- private var MaxDegree_ : Int = 2
-
- def initialize(isMaster__ : Boolean) {
- synchronized {
- if (!initialized) {
- isMaster_ = isMaster__
-
- if (isMaster) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
- // TODO: Logging the following line makes the Spark framework ID not
- // getting logged, cause it calls logInfo before log4j is initialized
- logInfo("TrackMultipleValues started...")
- }
-
- // Initialize DfsBroadcast to be used for broadcast variable persistence
- DfsBroadcast.initialize
-
- initialized = true
- }
- }
- }
-
- def isMaster = isMaster_
-
- def registerValue(uuid: UUID, guidePort: Int) {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap += (uuid -> guidePort)
- logInfo("New value registered with the Tracker " + valueToGuidePortMap)
- }
- }
-
- def unregisterValue(uuid: UUID) {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
- logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
- }
- }
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
- logInfo("TrackMultipleValues" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo("TrackMultipleValues Timeout. Stopping listening...")
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
- try {
- val uuid = ois.readObject.asInstanceOf[UUID]
- var guidePort =
- if (valueToGuidePortMap.contains(uuid)) {
- valueToGuidePortMap(uuid)
- } else SourceInfo.TxNotStartedRetry
- logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
- oos.writeObject(guidePort)
- } catch {
- case e: Exception => {
- logInfo("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close() socket here; else, client thread will close()
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
+ def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
+ def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
+ def stop() = MultiTracker.stop
}
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index e3958cec51..9e335c25f7 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -35,8 +35,6 @@ class Executor extends Logging {
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
- // Old stuff that isn't yet using env
- Broadcast.initialize(false)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()