aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>2010-11-26 23:15:33 -0800
committerMosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>2010-11-26 23:15:33 -0800
commitc25914228c542ed2900310b82926d10b95df83d6 (patch)
tree698a411c81d63084f1472efc5cd16467716158d6
parent98542f81bb30381148759f414f9c2ca679d3bd63 (diff)
downloadspark-c25914228c542ed2900310b82926d10b95df83d6.tar.gz
spark-c25914228c542ed2900310b82926d10b95df83d6.tar.bz2
spark-c25914228c542ed2900310b82926d10b95df83d6.zip
Moved ChaninedStreaming to a separate file and renamed to ChainedBroadcast.
-rw-r--r--src/scala/spark/Broadcast.scala922
-rw-r--r--src/scala/spark/ChainedBroadcast.scala897
-rw-r--r--src/scala/spark/SparkContext.scala2
3 files changed, 912 insertions, 909 deletions
diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala
index 08648d2ef4..05f214e48e 100644
--- a/src/scala/spark/Broadcast.scala
+++ b/src/scala/spark/Broadcast.scala
@@ -1,14 +1,6 @@
package spark
-import java.io._
-import java.net._
-import java.util.{Comparator, PriorityQueue, Random, UUID}
-
-import com.google.common.collect.MapMaker
-
-import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{Map, Set}
+import java.util.UUID
@serializable
trait Broadcast {
@@ -21,708 +13,24 @@ trait Broadcast {
override def toString = "spark.Broadcast(" + uuid + ")"
}
-@serializable
-class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean)
-extends Broadcast with Logging {
-
- def value = value_
-
- BroadcastCS.synchronized {
- BroadcastCS.values.put (uuid, value_)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var pqOfSources = new PriorityQueue[SourceInfo]
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = InetAddress.getLocalHost.getHostAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var hasCopyInHDFS = false
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!local) {
- sendBroadcast
- }
-
- def sendBroadcast (): Unit = {
- logInfo ("Local host address: " + hostAddress)
-
- // Store a persistent copy in HDFS
- // TODO: Turned OFF for now
- // val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
- // out.writeObject (value_)
- // out.close
- // TODO: Fix this at some point
- hasCopyInHDFS = true
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = blockifyObject (value_, BroadcastCS.BlockSize)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon (true)
- guideMR.start
- logInfo ("GuideMultipleRequests started...")
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon (true)
- serveMR.start
- logInfo ("ServeMultipleRequests started...")
-
- // Prepare the value being broadcasted
- // TODO: Refactoring and clean-up required here
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- pqOfSources = new PriorityQueue[SourceInfo]
- val masterSource_0 =
- SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0)
- pqOfSources.add (masterSource_0)
-
- // Register with the Tracker
- while (guidePort == -1) {
- guidePortLock.synchronized {
- guidePortLock.wait
- }
- }
- BroadcastCS.registerValue (uuid, guidePort)
- }
-
- private def readObject (in: ObjectInputStream): Unit = {
- in.defaultReadObject
- BroadcastCS.synchronized {
- val cachedVal = BroadcastCS.values.get (uuid)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- // Initializing everything because Master will only send null/0 values
- initializeSlaveVariables
-
- logInfo ("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon (true)
- serveMR.start
- logInfo ("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast (uuid)
- // If does not succeed, then get from HDFS copy
- if (receptionSucceeded) {
- value_ = unBlockifyObject[T]
- BroadcastCS.values.put (uuid, value_)
- } else {
- val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
- value_ = fileIn.readObject.asInstanceOf[T]
- BroadcastCS.values.put(uuid, value_)
- fileIn.close
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
- }
- }
-
- private def initializeSlaveVariables: Unit = {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = InetAddress.getLocalHost.getHostAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream (baos)
- oos.writeObject (obj)
- oos.close
- baos.close
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream (byteArray)
-
- var blockNum = (byteArray.length / blockSize)
- if (byteArray.length % blockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock] (blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, blockSize)) {
- val thisBlockSize = Math.min (blockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte] (thisBlockSize)
- val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
-
- retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
- blockID += 1
- }
- bais.close
-
- var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- private def unBlockifyObject[A]: A = {
- var retByteArray = new Array[Byte] (totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BroadcastCS.BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject (retByteArray)
- }
-
- private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
- val retVal = in.readObject.asInstanceOf[A]
- in.close
- return retVal
- }
-
- def getMasterListenPort (variableUUID: UUID): Int = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
-
- var retriesLeft = BroadcastCS.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out the guide
- val clientSocketToTracker =
- new Socket(BroadcastCS.MasterHostAddress, BroadcastCS.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream (clientSocketToTracker.getOutputStream)
- oosTracker.flush
- val oisTracker =
- new ObjectInputStream (clientSocketToTracker.getInputStream)
-
- // Send UUID and receive masterListenPort
- oosTracker.writeObject (uuid)
- oosTracker.flush
- masterListenPort = oisTracker.readObject.asInstanceOf[Int]
- } catch {
- case e: Exception => {
- logInfo ("getMasterListenPort had a " + e)
- }
- } finally {
- if (oisTracker != null) {
- oisTracker.close
- }
- if (oosTracker != null) {
- oosTracker.close
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close
- }
- }
- retriesLeft -= 1
-
- Thread.sleep (BroadcastCS.ranGen.nextInt (
- BroadcastCS.MaxKnockInterval - BroadcastCS.MinKnockInterval) +
- BroadcastCS.MinKnockInterval)
-
- } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
-
- logInfo ("Got this guidePort from Tracker: " + masterListenPort)
- return masterListenPort
- }
-
- def receiveBroadcast (variableUUID: UUID): Boolean = {
- val masterListenPort = getMasterListenPort (variableUUID)
-
- if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
- masterListenPort == SourceInfo.TxNotStartedRetry) {
- // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
- // to HDFS anyway when receiveBroadcast returns false
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized {
- listenPortLock.wait
- }
- }
-
- var clientSocketToMaster: Socket = null
- var oosMaster: ObjectOutputStream = null
- var oisMaster: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = BroadcastCS.MaxRetryCount
- do {
- // Connect to Master and send this worker's Information
- clientSocketToMaster =
- new Socket(BroadcastCS.MasterHostAddress, masterListenPort)
- // TODO: Guiding object connection is reusable
- oosMaster =
- new ObjectOutputStream (clientSocketToMaster.getOutputStream)
- oosMaster.flush
- oisMaster =
- new ObjectInputStream (clientSocketToMaster.getInputStream)
-
- logInfo ("Connected to Master's guiding object")
-
- // Send local source information
- oosMaster.writeObject(SourceInfo (hostAddress, listenPort, -1, -1, 0))
- oosMaster.flush
-
- // Receive source information from Master
- var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
- totalBlocksLock.synchronized {
- totalBlocksLock.notifyAll
- }
- totalBytes = sourceInfo.totalBytes
-
- logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission (sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Master will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Master
- oosMaster.writeObject (sourceInfo)
-
- if (oisMaster != null) {
- oisMaster.close
- }
- if (oosMaster != null) {
- oosMaster.close
- }
- if (clientSocketToMaster != null) {
- clientSocketToMaster.close
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- // Tries to receive broadcast from the source and returns Boolean status.
- // This might be called multiple times to retry a defined number of times.
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource =
- new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource =
- new ObjectOutputStream (clientSocketToSource.getOutputStream)
- oosSource.flush
- oisSource =
- new ObjectInputStream (clientSocketToSource.getInputStream)
-
- logInfo ("Inside receiveSingleTransmission")
- logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized {
- hasBlocksLock.notifyAll
- }
- }
- } catch {
- case e: Exception => {
- logInfo ("receiveSingleTransmission had a " + e)
- }
- } finally {
- if (oisSource != null) {
- oisSource.close
- }
- if (oosSource != null) {
- oosSource.close
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo] ()
-
- override def run: Unit = {
- var threadPool = BroadcastCS.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (0)
- guidePort = serverSocket.getLocalPort
- logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized {
- guidePortLock.notifyAll
- }
-
- try {
- // Don't stop until there is a copy in HDFS
- while (!stopBroadcast || !hasCopyInHDFS) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (BroadcastCS.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("GuideMultipleRequests Timeout.")
-
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // pqOfSources.size - 1, because it includes the Guide itself
- if (pqOfSources.size > 1 &&
- setOfCompletedSources.size == pqOfSources.size - 1) {
- stopBroadcast = true
- }
- }
- }
- if (clientSocket != null) {
- logInfo ("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute (new GuideSingleRequest (clientSocket))
- } catch {
- // In failure, close the socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close
- }
- }
- }
-
- logInfo ("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- BroadcastCS.unregisterValue (uuid)
- } finally {
- if (serverSocket != null) {
- logInfo ("GuideMultipleRequests now stopping...")
- serverSocket.close
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
- }
-
- private def sendStopBroadcastNotifications: Unit = {
- pqOfSources.synchronized {
- var pqIter = pqOfSources.iterator
- while (pqIter.hasNext) {
- var sourceInfo = pqIter.next
+private object Broadcast
+extends Logging {
+ private var initialized = false
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
+ // Called by SparkContext or Executor before using Broadcast
+ // Calls all other initializers here
+ def initialize (isMaster: Boolean): Unit = {
+ synchronized {
+ if (!initialized) {
+ // Initialization for DfsBroadcast
+ DfsBroadcast.initialize
+ // Initialization for ChainedStreamingBroadcast
+ ChainedBroadcast.initialize (isMaster)
- try {
- // Connect to the source
- guideSocketToSource =
- new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource =
- new ObjectOutputStream (guideSocketToSource.getOutputStream)
- gosSource.flush
- gisSource =
- new ObjectInputStream (guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
- gosSource.writeObject ((SourceInfo.StopBroadcast,
- SourceInfo.StopBroadcast))
- gosSource.flush
- } catch {
- case e: Exception => {
- logInfo ("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close
- }
- if (gosSource != null) {
- gosSource.close
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close
- }
- }
- }
+ initialized = true
}
}
-
- class GuideSingleRequest (val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- private val ois = new ObjectInputStream (clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run: Unit = {
- try {
- logInfo ("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. ReplicaID is 0 and other fields are invalid (-1)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- pqOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource (sourceInfo)
- logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject (selectedSourceInfo)
- oos.flush
-
- // Add this new (if it can finish) source to the PQ of sources
- thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes, 0)
- logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo)
- pqOfSources.add (thisWorkerInfo)
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in pqOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- pqOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert (pqOfSources.contains (selectedSourceInfo))
-
- // Remove first
- pqOfSources.remove (selectedSourceInfo)
- // TODO: Removing a source based on just one failure notification!
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources += thisWorkerInfo
-
- selectedSourceInfo.currentLeechers -= 1
-
- // Put it back
- pqOfSources.add (selectedSourceInfo)
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- case e: Exception => {
- // Assuming that exception caused due to receiver worker failure.
- // Remove failed worker from pqOfSources and update leecherCount of
- // corresponding source worker
- pqOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- pqOfSources.remove (selectedSourceInfo)
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- pqOfSources.add (selectedSourceInfo)
- }
-
- // Remove thisWorkerInfo
- if (pqOfSources != null) {
- pqOfSources.remove (thisWorkerInfo)
- }
- }
- }
- } finally {
- ois.close
- oos.close
- clientSocket.close
- }
- }
-
- // TODO: Caller must have a synchronized block on pqOfSources
- // TODO: If a worker fails to get the broadcasted variable from a source and
- // comes back to Master, this function might choose the worker itself as a
- // source tp create a dependency cycle (this worker was put into pqOfSources
- // as a streming source when it first arrived). The length of this cycle can
- // be arbitrarily long.
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- // Select one based on the ordering strategy (e.g., least leechers etc.)
- // take is a blocking call removing the element from PQ
- var selectedSource = pqOfSources.poll
- assert (selectedSource != null)
- // Update leecher count
- selectedSource.currentLeechers += 1
- // Add it back and then return
- pqOfSources.add (selectedSource)
- return selectedSource
- }
- }
}
-
- class ServeMultipleRequests
- extends Thread with Logging {
- override def run: Unit = {
- var threadPool = BroadcastCS.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (0)
- listenPort = serverSocket.getLocalPort
- logInfo ("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized {
- listenPortLock.notifyAll
- }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (BroadcastCS.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("ServeMultipleRequests Timeout.")
- }
- }
- if (clientSocket != null) {
- logInfo ("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute (new ServeSingleRequest (clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo ("ServeMultipleRequests now stopping...")
- serverSocket.close
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
- }
-
- class ServeSingleRequest (val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- private val ois = new ObjectInputStream (clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run: Unit = {
- try {
- logInfo ("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- if (sendFrom == SourceInfo.StopBroadcast &&
- sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- // Carry on
- sendObject
- }
- } catch {
- // TODO: Need to add better exception handling here
- // If something went wrong, e.g., the worker at the other end died etc.
- // then close everything up
- case e: Exception => {
- logInfo ("ServeSingleRequest had a " + e)
- }
- } finally {
- logInfo ("ServeSingleRequest is closing streams and sockets")
- ois.close
- oos.close
- clientSocket.close
- }
- }
-
- private def sendObject: Unit = {
- // Wait till receiving the SourceInfo from Master
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized {
- totalBlocksLock.wait
- }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized {
- hasBlocksLock.wait
- }
- }
- try {
- oos.writeObject (arrayOfBlocks(i))
- oos.flush
- } catch {
- case e: Exception => {
- logInfo ("sendObject had a " + e)
- }
- }
- logInfo ("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
}
@serializable
@@ -756,206 +64,4 @@ case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
val totalBlocks: Int, val totalBytes: Int) {
@transient var hasBlocks = 0
-}
-
-private object Broadcast
-extends Logging {
- private var initialized = false
-
- // Will be called by SparkContext or Executor before using Broadcast
- // Calls all other initializers here
- def initialize (isMaster: Boolean): Unit = {
- synchronized {
- if (!initialized) {
- // Initialization for DfsBroadcast
- DfsBroadcast.initialize
- // Initialization for ChainedStreamingBroadcast
- BroadcastCS.initialize (isMaster)
-
- initialized = true
- }
- }
- }
-}
-
-private object BroadcastCS
-extends Logging {
- val values = new MapMaker ().softValues ().makeMap[UUID, Any]
-
- var valueToGuidePortMap = Map[UUID, Int] ()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var isMaster_ = false
-
- private var MasterHostAddress_ = "127.0.0.1"
- private var MasterTrackerPort_ : Int = 22222
- private var BlockSize_ : Int = 512 * 1024
- private var MaxRetryCount_ : Int = 2
-
- private var TrackerSocketTimeout_ : Int = 50000
- private var ServerSocketTimeout_ : Int = 10000
-
- private var trackMV: TrackMultipleValues = null
-
- private var MinKnockInterval_ = 500
- private var MaxKnockInterval_ = 999
-
- def initialize (isMaster__ : Boolean): Unit = {
- synchronized {
- if (!initialized) {
- MasterHostAddress_ =
- System.getProperty ("spark.broadcast.MasterHostAddress", "127.0.0.1")
- MasterTrackerPort_ =
- System.getProperty ("spark.broadcast.MasterTrackerPort", "22222").toInt
- BlockSize_ =
- System.getProperty ("spark.broadcast.BlockSize", "512").toInt * 1024
- MaxRetryCount_ =
- System.getProperty ("spark.broadcast.MaxRetryCount", "2").toInt
-
- TrackerSocketTimeout_ =
- System.getProperty ("spark.broadcast.TrackerSocketTimeout", "50000").toInt
- ServerSocketTimeout_ =
- System.getProperty ("spark.broadcast.ServerSocketTimeout", "10000").toInt
-
- MinKnockInterval_ =
- System.getProperty ("spark.broadcast.MinKnockInterval", "500").toInt
- MaxKnockInterval_ =
- System.getProperty ("spark.broadcast.MaxKnockInterval", "999").toInt
-
- isMaster_ = isMaster__
-
- if (isMaster) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon (true)
- trackMV.start
- logInfo ("TrackMultipleValues started...")
- }
-
- initialized = true
- }
- }
- }
-
- def MasterHostAddress = MasterHostAddress_
- def MasterTrackerPort = MasterTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def isMaster = isMaster_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- def registerValue (uuid: UUID, guidePort: Int): Unit = {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap += (uuid -> guidePort)
- logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
- }
- }
-
- def unregisterValue (uuid: UUID): Unit = {
- valueToGuidePortMap.synchronized {
- valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
- logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
- }
- }
-
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread (r)
- t.setDaemon (true)
- return t
- }
- }
- }
-
- // Wrapper over newCachedThreadPool
- def newDaemonCachedThreadPool: ThreadPoolExecutor = {
- var threadPool =
- Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run: Unit = {
- var threadPool = BroadcastCS.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket (BroadcastCS.MasterTrackerPort)
- logInfo ("TrackMultipleValues" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout (TrackerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- logInfo ("TrackMultipleValues Timeout. Stopping listening...")
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute (new Thread {
- override def run: Unit = {
- val oos = new ObjectOutputStream (clientSocket.getOutputStream)
- oos.flush
- val ois = new ObjectInputStream (clientSocket.getInputStream)
- try {
- val uuid = ois.readObject.asInstanceOf[UUID]
- var guidePort =
- if (valueToGuidePortMap.contains (uuid)) {
- valueToGuidePortMap (uuid)
- } else SourceInfo.TxNotStartedRetry
- logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
- oos.writeObject (guidePort)
- } catch {
- case e: Exception => {
- logInfo ("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close
- oos.close
- clientSocket.close
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close
- }
- }
- }
- } finally {
- serverSocket.close
- }
-
- // Shutdown the thread pool
- threadPool.shutdown
- }
- }
}
diff --git a/src/scala/spark/ChainedBroadcast.scala b/src/scala/spark/ChainedBroadcast.scala
new file mode 100644
index 0000000000..6b34843abe
--- /dev/null
+++ b/src/scala/spark/ChainedBroadcast.scala
@@ -0,0 +1,897 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{Comparator, PriorityQueue, Random, UUID}
+
+import com.google.common.collect.MapMaker
+
+import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{Map, Set}
+
+@serializable
+class ChainedBroadcast[T] (@transient var value_ : T, local: Boolean)
+extends Broadcast with Logging {
+
+ def value = value_
+
+ ChainedBroadcast.synchronized {
+ ChainedBroadcast.values.put (uuid, value_)
+ }
+
+ @transient var arrayOfBlocks: Array[BroadcastBlock] = null
+ @transient var totalBytes = -1
+ @transient var totalBlocks = -1
+ @transient var hasBlocks = 0
+
+ @transient var listenPortLock = new Object
+ @transient var guidePortLock = new Object
+ @transient var totalBlocksLock = new Object
+ @transient var hasBlocksLock = new Object
+
+ @transient var pqOfSources = new PriorityQueue[SourceInfo]
+
+ @transient var serveMR: ServeMultipleRequests = null
+ @transient var guideMR: GuideMultipleRequests = null
+
+ @transient var hostAddress = InetAddress.getLocalHost.getHostAddress
+ @transient var listenPort = -1
+ @transient var guidePort = -1
+
+ @transient var hasCopyInHDFS = false
+ @transient var stopBroadcast = false
+
+ // Must call this after all the variables have been created/initialized
+ if (!local) {
+ sendBroadcast
+ }
+
+ def sendBroadcast (): Unit = {
+ logInfo ("Local host address: " + hostAddress)
+
+ // Store a persistent copy in HDFS
+ // TODO: Turned OFF for now
+ // val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
+ // out.writeObject (value_)
+ // out.close
+ // TODO: Fix this at some point
+ hasCopyInHDFS = true
+
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize)
+
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon (true)
+ guideMR.start
+ logInfo ("GuideMultipleRequests started...")
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon (true)
+ serveMR.start
+ logInfo ("ServeMultipleRequests started...")
+
+ // Prepare the value being broadcasted
+ // TODO: Refactoring and clean-up required here
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks = variableInfo.totalBlocks
+
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ pqOfSources = new PriorityQueue[SourceInfo]
+ val masterSource_0 =
+ SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0)
+ pqOfSources.add (masterSource_0)
+
+ // Register with the Tracker
+ while (guidePort == -1) {
+ guidePortLock.synchronized {
+ guidePortLock.wait
+ }
+ }
+ ChainedBroadcast.registerValue (uuid, guidePort)
+ }
+
+ private def readObject (in: ObjectInputStream): Unit = {
+ in.defaultReadObject
+ ChainedBroadcast.synchronized {
+ val cachedVal = ChainedBroadcast.values.get (uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ // Initializing everything because Master will only send null/0 values
+ initializeSlaveVariables
+
+ logInfo ("Local host address: " + hostAddress)
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon (true)
+ serveMR.start
+ logInfo ("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast (uuid)
+ // If does not succeed, then get from HDFS copy
+ if (receptionSucceeded) {
+ value_ = unBlockifyObject[T]
+ ChainedBroadcast.values.put (uuid, value_)
+ } else {
+ val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ ChainedBroadcast.values.put(uuid, value_)
+ fileIn.close
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ private def initializeSlaveVariables: Unit = {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+
+ listenPortLock = new Object
+ totalBlocksLock = new Object
+ hasBlocksLock = new Object
+
+ serveMR = null
+
+ hostAddress = InetAddress.getLocalHost.getHostAddress
+ listenPort = -1
+
+ stopBroadcast = false
+ }
+
+ private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream (baos)
+ oos.writeObject (obj)
+ oos.close
+ baos.close
+ val byteArray = baos.toByteArray
+ val bais = new ByteArrayInputStream (byteArray)
+
+ var blockNum = (byteArray.length / blockSize)
+ if (byteArray.length % blockSize != 0)
+ blockNum += 1
+
+ var retVal = new Array[BroadcastBlock] (blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, blockSize)) {
+ val thisBlockSize = Math.min (blockSize, byteArray.length - i)
+ var tempByteArray = new Array[Byte] (thisBlockSize)
+ val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
+
+ retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close
+
+ var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
+ variableInfo.hasBlocks = blockNum
+
+ return variableInfo
+ }
+
+ private def unBlockifyObject[A]: A = {
+ var retByteArray = new Array[Byte] (totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
+ }
+ byteArrayToObject (retByteArray)
+ }
+
+ private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
+ val retVal = in.readObject.asInstanceOf[A]
+ in.close
+ return retVal
+ }
+
+ def getMasterListenPort (variableUUID: UUID): Int = {
+ var clientSocketToTracker: Socket = null
+ var oosTracker: ObjectOutputStream = null
+ var oisTracker: ObjectInputStream = null
+
+ var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
+
+ var retriesLeft = ChainedBroadcast.MaxRetryCount
+ do {
+ try {
+ // Connect to the tracker to find out the guide
+ val clientSocketToTracker =
+ new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream (clientSocketToTracker.getOutputStream)
+ oosTracker.flush
+ val oisTracker =
+ new ObjectInputStream (clientSocketToTracker.getInputStream)
+
+ // Send UUID and receive masterListenPort
+ oosTracker.writeObject (uuid)
+ oosTracker.flush
+ masterListenPort = oisTracker.readObject.asInstanceOf[Int]
+ } catch {
+ case e: Exception => {
+ logInfo ("getMasterListenPort had a " + e)
+ }
+ } finally {
+ if (oisTracker != null) {
+ oisTracker.close
+ }
+ if (oosTracker != null) {
+ oosTracker.close
+ }
+ if (clientSocketToTracker != null) {
+ clientSocketToTracker.close
+ }
+ }
+ retriesLeft -= 1
+
+ Thread.sleep (ChainedBroadcast.ranGen.nextInt (
+ ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) +
+ ChainedBroadcast.MinKnockInterval)
+
+ } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
+
+ logInfo ("Got this guidePort from Tracker: " + masterListenPort)
+ return masterListenPort
+ }
+
+ def receiveBroadcast (variableUUID: UUID): Boolean = {
+ val masterListenPort = getMasterListenPort (variableUUID)
+
+ if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
+ masterListenPort == SourceInfo.TxNotStartedRetry) {
+ // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
+ // to HDFS anyway when receiveBroadcast returns false
+ return false
+ }
+
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ var clientSocketToMaster: Socket = null
+ var oosMaster: ObjectOutputStream = null
+ var oisMaster: ObjectInputStream = null
+
+ // Connect and receive broadcast from the specified source, retrying the
+ // specified number of times in case of failures
+ var retriesLeft = ChainedBroadcast.MaxRetryCount
+ do {
+ // Connect to Master and send this worker's Information
+ clientSocketToMaster =
+ new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort)
+ // TODO: Guiding object connection is reusable
+ oosMaster =
+ new ObjectOutputStream (clientSocketToMaster.getOutputStream)
+ oosMaster.flush
+ oisMaster =
+ new ObjectInputStream (clientSocketToMaster.getInputStream)
+
+ logInfo ("Connected to Master's guiding object")
+
+ // Send local source information
+ oosMaster.writeObject(SourceInfo (hostAddress, listenPort, -1, -1, 0))
+ oosMaster.flush
+
+ // Receive source information from Master
+ var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ totalBlocks = sourceInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
+ totalBlocksLock.synchronized {
+ totalBlocksLock.notifyAll
+ }
+ totalBytes = sourceInfo.totalBytes
+
+ logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+
+ val start = System.nanoTime
+ val receptionSucceeded = receiveSingleTransmission (sourceInfo)
+ val time = (System.nanoTime - start) / 1e9
+
+ // Updating some statistics in sourceInfo. Master will be using them later
+ if (!receptionSucceeded) {
+ sourceInfo.receptionFailed = true
+ }
+
+ // Send back statistics to the Master
+ oosMaster.writeObject (sourceInfo)
+
+ if (oisMaster != null) {
+ oisMaster.close
+ }
+ if (oosMaster != null) {
+ oosMaster.close
+ }
+ if (clientSocketToMaster != null) {
+ clientSocketToMaster.close
+ }
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && hasBlocks < totalBlocks)
+
+ return (hasBlocks == totalBlocks)
+ }
+
+ // Tries to receive broadcast from the source and returns Boolean status.
+ // This might be called multiple times to retry a defined number of times.
+ private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
+ var clientSocketToSource: Socket = null
+ var oosSource: ObjectOutputStream = null
+ var oisSource: ObjectInputStream = null
+
+ var receptionSucceeded = false
+ try {
+ // Connect to the source to get the object itself
+ clientSocketToSource =
+ new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream (clientSocketToSource.getOutputStream)
+ oosSource.flush
+ oisSource =
+ new ObjectInputStream (clientSocketToSource.getInputStream)
+
+ logInfo ("Inside receiveSingleTransmission")
+ logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+
+ // Send the range
+ oosSource.writeObject((hasBlocks, totalBlocks))
+ oosSource.flush
+
+ for (i <- hasBlocks until totalBlocks) {
+ val recvStartTime = System.currentTimeMillis
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ val receptionTime = (System.currentTimeMillis - recvStartTime)
+
+ logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
+
+ arrayOfBlocks(hasBlocks) = bcBlock
+ hasBlocks += 1
+ // Set to true if at least one block is received
+ receptionSucceeded = true
+ hasBlocksLock.synchronized {
+ hasBlocksLock.notifyAll
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logInfo ("receiveSingleTransmission had a " + e)
+ }
+ } finally {
+ if (oisSource != null) {
+ oisSource.close
+ }
+ if (oosSource != null) {
+ oosSource.close
+ }
+ if (clientSocketToSource != null) {
+ clientSocketToSource.close
+ }
+ }
+
+ return receptionSucceeded
+ }
+
+ class GuideMultipleRequests
+ extends Thread with Logging {
+ // Keep track of sources that have completed reception
+ private var setOfCompletedSources = Set[SourceInfo] ()
+
+ override def run: Unit = {
+ var threadPool = ChainedBroadcast.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (0)
+ guidePort = serverSocket.getLocalPort
+ logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
+
+ guidePortLock.synchronized {
+ guidePortLock.notifyAll
+ }
+
+ try {
+ // Don't stop until there is a copy in HDFS
+ while (!stopBroadcast || !hasCopyInHDFS) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo ("GuideMultipleRequests Timeout.")
+
+ // Stop broadcast if at least one worker has connected and
+ // everyone connected so far are done. Comparing with
+ // pqOfSources.size - 1, because it includes the Guide itself
+ if (pqOfSources.size > 1 &&
+ setOfCompletedSources.size == pqOfSources.size - 1) {
+ stopBroadcast = true
+ }
+ }
+ }
+ if (clientSocket != null) {
+ logInfo ("Guide: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute (new GuideSingleRequest (clientSocket))
+ } catch {
+ // In failure, close the socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+
+ logInfo ("Sending stopBroadcast notifications...")
+ sendStopBroadcastNotifications
+
+ ChainedBroadcast.unregisterValue (uuid)
+ } finally {
+ if (serverSocket != null) {
+ logInfo ("GuideMultipleRequests now stopping...")
+ serverSocket.close
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown
+ }
+
+ private def sendStopBroadcastNotifications: Unit = {
+ pqOfSources.synchronized {
+ var pqIter = pqOfSources.iterator
+ while (pqIter.hasNext) {
+ var sourceInfo = pqIter.next
+
+ var guideSocketToSource: Socket = null
+ var gosSource: ObjectOutputStream = null
+ var gisSource: ObjectInputStream = null
+
+ try {
+ // Connect to the source
+ guideSocketToSource =
+ new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource =
+ new ObjectOutputStream (guideSocketToSource.getOutputStream)
+ gosSource.flush
+ gisSource =
+ new ObjectInputStream (guideSocketToSource.getInputStream)
+
+ // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
+ gosSource.writeObject ((SourceInfo.StopBroadcast,
+ SourceInfo.StopBroadcast))
+ gosSource.flush
+ } catch {
+ case e: Exception => {
+ logInfo ("sendStopBroadcastNotifications had a " + e)
+ }
+ } finally {
+ if (gisSource != null) {
+ gisSource.close
+ }
+ if (gosSource != null) {
+ gosSource.close
+ }
+ if (guideSocketToSource != null) {
+ guideSocketToSource.close
+ }
+ }
+ }
+ }
+ }
+
+ class GuideSingleRequest (val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ oos.flush
+ private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+ private var selectedSourceInfo: SourceInfo = null
+ private var thisWorkerInfo:SourceInfo = null
+
+ override def run: Unit = {
+ try {
+ logInfo ("new GuideSingleRequest is running")
+ // Connecting worker is sending in its hostAddress and listenPort it will
+ // be listening to. ReplicaID is 0 and other fields are invalid (-1)
+ var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ pqOfSources.synchronized {
+ // Select a suitable source and send it back to the worker
+ selectedSourceInfo = selectSuitableSource (sourceInfo)
+ logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
+ oos.writeObject (selectedSourceInfo)
+ oos.flush
+
+ // Add this new (if it can finish) source to the PQ of sources
+ thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
+ sourceInfo.listenPort, totalBlocks, totalBytes, 0)
+ logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo)
+ pqOfSources.add (thisWorkerInfo)
+ }
+
+ // Wait till the whole transfer is done. Then receive and update source
+ // statistics in pqOfSources
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ pqOfSources.synchronized {
+ // This should work since SourceInfo is a case class
+ assert (pqOfSources.contains (selectedSourceInfo))
+
+ // Remove first
+ pqOfSources.remove (selectedSourceInfo)
+ // TODO: Removing a source based on just one failure notification!
+
+ // Update sourceInfo and put it back in, IF reception succeeded
+ if (!sourceInfo.receptionFailed) {
+ // Add thisWorkerInfo to sources that have completed reception
+ setOfCompletedSources += thisWorkerInfo
+
+ selectedSourceInfo.currentLeechers -= 1
+
+ // Put it back
+ pqOfSources.add (selectedSourceInfo)
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ // Assuming that exception caused due to receiver worker failure.
+ // Remove failed worker from pqOfSources and update leecherCount of
+ // corresponding source worker
+ pqOfSources.synchronized {
+ if (selectedSourceInfo != null) {
+ // Remove first
+ pqOfSources.remove (selectedSourceInfo)
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ pqOfSources.add (selectedSourceInfo)
+ }
+
+ // Remove thisWorkerInfo
+ if (pqOfSources != null) {
+ pqOfSources.remove (thisWorkerInfo)
+ }
+ }
+ }
+ } finally {
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+
+ // TODO: Caller must have a synchronized block on pqOfSources
+ // TODO: If a worker fails to get the broadcasted variable from a source and
+ // comes back to Master, this function might choose the worker itself as a
+ // source tp create a dependency cycle (this worker was put into pqOfSources
+ // as a streming source when it first arrived). The length of this cycle can
+ // be arbitrarily long.
+ private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+ // Select one based on the ordering strategy (e.g., least leechers etc.)
+ // take is a blocking call removing the element from PQ
+ var selectedSource = pqOfSources.poll
+ assert (selectedSource != null)
+ // Update leecher count
+ selectedSource.currentLeechers += 1
+ // Add it back and then return
+ pqOfSources.add (selectedSource)
+ return selectedSource
+ }
+ }
+ }
+
+ class ServeMultipleRequests
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = ChainedBroadcast.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (0)
+ listenPort = serverSocket.getLocalPort
+ logInfo ("ServeMultipleRequests started with " + serverSocket)
+
+ listenPortLock.synchronized {
+ listenPortLock.notifyAll
+ }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo ("ServeMultipleRequests Timeout.")
+ }
+ }
+ if (clientSocket != null) {
+ logInfo ("Serve: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute (new ServeSingleRequest (clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo ("ServeMultipleRequests now stopping...")
+ serverSocket.close
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown
+ }
+
+ class ServeSingleRequest (val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ oos.flush
+ private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+ private var sendFrom = 0
+ private var sendUntil = totalBlocks
+
+ override def run: Unit = {
+ try {
+ logInfo ("new ServeSingleRequest is running")
+
+ // Receive range to send
+ var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
+ sendFrom = rangeToSend._1
+ sendUntil = rangeToSend._2
+
+ if (sendFrom == SourceInfo.StopBroadcast &&
+ sendUntil == SourceInfo.StopBroadcast) {
+ stopBroadcast = true
+ } else {
+ // Carry on
+ sendObject
+ }
+ } catch {
+ // TODO: Need to add better exception handling here
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ logInfo ("ServeSingleRequest had a " + e)
+ }
+ } finally {
+ logInfo ("ServeSingleRequest is closing streams and sockets")
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+
+ private def sendObject: Unit = {
+ // Wait till receiving the SourceInfo from Master
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized {
+ totalBlocksLock.wait
+ }
+ }
+
+ for (i <- sendFrom until sendUntil) {
+ while (i == hasBlocks) {
+ hasBlocksLock.synchronized {
+ hasBlocksLock.wait
+ }
+ }
+ try {
+ oos.writeObject (arrayOfBlocks(i))
+ oos.flush
+ } catch {
+ case e: Exception => {
+ logInfo ("sendObject had a " + e)
+ }
+ }
+ logInfo ("Sent block: " + i + " to " + clientSocket)
+ }
+ }
+ }
+ }
+}
+
+private object ChainedBroadcast
+extends Logging {
+ val values = new MapMaker ().softValues ().makeMap[UUID, Any]
+
+ var valueToGuidePortMap = Map[UUID, Int] ()
+
+ // Random number generator
+ var ranGen = new Random
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var MasterHostAddress_ = "127.0.0.1"
+ private var MasterTrackerPort_ : Int = 22222
+ private var BlockSize_ : Int = 512 * 1024
+ private var MaxRetryCount_ : Int = 2
+
+ private var TrackerSocketTimeout_ : Int = 50000
+ private var ServerSocketTimeout_ : Int = 10000
+
+ private var trackMV: TrackMultipleValues = null
+
+ private var MinKnockInterval_ = 500
+ private var MaxKnockInterval_ = 999
+
+ def initialize (isMaster__ : Boolean): Unit = {
+ synchronized {
+ if (!initialized) {
+ MasterHostAddress_ =
+ System.getProperty ("spark.broadcast.MasterHostAddress", "127.0.0.1")
+ MasterTrackerPort_ =
+ System.getProperty ("spark.broadcast.MasterTrackerPort", "22222").toInt
+ BlockSize_ =
+ System.getProperty ("spark.broadcast.BlockSize", "512").toInt * 1024
+ MaxRetryCount_ =
+ System.getProperty ("spark.broadcast.MaxRetryCount", "2").toInt
+
+ TrackerSocketTimeout_ =
+ System.getProperty ("spark.broadcast.TrackerSocketTimeout", "50000").toInt
+ ServerSocketTimeout_ =
+ System.getProperty ("spark.broadcast.ServerSocketTimeout", "10000").toInt
+
+ MinKnockInterval_ =
+ System.getProperty ("spark.broadcast.MinKnockInterval", "500").toInt
+ MaxKnockInterval_ =
+ System.getProperty ("spark.broadcast.MaxKnockInterval", "999").toInt
+
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ trackMV = new TrackMultipleValues
+ trackMV.setDaemon (true)
+ trackMV.start
+ logInfo ("TrackMultipleValues started...")
+ }
+
+ initialized = true
+ }
+ }
+ }
+
+ def MasterHostAddress = MasterHostAddress_
+ def MasterTrackerPort = MasterTrackerPort_
+ def BlockSize = BlockSize_
+ def MaxRetryCount = MaxRetryCount_
+
+ def TrackerSocketTimeout = TrackerSocketTimeout_
+ def ServerSocketTimeout = ServerSocketTimeout_
+
+ def isMaster = isMaster_
+
+ def MinKnockInterval = MinKnockInterval_
+ def MaxKnockInterval = MaxKnockInterval_
+
+ def registerValue (uuid: UUID, guidePort: Int): Unit = {
+ valueToGuidePortMap.synchronized {
+ valueToGuidePortMap += (uuid -> guidePort)
+ logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
+ }
+ }
+
+ def unregisterValue (uuid: UUID): Unit = {
+ valueToGuidePortMap.synchronized {
+ valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
+ logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
+ }
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread (r)
+ t.setDaemon (true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newCachedThreadPool
+ def newDaemonCachedThreadPool: ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory (newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory (newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ class TrackMultipleValues
+ extends Thread with Logging {
+ override def run: Unit = {
+ var threadPool = ChainedBroadcast.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort)
+ logInfo ("TrackMultipleValues" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (TrackerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo ("TrackMultipleValues Timeout. Stopping listening...")
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute (new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ oos.flush
+ val ois = new ObjectInputStream (clientSocket.getInputStream)
+ try {
+ val uuid = ois.readObject.asInstanceOf[UUID]
+ var guidePort =
+ if (valueToGuidePortMap.contains (uuid)) {
+ valueToGuidePortMap (uuid)
+ } else SourceInfo.TxNotStartedRetry
+ logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
+ oos.writeObject (guidePort)
+ } catch {
+ case e: Exception => {
+ logInfo ("TrackMultipleValues had a " + e)
+ }
+ } finally {
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ serverSocket.close
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown
+ }
+ }
+}
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
index 75efd9d1fb..98149dc1b7 100644
--- a/src/scala/spark/SparkContext.scala
+++ b/src/scala/spark/SparkContext.scala
@@ -20,7 +20,7 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
// TODO: Keep around a weak hash map of values to Cached versions?
// def broadcast[T](value: T) = new DfsBroadcast(value, local)
- def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local)
+ def broadcast[T](value: T) = new ChainedBroadcast(value, local)
def textFile(path: String) = new HdfsTextFile(this, path)