From 29617c27a155cf5c55345e9f2596c6e9628e4c90 Mon Sep 17 00:00:00 2001 From: Mosharaf Chowdhury Date: Fri, 18 Oct 2013 23:54:11 -0700 Subject: Removed BitTorrentBroadcast and TreeBroadcast. TorrentBroadcast is replacing both. --- .../spark/broadcast/BitTorrentBroadcast.scala | 1060 -------------------- .../org/apache/spark/broadcast/MultiTracker.scala | 410 -------- .../org/apache/spark/broadcast/SourceInfo.scala | 54 - .../org/apache/spark/broadcast/TreeBroadcast.scala | 601 ----------- 4 files changed, 2125 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala delete mode 100644 core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala delete mode 100644 core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala delete mode 100644 core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala deleted file mode 100644 index 5332510e87..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1060 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils - -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) - with Logging - with Serializable { - - def value = value_ - - def blockId = BroadcastBlockId(id) - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = new AtomicInteger(0) - - // Used ONLY by driver to track how many unique blocks have been sent out - @transient var sentBlocks = new AtomicInteger(0) - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in driver - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks.set(variableInfo.totalBlocks) - - // Guide has all the blocks - hasBlocksBitVector = new BitSet(totalBlocks) - hasBlocksBitVector.set(0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int](totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val driverSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - driverSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += driverSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Driver sends everything as 0/null - private def initializeWorkerVariables() { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = new AtomicInteger(0) - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo]() - - stopBroadcast = false - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes) - - localSourceInfo.hasBlocks = hasBlocks.get - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources(newSourceInfo: SourceInfo) { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources(newSourceInfo) - } - } - - class TalkToGuide(gInfo: SourceInfo) - extends Thread with Logging { - override def run() { - - // Keep exchaning information until all blocks have been received - while (hasBlocks.get < totalBlocks) { - talkOnce - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) - oosGuide.flush() - oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush() - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Driver " + suitableSources) - - addToListOfSources(suitableSources) - - oisGuide.close() - oosGuide.close() - clientSocketToGuide.close() - } - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - hasBlocksBitVector = new BitSet(totalBlocks) - numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide(gInfo) - ttGuide.setDaemon(true) - ttGuide.start() - logInfo("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon(true) - pcController.start() - logInfo("PeerChatterController started...") - - // FIXME: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks.get < totalBlocks) { - Thread.sleep(MultiTracker.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo]() - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet(totalBlocks) - - override def run() { - var threadPool = Utils.newDaemonFixedThreadPool( - MultiTracker.MaxChatSlots, "Bit Torrent Chatter") - - while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = 0 - listOfSources.synchronized { - numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - - threadPool.getActiveCount - } - - while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkToRandom - - if (peerToTalkTo != null) - logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) - else - logDebug("No peer chosen...") - - if (peerToTalkTo != null) { - threadPool.execute(new TalkToPeer(peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep(MultiTracker.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown() - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkToRandom: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logDebug("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Select the peer that has the most blocks that this receiver does not - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - return curPeer - } - - // Picking peer with the weight of rare blocks it has - private def pickPeerToTalkToRarestFirst: SourceInfo = { - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Count the number of copies of each block in the neighborhood - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // A block is considered rare if there are at most 2 copies of that block - // This CONSTANT could be a function of the neighborhood size - var rareBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { - rareBlocksIndices += i - } - } - - // Find peers with rare blocks - var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() - var totalRareBlocks = 0 - - peersNotInUse.foreach { eachPeer => - var hasRareBlocks = 0 - rareBlocksIndices.foreach { rareBlock => - if (eachPeer.hasBlocksBitVector.get(rareBlock)) { - hasRareBlocks += 1 - } - } - - if (hasRareBlocks > 0) { - peersWithRareBlocks += ((eachPeer, hasRareBlocks)) - } - totalRareBlocks += hasRareBlocks - } - - // Select a peer from peersWithRareBlocks based on weight calculated from - // unique rare blocks - var selectedPeerToTalkTo: SourceInfo = null - - if (peersWithRareBlocks.size > 0) { - // Sort the peers based on how many rare blocks they have - peersWithRareBlocks.sortBy(_._2) - - var randomNumber = MultiTracker.ranGen.nextDouble - var tempSum = 0.0 - - var i = 0 - do { - tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) - if (tempSum >= randomNumber) { - selectedPeerToTalkTo = peersWithRareBlocks(i)._1 - } - i += 1 - } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) - } - - if (selectedPeerToTalkTo == null) { - selectedPeerToTalkTo = pickPeerToTalkToRandom - } - - return selectedPeerToTalkTo - } - - class TalkToPeer(peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run() { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run() { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) - - logInfo("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources(newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - - var keepReceiving = true - - while (hasBlocks.get < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other threads know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush() - - // CHANGED: Driver might send some other block than the one - // requested to ensure fast spreading of all blocks. - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set(bcBlock.blockID) - hasBlocks.getAndIncrement - } - - // Some block(may NOT be blockToAskFor) has arrived. - // In any case, blockToAskFor is not in request any more - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logError("TalktoPeer had a " + e) - // FIXME: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - } - - cleanUpConnections() - } - } - - // Right now it picks a block uniformly that this peer does not have - private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) - i -= 1 - } - - return pickedBlockIndex - } - } - - // Pick the block that seems to be the rarest across sources - private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Count the number of copies for each block across all sources - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // Find the minimum - var minVal = Integer.MAX_VALUE - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { - minVal = numCopiesPerBlock(i) - } - } - - // Find the blocks with the least copies that this peer does not have - var minBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { - minBlocksIndices += i - } - } - - // Now select a random index from minBlocksIndices - if (minBlocksIndices.size == 0) { - return -1 - } else { - // Pick uniformly the i'th index - var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) - return minBlocksIndices(i) - } - } - } - - private def cleanUpConnections() { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool("Bit torrent guide multiple requests") - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown() - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources(sourceInfo) - logDebug("Sending selectedSources:" + selectedSources) - oos.writeObject(selectedSources) - oos.flush() - - // Add this source to the listOfSources - addToListOfSources(sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { listOfSources -= sourceInfo } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo]() - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = MultiTracker.MaxPeersInGuideResponse - var alreadyPicked = new BitSet(listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = MultiTracker.ranGen.nextInt(listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - // Set the BitSet before i is decremented - alreadyPicked.set(i) - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most MultiTracker.MaxChatSlots peers - var threadPool = Utils.newDaemonFixedThreadPool( - MultiTracker.MaxChatSlots, "Bit torrent serve multiple requests") - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ServeSingleRequest is running") - - override def run() { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush() - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - addToListOfSources(rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = MultiTracker.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - var blockToSend = ois.readObject.asInstanceOf[Int] - - // If it is driver AND at least one copy of each block has not been - // sent out already, MODIFY blockToSend - if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { - blockToSend = sentBlocks.getAndIncrement - } - - // Send the block - sendBlock(blockToSend) - rxSourceInfo.hasBlocksBitVector.set(blockToSend) - - numBlocksToSend -= 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources(rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= MultiTracker.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendBlock(blockToSend: Int) { - try { - oos.writeObject(arrayOfBlocks(blockToSend)) - oos.flush() - } catch { - case e: Exception => logError("sendBlock had a " + e) - } - logDebug("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -private[spark] class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala deleted file mode 100644 index 82ed64f190..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala +++ /dev/null @@ -1,410 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.Random - -import scala.collection.mutable.Map - -import org.apache.spark._ -import org.apache.spark.util.Utils - -private object MultiTracker -extends Logging { - - // Tracker Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - - // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[Long, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var _isDriver = false - - private var stopBroadcast = false - - private var trackMV: TrackMultipleValues = null - - def initialize(__isDriver: Boolean) { - synchronized { - if (!initialized) { - _isDriver = __isDriver - - if (isDriver) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - - // Set DriverHostAddress to the driver's IP address for the slaves to read - System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) - } - - initialized = true - } - } - } - - def stop() { - stopBroadcast = true - } - - // Load common parameters - private var DriverHostAddress_ = System.getProperty( - "spark.MultiTracker.DriverHostAddress", "") - private var DriverTrackerPort_ = System.getProperty( - "spark.broadcast.driverTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty( - "spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxChatSlots_ = System.getProperty( - "spark.broadcast.maxChatSlots", "4").toInt - private var MaxChatTime_ = System.getProperty( - "spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty( - "spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble - - def isDriver = _isDriver - - // Common config params - def DriverHostAddress = DriverHostAddress_ - def DriverTrackerPort = DriverTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxChatSlots = MaxChatSlots_ - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool("Track multiple values") - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(DriverTrackerPort) - logInfo("TrackMultipleValues started at " + serverSocket) - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - if (stopBroadcast) { - logInfo("Stopping TrackMultipleValues...") - } - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (id -> gInfo) - } - - logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) - } - - logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - var gInfo = - if (valueToGuideMap.contains(id)) valueToGuideMap(id) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logError("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - def getGuideInfo(variableLong: Long): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) - - var retriesLeft = MultiTracker.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send Long and receive GuideInfo - oosTracker.writeObject(variableLong) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => logError("getGuideInfo had a " + e) - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - // Helper method to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / BlockSize) - if (byteArray.length % BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, BlockSize)) { - val thisBlockSize = math.min(BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper method to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) -extends Serializable - -private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) -extends Serializable { - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala deleted file mode 100644 index baa1fd6da4..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.util.BitSet - -import org.apache.spark._ - -/** - * Used to keep and pass around information of peers involved in a broadcast - */ -private[spark] case class SourceInfo (hostAddress: String, - listenPort: Int, - totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -/** - * Helper Object of SourceInfo for its constants - */ -private[spark] object SourceInfo { - // Broadcast has not started yet! Should never happen. - val TxNotStartedRetry = -1 - // Broadcast has already finished. Try default mechanism. - val TxOverGoToDefault = -3 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala deleted file mode 100644 index 51af80a35e..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala +++ /dev/null @@ -1,601 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ - -import scala.collection.mutable.{ListBuffer, Set} - -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils - -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId = BroadcastBlockId(id) - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - listOfSources += masterSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because Driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def initializeWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - var clientSocketToDriver: Socket = null - var oosDriver: ObjectOutputStream = null - var oisDriver: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = MultiTracker.MaxRetryCount - do { - // Connect to Driver and send this worker's Information - clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) - oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) - oosDriver.flush() - oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - - logDebug("Connected to Driver's guiding object") - - // Send local source information - oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) - oosDriver.flush() - - // Receive source information from Driver - var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = sourceInfo.totalBytes - - logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Driver will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Driver - oosDriver.writeObject(sourceInfo) - - if (oisDriver != null) { - oisDriver.close() - } - if (oosDriver != null) { - oosDriver.close() - } - if (clientSocketToDriver != null) { - clientSocketToDriver.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - /** - * Tries to receive broadcast from the source and returns Boolean status. - * This might be called multiple times to retry a defined number of times. - */ - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - - logDebug("Inside receiveSingleTransmission") - logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } - } - } catch { - case e: Exception => logError("receiveSingleTransmission had a " + e) - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast guide multiple requests") - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close() the socket here; else, the thread will close() it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - var listIter = listOfSources.iterator - while (listIter.hasNext) { - var sourceInfo = listIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal - gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new (if it can finish) source to the list of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) - listOfSources += thisWorkerInfo - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in listOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(listOfSources.contains(selectedSourceInfo)) - - // Remove first - // (Currently removing a source based on just one failure notification!) - listOfSources = listOfSources - selectedSourceInfo - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - } - } catch { - case e: Exception => { - // Remove failed worker from listOfSources and update leecherCount of - // corresponding source worker - listOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - listOfSources = listOfSources - selectedSourceInfo - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - - // Remove thisWorkerInfo - if (listOfSources != null) { - listOfSources = listOfSources - thisWorkerInfo - } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Assuming the caller to have a synchronized block on listOfSources - // Select one with the most leechers. This will level-wise fill the tree - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - var maxLeechers = -1 - var selectedSource: SourceInfo = null - - listOfSources.foreach { source => - if ((source.hostAddress != skipSourceInfo.hostAddress || - source.listenPort != skipSourceInfo.listenPort) && - source.currentLeechers < MultiTracker.MaxDegree && - source.currentLeechers > maxLeechers) { - selectedSource = source - maxLeechers = source.currentLeechers - } - } - - // Update leecher count - selectedSource.currentLeechers += 1 - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - - var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast serve multiple requests") - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { } - } - - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - // If not a valid range, stop broadcast - if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - sendObject - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Driver - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { hasBlocksLock.wait() } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => logError("sendObject had a " + e) - } - logDebug("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -private[spark] class TreeBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} -- cgit v1.2.3