From 49a2db09fb4b8cf2b2f207ad90a6c492f6952e87 Mon Sep 17 00:00:00 2001 From: Mosharaf Chowdhury Date: Sat, 6 Nov 2010 19:27:46 -0700 Subject: Graceful shutdown is working with dualMode=false. Probably will have to remove dualMode completely. Made BroadcastCS code more consistent with BT branches. --- src/scala/spark/Broadcast.scala | 223 ++++++++++++++++++++++++++++++---------- 1 file changed, 166 insertions(+), 57 deletions(-) diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala index b5bcdde21b..fc83803350 100644 --- a/src/scala/spark/Broadcast.scala +++ b/src/scala/spark/Broadcast.scala @@ -2,13 +2,13 @@ package spark import java.io._ import java.net._ -import java.util.{Comparator, PriorityQueue, UUID} +import java.util.{Comparator, PriorityQueue, Random, UUID} import com.google.common.collect.MapMaker import java.util.concurrent.{Executors, ExecutorService, ThreadPoolExecutor} -import scala.collection.mutable.Map +import scala.collection.mutable.{Map, Set} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} @@ -67,6 +67,7 @@ extends BroadcastRecipe with Logging { @transient var guidePort = -1 @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false // Must call this after all the variables have been created/initialized if (!local) { @@ -110,12 +111,12 @@ extends BroadcastRecipe with Logging { pqOfSources = new PriorityQueue[SourceInfo] val masterSource_0 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) + SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) pqOfSources.add (masterSource_0) // Add one more time to have two replicas of any seeds in the PQ if (BroadcastCS.DualMode) { val masterSource_1 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) + SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) pqOfSources.add (masterSource_1) } @@ -171,12 +172,17 @@ extends BroadcastRecipe with Logging { totalBytes = -1 totalBlocks = -1 hasBlocks = 0 + listenPortLock = new Object totalBlocksLock = new Object hasBlocksLock = new Object + serveMR = null + hostAddress = InetAddress.getLocalHost.getHostAddress listenPort = -1 + + stopBroadcast = false } private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { @@ -227,16 +233,12 @@ extends BroadcastRecipe with Logging { return retVal } - // masterListenPort aka guidePort value legend - // 0 = missed the broadcast, read from HDFS; - // <0 = hasn't started yet, wait & retry; - // >0 = Read from this port def getMasterListenPort (variableUUID: UUID): Int = { var clientSocketToTracker: Socket = null var oosTracker: ObjectOutputStream = null var oisTracker: ObjectInputStream = null - var masterListenPort: Int = -1 + var masterListenPort: Int = SourceInfo.TxOverGoToHDFS var retriesLeft = BroadcastCS.MaxRetryCount do { @@ -255,8 +257,7 @@ extends BroadcastRecipe with Logging { oosTracker.flush masterListenPort = oisTracker.readObject.asInstanceOf[Int] } catch { - // In case of any failure, set masterListenPort = 0 to read from HDFS - case e: Exception => (masterListenPort = 0) + case e: Exception => { } } finally { if (oisTracker != null) { oisTracker.close @@ -269,18 +270,25 @@ extends BroadcastRecipe with Logging { } } retriesLeft -= 1 - // TODO: Should wait before retrying - } while (retriesLeft > 0 && masterListenPort < 0) + + Thread.sleep (BroadcastCS.ranGen.nextInt ( + BroadcastCS.MaxKnockInterval - BroadcastCS.MinKnockInterval) + + BroadcastCS.MinKnockInterval) + + } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) + logInfo ("Got this guidePort from Tracker: " + masterListenPort) return masterListenPort } def receiveBroadcast (variableUUID: UUID): Boolean = { - // Get masterListenPort for this variable from the Tracker val masterListenPort = getMasterListenPort (variableUUID) - // If Tracker says that there is no guide for this object, read from HDFS - if (masterListenPort == 0) { - return false + + if (masterListenPort == SourceInfo.TxOverGoToHDFS || + masterListenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false } // Wait until hostAddress and listenPort are created by the @@ -291,22 +299,28 @@ extends BroadcastRecipe with Logging { } } + var clientSocketToMaster: Socket = null + var oosMaster: ObjectOutputStream = null + var oisMaster: ObjectInputStream = null + // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures var retriesLeft = BroadcastCS.MaxRetryCount do { // Connect to Master and send this worker's Information - val clientSocketToMaster = + clientSocketToMaster = new Socket(BroadcastCS.MasterHostAddress, masterListenPort) - logInfo ("Connected to Master's guiding object") // TODO: Guiding object connection is reusable - val oosMaster = + oosMaster = new ObjectOutputStream (clientSocketToMaster.getOutputStream) oosMaster.flush - val oisMaster = + oisMaster = new ObjectInputStream (clientSocketToMaster.getInputStream) - oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0)) + logInfo ("Connected to Master's guiding object") + + // Send local source information + oosMaster.writeObject(SourceInfo (hostAddress, listenPort, -1, -1, 0)) oosMaster.flush // Receive source information from Master @@ -333,12 +347,18 @@ extends BroadcastRecipe with Logging { // Send back statistics to the Master oosMaster.writeObject (sourceInfo) - oisMaster.close - oosMaster.close - clientSocketToMaster.close + if (oisMaster != null) { + oisMaster.close + } + if (oosMaster != null) { + oosMaster.close + } + if (clientSocketToMaster != null) { + clientSocketToMaster.close + } retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks != totalBlocks) + } while (retriesLeft > 0 && hasBlocks < totalBlocks) return (hasBlocks == totalBlocks) } @@ -383,7 +403,6 @@ extends BroadcastRecipe with Logging { hasBlocksLock.notifyAll } } - logInfo ("After the receive loop...") } catch { case e: Exception => { logInfo ("receiveSingleTransmission had a " + e) @@ -405,8 +424,10 @@ extends BroadcastRecipe with Logging { class GuideMultipleRequests extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo] () + override def run = { - // TODO: Cached threadpool has 60 s keep alive timer var threadPool = Executors.newCachedThreadPool var serverSocket: ServerSocket = null @@ -418,10 +439,9 @@ extends BroadcastRecipe with Logging { guidePortLock.notifyAll } - var keepAccepting = true try { // Don't stop until there is a copy in HDFS - while (keepAccepting || !hasCopyInHDFS) { + while (!stopBroadcast || !hasCopyInHDFS) { var clientSocket: Socket = null try { serverSocket.setSoTimeout (BroadcastCS.ServerSocketTimeout) @@ -429,7 +449,14 @@ extends BroadcastRecipe with Logging { } catch { case e: Exception => { logInfo ("GuideMultipleRequests Timeout.") - keepAccepting = false + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // pqOfSources.size - 1, because it includes the Guide itself + if (pqOfSources.size > 1 && + setOfCompletedSources.size == pqOfSources.size - 1) { + stopBroadcast = true + } } } if (clientSocket != null) { @@ -442,6 +469,10 @@ extends BroadcastRecipe with Logging { } } } + + logInfo ("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + BroadcastCS.unregisterValue (uuid) } finally { if (serverSocket != null) { @@ -449,8 +480,52 @@ extends BroadcastRecipe with Logging { serverSocket.close } } + + // Shutdown the thread pool + threadPool.shutdown } + private def sendStopBroadcastNotifications = { + pqOfSources.synchronized { + var pqIter = pqOfSources.iterator + while (pqIter.hasNext) { + var sourceInfo = pqIter.next + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream (guideSocketToSource.getOutputStream) + gosSource.flush + gisSource = + new ObjectInputStream (guideSocketToSource.getInputStream) + + // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 + gosSource.writeObject ((SourceInfo.StopBroadcast, + SourceInfo.StopBroadcast)) + gosSource.flush + } catch { + case e: Exception => { } + } finally { + if (gisSource != null) { + gisSource.close + } + if (gosSource != null) { + gosSource.close + } + if (guideSocketToSource != null) { + guideSocketToSource.close + } + } + } + } + } + class GuideSingleRequest (val clientSocket: Socket) extends Thread with Logging { private val oos = new ObjectOutputStream (clientSocket.getOutputStream) @@ -470,13 +545,13 @@ extends BroadcastRecipe with Logging { pqOfSources.synchronized { // Select a suitable source and send it back to the worker selectedSourceInfo = selectSuitableSource (sourceInfo) - logInfo ("Sending selectedSourceInfo:" + selectedSourceInfo) + logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo) oos.writeObject (selectedSourceInfo) oos.flush // Add this new (if it can finish) source to the PQ of sources - thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, 0) + thisWorkerInfo = SourceInfo (sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes, 0) logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo) pqOfSources.add (thisWorkerInfo) } @@ -490,10 +565,14 @@ extends BroadcastRecipe with Logging { assert (pqOfSources.contains (selectedSourceInfo)) // Remove first - pqOfSources.remove (selectedSourceInfo) + pqOfSources.remove (selectedSourceInfo) // TODO: Removing a source based on just one failure notification! + // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { + if (!sourceInfo.receptionFailed) { + // Add thisWorkerInfo to sources that have completed reception + setOfCompletedSources += thisWorkerInfo + selectedSourceInfo.currentLeechers -= 1 selectedSourceInfo.MBps = sourceInfo.MBps @@ -506,7 +585,7 @@ extends BroadcastRecipe with Logging { // No need to find and update thisWorkerInfo, but add its replica if (BroadcastCS.DualMode) { - pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress, + pqOfSources.add (SourceInfo (thisWorkerInfo.hostAddress, thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1)) } } @@ -574,9 +653,8 @@ extends BroadcastRecipe with Logging { listenPortLock.notifyAll } - var keepAccepting = true try { - while (keepAccepting) { + while (!stopBroadcast) { var clientSocket: Socket = null try { serverSocket.setSoTimeout (BroadcastCS.ServerSocketTimeout) @@ -584,7 +662,6 @@ extends BroadcastRecipe with Logging { } catch { case e: Exception => { logInfo ("ServeMultipleRequests Timeout.") - keepAccepting = false } } if (clientSocket != null) { @@ -603,6 +680,9 @@ extends BroadcastRecipe with Logging { serverSocket.close } } + + // Shutdown the thread pool + threadPool.shutdown } class ServeSingleRequest (val clientSocket: Socket) @@ -619,11 +699,17 @@ extends BroadcastRecipe with Logging { logInfo ("new ServeSingleRequest is running") // Receive range to send - var sendRange = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = sendRange._1 - sendUntil = sendRange._2 + var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] + sendFrom = rangeToSend._1 + sendUntil = rangeToSend._2 - sendObject + if (sendFrom == SourceInfo.StopBroadcast && + sendUntil == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + sendObject + } } catch { // TODO: Need to add better exception handling here // If something went wrong, e.g., the worker at the other end died etc. @@ -785,6 +871,8 @@ extends Comparable [SourceInfo] with Logging { var receptionFailed = false var MBps: Double = BroadcastCS.MaxMBps + var hasBlocks = 0 + // Ascending sort based on leecher count def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) @@ -806,6 +894,15 @@ extends Comparable [SourceInfo] with Logging { // } } +object SourceInfo { + // Constants for special values of listenPort + val TxNotStartedRetry = -1 + val TxOverGoToHDFS = 0 + // Other constants + val StopBroadcast = -2 + val UnusedParam = 0 +} + @serializable case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } @@ -847,6 +944,9 @@ extends Logging { var sourceToSpeedMap = Map[String, Double] () + // Random number generator + var ranGen = new Random + private var initialized = false private var isMaster_ = false @@ -867,6 +967,9 @@ extends Logging { // 125.0 MBps = 1 Gbps link private val MaxMBps_ = 125.0 + private var MinKnockInterval_ = 500 + private var MaxKnockInterval_ = 999 + def initialize (isMaster__ : Boolean) { synchronized { if (!initialized) { @@ -884,6 +987,11 @@ extends Logging { ServerSocketTimeout_ = System.getProperty ("spark.broadcast.ServerSocketTimeout", "10000").toInt + MinKnockInterval_ = + System.getProperty ("spark.broadcast.MinKnockInterval", "500").toInt + MaxKnockInterval_ = + System.getProperty ("spark.broadcast.MaxKnockInterval", "999").toInt + DualMode_ = System.getProperty ("spark.broadcast.DualMode", "false").toBoolean @@ -913,6 +1021,9 @@ extends Logging { def isMaster = isMaster_ + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + def MaxMBps = MaxMBps_ def registerValue (uuid: UUID, guidePort: Int) = { @@ -924,8 +1035,7 @@ extends Logging { def unregisterValue (uuid: UUID) = { valueToGuidePortMap.synchronized { - // Set to 0 to make sure that people read it from HDFS - valueToGuidePortMap (uuid) = 0 + valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap) } } @@ -946,7 +1056,7 @@ extends Logging { class TrackMultipleValues extends Thread with Logging { - var keepAccepting = true + var stopTracker = false override def run = { var threadPool = Executors.newCachedThreadPool @@ -956,37 +1066,33 @@ extends Logging { logInfo ("TrackMultipleValues" + serverSocket) try { - while (keepAccepting) { + while (!stopTracker) { var clientSocket: Socket = null try { // TODO: - serverSocket.setSoTimeout (ServerSocketTimeout) + serverSocket.setSoTimeout (TrackerSocketTimeout) clientSocket = serverSocket.accept } catch { case e: Exception => { logInfo ("TrackMultipleValues Timeout. Stopping listening...") // TODO: Tracking should be explicitly stopped by the SparkContext - keepAccepting = false + stopTracker = true } } if (clientSocket != null) { try { - threadPool.execute (new Runnable { + threadPool.execute (new Thread { override def run = { val oos = new ObjectOutputStream (clientSocket.getOutputStream) oos.flush val ois = new ObjectInputStream (clientSocket.getInputStream) try { val uuid = ois.readObject.asInstanceOf[UUID] - // masterListenPort/guidePort value legend - // 0 = missed the broadcast, read from HDFS; - // <0 = hasn't started yet, wait & retry; - // >0 = Read from this port var guidePort = if (valueToGuidePortMap.contains (uuid)) { valueToGuidePortMap (uuid) - } else -1 + } else SourceInfo.TxNotStartedRetry logInfo ("TrackMultipleValues:Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) oos.writeObject (guidePort) } catch { @@ -1006,7 +1112,10 @@ extends Logging { } } finally { serverSocket.close - } + } + + // Shutdown the thread pool + threadPool.shutdown } } } -- cgit v1.2.3