diff options
Diffstat (limited to 'core')
10 files changed, 564 insertions, 1814 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index c3e93f964e..bfd3e8d732 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -65,7 +65,7 @@ class SparkContext( System.setProperty("spark.master.port", "0") } - private val isLocal = (master == "local" || master.startsWith("local[")) + private val isLocal = (master == "local" || master.startsWith("local[")) && !master.startsWith("localhost") // Create the Spark execution environment (cache, map output tracker, etc) val env = SparkEnv.createFromSystemProperties( diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index e009d4e7db..3466d663af 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -2,7 +2,7 @@ package spark.broadcast import java.io._ import java.net._ -import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID} +import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ListBuffer, Map, Set} @@ -15,8 +15,8 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ - BitTorrentBroadcast.synchronized { - BitTorrentBroadcast.values.put(uuid, 0, value_) + Broadcast.synchronized { + Broadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -25,8 +25,6 @@ extends Broadcast[T] with Logging with Serializable { @transient var totalBytes = -1 @transient var totalBlocks = -1 @transient var hasBlocks = new AtomicInteger(0) - // CHANGED: BlockSize in the Broadcast object is expected to change over time - @transient var blockSize = Broadcast.BlockSize // Used ONLY by Master to track how many unique blocks have been sent out @transient var sentBlocks = new AtomicInteger(0) @@ -45,14 +43,10 @@ extends Broadcast[T] with Logging with Serializable { // Used only in Workers @transient var ttGuide: TalkToGuide = null - @transient var rxSpeeds = new SpeedTracker - @transient var txSpeeds = new SpeedTracker - @transient var hostAddress = Utils.localIpAddress @transient var listenPort = -1 @transient var guidePort = -1 - @transient var hasCopyInHDFS = false @transient var stopBroadcast = false // Must call this after all the variables have been created/initialized @@ -63,19 +57,10 @@ extends Broadcast[T] with Logging with Serializable { def sendBroadcast() { logInfo("Local host address: " + hostAddress) - // Store a persistent copy in HDFS - // TODO: Turned OFF for now. Related to persistence - // val out = new ObjectOutputStream(BroadcastCH.openFileForWriting(uuid)) - // out.writeObject(value_) - // out.close() - // FIXME: Fix this at some point - hasCopyInHDFS = true - // Create a variableInfo object and store it in valueInfos - var variableInfo = Broadcast.blockifyObject(value_) + var variableInfo = MultiTracker.blockifyObject(value_) // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here arrayOfBlocks = variableInfo.arrayOfBlocks totalBytes = variableInfo.totalBytes totalBlocks = variableInfo.totalBlocks @@ -95,9 +80,7 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER guideMR is created while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait() - } + guidePortLock.synchronized { guidePortLock.wait() } } serveMR = new ServeMultipleRequests @@ -107,14 +90,12 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER serveMR is created while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Must always come AFTER listenPort is created val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) hasBlocksBitVector.synchronized { masterSource.hasBlocksBitVector = hasBlocksBitVector } @@ -123,19 +104,20 @@ extends Broadcast[T] with Logging with Serializable { listOfSources += masterSource // Register with the Tracker - registerBroadcast(uuid, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes, blockSize)) + MultiTracker.registerBroadcast(uuid, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() - BitTorrentBroadcast.synchronized { - val cachedVal = BitTorrentBroadcast.values.get(uuid, 0) + Broadcast.synchronized { + val cachedVal = Broadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { - // Only the first worker in a node can ever be inside this 'else' + // Initializing everything because Master will only send null/0 values + // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables logInfo("Local host address: " + hostAddress) @@ -149,16 +131,11 @@ extends Broadcast[T] with Logging with Serializable { val start = System.nanoTime val receptionSucceeded = receiveBroadcast(uuid) - // If does not succeed, then get from HDFS copy if (receptionSucceeded) { - value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - BitTorrentBroadcast.values.put(uuid, 0, value_) + value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + Broadcast.values.put(uuid, 0, value_) } else { - // TODO: This part won't work, cause HDFS writing is turned OFF - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BitTorrentBroadcast.values.put(uuid, 0, value_) - fileIn.close() + logError("Reading Broadcasted variable " + uuid + " failed") } val time = (System.nanoTime - start) / 1e9 @@ -175,7 +152,6 @@ extends Broadcast[T] with Logging with Serializable { totalBytes = -1 totalBlocks = -1 hasBlocks = new AtomicInteger(0) - blockSize = -1 listenPortLock = new Object totalBlocksLock = new Object @@ -183,9 +159,6 @@ extends Broadcast[T] with Logging with Serializable { serveMR = null ttGuide = null - rxSpeeds = new SpeedTracker - txSpeeds = new SpeedTracker - hostAddress = Utils.localIpAddress listenPort = -1 @@ -194,75 +167,19 @@ extends Broadcast[T] with Logging with Serializable { stopBroadcast = false } - private def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { - val socket = new Socket(Broadcast.MasterHostAddress, - Broadcast.MasterTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send UUID of this broadcast - oosST.writeObject(uuid) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - private def unregisterBroadcast(uuid: UUID) { - val socket = new Socket(Broadcast.MasterHostAddress, - Broadcast.MasterTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send UUID of this broadcast - oosST.writeObject(uuid) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - private def getLocalSourceInfo: SourceInfo = { // Wait till hostName and listenPort are OK while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Wait till totalBlocks and totalBytes are OK while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait() - } + totalBlocksLock.synchronized { totalBlocksLock.wait() } } var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + hostAddress, listenPort, totalBlocks, totalBytes) localSourceInfo.hasBlocks = hasBlocks.get @@ -274,7 +191,7 @@ extends Broadcast[T] with Logging with Serializable { } // Add new SourceInfo to the listOfSources. Update if it exists already. - // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance + // Optimizing just by OR-ing the BitVectors was BAD for performance private def addToListOfSources(newSourceInfo: SourceInfo) { listOfSources.synchronized { if (listOfSources.contains(newSourceInfo)) { @@ -297,9 +214,9 @@ extends Broadcast[T] with Logging with Serializable { // Keep exchaning information until all blocks have been received while (hasBlocks.get < totalBlocks) { talkOnce - Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) + Thread.sleep(MultiTracker.ranGen.nextInt( + MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + + MultiTracker.MinKnockInterval) } // Talk one more time to let the Guide know of reception completion @@ -334,76 +251,17 @@ extends Broadcast[T] with Logging with Serializable { } } - def getGuideInfo(variableUUID: UUID): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToHDFS) - - var retriesLeft = Broadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(Broadcast.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send UUID and receive GuideInfo - oosTracker.writeObject(uuid) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => { - logInfo("getGuideInfo had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logInfo("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - def receiveBroadcast(variableUUID: UUID): Boolean = { - val gInfo = getGuideInfo(variableUUID) + val gInfo = MultiTracker.getGuideInfo(variableUUID) - if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS || - gInfo.listenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false + if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { return false } // Wait until hostAddress and listenPort are created by the // ServeMultipleRequests thread while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Setup initial states of variables @@ -411,11 +269,8 @@ extends Broadcast[T] with Logging with Serializable { arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) hasBlocksBitVector = new BitSet(totalBlocks) numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll() - } + totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = gInfo.totalBytes - blockSize = gInfo.blockSize // Start ttGuide to periodically talk to the Guide var ttGuide = new TalkToGuide(gInfo) @@ -432,7 +287,7 @@ extends Broadcast[T] with Logging with Serializable { // FIXME: Must fix this. This might never break if broadcast fails. // We should be able to break and send false. Also need to kill threads while (hasBlocks.get < totalBlocks) { - Thread.sleep(Broadcast.MaxKnockInterval) + Thread.sleep(MultiTracker.MaxKnockInterval) } return true @@ -446,11 +301,11 @@ extends Broadcast[T] with Logging with Serializable { private var blocksInRequestBitVector = new BitSet(totalBlocks) override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxRxSlots) + var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) while (hasBlocks.get < totalBlocks) { var numThreadsToCreate = - math.min(listOfSources.size, Broadcast.MaxRxSlots) - + math.min(listOfSources.size, MultiTracker.MaxChatSlots) - threadPool.getActiveCount while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { @@ -466,16 +321,14 @@ extends Broadcast[T] with Logging with Serializable { // Add to peersNowTalking. Remove in the thread. We have to do this // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { - peersNowTalking += peerToTalkTo - } + peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } } numThreadsToCreate = numThreadsToCreate - 1 } // Sleep for a while before starting some more threads - Thread.sleep(Broadcast.MinKnockInterval) + Thread.sleep(MultiTracker.MinKnockInterval) } // Shutdown the thread pool threadPool.shutdown() @@ -512,11 +365,10 @@ extends Broadcast[T] with Logging with Serializable { } } - // TODO: Always pick randomly or randomly pick randomly? - // Now always picking randomly + // Always picking randomly if (curPeer == null && peersNotInUse.size > 0) { // Pick uniformly the i'th required peer - var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size) + var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) var peerIter = peersNotInUse.iterator curPeer = peerIter.next @@ -552,8 +404,8 @@ extends Broadcast[T] with Logging with Serializable { } } - // TODO: A block is rare if there are at most 2 copies of that block - // TODO: This CONSTANT could be a function of the neighborhood size + // A block is considered rare if there are at most 2 copies of that block + // This CONSTANT could be a function of the neighborhood size var rareBlocksIndices = ListBuffer[Int]() for (i <- 0 until totalBlocks) { if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { @@ -587,7 +439,7 @@ extends Broadcast[T] with Logging with Serializable { // Sort the peers based on how many rare blocks they have peersWithRareBlocks.sortBy(_._2) - var randomNumber = BitTorrentBroadcast.ranGen.nextDouble + var randomNumber = MultiTracker.ranGen.nextDouble var tempSum = 0.0 var i = 0 @@ -625,7 +477,7 @@ extends Broadcast[T] with Logging with Serializable { } var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Broadcast.MaxKnockInterval) + timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) logInfo("TalkToPeer started... => " + peerToTalkTo) @@ -688,8 +540,6 @@ extends Broadcast[T] with Logging with Serializable { hasBlocks.getAndIncrement } - rxSpeeds.addDataPoint(peerToTalkTo, receptionTime) - // Some block(may NOT be blockToAskFor) has arrived. // In any case, blockToAskFor is not in request any more blocksInRequestBitVector.synchronized { @@ -741,8 +591,8 @@ extends Broadcast[T] with Logging with Serializable { } // Include blocks already in transmission ONLY IF - // BitTorrentBroadcast.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) { + // MultiTracker.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { blocksInRequestBitVector.synchronized { needBlocksBitVector.or(blocksInRequestBitVector) } @@ -758,7 +608,7 @@ extends Broadcast[T] with Logging with Serializable { return -1 } else { // Pick uniformly the i'th required block - var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality) + var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) while (i > 0) { @@ -781,8 +631,8 @@ extends Broadcast[T] with Logging with Serializable { } // Include blocks already in transmission ONLY IF - // BitTorrentBroadcast.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) { + // MultiTracker.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { blocksInRequestBitVector.synchronized { needBlocksBitVector.or(blocksInRequestBitVector) } @@ -830,7 +680,7 @@ extends Broadcast[T] with Logging with Serializable { return -1 } else { // Pick uniformly the i'th index - var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size) + var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) return minBlocksIndices(i) } } @@ -848,9 +698,7 @@ extends Broadcast[T] with Logging with Serializable { } // Delete from peersNowTalking - peersNowTalking.synchronized { - peersNowTalking = peersNowTalking - peerToTalkTo - } + peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } } } } @@ -868,16 +716,14 @@ extends Broadcast[T] with Logging with Serializable { guidePort = serverSocket.getLocalPort logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - guidePortLock.synchronized { - guidePortLock.notifyAll() - } + guidePortLock.synchronized { guidePortLock.notifyAll() } try { // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { + while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept() } catch { case e: Exception => { @@ -911,7 +757,7 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - unregisterBroadcast(uuid) + MultiTracker.unregisterBroadcast(uuid) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") @@ -930,13 +776,10 @@ extends Broadcast[T] with Logging with Serializable { try { // Connect to the source - guideSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream(guideSocketToSource.getOutputStream) + guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) gosSource.flush() - gisSource = - new ObjectInputStream(guideSocketToSource.getInputStream) + gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) // Throw away whatever comes in gisSource.readObject.asInstanceOf[SourceInfo] @@ -990,9 +833,7 @@ extends Broadcast[T] with Logging with Serializable { case e: Exception => { // Assuming exception caused by receiver failure: remove if (listOfSources != null) { - listOfSources.synchronized { - listOfSources = listOfSources - sourceInfo - } + listOfSources.synchronized { listOfSources -= sourceInfo } } } } finally { @@ -1009,24 +850,22 @@ extends Broadcast[T] with Logging with Serializable { // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' // then add skipSourceInfo to setOfCompletedSources. Return blank. if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { - setOfCompletedSources += skipSourceInfo - } + setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } return selectedSources } listOfSources.synchronized { - if (listOfSources.size <= Broadcast.MaxPeersInGuideResponse) { + if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { selectedSources = listOfSources.clone } else { - var picksLeft = Broadcast.MaxPeersInGuideResponse + var picksLeft = MultiTracker.MaxPeersInGuideResponse var alreadyPicked = new BitSet(listOfSources.size) while (picksLeft > 0) { var i = -1 do { - i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size) + i = MultiTracker.ranGen.nextInt(listOfSources.size) } while (alreadyPicked.get(i)) var peerIter = listOfSources.iterator @@ -1057,8 +896,8 @@ extends Broadcast[T] with Logging with Serializable { class ServeMultipleRequests extends Thread with Logging { - // Server at most Broadcast.MaxTxSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxTxSlots) + // Server at most MultiTracker.MaxChatSlots peers + var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) override def run() { var serverSocket = new ServerSocket(0) @@ -1066,15 +905,13 @@ extends Broadcast[T] with Logging with Serializable { logInfo("ServeMultipleRequests started with " + serverSocket) - listenPortLock.synchronized { - listenPortLock.notifyAll() - } + listenPortLock.synchronized { listenPortLock.notifyAll() } try { while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept() } catch { case e: Exception => { @@ -1087,9 +924,7 @@ extends Broadcast[T] with Logging with Serializable { threadPool.execute(new ServeSingleRequest(clientSocket)) } catch { // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } + case ioe: IOException => clientSocket.close() } } } @@ -1125,14 +960,13 @@ extends Broadcast[T] with Logging with Serializable { if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { stopBroadcast = true } else { - // Carry on addToListOfSources(rxSourceInfo) } val startTime = System.currentTimeMillis var curTime = startTime var keepSending = true - var numBlocksToSend = Broadcast.MaxChatBlocks + var numBlocksToSend = MultiTracker.MaxChatBlocks while (!stopBroadcast && keepSending && numBlocksToSend > 0) { // Receive which block to send @@ -1140,7 +974,7 @@ extends Broadcast[T] with Logging with Serializable { // If it is master AND at least one copy of each block has not been // sent out already, MODIFY blockToSend - if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) { + if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) { blockToSend = sentBlocks.getAndIncrement } @@ -1157,22 +991,16 @@ extends Broadcast[T] with Logging with Serializable { curTime = System.currentTimeMillis // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= Broadcast.MaxChatTime && + if (curTime - startTime >= MultiTracker.MaxChatTime && threadPool.getQueue.size > 0) { keepSending = false } } } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - // Exception can happen if the receiver stops receiving - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } + case e: Exception => logInfo("ServeSingleRequest had a " + e) } finally { logInfo("ServeSingleRequest is closing streams and sockets") ois.close() - // TODO: The following line causes a "java.net.SocketException: Socket closed" oos.close() clientSocket.close() } @@ -1183,9 +1011,7 @@ extends Broadcast[T] with Logging with Serializable { oos.writeObject(arrayOfBlocks(blockToSend)) oos.flush() } catch { - case e: Exception => { - logInfo("sendBlock had a " + e) - } + case e: Exception => logInfo("sendBlock had a " + e) } logInfo("Sent block: " + blockToSend + " to " + clientSocket) } @@ -1195,161 +1021,7 @@ extends Broadcast[T] with Logging with Serializable { class BitTorrentBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { - BitTorrentBroadcast.initialize(isMaster) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean) = { - new BitTorrentBroadcast[T](value_, isLocal) - } -} - -private object BitTorrentBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuideMap = Map[UUID, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - // TODO: Logging the following line makes the Spark framework ID not - // getting logged, cause it calls logInfo before log4j is initialized - logInfo("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - // TODO: Think about persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def isMaster = isMaster_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (uuid -> gInfo) - } - - logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToHDFS) - logInfo("Value unregistered from the Tracker " + valueToGuideMap) - } - - logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - - var gInfo = - if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else if (messageType == Broadcast.GET_UPDATED_SHARE) { - // TODO: Not implemented - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } + def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster) + def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal) + def stop() = MultiTracker.stop } diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index eaa9153279..135bc31b72 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -5,6 +5,8 @@ import java.net._ import java.util.{BitSet, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} +import scala.collection.mutable.Map + import spark._ trait Broadcast[T] extends Serializable { @@ -13,24 +15,22 @@ trait Broadcast[T] extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. Possibly a Scala bug! + // readObject having to be 'private' in sub-classes. override def toString = "spark.Broadcast(" + uuid + ")" } object Broadcast extends Logging with Serializable { - // Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - val GET_UPDATED_SHARE = 3 private var initialized = false private var isMaster_ = false private var broadcastFactory: BroadcastFactory = null + // Cache of broadcasted objects + val values = SparkEnv.get.cache.newKeySpace() + // Called by SparkContext or Executor before using Broadcast - def initialize (isMaster__ : Boolean) { + def initialize(isMaster__ : Boolean) { synchronized { if (!initialized) { val broadcastFactoryClass = System.getProperty( @@ -55,6 +55,10 @@ object Broadcast extends Logging with Serializable { } } + def stop() { + broadcastFactory.stop() + } + def getBroadcastFactory: BroadcastFactory = { if (broadcastFactory == null) { throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") @@ -62,163 +66,10 @@ object Broadcast extends Logging with Serializable { broadcastFactory } - // Load common broadcast-related config parameters private var MasterHostAddress_ = System.getProperty( "spark.broadcast.masterHostAddress", "") - private var MasterTrackerPort_ = System.getProperty( - "spark.broadcast.masterTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load ChainedBroadcast config params - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxRxSlots_ = System.getProperty("spark.broadcast.maxRxSlots", "4").toInt - private var MaxTxSlots_ = System.getProperty("spark.broadcast.maxTxSlots", "4").toInt - - private var MaxChatTime_ = System.getProperty("spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty("spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble def isMaster = isMaster_ - - // Common config params - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // ChainedBroadcast configs - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxRxSlots = MaxRxSlots_ - def MaxTxSlots = MaxTxSlots_ - - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - // Helper functions to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / Broadcast.BlockSize) - if (byteArray.length % Broadcast.BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) { - val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper function to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -case class BroadcastBlock (blockID: Int, byteArray: Array[Byte]) extends Serializable - -case class VariableInfo (@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) - extends Serializable { - @transient - var hasBlocks = 0 -} - -class SpeedTracker extends Serializable { - // Mapping 'source' to '(totalTime, numBlocks)' - private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] () - - def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long) { - sourceToSpeedMap.synchronized { - if (!sourceToSpeedMap.contains(srcInfo)) { - sourceToSpeedMap += (srcInfo -> (timeInMillis, 1)) - } else { - val tTnB = sourceToSpeedMap (srcInfo) - sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1)) - } - } - } - - def getTimePerBlock (srcInfo: SourceInfo): Double = { - sourceToSpeedMap.synchronized { - val tTnB = sourceToSpeedMap (srcInfo) - return tTnB._1 / tTnB._2 - } - } - - override def toString = sourceToSpeedMap.toString + def MasterHostAddress = MasterHostAddress_ } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index b18908f789..e341d556bf 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -9,4 +9,5 @@ package spark.broadcast trait BroadcastFactory { def initialize(isMaster: Boolean): Unit def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T] + def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala deleted file mode 100644 index 43290c241f..0000000000 --- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala +++ /dev/null @@ -1,794 +0,0 @@ -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.{Comparator, PriorityQueue, Random, UUID} - -import scala.collection.mutable.{Map, Set} -import scala.math - -import spark._ - -class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { - - def value = value_ - - ChainedBroadcast.synchronized { - ChainedBroadcast.values.put(uuid, 0, value_) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - // CHANGED: BlockSize in the Broadcast object is expected to change over time - @transient var blockSize = Broadcast.BlockSize - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var pqOfSources = new PriorityQueue[SourceInfo] - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var hasCopyInHDFS = false - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Store a persistent copy in HDFS - // TODO: Turned OFF for now - // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid)) - // out.writeObject(value_) - // out.close() - // TODO: Fix this at some point - hasCopyInHDFS = true - - // Create a variableInfo object and store it in valueInfos - var variableInfo = Broadcast.blockifyObject(value_) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } - } - - pqOfSources = new PriorityQueue[SourceInfo] - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) - pqOfSources.add(masterSource) - - // Register with the Tracker - while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait() - } - } - ChainedBroadcast.registerValue(uuid, guidePort) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - ChainedBroadcast.synchronized { - val cachedVal = ChainedBroadcast.values.get(uuid, 0) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Initializing everything because Master will only send null/0 values - initializeSlaveVariables - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(uuid) - // If does not succeed, then get from HDFS copy - if (receptionSucceeded) { - value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - ChainedBroadcast.values.put(uuid, 0, value_) - } else { - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - ChainedBroadcast.values.put(uuid, 0, value_) - fileIn.close() - } - - val time =(System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } - - private def initializeSlaveVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - blockSize = -1 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def getMasterListenPort(variableUUID: UUID): Int = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var masterListenPort: Int = SourceInfo.TxOverGoToHDFS - - var retriesLeft = Broadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out the guide - clientSocketToTracker = - new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send UUID and receive masterListenPort - oosTracker.writeObject(uuid) - oosTracker.flush() - masterListenPort = oisTracker.readObject.asInstanceOf[Int] - } catch { - case e: Exception => { - logInfo("getMasterListenPort had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - retriesLeft -= 1 - - Thread.sleep(ChainedBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) - - } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) - - logInfo("Got this guidePort from Tracker: " + masterListenPort) - return masterListenPort - } - - def receiveBroadcast(variableUUID: UUID): Boolean = { - val masterListenPort = getMasterListenPort(variableUUID) - - if (masterListenPort == SourceInfo.TxOverGoToHDFS || - masterListenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } - } - - var clientSocketToMaster: Socket = null - var oosMaster: ObjectOutputStream = null - var oisMaster: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = Broadcast.MaxRetryCount - do { - // Connect to Master and send this worker's Information - clientSocketToMaster = - new Socket(Broadcast.MasterHostAddress, masterListenPort) - // TODO: Guiding object connection is reusable - oosMaster = - new ObjectOutputStream(clientSocketToMaster.getOutputStream) - oosMaster.flush() - oisMaster = - new ObjectInputStream(clientSocketToMaster.getInputStream) - - logInfo("Connected to Master's guiding object") - - // Send local source information - oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) - oosMaster.flush() - - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll() - } - totalBytes = sourceInfo.totalBytes - - logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time =(System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Master will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Master - oosMaster.writeObject(sourceInfo) - - if (oisMaster != null) { - oisMaster.close() - } - if (oosMaster != null) { - oosMaster.close() - } - if (clientSocketToMaster != null) { - clientSocketToMaster.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return(hasBlocks == totalBlocks) - } - - // Tries to receive broadcast from the source and returns Boolean status. - // This might be called multiple times to retry a defined number of times. - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(clientSocketToSource.getInputStream) - - logInfo("Inside receiveSingleTransmission") - logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime =(System.currentTimeMillis - recvStartTime) - - logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll() - } - } - } catch { - case e: Exception => { - logInfo("receiveSingleTransmission had a " + e) - } - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { - guidePortLock.notifyAll() - } - - try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("GuideMultipleRequests Timeout.") - - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // pqOfSources.size - 1, because it includes the Guide itself - if (pqOfSources.size > 1 && - setOfCompletedSources.size == pqOfSources.size - 1) { - stopBroadcast = true - } - } - } - if (clientSocket != null) { - logInfo("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - ChainedBroadcast.unregisterValue(uuid) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - pqOfSources.synchronized { - var pqIter = pqOfSources.iterator - while (pqIter.hasNext) { - var sourceInfo = pqIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = - new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 - gosSource.writeObject((SourceInfo.StopBroadcast, - SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logInfo("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid(SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logInfo("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new(if it can finish) source to the PQ of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, blockSize) - logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo) - pqOfSources.add(thisWorkerInfo) - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in pqOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(pqOfSources.contains(selectedSourceInfo)) - - // Remove first - pqOfSources.remove(selectedSourceInfo) - // TODO: Removing a source based on just one failure notification! - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - selectedSourceInfo.currentLeechers -= 1 - - // Put it back - pqOfSources.add(selectedSourceInfo) - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - // Assuming that exception caused due to receiver worker failure. - // Remove failed worker from pqOfSources and update leecherCount of - // corresponding source worker - pqOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - pqOfSources.remove(selectedSourceInfo) - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add(selectedSourceInfo) - } - - // Remove thisWorkerInfo - if (pqOfSources != null) { - pqOfSources.remove(thisWorkerInfo) - } - } - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - - // FIXME: Caller must have a synchronized block on pqOfSources - // FIXME: If a worker fails to get the broadcasted variable from a source and - // comes back to Master, this function might choose the worker itself as a - // source tp create a dependency cycle(this worker was put into pqOfSources - // as a streming source when it first arrived). The length of this cycle can - // be arbitrarily long. - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one based on the ordering strategy(e.g., least leechers etc.) - // take is a blocking call removing the element from PQ - var selectedSource = pqOfSources.poll - assert(selectedSource != null) - // Update leecher count - selectedSource.currentLeechers += 1 - // Add it back and then return - pqOfSources.add(selectedSource) - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { - listenPortLock.notifyAll() - } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout.") - } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - if (sendFrom == SourceInfo.StopBroadcast && - sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - // Carry on - sendObject - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Master - while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait() - } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait() - } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => { - logInfo("sendObject had a " + e) - } - } - logInfo("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -class ChainedBroadcastFactory -extends BroadcastFactory { - def initialize(isMaster: Boolean) { - ChainedBroadcast.initialize(isMaster) - } - def newBroadcast[T](value_ : T, isLocal: Boolean) = { - new ChainedBroadcast[T](value_, isLocal) - } -} - -private object ChainedBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuidePortMap = Map[UUID, Int]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - // TODO: Logging the following line makes the Spark framework ID not - // getting logged, cause it calls logInfo before log4j is initialized - logInfo("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def isMaster = isMaster_ - - def registerValue(uuid: UUID, guidePort: Int) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap +=(uuid -> guidePort) - logInfo("New value registered with the Tracker " + valueToGuidePortMap) - } - } - - def unregisterValue(uuid: UUID) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS - logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var guidePort = - if (valueToGuidePortMap.contains(uuid)) { - valueToGuidePortMap(uuid) - } else SourceInfo.TxNotStartedRetry - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) - oos.writeObject(guidePort) - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - - // Shutdown the thread pool - threadPool.shutdown() - } - } -} diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala deleted file mode 100644 index d18dfb8963..0000000000 --- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala +++ /dev/null @@ -1,135 +0,0 @@ -package spark.broadcast - -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} - -import java.io._ -import java.net._ -import java.util.UUID - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} - -import spark._ - -class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { - - def value = value_ - - DfsBroadcast.synchronized { - DfsBroadcast.values.put(uuid, 0, value_) - } - - if (!isLocal) { - sendBroadcast - } - - def sendBroadcast () { - val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) - out.writeObject (value_) - out.close() - } - - // Called by JVM when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - DfsBroadcast.synchronized { - val cachedVal = DfsBroadcast.values.get(uuid, 0) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - logInfo( "Started reading Broadcasted variable " + uuid) - val start = System.nanoTime - - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - DfsBroadcast.values.put(uuid, 0, value_) - fileIn.close() - - val time = (System.nanoTime - start) / 1e9 - logInfo( "Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } -} - -class DfsBroadcastFactory -extends BroadcastFactory { - def initialize (isMaster: Boolean) { - DfsBroadcast.initialize - } - def newBroadcast[T] (value_ : T, isLocal: Boolean) = - new DfsBroadcast[T] (value_, isLocal) -} - -private object DfsBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - private var initialized = false - - private var fileSystem: FileSystem = null - private var workDir: String = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - - def initialize() { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val dfs = System.getProperty("spark.dfs", "file:///") - if (!dfs.startsWith("file://")) { - val conf = new Configuration() - conf.setInt("io.file.buffer.size", bufferSize) - val rep = System.getProperty("spark.dfs.replication", "3").toInt - conf.setInt("dfs.replication", rep) - fileSystem = FileSystem.get(new URI(dfs), conf) - } - workDir = System.getProperty("spark.dfs.workDir", "/tmp") - compress = System.getProperty("spark.compress", "false").toBoolean - - initialized = true - } - } - } - - private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) - - def openFileForReading(uuid: UUID): InputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.open(getPath(uuid)) - } else { - // Local filesystem - new FileInputStream(getPath(uuid).toString) - } - - if (compress) { - // LZF stream does its own buffering - new LZFInputStream(fileStream) - } else if (fileSystem == null) { - new BufferedInputStream(fileStream, bufferSize) - } else { - // Hadoop streams do their own buffering - fileStream - } - } - - def openFileForWriting(uuid: UUID): OutputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.create(getPath(uuid)) - } else { - // Local filesystem - new FileOutputStream(getPath(uuid).toString) - } - - if (compress) { - // LZF stream does its own buffering - new LZFOutputStream(fileStream) - } else if (fileSystem == null) { - new BufferedOutputStream(fileStream, bufferSize) - } else { - // Hadoop streams do their own buffering - fileStream - } - } -} diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 6e3dde76bd..ec8749c4a5 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -16,8 +16,8 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ - HttpBroadcast.synchronized { - HttpBroadcast.values.put(uuid, 0, value_) + Broadcast.synchronized { + Broadcast.values.put(uuid, 0, value_) } if (!isLocal) { @@ -28,14 +28,14 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { - val cachedVal = HttpBroadcast.values.get(uuid, 0) + val cachedVal = Broadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { logInfo("Started reading broadcast variable " + uuid) val start = System.nanoTime value_ = HttpBroadcast.read[T](uuid) - HttpBroadcast.values.put(uuid, 0, value_) + Broadcast.values.put(uuid, 0, value_) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + uuid + " took " + time + " s") } @@ -44,15 +44,12 @@ extends Broadcast[T] with Logging with Serializable { } class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { - HttpBroadcast.initialize(isMaster) - } + def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster) def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal) + def stop() = HttpBroadcast.stop() } private object HttpBroadcast extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - private var initialized = false private var broadcastDir: File = null @@ -74,6 +71,12 @@ private object HttpBroadcast extends Logging { } } } + + def stop() { + if (server != null) { + server.stop() + } + } private def createServer() { broadcastDir = Utils.createTempDir() diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala new file mode 100644 index 0000000000..10b90526e8 --- /dev/null +++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala @@ -0,0 +1,389 @@ +package spark.broadcast + +import java.io._ +import java.net._ +import java.util.{UUID, Random} +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} + +import scala.collection.mutable.Map + +import spark._ + +private object MultiTracker +extends Logging { + + // Tracker Messages + val REGISTER_BROADCAST_TRACKER = 0 + val UNREGISTER_BROADCAST_TRACKER = 1 + val FIND_BROADCAST_TRACKER = 2 + + // Map to keep track of guides of ongoing broadcasts + var valueToGuideMap = Map[UUID, SourceInfo]() + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var stopBroadcast = false + + private var trackMV: TrackMultipleValues = null + + def initialize(isMaster__ : Boolean) { + synchronized { + if (!initialized) { + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start() + } + + initialized = true + } + } + } + + def stop() { + stopBroadcast = true + } + + // Load common parameters + private var MasterTrackerPort_ = System.getProperty( + "spark.broadcast.masterTrackerPort", "11111").toInt + private var BlockSize_ = System.getProperty( + "spark.broadcast.blockSize", "4096").toInt * 1024 + private var MaxRetryCount_ = System.getProperty( + "spark.broadcast.maxRetryCount", "2").toInt + + private var TrackerSocketTimeout_ = System.getProperty( + "spark.broadcast.trackerSocketTimeout", "50000").toInt + private var ServerSocketTimeout_ = System.getProperty( + "spark.broadcast.serverSocketTimeout", "10000").toInt + + private var MinKnockInterval_ = System.getProperty( + "spark.broadcast.minKnockInterval", "500").toInt + private var MaxKnockInterval_ = System.getProperty( + "spark.broadcast.maxKnockInterval", "999").toInt + + // Load TreeBroadcast config params + private var MaxDegree_ = System.getProperty( + "spark.broadcast.maxDegree", "2").toInt + + // Load BitTorrentBroadcast config params + private var MaxPeersInGuideResponse_ = System.getProperty( + "spark.broadcast.maxPeersInGuideResponse", "4").toInt + + private var MaxChatSlots_ = System.getProperty( + "spark.broadcast.maxChatSlots", "4").toInt + private var MaxChatTime_ = System.getProperty( + "spark.broadcast.maxChatTime", "500").toInt + private var MaxChatBlocks_ = System.getProperty( + "spark.broadcast.maxChatBlocks", "1024").toInt + + private var EndGameFraction_ = System.getProperty( + "spark.broadcast.endGameFraction", "0.95").toDouble + + def isMaster = isMaster_ + + // Common config params + def MasterTrackerPort = MasterTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + // TreeBroadcast configs + def MaxDegree = MaxDegree_ + + // BitTorrentBroadcast configs + def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ + + def MaxChatSlots = MaxChatSlots_ + def MaxChatTime = MaxChatTime_ + def MaxChatBlocks = MaxChatBlocks_ + + def EndGameFraction = EndGameFraction_ + + class TrackMultipleValues + extends Thread with Logging { + override def run() { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(MasterTrackerPort) + logInfo("TrackMultipleValues" + serverSocket) + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(TrackerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + if (stopBroadcast) { + logInfo("Stopping TrackMultipleValues...") + } + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run() { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // First, read message type + val messageType = ois.readObject.asInstanceOf[Int] + + if (messageType == REGISTER_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + // Receive hostAddress and listenPort + val gInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Add to the map + valueToGuideMap.synchronized { + valueToGuideMap += (uuid -> gInfo) + } + + logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + + // Remove from the map + valueToGuideMap.synchronized { + valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault) + logInfo("Value unregistered from the Tracker " + valueToGuideMap) + } + + logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == FIND_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + + var gInfo = + if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) + else SourceInfo("", SourceInfo.TxNotStartedRetry) + + logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) + + // Send reply back + oos.writeObject(gInfo) + oos.flush() + } else { + throw new SparkException("Undefined messageType at TrackMultipleValues") + } + } catch { + case e: Exception => { + logInfo("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + def getGuideInfo(variableUUID: UUID): SourceInfo = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToDefault) + + var retriesLeft = MultiTracker.MaxRetryCount + do { + try { + // Connect to the tracker to find out GuideInfo + clientSocketToTracker = + new Socket(Broadcast.MasterHostAddress, MultiTracker.MasterTrackerPort) + oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + // Send messageType/intention + oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) + oosTracker.flush() + + // Send UUID and receive GuideInfo + oosTracker.writeObject(variableUUID) + oosTracker.flush() + gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] + } catch { + case e: Exception => logInfo("getGuideInfo had a " + e) + } finally { + if (oisTracker != null) { + oisTracker.close() + } + if (oosTracker != null) { + oosTracker.close() + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close() + } + } + + Thread.sleep(MultiTracker.ranGen.nextInt( + MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + + MultiTracker.MinKnockInterval) + + retriesLeft -= 1 + } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) + + logInfo("Got this guidePort from Tracker: " + gInfo.listenPort) + return gInfo + } + + def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { + val socket = new Socket(Broadcast.MasterHostAddress, MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(REGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send UUID of this broadcast + oosST.writeObject(uuid) + oosST.flush() + + // Send this tracker's information + oosST.writeObject(gInfo) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + def unregisterBroadcast(uuid: UUID) { + val socket = new Socket(Broadcast.MasterHostAddress, MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send UUID of this broadcast + oosST.writeObject(uuid) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + // Helper method to convert an object to Array[BroadcastBlock] + def blockifyObject[IN](obj: IN): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(obj) + oos.close() + baos.close() + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BlockSize) + if (byteArray.length % BlockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BlockSize)) { + val thisBlockSize = math.min(BlockSize, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + // Helper method to convert Array[BroadcastBlock] to object + def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], + totalBytes: Int, + totalBlocks: Int): OUT = { + + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject(retByteArray) + } + + private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) + } + val retVal = in.readObject.asInstanceOf[OUT] + in.close() + return retVal + } +} + +case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) +extends Serializable + +case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], + totalBlocks: Int, + totalBytes: Int) +extends Serializable { + @transient var hasBlocks = 0 +} diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala index 09907f4ee7..f90385fd47 100644 --- a/core/src/main/scala/spark/broadcast/SourceInfo.scala +++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala @@ -6,15 +6,11 @@ import spark._ /** * Used to keep and pass around information of peers involved in a broadcast - * - * CHANGED: Keep track of the blockSize for THIS broadcast variable. - * Broadcast.BlockSize is expected to be updated across different broadcasts */ case class SourceInfo (hostAddress: String, listenPort: Int, totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam, - blockSize: Int = Broadcast.BlockSize) + totalBytes: Int = SourceInfo.UnusedParam) extends Comparable[SourceInfo] with Logging { var currentLeechers = 0 @@ -33,8 +29,8 @@ extends Comparable[SourceInfo] with Logging { object SourceInfo { // Constants for special values of listenPort val TxNotStartedRetry = -1 - val TxOverGoToHDFS = 0 + val TxOverGoToDefault = 0 // Other constants val StopBroadcast = -2 val UnusedParam = 0 -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index f5527b6ec9..e2b6fcfc7d 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -14,16 +14,14 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ - TreeBroadcast.synchronized { - TreeBroadcast.values.put(uuid, 0, value_) + Broadcast.synchronized { + Broadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @transient var totalBytes = -1 @transient var totalBlocks = -1 @transient var hasBlocks = 0 - // CHANGED: BlockSize in the Broadcast object is expected to change over time - @transient var blockSize = Broadcast.BlockSize @transient var listenPortLock = new Object @transient var guidePortLock = new Object @@ -39,7 +37,6 @@ extends Broadcast[T] with Logging with Serializable { @transient var listenPort = -1 @transient var guidePort = -1 - @transient var hasCopyInHDFS = false @transient var stopBroadcast = false // Must call this after all the variables have been created/initialized @@ -50,19 +47,10 @@ extends Broadcast[T] with Logging with Serializable { def sendBroadcast() { logInfo("Local host address: " + hostAddress) - // Store a persistent copy in HDFS - // TODO: Turned OFF for now - // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid)) - // out.writeObject(value_) - // out.close() - // TODO: Fix this at some point - hasCopyInHDFS = true - // Create a variableInfo object and store it in valueInfos - var variableInfo = Broadcast.blockifyObject(value_) + var variableInfo = MultiTracker.blockifyObject(value_) // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here arrayOfBlocks = variableInfo.arrayOfBlocks totalBytes = variableInfo.totalBytes totalBlocks = variableInfo.totalBlocks @@ -75,9 +63,7 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER guideMR is created while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait() - } + guidePortLock.synchronized { guidePortLock.wait() } } serveMR = new ServeMultipleRequests @@ -87,29 +73,29 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER serveMR is created while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Must always come AFTER listenPort is created val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) listOfSources += masterSource // Register with the Tracker - TreeBroadcast.registerValue(uuid, guidePort) + MultiTracker.registerBroadcast(uuid, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() - TreeBroadcast.synchronized { - val cachedVal = TreeBroadcast.values.get(uuid, 0) + Broadcast.synchronized { + val cachedVal = Broadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { // Initializing everything because Master will only send null/0 values - initializeSlaveVariables + // Only the 1st worker in a node can be here. Others will get from cache + initializeWorkerVariables logInfo("Local host address: " + hostAddress) @@ -121,15 +107,11 @@ extends Broadcast[T] with Logging with Serializable { val start = System.nanoTime val receptionSucceeded = receiveBroadcast(uuid) - // If does not succeed, then get from HDFS copy if (receptionSucceeded) { - value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - TreeBroadcast.values.put(uuid, 0, value_) + value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + Broadcast.values.put(uuid, 0, value_) } else { - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - TreeBroadcast.values.put(uuid, 0, value_) - fileIn.close() + logError("Reading Broadcasted variable " + uuid + " failed") } val time = (System.nanoTime - start) / 1e9 @@ -138,12 +120,11 @@ extends Broadcast[T] with Logging with Serializable { } } - private def initializeSlaveVariables() { + private def initializeWorkerVariables() { arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 hasBlocks = 0 - blockSize = -1 listenPortLock = new Object totalBlocksLock = new Object @@ -157,72 +138,17 @@ extends Broadcast[T] with Logging with Serializable { stopBroadcast = false } - def getMasterListenPort(variableUUID: UUID): Int = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var masterListenPort: Int = SourceInfo.TxOverGoToHDFS - - var retriesLeft = Broadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out the guide - clientSocketToTracker = - new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send UUID and receive masterListenPort - oosTracker.writeObject(uuid) - oosTracker.flush() - masterListenPort = oisTracker.readObject.asInstanceOf[Int] - } catch { - case e: Exception => { - logInfo("getMasterListenPort had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - retriesLeft -= 1 - - Thread.sleep(TreeBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) - - } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) - - logInfo("Got this guidePort from Tracker: " + masterListenPort) - return masterListenPort - } - def receiveBroadcast(variableUUID: UUID): Boolean = { - val masterListenPort = getMasterListenPort(variableUUID) - - if (masterListenPort == SourceInfo.TxOverGoToHDFS || - masterListenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false + val gInfo = MultiTracker.getGuideInfo(variableUUID) + + if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { return false } // Wait until hostAddress and listenPort are created by the // ServeMultipleRequests thread while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } var clientSocketToMaster: Socket = null @@ -231,17 +157,13 @@ extends Broadcast[T] with Logging with Serializable { // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures - var retriesLeft = Broadcast.MaxRetryCount + var retriesLeft = MultiTracker.MaxRetryCount do { // Connect to Master and send this worker's Information - clientSocketToMaster = - new Socket(Broadcast.MasterHostAddress, masterListenPort) - // TODO: Guiding object connection is reusable - oosMaster = - new ObjectOutputStream(clientSocketToMaster.getOutputStream) + clientSocketToMaster = new Socket(Broadcast.MasterHostAddress, gInfo.listenPort) + oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream) oosMaster.flush() - oisMaster = - new ObjectInputStream(clientSocketToMaster.getInputStream) + oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream) logInfo("Connected to Master's guiding object") @@ -253,11 +175,8 @@ extends Broadcast[T] with Logging with Serializable { var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] totalBlocks = sourceInfo.totalBlocks arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll() - } + totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = sourceInfo.totalBytes - blockSize = sourceInfo.blockSize logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) @@ -289,8 +208,10 @@ extends Broadcast[T] with Logging with Serializable { return (hasBlocks == totalBlocks) } - // Tries to receive broadcast from the source and returns Boolean status. - // This might be called multiple times to retry a defined number of times. + /** + * Tries to receive broadcast from the source and returns Boolean status. + * This might be called multiple times to retry a defined number of times. + */ private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { var clientSocketToSource: Socket = null var oosSource: ObjectOutputStream = null @@ -299,13 +220,10 @@ extends Broadcast[T] with Logging with Serializable { var receptionSucceeded = false try { // Connect to the source to get the object itself - clientSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream(clientSocketToSource.getOutputStream) + clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) oosSource.flush() - oisSource = - new ObjectInputStream(clientSocketToSource.getInputStream) + oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) logInfo("Inside receiveSingleTransmission") logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) @@ -323,16 +241,13 @@ extends Broadcast[T] with Logging with Serializable { arrayOfBlocks(hasBlocks) = bcBlock hasBlocks += 1 + // Set to true if at least one block is received receptionSucceeded = true - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll() - } + hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } } } catch { - case e: Exception => { - logInfo("receiveSingleTransmission had a " + e) - } + case e: Exception => logInfo("receiveSingleTransmission had a " + e) } finally { if (oisSource != null) { oisSource.close() @@ -361,24 +276,22 @@ extends Broadcast[T] with Logging with Serializable { guidePort = serverSocket.getLocalPort logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - guidePortLock.synchronized { - guidePortLock.notifyAll() - } + guidePortLock.synchronized { guidePortLock.notifyAll() } try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { + while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept } catch { case e: Exception => { logInfo("GuideMultipleRequests Timeout.") // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself + // everyone connected so far are done. + // Comparing with listOfSources.size - 1, because the Guide itself + // is included if (listOfSources.size > 1 && setOfCompletedSources.size == listOfSources.size - 1) { stopBroadcast = true @@ -399,14 +312,13 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - TreeBroadcast.unregisterValue(uuid) + MultiTracker.unregisterBroadcast(uuid) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") serverSocket.close() } } - // Shutdown the thread pool threadPool.shutdown() } @@ -423,17 +335,13 @@ extends Broadcast[T] with Logging with Serializable { try { // Connect to the source - guideSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream(guideSocketToSource.getOutputStream) + guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) gosSource.flush() - gisSource = - new ObjectInputStream(guideSocketToSource.getInputStream) + gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 - gosSource.writeObject((SourceInfo.StopBroadcast, - SourceInfo.StopBroadcast)) + // Send stopBroadcast signal + gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) gosSource.flush() } catch { case e: Exception => { @@ -479,7 +387,7 @@ extends Broadcast[T] with Logging with Serializable { // Add this new (if it can finish) source to the list of sources thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, blockSize) + sourceInfo.listenPort, totalBlocks, totalBytes) logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo) listOfSources += thisWorkerInfo } @@ -492,9 +400,9 @@ extends Broadcast[T] with Logging with Serializable { // This should work since SourceInfo is a case class assert(listOfSources.contains(selectedSourceInfo)) - // Remove first + // Remove first + // (Currently removing a source based on just one failure notification!) listOfSources = listOfSources - selectedSourceInfo - // TODO: Removing a source based on just one failure notification! // Update sourceInfo and put it back in, IF reception succeeded if (!sourceInfo.receptionFailed) { @@ -503,17 +411,13 @@ extends Broadcast[T] with Logging with Serializable { setOfCompletedSources += thisWorkerInfo } + // Update leecher count and put it back in selectedSourceInfo.currentLeechers -= 1 - - // Put it back listOfSources += selectedSourceInfo } } } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close() everything up case e: Exception => { - // Assuming that exception caused due to receiver worker failure. // Remove failed worker from listOfSources and update leecherCount of // corresponding source worker listOfSources.synchronized { @@ -538,21 +442,16 @@ extends Broadcast[T] with Logging with Serializable { } } - // FIXME: Caller must have a synchronized block on listOfSources - // FIXME: If a worker fails to get the broadcasted variable from a source - // and comes back to the Master, this function might choose the worker - // itself as a source to create a dependency cycle (this worker was put - // into listOfSources as a streming source when it first arrived). The - // length of this cycle can be arbitrarily long. + // Assuming the caller to have a synchronized block on listOfSources + // Select one with the most leechers. This will level-wise fill the tree private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one with the most leechers. This will level-wise fill the tree - var maxLeechers = -1 var selectedSource: SourceInfo = null listOfSources.foreach { source => - if (source != skipSourceInfo && - source.currentLeechers < Broadcast.MaxDegree && + if ((source.hostAddress != skipSourceInfo.hostAddress || + source.listenPort != skipSourceInfo.listenPort) && + source.currentLeechers < MultiTracker.MaxDegree && source.currentLeechers > maxLeechers) { selectedSource = source maxLeechers = source.currentLeechers @@ -561,7 +460,6 @@ extends Broadcast[T] with Logging with Serializable { // Update leecher count selectedSource.currentLeechers += 1 - return selectedSource } } @@ -569,35 +467,33 @@ extends Broadcast[T] with Logging with Serializable { class ServeMultipleRequests extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) + + var threadPool = Utils.newDaemonCachedThreadPool() + + override def run() { + var serverSocket = new ServerSocket(0) listenPort = serverSocket.getLocalPort + logInfo("ServeMultipleRequests started with " + serverSocket) - listenPortLock.synchronized { - listenPortLock.notifyAll() - } + listenPortLock.synchronized { listenPortLock.notifyAll() } try { while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout.") - } + case e: Exception => logInfo("ServeMultipleRequests Timeout.") } + if (clientSocket != null) { logInfo("Serve: Accepted new client connection: " + clientSocket) try { threadPool.execute(new ServeSingleRequest(clientSocket)) } catch { - // In failure, close() socket here; else, the thread will close() it + // In failure, close socket here; else, the thread will close it case ioe: IOException => clientSocket.close() } } @@ -608,7 +504,6 @@ extends Broadcast[T] with Logging with Serializable { serverSocket.close() } } - // Shutdown the thread pool threadPool.shutdown() } @@ -631,19 +526,14 @@ extends Broadcast[T] with Logging with Serializable { sendFrom = rangeToSend._1 sendUntil = rangeToSend._2 - if (sendFrom == SourceInfo.StopBroadcast && - sendUntil == SourceInfo.StopBroadcast) { + // If not a valid range, stop broadcast + if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { stopBroadcast = true } else { - // Carry on sendObject } } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close() everything up - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } + case e: Exception => logInfo("ServeSingleRequest had a " + e) } finally { logInfo("ServeSingleRequest is closing streams and sockets") ois.close() @@ -655,24 +545,18 @@ extends Broadcast[T] with Logging with Serializable { private def sendObject() { // Wait till receiving the SourceInfo from Master while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait() - } + totalBlocksLock.synchronized { totalBlocksLock.wait() } } for (i <- sendFrom until sendUntil) { while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait() - } + hasBlocksLock.synchronized { hasBlocksLock.wait() } } try { oos.writeObject(arrayOfBlocks(i)) oos.flush() } catch { - case e: Exception => { - logInfo("sendObject had a " + e) - } + case e: Exception => logInfo("sendObject had a " + e) } logInfo("Sent block: " + i + " to " + clientSocket) } @@ -683,124 +567,7 @@ extends Broadcast[T] with Logging with Serializable { class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster) - def newBroadcast[T](value_ : T, isLocal: Boolean) = - new TreeBroadcast[T](value_, isLocal) -} - -private object TreeBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuidePortMap = Map[UUID, Int]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - private var MaxDegree_ : Int = 2 - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - // TODO: Logging the following line makes the Spark framework ID not - // getting logged, cause it calls logInfo before log4j is initialized - logInfo("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def isMaster = isMaster_ - - def registerValue(uuid: UUID, guidePort: Int) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap += (uuid -> guidePort) - logInfo("New value registered with the Tracker " + valueToGuidePortMap) - } - } - - def unregisterValue(uuid: UUID) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS - logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var guidePort = - if (valueToGuidePortMap.contains(uuid)) { - valueToGuidePortMap(uuid) - } else SourceInfo.TxNotStartedRetry - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) - oos.writeObject(guidePort) - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close() socket here; else, client thread will close() - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - - // Shutdown the thread pool - threadPool.shutdown() - } - } + def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster) + def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal) + def stop() = MultiTracker.stop } |