diff options
author | Mosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)> | 2010-11-04 22:09:14 -0700 |
---|---|---|
committer | Mosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)> | 2010-11-04 22:09:14 -0700 |
commit | 878d157ce3f3b625b009ca53cf53c2f2481d2698 (patch) | |
tree | 642a717223bc4189304ecab0dbb48241ac40dc32 | |
parent | 10fc66b1c4d13af0fab4a964b133a6c9c02b272b (diff) | |
download | spark-878d157ce3f3b625b009ca53cf53c2f2481d2698.tar.gz spark-878d157ce3f3b625b009ca53cf53c2f2481d2698.tar.bz2 spark-878d157ce3f3b625b009ca53cf53c2f2481d2698.zip |
Graceful shutdown after a single transmission in the swarm is over.
There might still be a problem with the Tracker shutdown. It must be done explicitly by SparkContext.
-rw-r--r-- | conf/java-opts | 2 | ||||
-rw-r--r-- | src/scala/spark/Broadcast.scala | 237 |
2 files changed, 161 insertions, 78 deletions
diff --git a/conf/java-opts b/conf/java-opts index 39bdbb77bb..c3c7a9c0e3 100644 --- a/conf/java-opts +++ b/conf/java-opts @@ -1 +1 @@ --Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=11111 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.ServerSocketTimout=50000 -Dspark.broadcast.MaxChatTime=500 +-Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=11111 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.TrackerSocketTimeout=50000 -Dspark.broadcast.ServerSocketTimout=10000 -Dspark.broadcast.MaxChatTime=500 diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala index e0934e6a6b..62b34d29a3 100644 --- a/src/scala/spark/Broadcast.scala +++ b/src/scala/spark/Broadcast.scala @@ -8,7 +8,7 @@ import com.google.common.collect.MapMaker import java.util.concurrent.{Executors, ExecutorService, ThreadPoolExecutor} -import scala.collection.mutable.{ListBuffer, Map} +import scala.collection.mutable.{ListBuffer, Map, Set} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} @@ -66,6 +66,7 @@ extends BroadcastRecipe with Logging { @transient var guidePort = -1 @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false // Must call this after all the variables have been created/initialized if (!local) { @@ -77,7 +78,8 @@ extends BroadcastRecipe with Logging { // TODO: Turned OFF for now // val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) // out.writeObject (value_) - // out.close + // out.close + // TODO: Fix this at some point hasCopyInHDFS = true // Create a variableInfo object and store it in valueInfos @@ -203,6 +205,8 @@ extends BroadcastRecipe with Logging { listenPort = -1 listOfSources = ListBuffer[SourceInfo] () + + stopBroadcast = false } private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { @@ -270,6 +274,8 @@ extends BroadcastRecipe with Logging { var localSourceInfo = SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes) + + localSourceInfo.hasBlocks = hasBlocks hasBlocksBitVector.synchronized { localSourceInfo.hasBlocksBitVector = hasBlocksBitVector @@ -299,40 +305,48 @@ extends BroadcastRecipe with Logging { } } - class TalkToGuide (gInfo: SourceInfo) + class TalkToGuide (gInfo: SourceInfo) extends Thread with Logging { override def run = { - // Connect to Guide and send this worker's information - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null + // Keep exchaning information until all blocks have been received while (hasBlocks < totalBlocks) { - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream) - oosGuide.flush - oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logInfo("Received suitableSources from Master " + suitableSources) - - addToListOfSources (suitableSources) - - oisGuide.close - oosGuide.close - clientSocketToGuide.close - + talkOnce Thread.sleep ( BroadcastBT.ranGen.nextInt ( BroadcastBT.MaxKnockInterval - BroadcastBT.MinKnockInterval) + BroadcastBT.MinKnockInterval) - } - } + } + + // Talk one more time to let the Guide know of reception completion + talkOnce + } + + // Connect to Guide and send this worker's information + private def talkOnce = { + var clientSocketToGuide: Socket = null + var oosGuide: ObjectOutputStream = null + var oisGuide: ObjectInputStream = null + + clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) + oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream) + oosGuide.flush + oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream) + + // Send local information + oosGuide.writeObject(getLocalSourceInfo) + oosGuide.flush + + // Receive source information from Guide + var suitableSources = + oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] + logInfo("Received suitableSources from Master " + suitableSources) + + addToListOfSources (suitableSources) + + oisGuide.close + oosGuide.close + clientSocketToGuide.close + } } def getGuideInfo (variableUUID: UUID): SourceInfo = { @@ -427,7 +441,7 @@ extends BroadcastRecipe with Logging { // TODO: Must fix this. This might never break if broadcast fails. // We should be able to break and send false. Also need to kill threads - while (hasBlocks != totalBlocks) { + while (hasBlocks < totalBlocks) { Thread.sleep(1234) } @@ -448,7 +462,7 @@ extends BroadcastRecipe with Logging { Math.min (listOfSources.size, BroadcastBT.MaxTxPeers) - threadPool.getActiveCount - while(numThreadsToCreate > 0 && hasBlocks < totalBlocks) { + while (hasBlocks < totalBlocks && numThreadsToCreate > 0) { var peerToTalkTo = pickPeerToTalkTo if (peerToTalkTo != null) { threadPool.execute (new TalkToPeer (peerToTalkTo)) @@ -464,8 +478,11 @@ extends BroadcastRecipe with Logging { } // Sleep for a while before starting some more threads + // TODO: Whats up with this? Thread.sleep (500) } + // Shutdown the thread pool + threadPool.shutdown } // TODO: Right now picking the one that has the most blocks this peer wants @@ -510,7 +527,6 @@ extends BroadcastRecipe with Logging { private var oisSource: ObjectInputStream = null override def run = { - // Setup the timeout mechanism var timeOutTask = new TimerTask { override def run = { @@ -616,6 +632,9 @@ extends BroadcastRecipe with Logging { class GuideMultipleRequests extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo] () + override def run = { // TODO: Cached threadpool has 60s keep alive timer var threadPool = Executors.newCachedThreadPool @@ -629,18 +648,24 @@ extends BroadcastRecipe with Logging { guidePortLock.notifyAll } - var keepAccepting = true try { // Don't stop until there is a copy in HDFS - while (keepAccepting || !hasCopyInHDFS) { + while (!stopBroadcast || !hasCopyInHDFS) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimout) + serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimeout) clientSocket = serverSocket.accept } catch { - case e: Exception => { - logInfo ("GuideMultipleRequests Timeout. Stopping listening..." + hasCopyInHDFS) - keepAccepting = false + case e: Exception => { + logInfo ("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // listOfSources.size - 1, because it includes the Guide itself + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + } } } if (clientSocket != null) { @@ -655,9 +680,46 @@ extends BroadcastRecipe with Logging { } } } + + // Shutdown the thread pool + threadPool.shutdown + + logInfo ("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + BroadcastBT.unregisterValue (uuid) } finally { - serverSocket.close + if (serverSocket != null) { + logInfo ("GuideMultipleRequests now stopping...") + serverSocket.close + } + } + } + + private def sendStopBroadcastNotifications = { + listOfSources.synchronized { + listOfSources.foreach { sourceInfo => + // Connect to the source + var guideSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + var gosSource = + new ObjectOutputStream (guideSocketToSource.getOutputStream) + gosSource.flush + var gisSource = + new ObjectInputStream (guideSocketToSource.getInputStream) + + // Throw away whatever comes in + gisSource.readObject.asInstanceOf[SourceInfo] + + // Sent stopBroadcast signal. listenPort = SourceInfo.StopBroadcast + gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast, + SourceInfo.UnusedParam, SourceInfo.UnusedParam)) + gosSource.flush + + gisSource.close + gosSource.close + guideSocketToSource.close + } } } @@ -670,6 +732,7 @@ extends BroadcastRecipe with Logging { private var sourceInfo: SourceInfo = null private var selectedSources: ListBuffer[SourceInfo] = null + // Used to select a rolling window of peers from listOfSources private var rollOverIndex = 0 override def run = { @@ -689,7 +752,7 @@ extends BroadcastRecipe with Logging { } catch { case e: Exception => { // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { + if (listOfSources != null) { listOfSources.synchronized { listOfSources = listOfSources - sourceInfo } @@ -703,11 +766,20 @@ extends BroadcastRecipe with Logging { } // TODO: Randomly select some sources to send back. - // Right now just rolls over the listOfSources to send back + // Right now just rolls over the listOfSources to send back // BroadcastBT.MaxPeersInGuideResponse number of possible sources - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var curIndex = rollOverIndex + private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { var selectedSources = ListBuffer[SourceInfo] () + + // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' + // then add skipSourceInfo to setOfCompletedSources. Return blank. + if (skipSourceInfo.hasBlocks == totalBlocks) { + setOfCompletedSources += skipSourceInfo + return selectedSources + } + + var curIndex = rollOverIndex + listOfSources.synchronized { do { if (listOfSources(curIndex) != skipSourceInfo) { @@ -726,12 +798,10 @@ extends BroadcastRecipe with Logging { class ServeMultipleRequests extends Thread with Logging { override def run = { - // TODO: Look into ExecutorService shutdown and shutdownNow methods // TODO: Not sure if this will be able to fix the number of outgoing links // We should have a timeout mechanism on the receiver side var threadPool = - Executors.newFixedThreadPool( - BroadcastBT.MaxRxPeers).asInstanceOf[ThreadPoolExecutor] + Executors.newFixedThreadPool(BroadcastBT.MaxRxPeers) var serverSocket = new ServerSocket (0) listenPort = serverSocket.getLocalPort @@ -742,17 +812,15 @@ extends BroadcastRecipe with Logging { listenPortLock.notifyAll } - var keepAccepting = true try { - while (keepAccepting) { + while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimout) + serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimeout) clientSocket = serverSocket.accept } catch { case e: Exception => { - logInfo ("ServeMultipleRequests Timeout. Stopping listening...") - keepAccepting = false + logInfo ("ServeMultipleRequests Timeout.") } } if (clientSocket != null) { @@ -768,10 +836,13 @@ extends BroadcastRecipe with Logging { } } } finally { - if (serverSocket != null) { + if (serverSocket != null) { + logInfo ("ServeMultipleRequests now stopping...") serverSocket.close } - } + } + // Shutdown the thread pool + threadPool.shutdown } class ServeSingleRequest (val clientSocket: Socket) @@ -791,18 +862,22 @@ extends BroadcastRecipe with Logging { oos.flush // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources (rxSourceInfo) - // TODO: NOT the most efficient way to do time-based break; - // but using timer can cause a break in the middle :-S + if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + addToListOfSources (rxSourceInfo) + } + val startTime = System.currentTimeMillis var curTime = startTime var keepSending = true var blocksToSend = BroadcastBT.MaxChatBlocks - while (keepSending && blocksToSend > 0 && + while (!stopBroadcast && keepSending && blocksToSend > 0 && (curTime - startTime) < BroadcastBT.MaxChatTime) { val sentBlock = pickAndSendBlock (rxSourceInfo.hasBlocksBitVector) if (sentBlock < 0) { @@ -930,6 +1005,7 @@ case class SourceInfo (val hostAddress: String, val listenPort: Int, var currentLeechers = 0 var receptionFailed = false + var hasBlocks = 0 var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) } @@ -938,6 +1014,7 @@ object SourceInfo { val TxNotStartedRetry = -1 val TxOverGoToHDFS = 0 // Other constants + val StopBroadcast = -2 val UnusedParam = 0 } @@ -1021,7 +1098,9 @@ extends Logging { private var MasterTrackerPort_ : Int = 11111 private var BlockSize_ : Int = 512 * 1024 private var MaxRetryCount_ : Int = 2 - private var ServerSocketTimout_ : Int = 50000 + + private var TrackerSocketTimeout_ : Int = 50000 + private var ServerSocketTimeout_ : Int = 10000 private var trackMV: TrackMultipleValues = null @@ -1055,8 +1134,11 @@ extends Logging { System.getProperty ("spark.broadcast.BlockSize", "512").toInt * 1024 MaxRetryCount_ = System.getProperty ("spark.broadcast.MaxRetryCount", "2").toInt - ServerSocketTimout_ = - System.getProperty ("spark.broadcast.ServerSocketTimout", "50000").toInt + + TrackerSocketTimeout_ = + System.getProperty ("spark.broadcast.TrackerSocketTimeout", "50000").toInt + ServerSocketTimeout_ = + System.getProperty ("spark.broadcast.ServerSocketTimeout", "10000").toInt MinKnockInterval_ = System.getProperty ("spark.broadcast.MinKnockInterval", "500").toInt @@ -1094,7 +1176,9 @@ extends Logging { def MasterTrackerPort = MasterTrackerPort_ def BlockSize = BlockSize_ def MaxRetryCount = MaxRetryCount_ - def ServerSocketTimout = ServerSocketTimout_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ def isMaster = isMaster_ @@ -1122,41 +1206,38 @@ extends Logging { valueToGuideMap.synchronized { valueToGuideMap (uuid) = SourceInfo ("", SourceInfo.TxOverGoToHDFS, SourceInfo.UnusedParam, SourceInfo.UnusedParam) - logInfo ("Value unregistered from the Tracker " + valueToGuideMap) + logInfo ("Value unregistered from the Tracker " + valueToGuideMap) } } -// def startMultiTracker -// def stopMultiTracker - class TrackMultipleValues extends Thread with Logging { - var keepAccepting = true + var stopTracker = false override def run = { var threadPool = Executors.newCachedThreadPool var serverSocket: ServerSocket = null serverSocket = new ServerSocket (BroadcastBT.MasterTrackerPort) - logInfo ("TrackMultipleValues" + serverSocket) + logInfo ("TrackMultipleValues" + serverSocket) try { - while (keepAccepting) { + while (!stopTracker) { var clientSocket: Socket = null try { // TODO: - serverSocket.setSoTimeout (ServerSocketTimout) + serverSocket.setSoTimeout (TrackerSocketTimeout) clientSocket = serverSocket.accept } catch { - case e: Exception => { + case e: Exception => { logInfo ("TrackMultipleValues Timeout. Stopping listening...") // TODO: Tracking should be explicitly stopped by the SparkContext - keepAccepting = false + stopTracker = true } } if (clientSocket != null) { - try { + try { threadPool.execute (new Thread { override def run = { val oos = new ObjectOutputStream (clientSocket.getOutputStream) @@ -1164,12 +1245,12 @@ extends Logging { val ois = new ObjectInputStream (clientSocket.getInputStream) try { val uuid = ois.readObject.asInstanceOf[UUID] - var gInfo = + var gInfo = if (valueToGuideMap.contains (uuid)) { valueToGuideMap (uuid) - } else SourceInfo ("", SourceInfo.TxNotStartedRetry, + } else SourceInfo ("", SourceInfo.TxNotStartedRetry, SourceInfo.UnusedParam, SourceInfo.UnusedParam) - logInfo ("TrackMultipleValues:Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) + logInfo ("TrackMultipleValues:Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) oos.writeObject (gInfo) } catch { case e: Exception => { } @@ -1190,7 +1271,9 @@ extends Logging { } } finally { serverSocket.close - } + } + // Shutdown the thread pool + threadPool.shutdown } } } |