aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2011-06-27 13:43:44 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2011-06-27 13:43:44 -0700
commit3f08e1129f092cf80a136a1aa7d0134976c9e7fe (patch)
tree2d212d2dcaaa944e9481d8bd335d9d0d6dfb20c2 /core
parentad842ac823e5cbc8e9e66a4fcaa057c07bc0a291 (diff)
parentb187675b686a74c754e5d502b000ec007b5d4e48 (diff)
downloadspark-3f08e1129f092cf80a136a1aa7d0134976c9e7fe.tar.gz
spark-3f08e1129f092cf80a136a1aa7d0134976c9e7fe.tar.bz2
spark-3f08e1129f092cf80a136a1aa7d0134976c9e7fe.zip
Merge branch 'master' into td-rdd-save
Conflicts: core/src/main/scala/spark/SparkContext.scala
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/BitTorrentBroadcast.scala1236
-rw-r--r--core/src/main/scala/spark/Broadcast.scala140
-rw-r--r--core/src/main/scala/spark/ChainedBroadcast.scala873
-rw-r--r--core/src/main/scala/spark/Executor.scala2
-rw-r--r--core/src/main/scala/spark/LocalFileShuffle.scala7
-rw-r--r--core/src/main/scala/spark/RDD.scala46
-rw-r--r--core/src/main/scala/spark/Shuffle.scala15
-rw-r--r--core/src/main/scala/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/spark/Utils.scala44
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala1355
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala228
-rw-r--r--core/src/main/scala/spark/broadcast/BroadcastFactory.scala12
-rw-r--r--core/src/main/scala/spark/broadcast/ChainedBroadcast.scala792
-rw-r--r--core/src/main/scala/spark/broadcast/DfsBroadcast.scala (renamed from core/src/main/scala/spark/DfsBroadcast.scala)6
-rw-r--r--core/src/main/scala/spark/broadcast/SourceInfo.scala41
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala807
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala50
17 files changed, 3381 insertions, 2275 deletions
diff --git a/core/src/main/scala/spark/BitTorrentBroadcast.scala b/core/src/main/scala/spark/BitTorrentBroadcast.scala
deleted file mode 100644
index 126e61dc7d..0000000000
--- a/core/src/main/scala/spark/BitTorrentBroadcast.scala
+++ /dev/null
@@ -1,1236 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-
-@serializable
-class BitTorrentBroadcast[T] (@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging {
-
- def value = value_
-
- BitTorrentBroadcast.synchronized {
- BitTorrentBroadcast.values.put (uuid, value_)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var hasBlocksBitVector: BitSet = null
- @transient var numCopiesSent: Array[Int] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo] ()
-
- @transient var serveMR: ServeMultipleRequests = null
-
- // Used only in Master
- @transient var guideMR: GuideMultipleRequests = null
-
- // Used only in Workers
- @transient var ttGuide: TalkToGuide = null
-
- @transient var rxSpeeds = new SpeedTracker
- @transient var txSpeeds = new SpeedTracker
-
- @transient var hostAddress = InetAddress.getLocalHost.getHostAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var hasCopyInHDFS = false
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast
- }
-
- def sendBroadcast (): Unit = {
- logInfo ("Local host address: " + hostAddress)
-
- // Store a persistent copy in HDFS
- // TODO: Turned OFF for now
- // val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
- // out.writeObject (value_)
- // out.close
- // TODO: Fix this at some point
- hasCopyInHDFS = true
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = blockifyObject (value_, BitTorrentBroadcast.BlockSize)
-
- // Prepare the value being broadcasted
- // TODO: Refactoring and clean-up required here
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- // Guide has all the blocks
- hasBlocksBitVector = new BitSet (totalBlocks)
- hasBlocksBitVector.set (0, totalBlocks)
-
- // Guide still hasn't sent any block
- numCopiesSent = new Array[Int] (totalBlocks)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon (true)
- guideMR.start
- logInfo ("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized {
- guidePortLock.wait
- }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon (true)
- serveMR.start
- logInfo ("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- // Must always come AFTER listenPort is created
- val masterSource =
- SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes)
- hasBlocksBitVector.synchronized {
- masterSource.hasBlocksBitVector = hasBlocksBitVector
- }
-
- // In the beginning, this is the only known source to Guide
- listOfSources += masterSource
-
- // Register with the Tracker
- BitTorrentBroadcast.registerValue (uuid,
- SourceInfo (hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject (in: ObjectInputStream): Unit = {
- in.defaultReadObject
- BitTorrentBroadcast.synchronized {
- val cachedVal = BitTorrentBroadcast.values.get (uuid)
-
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- // Only the first worker in a node can ever be inside this 'else'
- initializeWorkerVariables
-
- logInfo ("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon (true)
- serveMR.start
- logInfo ("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast (uuid)
- // If does not succeed, then get from HDFS copy
- if (receptionSucceeded) {
- value_ = unBlockifyObject[T]
- BitTorrentBroadcast.values.put (uuid, value_)
- } else {
- // TODO: This part won't work, cause HDFS writing is turned OFF
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- BitTorrentBroadcast.values.put(uuid, value_)
- fileIn.close
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
- }
- }
-
- // Initialize variables in the worker node. Master sends everything as 0/null
- private def initializeWorkerVariables: Unit = {
- arrayOfBlocks = null
- hasBlocksBitVector = null
- numCopiesSent = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
-
- serveMR = null
- ttGuide = null
-
- rxSpeeds = new SpeedTracker
- txSpeeds = new SpeedTracker
-
- hostAddress = InetAddress.getLocalHost.getHostAddress
- listenPort = -1
-
- listOfSources = ListBuffer[SourceInfo] ()
-
- stopBroadcast = false
- }
-
- private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream (baos)
- oos.writeObject (obj)
- oos.close
- baos.close
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream (byteArray)
-
- var blockNum = (byteArray.length / blockSize)
- if (byteArray.length % blockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock] (blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, blockSize)) {
- val thisBlockSize = math.min (blockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte] (thisBlockSize)
- val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
-
- retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
- blockID += 1
- }
- bais.close
-
- var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- private def unBlockifyObject[A]: A = {
- var retByteArray = new Array[Byte] (totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BitTorrentBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject (retByteArray)
- }
-
- private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
- val retVal = in.readObject.asInstanceOf[A]
- in.close
- return retVal
- }
-
- private def getLocalSourceInfo: SourceInfo = {
- // Wait till hostName and listenPort are OK
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- // Wait till totalBlocks and totalBytes are OK
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized {
- totalBlocksLock.wait
- }
- }
-
- var localSourceInfo = SourceInfo (hostAddress, listenPort, totalBlocks,
- totalBytes)
-
- localSourceInfo.hasBlocks = hasBlocks
-
- hasBlocksBitVector.synchronized {
- localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
- }
-
- return localSourceInfo
- }
-
- // Add new SourceInfo to the listOfSources. Update if it exists already.
- // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance
- private def addToListOfSources (newSourceInfo: SourceInfo): Unit = {
- listOfSources.synchronized {
- if (listOfSources.contains(newSourceInfo)) {
- listOfSources = listOfSources - newSourceInfo
- }
- listOfSources += newSourceInfo
- }
- }
-
- private def addToListOfSources (newSourceInfos: ListBuffer[SourceInfo]): Unit = {
- newSourceInfos.foreach { newSourceInfo =>
- addToListOfSources (newSourceInfo)
- }
- }
-
- class TalkToGuide (gInfo: SourceInfo)
- extends Thread with Logging {
- override def run: Unit = {
-
- // Keep exchaning information until all blocks have been received
- while (hasBlocks < totalBlocks) {
- talkOnce
- Thread.sleep (BitTorrentBroadcast.ranGen.nextInt (
- BitTorrentBroadcast.MaxKnockInterval - BitTorrentBroadcast.MinKnockInterval) +
- BitTorrentBroadcast.MinKnockInterval)
- }
-
- // Talk one more time to let the Guide know of reception completion
- talkOnce
- }
-
- // Connect to Guide and send this worker's information
- private def talkOnce: Unit = {
- var clientSocketToGuide: Socket = null
- var oosGuide: ObjectOutputStream = null
- var oisGuide: ObjectInputStream = null
-
- clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
- oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream)
- oosGuide.flush
- oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream)
-
- // Send local information
- oosGuide.writeObject(getLocalSourceInfo)
- oosGuide.flush
-
- // Receive source information from Guide
- var suitableSources =
- oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logInfo("Received suitableSources from Master " + suitableSources)
-
- addToListOfSources (suitableSources)
-
- oisGuide.close
- oosGuide.close
- clientSocketToGuide.close
- }
- }
-
- def getGuideInfo (variableUUID: UUID): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo ("", SourceInfo.TxOverGoToHDFS,
- SourceInfo.UnusedParam, SourceInfo.UnusedParam)
-
- var retriesLeft = BitTorrentBroadcast.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- val clientSocketToTracker =
- new Socket(BitTorrentBroadcast.MasterHostAddress, BitTorrentBroadcast.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream (clientSocketToTracker.getOutputStream)
- oosTracker.flush
- val oisTracker =
- new ObjectInputStream (clientSocketToTracker.getInputStream)
-
- // Send UUID and receive GuideInfo
- oosTracker.writeObject (uuid)
- oosTracker.flush
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => {
- logInfo ("getGuideInfo had a " + e)
- }
- } finally {
- if (oisTracker != null) {
- oisTracker.close
- }
- if (oosTracker != null) {
- oosTracker.close
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close
- }
- }
-
- Thread.sleep (BitTorrentBroadcast.ranGen.nextInt (
- BitTorrentBroadcast.MaxKnockInterval - BitTorrentBroadcast.MinKnockInterval) +
- BitTorrentBroadcast.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logInfo ("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
- def receiveBroadcast (variableUUID: UUID): Boolean = {
- val gInfo = getGuideInfo (variableUUID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS ||
- gInfo.listenPort == SourceInfo.TxNotStartedRetry) {
- // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
- // to HDFS anyway when receiveBroadcast returns false
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- // Setup initial states of variables
- totalBlocks = gInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
- hasBlocksBitVector = new BitSet (totalBlocks)
- numCopiesSent = new Array[Int] (totalBlocks)
- totalBlocksLock.synchronized {
- totalBlocksLock.notifyAll
- }
- totalBytes = gInfo.totalBytes
-
- // Start ttGuide to periodically talk to the Guide
- var ttGuide = new TalkToGuide (gInfo)
- ttGuide.setDaemon (true)
- ttGuide.start
- logInfo ("TalkToGuide started...")
-
- // Start pController to run TalkToPeer threads
- var pcController = new PeerChatterController
- pcController.setDaemon (true)
- pcController.start
- logInfo ("PeerChatterController started...")
-
- // TODO: Must fix this. This might never break if broadcast fails.
- // We should be able to break and send false. Also need to kill threads
- while (hasBlocks < totalBlocks) {
- Thread.sleep(BitTorrentBroadcast.MaxKnockInterval)
- }
-
- return true
- }
-
- class PeerChatterController
- extends Thread with Logging {
- private var peersNowTalking = ListBuffer[SourceInfo] ()
- // TODO: There is a possible bug with blocksInRequestBitVector when a
- // certain bit is NOT unset upon failure resulting in an infinite loop.
- private var blocksInRequestBitVector = new BitSet (totalBlocks)
-
- override def run: Unit = {
- var threadPool =
- Broadcast.newDaemonFixedThreadPool (BitTorrentBroadcast.MaxTxPeers)
-
- while (hasBlocks < totalBlocks) {
- var numThreadsToCreate =
- math.min (listOfSources.size, BitTorrentBroadcast.MaxTxPeers) -
- threadPool.getActiveCount
-
- while (hasBlocks < totalBlocks && numThreadsToCreate > 0) {
- var peerToTalkTo = pickPeerToTalkTo
- if (peerToTalkTo != null) {
- threadPool.execute (new TalkToPeer (peerToTalkTo))
-
- // Add to peersNowTalking. Remove in the thread. We have to do this
- // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized {
- peersNowTalking += peerToTalkTo
- }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before starting some more threads
- Thread.sleep (BitTorrentBroadcast.MinKnockInterval)
- }
- // Shutdown the thread pool
- threadPool.shutdown
- }
-
- // Right now picking the one that has the most blocks this peer wants
- // Also picking peer randomly if no one has anything interesting
- private def pickPeerToTalkTo: SourceInfo = {
- var curPeer: SourceInfo = null
- var curMax = 0
-
- logInfo ("Picking peers to talk to...")
-
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo] ()
- synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
-
- peersNotInUse.foreach { eachSource =>
- var tempHasBlocksBitVector: BitSet = null
- hasBlocksBitVector.synchronized {
- tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
- tempHasBlocksBitVector.flip (0, tempHasBlocksBitVector.size)
- tempHasBlocksBitVector.and (eachSource.hasBlocksBitVector)
-
- if (tempHasBlocksBitVector.cardinality > curMax) {
- curPeer = eachSource
- curMax = tempHasBlocksBitVector.cardinality
- }
- }
-
- // Always pick randomly or randomly pick randomly?
- // Now always picking randomly
- if (curPeer == null && peersNotInUse.size > 0) {
- // Pick uniformly the i'th required peer
- var i = BitTorrentBroadcast.ranGen.nextInt (peersNotInUse.size)
-
- var peerIter = peersNotInUse.iterator
- curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
- }
-
- if (curPeer != null)
- logInfo ("Peer chosen: " + curPeer + " with " + curPeer.hasBlocksBitVector)
- else
- logInfo ("No peer chosen...")
-
- return curPeer
- }
-
- class TalkToPeer (peerToTalkTo: SourceInfo)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- override def run: Unit = {
- // TODO: There is a possible bug here regarding blocksInRequestBitVector
- var blockToAskFor = -1
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUpConnections
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule (timeOutTask, BitTorrentBroadcast.MaxKnockInterval)
-
- logInfo ("TalkToPeer started... => " + peerToTalkTo)
-
- try {
- // Connect to the source
- peerSocketToSource =
- new Socket (peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
- oosSource =
- new ObjectOutputStream (peerSocketToSource.getOutputStream)
- oosSource.flush
- oisSource =
- new ObjectInputStream (peerSocketToSource.getInputStream)
-
- // Receive latest SourceInfo from peerToTalkTo
- var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
- // Update listOfSources
- addToListOfSources (newPeerToTalkTo)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush
-
- var keepReceiving = true
-
- while (hasBlocks < totalBlocks && keepReceiving) {
- blockToAskFor =
- pickBlockToRequest (newPeerToTalkTo.hasBlocksBitVector)
-
- // No block to request
- if (blockToAskFor < 0) {
- // Nothing to receive from newPeerToTalkTo
- keepReceiving = false
- } else {
- // Let other thread know that blockToAskFor is being requested
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set (blockToAskFor)
- }
-
- // Start with sending the blockID
- oosSource.writeObject(blockToAskFor)
- oosSource.flush
-
- // Receive the requested block
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- // Expecting sender to send the block that was asked for
- assert (bcBlock.blockID == blockToAskFor)
-
- logInfo ("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
-
- if (!hasBlocksBitVector.get(bcBlock.blockID)) {
- arrayOfBlocks(bcBlock.blockID) = bcBlock
-
- // Update the hasBlocksBitVector first
- hasBlocksBitVector.synchronized {
- hasBlocksBitVector.set (bcBlock.blockID)
- }
- hasBlocks += 1
-
- rxSpeeds.addDataPoint (peerToTalkTo, receptionTime)
-
- // blockToAskFor has arrived. Not in request any more
- // Probably no need to update it though
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set (bcBlock.blockID, false)
- }
-
- // Reset blockToAskFor to -1. Else it will be considered missing
- blockToAskFor = -1
- }
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo ("TalktoPeer had a " + e)
- // TODO: Remove 'newPeerToTalkTo' from listOfSources
- // We probably should have the following in some form, but not
- // really here. This exception can happen if the sender just breaks connection
- // listOfSources.synchronized {
- // logInfo ("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
- // listOfSources = listOfSources - peerToTalkTo
- // }
- }
- } finally {
- // blockToAskFor != -1 => there was an exception
- if (blockToAskFor != -1) {
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set (blockToAskFor, false)
- }
- }
-
- cleanUpConnections
- }
- }
-
- // Right now it picks a block uniformly that this peer does not have
- // TODO: Implement more intelligent block selection policies
- private def pickBlockToRequest (txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // BitTorrentBroadcast.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks / totalBlocks) < BitTorrentBroadcast.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or (blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip (0, needBlocksBitVector.size)
-
- // Blocks that should be requested
- needBlocksBitVector.and (txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Pick uniformly the i'th required block
- var i = BitTorrentBroadcast.ranGen.nextInt (needBlocksBitVector.cardinality)
- var pickedBlockIndex = needBlocksBitVector.nextSetBit (0)
-
- while (i > 0) {
- pickedBlockIndex =
- needBlocksBitVector.nextSetBit (pickedBlockIndex + 1)
- i = i - 1
- }
-
- return pickedBlockIndex
- }
- }
-
- private def cleanUpConnections: Unit = {
- if (oisSource != null) {
- oisSource.close
- }
- if (oosSource != null) {
- oosSource.close
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close
- }
-
- // Delete from peersNowTalking
- peersNowTalking.synchronized {
- peersNowTalking = peersNowTalking - peerToTalkTo
- }
- }
- }
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo] ()
-
- override def run: Unit = {
- var threadPool = Broadcast.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (0)
- guidePort = serverSocket.getLocalPort
- logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized {
- guidePortLock.notifyAll
- }
-
- try {
- // Don't stop until there is a copy in HDFS
- while (!stopBroadcast || !hasCopyInHDFS) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (BitTorrentBroadcast.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("GuideMultipleRequests Timeout.")
-
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- }
- }
- }
- if (clientSocket != null) {
- logInfo ("Guide: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute (new GuideSingleRequest (clientSocket))
- } catch {
- // In failure, close the socket here; else, thread will close it
- case ioe: IOException => {
- clientSocket.close
- }
- }
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
-
- logInfo ("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- BitTorrentBroadcast.unregisterValue (uuid)
- } finally {
- if (serverSocket != null) {
- logInfo ("GuideMultipleRequests now stopping...")
- serverSocket.close
- }
- }
- }
-
- private def sendStopBroadcastNotifications: Unit = {
- listOfSources.synchronized {
- listOfSources.foreach { sourceInfo =>
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource =
- new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource =
- new ObjectOutputStream (guideSocketToSource.getOutputStream)
- gosSource.flush
- gisSource =
- new ObjectInputStream (guideSocketToSource.getInputStream)
-
- // Throw away whatever comes in
- gisSource.readObject.asInstanceOf[SourceInfo]
-
- // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
- gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast,
- SourceInfo.UnusedParam, SourceInfo.UnusedParam))
- gosSource.flush
- } catch {
- case e: Exception => {
- logInfo ("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close
- }
- if (gosSource != null) {
- gosSource.close
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close
- }
- }
- }
- }
- }
-
- class GuideSingleRequest (val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- private val ois = new ObjectInputStream (clientSocket.getInputStream)
-
- private var sourceInfo: SourceInfo = null
- private var selectedSources: ListBuffer[SourceInfo] = null
-
- override def run: Unit = {
- try {
- logInfo ("new GuideSingleRequest is running")
- // Connecting worker is sending in its information
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Select a suitable source and send it back to the worker
- selectedSources = selectSuitableSources (sourceInfo)
- logInfo ("Sending selectedSources:" + selectedSources)
- oos.writeObject (selectedSources)
- oos.flush
-
- // Add this source to the listOfSources
- addToListOfSources (sourceInfo)
- } catch {
- case e: Exception => {
- // Assuming exception caused by receiver failure: remove
- if (listOfSources != null) {
- listOfSources.synchronized {
- listOfSources = listOfSources - sourceInfo
- }
- }
- }
- } finally {
- ois.close
- oos.close
- clientSocket.close
- }
- }
-
- // Randomly select some sources to send back
- private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
- var selectedSources = ListBuffer[SourceInfo] ()
-
- // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
- // then add skipSourceInfo to setOfCompletedSources. Return blank.
- if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources += skipSourceInfo
- return selectedSources
- }
-
- listOfSources.synchronized {
- if (listOfSources.size <= BitTorrentBroadcast.MaxPeersInGuideResponse) {
- selectedSources = listOfSources.clone
- } else {
- var picksLeft = BitTorrentBroadcast.MaxPeersInGuideResponse
- var alreadyPicked = new BitSet (listOfSources.size)
-
- while (picksLeft > 0) {
- var i = -1
-
- do {
- i = BitTorrentBroadcast.ranGen.nextInt (listOfSources.size)
- } while (alreadyPicked.get(i))
-
- var peerIter = listOfSources.iterator
- var curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
-
- selectedSources += curPeer
- alreadyPicked.set (i)
-
- picksLeft = picksLeft - 1
- }
- }
- }
-
- // Remove the receiving source (if present)
- selectedSources = selectedSources - skipSourceInfo
-
- return selectedSources
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- // Server at most BitTorrentBroadcast.MaxRxPeers peers
- var threadPool =
- Broadcast.newDaemonFixedThreadPool(BitTorrentBroadcast.MaxRxPeers)
-
- override def run: Unit = {
- var serverSocket = new ServerSocket (0)
- listenPort = serverSocket.getLocalPort
-
- logInfo ("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized {
- listenPortLock.notifyAll
- }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (BitTorrentBroadcast.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("ServeMultipleRequests Timeout.")
- }
- }
- if (clientSocket != null) {
- logInfo ("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute (new ServeSingleRequest (clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo ("ServeMultipleRequests now stopping...")
- serverSocket.close
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown
- }
-
- class ServeSingleRequest (val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- private val ois = new ObjectInputStream (clientSocket.getInputStream)
-
- logInfo ("new ServeSingleRequest is running")
-
- override def run: Unit = {
- try {
- // Send latest local SourceInfo to the receiver
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(getLocalSourceInfo)
- oos.flush
-
- // Receive latest SourceInfo from the receiver
- var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
-
- if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- // Carry on
- addToListOfSources (rxSourceInfo)
- }
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = BitTorrentBroadcast.MaxChatBlocks
-
- while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
- // Receive which block to send
- val blockToSend = ois.readObject.asInstanceOf[Int]
-
- // Send the block
- sendBlock (blockToSend)
- rxSourceInfo.hasBlocksBitVector.set (blockToSend)
-
- numBlocksToSend = numBlocksToSend - 1
-
- // Receive latest SourceInfo from the receiver
- rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
- addToListOfSources (rxSourceInfo)
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= BitTorrentBroadcast.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- // Exception can happen if the receiver stops receiving
- case e: Exception => {
- logInfo ("ServeSingleRequest had a " + e)
- }
- } finally {
- logInfo ("ServeSingleRequest is closing streams and sockets")
- ois.close
- // TODO: The following line causes a "java.net.SocketException: Socket closed"
- oos.close
- clientSocket.close
- }
- }
-
- private def sendBlock (blockToSend: Int): Unit = {
- try {
- oos.writeObject (arrayOfBlocks(blockToSend))
- oos.flush
- } catch {
- case e: Exception => {
- logInfo ("sendBlock had a " + e)
- }
- }
- logInfo ("Sent block: " + blockToSend + " to " + clientSocket)
- }
- }
- }
-}
-
-class BitTorrentBroadcastFactory
-extends BroadcastFactory {
- def initialize (isMaster: Boolean) = BitTorrentBroadcast.initialize (isMaster)
- def newBroadcast[T] (value_ : T, isLocal: Boolean) =
- new BitTorrentBroadcast[T] (value_, isLocal)
-}
-
-private object BitTorrentBroadcast
-extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
- var valueToGuideMap = Map[UUID, SourceInfo] ()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var isMaster_ = false
-
- private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress
- private var MasterTrackerPort_ : Int = 11111
- private var BlockSize_ : Int = 512 * 1024
- private var MaxRetryCount_ : Int = 2
-
- private var TrackerSocketTimeout_ : Int = 50000
- private var ServerSocketTimeout_ : Int = 10000
-
- private var trackMV: TrackMultipleValues = null
-
- // A peer syncs back to Guide after waiting randomly within following limits
- // Also used thoughout the code for small and large waits/timeouts
- private var MinKnockInterval_ = 500
- private var MaxKnockInterval_ = 999
-
- private var MaxPeersInGuideResponse_ = 4
-
- // Maximum number of receiving and sending threads of a peer
- private var MaxRxPeers_ = 4
- private var MaxTxPeers_ = 4
-
- // Peers can char at most this milliseconds or transfer this number of blocks
- private var MaxChatTime_ = 250
- private var MaxChatBlocks_ = 1024
-
- // Fraction of blocks to receive before entering the end game
- private var EndGameFraction_ = 1.0
-
-
- def initialize (isMaster__ : Boolean): Unit = {
- synchronized {
- if (!initialized) {
- // Fix for issue #42
- MasterHostAddress_ =
- System.getProperty ("spark.broadcast.masterHostAddress", "")
- MasterTrackerPort_ =
- System.getProperty ("spark.broadcast.masterTrackerPort", "11111").toInt
- BlockSize_ =
- System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
- MaxRetryCount_ =
- System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
-
- TrackerSocketTimeout_ =
- System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt
- ServerSocketTimeout_ =
- System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt
-
- MinKnockInterval_ =
- System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt
- MaxKnockInterval_ =
- System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt
-
- MaxPeersInGuideResponse_ =
- System.getProperty ("spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- MaxRxPeers_ =
- System.getProperty ("spark.broadcast.maxRxPeers", "4").toInt
- MaxTxPeers_ =
- System.getProperty ("spark.broadcast.maxTxPeers", "4").toInt
-
- MaxChatTime_ =
- System.getProperty ("spark.broadcast.maxChatTime", "250").toInt
- MaxChatBlocks_ =
- System.getProperty ("spark.broadcast.maxChatBlocks", "1024").toInt
-
- EndGameFraction_ =
- System.getProperty ("spark.broadcast.endGameFraction", "1.0").toDouble
-
- isMaster_ = isMaster__
-
- if (isMaster) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon (true)
- trackMV.start
- logInfo ("TrackMultipleValues started...")
- }
-
- // Initialize DfsBroadcast to be used for broadcast variable persistence
- DfsBroadcast.initialize
-
- initialized = true
- }
- }
- }
-
- def MasterHostAddress = MasterHostAddress_
- def MasterTrackerPort = MasterTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def isMaster = isMaster_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxRxPeers = MaxRxPeers_
- def MaxTxPeers = MaxTxPeers_
-
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- def registerValue (uuid: UUID, gInfo: SourceInfo): Unit = {
- valueToGuideMap.synchronized {
- valueToGuideMap += (uuid -> gInfo)
- logInfo ("New value registered with the Tracker " + valueToGuideMap)
- }
- }
-
- def unregisterValue (uuid: UUID): Unit = {
- valueToGuideMap.synchronized {
- valueToGuideMap (uuid) = SourceInfo ("", SourceInfo.TxOverGoToHDFS,
- SourceInfo.UnusedParam, SourceInfo.UnusedParam)
- logInfo ("Value unregistered from the Tracker " + valueToGuideMap)
- }
- }
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run: Unit = {
- var threadPool = Broadcast.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (BitTorrentBroadcast.MasterTrackerPort)
- logInfo ("TrackMultipleValues" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (TrackerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("TrackMultipleValues Timeout. Stopping listening...")
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute (new Thread {
- override def run: Unit = {
- val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- val ois = new ObjectInputStream (clientSocket.getInputStream)
- try {
- val uuid = ois.readObject.asInstanceOf[UUID]
- var gInfo =
- if (valueToGuideMap.contains (uuid)) {
- valueToGuideMap (uuid)
- } else SourceInfo ("", SourceInfo.TxNotStartedRetry,
- SourceInfo.UnusedParam, SourceInfo.UnusedParam)
- logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
- oos.writeObject (gInfo)
- } catch {
- case e: Exception => {
- logInfo ("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close
- oos.close
- clientSocket.close
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => {
- clientSocket.close
- }
- }
- }
- }
- } finally {
- serverSocket.close
- }
- // Shutdown the thread pool
- threadPool.shutdown
- }
- }
-}
diff --git a/core/src/main/scala/spark/Broadcast.scala b/core/src/main/scala/spark/Broadcast.scala
deleted file mode 100644
index fe2ab1ebf0..0000000000
--- a/core/src/main/scala/spark/Broadcast.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-package spark
-
-import java.util.{BitSet, UUID}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-
-@serializable
-trait Broadcast[T] {
- val uuid = UUID.randomUUID
-
- def value: T
-
- // We cannot have an abstract readObject here due to some weird issues with
- // readObject having to be 'private' in sub-classes. Possibly a Scala bug!
-
- override def toString = "spark.Broadcast(" + uuid + ")"
-}
-
-trait BroadcastFactory {
- def initialize (isMaster: Boolean): Unit
- def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
-}
-
-private object Broadcast
-extends Logging {
- private var initialized = false
- private var broadcastFactory: BroadcastFactory = null
-
- // Called by SparkContext or Executor before using Broadcast
- def initialize (isMaster: Boolean): Unit = synchronized {
- if (!initialized) {
- val broadcastFactoryClass = System.getProperty("spark.broadcast.factory",
- "spark.DfsBroadcastFactory")
- val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef])
-
- broadcastFactory =
- Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
-
- // Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isMaster)
-
- initialized = true
- }
- }
-
- def getBroadcastFactory: BroadcastFactory = {
- if (broadcastFactory == null) {
- throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
- }
- broadcastFactory
- }
-
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread (r)
- t.setDaemon (true)
- return t
- }
- }
- }
-
- // Wrapper over newCachedThreadPool
- def newDaemonCachedThreadPool: ThreadPoolExecutor = {
- var threadPool =
- Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
-}
-
-@serializable
-case class SourceInfo (val hostAddress: String, val listenPort: Int,
- val totalBlocks: Int, val totalBytes: Int)
-extends Comparable[SourceInfo] with Logging {
-
- var currentLeechers = 0
- var receptionFailed = false
-
- var hasBlocks = 0
- var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
-
- // Ascending sort based on leecher count
- def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)}
-
-object SourceInfo {
- // Constants for special values of listenPort
- val TxNotStartedRetry = -1
- val TxOverGoToHDFS = 0
- // Other constants
- val StopBroadcast = -2
- val UnusedParam = 0
-}
-
-@serializable
-case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
-
-@serializable
-case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
- val totalBlocks: Int, val totalBytes: Int) {
- @transient var hasBlocks = 0
-}
-
-@serializable
-class SpeedTracker {
- // Mapping 'source' to '(totalTime, numBlocks)'
- private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
-
- def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
- sourceToSpeedMap.synchronized {
- if (!sourceToSpeedMap.contains(srcInfo)) {
- sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
- } else {
- val tTnB = sourceToSpeedMap (srcInfo)
- sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
- }
- }
- }
-
- def getTimePerBlock (srcInfo: SourceInfo): Double = {
- sourceToSpeedMap.synchronized {
- val tTnB = sourceToSpeedMap (srcInfo)
- return tTnB._1 / tTnB._2
- }
- }
-
- override def toString = sourceToSpeedMap.toString
-}
diff --git a/core/src/main/scala/spark/ChainedBroadcast.scala b/core/src/main/scala/spark/ChainedBroadcast.scala
deleted file mode 100644
index 6f2cc3f6f0..0000000000
--- a/core/src/main/scala/spark/ChainedBroadcast.scala
+++ /dev/null
@@ -1,873 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{Comparator, PriorityQueue, Random, UUID}
-
-import scala.collection.mutable.{Map, Set}
-
-@serializable
-class ChainedBroadcast[T] (@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging {
-
- def value = value_
-
- ChainedBroadcast.synchronized {
- ChainedBroadcast.values.put (uuid, value_)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var pqOfSources = new PriorityQueue[SourceInfo]
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = InetAddress.getLocalHost.getHostAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var hasCopyInHDFS = false
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast
- }
-
- def sendBroadcast (): Unit = {
- logInfo ("Local host address: " + hostAddress)
-
- // Store a persistent copy in HDFS
- // TODO: Turned OFF for now
- // val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
- // out.writeObject (value_)
- // out.close
- // TODO: Fix this at some point
- hasCopyInHDFS = true
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon (true)
- guideMR.start
- logInfo ("GuideMultipleRequests started...")
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon (true)
- serveMR.start
- logInfo ("ServeMultipleRequests started...")
-
- // Prepare the value being broadcasted
- // TODO: Refactoring and clean-up required here
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- pqOfSources = new PriorityQueue[SourceInfo]
- val masterSource_0 =
- SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes)
- pqOfSources.add (masterSource_0)
-
- // Register with the Tracker
- while (guidePort == -1) {
- guidePortLock.synchronized {
- guidePortLock.wait
- }
- }
- ChainedBroadcast.registerValue (uuid, guidePort)
- }
-
- private def readObject (in: ObjectInputStream): Unit = {
- in.defaultReadObject
- ChainedBroadcast.synchronized {
- val cachedVal = ChainedBroadcast.values.get (uuid)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- // Initializing everything because Master will only send null/0 values
- initializeSlaveVariables
-
- logInfo ("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon (true)
- serveMR.start
- logInfo ("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast (uuid)
- // If does not succeed, then get from HDFS copy
- if (receptionSucceeded) {
- value_ = unBlockifyObject[T]
- ChainedBroadcast.values.put (uuid, value_)
- } else {
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- ChainedBroadcast.values.put(uuid, value_)
- fileIn.close
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
- }
- }
-
- private def initializeSlaveVariables: Unit = {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = InetAddress.getLocalHost.getHostAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream (baos)
- oos.writeObject (obj)
- oos.close
- baos.close
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream (byteArray)
-
- var blockNum = (byteArray.length / blockSize)
- if (byteArray.length % blockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock] (blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, blockSize)) {
- val thisBlockSize = math.min (blockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte] (thisBlockSize)
- val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
-
- retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
- blockID += 1
- }
- bais.close
-
- var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- private def unBlockifyObject[A]: A = {
- var retByteArray = new Array[Byte] (totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject (retByteArray)
- }
-
- private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
- val retVal = in.readObject.asInstanceOf[A]
- in.close
- return retVal
- }
-
- def getMasterListenPort (variableUUID: UUID): Int = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
-
- var retriesLeft = ChainedBroadcast.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out the guide
- val clientSocketToTracker =
- new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream (clientSocketToTracker.getOutputStream)
- oosTracker.flush
- val oisTracker =
- new ObjectInputStream (clientSocketToTracker.getInputStream)
-
- // Send UUID and receive masterListenPort
- oosTracker.writeObject (uuid)
- oosTracker.flush
- masterListenPort = oisTracker.readObject.asInstanceOf[Int]
- } catch {
- case e: Exception => {
- logInfo ("getMasterListenPort had a " + e)
- }
- } finally {
- if (oisTracker != null) {
- oisTracker.close
- }
- if (oosTracker != null) {
- oosTracker.close
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close
- }
- }
- retriesLeft -= 1
-
- Thread.sleep (ChainedBroadcast.ranGen.nextInt (
- ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) +
- ChainedBroadcast.MinKnockInterval)
-
- } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
-
- logInfo ("Got this guidePort from Tracker: " + masterListenPort)
- return masterListenPort
- }
-
- def receiveBroadcast (variableUUID: UUID): Boolean = {
- val masterListenPort = getMasterListenPort (variableUUID)
-
- if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
- masterListenPort == SourceInfo.TxNotStartedRetry) {
- // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
- // to HDFS anyway when receiveBroadcast returns false
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- var clientSocketToMaster: Socket = null
- var oosMaster: ObjectOutputStream = null
- var oisMaster: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = ChainedBroadcast.MaxRetryCount
- do {
- // Connect to Master and send this worker's Information
- clientSocketToMaster =
- new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort)
- // TODO: Guiding object connection is reusable
- oosMaster =
- new ObjectOutputStream (clientSocketToMaster.getOutputStream)
- oosMaster.flush
- oisMaster =
- new ObjectInputStream (clientSocketToMaster.getInputStream)
-
- logInfo ("Connected to Master's guiding object")
-
- // Send local source information
- oosMaster.writeObject(SourceInfo (hostAddress, listenPort,
- SourceInfo.UnusedParam, SourceInfo.UnusedParam))
- oosMaster.flush
-
- // Receive source information from Master
- var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
- totalBlocksLock.synchronized {
- totalBlocksLock.notifyAll
- }
- totalBytes = sourceInfo.totalBytes
-
- logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission (sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Master will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Master
- oosMaster.writeObject (sourceInfo)
-
- if (oisMaster != null) {
- oisMaster.close
- }
- if (oosMaster != null) {
- oosMaster.close
- }
- if (clientSocketToMaster != null) {
- clientSocketToMaster.close
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- // Tries to receive broadcast from the source and returns Boolean status.
- // This might be called multiple times to retry a defined number of times.
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource =
- new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource =
- new ObjectOutputStream (clientSocketToSource.getOutputStream)
- oosSource.flush
- oisSource =
- new ObjectInputStream (clientSocketToSource.getInputStream)
-
- logInfo ("Inside receiveSingleTransmission")
- logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized {
- hasBlocksLock.notifyAll
- }
- }
- } catch {
- case e: Exception => {
- logInfo ("receiveSingleTransmission had a " + e)
- }
- } finally {
- if (oisSource != null) {
- oisSource.close
- }
- if (oosSource != null) {
- oosSource.close
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo] ()
-
- override def run: Unit = {
- var threadPool = Broadcast.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (0)
- guidePort = serverSocket.getLocalPort
- logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized {
- guidePortLock.notifyAll
- }
-
- try {
- // Don't stop until there is a copy in HDFS
- while (!stopBroadcast || !hasCopyInHDFS) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("GuideMultipleRequests Timeout.")
-
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // pqOfSources.size - 1, because it includes the Guide itself
- if (pqOfSources.size > 1 &&
- setOfCompletedSources.size == pqOfSources.size - 1) {
- stopBroadcast = true
- }
- }
- }
- if (clientSocket != null) {
- logInfo ("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute (new GuideSingleRequest (clientSocket))
- } catch {
- // In failure, close the socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close
- }
- }
- }
-
- logInfo ("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- ChainedBroadcast.unregisterValue (uuid)
- } finally {
- if (serverSocket != null) {
- logInfo ("GuideMultipleRequests now stopping...")
- serverSocket.close
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
- }
-
- private def sendStopBroadcastNotifications: Unit = {
- pqOfSources.synchronized {
- var pqIter = pqOfSources.iterator
- while (pqIter.hasNext) {
- var sourceInfo = pqIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource =
- new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource =
- new ObjectOutputStream (guideSocketToSource.getOutputStream)
- gosSource.flush
- gisSource =
- new ObjectInputStream (guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
- gosSource.writeObject ((SourceInfo.StopBroadcast,
- SourceInfo.StopBroadcast))
- gosSource.flush
- } catch {
- case e: Exception => {
- logInfo ("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close
- }
- if (gosSource != null) {
- gosSource.close
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close
- }
- }
- }
- }
- }
-
- class GuideSingleRequest (val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- private val ois = new ObjectInputStream (clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run: Unit = {
- try {
- logInfo ("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- pqOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource (sourceInfo)
- logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject (selectedSourceInfo)
- oos.flush
-
- // Add this new (if it can finish) source to the PQ of sources
- thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes)
- logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo)
- pqOfSources.add (thisWorkerInfo)
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in pqOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- pqOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert (pqOfSources.contains (selectedSourceInfo))
-
- // Remove first
- pqOfSources.remove (selectedSourceInfo)
- // TODO: Removing a source based on just one failure notification!
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources += thisWorkerInfo
-
- selectedSourceInfo.currentLeechers -= 1
-
- // Put it back
- pqOfSources.add (selectedSourceInfo)
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- case e: Exception => {
- // Assuming that exception caused due to receiver worker failure.
- // Remove failed worker from pqOfSources and update leecherCount of
- // corresponding source worker
- pqOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- pqOfSources.remove (selectedSourceInfo)
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- pqOfSources.add (selectedSourceInfo)
- }
-
- // Remove thisWorkerInfo
- if (pqOfSources != null) {
- pqOfSources.remove (thisWorkerInfo)
- }
- }
- }
- } finally {
- ois.close
- oos.close
- clientSocket.close
- }
- }
-
- // TODO: Caller must have a synchronized block on pqOfSources
- // TODO: If a worker fails to get the broadcasted variable from a source and
- // comes back to Master, this function might choose the worker itself as a
- // source tp create a dependency cycle (this worker was put into pqOfSources
- // as a streming source when it first arrived). The length of this cycle can
- // be arbitrarily long.
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- // Select one based on the ordering strategy (e.g., least leechers etc.)
- // take is a blocking call removing the element from PQ
- var selectedSource = pqOfSources.poll
- assert (selectedSource != null)
- // Update leecher count
- selectedSource.currentLeechers += 1
- // Add it back and then return
- pqOfSources.add (selectedSource)
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- override def run: Unit = {
- var threadPool = Broadcast.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (0)
- listenPort = serverSocket.getLocalPort
- logInfo ("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized {
- listenPortLock.notifyAll
- }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("ServeMultipleRequests Timeout.")
- }
- }
- if (clientSocket != null) {
- logInfo ("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute (new ServeSingleRequest (clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo ("ServeMultipleRequests now stopping...")
- serverSocket.close
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
- }
-
- class ServeSingleRequest (val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- private val ois = new ObjectInputStream (clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run: Unit = {
- try {
- logInfo ("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- if (sendFrom == SourceInfo.StopBroadcast &&
- sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- // Carry on
- sendObject
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- case e: Exception => {
- logInfo ("ServeSingleRequest had a " + e)
- }
- } finally {
- logInfo ("ServeSingleRequest is closing streams and sockets")
- ois.close
- oos.close
- clientSocket.close
- }
- }
-
- private def sendObject: Unit = {
- // Wait till receiving the SourceInfo from Master
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized {
- totalBlocksLock.wait
- }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized {
- hasBlocksLock.wait
- }
- }
- try {
- oos.writeObject (arrayOfBlocks(i))
- oos.flush
- } catch {
- case e: Exception => {
- logInfo ("sendObject had a " + e)
- }
- }
- logInfo ("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-class ChainedBroadcastFactory
-extends BroadcastFactory {
- def initialize (isMaster: Boolean) = ChainedBroadcast.initialize (isMaster)
- def newBroadcast[T] (value_ : T, isLocal: Boolean) =
- new ChainedBroadcast[T] (value_, isLocal)
-}
-
-private object ChainedBroadcast
-extends Logging {
- val values = SparkEnv.get.cache.newKeySpace()
-
- var valueToGuidePortMap = Map[UUID, Int] ()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var isMaster_ = false
-
- private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress
- private var MasterTrackerPort_ : Int = 22222
- private var BlockSize_ : Int = 512 * 1024
- private var MaxRetryCount_ : Int = 2
-
- private var TrackerSocketTimeout_ : Int = 50000
- private var ServerSocketTimeout_ : Int = 10000
-
- private var trackMV: TrackMultipleValues = null
-
- private var MinKnockInterval_ = 500
- private var MaxKnockInterval_ = 999
-
- def initialize (isMaster__ : Boolean): Unit = {
- synchronized {
- if (!initialized) {
- // Fix for issue #42
- MasterHostAddress_ =
- System.getProperty ("spark.broadcast.masterHostAddress", "")
- MasterTrackerPort_ =
- System.getProperty ("spark.broadcast.masterTrackerPort", "22222").toInt
- BlockSize_ =
- System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
- MaxRetryCount_ =
- System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
-
- TrackerSocketTimeout_ =
- System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt
- ServerSocketTimeout_ =
- System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt
-
- MinKnockInterval_ =
- System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt
- MaxKnockInterval_ =
- System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt
-
- isMaster_ = isMaster__
-
- if (isMaster) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon (true)
- trackMV.start
- logInfo ("TrackMultipleValues started...")
- }
-
- // Initialize DfsBroadcast to be used for broadcast variable persistence
- DfsBroadcast.initialize
-
- initialized = true
- }
- }
- }
-
- def MasterHostAddress = MasterHostAddress_
- def MasterTrackerPort = MasterTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def isMaster = isMaster_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- def registerValue (uuid: UUID, guidePort: Int): Unit = {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap += (uuid -> guidePort)
- logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
- }
- }
-
- def unregisterValue (uuid: UUID): Unit = {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
- logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
- }
- }
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run: Unit = {
- var threadPool = Broadcast.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort)
- logInfo ("TrackMultipleValues" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (TrackerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("TrackMultipleValues Timeout. Stopping listening...")
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute (new Thread {
- override def run: Unit = {
- val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- val ois = new ObjectInputStream (clientSocket.getInputStream)
- try {
- val uuid = ois.readObject.asInstanceOf[UUID]
- var guidePort =
- if (valueToGuidePortMap.contains (uuid)) {
- valueToGuidePortMap (uuid)
- } else SourceInfo.TxNotStartedRetry
- logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
- oos.writeObject (guidePort)
- } catch {
- case e: Exception => {
- logInfo ("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close
- oos.close
- clientSocket.close
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close
- }
- }
- }
- } finally {
- serverSocket.close
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
- }
- }
-}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index 9c89e34749..0942fecff3 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -9,6 +9,8 @@ import scala.collection.mutable.ArrayBuffer
import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver}
import mesos.{TaskDescription, TaskState, TaskStatus}
+import spark.broadcast._
+
/**
* The Mesos executor for Spark.
*/
diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala
index da3897b117..6c7f3dede2 100644
--- a/core/src/main/scala/spark/LocalFileShuffle.scala
+++ b/core/src/main/scala/spark/LocalFileShuffle.scala
@@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import spark._
object LocalFileShuffle extends Logging {
private var initialized = false
@@ -29,9 +30,9 @@ object LocalFileShuffle extends Logging {
while (!foundLocalDir && tries < 10) {
tries += 1
try {
- localDirUuid = UUID.randomUUID()
+ localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists()) {
+ if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
@@ -47,6 +48,7 @@ object LocalFileShuffle extends Logging {
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
+
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
@@ -65,6 +67,7 @@ object LocalFileShuffle extends Logging {
serverUri = server.uri
}
initialized = true
+ logInfo("Local URI: " + serverUri)
}
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 15f15ba5b2..5ea5fa9b6e 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -263,6 +263,44 @@ extends RDD[Array[T]](prev.context) {
}
}
+ def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = {
+ val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) }
+ val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) }
+ (vs ++ ws).groupByKey(numSplits).flatMap {
+ case (k, seq) => {
+ val vbuf = new ArrayBuffer[V]
+ val wbuf = new ArrayBuffer[Option[W]]
+ seq.foreach(_ match {
+ case Left(v) => vbuf += v
+ case Right(w) => wbuf += Some(w)
+ })
+ if (wbuf.isEmpty) {
+ wbuf += None
+ }
+ for (v <- vbuf; w <- wbuf) yield (k, (v, w))
+ }
+ }
+ }
+
+ def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = {
+ val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) }
+ val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) }
+ (vs ++ ws).groupByKey(numSplits).flatMap {
+ case (k, seq) => {
+ val vbuf = new ArrayBuffer[Option[V]]
+ val wbuf = new ArrayBuffer[W]
+ seq.foreach(_ match {
+ case Left(v) => vbuf += Some(v)
+ case Right(w) => wbuf += w
+ })
+ if (vbuf.isEmpty) {
+ vbuf += None
+ }
+ for (v <- vbuf; w <- wbuf) yield (k, (v, w))
+ }
+ }
+ }
+
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
@@ -282,6 +320,14 @@ extends RDD[Array[T]](prev.context) {
join(other, numCores)
}
+ def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
+ leftOuterJoin(other, numCores)
+ }
+
+ def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
+ rightOuterJoin(other, numCores)
+ }
+
def numCores = self.context.numCores
def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
diff --git a/core/src/main/scala/spark/Shuffle.scala b/core/src/main/scala/spark/Shuffle.scala
deleted file mode 100644
index 4c5649b537..0000000000
--- a/core/src/main/scala/spark/Shuffle.scala
+++ /dev/null
@@ -1,15 +0,0 @@
-package spark
-
-/**
- * A trait for shuffle system. Given an input RDD and combiner functions
- * for PairRDDExtras.combineByKey(), returns an output RDD.
- */
-@serializable
-trait Shuffle[K, V, C] {
- def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)]
-}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index cb80506923..5f12b247a7 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -18,6 +18,8 @@ import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
+import spark.broadcast._
+
class SparkContext(
master: String,
frameworkName: String,
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 4476fe44c1..1d77c357c3 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -3,6 +3,7 @@ package spark
import java.io._
import java.net.InetAddress
import java.util.UUID
+import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
@@ -117,12 +118,43 @@ object Utils {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4)
*/
- def localIpAddress(): String = {
- // Get local IP as an array of four bytes
- val bytes = InetAddress.getLocalHost().getAddress()
- // Convert the bytes to ints (keeping in mind that they may be negative)
- // and join them into a string
- return bytes.map(b => (b.toInt + 256) % 256).mkString(".")
+ def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
+
+ /**
+ * Returns a standard ThreadFactory except all threads are daemons
+ */
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread (r)
+ t.setDaemon (true)
+ return t
+ }
+ }
+ }
+
+ /**
+ * Wrapper over newCachedThreadPool
+ */
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory (newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ /**
+ * Wrapper over newFixedThreadPool
+ */
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
}
/**
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
new file mode 100644
index 0000000000..220456a210
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -0,0 +1,1355 @@
+package spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable.{ListBuffer, Map, Set}
+import scala.math
+
+import spark._
+
+@serializable
+class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
+extends Broadcast[T] with Logging {
+
+ def value = value_
+
+ BitTorrentBroadcast.synchronized {
+ BitTorrentBroadcast.values.put(uuid, value_)
+ }
+
+ @transient var arrayOfBlocks: Array[BroadcastBlock] = null
+ @transient var hasBlocksBitVector: BitSet = null
+ @transient var numCopiesSent: Array[Int] = null
+ @transient var totalBytes = -1
+ @transient var totalBlocks = -1
+ @transient var hasBlocks = new AtomicInteger(0)
+ // CHANGED: BlockSize in the Broadcast object is expected to change over time
+ @transient var blockSize = Broadcast.BlockSize
+
+ // Used ONLY by Master to track how many unique blocks have been sent out
+ @transient var sentBlocks = new AtomicInteger(0)
+
+ @transient var listenPortLock = new Object
+ @transient var guidePortLock = new Object
+ @transient var totalBlocksLock = new Object
+
+ @transient var listOfSources = ListBuffer[SourceInfo]()
+
+ @transient var serveMR: ServeMultipleRequests = null
+
+ // Used only in Master
+ @transient var guideMR: GuideMultipleRequests = null
+
+ // Used only in Workers
+ @transient var ttGuide: TalkToGuide = null
+
+ @transient var rxSpeeds = new SpeedTracker
+ @transient var txSpeeds = new SpeedTracker
+
+ @transient var hostAddress = Utils.localIpAddress
+ @transient var listenPort = -1
+ @transient var guidePort = -1
+
+ @transient var hasCopyInHDFS = false
+ @transient var stopBroadcast = false
+
+ // Must call this after all the variables have been created/initialized
+ if (!isLocal) {
+ sendBroadcast
+ }
+
+ def sendBroadcast(): Unit = {
+ logInfo("Local host address: " + hostAddress)
+
+ // Store a persistent copy in HDFS
+ // TODO: Turned OFF for now. Related to persistence
+ // val out = new ObjectOutputStream(BroadcastCH.openFileForWriting(uuid))
+ // out.writeObject(value_)
+ // out.close()
+ // FIXME: Fix this at some point
+ hasCopyInHDFS = true
+
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = Broadcast.blockifyObject(value_)
+
+ // Prepare the value being broadcasted
+ // TODO: Refactoring and clean-up required here
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks.set(variableInfo.totalBlocks)
+
+ // Guide has all the blocks
+ hasBlocksBitVector = new BitSet(totalBlocks)
+ hasBlocksBitVector.set(0, totalBlocks)
+
+ // Guide still hasn't sent any block
+ numCopiesSent = new Array[Int](totalBlocks)
+
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon(true)
+ guideMR.start()
+ logInfo("GuideMultipleRequests started...")
+
+ // Must always come AFTER guideMR is created
+ while (guidePort == -1) {
+ guidePortLock.synchronized {
+ guidePortLock.wait
+ }
+ }
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ // Must always come AFTER serveMR is created
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ // Must always come AFTER listenPort is created
+ val masterSource =
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+ hasBlocksBitVector.synchronized {
+ masterSource.hasBlocksBitVector = hasBlocksBitVector
+ }
+
+ // In the beginning, this is the only known source to Guide
+ listOfSources += masterSource
+
+ // Register with the Tracker
+ registerBroadcast(uuid,
+ SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes, blockSize))
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ in.defaultReadObject
+ BitTorrentBroadcast.synchronized {
+ val cachedVal = BitTorrentBroadcast.values.get(uuid)
+
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ // Only the first worker in a node can ever be inside this 'else'
+ initializeWorkerVariables
+
+ logInfo("Local host address: " + hostAddress)
+
+ // Start local ServeMultipleRequests thread first
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(uuid)
+ // If does not succeed, then get from HDFS copy
+ if (receptionSucceeded) {
+ value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ BitTorrentBroadcast.values.put(uuid, value_)
+ } else {
+ // TODO: This part won't work, cause HDFS writing is turned OFF
+ val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ BitTorrentBroadcast.values.put(uuid, value_)
+ fileIn.close()
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ // Initialize variables in the worker node. Master sends everything as 0/null
+ private def initializeWorkerVariables: Unit = {
+ arrayOfBlocks = null
+ hasBlocksBitVector = null
+ numCopiesSent = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = new AtomicInteger(0)
+ blockSize = -1
+
+ listenPortLock = new Object
+ totalBlocksLock = new Object
+
+ serveMR = null
+ ttGuide = null
+
+ rxSpeeds = new SpeedTracker
+ txSpeeds = new SpeedTracker
+
+ hostAddress = Utils.localIpAddress
+ listenPort = -1
+
+ listOfSources = ListBuffer[SourceInfo]()
+
+ stopBroadcast = false
+ }
+
+ private def registerBroadcast(uuid: UUID, gInfo: SourceInfo): Unit = {
+ val socket = new Socket(Broadcast.MasterHostAddress,
+ Broadcast.MasterTrackerPort)
+ val oosST = new ObjectOutputStream(socket.getOutputStream)
+ oosST.flush()
+ val oisST = new ObjectInputStream(socket.getInputStream)
+
+ // Send messageType/intention
+ oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER)
+ oosST.flush()
+
+ // Send UUID of this broadcast
+ oosST.writeObject(uuid)
+ oosST.flush()
+
+ // Send this tracker's information
+ oosST.writeObject(gInfo)
+ oosST.flush()
+
+ // Receive ACK and throw it away
+ oisST.readObject.asInstanceOf[Int]
+
+ // Shut stuff down
+ oisST.close()
+ oosST.close()
+ socket.close()
+ }
+
+ private def unregisterBroadcast(uuid: UUID): Unit = {
+ val socket = new Socket(Broadcast.MasterHostAddress,
+ Broadcast.MasterTrackerPort)
+ val oosST = new ObjectOutputStream(socket.getOutputStream)
+ oosST.flush()
+ val oisST = new ObjectInputStream(socket.getInputStream)
+
+ // Send messageType/intention
+ oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER)
+ oosST.flush()
+
+ // Send UUID of this broadcast
+ oosST.writeObject(uuid)
+ oosST.flush()
+
+ // Receive ACK and throw it away
+ oisST.readObject.asInstanceOf[Int]
+
+ // Shut stuff down
+ oisST.close()
+ oosST.close()
+ socket.close()
+ }
+
+ private def getLocalSourceInfo: SourceInfo = {
+ // Wait till hostName and listenPort are OK
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ // Wait till totalBlocks and totalBytes are OK
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized {
+ totalBlocksLock.wait
+ }
+ }
+
+ var localSourceInfo = SourceInfo(
+ hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+
+ localSourceInfo.hasBlocks = hasBlocks.get
+
+ hasBlocksBitVector.synchronized {
+ localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
+ }
+
+ return localSourceInfo
+ }
+
+ // Add new SourceInfo to the listOfSources. Update if it exists already.
+ // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance
+ private def addToListOfSources(newSourceInfo: SourceInfo): Unit = {
+ listOfSources.synchronized {
+ if (listOfSources.contains(newSourceInfo)) {
+ listOfSources = listOfSources - newSourceInfo
+ }
+ listOfSources += newSourceInfo
+ }
+ }
+
+ private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]): Unit = {
+ newSourceInfos.foreach { newSourceInfo =>
+ addToListOfSources(newSourceInfo)
+ }
+ }
+
+ class TalkToGuide(gInfo: SourceInfo)
+ extends Thread with Logging {
+ override def run: Unit = {
+
+ // Keep exchaning information until all blocks have been received
+ while (hasBlocks.get < totalBlocks) {
+ talkOnce
+ Thread.sleep(BitTorrentBroadcast.ranGen.nextInt(
+ Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
+ Broadcast.MinKnockInterval)
+ }
+
+ // Talk one more time to let the Guide know of reception completion
+ talkOnce
+ }
+
+ // Connect to Guide and send this worker's information
+ private def talkOnce: Unit = {
+ var clientSocketToGuide: Socket = null
+ var oosGuide: ObjectOutputStream = null
+ var oisGuide: ObjectInputStream = null
+
+ clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
+ oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
+ oosGuide.flush()
+ oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
+
+ // Send local information
+ oosGuide.writeObject(getLocalSourceInfo)
+ oosGuide.flush()
+
+ // Receive source information from Guide
+ var suitableSources =
+ oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
+ logInfo("Received suitableSources from Master " + suitableSources)
+
+ addToListOfSources(suitableSources)
+
+ oisGuide.close()
+ oosGuide.close()
+ clientSocketToGuide.close()
+ }
+ }
+
+ def getGuideInfo(variableUUID: UUID): SourceInfo = {
+ var clientSocketToTracker: Socket = null
+ var oosTracker: ObjectOutputStream = null
+ var oisTracker: ObjectInputStream = null
+
+ var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToHDFS)
+
+ var retriesLeft = Broadcast.MaxRetryCount
+ do {
+ try {
+ // Connect to the tracker to find out GuideInfo
+ clientSocketToTracker =
+ new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
+ oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ // Send messageType/intention
+ oosTracker.writeObject(Broadcast.FIND_BROADCAST_TRACKER)
+ oosTracker.flush()
+
+ // Send UUID and receive GuideInfo
+ oosTracker.writeObject(uuid)
+ oosTracker.flush()
+ gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
+ } catch {
+ case e: Exception => {
+ logInfo("getGuideInfo had a " + e)
+ }
+ } finally {
+ if (oisTracker != null) {
+ oisTracker.close()
+ }
+ if (oosTracker != null) {
+ oosTracker.close()
+ }
+ if (clientSocketToTracker != null) {
+ clientSocketToTracker.close()
+ }
+ }
+
+ Thread.sleep(BitTorrentBroadcast.ranGen.nextInt(
+ Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
+ Broadcast.MinKnockInterval)
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
+
+ logInfo("Got this guidePort from Tracker: " + gInfo.listenPort)
+ return gInfo
+ }
+
+ def receiveBroadcast(variableUUID: UUID): Boolean = {
+ val gInfo = getGuideInfo(variableUUID)
+
+ if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS ||
+ gInfo.listenPort == SourceInfo.TxNotStartedRetry) {
+ // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
+ // to HDFS anyway when receiveBroadcast returns false
+ return false
+ }
+
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ // Setup initial states of variables
+ totalBlocks = gInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
+ hasBlocksBitVector = new BitSet(totalBlocks)
+ numCopiesSent = new Array[Int](totalBlocks)
+ totalBlocksLock.synchronized {
+ totalBlocksLock.notifyAll
+ }
+ totalBytes = gInfo.totalBytes
+ blockSize = gInfo.blockSize
+
+ // Start ttGuide to periodically talk to the Guide
+ var ttGuide = new TalkToGuide(gInfo)
+ ttGuide.setDaemon(true)
+ ttGuide.start()
+ logInfo("TalkToGuide started...")
+
+ // Start pController to run TalkToPeer threads
+ var pcController = new PeerChatterController
+ pcController.setDaemon(true)
+ pcController.start()
+ logInfo("PeerChatterController started...")
+
+ // FIXME: Must fix this. This might never break if broadcast fails.
+ // We should be able to break and send false. Also need to kill threads
+ while (hasBlocks.get < totalBlocks) {
+ Thread.sleep(Broadcast.MaxKnockInterval)
+ }
+
+ return true
+ }
+
+ class PeerChatterController
+ extends Thread with Logging {
+ private var peersNowTalking = ListBuffer[SourceInfo]()
+ // TODO: There is a possible bug with blocksInRequestBitVector when a
+ // certain bit is NOT unset upon failure resulting in an infinite loop.
+ private var blocksInRequestBitVector = new BitSet(totalBlocks)
+
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxRxSlots)
+
+ while (hasBlocks.get < totalBlocks) {
+ var numThreadsToCreate =
+ math.min(listOfSources.size, Broadcast.MaxRxSlots) -
+ threadPool.getActiveCount
+
+ while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
+ var peerToTalkTo = pickPeerToTalkToRandom
+
+ if (peerToTalkTo != null)
+ logInfo("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
+ else
+ logInfo("No peer chosen...")
+
+ if (peerToTalkTo != null) {
+ threadPool.execute(new TalkToPeer(peerToTalkTo))
+
+ // Add to peersNowTalking. Remove in the thread. We have to do this
+ // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
+ peersNowTalking.synchronized {
+ peersNowTalking += peerToTalkTo
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before starting some more threads
+ Thread.sleep(Broadcast.MinKnockInterval)
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ // Right now picking the one that has the most blocks this peer wants
+ // Also picking peer randomly if no one has anything interesting
+ private def pickPeerToTalkToRandom: SourceInfo = {
+ var curPeer: SourceInfo = null
+ var curMax = 0
+
+ logInfo("Picking peers to talk to...")
+
+ // Find peers that are not connected right now
+ var peersNotInUse = ListBuffer[SourceInfo]()
+ listOfSources.synchronized {
+ peersNowTalking.synchronized {
+ peersNotInUse = listOfSources -- peersNowTalking
+ }
+ }
+
+ // Select the peer that has the most blocks that this receiver does not
+ peersNotInUse.foreach { eachSource =>
+ var tempHasBlocksBitVector: BitSet = null
+ hasBlocksBitVector.synchronized {
+ tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
+ }
+ tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
+ tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
+
+ if (tempHasBlocksBitVector.cardinality > curMax) {
+ curPeer = eachSource
+ curMax = tempHasBlocksBitVector.cardinality
+ }
+ }
+
+ // TODO: Always pick randomly or randomly pick randomly?
+ // Now always picking randomly
+ if (curPeer == null && peersNotInUse.size > 0) {
+ // Pick uniformly the i'th required peer
+ var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size)
+
+ var peerIter = peersNotInUse.iterator
+ curPeer = peerIter.next
+
+ while (i > 0) {
+ curPeer = peerIter.next
+ i = i - 1
+ }
+ }
+
+ return curPeer
+ }
+
+ // Picking peer with the weight of rare blocks it has
+ private def pickPeerToTalkToRarestFirst: SourceInfo = {
+ // Find peers that are not connected right now
+ var peersNotInUse = ListBuffer[SourceInfo]()
+ listOfSources.synchronized {
+ peersNowTalking.synchronized {
+ peersNotInUse = listOfSources -- peersNowTalking
+ }
+ }
+
+ // Count the number of copies of each block in the neighborhood
+ var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
+
+ listOfSources.synchronized {
+ listOfSources.foreach { eachSource =>
+ for (i <- 0 until totalBlocks) {
+ numCopiesPerBlock(i) +=
+ ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
+ }
+ }
+ }
+
+ // TODO: A block is rare if there are at most 2 copies of that block
+ // TODO: This CONSTANT could be a function of the neighborhood size
+ var rareBlocksIndices = ListBuffer[Int]()
+ for (i <- 0 until totalBlocks) {
+ if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
+ rareBlocksIndices += i
+ }
+ }
+
+ // Find peers with rare blocks
+ var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
+ var totalRareBlocks = 0
+
+ peersNotInUse.foreach { eachPeer =>
+ var hasRareBlocks = 0
+ rareBlocksIndices.foreach { rareBlock =>
+ if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
+ hasRareBlocks += 1
+ }
+ }
+
+ if (hasRareBlocks > 0) {
+ peersWithRareBlocks += ((eachPeer, hasRareBlocks))
+ }
+ totalRareBlocks += hasRareBlocks
+ }
+
+ // Select a peer from peersWithRareBlocks based on weight calculated from
+ // unique rare blocks
+ var selectedPeerToTalkTo: SourceInfo = null
+
+ if (peersWithRareBlocks.size > 0) {
+ // Sort the peers based on how many rare blocks they have
+ peersWithRareBlocks.sortBy(_._2)
+
+ var randomNumber = BitTorrentBroadcast.ranGen.nextDouble
+ var tempSum = 0.0
+
+ var i = 0
+ do {
+ tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
+ if (tempSum >= randomNumber) {
+ selectedPeerToTalkTo = peersWithRareBlocks(i)._1
+ }
+ i += 1
+ } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
+ }
+
+ if (selectedPeerToTalkTo == null) {
+ selectedPeerToTalkTo = pickPeerToTalkToRandom
+ }
+
+ return selectedPeerToTalkTo
+ }
+
+ class TalkToPeer(peerToTalkTo: SourceInfo)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ override def run: Unit = {
+ // TODO: There is a possible bug here regarding blocksInRequestBitVector
+ var blockToAskFor = -1
+
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Broadcast.MaxKnockInterval)
+
+ logInfo("TalkToPeer started... => " + peerToTalkTo)
+
+ try {
+ // Connect to the source
+ peerSocketToSource =
+ new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ oisSource =
+ new ObjectInputStream(peerSocketToSource.getInputStream)
+
+ // Receive latest SourceInfo from peerToTalkTo
+ var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
+ // Update listOfSources
+ addToListOfSources(newPeerToTalkTo)
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel
+
+ // Send the latest SourceInfo
+ oosSource.writeObject(getLocalSourceInfo)
+ oosSource.flush()
+
+ var keepReceiving = true
+
+ while (hasBlocks.get < totalBlocks && keepReceiving) {
+ blockToAskFor =
+ pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
+
+ // No block to request
+ if (blockToAskFor < 0) {
+ // Nothing to receive from newPeerToTalkTo
+ keepReceiving = false
+ } else {
+ // Let other threads know that blockToAskFor is being requested
+ blocksInRequestBitVector.synchronized {
+ blocksInRequestBitVector.set(blockToAskFor)
+ }
+
+ // Start with sending the blockID
+ oosSource.writeObject(blockToAskFor)
+ oosSource.flush()
+
+ // CHANGED: Master might send some other block than the one
+ // requested to ensure fast spreading of all blocks.
+ val recvStartTime = System.currentTimeMillis
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ val receptionTime = (System.currentTimeMillis - recvStartTime)
+
+ logInfo("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
+
+ if (!hasBlocksBitVector.get(bcBlock.blockID)) {
+ arrayOfBlocks(bcBlock.blockID) = bcBlock
+
+ // Update the hasBlocksBitVector first
+ hasBlocksBitVector.synchronized {
+ hasBlocksBitVector.set(bcBlock.blockID)
+ hasBlocks.getAndIncrement
+ }
+
+ rxSpeeds.addDataPoint(peerToTalkTo, receptionTime)
+
+ // Some block(may NOT be blockToAskFor) has arrived.
+ // In any case, blockToAskFor is not in request any more
+ blocksInRequestBitVector.synchronized {
+ blocksInRequestBitVector.set(blockToAskFor, false)
+ }
+
+ // Reset blockToAskFor to -1. Else it will be considered missing
+ blockToAskFor = -1
+ }
+
+ // Send the latest SourceInfo
+ oosSource.writeObject(getLocalSourceInfo)
+ oosSource.flush()
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("TalktoPeer had a " + e)
+ // FIXME: Remove 'newPeerToTalkTo' from listOfSources
+ // We probably should have the following in some form, but not
+ // really here. This exception can happen if the sender just breaks connection
+ // listOfSources.synchronized {
+ // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
+ // listOfSources = listOfSources - peerToTalkTo
+ // }
+ }
+ } finally {
+ // blockToAskFor != -1 => there was an exception
+ if (blockToAskFor != -1) {
+ blocksInRequestBitVector.synchronized {
+ blocksInRequestBitVector.set(blockToAskFor, false)
+ }
+ }
+
+ cleanUpConnections()
+ }
+ }
+
+ // Right now it picks a block uniformly that this peer does not have
+ private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
+ var needBlocksBitVector: BitSet = null
+
+ // Blocks already present
+ hasBlocksBitVector.synchronized {
+ needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Include blocks already in transmission ONLY IF
+ // BitTorrentBroadcast.EndGameFraction has NOT been achieved
+ if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) {
+ blocksInRequestBitVector.synchronized {
+ needBlocksBitVector.or(blocksInRequestBitVector)
+ }
+ }
+
+ // Find blocks that are neither here nor in transit
+ needBlocksBitVector.flip(0, needBlocksBitVector.size)
+
+ // Blocks that should/can be requested
+ needBlocksBitVector.and(txHasBlocksBitVector)
+
+ if (needBlocksBitVector.cardinality == 0) {
+ return -1
+ } else {
+ // Pick uniformly the i'th required block
+ var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality)
+ var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
+
+ while (i > 0) {
+ pickedBlockIndex =
+ needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
+ i -= 1
+ }
+
+ return pickedBlockIndex
+ }
+ }
+
+ // Pick the block that seems to be the rarest across sources
+ private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
+ var needBlocksBitVector: BitSet = null
+
+ // Blocks already present
+ hasBlocksBitVector.synchronized {
+ needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Include blocks already in transmission ONLY IF
+ // BitTorrentBroadcast.EndGameFraction has NOT been achieved
+ if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) {
+ blocksInRequestBitVector.synchronized {
+ needBlocksBitVector.or(blocksInRequestBitVector)
+ }
+ }
+
+ // Find blocks that are neither here nor in transit
+ needBlocksBitVector.flip(0, needBlocksBitVector.size)
+
+ // Blocks that should/can be requested
+ needBlocksBitVector.and(txHasBlocksBitVector)
+
+ if (needBlocksBitVector.cardinality == 0) {
+ return -1
+ } else {
+ // Count the number of copies for each block across all sources
+ var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
+
+ listOfSources.synchronized {
+ listOfSources.foreach { eachSource =>
+ for (i <- 0 until totalBlocks) {
+ numCopiesPerBlock(i) +=
+ ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
+ }
+ }
+ }
+
+ // Find the minimum
+ var minVal = Integer.MAX_VALUE
+ for (i <- 0 until totalBlocks) {
+ if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
+ minVal = numCopiesPerBlock(i)
+ }
+ }
+
+ // Find the blocks with the least copies that this peer does not have
+ var minBlocksIndices = ListBuffer[Int]()
+ for (i <- 0 until totalBlocks) {
+ if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
+ minBlocksIndices += i
+ }
+ }
+
+ // Now select a random index from minBlocksIndices
+ if (minBlocksIndices.size == 0) {
+ return -1
+ } else {
+ // Pick uniformly the i'th index
+ var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size)
+ return minBlocksIndices(i)
+ }
+ }
+ }
+
+ private def cleanUpConnections(): Unit = {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+
+ // Delete from peersNowTalking
+ peersNowTalking.synchronized {
+ peersNowTalking = peersNowTalking - peerToTalkTo
+ }
+ }
+ }
+ }
+
+ class GuideMultipleRequests
+ extends Thread with Logging {
+ // Keep track of sources that have completed reception
+ private var setOfCompletedSources = Set[SourceInfo]()
+
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ guidePort = serverSocket.getLocalPort
+ logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
+
+ guidePortLock.synchronized {
+ guidePortLock.notifyAll
+ }
+
+ try {
+ // Don't stop until there is a copy in HDFS
+ while (!stopBroadcast || !hasCopyInHDFS) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("GuideMultipleRequests Timeout.")
+
+ // Stop broadcast if at least one worker has connected and
+ // everyone connected so far are done. Comparing with
+ // listOfSources.size - 1, because it includes the Guide itself
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ }
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Guide: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new GuideSingleRequest(clientSocket))
+ } catch {
+ // In failure, close the socket here; else, thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown()
+
+ logInfo("Sending stopBroadcast notifications...")
+ sendStopBroadcastNotifications
+
+ unregisterBroadcast(uuid)
+ } finally {
+ if (serverSocket != null) {
+ logInfo("GuideMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+ }
+
+ private def sendStopBroadcastNotifications: Unit = {
+ listOfSources.synchronized {
+ listOfSources.foreach { sourceInfo =>
+
+ var guideSocketToSource: Socket = null
+ var gosSource: ObjectOutputStream = null
+ var gisSource: ObjectInputStream = null
+
+ try {
+ // Connect to the source
+ guideSocketToSource =
+ new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource =
+ new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ gosSource.flush()
+ gisSource =
+ new ObjectInputStream(guideSocketToSource.getInputStream)
+
+ // Throw away whatever comes in
+ gisSource.readObject.asInstanceOf[SourceInfo]
+
+ // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
+ gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
+ gosSource.flush()
+ } catch {
+ case e: Exception => {
+ logInfo("sendStopBroadcastNotifications had a " + e)
+ }
+ } finally {
+ if (gisSource != null) {
+ gisSource.close()
+ }
+ if (gosSource != null) {
+ gosSource.close()
+ }
+ if (guideSocketToSource != null) {
+ guideSocketToSource.close()
+ }
+ }
+ }
+ }
+ }
+
+ class GuideSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var sourceInfo: SourceInfo = null
+ private var selectedSources: ListBuffer[SourceInfo] = null
+
+ override def run: Unit = {
+ try {
+ logInfo("new GuideSingleRequest is running")
+ // Connecting worker is sending in its information
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Select a suitable source and send it back to the worker
+ selectedSources = selectSuitableSources(sourceInfo)
+ logInfo("Sending selectedSources:" + selectedSources)
+ oos.writeObject(selectedSources)
+ oos.flush()
+
+ // Add this source to the listOfSources
+ addToListOfSources(sourceInfo)
+ } catch {
+ case e: Exception => {
+ // Assuming exception caused by receiver failure: remove
+ if (listOfSources != null) {
+ listOfSources.synchronized {
+ listOfSources = listOfSources - sourceInfo
+ }
+ }
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ // Randomly select some sources to send back
+ private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
+ var selectedSources = ListBuffer[SourceInfo]()
+
+ // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
+ // then add skipSourceInfo to setOfCompletedSources. Return blank.
+ if (skipSourceInfo.hasBlocks == totalBlocks) {
+ setOfCompletedSources.synchronized {
+ setOfCompletedSources += skipSourceInfo
+ }
+ return selectedSources
+ }
+
+ listOfSources.synchronized {
+ if (listOfSources.size <= Broadcast.MaxPeersInGuideResponse) {
+ selectedSources = listOfSources.clone
+ } else {
+ var picksLeft = Broadcast.MaxPeersInGuideResponse
+ var alreadyPicked = new BitSet(listOfSources.size)
+
+ while (picksLeft > 0) {
+ var i = -1
+
+ do {
+ i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size)
+ } while (alreadyPicked.get(i))
+
+ var peerIter = listOfSources.iterator
+ var curPeer = peerIter.next
+
+ // Set the BitSet before i is decremented
+ alreadyPicked.set(i)
+
+ while (i > 0) {
+ curPeer = peerIter.next
+ i = i - 1
+ }
+
+ selectedSources += curPeer
+
+ picksLeft = picksLeft - 1
+ }
+ }
+ }
+
+ // Remove the receiving source (if present)
+ selectedSources = selectedSources - skipSourceInfo
+
+ return selectedSources
+ }
+ }
+ }
+
+ class ServeMultipleRequests
+ extends Thread with Logging {
+ // Server at most Broadcast.MaxTxSlots peers
+ var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxTxSlots)
+
+ override def run: Unit = {
+ var serverSocket = new ServerSocket(0)
+ listenPort = serverSocket.getLocalPort
+
+ logInfo("ServeMultipleRequests started with " + serverSocket)
+
+ listenPortLock.synchronized {
+ listenPortLock.notifyAll
+ }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("ServeMultipleRequests Timeout.")
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ServeSingleRequest(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ServeMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ServeSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ServeSingleRequest is running")
+
+ override def run: Unit = {
+ try {
+ // Send latest local SourceInfo to the receiver
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ oos.writeObject(getLocalSourceInfo)
+ oos.flush()
+
+ // Receive latest SourceInfo from the receiver
+ var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
+ stopBroadcast = true
+ } else {
+ // Carry on
+ addToListOfSources(rxSourceInfo)
+ }
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = Broadcast.MaxChatBlocks
+
+ while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
+ // Receive which block to send
+ var blockToSend = ois.readObject.asInstanceOf[Int]
+
+ // If it is master AND at least one copy of each block has not been
+ // sent out already, MODIFY blockToSend
+ if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) {
+ blockToSend = sentBlocks.getAndIncrement
+ }
+
+ // Send the block
+ sendBlock(blockToSend)
+ rxSourceInfo.hasBlocksBitVector.set(blockToSend)
+
+ numBlocksToSend -= 1
+
+ // Receive latest SourceInfo from the receiver
+ rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+ // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
+ addToListOfSources(rxSourceInfo)
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ if (curTime - startTime >= Broadcast.MaxChatTime &&
+ threadPool.getQueue.size > 0) {
+ keepSending = false
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ case e: Exception => {
+ logInfo("ServeSingleRequest had a " + e)
+ }
+ } finally {
+ logInfo("ServeSingleRequest is closing streams and sockets")
+ ois.close()
+ // TODO: The following line causes a "java.net.SocketException: Socket closed"
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ private def sendBlock(blockToSend: Int): Unit = {
+ try {
+ oos.writeObject(arrayOfBlocks(blockToSend))
+ oos.flush()
+ } catch {
+ case e: Exception => {
+ logInfo("sendBlock had a " + e)
+ }
+ }
+ logInfo("Sent block: " + blockToSend + " to " + clientSocket)
+ }
+ }
+ }
+}
+
+class BitTorrentBroadcastFactory
+extends BroadcastFactory {
+ def initialize(isMaster: Boolean) = {
+ BitTorrentBroadcast.initialize(isMaster)
+ }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean) =
+ new BitTorrentBroadcast[T](value_, isLocal)
+}
+
+private object BitTorrentBroadcast
+extends Logging {
+ val values = SparkEnv.get.cache.newKeySpace()
+
+ var valueToGuideMap = Map[UUID, SourceInfo]()
+
+ // Random number generator
+ var ranGen = new Random
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var trackMV: TrackMultipleValues = null
+
+ def initialize(isMaster__ : Boolean): Unit = {
+ synchronized {
+ if (!initialized) {
+
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ trackMV = new TrackMultipleValues
+ trackMV.setDaemon(true)
+ trackMV.start()
+ // TODO: Logging the following line makes the Spark framework ID not
+ // getting logged, cause it calls logInfo before log4j is initialized
+ logInfo("TrackMultipleValues started...")
+ }
+
+ // Initialize DfsBroadcast to be used for broadcast variable persistence
+ // TODO: Think about persistence
+ DfsBroadcast.initialize
+
+ initialized = true
+ }
+ }
+ }
+
+ def isMaster = isMaster_
+
+ class TrackMultipleValues
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
+ logInfo("TrackMultipleValues" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("TrackMultipleValues Timeout. Stopping listening...")
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // First, read message type
+ val messageType = ois.readObject.asInstanceOf[Int]
+
+ if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) {
+ // Receive UUID
+ val uuid = ois.readObject.asInstanceOf[UUID]
+ // Receive hostAddress and listenPort
+ val gInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Add to the map
+ valueToGuideMap.synchronized {
+ valueToGuideMap += (uuid -> gInfo)
+ }
+
+ logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap)
+
+ // Send dummy ACK
+ oos.writeObject(-1)
+ oos.flush()
+ } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) {
+ // Receive UUID
+ val uuid = ois.readObject.asInstanceOf[UUID]
+
+ // Remove from the map
+ valueToGuideMap.synchronized {
+ valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToHDFS)
+ logInfo("Value unregistered from the Tracker " + valueToGuideMap)
+ }
+
+ logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap)
+
+ // Send dummy ACK
+ oos.writeObject(-1)
+ oos.flush()
+ } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) {
+ // Receive UUID
+ val uuid = ois.readObject.asInstanceOf[UUID]
+
+ var gInfo =
+ if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
+ else SourceInfo("", SourceInfo.TxNotStartedRetry)
+
+ logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
+
+ // Send reply back
+ oos.writeObject(gInfo)
+ oos.flush()
+ } else if (messageType == Broadcast.GET_UPDATED_SHARE) {
+ // TODO: Not implemented
+ } else {
+ throw new SparkException("Undefined messageType at TrackMultipleValues")
+ }
+ } catch {
+ case e: Exception => {
+ logInfo("TrackMultipleValues had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
new file mode 100644
index 0000000000..f39fb9de69
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -0,0 +1,228 @@
+package spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{BitSet, UUID}
+import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+
+import spark._
+
+@serializable
+trait Broadcast[T] {
+ val uuid = UUID.randomUUID
+
+ def value: T
+
+ // We cannot have an abstract readObject here due to some weird issues with
+ // readObject having to be 'private' in sub-classes. Possibly a Scala bug!
+
+ override def toString = "spark.Broadcast(" + uuid + ")"
+}
+
+object Broadcast
+extends Logging {
+ // Messages
+ val REGISTER_BROADCAST_TRACKER = 0
+ val UNREGISTER_BROADCAST_TRACKER = 1
+ val FIND_BROADCAST_TRACKER = 2
+ val GET_UPDATED_SHARE = 3
+
+ private var initialized = false
+ private var isMaster_ = false
+ private var broadcastFactory: BroadcastFactory = null
+
+ // Called by SparkContext or Executor before using Broadcast
+ def initialize (isMaster__ : Boolean): Unit = synchronized {
+ if (!initialized) {
+ val broadcastFactoryClass = System.getProperty(
+ "spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
+
+ broadcastFactory =
+ Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+
+ // Setup isMaster before using it
+ isMaster_ = isMaster__
+
+ // Set masterHostAddress to the master's IP address for the slaves to read
+ if (isMaster) {
+ System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
+ }
+
+ // Initialize appropriate BroadcastFactory and BroadcastObject
+ broadcastFactory.initialize(isMaster)
+
+ initialized = true
+ }
+ }
+
+ def getBroadcastFactory: BroadcastFactory = {
+ if (broadcastFactory == null) {
+ throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
+ }
+ broadcastFactory
+ }
+
+ // Load common broadcast-related config parameters
+ private var MasterHostAddress_ = System.getProperty(
+ "spark.broadcast.masterHostAddress", "")
+ private var MasterTrackerPort_ = System.getProperty(
+ "spark.broadcast.masterTrackerPort", "11111").toInt
+ private var BlockSize_ = System.getProperty(
+ "spark.broadcast.blockSize", "4096").toInt * 1024
+ private var MaxRetryCount_ = System.getProperty(
+ "spark.broadcast.maxRetryCount", "2").toInt
+
+ private var TrackerSocketTimeout_ = System.getProperty(
+ "spark.broadcast.trackerSocketTimeout", "50000").toInt
+ private var ServerSocketTimeout_ = System.getProperty(
+ "spark.broadcast.serverSocketTimeout", "10000").toInt
+
+ private var MinKnockInterval_ = System.getProperty(
+ "spark.broadcast.minKnockInterval", "500").toInt
+ private var MaxKnockInterval_ = System.getProperty(
+ "spark.broadcast.maxKnockInterval", "999").toInt
+
+ // Load ChainedBroadcast config params
+
+ // Load TreeBroadcast config params
+ private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt
+
+ // Load BitTorrentBroadcast config params
+ private var MaxPeersInGuideResponse_ = System.getProperty(
+ "spark.broadcast.maxPeersInGuideResponse", "4").toInt
+
+ private var MaxRxSlots_ = System.getProperty(
+ "spark.broadcast.maxRxSlots", "4").toInt
+ private var MaxTxSlots_ = System.getProperty(
+ "spark.broadcast.maxTxSlots", "4").toInt
+
+ private var MaxChatTime_ = System.getProperty(
+ "spark.broadcast.maxChatTime", "500").toInt
+ private var MaxChatBlocks_ = System.getProperty(
+ "spark.broadcast.maxChatBlocks", "1024").toInt
+
+ private var EndGameFraction_ = System.getProperty(
+ "spark.broadcast.endGameFraction", "0.95").toDouble
+
+ def isMaster = isMaster_
+
+ // Common config params
+ def MasterHostAddress = MasterHostAddress_
+ def MasterTrackerPort = MasterTrackerPort_
+ def BlockSize = BlockSize_
+ def MaxRetryCount = MaxRetryCount_
+
+ def TrackerSocketTimeout = TrackerSocketTimeout_
+ def ServerSocketTimeout = ServerSocketTimeout_
+
+ def MinKnockInterval = MinKnockInterval_
+ def MaxKnockInterval = MaxKnockInterval_
+
+ // ChainedBroadcast configs
+
+ // TreeBroadcast configs
+ def MaxDegree = MaxDegree_
+
+ // BitTorrentBroadcast configs
+ def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
+
+ def MaxRxSlots = MaxRxSlots_
+ def MaxTxSlots = MaxTxSlots_
+
+ def MaxChatTime = MaxChatTime_
+ def MaxChatBlocks = MaxChatBlocks_
+
+ def EndGameFraction = EndGameFraction_
+
+ // Helper functions to convert an object to Array[BroadcastBlock]
+ def blockifyObject[IN](obj: IN): VariableInfo = {
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream(baos)
+ oos.writeObject(obj)
+ oos.close()
+ baos.close()
+ val byteArray = baos.toByteArray
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / Broadcast.BlockSize)
+ if (byteArray.length % Broadcast.BlockSize != 0)
+ blockNum += 1
+
+ var retVal = new Array[BroadcastBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
+ val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
+ variableInfo.hasBlocks = blockNum
+
+ return variableInfo
+ }
+
+ // Helper function to convert Array[BroadcastBlock] to object
+ def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
+ totalBytes: Int,
+ totalBlocks: Int): OUT = {
+
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
+ }
+ byteArrayToObject(retByteArray)
+ }
+
+ private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, currentThread.getContextClassLoader)
+ }
+ val retVal = in.readObject.asInstanceOf[OUT]
+ in.close()
+ return retVal
+ }
+}
+
+@serializable
+case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
+
+@serializable
+case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
+ val totalBlocks: Int,
+ val totalBytes: Int) {
+ @transient var hasBlocks = 0
+}
+
+@serializable
+class SpeedTracker {
+ // Mapping 'source' to '(totalTime, numBlocks)'
+ private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
+
+ def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
+ sourceToSpeedMap.synchronized {
+ if (!sourceToSpeedMap.contains(srcInfo)) {
+ sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
+ } else {
+ val tTnB = sourceToSpeedMap (srcInfo)
+ sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
+ }
+ }
+ }
+
+ def getTimePerBlock (srcInfo: SourceInfo): Double = {
+ sourceToSpeedMap.synchronized {
+ val tTnB = sourceToSpeedMap (srcInfo)
+ return tTnB._1 / tTnB._2
+ }
+ }
+
+ override def toString = sourceToSpeedMap.toString
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
new file mode 100644
index 0000000000..341746d18e
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
@@ -0,0 +1,12 @@
+package spark.broadcast
+
+/**
+ * An interface for all the broadcast implementations in Spark (to allow
+ * multiple broadcast implementations). SparkContext uses a user-specified
+ * BroadcastFactory implementation to instantiate a particular broadcast for the
+ * entire Spark job.
+ */
+trait BroadcastFactory {
+ def initialize (isMaster: Boolean): Unit
+ def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
new file mode 100644
index 0000000000..3afe923bae
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
@@ -0,0 +1,792 @@
+package spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{Comparator, PriorityQueue, Random, UUID}
+
+import scala.collection.mutable.{Map, Set}
+import scala.math
+
+import spark._
+
+@serializable
+class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean)
+extends Broadcast[T] with Logging {
+
+ def value = value_
+
+ ChainedBroadcast.synchronized {
+ ChainedBroadcast.values.put(uuid, value_)
+ }
+
+ @transient var arrayOfBlocks: Array[BroadcastBlock] = null
+ @transient var totalBytes = -1
+ @transient var totalBlocks = -1
+ @transient var hasBlocks = 0
+ // CHANGED: BlockSize in the Broadcast object is expected to change over time
+ @transient var blockSize = Broadcast.BlockSize
+
+ @transient var listenPortLock = new Object
+ @transient var guidePortLock = new Object
+ @transient var totalBlocksLock = new Object
+ @transient var hasBlocksLock = new Object
+
+ @transient var pqOfSources = new PriorityQueue[SourceInfo]
+
+ @transient var serveMR: ServeMultipleRequests = null
+ @transient var guideMR: GuideMultipleRequests = null
+
+ @transient var hostAddress = Utils.localIpAddress
+ @transient var listenPort = -1
+ @transient var guidePort = -1
+
+ @transient var hasCopyInHDFS = false
+ @transient var stopBroadcast = false
+
+ // Must call this after all the variables have been created/initialized
+ if (!isLocal) {
+ sendBroadcast
+ }
+
+ def sendBroadcast(): Unit = {
+ logInfo("Local host address: " + hostAddress)
+
+ // Store a persistent copy in HDFS
+ // TODO: Turned OFF for now
+ // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
+ // out.writeObject(value_)
+ // out.close()
+ // TODO: Fix this at some point
+ hasCopyInHDFS = true
+
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = Broadcast.blockifyObject(value_)
+
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon(true)
+ guideMR.start()
+ logInfo("GuideMultipleRequests started...")
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ // Prepare the value being broadcasted
+ // TODO: Refactoring and clean-up required here
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks = variableInfo.totalBlocks
+
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ pqOfSources = new PriorityQueue[SourceInfo]
+ val masterSource =
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+ pqOfSources.add(masterSource)
+
+ // Register with the Tracker
+ while (guidePort == -1) {
+ guidePortLock.synchronized {
+ guidePortLock.wait
+ }
+ }
+ ChainedBroadcast.registerValue(uuid, guidePort)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ in.defaultReadObject
+ ChainedBroadcast.synchronized {
+ val cachedVal = ChainedBroadcast.values.get(uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ // Initializing everything because Master will only send null/0 values
+ initializeSlaveVariables
+
+ logInfo("Local host address: " + hostAddress)
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(uuid)
+ // If does not succeed, then get from HDFS copy
+ if (receptionSucceeded) {
+ value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ ChainedBroadcast.values.put(uuid, value_)
+ } else {
+ val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ ChainedBroadcast.values.put(uuid, value_)
+ fileIn.close()
+ }
+
+ val time =(System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ private def initializeSlaveVariables: Unit = {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ blockSize = -1
+
+ listenPortLock = new Object
+ totalBlocksLock = new Object
+ hasBlocksLock = new Object
+
+ serveMR = null
+
+ hostAddress = Utils.localIpAddress
+ listenPort = -1
+
+ stopBroadcast = false
+ }
+
+ def getMasterListenPort(variableUUID: UUID): Int = {
+ var clientSocketToTracker: Socket = null
+ var oosTracker: ObjectOutputStream = null
+ var oisTracker: ObjectInputStream = null
+
+ var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
+
+ var retriesLeft = Broadcast.MaxRetryCount
+ do {
+ try {
+ // Connect to the tracker to find out the guide
+ clientSocketToTracker =
+ new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
+ oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ // Send UUID and receive masterListenPort
+ oosTracker.writeObject(uuid)
+ oosTracker.flush()
+ masterListenPort = oisTracker.readObject.asInstanceOf[Int]
+ } catch {
+ case e: Exception => {
+ logInfo("getMasterListenPort had a " + e)
+ }
+ } finally {
+ if (oisTracker != null) {
+ oisTracker.close()
+ }
+ if (oosTracker != null) {
+ oosTracker.close()
+ }
+ if (clientSocketToTracker != null) {
+ clientSocketToTracker.close()
+ }
+ }
+ retriesLeft -= 1
+
+ Thread.sleep(ChainedBroadcast.ranGen.nextInt(
+ Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
+ Broadcast.MinKnockInterval)
+
+ } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
+
+ logInfo("Got this guidePort from Tracker: " + masterListenPort)
+ return masterListenPort
+ }
+
+ def receiveBroadcast(variableUUID: UUID): Boolean = {
+ val masterListenPort = getMasterListenPort(variableUUID)
+
+ if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
+ masterListenPort == SourceInfo.TxNotStartedRetry) {
+ // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
+ // to HDFS anyway when receiveBroadcast returns false
+ return false
+ }
+
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ var clientSocketToMaster: Socket = null
+ var oosMaster: ObjectOutputStream = null
+ var oisMaster: ObjectInputStream = null
+
+ // Connect and receive broadcast from the specified source, retrying the
+ // specified number of times in case of failures
+ var retriesLeft = Broadcast.MaxRetryCount
+ do {
+ // Connect to Master and send this worker's Information
+ clientSocketToMaster =
+ new Socket(Broadcast.MasterHostAddress, masterListenPort)
+ // TODO: Guiding object connection is reusable
+ oosMaster =
+ new ObjectOutputStream(clientSocketToMaster.getOutputStream)
+ oosMaster.flush()
+ oisMaster =
+ new ObjectInputStream(clientSocketToMaster.getInputStream)
+
+ logInfo("Connected to Master's guiding object")
+
+ // Send local source information
+ oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
+ oosMaster.flush()
+
+ // Receive source information from Master
+ var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ totalBlocks = sourceInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
+ totalBlocksLock.synchronized {
+ totalBlocksLock.notifyAll
+ }
+ totalBytes = sourceInfo.totalBytes
+
+ logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+
+ val start = System.nanoTime
+ val receptionSucceeded = receiveSingleTransmission(sourceInfo)
+ val time =(System.nanoTime - start) / 1e9
+
+ // Updating some statistics in sourceInfo. Master will be using them later
+ if (!receptionSucceeded) {
+ sourceInfo.receptionFailed = true
+ }
+
+ // Send back statistics to the Master
+ oosMaster.writeObject(sourceInfo)
+
+ if (oisMaster != null) {
+ oisMaster.close()
+ }
+ if (oosMaster != null) {
+ oosMaster.close()
+ }
+ if (clientSocketToMaster != null) {
+ clientSocketToMaster.close()
+ }
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && hasBlocks < totalBlocks)
+
+ return(hasBlocks == totalBlocks)
+ }
+
+ // Tries to receive broadcast from the source and returns Boolean status.
+ // This might be called multiple times to retry a defined number of times.
+ private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
+ var clientSocketToSource: Socket = null
+ var oosSource: ObjectOutputStream = null
+ var oisSource: ObjectInputStream = null
+
+ var receptionSucceeded = false
+ try {
+ // Connect to the source to get the object itself
+ clientSocketToSource =
+ new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream(clientSocketToSource.getOutputStream)
+ oosSource.flush()
+ oisSource =
+ new ObjectInputStream(clientSocketToSource.getInputStream)
+
+ logInfo("Inside receiveSingleTransmission")
+ logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+
+ // Send the range
+ oosSource.writeObject((hasBlocks, totalBlocks))
+ oosSource.flush()
+
+ for (i <- hasBlocks until totalBlocks) {
+ val recvStartTime = System.currentTimeMillis
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ val receptionTime =(System.currentTimeMillis - recvStartTime)
+
+ logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
+
+ arrayOfBlocks(hasBlocks) = bcBlock
+ hasBlocks += 1
+ // Set to true if at least one block is received
+ receptionSucceeded = true
+ hasBlocksLock.synchronized {
+ hasBlocksLock.notifyAll
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logInfo("receiveSingleTransmission had a " + e)
+ }
+ } finally {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (clientSocketToSource != null) {
+ clientSocketToSource.close()
+ }
+ }
+
+ return receptionSucceeded
+ }
+
+ class GuideMultipleRequests
+ extends Thread with Logging {
+ // Keep track of sources that have completed reception
+ private var setOfCompletedSources = Set[SourceInfo]()
+
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ guidePort = serverSocket.getLocalPort
+ logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
+
+ guidePortLock.synchronized {
+ guidePortLock.notifyAll
+ }
+
+ try {
+ // Don't stop until there is a copy in HDFS
+ while (!stopBroadcast || !hasCopyInHDFS) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("GuideMultipleRequests Timeout.")
+
+ // Stop broadcast if at least one worker has connected and
+ // everyone connected so far are done. Comparing with
+ // pqOfSources.size - 1, because it includes the Guide itself
+ if (pqOfSources.size > 1 &&
+ setOfCompletedSources.size == pqOfSources.size - 1) {
+ stopBroadcast = true
+ }
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Guide: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute(new GuideSingleRequest(clientSocket))
+ } catch {
+ // In failure, close the socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+
+ logInfo("Sending stopBroadcast notifications...")
+ sendStopBroadcastNotifications
+
+ ChainedBroadcast.unregisterValue(uuid)
+ } finally {
+ if (serverSocket != null) {
+ logInfo("GuideMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ private def sendStopBroadcastNotifications: Unit = {
+ pqOfSources.synchronized {
+ var pqIter = pqOfSources.iterator
+ while (pqIter.hasNext) {
+ var sourceInfo = pqIter.next
+
+ var guideSocketToSource: Socket = null
+ var gosSource: ObjectOutputStream = null
+ var gisSource: ObjectInputStream = null
+
+ try {
+ // Connect to the source
+ guideSocketToSource =
+ new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource =
+ new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ gosSource.flush()
+ gisSource =
+ new ObjectInputStream(guideSocketToSource.getInputStream)
+
+ // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
+ gosSource.writeObject((SourceInfo.StopBroadcast,
+ SourceInfo.StopBroadcast))
+ gosSource.flush()
+ } catch {
+ case e: Exception => {
+ logInfo("sendStopBroadcastNotifications had a " + e)
+ }
+ } finally {
+ if (gisSource != null) {
+ gisSource.close()
+ }
+ if (gosSource != null) {
+ gosSource.close()
+ }
+ if (guideSocketToSource != null) {
+ guideSocketToSource.close()
+ }
+ }
+ }
+ }
+ }
+
+ class GuideSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var selectedSourceInfo: SourceInfo = null
+ private var thisWorkerInfo:SourceInfo = null
+
+ override def run: Unit = {
+ try {
+ logInfo("new GuideSingleRequest is running")
+ // Connecting worker is sending in its hostAddress and listenPort it will
+ // be listening to. Other fields are invalid(SourceInfo.UnusedParam)
+ var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ pqOfSources.synchronized {
+ // Select a suitable source and send it back to the worker
+ selectedSourceInfo = selectSuitableSource(sourceInfo)
+ logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
+ oos.writeObject(selectedSourceInfo)
+ oos.flush()
+
+ // Add this new(if it can finish) source to the PQ of sources
+ thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
+ sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
+ logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
+ pqOfSources.add(thisWorkerInfo)
+ }
+
+ // Wait till the whole transfer is done. Then receive and update source
+ // statistics in pqOfSources
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ pqOfSources.synchronized {
+ // This should work since SourceInfo is a case class
+ assert(pqOfSources.contains(selectedSourceInfo))
+
+ // Remove first
+ pqOfSources.remove(selectedSourceInfo)
+ // TODO: Removing a source based on just one failure notification!
+
+ // Update sourceInfo and put it back in, IF reception succeeded
+ if (!sourceInfo.receptionFailed) {
+ // Add thisWorkerInfo to sources that have completed reception
+ setOfCompletedSources.synchronized {
+ setOfCompletedSources += thisWorkerInfo
+ }
+
+ selectedSourceInfo.currentLeechers -= 1
+
+ // Put it back
+ pqOfSources.add(selectedSourceInfo)
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ // Assuming that exception caused due to receiver worker failure.
+ // Remove failed worker from pqOfSources and update leecherCount of
+ // corresponding source worker
+ pqOfSources.synchronized {
+ if (selectedSourceInfo != null) {
+ // Remove first
+ pqOfSources.remove(selectedSourceInfo)
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ pqOfSources.add(selectedSourceInfo)
+ }
+
+ // Remove thisWorkerInfo
+ if (pqOfSources != null) {
+ pqOfSources.remove(thisWorkerInfo)
+ }
+ }
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ // FIXME: Caller must have a synchronized block on pqOfSources
+ // FIXME: If a worker fails to get the broadcasted variable from a source and
+ // comes back to Master, this function might choose the worker itself as a
+ // source tp create a dependency cycle(this worker was put into pqOfSources
+ // as a streming source when it first arrived). The length of this cycle can
+ // be arbitrarily long.
+ private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+ // Select one based on the ordering strategy(e.g., least leechers etc.)
+ // take is a blocking call removing the element from PQ
+ var selectedSource = pqOfSources.poll
+ assert(selectedSource != null)
+ // Update leecher count
+ selectedSource.currentLeechers += 1
+ // Add it back and then return
+ pqOfSources.add(selectedSource)
+ return selectedSource
+ }
+ }
+ }
+
+ class ServeMultipleRequests
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ listenPort = serverSocket.getLocalPort
+ logInfo("ServeMultipleRequests started with " + serverSocket)
+
+ listenPortLock.synchronized {
+ listenPortLock.notifyAll
+ }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("ServeMultipleRequests Timeout.")
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute(new ServeSingleRequest(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ServeMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ServeSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var sendFrom = 0
+ private var sendUntil = totalBlocks
+
+ override def run: Unit = {
+ try {
+ logInfo("new ServeSingleRequest is running")
+
+ // Receive range to send
+ var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
+ sendFrom = rangeToSend._1
+ sendUntil = rangeToSend._2
+
+ if (sendFrom == SourceInfo.StopBroadcast &&
+ sendUntil == SourceInfo.StopBroadcast) {
+ stopBroadcast = true
+ } else {
+ // Carry on
+ sendObject
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ logInfo("ServeSingleRequest had a " + e)
+ }
+ } finally {
+ logInfo("ServeSingleRequest is closing streams and sockets")
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ private def sendObject: Unit = {
+ // Wait till receiving the SourceInfo from Master
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized {
+ totalBlocksLock.wait
+ }
+ }
+
+ for (i <- sendFrom until sendUntil) {
+ while (i == hasBlocks) {
+ hasBlocksLock.synchronized {
+ hasBlocksLock.wait
+ }
+ }
+ try {
+ oos.writeObject(arrayOfBlocks(i))
+ oos.flush()
+ } catch {
+ case e: Exception => {
+ logInfo("sendObject had a " + e)
+ }
+ }
+ logInfo("Sent block: " + i + " to " + clientSocket)
+ }
+ }
+ }
+ }
+}
+
+class ChainedBroadcastFactory
+extends BroadcastFactory {
+ def initialize(isMaster: Boolean) = ChainedBroadcast.initialize(isMaster)
+ def newBroadcast[T](value_ : T, isLocal: Boolean) =
+ new ChainedBroadcast[T](value_, isLocal)
+}
+
+private object ChainedBroadcast
+extends Logging {
+ val values = SparkEnv.get.cache.newKeySpace()
+
+ var valueToGuidePortMap = Map[UUID, Int]()
+
+ // Random number generator
+ var ranGen = new Random
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var trackMV: TrackMultipleValues = null
+
+ def initialize(isMaster__ : Boolean): Unit = {
+ synchronized {
+ if (!initialized) {
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ trackMV = new TrackMultipleValues
+ trackMV.setDaemon(true)
+ trackMV.start()
+ // TODO: Logging the following line makes the Spark framework ID not
+ // getting logged, cause it calls logInfo before log4j is initialized
+ logInfo("TrackMultipleValues started...")
+ }
+
+ // Initialize DfsBroadcast to be used for broadcast variable persistence
+ DfsBroadcast.initialize
+
+ initialized = true
+ }
+ }
+ }
+
+ def isMaster = isMaster_
+
+ def registerValue(uuid: UUID, guidePort: Int): Unit = {
+ valueToGuidePortMap.synchronized {
+ valueToGuidePortMap +=(uuid -> guidePort)
+ logInfo("New value registered with the Tracker " + valueToGuidePortMap)
+ }
+ }
+
+ def unregisterValue(uuid: UUID): Unit = {
+ valueToGuidePortMap.synchronized {
+ valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
+ logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
+ }
+ }
+
+ class TrackMultipleValues
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
+ logInfo("TrackMultipleValues" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("TrackMultipleValues Timeout. Stopping listening...")
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+ try {
+ val uuid = ois.readObject.asInstanceOf[UUID]
+ var guidePort =
+ if (valueToGuidePortMap.contains(uuid)) {
+ valueToGuidePortMap(uuid)
+ } else SourceInfo.TxNotStartedRetry
+ logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
+ oos.writeObject(guidePort)
+ } catch {
+ case e: Exception => {
+ logInfo("TrackMultipleValues had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
index 895f55ca22..e541c09216 100644
--- a/core/src/main/scala/spark/DfsBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
@@ -1,4 +1,6 @@
-package spark
+package spark.broadcast
+
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import java.io._
import java.net._
@@ -7,7 +9,7 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
-import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
+import spark._
@serializable
class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean)
diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala
new file mode 100644
index 0000000000..064142590a
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala
@@ -0,0 +1,41 @@
+package spark.broadcast
+
+import java.util.BitSet
+
+import spark._
+
+/**
+ * Used to keep and pass around information of peers involved in a broadcast
+ *
+ * CHANGED: Keep track of the blockSize for THIS broadcast variable.
+ * Broadcast.BlockSize is expected to be updated across different broadcasts
+ */
+@serializable
+case class SourceInfo (val hostAddress: String,
+ val listenPort: Int,
+ val totalBlocks: Int = SourceInfo.UnusedParam,
+ val totalBytes: Int = SourceInfo.UnusedParam,
+ val blockSize: Int = Broadcast.BlockSize)
+extends Comparable[SourceInfo] with Logging {
+
+ var currentLeechers = 0
+ var receptionFailed = false
+
+ var hasBlocks = 0
+ var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
+
+ // Ascending sort based on leecher count
+ def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
+}
+
+/**
+ * Helper Object of SourceInfo for its constants
+ */
+object SourceInfo {
+ // Constants for special values of listenPort
+ val TxNotStartedRetry = -1
+ val TxOverGoToHDFS = 0
+ // Other constants
+ val StopBroadcast = -2
+ val UnusedParam = 0
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
new file mode 100644
index 0000000000..79dcd317ec
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -0,0 +1,807 @@
+package spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{Comparator, Random, UUID}
+
+import scala.collection.mutable.{ListBuffer, Map, Set}
+import scala.math
+
+import spark._
+
+@serializable
+class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
+extends Broadcast[T] with Logging {
+
+ def value = value_
+
+ TreeBroadcast.synchronized {
+ TreeBroadcast.values.put(uuid, value_)
+ }
+
+ @transient var arrayOfBlocks: Array[BroadcastBlock] = null
+ @transient var totalBytes = -1
+ @transient var totalBlocks = -1
+ @transient var hasBlocks = 0
+ // CHANGED: BlockSize in the Broadcast object is expected to change over time
+ @transient var blockSize = Broadcast.BlockSize
+
+ @transient var listenPortLock = new Object
+ @transient var guidePortLock = new Object
+ @transient var totalBlocksLock = new Object
+ @transient var hasBlocksLock = new Object
+
+ @transient var listOfSources = ListBuffer[SourceInfo]()
+
+ @transient var serveMR: ServeMultipleRequests = null
+ @transient var guideMR: GuideMultipleRequests = null
+
+ @transient var hostAddress = Utils.localIpAddress
+ @transient var listenPort = -1
+ @transient var guidePort = -1
+
+ @transient var hasCopyInHDFS = false
+ @transient var stopBroadcast = false
+
+ // Must call this after all the variables have been created/initialized
+ if (!isLocal) {
+ sendBroadcast
+ }
+
+ def sendBroadcast(): Unit = {
+ logInfo("Local host address: " + hostAddress)
+
+ // Store a persistent copy in HDFS
+ // TODO: Turned OFF for now
+ // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
+ // out.writeObject(value_)
+ // out.close()
+ // TODO: Fix this at some point
+ hasCopyInHDFS = true
+
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = Broadcast.blockifyObject(value_)
+
+ // Prepare the value being broadcasted
+ // TODO: Refactoring and clean-up required here
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks = variableInfo.totalBlocks
+
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon(true)
+ guideMR.start
+ logInfo("GuideMultipleRequests started...")
+
+ // Must always come AFTER guideMR is created
+ while (guidePort == -1) {
+ guidePortLock.synchronized {
+ guidePortLock.wait
+ }
+ }
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start
+ logInfo("ServeMultipleRequests started...")
+
+ // Must always come AFTER serveMR is created
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ // Must always come AFTER listenPort is created
+ val masterSource =
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
+ listOfSources += masterSource
+
+ // Register with the Tracker
+ TreeBroadcast.registerValue(uuid, guidePort)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ in.defaultReadObject
+ TreeBroadcast.synchronized {
+ val cachedVal = TreeBroadcast.values.get(uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ // Initializing everything because Master will only send null/0 values
+ initializeSlaveVariables
+
+ logInfo("Local host address: " + hostAddress)
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(uuid)
+ // If does not succeed, then get from HDFS copy
+ if (receptionSucceeded) {
+ value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ TreeBroadcast.values.put(uuid, value_)
+ } else {
+ val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ TreeBroadcast.values.put(uuid, value_)
+ fileIn.close()
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ private def initializeSlaveVariables: Unit = {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ blockSize = -1
+
+ listenPortLock = new Object
+ totalBlocksLock = new Object
+ hasBlocksLock = new Object
+
+ serveMR = null
+
+ hostAddress = Utils.localIpAddress
+ listenPort = -1
+
+ stopBroadcast = false
+ }
+
+ def getMasterListenPort(variableUUID: UUID): Int = {
+ var clientSocketToTracker: Socket = null
+ var oosTracker: ObjectOutputStream = null
+ var oisTracker: ObjectInputStream = null
+
+ var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
+
+ var retriesLeft = Broadcast.MaxRetryCount
+ do {
+ try {
+ // Connect to the tracker to find out the guide
+ clientSocketToTracker =
+ new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
+ oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ // Send UUID and receive masterListenPort
+ oosTracker.writeObject(uuid)
+ oosTracker.flush()
+ masterListenPort = oisTracker.readObject.asInstanceOf[Int]
+ } catch {
+ case e: Exception => {
+ logInfo("getMasterListenPort had a " + e)
+ }
+ } finally {
+ if (oisTracker != null) {
+ oisTracker.close()
+ }
+ if (oosTracker != null) {
+ oosTracker.close()
+ }
+ if (clientSocketToTracker != null) {
+ clientSocketToTracker.close()
+ }
+ }
+ retriesLeft -= 1
+
+ Thread.sleep(TreeBroadcast.ranGen.nextInt(
+ Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
+ Broadcast.MinKnockInterval)
+
+ } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
+
+ logInfo("Got this guidePort from Tracker: " + masterListenPort)
+ return masterListenPort
+ }
+
+ def receiveBroadcast(variableUUID: UUID): Boolean = {
+ val masterListenPort = getMasterListenPort(variableUUID)
+
+ if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
+ masterListenPort == SourceInfo.TxNotStartedRetry) {
+ // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
+ // to HDFS anyway when receiveBroadcast returns false
+ return false
+ }
+
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ var clientSocketToMaster: Socket = null
+ var oosMaster: ObjectOutputStream = null
+ var oisMaster: ObjectInputStream = null
+
+ // Connect and receive broadcast from the specified source, retrying the
+ // specified number of times in case of failures
+ var retriesLeft = Broadcast.MaxRetryCount
+ do {
+ // Connect to Master and send this worker's Information
+ clientSocketToMaster =
+ new Socket(Broadcast.MasterHostAddress, masterListenPort)
+ // TODO: Guiding object connection is reusable
+ oosMaster =
+ new ObjectOutputStream(clientSocketToMaster.getOutputStream)
+ oosMaster.flush()
+ oisMaster =
+ new ObjectInputStream(clientSocketToMaster.getInputStream)
+
+ logInfo("Connected to Master's guiding object")
+
+ // Send local source information
+ oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
+ oosMaster.flush()
+
+ // Receive source information from Master
+ var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ totalBlocks = sourceInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
+ totalBlocksLock.synchronized {
+ totalBlocksLock.notifyAll
+ }
+ totalBytes = sourceInfo.totalBytes
+ blockSize = sourceInfo.blockSize
+
+ logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+
+ val start = System.nanoTime
+ val receptionSucceeded = receiveSingleTransmission(sourceInfo)
+ val time = (System.nanoTime - start) / 1e9
+
+ // Updating some statistics in sourceInfo. Master will be using them later
+ if (!receptionSucceeded) {
+ sourceInfo.receptionFailed = true
+ }
+
+ // Send back statistics to the Master
+ oosMaster.writeObject(sourceInfo)
+
+ if (oisMaster != null) {
+ oisMaster.close()
+ }
+ if (oosMaster != null) {
+ oosMaster.close()
+ }
+ if (clientSocketToMaster != null) {
+ clientSocketToMaster.close()
+ }
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && hasBlocks < totalBlocks)
+
+ return (hasBlocks == totalBlocks)
+ }
+
+ // Tries to receive broadcast from the source and returns Boolean status.
+ // This might be called multiple times to retry a defined number of times.
+ private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
+ var clientSocketToSource: Socket = null
+ var oosSource: ObjectOutputStream = null
+ var oisSource: ObjectInputStream = null
+
+ var receptionSucceeded = false
+ try {
+ // Connect to the source to get the object itself
+ clientSocketToSource =
+ new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream(clientSocketToSource.getOutputStream)
+ oosSource.flush()
+ oisSource =
+ new ObjectInputStream(clientSocketToSource.getInputStream)
+
+ logInfo("Inside receiveSingleTransmission")
+ logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+
+ // Send the range
+ oosSource.writeObject((hasBlocks, totalBlocks))
+ oosSource.flush()
+
+ for (i <- hasBlocks until totalBlocks) {
+ val recvStartTime = System.currentTimeMillis
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ val receptionTime = (System.currentTimeMillis - recvStartTime)
+
+ logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
+
+ arrayOfBlocks(hasBlocks) = bcBlock
+ hasBlocks += 1
+ // Set to true if at least one block is received
+ receptionSucceeded = true
+ hasBlocksLock.synchronized {
+ hasBlocksLock.notifyAll
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logInfo("receiveSingleTransmission had a " + e)
+ }
+ } finally {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (clientSocketToSource != null) {
+ clientSocketToSource.close()
+ }
+ }
+
+ return receptionSucceeded
+ }
+
+ class GuideMultipleRequests
+ extends Thread with Logging {
+ // Keep track of sources that have completed reception
+ private var setOfCompletedSources = Set[SourceInfo]()
+
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ guidePort = serverSocket.getLocalPort
+ logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
+
+ guidePortLock.synchronized {
+ guidePortLock.notifyAll
+ }
+
+ try {
+ // Don't stop until there is a copy in HDFS
+ while (!stopBroadcast || !hasCopyInHDFS) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("GuideMultipleRequests Timeout.")
+
+ // Stop broadcast if at least one worker has connected and
+ // everyone connected so far are done. Comparing with
+ // listOfSources.size - 1, because it includes the Guide itself
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ }
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Guide: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute(new GuideSingleRequest(clientSocket))
+ } catch {
+ // In failure, close() the socket here; else, the thread will close() it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+
+ logInfo("Sending stopBroadcast notifications...")
+ sendStopBroadcastNotifications
+
+ TreeBroadcast.unregisterValue(uuid)
+ } finally {
+ if (serverSocket != null) {
+ logInfo("GuideMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown
+ }
+
+ private def sendStopBroadcastNotifications: Unit = {
+ listOfSources.synchronized {
+ var listIter = listOfSources.iterator
+ while (listIter.hasNext) {
+ var sourceInfo = listIter.next
+
+ var guideSocketToSource: Socket = null
+ var gosSource: ObjectOutputStream = null
+ var gisSource: ObjectInputStream = null
+
+ try {
+ // Connect to the source
+ guideSocketToSource =
+ new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource =
+ new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ gosSource.flush()
+ gisSource =
+ new ObjectInputStream(guideSocketToSource.getInputStream)
+
+ // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
+ gosSource.writeObject((SourceInfo.StopBroadcast,
+ SourceInfo.StopBroadcast))
+ gosSource.flush()
+ } catch {
+ case e: Exception => {
+ logInfo("sendStopBroadcastNotifications had a " + e)
+ }
+ } finally {
+ if (gisSource != null) {
+ gisSource.close()
+ }
+ if (gosSource != null) {
+ gosSource.close()
+ }
+ if (guideSocketToSource != null) {
+ guideSocketToSource.close()
+ }
+ }
+ }
+ }
+ }
+
+ class GuideSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var selectedSourceInfo: SourceInfo = null
+ private var thisWorkerInfo:SourceInfo = null
+
+ override def run: Unit = {
+ try {
+ logInfo("new GuideSingleRequest is running")
+ // Connecting worker is sending in its hostAddress and listenPort it will
+ // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
+ var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ listOfSources.synchronized {
+ // Select a suitable source and send it back to the worker
+ selectedSourceInfo = selectSuitableSource(sourceInfo)
+ logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
+ oos.writeObject(selectedSourceInfo)
+ oos.flush()
+
+ // Add this new (if it can finish) source to the list of sources
+ thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
+ sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
+ logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo)
+ listOfSources += thisWorkerInfo
+ }
+
+ // Wait till the whole transfer is done. Then receive and update source
+ // statistics in listOfSources
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ listOfSources.synchronized {
+ // This should work since SourceInfo is a case class
+ assert(listOfSources.contains(selectedSourceInfo))
+
+ // Remove first
+ listOfSources = listOfSources - selectedSourceInfo
+ // TODO: Removing a source based on just one failure notification!
+
+ // Update sourceInfo and put it back in, IF reception succeeded
+ if (!sourceInfo.receptionFailed) {
+ // Add thisWorkerInfo to sources that have completed reception
+ setOfCompletedSources.synchronized {
+ setOfCompletedSources += thisWorkerInfo
+ }
+
+ selectedSourceInfo.currentLeechers -= 1
+
+ // Put it back
+ listOfSources += selectedSourceInfo
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close() everything up
+ case e: Exception => {
+ // Assuming that exception caused due to receiver worker failure.
+ // Remove failed worker from listOfSources and update leecherCount of
+ // corresponding source worker
+ listOfSources.synchronized {
+ if (selectedSourceInfo != null) {
+ // Remove first
+ listOfSources = listOfSources - selectedSourceInfo
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ listOfSources += selectedSourceInfo
+ }
+
+ // Remove thisWorkerInfo
+ if (listOfSources != null) {
+ listOfSources = listOfSources - thisWorkerInfo
+ }
+ }
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ // FIXME: Caller must have a synchronized block on listOfSources
+ // FIXME: If a worker fails to get the broadcasted variable from a source
+ // and comes back to the Master, this function might choose the worker
+ // itself as a source to create a dependency cycle (this worker was put
+ // into listOfSources as a streming source when it first arrived). The
+ // length of this cycle can be arbitrarily long.
+ private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+ // Select one with the most leechers. This will level-wise fill the tree
+
+ var maxLeechers = -1
+ var selectedSource: SourceInfo = null
+
+ listOfSources.foreach { source =>
+ if (source != skipSourceInfo &&
+ source.currentLeechers < Broadcast.MaxDegree &&
+ source.currentLeechers > maxLeechers) {
+ selectedSource = source
+ maxLeechers = source.currentLeechers
+ }
+ }
+
+ // Update leecher count
+ selectedSource.currentLeechers += 1
+
+ return selectedSource
+ }
+ }
+ }
+
+ class ServeMultipleRequests
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ listenPort = serverSocket.getLocalPort
+ logInfo("ServeMultipleRequests started with " + serverSocket)
+
+ listenPortLock.synchronized {
+ listenPortLock.notifyAll
+ }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("ServeMultipleRequests Timeout.")
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute(new ServeSingleRequest(clientSocket))
+ } catch {
+ // In failure, close() socket here; else, the thread will close() it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ServeMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown
+ }
+
+ class ServeSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var sendFrom = 0
+ private var sendUntil = totalBlocks
+
+ override def run: Unit = {
+ try {
+ logInfo("new ServeSingleRequest is running")
+
+ // Receive range to send
+ var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
+ sendFrom = rangeToSend._1
+ sendUntil = rangeToSend._2
+
+ if (sendFrom == SourceInfo.StopBroadcast &&
+ sendUntil == SourceInfo.StopBroadcast) {
+ stopBroadcast = true
+ } else {
+ // Carry on
+ sendObject
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close() everything up
+ case e: Exception => {
+ logInfo("ServeSingleRequest had a " + e)
+ }
+ } finally {
+ logInfo("ServeSingleRequest is closing streams and sockets")
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ private def sendObject: Unit = {
+ // Wait till receiving the SourceInfo from Master
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized {
+ totalBlocksLock.wait
+ }
+ }
+
+ for (i <- sendFrom until sendUntil) {
+ while (i == hasBlocks) {
+ hasBlocksLock.synchronized {
+ hasBlocksLock.wait
+ }
+ }
+ try {
+ oos.writeObject(arrayOfBlocks(i))
+ oos.flush()
+ } catch {
+ case e: Exception => {
+ logInfo("sendObject had a " + e)
+ }
+ }
+ logInfo("Sent block: " + i + " to " + clientSocket)
+ }
+ }
+ }
+ }
+}
+
+class TreeBroadcastFactory
+extends BroadcastFactory {
+ def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster)
+ def newBroadcast[T](value_ : T, isLocal: Boolean) =
+ new TreeBroadcast[T](value_, isLocal)
+}
+
+private object TreeBroadcast
+extends Logging {
+ val values = SparkEnv.get.cache.newKeySpace()
+
+ var valueToGuidePortMap = Map[UUID, Int]()
+
+ // Random number generator
+ var ranGen = new Random
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var trackMV: TrackMultipleValues = null
+
+ private var MaxDegree_ : Int = 2
+
+ def initialize(isMaster__ : Boolean): Unit = {
+ synchronized {
+ if (!initialized) {
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ trackMV = new TrackMultipleValues
+ trackMV.setDaemon(true)
+ trackMV.start
+ // TODO: Logging the following line makes the Spark framework ID not
+ // getting logged, cause it calls logInfo before log4j is initialized
+ logInfo("TrackMultipleValues started...")
+ }
+
+ // Initialize DfsBroadcast to be used for broadcast variable persistence
+ DfsBroadcast.initialize
+
+ initialized = true
+ }
+ }
+ }
+
+ def isMaster = isMaster_
+
+ def registerValue(uuid: UUID, guidePort: Int): Unit = {
+ valueToGuidePortMap.synchronized {
+ valueToGuidePortMap += (uuid -> guidePort)
+ logInfo("New value registered with the Tracker " + valueToGuidePortMap)
+ }
+ }
+
+ def unregisterValue(uuid: UUID): Unit = {
+ valueToGuidePortMap.synchronized {
+ valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
+ logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
+ }
+ }
+
+ class TrackMultipleValues
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
+ logInfo("TrackMultipleValues" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("TrackMultipleValues Timeout. Stopping listening...")
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+ try {
+ val uuid = ois.readObject.asInstanceOf[UUID]
+ var guidePort =
+ if (valueToGuidePortMap.contains(uuid)) {
+ valueToGuidePortMap(uuid)
+ } else SourceInfo.TxNotStartedRetry
+ logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
+ oos.writeObject(guidePort)
+ } catch {
+ case e: Exception => {
+ logInfo("TrackMultipleValues had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close() socket here; else, client thread will close()
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown
+ }
+ }
+}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 3089360756..d14d313b48 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -5,8 +5,8 @@ import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
-
import SparkContext._
+import scala.collection.mutable.ArrayBuffer
class ShuffleSuite extends FunSuite {
test("groupByKey") {
@@ -115,6 +115,38 @@ class ShuffleSuite extends FunSuite {
sc.stop()
}
+ test("leftOuterJoin") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.leftOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (1, Some('x'))),
+ (1, (2, Some('x'))),
+ (2, (1, Some('y'))),
+ (2, (1, Some('z'))),
+ (3, (1, None))
+ ))
+ sc.stop()
+ }
+
+ test("rightOuterJoin") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.rightOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (Some(1), 'x')),
+ (1, (Some(2), 'x')),
+ (2, (Some(1), 'y')),
+ (2, (Some(1), 'z')),
+ (4, (None, 'w'))
+ ))
+ sc.stop()
+ }
+
test("join with no matches") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
@@ -138,4 +170,20 @@ class ShuffleSuite extends FunSuite {
))
sc.stop()
}
+
+ test("groupWith") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.groupWith(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
+ (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
+ (3, (ArrayBuffer(1), ArrayBuffer())),
+ (4, (ArrayBuffer(), ArrayBuffer('w')))
+ ))
+ sc.stop()
+ }
+
}