diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-06-26 19:22:27 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-06-26 19:22:27 -0700 |
commit | bae8a9796816cde998b04aff9b4cddaff2864b17 (patch) | |
tree | 12ed824b98354d2e863a40e364452bd95a62c0da /core | |
parent | 23b42af70ad1ee0bfb3bd1936bba3dc224f9eb42 (diff) | |
parent | b187675b686a74c754e5d502b000ec007b5d4e48 (diff) | |
download | spark-bae8a9796816cde998b04aff9b4cddaff2864b17.tar.gz spark-bae8a9796816cde998b04aff9b4cddaff2864b17.tar.bz2 spark-bae8a9796816cde998b04aff9b4cddaff2864b17.zip |
Merge branch 'master' into scala-2.9
Conflicts:
repl/src/main/scala/spark/repl/SparkInterpreterLoop.scala
Diffstat (limited to 'core')
17 files changed, 3397 insertions, 2275 deletions
diff --git a/core/src/main/scala/spark/BitTorrentBroadcast.scala b/core/src/main/scala/spark/BitTorrentBroadcast.scala deleted file mode 100644 index 126e61dc7d..0000000000 --- a/core/src/main/scala/spark/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1236 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} - -import scala.collection.mutable.{ListBuffer, Map, Set} - -@serializable -class BitTorrentBroadcast[T] (@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging { - - def value = value_ - - BitTorrentBroadcast.synchronized { - BitTorrentBroadcast.values.put (uuid, value_) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo] () - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in Master - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var rxSpeeds = new SpeedTracker - @transient var txSpeeds = new SpeedTracker - - @transient var hostAddress = InetAddress.getLocalHost.getHostAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var hasCopyInHDFS = false - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast - } - - def sendBroadcast (): Unit = { - logInfo ("Local host address: " + hostAddress) - - // Store a persistent copy in HDFS - // TODO: Turned OFF for now - // val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) - // out.writeObject (value_) - // out.close - // TODO: Fix this at some point - hasCopyInHDFS = true - - // Create a variableInfo object and store it in valueInfos - var variableInfo = blockifyObject (value_, BitTorrentBroadcast.BlockSize) - - // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - // Guide has all the blocks - hasBlocksBitVector = new BitSet (totalBlocks) - hasBlocksBitVector.set (0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int] (totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon (true) - guideMR.start - logInfo ("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait - } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon (true) - serveMR.start - logInfo ("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - masterSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += masterSource - - // Register with the Tracker - BitTorrentBroadcast.registerValue (uuid, - SourceInfo (hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject (in: ObjectInputStream): Unit = { - in.defaultReadObject - BitTorrentBroadcast.synchronized { - val cachedVal = BitTorrentBroadcast.values.get (uuid) - - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Only the first worker in a node can ever be inside this 'else' - initializeWorkerVariables - - logInfo ("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon (true) - serveMR.start - logInfo ("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast (uuid) - // If does not succeed, then get from HDFS copy - if (receptionSucceeded) { - value_ = unBlockifyObject[T] - BitTorrentBroadcast.values.put (uuid, value_) - } else { - // TODO: This part won't work, cause HDFS writing is turned OFF - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BitTorrentBroadcast.values.put(uuid, value_) - fileIn.close - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Master sends everything as 0/null - private def initializeWorkerVariables: Unit = { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - rxSpeeds = new SpeedTracker - txSpeeds = new SpeedTracker - - hostAddress = InetAddress.getLocalHost.getHostAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo] () - - stopBroadcast = false - } - - private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream (baos) - oos.writeObject (obj) - oos.close - baos.close - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream (byteArray) - - var blockNum = (byteArray.length / blockSize) - if (byteArray.length % blockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock] (blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, blockSize)) { - val thisBlockSize = math.min (blockSize, byteArray.length - i) - var tempByteArray = new Array[Byte] (thisBlockSize) - val hasRead = bais.read (tempByteArray, 0, thisBlockSize) - - retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) - blockID += 1 - } - bais.close - - var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - private def unBlockifyObject[A]: A = { - var retByteArray = new Array[Byte] (totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BitTorrentBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject (retByteArray) - } - - private def byteArrayToObject[A] (bytes: Array[Byte]): A = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) - val retVal = in.readObject.asInstanceOf[A] - in.close - return retVal - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait - } - } - - var localSourceInfo = SourceInfo (hostAddress, listenPort, totalBlocks, - totalBytes) - - localSourceInfo.hasBlocks = hasBlocks - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources (newSourceInfo: SourceInfo): Unit = { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources (newSourceInfos: ListBuffer[SourceInfo]): Unit = { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources (newSourceInfo) - } - } - - class TalkToGuide (gInfo: SourceInfo) - extends Thread with Logging { - override def run: Unit = { - - // Keep exchaning information until all blocks have been received - while (hasBlocks < totalBlocks) { - talkOnce - Thread.sleep (BitTorrentBroadcast.ranGen.nextInt ( - BitTorrentBroadcast.MaxKnockInterval - BitTorrentBroadcast.MinKnockInterval) + - BitTorrentBroadcast.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce: Unit = { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream) - oosGuide.flush - oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logInfo("Received suitableSources from Master " + suitableSources) - - addToListOfSources (suitableSources) - - oisGuide.close - oosGuide.close - clientSocketToGuide.close - } - } - - def getGuideInfo (variableUUID: UUID): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo ("", SourceInfo.TxOverGoToHDFS, - SourceInfo.UnusedParam, SourceInfo.UnusedParam) - - var retriesLeft = BitTorrentBroadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - val clientSocketToTracker = - new Socket(BitTorrentBroadcast.MasterHostAddress, BitTorrentBroadcast.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream (clientSocketToTracker.getOutputStream) - oosTracker.flush - val oisTracker = - new ObjectInputStream (clientSocketToTracker.getInputStream) - - // Send UUID and receive GuideInfo - oosTracker.writeObject (uuid) - oosTracker.flush - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => { - logInfo ("getGuideInfo had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close - } - if (oosTracker != null) { - oosTracker.close - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close - } - } - - Thread.sleep (BitTorrentBroadcast.ranGen.nextInt ( - BitTorrentBroadcast.MaxKnockInterval - BitTorrentBroadcast.MinKnockInterval) + - BitTorrentBroadcast.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logInfo ("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def receiveBroadcast (variableUUID: UUID): Boolean = { - val gInfo = getGuideInfo (variableUUID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS || - gInfo.listenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) - hasBlocksBitVector = new BitSet (totalBlocks) - numCopiesSent = new Array[Int] (totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll - } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide (gInfo) - ttGuide.setDaemon (true) - ttGuide.start - logInfo ("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon (true) - pcController.start - logInfo ("PeerChatterController started...") - - // TODO: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks < totalBlocks) { - Thread.sleep(BitTorrentBroadcast.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo] () - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet (totalBlocks) - - override def run: Unit = { - var threadPool = - Broadcast.newDaemonFixedThreadPool (BitTorrentBroadcast.MaxTxPeers) - - while (hasBlocks < totalBlocks) { - var numThreadsToCreate = - math.min (listOfSources.size, BitTorrentBroadcast.MaxTxPeers) - - threadPool.getActiveCount - - while (hasBlocks < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkTo - if (peerToTalkTo != null) { - threadPool.execute (new TalkToPeer (peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { - peersNowTalking += peerToTalkTo - } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep (BitTorrentBroadcast.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkTo: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logInfo ("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo] () - synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip (0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and (eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always pick randomly or randomly pick randomly? - // Now always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = BitTorrentBroadcast.ranGen.nextInt (peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - if (curPeer != null) - logInfo ("Peer chosen: " + curPeer + " with " + curPeer.hasBlocksBitVector) - else - logInfo ("No peer chosen...") - - return curPeer - } - - class TalkToPeer (peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run: Unit = { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUpConnections - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule (timeOutTask, BitTorrentBroadcast.MaxKnockInterval) - - logInfo ("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket (peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream (peerSocketToSource.getOutputStream) - oosSource.flush - oisSource = - new ObjectInputStream (peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources (newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush - - var keepReceiving = true - - while (hasBlocks < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockToRequest (newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other thread know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set (blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush - - // Receive the requested block - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - // Expecting sender to send the block that was asked for - assert (bcBlock.blockID == blockToAskFor) - - logInfo ("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set (bcBlock.blockID) - } - hasBlocks += 1 - - rxSpeeds.addDataPoint (peerToTalkTo, receptionTime) - - // blockToAskFor has arrived. Not in request any more - // Probably no need to update it though - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set (bcBlock.blockID, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo ("TalktoPeer had a " + e) - // TODO: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo ("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set (blockToAskFor, false) - } - } - - cleanUpConnections - } - } - - // Right now it picks a block uniformly that this peer does not have - // TODO: Implement more intelligent block selection policies - private def pickBlockToRequest (txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // BitTorrentBroadcast.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks / totalBlocks) < BitTorrentBroadcast.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or (blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip (0, needBlocksBitVector.size) - - // Blocks that should be requested - needBlocksBitVector.and (txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = BitTorrentBroadcast.ranGen.nextInt (needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit (0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit (pickedBlockIndex + 1) - i = i - 1 - } - - return pickedBlockIndex - } - } - - private def cleanUpConnections: Unit = { - if (oisSource != null) { - oisSource.close - } - if (oosSource != null) { - oosSource.close - } - if (peerSocketToSource != null) { - peerSocketToSource.close - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { - peersNowTalking = peersNowTalking - peerToTalkTo - } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo] () - - override def run: Unit = { - var threadPool = Broadcast.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (0) - guidePort = serverSocket.getLocalPort - logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { - guidePortLock.notifyAll - } - - try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (BitTorrentBroadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo ("GuideMultipleRequests Timeout.") - - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - } - } - } - if (clientSocket != null) { - logInfo ("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute (new GuideSingleRequest (clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown - - logInfo ("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - BitTorrentBroadcast.unregisterValue (uuid) - } finally { - if (serverSocket != null) { - logInfo ("GuideMultipleRequests now stopping...") - serverSocket.close - } - } - } - - private def sendStopBroadcastNotifications: Unit = { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = - new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream (guideSocketToSource.getOutputStream) - gosSource.flush - gisSource = - new ObjectInputStream (guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast, - SourceInfo.UnusedParam, SourceInfo.UnusedParam)) - gosSource.flush - } catch { - case e: Exception => { - logInfo ("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close - } - if (gosSource != null) { - gosSource.close - } - if (guideSocketToSource != null) { - guideSocketToSource.close - } - } - } - } - } - - class GuideSingleRequest (val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - oos.flush - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run: Unit = { - try { - logInfo ("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources (sourceInfo) - logInfo ("Sending selectedSources:" + selectedSources) - oos.writeObject (selectedSources) - oos.flush - - // Add this source to the listOfSources - addToListOfSources (sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { - listOfSources = listOfSources - sourceInfo - } - } - } - } finally { - ois.close - oos.close - clientSocket.close - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo] () - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources += skipSourceInfo - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= BitTorrentBroadcast.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = BitTorrentBroadcast.MaxPeersInGuideResponse - var alreadyPicked = new BitSet (listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = BitTorrentBroadcast.ranGen.nextInt (listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - alreadyPicked.set (i) - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most BitTorrentBroadcast.MaxRxPeers peers - var threadPool = - Broadcast.newDaemonFixedThreadPool(BitTorrentBroadcast.MaxRxPeers) - - override def run: Unit = { - var serverSocket = new ServerSocket (0) - listenPort = serverSocket.getLocalPort - - logInfo ("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { - listenPortLock.notifyAll - } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (BitTorrentBroadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo ("ServeMultipleRequests Timeout.") - } - } - if (clientSocket != null) { - logInfo ("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute (new ServeSingleRequest (clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo ("ServeMultipleRequests now stopping...") - serverSocket.close - } - } - // Shutdown the thread pool - threadPool.shutdown - } - - class ServeSingleRequest (val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - oos.flush - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - logInfo ("new ServeSingleRequest is running") - - override def run: Unit = { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - // Carry on - addToListOfSources (rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = BitTorrentBroadcast.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - val blockToSend = ois.readObject.asInstanceOf[Int] - - // Send the block - sendBlock (blockToSend) - rxSourceInfo.hasBlocksBitVector.set (blockToSend) - - numBlocksToSend = numBlocksToSend - 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources (rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= BitTorrentBroadcast.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - // Exception can happen if the receiver stops receiving - case e: Exception => { - logInfo ("ServeSingleRequest had a " + e) - } - } finally { - logInfo ("ServeSingleRequest is closing streams and sockets") - ois.close - // TODO: The following line causes a "java.net.SocketException: Socket closed" - oos.close - clientSocket.close - } - } - - private def sendBlock (blockToSend: Int): Unit = { - try { - oos.writeObject (arrayOfBlocks(blockToSend)) - oos.flush - } catch { - case e: Exception => { - logInfo ("sendBlock had a " + e) - } - } - logInfo ("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize (isMaster: Boolean) = BitTorrentBroadcast.initialize (isMaster) - def newBroadcast[T] (value_ : T, isLocal: Boolean) = - new BitTorrentBroadcast[T] (value_, isLocal) -} - -private object BitTorrentBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuideMap = Map[UUID, SourceInfo] () - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress - private var MasterTrackerPort_ : Int = 11111 - private var BlockSize_ : Int = 512 * 1024 - private var MaxRetryCount_ : Int = 2 - - private var TrackerSocketTimeout_ : Int = 50000 - private var ServerSocketTimeout_ : Int = 10000 - - private var trackMV: TrackMultipleValues = null - - // A peer syncs back to Guide after waiting randomly within following limits - // Also used thoughout the code for small and large waits/timeouts - private var MinKnockInterval_ = 500 - private var MaxKnockInterval_ = 999 - - private var MaxPeersInGuideResponse_ = 4 - - // Maximum number of receiving and sending threads of a peer - private var MaxRxPeers_ = 4 - private var MaxTxPeers_ = 4 - - // Peers can char at most this milliseconds or transfer this number of blocks - private var MaxChatTime_ = 250 - private var MaxChatBlocks_ = 1024 - - // Fraction of blocks to receive before entering the end game - private var EndGameFraction_ = 1.0 - - - def initialize (isMaster__ : Boolean): Unit = { - synchronized { - if (!initialized) { - // Fix for issue #42 - MasterHostAddress_ = - System.getProperty ("spark.broadcast.masterHostAddress", "") - MasterTrackerPort_ = - System.getProperty ("spark.broadcast.masterTrackerPort", "11111").toInt - BlockSize_ = - System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 - MaxRetryCount_ = - System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt - - TrackerSocketTimeout_ = - System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt - ServerSocketTimeout_ = - System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt - - MinKnockInterval_ = - System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt - MaxKnockInterval_ = - System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt - - MaxPeersInGuideResponse_ = - System.getProperty ("spark.broadcast.maxPeersInGuideResponse", "4").toInt - - MaxRxPeers_ = - System.getProperty ("spark.broadcast.maxRxPeers", "4").toInt - MaxTxPeers_ = - System.getProperty ("spark.broadcast.maxTxPeers", "4").toInt - - MaxChatTime_ = - System.getProperty ("spark.broadcast.maxChatTime", "250").toInt - MaxChatBlocks_ = - System.getProperty ("spark.broadcast.maxChatBlocks", "1024").toInt - - EndGameFraction_ = - System.getProperty ("spark.broadcast.endGameFraction", "1.0").toDouble - - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon (true) - trackMV.start - logInfo ("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def isMaster = isMaster_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxRxPeers = MaxRxPeers_ - def MaxTxPeers = MaxTxPeers_ - - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - def registerValue (uuid: UUID, gInfo: SourceInfo): Unit = { - valueToGuideMap.synchronized { - valueToGuideMap += (uuid -> gInfo) - logInfo ("New value registered with the Tracker " + valueToGuideMap) - } - } - - def unregisterValue (uuid: UUID): Unit = { - valueToGuideMap.synchronized { - valueToGuideMap (uuid) = SourceInfo ("", SourceInfo.TxOverGoToHDFS, - SourceInfo.UnusedParam, SourceInfo.UnusedParam) - logInfo ("Value unregistered from the Tracker " + valueToGuideMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run: Unit = { - var threadPool = Broadcast.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (BitTorrentBroadcast.MasterTrackerPort) - logInfo ("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo ("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute (new Thread { - override def run: Unit = { - val oos = new ObjectOutputStream (clientSocket.getOutputStream) - oos.flush - val ois = new ObjectInputStream (clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var gInfo = - if (valueToGuideMap.contains (uuid)) { - valueToGuideMap (uuid) - } else SourceInfo ("", SourceInfo.TxNotStartedRetry, - SourceInfo.UnusedParam, SourceInfo.UnusedParam) - logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) - oos.writeObject (gInfo) - } catch { - case e: Exception => { - logInfo ("TrackMultipleValues had a " + e) - } - } finally { - ois.close - oos.close - clientSocket.close - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close - } - } - } - } - } finally { - serverSocket.close - } - // Shutdown the thread pool - threadPool.shutdown - } - } -} diff --git a/core/src/main/scala/spark/Broadcast.scala b/core/src/main/scala/spark/Broadcast.scala deleted file mode 100644 index fe2ab1ebf0..0000000000 --- a/core/src/main/scala/spark/Broadcast.scala +++ /dev/null @@ -1,140 +0,0 @@ -package spark - -import java.util.{BitSet, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} - -@serializable -trait Broadcast[T] { - val uuid = UUID.randomUUID - - def value: T - - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. Possibly a Scala bug! - - override def toString = "spark.Broadcast(" + uuid + ")" -} - -trait BroadcastFactory { - def initialize (isMaster: Boolean): Unit - def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T] -} - -private object Broadcast -extends Logging { - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - // Called by SparkContext or Executor before using Broadcast - def initialize (isMaster: Boolean): Unit = synchronized { - if (!initialized) { - val broadcastFactoryClass = System.getProperty("spark.broadcast.factory", - "spark.DfsBroadcastFactory") - val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef]) - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isMaster) - - initialized = true - } - } - - def getBroadcastFactory: BroadcastFactory = { - if (broadcastFactory == null) { - throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") - } - broadcastFactory - } - - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } - - // Wrapper over newCachedThreadPool - def newDaemonCachedThreadPool: ThreadPoolExecutor = { - var threadPool = - Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } -} - -@serializable -case class SourceInfo (val hostAddress: String, val listenPort: Int, - val totalBlocks: Int, val totalBytes: Int) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)} - -object SourceInfo { - // Constants for special values of listenPort - val TxNotStartedRetry = -1 - val TxOverGoToHDFS = 0 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} - -@serializable -case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } - -@serializable -case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], - val totalBlocks: Int, val totalBytes: Int) { - @transient var hasBlocks = 0 -} - -@serializable -class SpeedTracker { - // Mapping 'source' to '(totalTime, numBlocks)' - private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] () - - def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = { - sourceToSpeedMap.synchronized { - if (!sourceToSpeedMap.contains(srcInfo)) { - sourceToSpeedMap += (srcInfo -> (timeInMillis, 1)) - } else { - val tTnB = sourceToSpeedMap (srcInfo) - sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1)) - } - } - } - - def getTimePerBlock (srcInfo: SourceInfo): Double = { - sourceToSpeedMap.synchronized { - val tTnB = sourceToSpeedMap (srcInfo) - return tTnB._1 / tTnB._2 - } - } - - override def toString = sourceToSpeedMap.toString -} diff --git a/core/src/main/scala/spark/ChainedBroadcast.scala b/core/src/main/scala/spark/ChainedBroadcast.scala deleted file mode 100644 index 6f2cc3f6f0..0000000000 --- a/core/src/main/scala/spark/ChainedBroadcast.scala +++ /dev/null @@ -1,873 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{Comparator, PriorityQueue, Random, UUID} - -import scala.collection.mutable.{Map, Set} - -@serializable -class ChainedBroadcast[T] (@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging { - - def value = value_ - - ChainedBroadcast.synchronized { - ChainedBroadcast.values.put (uuid, value_) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var pqOfSources = new PriorityQueue[SourceInfo] - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = InetAddress.getLocalHost.getHostAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var hasCopyInHDFS = false - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast - } - - def sendBroadcast (): Unit = { - logInfo ("Local host address: " + hostAddress) - - // Store a persistent copy in HDFS - // TODO: Turned OFF for now - // val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) - // out.writeObject (value_) - // out.close - // TODO: Fix this at some point - hasCopyInHDFS = true - - // Create a variableInfo object and store it in valueInfos - var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon (true) - guideMR.start - logInfo ("GuideMultipleRequests started...") - - serveMR = new ServeMultipleRequests - serveMR.setDaemon (true) - serveMR.start - logInfo ("ServeMultipleRequests started...") - - // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - pqOfSources = new PriorityQueue[SourceInfo] - val masterSource_0 = - SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes) - pqOfSources.add (masterSource_0) - - // Register with the Tracker - while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait - } - } - ChainedBroadcast.registerValue (uuid, guidePort) - } - - private def readObject (in: ObjectInputStream): Unit = { - in.defaultReadObject - ChainedBroadcast.synchronized { - val cachedVal = ChainedBroadcast.values.get (uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Initializing everything because Master will only send null/0 values - initializeSlaveVariables - - logInfo ("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon (true) - serveMR.start - logInfo ("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast (uuid) - // If does not succeed, then get from HDFS copy - if (receptionSucceeded) { - value_ = unBlockifyObject[T] - ChainedBroadcast.values.put (uuid, value_) - } else { - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - ChainedBroadcast.values.put(uuid, value_) - fileIn.close - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } - - private def initializeSlaveVariables: Unit = { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = InetAddress.getLocalHost.getHostAddress - listenPort = -1 - - stopBroadcast = false - } - - private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream (baos) - oos.writeObject (obj) - oos.close - baos.close - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream (byteArray) - - var blockNum = (byteArray.length / blockSize) - if (byteArray.length % blockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock] (blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, blockSize)) { - val thisBlockSize = math.min (blockSize, byteArray.length - i) - var tempByteArray = new Array[Byte] (thisBlockSize) - val hasRead = bais.read (tempByteArray, 0, thisBlockSize) - - retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) - blockID += 1 - } - bais.close - - var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - private def unBlockifyObject[A]: A = { - var retByteArray = new Array[Byte] (totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray, - i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject (retByteArray) - } - - private def byteArrayToObject[A] (bytes: Array[Byte]): A = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) - val retVal = in.readObject.asInstanceOf[A] - in.close - return retVal - } - - def getMasterListenPort (variableUUID: UUID): Int = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var masterListenPort: Int = SourceInfo.TxOverGoToHDFS - - var retriesLeft = ChainedBroadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out the guide - val clientSocketToTracker = - new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream (clientSocketToTracker.getOutputStream) - oosTracker.flush - val oisTracker = - new ObjectInputStream (clientSocketToTracker.getInputStream) - - // Send UUID and receive masterListenPort - oosTracker.writeObject (uuid) - oosTracker.flush - masterListenPort = oisTracker.readObject.asInstanceOf[Int] - } catch { - case e: Exception => { - logInfo ("getMasterListenPort had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close - } - if (oosTracker != null) { - oosTracker.close - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close - } - } - retriesLeft -= 1 - - Thread.sleep (ChainedBroadcast.ranGen.nextInt ( - ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) + - ChainedBroadcast.MinKnockInterval) - - } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) - - logInfo ("Got this guidePort from Tracker: " + masterListenPort) - return masterListenPort - } - - def receiveBroadcast (variableUUID: UUID): Boolean = { - val masterListenPort = getMasterListenPort (variableUUID) - - if (masterListenPort == SourceInfo.TxOverGoToHDFS || - masterListenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - var clientSocketToMaster: Socket = null - var oosMaster: ObjectOutputStream = null - var oisMaster: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = ChainedBroadcast.MaxRetryCount - do { - // Connect to Master and send this worker's Information - clientSocketToMaster = - new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort) - // TODO: Guiding object connection is reusable - oosMaster = - new ObjectOutputStream (clientSocketToMaster.getOutputStream) - oosMaster.flush - oisMaster = - new ObjectInputStream (clientSocketToMaster.getInputStream) - - logInfo ("Connected to Master's guiding object") - - // Send local source information - oosMaster.writeObject(SourceInfo (hostAddress, listenPort, - SourceInfo.UnusedParam, SourceInfo.UnusedParam)) - oosMaster.flush - - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll - } - totalBytes = sourceInfo.totalBytes - - logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission (sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Master will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Master - oosMaster.writeObject (sourceInfo) - - if (oisMaster != null) { - oisMaster.close - } - if (oosMaster != null) { - oosMaster.close - } - if (clientSocketToMaster != null) { - clientSocketToMaster.close - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - // Tries to receive broadcast from the source and returns Boolean status. - // This might be called multiple times to retry a defined number of times. - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = - new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream (clientSocketToSource.getOutputStream) - oosSource.flush - oisSource = - new ObjectInputStream (clientSocketToSource.getInputStream) - - logInfo ("Inside receiveSingleTransmission") - logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll - } - } - } catch { - case e: Exception => { - logInfo ("receiveSingleTransmission had a " + e) - } - } finally { - if (oisSource != null) { - oisSource.close - } - if (oosSource != null) { - oosSource.close - } - if (clientSocketToSource != null) { - clientSocketToSource.close - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo] () - - override def run: Unit = { - var threadPool = Broadcast.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (0) - guidePort = serverSocket.getLocalPort - logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { - guidePortLock.notifyAll - } - - try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo ("GuideMultipleRequests Timeout.") - - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // pqOfSources.size - 1, because it includes the Guide itself - if (pqOfSources.size > 1 && - setOfCompletedSources.size == pqOfSources.size - 1) { - stopBroadcast = true - } - } - } - if (clientSocket != null) { - logInfo ("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute (new GuideSingleRequest (clientSocket)) - } catch { - // In failure, close the socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - - logInfo ("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - ChainedBroadcast.unregisterValue (uuid) - } finally { - if (serverSocket != null) { - logInfo ("GuideMultipleRequests now stopping...") - serverSocket.close - } - } - - // Shutdown the thread pool - threadPool.shutdown - } - - private def sendStopBroadcastNotifications: Unit = { - pqOfSources.synchronized { - var pqIter = pqOfSources.iterator - while (pqIter.hasNext) { - var sourceInfo = pqIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = - new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream (guideSocketToSource.getOutputStream) - gosSource.flush - gisSource = - new ObjectInputStream (guideSocketToSource.getInputStream) - - // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 - gosSource.writeObject ((SourceInfo.StopBroadcast, - SourceInfo.StopBroadcast)) - gosSource.flush - } catch { - case e: Exception => { - logInfo ("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close - } - if (gosSource != null) { - gosSource.close - } - if (guideSocketToSource != null) { - guideSocketToSource.close - } - } - } - } - } - - class GuideSingleRequest (val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - oos.flush - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run: Unit = { - try { - logInfo ("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource (sourceInfo) - logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject (selectedSourceInfo) - oos.flush - - // Add this new (if it can finish) source to the PQ of sources - thisWorkerInfo = SourceInfo (sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo) - pqOfSources.add (thisWorkerInfo) - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in pqOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // This should work since SourceInfo is a case class - assert (pqOfSources.contains (selectedSourceInfo)) - - // Remove first - pqOfSources.remove (selectedSourceInfo) - // TODO: Removing a source based on just one failure notification! - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources += thisWorkerInfo - - selectedSourceInfo.currentLeechers -= 1 - - // Put it back - pqOfSources.add (selectedSourceInfo) - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - // Assuming that exception caused due to receiver worker failure. - // Remove failed worker from pqOfSources and update leecherCount of - // corresponding source worker - pqOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - pqOfSources.remove (selectedSourceInfo) - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add (selectedSourceInfo) - } - - // Remove thisWorkerInfo - if (pqOfSources != null) { - pqOfSources.remove (thisWorkerInfo) - } - } - } - } finally { - ois.close - oos.close - clientSocket.close - } - } - - // TODO: Caller must have a synchronized block on pqOfSources - // TODO: If a worker fails to get the broadcasted variable from a source and - // comes back to Master, this function might choose the worker itself as a - // source tp create a dependency cycle (this worker was put into pqOfSources - // as a streming source when it first arrived). The length of this cycle can - // be arbitrarily long. - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one based on the ordering strategy (e.g., least leechers etc.) - // take is a blocking call removing the element from PQ - var selectedSource = pqOfSources.poll - assert (selectedSource != null) - // Update leecher count - selectedSource.currentLeechers += 1 - // Add it back and then return - pqOfSources.add (selectedSource) - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - override def run: Unit = { - var threadPool = Broadcast.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (0) - listenPort = serverSocket.getLocalPort - logInfo ("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { - listenPortLock.notifyAll - } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo ("ServeMultipleRequests Timeout.") - } - } - if (clientSocket != null) { - logInfo ("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute (new ServeSingleRequest (clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - } finally { - if (serverSocket != null) { - logInfo ("ServeMultipleRequests now stopping...") - serverSocket.close - } - } - - // Shutdown the thread pool - threadPool.shutdown - } - - class ServeSingleRequest (val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - oos.flush - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run: Unit = { - try { - logInfo ("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - if (sendFrom == SourceInfo.StopBroadcast && - sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - // Carry on - sendObject - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - logInfo ("ServeSingleRequest had a " + e) - } - } finally { - logInfo ("ServeSingleRequest is closing streams and sockets") - ois.close - oos.close - clientSocket.close - } - } - - private def sendObject: Unit = { - // Wait till receiving the SourceInfo from Master - while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait - } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait - } - } - try { - oos.writeObject (arrayOfBlocks(i)) - oos.flush - } catch { - case e: Exception => { - logInfo ("sendObject had a " + e) - } - } - logInfo ("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -class ChainedBroadcastFactory -extends BroadcastFactory { - def initialize (isMaster: Boolean) = ChainedBroadcast.initialize (isMaster) - def newBroadcast[T] (value_ : T, isLocal: Boolean) = - new ChainedBroadcast[T] (value_, isLocal) -} - -private object ChainedBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuidePortMap = Map[UUID, Int] () - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress - private var MasterTrackerPort_ : Int = 22222 - private var BlockSize_ : Int = 512 * 1024 - private var MaxRetryCount_ : Int = 2 - - private var TrackerSocketTimeout_ : Int = 50000 - private var ServerSocketTimeout_ : Int = 10000 - - private var trackMV: TrackMultipleValues = null - - private var MinKnockInterval_ = 500 - private var MaxKnockInterval_ = 999 - - def initialize (isMaster__ : Boolean): Unit = { - synchronized { - if (!initialized) { - // Fix for issue #42 - MasterHostAddress_ = - System.getProperty ("spark.broadcast.masterHostAddress", "") - MasterTrackerPort_ = - System.getProperty ("spark.broadcast.masterTrackerPort", "22222").toInt - BlockSize_ = - System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 - MaxRetryCount_ = - System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt - - TrackerSocketTimeout_ = - System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt - ServerSocketTimeout_ = - System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt - - MinKnockInterval_ = - System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt - MaxKnockInterval_ = - System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt - - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon (true) - trackMV.start - logInfo ("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def isMaster = isMaster_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - def registerValue (uuid: UUID, guidePort: Int): Unit = { - valueToGuidePortMap.synchronized { - valueToGuidePortMap += (uuid -> guidePort) - logInfo ("New value registered with the Tracker " + valueToGuidePortMap) - } - } - - def unregisterValue (uuid: UUID): Unit = { - valueToGuidePortMap.synchronized { - valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS - logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run: Unit = { - var threadPool = Broadcast.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort) - logInfo ("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo ("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute (new Thread { - override def run: Unit = { - val oos = new ObjectOutputStream (clientSocket.getOutputStream) - oos.flush - val ois = new ObjectInputStream (clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var guidePort = - if (valueToGuidePortMap.contains (uuid)) { - valueToGuidePortMap (uuid) - } else SourceInfo.TxNotStartedRetry - logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) - oos.writeObject (guidePort) - } catch { - case e: Exception => { - logInfo ("TrackMultipleValues had a " + e) - } - } finally { - ois.close - oos.close - clientSocket.close - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close - } - } - } - } finally { - serverSocket.close - } - - // Shutdown the thread pool - threadPool.shutdown - } - } -} diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index 54e169a1a1..b5e34629a2 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -9,6 +9,8 @@ import scala.collection.mutable.ArrayBuffer import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver} import mesos.{TaskDescription, TaskState, TaskStatus} +import spark.broadcast._ + /** * The Mesos executor for Spark. */ diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala index da3897b117..6c7f3dede2 100644 --- a/core/src/main/scala/spark/LocalFileShuffle.scala +++ b/core/src/main/scala/spark/LocalFileShuffle.scala @@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.{ArrayBuffer, HashMap} +import spark._ object LocalFileShuffle extends Logging { private var initialized = false @@ -29,9 +30,9 @@ object LocalFileShuffle extends Logging { while (!foundLocalDir && tries < 10) { tries += 1 try { - localDirUuid = UUID.randomUUID() + localDirUuid = UUID.randomUUID localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists()) { + if (!localDir.exists) { localDir.mkdirs() foundLocalDir = true } @@ -47,6 +48,7 @@ object LocalFileShuffle extends Logging { shuffleDir = new File(localDir, "shuffle") shuffleDir.mkdirs() logInfo("Shuffle dir: " + shuffleDir) + val extServerPort = System.getProperty( "spark.localFileShuffle.external.server.port", "-1").toInt if (extServerPort != -1) { @@ -65,6 +67,7 @@ object LocalFileShuffle extends Logging { serverUri = server.uri } initialized = true + logInfo("Local URI: " + serverUri) } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ee9d747de6..265fbae60d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -242,6 +242,44 @@ extends RDD[Array[T]](prev.context) { } } + def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = { + val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) } + val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) } + (vs ++ ws).groupByKey(numSplits).flatMap { + case (k, seq) => { + val vbuf = new ArrayBuffer[V] + val wbuf = new ArrayBuffer[Option[W]] + seq.foreach(_ match { + case Left(v) => vbuf += v + case Right(w) => wbuf += Some(w) + }) + if (wbuf.isEmpty) { + wbuf += None + } + for (v <- vbuf; w <- wbuf) yield (k, (v, w)) + } + } + } + + def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = { + val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) } + val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) } + (vs ++ ws).groupByKey(numSplits).flatMap { + case (k, seq) => { + val vbuf = new ArrayBuffer[Option[V]] + val wbuf = new ArrayBuffer[W] + seq.foreach(_ match { + case Left(v) => vbuf += Some(v) + case Right(w) => wbuf += w + }) + if (vbuf.isEmpty) { + vbuf += None + } + for (v <- vbuf; w <- wbuf) yield (k, (v, w)) + } + } + } + def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) @@ -261,6 +299,14 @@ extends RDD[Array[T]](prev.context) { join(other, numCores) } + def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { + leftOuterJoin(other, numCores) + } + + def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { + rightOuterJoin(other, numCores) + } + def numCores = self.context.numCores def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*) @@ -301,6 +347,23 @@ extends RDD[Array[T]](prev.context) { (k, (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]])) } } + + def lookup(key: K): Seq[V] = { + self.partitioner match { + case Some(p) => + val index = p.getPartition(key) + def process(it: Iterator[(K, V)]): Seq[V] = { + val buf = new ArrayBuffer[V] + for ((k, v) <- it if k == key) + buf += v + buf + } + val res = self.context.runJob(self, process, Array(index)) + res(0) + case None => + throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner") + } + } } class MappedValuesRDD[K, V, U]( diff --git a/core/src/main/scala/spark/Shuffle.scala b/core/src/main/scala/spark/Shuffle.scala deleted file mode 100644 index 4c5649b537..0000000000 --- a/core/src/main/scala/spark/Shuffle.scala +++ /dev/null @@ -1,15 +0,0 @@ -package spark - -/** - * A trait for shuffle system. Given an input RDD and combiner functions - * for PairRDDExtras.combineByKey(), returns an output RDD. - */ -@serializable -trait Shuffle[K, V, C] { - def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] -} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 5d27a3d46b..0bc3abf114 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat +import spark.broadcast._ class SparkContext( master: String, diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 4476fe44c1..1d77c357c3 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -3,6 +3,7 @@ package spark import java.io._ import java.net.InetAddress import java.util.UUID +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -117,12 +118,43 @@ object Utils { /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4) */ - def localIpAddress(): String = { - // Get local IP as an array of four bytes - val bytes = InetAddress.getLocalHost().getAddress() - // Convert the bytes to ints (keeping in mind that they may be negative) - // and join them into a string - return bytes.map(b => (b.toInt + 256) % 256).mkString(".") + def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress + + /** + * Returns a standard ThreadFactory except all threads are daemons + */ + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread (r) + t.setDaemon (true) + return t + } + } + } + + /** + * Wrapper over newCachedThreadPool + */ + def newDaemonCachedThreadPool(): ThreadPoolExecutor = { + var threadPool = + Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool + } + + /** + * Wrapper over newFixedThreadPool + */ + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool } /** diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala new file mode 100644 index 0000000000..220456a210 --- /dev/null +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -0,0 +1,1355 @@ +package spark.broadcast + +import java.io._ +import java.net._ +import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.{ListBuffer, Map, Set} +import scala.math + +import spark._ + +@serializable +class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { + + def value = value_ + + BitTorrentBroadcast.synchronized { + BitTorrentBroadcast.values.put(uuid, value_) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var hasBlocksBitVector: BitSet = null + @transient var numCopiesSent: Array[Int] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = new AtomicInteger(0) + // CHANGED: BlockSize in the Broadcast object is expected to change over time + @transient var blockSize = Broadcast.BlockSize + + // Used ONLY by Master to track how many unique blocks have been sent out + @transient var sentBlocks = new AtomicInteger(0) + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + + @transient var listOfSources = ListBuffer[SourceInfo]() + + @transient var serveMR: ServeMultipleRequests = null + + // Used only in Master + @transient var guideMR: GuideMultipleRequests = null + + // Used only in Workers + @transient var ttGuide: TalkToGuide = null + + @transient var rxSpeeds = new SpeedTracker + @transient var txSpeeds = new SpeedTracker + + @transient var hostAddress = Utils.localIpAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast + } + + def sendBroadcast(): Unit = { + logInfo("Local host address: " + hostAddress) + + // Store a persistent copy in HDFS + // TODO: Turned OFF for now. Related to persistence + // val out = new ObjectOutputStream(BroadcastCH.openFileForWriting(uuid)) + // out.writeObject(value_) + // out.close() + // FIXME: Fix this at some point + hasCopyInHDFS = true + + // Create a variableInfo object and store it in valueInfos + var variableInfo = Broadcast.blockifyObject(value_) + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks.set(variableInfo.totalBlocks) + + // Guide has all the blocks + hasBlocksBitVector = new BitSet(totalBlocks) + hasBlocksBitVector.set(0, totalBlocks) + + // Guide still hasn't sent any block + numCopiesSent = new Array[Int](totalBlocks) + + guideMR = new GuideMultipleRequests + guideMR.setDaemon(true) + guideMR.start() + logInfo("GuideMultipleRequests started...") + + // Must always come AFTER guideMR is created + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + // Must always come AFTER serveMR is created + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Must always come AFTER listenPort is created + val masterSource = + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + hasBlocksBitVector.synchronized { + masterSource.hasBlocksBitVector = hasBlocksBitVector + } + + // In the beginning, this is the only known source to Guide + listOfSources += masterSource + + // Register with the Tracker + registerBroadcast(uuid, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes, blockSize)) + } + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject + BitTorrentBroadcast.synchronized { + val cachedVal = BitTorrentBroadcast.values.get(uuid) + + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Only the first worker in a node can ever be inside this 'else' + initializeWorkerVariables + + logInfo("Local host address: " + hostAddress) + + // Start local ServeMultipleRequests thread first + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(uuid) + // If does not succeed, then get from HDFS copy + if (receptionSucceeded) { + value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + BitTorrentBroadcast.values.put(uuid, value_) + } else { + // TODO: This part won't work, cause HDFS writing is turned OFF + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + BitTorrentBroadcast.values.put(uuid, value_) + fileIn.close() + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + // Initialize variables in the worker node. Master sends everything as 0/null + private def initializeWorkerVariables: Unit = { + arrayOfBlocks = null + hasBlocksBitVector = null + numCopiesSent = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = new AtomicInteger(0) + blockSize = -1 + + listenPortLock = new Object + totalBlocksLock = new Object + + serveMR = null + ttGuide = null + + rxSpeeds = new SpeedTracker + txSpeeds = new SpeedTracker + + hostAddress = Utils.localIpAddress + listenPort = -1 + + listOfSources = ListBuffer[SourceInfo]() + + stopBroadcast = false + } + + private def registerBroadcast(uuid: UUID, gInfo: SourceInfo): Unit = { + val socket = new Socket(Broadcast.MasterHostAddress, + Broadcast.MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send UUID of this broadcast + oosST.writeObject(uuid) + oosST.flush() + + // Send this tracker's information + oosST.writeObject(gInfo) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + private def unregisterBroadcast(uuid: UUID): Unit = { + val socket = new Socket(Broadcast.MasterHostAddress, + Broadcast.MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send UUID of this broadcast + oosST.writeObject(uuid) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + private def getLocalSourceInfo: SourceInfo = { + // Wait till hostName and listenPort are OK + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Wait till totalBlocks and totalBytes are OK + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + var localSourceInfo = SourceInfo( + hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + + localSourceInfo.hasBlocks = hasBlocks.get + + hasBlocksBitVector.synchronized { + localSourceInfo.hasBlocksBitVector = hasBlocksBitVector + } + + return localSourceInfo + } + + // Add new SourceInfo to the listOfSources. Update if it exists already. + // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance + private def addToListOfSources(newSourceInfo: SourceInfo): Unit = { + listOfSources.synchronized { + if (listOfSources.contains(newSourceInfo)) { + listOfSources = listOfSources - newSourceInfo + } + listOfSources += newSourceInfo + } + } + + private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]): Unit = { + newSourceInfos.foreach { newSourceInfo => + addToListOfSources(newSourceInfo) + } + } + + class TalkToGuide(gInfo: SourceInfo) + extends Thread with Logging { + override def run: Unit = { + + // Keep exchaning information until all blocks have been received + while (hasBlocks.get < totalBlocks) { + talkOnce + Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( + Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + + Broadcast.MinKnockInterval) + } + + // Talk one more time to let the Guide know of reception completion + talkOnce + } + + // Connect to Guide and send this worker's information + private def talkOnce: Unit = { + var clientSocketToGuide: Socket = null + var oosGuide: ObjectOutputStream = null + var oisGuide: ObjectInputStream = null + + clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) + oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) + oosGuide.flush() + oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) + + // Send local information + oosGuide.writeObject(getLocalSourceInfo) + oosGuide.flush() + + // Receive source information from Guide + var suitableSources = + oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] + logInfo("Received suitableSources from Master " + suitableSources) + + addToListOfSources(suitableSources) + + oisGuide.close() + oosGuide.close() + clientSocketToGuide.close() + } + } + + def getGuideInfo(variableUUID: UUID): SourceInfo = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToHDFS) + + var retriesLeft = Broadcast.MaxRetryCount + do { + try { + // Connect to the tracker to find out GuideInfo + clientSocketToTracker = + new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) + oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + // Send messageType/intention + oosTracker.writeObject(Broadcast.FIND_BROADCAST_TRACKER) + oosTracker.flush() + + // Send UUID and receive GuideInfo + oosTracker.writeObject(uuid) + oosTracker.flush() + gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] + } catch { + case e: Exception => { + logInfo("getGuideInfo had a " + e) + } + } finally { + if (oisTracker != null) { + oisTracker.close() + } + if (oosTracker != null) { + oosTracker.close() + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close() + } + } + + Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( + Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + + Broadcast.MinKnockInterval) + + retriesLeft -= 1 + } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) + + logInfo("Got this guidePort from Tracker: " + gInfo.listenPort) + return gInfo + } + + def receiveBroadcast(variableUUID: UUID): Boolean = { + val gInfo = getGuideInfo(variableUUID) + + if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS || + gInfo.listenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Setup initial states of variables + totalBlocks = gInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) + hasBlocksBitVector = new BitSet(totalBlocks) + numCopiesSent = new Array[Int](totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = gInfo.totalBytes + blockSize = gInfo.blockSize + + // Start ttGuide to periodically talk to the Guide + var ttGuide = new TalkToGuide(gInfo) + ttGuide.setDaemon(true) + ttGuide.start() + logInfo("TalkToGuide started...") + + // Start pController to run TalkToPeer threads + var pcController = new PeerChatterController + pcController.setDaemon(true) + pcController.start() + logInfo("PeerChatterController started...") + + // FIXME: Must fix this. This might never break if broadcast fails. + // We should be able to break and send false. Also need to kill threads + while (hasBlocks.get < totalBlocks) { + Thread.sleep(Broadcast.MaxKnockInterval) + } + + return true + } + + class PeerChatterController + extends Thread with Logging { + private var peersNowTalking = ListBuffer[SourceInfo]() + // TODO: There is a possible bug with blocksInRequestBitVector when a + // certain bit is NOT unset upon failure resulting in an infinite loop. + private var blocksInRequestBitVector = new BitSet(totalBlocks) + + override def run: Unit = { + var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxRxSlots) + + while (hasBlocks.get < totalBlocks) { + var numThreadsToCreate = + math.min(listOfSources.size, Broadcast.MaxRxSlots) - + threadPool.getActiveCount + + while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { + var peerToTalkTo = pickPeerToTalkToRandom + + if (peerToTalkTo != null) + logInfo("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) + else + logInfo("No peer chosen...") + + if (peerToTalkTo != null) { + threadPool.execute(new TalkToPeer(peerToTalkTo)) + + // Add to peersNowTalking. Remove in the thread. We have to do this + // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once + peersNowTalking.synchronized { + peersNowTalking += peerToTalkTo + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before starting some more threads + Thread.sleep(Broadcast.MinKnockInterval) + } + // Shutdown the thread pool + threadPool.shutdown() + } + + // Right now picking the one that has the most blocks this peer wants + // Also picking peer randomly if no one has anything interesting + private def pickPeerToTalkToRandom: SourceInfo = { + var curPeer: SourceInfo = null + var curMax = 0 + + logInfo("Picking peers to talk to...") + + // Find peers that are not connected right now + var peersNotInUse = ListBuffer[SourceInfo]() + listOfSources.synchronized { + peersNowTalking.synchronized { + peersNotInUse = listOfSources -- peersNowTalking + } + } + + // Select the peer that has the most blocks that this receiver does not + peersNotInUse.foreach { eachSource => + var tempHasBlocksBitVector: BitSet = null + hasBlocksBitVector.synchronized { + tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) + tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) + + if (tempHasBlocksBitVector.cardinality > curMax) { + curPeer = eachSource + curMax = tempHasBlocksBitVector.cardinality + } + } + + // TODO: Always pick randomly or randomly pick randomly? + // Now always picking randomly + if (curPeer == null && peersNotInUse.size > 0) { + // Pick uniformly the i'th required peer + var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size) + + var peerIter = peersNotInUse.iterator + curPeer = peerIter.next + + while (i > 0) { + curPeer = peerIter.next + i = i - 1 + } + } + + return curPeer + } + + // Picking peer with the weight of rare blocks it has + private def pickPeerToTalkToRarestFirst: SourceInfo = { + // Find peers that are not connected right now + var peersNotInUse = ListBuffer[SourceInfo]() + listOfSources.synchronized { + peersNowTalking.synchronized { + peersNotInUse = listOfSources -- peersNowTalking + } + } + + // Count the number of copies of each block in the neighborhood + var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) + + listOfSources.synchronized { + listOfSources.foreach { eachSource => + for (i <- 0 until totalBlocks) { + numCopiesPerBlock(i) += + ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) + } + } + } + + // TODO: A block is rare if there are at most 2 copies of that block + // TODO: This CONSTANT could be a function of the neighborhood size + var rareBlocksIndices = ListBuffer[Int]() + for (i <- 0 until totalBlocks) { + if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { + rareBlocksIndices += i + } + } + + // Find peers with rare blocks + var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() + var totalRareBlocks = 0 + + peersNotInUse.foreach { eachPeer => + var hasRareBlocks = 0 + rareBlocksIndices.foreach { rareBlock => + if (eachPeer.hasBlocksBitVector.get(rareBlock)) { + hasRareBlocks += 1 + } + } + + if (hasRareBlocks > 0) { + peersWithRareBlocks += ((eachPeer, hasRareBlocks)) + } + totalRareBlocks += hasRareBlocks + } + + // Select a peer from peersWithRareBlocks based on weight calculated from + // unique rare blocks + var selectedPeerToTalkTo: SourceInfo = null + + if (peersWithRareBlocks.size > 0) { + // Sort the peers based on how many rare blocks they have + peersWithRareBlocks.sortBy(_._2) + + var randomNumber = BitTorrentBroadcast.ranGen.nextDouble + var tempSum = 0.0 + + var i = 0 + do { + tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) + if (tempSum >= randomNumber) { + selectedPeerToTalkTo = peersWithRareBlocks(i)._1 + } + i += 1 + } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) + } + + if (selectedPeerToTalkTo == null) { + selectedPeerToTalkTo = pickPeerToTalkToRandom + } + + return selectedPeerToTalkTo + } + + class TalkToPeer(peerToTalkTo: SourceInfo) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + override def run: Unit = { + // TODO: There is a possible bug here regarding blocksInRequestBitVector + var blockToAskFor = -1 + + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Broadcast.MaxKnockInterval) + + logInfo("TalkToPeer started... => " + peerToTalkTo) + + try { + // Connect to the source + peerSocketToSource = + new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + oisSource = + new ObjectInputStream(peerSocketToSource.getInputStream) + + // Receive latest SourceInfo from peerToTalkTo + var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] + // Update listOfSources + addToListOfSources(newPeerToTalkTo) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel + + // Send the latest SourceInfo + oosSource.writeObject(getLocalSourceInfo) + oosSource.flush() + + var keepReceiving = true + + while (hasBlocks.get < totalBlocks && keepReceiving) { + blockToAskFor = + pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) + + // No block to request + if (blockToAskFor < 0) { + // Nothing to receive from newPeerToTalkTo + keepReceiving = false + } else { + // Let other threads know that blockToAskFor is being requested + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set(blockToAskFor) + } + + // Start with sending the blockID + oosSource.writeObject(blockToAskFor) + oosSource.flush() + + // CHANGED: Master might send some other block than the one + // requested to ensure fast spreading of all blocks. + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + logInfo("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") + + if (!hasBlocksBitVector.get(bcBlock.blockID)) { + arrayOfBlocks(bcBlock.blockID) = bcBlock + + // Update the hasBlocksBitVector first + hasBlocksBitVector.synchronized { + hasBlocksBitVector.set(bcBlock.blockID) + hasBlocks.getAndIncrement + } + + rxSpeeds.addDataPoint(peerToTalkTo, receptionTime) + + // Some block(may NOT be blockToAskFor) has arrived. + // In any case, blockToAskFor is not in request any more + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set(blockToAskFor, false) + } + + // Reset blockToAskFor to -1. Else it will be considered missing + blockToAskFor = -1 + } + + // Send the latest SourceInfo + oosSource.writeObject(getLocalSourceInfo) + oosSource.flush() + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("TalktoPeer had a " + e) + // FIXME: Remove 'newPeerToTalkTo' from listOfSources + // We probably should have the following in some form, but not + // really here. This exception can happen if the sender just breaks connection + // listOfSources.synchronized { + // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) + // listOfSources = listOfSources - peerToTalkTo + // } + } + } finally { + // blockToAskFor != -1 => there was an exception + if (blockToAskFor != -1) { + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set(blockToAskFor, false) + } + } + + cleanUpConnections() + } + } + + // Right now it picks a block uniformly that this peer does not have + private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { + var needBlocksBitVector: BitSet = null + + // Blocks already present + hasBlocksBitVector.synchronized { + needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + + // Include blocks already in transmission ONLY IF + // BitTorrentBroadcast.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) { + blocksInRequestBitVector.synchronized { + needBlocksBitVector.or(blocksInRequestBitVector) + } + } + + // Find blocks that are neither here nor in transit + needBlocksBitVector.flip(0, needBlocksBitVector.size) + + // Blocks that should/can be requested + needBlocksBitVector.and(txHasBlocksBitVector) + + if (needBlocksBitVector.cardinality == 0) { + return -1 + } else { + // Pick uniformly the i'th required block + var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality) + var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) + + while (i > 0) { + pickedBlockIndex = + needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) + i -= 1 + } + + return pickedBlockIndex + } + } + + // Pick the block that seems to be the rarest across sources + private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { + var needBlocksBitVector: BitSet = null + + // Blocks already present + hasBlocksBitVector.synchronized { + needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + + // Include blocks already in transmission ONLY IF + // BitTorrentBroadcast.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) { + blocksInRequestBitVector.synchronized { + needBlocksBitVector.or(blocksInRequestBitVector) + } + } + + // Find blocks that are neither here nor in transit + needBlocksBitVector.flip(0, needBlocksBitVector.size) + + // Blocks that should/can be requested + needBlocksBitVector.and(txHasBlocksBitVector) + + if (needBlocksBitVector.cardinality == 0) { + return -1 + } else { + // Count the number of copies for each block across all sources + var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) + + listOfSources.synchronized { + listOfSources.foreach { eachSource => + for (i <- 0 until totalBlocks) { + numCopiesPerBlock(i) += + ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) + } + } + } + + // Find the minimum + var minVal = Integer.MAX_VALUE + for (i <- 0 until totalBlocks) { + if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { + minVal = numCopiesPerBlock(i) + } + } + + // Find the blocks with the least copies that this peer does not have + var minBlocksIndices = ListBuffer[Int]() + for (i <- 0 until totalBlocks) { + if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { + minBlocksIndices += i + } + } + + // Now select a random index from minBlocksIndices + if (minBlocksIndices.size == 0) { + return -1 + } else { + // Pick uniformly the i'th index + var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size) + return minBlocksIndices(i) + } + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + + // Delete from peersNowTalking + peersNowTalking.synchronized { + peersNowTalking = peersNowTalking - peerToTalkTo + } + } + } + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo]() + + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + guidePort = serverSocket.getLocalPort + logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { + guidePortLock.notifyAll + } + + try { + // Don't stop until there is a copy in HDFS + while (!stopBroadcast || !hasCopyInHDFS) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // listOfSources.size - 1, because it includes the Guide itself + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + } + } + } + if (clientSocket != null) { + logInfo("Guide: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new GuideSingleRequest(clientSocket)) + } catch { + // In failure, close the socket here; else, thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + + // Shutdown the thread pool + threadPool.shutdown() + + logInfo("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + unregisterBroadcast(uuid) + } finally { + if (serverSocket != null) { + logInfo("GuideMultipleRequests now stopping...") + serverSocket.close() + } + } + } + + private def sendStopBroadcastNotifications: Unit = { + listOfSources.synchronized { + listOfSources.foreach { sourceInfo => + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream(guideSocketToSource.getOutputStream) + gosSource.flush() + gisSource = + new ObjectInputStream(guideSocketToSource.getInputStream) + + // Throw away whatever comes in + gisSource.readObject.asInstanceOf[SourceInfo] + + // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast + gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) + gosSource.flush() + } catch { + case e: Exception => { + logInfo("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close() + } + if (gosSource != null) { + gosSource.close() + } + if (guideSocketToSource != null) { + guideSocketToSource.close() + } + } + } + } + } + + class GuideSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var sourceInfo: SourceInfo = null + private var selectedSources: ListBuffer[SourceInfo] = null + + override def run: Unit = { + try { + logInfo("new GuideSingleRequest is running") + // Connecting worker is sending in its information + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Select a suitable source and send it back to the worker + selectedSources = selectSuitableSources(sourceInfo) + logInfo("Sending selectedSources:" + selectedSources) + oos.writeObject(selectedSources) + oos.flush() + + // Add this source to the listOfSources + addToListOfSources(sourceInfo) + } catch { + case e: Exception => { + // Assuming exception caused by receiver failure: remove + if (listOfSources != null) { + listOfSources.synchronized { + listOfSources = listOfSources - sourceInfo + } + } + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + + // Randomly select some sources to send back + private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { + var selectedSources = ListBuffer[SourceInfo]() + + // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' + // then add skipSourceInfo to setOfCompletedSources. Return blank. + if (skipSourceInfo.hasBlocks == totalBlocks) { + setOfCompletedSources.synchronized { + setOfCompletedSources += skipSourceInfo + } + return selectedSources + } + + listOfSources.synchronized { + if (listOfSources.size <= Broadcast.MaxPeersInGuideResponse) { + selectedSources = listOfSources.clone + } else { + var picksLeft = Broadcast.MaxPeersInGuideResponse + var alreadyPicked = new BitSet(listOfSources.size) + + while (picksLeft > 0) { + var i = -1 + + do { + i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size) + } while (alreadyPicked.get(i)) + + var peerIter = listOfSources.iterator + var curPeer = peerIter.next + + // Set the BitSet before i is decremented + alreadyPicked.set(i) + + while (i > 0) { + curPeer = peerIter.next + i = i - 1 + } + + selectedSources += curPeer + + picksLeft = picksLeft - 1 + } + } + } + + // Remove the receiving source (if present) + selectedSources = selectedSources - skipSourceInfo + + return selectedSources + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + // Server at most Broadcast.MaxTxSlots peers + var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxTxSlots) + + override def run: Unit = { + var serverSocket = new ServerSocket(0) + listenPort = serverSocket.getLocalPort + + logInfo("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("ServeMultipleRequests Timeout.") + } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ServeSingleRequest(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ServeMultipleRequests now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ServeSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ServeSingleRequest is running") + + override def run: Unit = { + try { + // Send latest local SourceInfo to the receiver + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(getLocalSourceInfo) + oos.flush() + + // Receive latest SourceInfo from the receiver + var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + addToListOfSources(rxSourceInfo) + } + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = Broadcast.MaxChatBlocks + + while (!stopBroadcast && keepSending && numBlocksToSend > 0) { + // Receive which block to send + var blockToSend = ois.readObject.asInstanceOf[Int] + + // If it is master AND at least one copy of each block has not been + // sent out already, MODIFY blockToSend + if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) { + blockToSend = sentBlocks.getAndIncrement + } + + // Send the block + sendBlock(blockToSend) + rxSourceInfo.hasBlocksBitVector.set(blockToSend) + + numBlocksToSend -= 1 + + // Receive latest SourceInfo from the receiver + rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) + addToListOfSources(rxSourceInfo) + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + if (curTime - startTime >= Broadcast.MaxChatTime && + threadPool.getQueue.size > 0) { + keepSending = false + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo("ServeSingleRequest had a " + e) + } + } finally { + logInfo("ServeSingleRequest is closing streams and sockets") + ois.close() + // TODO: The following line causes a "java.net.SocketException: Socket closed" + oos.close() + clientSocket.close() + } + } + + private def sendBlock(blockToSend: Int): Unit = { + try { + oos.writeObject(arrayOfBlocks(blockToSend)) + oos.flush() + } catch { + case e: Exception => { + logInfo("sendBlock had a " + e) + } + } + logInfo("Sent block: " + blockToSend + " to " + clientSocket) + } + } + } +} + +class BitTorrentBroadcastFactory +extends BroadcastFactory { + def initialize(isMaster: Boolean) = { + BitTorrentBroadcast.initialize(isMaster) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean) = + new BitTorrentBroadcast[T](value_, isLocal) +} + +private object BitTorrentBroadcast +extends Logging { + val values = SparkEnv.get.cache.newKeySpace() + + var valueToGuideMap = Map[UUID, SourceInfo]() + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var trackMV: TrackMultipleValues = null + + def initialize(isMaster__ : Boolean): Unit = { + synchronized { + if (!initialized) { + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start() + // TODO: Logging the following line makes the Spark framework ID not + // getting logged, cause it calls logInfo before log4j is initialized + logInfo("TrackMultipleValues started...") + } + + // Initialize DfsBroadcast to be used for broadcast variable persistence + // TODO: Think about persistence + DfsBroadcast.initialize + + initialized = true + } + } + } + + def isMaster = isMaster_ + + class TrackMultipleValues + extends Thread with Logging { + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) + logInfo("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // First, read message type + val messageType = ois.readObject.asInstanceOf[Int] + + if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + // Receive hostAddress and listenPort + val gInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Add to the map + valueToGuideMap.synchronized { + valueToGuideMap += (uuid -> gInfo) + } + + logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + + // Remove from the map + valueToGuideMap.synchronized { + valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToHDFS) + logInfo("Value unregistered from the Tracker " + valueToGuideMap) + } + + logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + + var gInfo = + if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) + else SourceInfo("", SourceInfo.TxNotStartedRetry) + + logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) + + // Send reply back + oos.writeObject(gInfo) + oos.flush() + } else if (messageType == Broadcast.GET_UPDATED_SHARE) { + // TODO: Not implemented + } else { + throw new SparkException("Undefined messageType at TrackMultipleValues") + } + } catch { + case e: Exception => { + logInfo("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } +} diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala new file mode 100644 index 0000000000..f39fb9de69 --- /dev/null +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -0,0 +1,228 @@ +package spark.broadcast + +import java.io._ +import java.net._ +import java.util.{BitSet, UUID} +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} + +import spark._ + +@serializable +trait Broadcast[T] { + val uuid = UUID.randomUUID + + def value: T + + // We cannot have an abstract readObject here due to some weird issues with + // readObject having to be 'private' in sub-classes. Possibly a Scala bug! + + override def toString = "spark.Broadcast(" + uuid + ")" +} + +object Broadcast +extends Logging { + // Messages + val REGISTER_BROADCAST_TRACKER = 0 + val UNREGISTER_BROADCAST_TRACKER = 1 + val FIND_BROADCAST_TRACKER = 2 + val GET_UPDATED_SHARE = 3 + + private var initialized = false + private var isMaster_ = false + private var broadcastFactory: BroadcastFactory = null + + // Called by SparkContext or Executor before using Broadcast + def initialize (isMaster__ : Boolean): Unit = synchronized { + if (!initialized) { + val broadcastFactoryClass = System.getProperty( + "spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Setup isMaster before using it + isMaster_ = isMaster__ + + // Set masterHostAddress to the master's IP address for the slaves to read + if (isMaster) { + System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress) + } + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isMaster) + + initialized = true + } + } + + def getBroadcastFactory: BroadcastFactory = { + if (broadcastFactory == null) { + throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") + } + broadcastFactory + } + + // Load common broadcast-related config parameters + private var MasterHostAddress_ = System.getProperty( + "spark.broadcast.masterHostAddress", "") + private var MasterTrackerPort_ = System.getProperty( + "spark.broadcast.masterTrackerPort", "11111").toInt + private var BlockSize_ = System.getProperty( + "spark.broadcast.blockSize", "4096").toInt * 1024 + private var MaxRetryCount_ = System.getProperty( + "spark.broadcast.maxRetryCount", "2").toInt + + private var TrackerSocketTimeout_ = System.getProperty( + "spark.broadcast.trackerSocketTimeout", "50000").toInt + private var ServerSocketTimeout_ = System.getProperty( + "spark.broadcast.serverSocketTimeout", "10000").toInt + + private var MinKnockInterval_ = System.getProperty( + "spark.broadcast.minKnockInterval", "500").toInt + private var MaxKnockInterval_ = System.getProperty( + "spark.broadcast.maxKnockInterval", "999").toInt + + // Load ChainedBroadcast config params + + // Load TreeBroadcast config params + private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt + + // Load BitTorrentBroadcast config params + private var MaxPeersInGuideResponse_ = System.getProperty( + "spark.broadcast.maxPeersInGuideResponse", "4").toInt + + private var MaxRxSlots_ = System.getProperty( + "spark.broadcast.maxRxSlots", "4").toInt + private var MaxTxSlots_ = System.getProperty( + "spark.broadcast.maxTxSlots", "4").toInt + + private var MaxChatTime_ = System.getProperty( + "spark.broadcast.maxChatTime", "500").toInt + private var MaxChatBlocks_ = System.getProperty( + "spark.broadcast.maxChatBlocks", "1024").toInt + + private var EndGameFraction_ = System.getProperty( + "spark.broadcast.endGameFraction", "0.95").toDouble + + def isMaster = isMaster_ + + // Common config params + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + // ChainedBroadcast configs + + // TreeBroadcast configs + def MaxDegree = MaxDegree_ + + // BitTorrentBroadcast configs + def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ + + def MaxRxSlots = MaxRxSlots_ + def MaxTxSlots = MaxTxSlots_ + + def MaxChatTime = MaxChatTime_ + def MaxChatBlocks = MaxChatBlocks_ + + def EndGameFraction = EndGameFraction_ + + // Helper functions to convert an object to Array[BroadcastBlock] + def blockifyObject[IN](obj: IN): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(obj) + oos.close() + baos.close() + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / Broadcast.BlockSize) + if (byteArray.length % Broadcast.BlockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) { + val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + // Helper function to convert Array[BroadcastBlock] to object + def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], + totalBytes: Int, + totalBlocks: Int): OUT = { + + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject(retByteArray) + } + + private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, currentThread.getContextClassLoader) + } + val retVal = in.readObject.asInstanceOf[OUT] + in.close() + return retVal + } +} + +@serializable +case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } + +@serializable +case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], + val totalBlocks: Int, + val totalBytes: Int) { + @transient var hasBlocks = 0 +} + +@serializable +class SpeedTracker { + // Mapping 'source' to '(totalTime, numBlocks)' + private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] () + + def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = { + sourceToSpeedMap.synchronized { + if (!sourceToSpeedMap.contains(srcInfo)) { + sourceToSpeedMap += (srcInfo -> (timeInMillis, 1)) + } else { + val tTnB = sourceToSpeedMap (srcInfo) + sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1)) + } + } + } + + def getTimePerBlock (srcInfo: SourceInfo): Double = { + sourceToSpeedMap.synchronized { + val tTnB = sourceToSpeedMap (srcInfo) + return tTnB._1 / tTnB._2 + } + } + + override def toString = sourceToSpeedMap.toString +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala new file mode 100644 index 0000000000..341746d18e --- /dev/null +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -0,0 +1,12 @@ +package spark.broadcast + +/** + * An interface for all the broadcast implementations in Spark (to allow + * multiple broadcast implementations). SparkContext uses a user-specified + * BroadcastFactory implementation to instantiate a particular broadcast for the + * entire Spark job. + */ +trait BroadcastFactory { + def initialize (isMaster: Boolean): Unit + def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T] +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala new file mode 100644 index 0000000000..3afe923bae --- /dev/null +++ b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala @@ -0,0 +1,792 @@ +package spark.broadcast + +import java.io._ +import java.net._ +import java.util.{Comparator, PriorityQueue, Random, UUID} + +import scala.collection.mutable.{Map, Set} +import scala.math + +import spark._ + +@serializable +class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { + + def value = value_ + + ChainedBroadcast.synchronized { + ChainedBroadcast.values.put(uuid, value_) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + // CHANGED: BlockSize in the Broadcast object is expected to change over time + @transient var blockSize = Broadcast.BlockSize + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + @transient var hasBlocksLock = new Object + + @transient var pqOfSources = new PriorityQueue[SourceInfo] + + @transient var serveMR: ServeMultipleRequests = null + @transient var guideMR: GuideMultipleRequests = null + + @transient var hostAddress = Utils.localIpAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast + } + + def sendBroadcast(): Unit = { + logInfo("Local host address: " + hostAddress) + + // Store a persistent copy in HDFS + // TODO: Turned OFF for now + // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid)) + // out.writeObject(value_) + // out.close() + // TODO: Fix this at some point + hasCopyInHDFS = true + + // Create a variableInfo object and store it in valueInfos + var variableInfo = Broadcast.blockifyObject(value_) + + guideMR = new GuideMultipleRequests + guideMR.setDaemon(true) + guideMR.start() + logInfo("GuideMultipleRequests started...") + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + pqOfSources = new PriorityQueue[SourceInfo] + val masterSource = + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + pqOfSources.add(masterSource) + + // Register with the Tracker + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + ChainedBroadcast.registerValue(uuid, guidePort) + } + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject + ChainedBroadcast.synchronized { + val cachedVal = ChainedBroadcast.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Initializing everything because Master will only send null/0 values + initializeSlaveVariables + + logInfo("Local host address: " + hostAddress) + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(uuid) + // If does not succeed, then get from HDFS copy + if (receptionSucceeded) { + value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + ChainedBroadcast.values.put(uuid, value_) + } else { + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + ChainedBroadcast.values.put(uuid, value_) + fileIn.close() + } + + val time =(System.nanoTime - start) / 1e9 + logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + private def initializeSlaveVariables: Unit = { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + blockSize = -1 + + listenPortLock = new Object + totalBlocksLock = new Object + hasBlocksLock = new Object + + serveMR = null + + hostAddress = Utils.localIpAddress + listenPort = -1 + + stopBroadcast = false + } + + def getMasterListenPort(variableUUID: UUID): Int = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var masterListenPort: Int = SourceInfo.TxOverGoToHDFS + + var retriesLeft = Broadcast.MaxRetryCount + do { + try { + // Connect to the tracker to find out the guide + clientSocketToTracker = + new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) + oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + // Send UUID and receive masterListenPort + oosTracker.writeObject(uuid) + oosTracker.flush() + masterListenPort = oisTracker.readObject.asInstanceOf[Int] + } catch { + case e: Exception => { + logInfo("getMasterListenPort had a " + e) + } + } finally { + if (oisTracker != null) { + oisTracker.close() + } + if (oosTracker != null) { + oosTracker.close() + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close() + } + } + retriesLeft -= 1 + + Thread.sleep(ChainedBroadcast.ranGen.nextInt( + Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + + Broadcast.MinKnockInterval) + + } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) + + logInfo("Got this guidePort from Tracker: " + masterListenPort) + return masterListenPort + } + + def receiveBroadcast(variableUUID: UUID): Boolean = { + val masterListenPort = getMasterListenPort(variableUUID) + + if (masterListenPort == SourceInfo.TxOverGoToHDFS || + masterListenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + var clientSocketToMaster: Socket = null + var oosMaster: ObjectOutputStream = null + var oisMaster: ObjectInputStream = null + + // Connect and receive broadcast from the specified source, retrying the + // specified number of times in case of failures + var retriesLeft = Broadcast.MaxRetryCount + do { + // Connect to Master and send this worker's Information + clientSocketToMaster = + new Socket(Broadcast.MasterHostAddress, masterListenPort) + // TODO: Guiding object connection is reusable + oosMaster = + new ObjectOutputStream(clientSocketToMaster.getOutputStream) + oosMaster.flush() + oisMaster = + new ObjectInputStream(clientSocketToMaster.getInputStream) + + logInfo("Connected to Master's guiding object") + + // Send local source information + oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) + oosMaster.flush() + + // Receive source information from Master + var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + totalBlocks = sourceInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = sourceInfo.totalBytes + + logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + + val start = System.nanoTime + val receptionSucceeded = receiveSingleTransmission(sourceInfo) + val time =(System.nanoTime - start) / 1e9 + + // Updating some statistics in sourceInfo. Master will be using them later + if (!receptionSucceeded) { + sourceInfo.receptionFailed = true + } + + // Send back statistics to the Master + oosMaster.writeObject(sourceInfo) + + if (oisMaster != null) { + oisMaster.close() + } + if (oosMaster != null) { + oosMaster.close() + } + if (clientSocketToMaster != null) { + clientSocketToMaster.close() + } + + retriesLeft -= 1 + } while (retriesLeft > 0 && hasBlocks < totalBlocks) + + return(hasBlocks == totalBlocks) + } + + // Tries to receive broadcast from the source and returns Boolean status. + // This might be called multiple times to retry a defined number of times. + private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { + var clientSocketToSource: Socket = null + var oosSource: ObjectOutputStream = null + var oisSource: ObjectInputStream = null + + var receptionSucceeded = false + try { + // Connect to the source to get the object itself + clientSocketToSource = + new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = + new ObjectOutputStream(clientSocketToSource.getOutputStream) + oosSource.flush() + oisSource = + new ObjectInputStream(clientSocketToSource.getInputStream) + + logInfo("Inside receiveSingleTransmission") + logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + + // Send the range + oosSource.writeObject((hasBlocks, totalBlocks)) + oosSource.flush() + + for (i <- hasBlocks until totalBlocks) { + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime =(System.currentTimeMillis - recvStartTime) + + logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") + + arrayOfBlocks(hasBlocks) = bcBlock + hasBlocks += 1 + // Set to true if at least one block is received + receptionSucceeded = true + hasBlocksLock.synchronized { + hasBlocksLock.notifyAll + } + } + } catch { + case e: Exception => { + logInfo("receiveSingleTransmission had a " + e) + } + } finally { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (clientSocketToSource != null) { + clientSocketToSource.close() + } + } + + return receptionSucceeded + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo]() + + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + guidePort = serverSocket.getLocalPort + logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { + guidePortLock.notifyAll + } + + try { + // Don't stop until there is a copy in HDFS + while (!stopBroadcast || !hasCopyInHDFS) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // pqOfSources.size - 1, because it includes the Guide itself + if (pqOfSources.size > 1 && + setOfCompletedSources.size == pqOfSources.size - 1) { + stopBroadcast = true + } + } + } + if (clientSocket != null) { + logInfo("Guide: Accepted new client connection: " + clientSocket) + try { + threadPool.execute(new GuideSingleRequest(clientSocket)) + } catch { + // In failure, close the socket here; else, the thread will close it + case ioe: IOException => clientSocket.close() + } + } + } + + logInfo("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + ChainedBroadcast.unregisterValue(uuid) + } finally { + if (serverSocket != null) { + logInfo("GuideMultipleRequests now stopping...") + serverSocket.close() + } + } + + // Shutdown the thread pool + threadPool.shutdown() + } + + private def sendStopBroadcastNotifications: Unit = { + pqOfSources.synchronized { + var pqIter = pqOfSources.iterator + while (pqIter.hasNext) { + var sourceInfo = pqIter.next + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream(guideSocketToSource.getOutputStream) + gosSource.flush() + gisSource = + new ObjectInputStream(guideSocketToSource.getInputStream) + + // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 + gosSource.writeObject((SourceInfo.StopBroadcast, + SourceInfo.StopBroadcast)) + gosSource.flush() + } catch { + case e: Exception => { + logInfo("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close() + } + if (gosSource != null) { + gosSource.close() + } + if (guideSocketToSource != null) { + guideSocketToSource.close() + } + } + } + } + } + + class GuideSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var selectedSourceInfo: SourceInfo = null + private var thisWorkerInfo:SourceInfo = null + + override def run: Unit = { + try { + logInfo("new GuideSingleRequest is running") + // Connecting worker is sending in its hostAddress and listenPort it will + // be listening to. Other fields are invalid(SourceInfo.UnusedParam) + var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // Select a suitable source and send it back to the worker + selectedSourceInfo = selectSuitableSource(sourceInfo) + logInfo("Sending selectedSourceInfo: " + selectedSourceInfo) + oos.writeObject(selectedSourceInfo) + oos.flush() + + // Add this new(if it can finish) source to the PQ of sources + thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes, blockSize) + logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo) + pqOfSources.add(thisWorkerInfo) + } + + // Wait till the whole transfer is done. Then receive and update source + // statistics in pqOfSources + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // This should work since SourceInfo is a case class + assert(pqOfSources.contains(selectedSourceInfo)) + + // Remove first + pqOfSources.remove(selectedSourceInfo) + // TODO: Removing a source based on just one failure notification! + + // Update sourceInfo and put it back in, IF reception succeeded + if (!sourceInfo.receptionFailed) { + // Add thisWorkerInfo to sources that have completed reception + setOfCompletedSources.synchronized { + setOfCompletedSources += thisWorkerInfo + } + + selectedSourceInfo.currentLeechers -= 1 + + // Put it back + pqOfSources.add(selectedSourceInfo) + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + // Assuming that exception caused due to receiver worker failure. + // Remove failed worker from pqOfSources and update leecherCount of + // corresponding source worker + pqOfSources.synchronized { + if (selectedSourceInfo != null) { + // Remove first + pqOfSources.remove(selectedSourceInfo) + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + pqOfSources.add(selectedSourceInfo) + } + + // Remove thisWorkerInfo + if (pqOfSources != null) { + pqOfSources.remove(thisWorkerInfo) + } + } + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + + // FIXME: Caller must have a synchronized block on pqOfSources + // FIXME: If a worker fails to get the broadcasted variable from a source and + // comes back to Master, this function might choose the worker itself as a + // source tp create a dependency cycle(this worker was put into pqOfSources + // as a streming source when it first arrived). The length of this cycle can + // be arbitrarily long. + private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { + // Select one based on the ordering strategy(e.g., least leechers etc.) + // take is a blocking call removing the element from PQ + var selectedSource = pqOfSources.poll + assert(selectedSource != null) + // Update leecher count + selectedSource.currentLeechers += 1 + // Add it back and then return + pqOfSources.add(selectedSource) + return selectedSource + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + listenPort = serverSocket.getLocalPort + logInfo("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo("ServeMultipleRequests Timeout.") + } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection: " + clientSocket) + try { + threadPool.execute(new ServeSingleRequest(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ServeMultipleRequests now stopping...") + serverSocket.close() + } + } + + // Shutdown the thread pool + threadPool.shutdown() + } + + class ServeSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var sendFrom = 0 + private var sendUntil = totalBlocks + + override def run: Unit = { + try { + logInfo("new ServeSingleRequest is running") + + // Receive range to send + var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] + sendFrom = rangeToSend._1 + sendUntil = rangeToSend._2 + + if (sendFrom == SourceInfo.StopBroadcast && + sendUntil == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + sendObject + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + logInfo("ServeSingleRequest had a " + e) + } + } finally { + logInfo("ServeSingleRequest is closing streams and sockets") + ois.close() + oos.close() + clientSocket.close() + } + } + + private def sendObject: Unit = { + // Wait till receiving the SourceInfo from Master + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + for (i <- sendFrom until sendUntil) { + while (i == hasBlocks) { + hasBlocksLock.synchronized { + hasBlocksLock.wait + } + } + try { + oos.writeObject(arrayOfBlocks(i)) + oos.flush() + } catch { + case e: Exception => { + logInfo("sendObject had a " + e) + } + } + logInfo("Sent block: " + i + " to " + clientSocket) + } + } + } + } +} + +class ChainedBroadcastFactory +extends BroadcastFactory { + def initialize(isMaster: Boolean) = ChainedBroadcast.initialize(isMaster) + def newBroadcast[T](value_ : T, isLocal: Boolean) = + new ChainedBroadcast[T](value_, isLocal) +} + +private object ChainedBroadcast +extends Logging { + val values = SparkEnv.get.cache.newKeySpace() + + var valueToGuidePortMap = Map[UUID, Int]() + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var trackMV: TrackMultipleValues = null + + def initialize(isMaster__ : Boolean): Unit = { + synchronized { + if (!initialized) { + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start() + // TODO: Logging the following line makes the Spark framework ID not + // getting logged, cause it calls logInfo before log4j is initialized + logInfo("TrackMultipleValues started...") + } + + // Initialize DfsBroadcast to be used for broadcast variable persistence + DfsBroadcast.initialize + + initialized = true + } + } + } + + def isMaster = isMaster_ + + def registerValue(uuid: UUID, guidePort: Int): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap +=(uuid -> guidePort) + logInfo("New value registered with the Tracker " + valueToGuidePortMap) + } + } + + def unregisterValue(uuid: UUID): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS + logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) + } + } + + class TrackMultipleValues + extends Thread with Logging { + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) + logInfo("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + try { + val uuid = ois.readObject.asInstanceOf[UUID] + var guidePort = + if (valueToGuidePortMap.contains(uuid)) { + valueToGuidePortMap(uuid) + } else SourceInfo.TxNotStartedRetry + logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) + oos.writeObject(guidePort) + } catch { + case e: Exception => { + logInfo("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + serverSocket.close() + } + + // Shutdown the thread pool + threadPool.shutdown() + } + } +} diff --git a/core/src/main/scala/spark/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala index 895f55ca22..e541c09216 100644 --- a/core/src/main/scala/spark/DfsBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala @@ -1,4 +1,6 @@ -package spark +package spark.broadcast + +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import java.io._ import java.net._ @@ -7,7 +9,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import spark._ @serializable class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean) diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala new file mode 100644 index 0000000000..064142590a --- /dev/null +++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala @@ -0,0 +1,41 @@ +package spark.broadcast + +import java.util.BitSet + +import spark._ + +/** + * Used to keep and pass around information of peers involved in a broadcast + * + * CHANGED: Keep track of the blockSize for THIS broadcast variable. + * Broadcast.BlockSize is expected to be updated across different broadcasts + */ +@serializable +case class SourceInfo (val hostAddress: String, + val listenPort: Int, + val totalBlocks: Int = SourceInfo.UnusedParam, + val totalBytes: Int = SourceInfo.UnusedParam, + val blockSize: Int = Broadcast.BlockSize) +extends Comparable[SourceInfo] with Logging { + + var currentLeechers = 0 + var receptionFailed = false + + var hasBlocks = 0 + var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) + + // Ascending sort based on leecher count + def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) +} + +/** + * Helper Object of SourceInfo for its constants + */ +object SourceInfo { + // Constants for special values of listenPort + val TxNotStartedRetry = -1 + val TxOverGoToHDFS = 0 + // Other constants + val StopBroadcast = -2 + val UnusedParam = 0 +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala new file mode 100644 index 0000000000..79dcd317ec --- /dev/null +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -0,0 +1,807 @@ +package spark.broadcast + +import java.io._ +import java.net._ +import java.util.{Comparator, Random, UUID} + +import scala.collection.mutable.{ListBuffer, Map, Set} +import scala.math + +import spark._ + +@serializable +class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { + + def value = value_ + + TreeBroadcast.synchronized { + TreeBroadcast.values.put(uuid, value_) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + // CHANGED: BlockSize in the Broadcast object is expected to change over time + @transient var blockSize = Broadcast.BlockSize + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + @transient var hasBlocksLock = new Object + + @transient var listOfSources = ListBuffer[SourceInfo]() + + @transient var serveMR: ServeMultipleRequests = null + @transient var guideMR: GuideMultipleRequests = null + + @transient var hostAddress = Utils.localIpAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast + } + + def sendBroadcast(): Unit = { + logInfo("Local host address: " + hostAddress) + + // Store a persistent copy in HDFS + // TODO: Turned OFF for now + // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid)) + // out.writeObject(value_) + // out.close() + // TODO: Fix this at some point + hasCopyInHDFS = true + + // Create a variableInfo object and store it in valueInfos + var variableInfo = Broadcast.blockifyObject(value_) + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + guideMR = new GuideMultipleRequests + guideMR.setDaemon(true) + guideMR.start + logInfo("GuideMultipleRequests started...") + + // Must always come AFTER guideMR is created + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start + logInfo("ServeMultipleRequests started...") + + // Must always come AFTER serveMR is created + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Must always come AFTER listenPort is created + val masterSource = + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + listOfSources += masterSource + + // Register with the Tracker + TreeBroadcast.registerValue(uuid, guidePort) + } + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject + TreeBroadcast.synchronized { + val cachedVal = TreeBroadcast.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Initializing everything because Master will only send null/0 values + initializeSlaveVariables + + logInfo("Local host address: " + hostAddress) + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(uuid) + // If does not succeed, then get from HDFS copy + if (receptionSucceeded) { + value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + TreeBroadcast.values.put(uuid, value_) + } else { + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + TreeBroadcast.values.put(uuid, value_) + fileIn.close() + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + private def initializeSlaveVariables: Unit = { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + blockSize = -1 + + listenPortLock = new Object + totalBlocksLock = new Object + hasBlocksLock = new Object + + serveMR = null + + hostAddress = Utils.localIpAddress + listenPort = -1 + + stopBroadcast = false + } + + def getMasterListenPort(variableUUID: UUID): Int = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var masterListenPort: Int = SourceInfo.TxOverGoToHDFS + + var retriesLeft = Broadcast.MaxRetryCount + do { + try { + // Connect to the tracker to find out the guide + clientSocketToTracker = + new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) + oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + // Send UUID and receive masterListenPort + oosTracker.writeObject(uuid) + oosTracker.flush() + masterListenPort = oisTracker.readObject.asInstanceOf[Int] + } catch { + case e: Exception => { + logInfo("getMasterListenPort had a " + e) + } + } finally { + if (oisTracker != null) { + oisTracker.close() + } + if (oosTracker != null) { + oosTracker.close() + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close() + } + } + retriesLeft -= 1 + + Thread.sleep(TreeBroadcast.ranGen.nextInt( + Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + + Broadcast.MinKnockInterval) + + } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) + + logInfo("Got this guidePort from Tracker: " + masterListenPort) + return masterListenPort + } + + def receiveBroadcast(variableUUID: UUID): Boolean = { + val masterListenPort = getMasterListenPort(variableUUID) + + if (masterListenPort == SourceInfo.TxOverGoToHDFS || + masterListenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + var clientSocketToMaster: Socket = null + var oosMaster: ObjectOutputStream = null + var oisMaster: ObjectInputStream = null + + // Connect and receive broadcast from the specified source, retrying the + // specified number of times in case of failures + var retriesLeft = Broadcast.MaxRetryCount + do { + // Connect to Master and send this worker's Information + clientSocketToMaster = + new Socket(Broadcast.MasterHostAddress, masterListenPort) + // TODO: Guiding object connection is reusable + oosMaster = + new ObjectOutputStream(clientSocketToMaster.getOutputStream) + oosMaster.flush() + oisMaster = + new ObjectInputStream(clientSocketToMaster.getInputStream) + + logInfo("Connected to Master's guiding object") + + // Send local source information + oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) + oosMaster.flush() + + // Receive source information from Master + var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + totalBlocks = sourceInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = sourceInfo.totalBytes + blockSize = sourceInfo.blockSize + + logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + + val start = System.nanoTime + val receptionSucceeded = receiveSingleTransmission(sourceInfo) + val time = (System.nanoTime - start) / 1e9 + + // Updating some statistics in sourceInfo. Master will be using them later + if (!receptionSucceeded) { + sourceInfo.receptionFailed = true + } + + // Send back statistics to the Master + oosMaster.writeObject(sourceInfo) + + if (oisMaster != null) { + oisMaster.close() + } + if (oosMaster != null) { + oosMaster.close() + } + if (clientSocketToMaster != null) { + clientSocketToMaster.close() + } + + retriesLeft -= 1 + } while (retriesLeft > 0 && hasBlocks < totalBlocks) + + return (hasBlocks == totalBlocks) + } + + // Tries to receive broadcast from the source and returns Boolean status. + // This might be called multiple times to retry a defined number of times. + private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { + var clientSocketToSource: Socket = null + var oosSource: ObjectOutputStream = null + var oisSource: ObjectInputStream = null + + var receptionSucceeded = false + try { + // Connect to the source to get the object itself + clientSocketToSource = + new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = + new ObjectOutputStream(clientSocketToSource.getOutputStream) + oosSource.flush() + oisSource = + new ObjectInputStream(clientSocketToSource.getInputStream) + + logInfo("Inside receiveSingleTransmission") + logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + + // Send the range + oosSource.writeObject((hasBlocks, totalBlocks)) + oosSource.flush() + + for (i <- hasBlocks until totalBlocks) { + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") + + arrayOfBlocks(hasBlocks) = bcBlock + hasBlocks += 1 + // Set to true if at least one block is received + receptionSucceeded = true + hasBlocksLock.synchronized { + hasBlocksLock.notifyAll + } + } + } catch { + case e: Exception => { + logInfo("receiveSingleTransmission had a " + e) + } + } finally { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (clientSocketToSource != null) { + clientSocketToSource.close() + } + } + + return receptionSucceeded + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo]() + + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + guidePort = serverSocket.getLocalPort + logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { + guidePortLock.notifyAll + } + + try { + // Don't stop until there is a copy in HDFS + while (!stopBroadcast || !hasCopyInHDFS) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // listOfSources.size - 1, because it includes the Guide itself + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + } + } + } + if (clientSocket != null) { + logInfo("Guide: Accepted new client connection: " + clientSocket) + try { + threadPool.execute(new GuideSingleRequest(clientSocket)) + } catch { + // In failure, close() the socket here; else, the thread will close() it + case ioe: IOException => clientSocket.close() + } + } + } + + logInfo("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + TreeBroadcast.unregisterValue(uuid) + } finally { + if (serverSocket != null) { + logInfo("GuideMultipleRequests now stopping...") + serverSocket.close() + } + } + + // Shutdown the thread pool + threadPool.shutdown + } + + private def sendStopBroadcastNotifications: Unit = { + listOfSources.synchronized { + var listIter = listOfSources.iterator + while (listIter.hasNext) { + var sourceInfo = listIter.next + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream(guideSocketToSource.getOutputStream) + gosSource.flush() + gisSource = + new ObjectInputStream(guideSocketToSource.getInputStream) + + // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 + gosSource.writeObject((SourceInfo.StopBroadcast, + SourceInfo.StopBroadcast)) + gosSource.flush() + } catch { + case e: Exception => { + logInfo("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close() + } + if (gosSource != null) { + gosSource.close() + } + if (guideSocketToSource != null) { + guideSocketToSource.close() + } + } + } + } + } + + class GuideSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var selectedSourceInfo: SourceInfo = null + private var thisWorkerInfo:SourceInfo = null + + override def run: Unit = { + try { + logInfo("new GuideSingleRequest is running") + // Connecting worker is sending in its hostAddress and listenPort it will + // be listening to. Other fields are invalid (SourceInfo.UnusedParam) + var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + listOfSources.synchronized { + // Select a suitable source and send it back to the worker + selectedSourceInfo = selectSuitableSource(sourceInfo) + logInfo("Sending selectedSourceInfo: " + selectedSourceInfo) + oos.writeObject(selectedSourceInfo) + oos.flush() + + // Add this new (if it can finish) source to the list of sources + thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes, blockSize) + logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo) + listOfSources += thisWorkerInfo + } + + // Wait till the whole transfer is done. Then receive and update source + // statistics in listOfSources + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + listOfSources.synchronized { + // This should work since SourceInfo is a case class + assert(listOfSources.contains(selectedSourceInfo)) + + // Remove first + listOfSources = listOfSources - selectedSourceInfo + // TODO: Removing a source based on just one failure notification! + + // Update sourceInfo and put it back in, IF reception succeeded + if (!sourceInfo.receptionFailed) { + // Add thisWorkerInfo to sources that have completed reception + setOfCompletedSources.synchronized { + setOfCompletedSources += thisWorkerInfo + } + + selectedSourceInfo.currentLeechers -= 1 + + // Put it back + listOfSources += selectedSourceInfo + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close() everything up + case e: Exception => { + // Assuming that exception caused due to receiver worker failure. + // Remove failed worker from listOfSources and update leecherCount of + // corresponding source worker + listOfSources.synchronized { + if (selectedSourceInfo != null) { + // Remove first + listOfSources = listOfSources - selectedSourceInfo + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + listOfSources += selectedSourceInfo + } + + // Remove thisWorkerInfo + if (listOfSources != null) { + listOfSources = listOfSources - thisWorkerInfo + } + } + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + + // FIXME: Caller must have a synchronized block on listOfSources + // FIXME: If a worker fails to get the broadcasted variable from a source + // and comes back to the Master, this function might choose the worker + // itself as a source to create a dependency cycle (this worker was put + // into listOfSources as a streming source when it first arrived). The + // length of this cycle can be arbitrarily long. + private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { + // Select one with the most leechers. This will level-wise fill the tree + + var maxLeechers = -1 + var selectedSource: SourceInfo = null + + listOfSources.foreach { source => + if (source != skipSourceInfo && + source.currentLeechers < Broadcast.MaxDegree && + source.currentLeechers > maxLeechers) { + selectedSource = source + maxLeechers = source.currentLeechers + } + } + + // Update leecher count + selectedSource.currentLeechers += 1 + + return selectedSource + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + listenPort = serverSocket.getLocalPort + logInfo("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo("ServeMultipleRequests Timeout.") + } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection: " + clientSocket) + try { + threadPool.execute(new ServeSingleRequest(clientSocket)) + } catch { + // In failure, close() socket here; else, the thread will close() it + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ServeMultipleRequests now stopping...") + serverSocket.close() + } + } + + // Shutdown the thread pool + threadPool.shutdown + } + + class ServeSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var sendFrom = 0 + private var sendUntil = totalBlocks + + override def run: Unit = { + try { + logInfo("new ServeSingleRequest is running") + + // Receive range to send + var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] + sendFrom = rangeToSend._1 + sendUntil = rangeToSend._2 + + if (sendFrom == SourceInfo.StopBroadcast && + sendUntil == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + sendObject + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close() everything up + case e: Exception => { + logInfo("ServeSingleRequest had a " + e) + } + } finally { + logInfo("ServeSingleRequest is closing streams and sockets") + ois.close() + oos.close() + clientSocket.close() + } + } + + private def sendObject: Unit = { + // Wait till receiving the SourceInfo from Master + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + for (i <- sendFrom until sendUntil) { + while (i == hasBlocks) { + hasBlocksLock.synchronized { + hasBlocksLock.wait + } + } + try { + oos.writeObject(arrayOfBlocks(i)) + oos.flush() + } catch { + case e: Exception => { + logInfo("sendObject had a " + e) + } + } + logInfo("Sent block: " + i + " to " + clientSocket) + } + } + } + } +} + +class TreeBroadcastFactory +extends BroadcastFactory { + def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster) + def newBroadcast[T](value_ : T, isLocal: Boolean) = + new TreeBroadcast[T](value_, isLocal) +} + +private object TreeBroadcast +extends Logging { + val values = SparkEnv.get.cache.newKeySpace() + + var valueToGuidePortMap = Map[UUID, Int]() + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var trackMV: TrackMultipleValues = null + + private var MaxDegree_ : Int = 2 + + def initialize(isMaster__ : Boolean): Unit = { + synchronized { + if (!initialized) { + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start + // TODO: Logging the following line makes the Spark framework ID not + // getting logged, cause it calls logInfo before log4j is initialized + logInfo("TrackMultipleValues started...") + } + + // Initialize DfsBroadcast to be used for broadcast variable persistence + DfsBroadcast.initialize + + initialized = true + } + } + } + + def isMaster = isMaster_ + + def registerValue(uuid: UUID, guidePort: Int): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap += (uuid -> guidePort) + logInfo("New value registered with the Tracker " + valueToGuidePortMap) + } + } + + def unregisterValue(uuid: UUID): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS + logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) + } + } + + class TrackMultipleValues + extends Thread with Logging { + override def run: Unit = { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) + logInfo("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + try { + val uuid = ois.readObject.asInstanceOf[UUID] + var guidePort = + if (valueToGuidePortMap.contains(uuid)) { + valueToGuidePortMap(uuid) + } else SourceInfo.TxNotStartedRetry + logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) + oos.writeObject(guidePort) + } catch { + case e: Exception => { + logInfo("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close() socket here; else, client thread will close() + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + serverSocket.close() + } + + // Shutdown the thread pool + threadPool.shutdown + } + } +} diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 3089360756..d14d313b48 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -5,8 +5,8 @@ import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ - import SparkContext._ +import scala.collection.mutable.ArrayBuffer class ShuffleSuite extends FunSuite { test("groupByKey") { @@ -115,6 +115,38 @@ class ShuffleSuite extends FunSuite { sc.stop() } + test("leftOuterJoin") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.leftOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (1, Some('x'))), + (1, (2, Some('x'))), + (2, (1, Some('y'))), + (2, (1, Some('z'))), + (3, (1, None)) + )) + sc.stop() + } + + test("rightOuterJoin") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.rightOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (Some(1), 'x')), + (1, (Some(2), 'x')), + (2, (Some(1), 'y')), + (2, (Some(1), 'z')), + (4, (None, 'w')) + )) + sc.stop() + } + test("join with no matches") { val sc = new SparkContext("local", "test") val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) @@ -138,4 +170,20 @@ class ShuffleSuite extends FunSuite { )) sc.stop() } + + test("groupWith") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.groupWith(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), + (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), + (3, (ArrayBuffer(1), ArrayBuffer())), + (4, (ArrayBuffer(), ArrayBuffer('w'))) + )) + sc.stop() + } + } |